Compare commits

...

24 Commits

Author SHA1 Message Date
JustSong
f8cc63f00b feat: add user info to topup link 2024-04-05 01:23:11 +08:00
JustSong
0a37aa4cbd docs: add API docs 2024-04-05 01:10:30 +08:00
JustSong
054b00b725 docs: add API docs 2024-04-05 00:40:48 +08:00
JustSong
76569bb0b6 chore: disable channel when error message contain credit or balance 2024-04-05 00:31:41 +08:00
JustSong
1994256bac chore: disable channel when error message contain quota 2024-04-05 00:18:26 +08:00
JustSong
1f80b0a39f chore: add omitempty for xunfei functions 2024-04-05 00:13:37 +08:00
manjieqi
f73f2e51df feat: update baidu model name & ratio (#1253)
* 修正百度模型名称

* 更新百度模型名称,并保留旧版兼容以及修正单价

* chore: add more model and adjust order

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-05 00:02:15 +08:00
Yang Fei
6f036bd0c9 feat: add embedding-2 support for zhipu (#1273)
* 增加对智谱embedding-2模型的支持

* fix: fix usage & ratio

---------

Co-authored-by: yangfei <yangfei@xuyao.info>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-04 23:32:59 +08:00
JustSong
fb90747c23 fix: fix /v1/models return null data when no models available 2024-04-04 18:53:42 +08:00
JustSong
ed70881a58 fix: fix token create 2024-04-04 11:18:21 +08:00
JustSong
8b9fa3d6e4 fix: fix GetGroupModels 2024-04-04 02:58:21 +08:00
JustSong
8b9813d63b feat: /v1/models now only return available models 2024-04-04 02:44:59 +08:00
JustSong
dc7aaf2de5 feat: able to set model limitation for token (close #178) 2024-04-04 02:08:18 +08:00
JustSong
065da8ef8c fix: fix ali function call (#1242) 2024-04-04 00:46:30 +08:00
JustSong
e3cfb1fa52 feat: use given usage if available in stream mode 2024-03-31 23:41:52 +08:00
JustSong
f89ae5ad58 feat: initial function call support for xunfei 2024-03-31 23:12:29 +08:00
JustSong
06a3fc5421 chore: update GeneralOpenAIRequest 2024-03-31 22:23:42 +08:00
ManJieqi
a9c464ec5a fix: update model-ratio.go 修正文心计费模型名称
统一文心计费模型名称
2024-03-30 11:06:31 +08:00
JustSong
3f3c13c98c feat: support top_k for claude (close #1239) 2024-03-30 10:47:07 +08:00
JustSong
2ba28c72cb feat: support function call for ali (close #1242) 2024-03-30 10:43:26 +08:00
JustSong
5e81e19bc8 fix: fix SQL channel selection algo (#1197) 2024-03-27 19:09:27 +08:00
JustSong
96d7a99312 fix: fix autofilled models are not correct 2024-03-24 23:12:32 +08:00
JustSong
24be9de098 chore: update copy 2024-03-24 23:01:03 +08:00
JustSong
5b349efff9 chore: fix berry copy 2024-03-24 22:57:24 +08:00
49 changed files with 713 additions and 230 deletions

View File

@@ -87,7 +87,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
5. 支持**多机部署**[详见此处](#多机部署)。
6. 支持**令牌管理**,设置令牌的过期时间和额度。
7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
8. 支持**道管理**,批量创建道。
8. 支持**道管理**,批量创建道。
9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
10. 支持渠道**设置模型列表**。
11. 支持**查看额度明细**。
@@ -109,6 +109,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。
24. 配合 [Message Pusher](https://github.com/songquanpeng/message-pusher) 可将报警信息推送到多种 App 上。
25. 支持**扩展**,详情请参考此处 [API 文档](./docs/API.md)。
## 部署
### 基于 Docker 进行部署
@@ -421,7 +422,7 @@ https://openai.justsong.cn
+ 检查你的接口地址和 API Key 有没有填对。
+ 检查是否启用了 HTTPS浏览器会拦截 HTTPS 域名下的 HTTP 请求。
6. 报错:`当前分组负载已饱和,请稍后再试`
+ 上游道 429 了。
+ 上游道 429 了。
7. 升级之后我的数据会丢失吗?
+ 如果使用 MySQL不会。
+ 如果使用 SQLite需要按照我所给的部署命令挂载 volume 持久化 one-api.db 数据库文件,否则容器重启后数据会丢失。
@@ -429,8 +430,8 @@ https://openai.justsong.cn
+ 一般情况下不需要,系统将在初始化的时候自动调整。
+ 如果需要的话,我会在更新日志中说明,并给出脚本。
9. 手动修改数据库后报错:`数据库一致性已被破坏,请联系管理员`
+ 这是检测到 ability 表里有些记录的道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的道。
+ 对于每一个道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该道支持该模型。
+ 这是检测到 ability 表里有些记录的道 id 是不存在的,这大概率是因为你删了 channel 表里的记录但是没有同步在 ability 表里清理无效的道。
+ 对于每一个道,其所支持的模型都需要有一个专门的 ability 表的记录,表示该道支持该模型。
## 相关项目
* [FastGPT](https://github.com/labring/FastGPT): 基于 LLM 大语言模型的知识库问答系统

6
common/conv/any.go Normal file
View File

@@ -0,0 +1,6 @@
package conv
func AsString(v any) string {
str, _ := v.(string)
return str
}

View File

@@ -72,14 +72,22 @@ var ModelRatio = map[string]float64{
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
"ERNIE-Bot-8k": 0.024 * RMB,
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"bge-large-8k": 0.002 * RMB,
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-Bot-8K-0922": 0.024 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Lite-8K": 0.003 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
// https://ai.google.dev/pricing
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
@@ -91,6 +99,7 @@ var ModelRatio = map[string]float64{
"glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB,
"glm-3-turbo": 0.005 * RMB,
"embedding-2": 0.0005 * RMB,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens

View File

@@ -197,7 +197,7 @@ func testChannels(notify bool, scope string) error {
testAllChannelsRunning = false
testAllChannelsLock.Unlock()
if notify {
err := message.Notify(message.ByAll, "道测试完成", "", "道测试完成,如果没有收到禁用通知,说明所有道都正常")
err := message.Notify(message.ByAll, "道测试完成", "", "道测试完成,如果没有收到禁用通知,说明所有道都正常")
if err != nil {
logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error()))
}

View File

@@ -4,12 +4,14 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http"
"strings"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -120,9 +122,41 @@ func DashboardListModels(c *gin.Context) {
}
func ListModels(c *gin.Context) {
ctx := c.Request.Context()
var availableModels []string
if c.GetString("available_models") != "" {
availableModels = strings.Split(c.GetString("available_models"), ",")
} else {
userId := c.GetInt("id")
userGroup, _ := model.CacheGetUserGroup(userId)
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
}
modelSet := make(map[string]bool)
for _, availableModel := range availableModels {
modelSet[availableModel] = true
}
availableOpenAIModels := make([]OpenAIModels, 0)
for _, model := range openAIModels {
if _, ok := modelSet[model.Id]; ok {
modelSet[model.Id] = false
availableOpenAIModels = append(availableOpenAIModels, model)
}
}
for modelName, ok := range modelSet {
if ok {
availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Root: modelName,
Parent: nil,
})
}
}
c.JSON(200, gin.H{
"object": "list",
"data": openAIModels,
"data": availableOpenAIModels,
})
}
@@ -142,3 +176,30 @@ func RetrieveModel(c *gin.Context) {
})
}
}
func GetUserAvailableModels(c *gin.Context) {
ctx := c.Request.Context()
id := c.GetInt("id")
userGroup, err := model.CacheGetUserGroup(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
models, err := model.CacheGetGroupModels(ctx, userGroup)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": models,
})
return
}

View File

@@ -130,6 +130,7 @@ func AddToken(c *gin.Context) {
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
Models: token.Models,
}
err = cleanToken.Insert()
if err != nil {
@@ -216,6 +217,7 @@ func UpdateToken(c *gin.Context) {
cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.Models = token.Models
}
err = cleanToken.Update()
if err != nil {

View File

@@ -180,27 +180,27 @@ func Register(c *gin.Context) {
}
func GetAllUsers(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
order := c.DefaultQuery("order", "")
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": users,
})
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
order := c.DefaultQuery("order", "")
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": users,
})
}
func SearchUsers(c *gin.Context) {
@@ -770,3 +770,38 @@ func TopUp(c *gin.Context) {
})
return
}
type adminTopUpRequest struct {
UserId int `json:"user_id"`
Quota int `json:"quota"`
Remark string `json:"remark"`
}
func AdminTopUp(c *gin.Context) {
req := adminTopUpRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
err = model.IncreaseUserQuota(req.UserId, int64(req.Quota))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if req.Remark == "" {
req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
}
model.RecordTopupLog(req.UserId, req.Remark, req.Quota)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

44
docs/API.md Normal file
View File

@@ -0,0 +1,44 @@
# 使用 API 操控 & 扩展 One API
> 欢迎提交 PR 在此放上你的拓展项目。
例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。
又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。
## 鉴权
One API 支持两种鉴权方式Cookie 和 Token对于 Token参照下图获取
![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/c15281a7-83ed-47cb-a1f6-913cb6bf4a7c)
之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API
![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1273b7ae-cb60-4c0d-93a6-b1cbc039c4f8)
## 请求格式与响应格式
One API 使用 JSON 格式进行请求和响应。
对于响应体,一般格式如下:
```json
{
"message": "请求信息",
"success": true,
"data": {}
}
```
## API 列表
> 当前 API 列表不全,请自行通过浏览器抓取前端请求
如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。
### 获取当前登录用户信息
**GET** `/api/user/self`
### 为给定用户充值额度
**POST** `/api/topup`
```json
{
"user_id": 1,
"quota": 100000,
"remark": "充值 100000 额度"
}
```

View File

@@ -8,12 +8,12 @@
"确认删除": "Confirm Delete",
"确认绑定": "Confirm Binding",
"您正在删除自己的帐户,将清空所有数据且不可恢复": "You are deleting your account, all data will be cleared and unrecoverable.",
"\"道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
"道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s",
"\"道「%s」#%d已被禁用\"": "\"Channel %s (#%d) has been disabled\"",
"道「%s」#%d已被禁用原因%s": "Channel %s (#%d) has been disabled, reason: %s",
"测试已在运行中": "Test is already running",
"响应时间 %.2fs 超过阈值 %.2fs": "Response time %.2fs exceeds threshold %.2fs",
"道测试完成": "Channel test completed",
"道测试完成,如果没有收到禁用通知,说明所有道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal",
"道测试完成": "Channel test completed",
"道测试完成,如果没有收到禁用通知,说明所有道都正常": "Channel test completed, if you have not received the disable notification, it means that all channels are normal",
"无法连接至 GitHub 服务器,请稍后重试!": "Unable to connect to GitHub server, please try again later!",
"返回值非法,用户字段为空,请稍后重试!": "The return value is illegal, the user field is empty, please try again later!",
"管理员未开启通过 GitHub 登录以及注册": "The administrator did not turn on login and registration via GitHub",
@@ -119,11 +119,11 @@
" 个月 ": " M ",
" 年 ": " y ",
"未测试": "Not tested",
"道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
"已成功开始测试所有道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
"已成功开始测试所有已启用道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
"道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
"已更新完毕所有已启用道余额!": "The balance of all enabled channels has been updated!",
"道 ${name} 测试成功,耗时 ${time.toFixed(2)} 秒。": "Channel ${name} test succeeded, time consumed ${time.toFixed(2)} s.",
"已成功开始测试所有道,请刷新页面查看结果。": "All channels have been successfully tested, please refresh the page to view the results.",
"已成功开始测试所有已启用道,请刷新页面查看结果。": "All enabled channels have been successfully tested, please refresh the page to view the results.",
"道 ${name} 余额更新成功!": "Channel ${name} balance updated successfully!",
"已更新完毕所有已启用道余额!": "The balance of all enabled channels has been updated!",
"搜索渠道的 ID名称和密钥 ...": "Search for channel ID, name and key ...",
"名称": "Name",
"分组": "Group",
@@ -141,9 +141,9 @@
"启用": "Enable",
"编辑": "Edit",
"添加新的渠道": "Add a new channel",
"测试所有道": "Test all channels",
"测试所有已启用道": "Test all enabled channels",
"更新所有已启用道余额": "Update the balance of all enabled channels",
"测试所有道": "Test all channels",
"测试所有已启用道": "Test all enabled channels",
"更新所有已启用道余额": "Update the balance of all enabled channels",
"刷新": "Refresh",
"处理中...": "Processing...",
"绑定成功!": "Binding succeeded!",
@@ -207,11 +207,11 @@
"监控设置": "Monitoring Settings",
"最长响应时间": "Longest Response Time",
"单位秒": "Unit in seconds",
"当运行道全部测试时": "When all operating channels are tested",
"超过此时间将自动禁用道": "Channels will be automatically disabled if this time is exceeded",
"当运行道全部测试时": "When all operating channels are tested",
"超过此时间将自动禁用道": "Channels will be automatically disabled if this time is exceeded",
"额度提醒阈值": "Quota reminder threshold",
"低于此额度时将发送邮件提醒用户": "Email will be sent to remind users when the quota is below this",
"失败时自动禁用道": "Automatically disable the channel when it fails",
"失败时自动禁用道": "Automatically disable the channel when it fails",
"保存监控设置": "Save Monitoring Settings",
"额度设置": "Quota Settings",
"新用户初始额度": "Initial quota for new users",
@@ -405,7 +405,7 @@
"镜像": "Mirror",
"请输入镜像站地址格式为https://domain.com可不填不填则使用渠道默认值": "Please enter the mirror site address, the format is: https://domain.com, it can be left blank, if left blank, the default value of the channel will be used",
"模型": "Model",
"请选择该道所支持的模型": "Please select the model supported by the channel",
"请选择该道所支持的模型": "Please select the model supported by the channel",
"填入基础模型": "Fill in the basic model",
"填入所有模型": "Fill in all models",
"清除所有模型": "Clear all models",
@@ -515,7 +515,7 @@
"请输入自定义渠道的 Base URL": "Please enter the Base URL of the custom channel",
"Homepage URL 填": "Fill in the Homepage URL",
"Authorization callback URL 填": "Fill in the Authorization callback URL",
"请为道命名": "Please name the channel",
"请为道命名": "Please name the channel",
"此项可选,用于修改请求体中的模型名称,为一个 JSON 字符串,键为请求中模型名称,值为要替换的模型名称,例如:": "This is optional, used to modify the model name in the request body, it's a JSON string, the key is the model name in the request, and the value is the model name to be replaced, for example:",
"模型重定向": "Model redirection",
"请输入渠道对应的鉴权密钥": "Please enter the authentication key corresponding to the channel",

View File

@@ -1,6 +1,7 @@
package middleware
import (
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
@@ -107,6 +108,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
requestModel, err := getRequestModel(c)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
c.Set("request_model", requestModel)
if token.Models != nil && *token.Models != "" {
c.Set("available_models", *token.Models)
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
return
}
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)

View File

@@ -2,14 +2,12 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type ModelRequest struct {
@@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) {
return
}
} else {
// Select a channel for the user
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
requestModel := c.GetString("request_model")
var err error
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e-2"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
}
requestModel = modelRequest.Model
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
if channel != nil {
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"

View File

@@ -1,9 +1,12 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"strings"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
@@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.Abort()
logger.Error(c.Request.Context(), message)
}
func getRequestModel(c *gin.Context) (string, error) {
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e-2"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
}
return modelRequest.Model, nil
}
func isModelInList(modelName string, models string) bool {
modelList := strings.Split(models, ",")
for _, model := range modelList {
if modelName == model {
return true
}
}
return false
}

View File

@@ -1,7 +1,10 @@
package model
import (
"context"
"github.com/songquanpeng/one-api/common"
"gorm.io/gorm"
"sort"
"strings"
)
@@ -13,7 +16,7 @@ type Ability struct {
Priority *int64 `json:"priority" gorm:"bigint;default:0;index"`
}
func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
func GetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
ability := Ability{}
groupCol := "`group`"
trueVal := "1"
@@ -23,8 +26,13 @@ func GetRandomSatisfiedChannel(group string, model string) (*Channel, error) {
}
var err error = nil
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery := DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
var channelQuery *gorm.DB
if ignoreFirstPriority {
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
} else {
maxPrioritySubQuery := DB.Model(&Ability{}).Select("MAX(priority)").Where(groupCol+" = ? and model = ? and enabled = "+trueVal, group, model)
channelQuery = DB.Where(groupCol+" = ? and model = ? and enabled = "+trueVal+" and priority = (?)", group, model, maxPrioritySubQuery)
}
if common.UsingSQLite || common.UsingPostgreSQL {
err = channelQuery.Order("RANDOM()").First(&ability).Error
} else {
@@ -82,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error {
func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}
func GetGroupModels(ctx context.Context, group string) ([]string, error) {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var models []string
err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error
if err != nil {
return nil, err
}
sort.Strings(models)
return models, err
}

View File

@@ -21,6 +21,7 @@ var (
UserId2GroupCacheSeconds = config.SyncFrequency
UserId2QuotaCacheSeconds = config.SyncFrequency
UserId2StatusCacheSeconds = config.SyncFrequency
GroupModelsCacheSeconds = config.SyncFrequency
)
func CacheGetTokenByKey(key string) (*Token, error) {
@@ -146,6 +147,25 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err
}
func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
if !common.RedisEnabled {
return GetGroupModels(ctx, group)
}
modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group))
if err == nil {
return strings.Split(modelsStr, ","), nil
}
models, err := GetGroupModels(ctx, group)
if err != nil {
return nil, err
}
err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second)
if err != nil {
logger.SysError("Redis set group models error: " + err.Error())
}
return models, nil
}
var group2model2channels map[string]map[string][]*Channel
var channelSyncLock sync.RWMutex
@@ -205,7 +225,7 @@ func SyncChannelCache(frequency int) {
func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPriority bool) (*Channel, error) {
if !config.MemoryCacheEnabled {
return GetRandomSatisfiedChannel(group, model)
return GetRandomSatisfiedChannel(group, model, ignoreFirstPriority)
}
channelSyncLock.RLock()
defer channelSyncLock.RUnlock()

View File

@@ -51,6 +51,21 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordTopupLog(userId int, content string, quota int) {
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
CreatedAt: helper.GetTimestamp(),
Type: LogTypeTopup,
Content: content,
Quota: quota,
}
err := LOG_DB.Create(log).Error
if err != nil {
logger.SysError("failed to record log: " + err.Error())
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled {

View File

@@ -12,24 +12,25 @@ import (
)
type Token struct {
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
Models *string `json:"models" gorm:"default:''"`
}
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
var tokens []*Token
var err error
query := DB.Where("user_id = ?", userId)
switch order {
case "remain_quota":
query = query.Order("unlimited_quota desc, remain_quota desc")
@@ -38,7 +39,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token
default:
query = query.Order("id desc")
}
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
return tokens, err
}
@@ -121,7 +122,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models").Updates(token).Error
return err
}

View File

@@ -31,17 +31,17 @@ func notifyRootUser(subject string, content string) {
func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason))
subject := fmt.Sprintf("道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
subject := fmt.Sprintf("道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
notifyRootUser(subject, content)
}
func MetricDisableChannel(channelId int, successRate float64) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100))
subject := fmt.Sprintf("道 #%d 已被禁用", channelId)
content := fmt.Sprintf("该渠道在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100)
subject := fmt.Sprintf("道 #%d 已被禁用", channelId)
content := fmt.Sprintf("该渠道#%d在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
channelId, config.MetricQueueSize, successRate*100, config.MetricSuccessRateThreshold*100)
notifyRootUser(subject, content)
}
@@ -49,7 +49,7 @@ func MetricDisableChannel(channelId int, successRate float64) {
func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId))
subject := fmt.Sprintf("道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("道「%s」#%d已被启用", channelName, channelId)
subject := fmt.Sprintf("道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("道「%s」#%d已被启用", channelName, channelId)
notifyRootUser(subject, content)
}

View File

@@ -48,6 +48,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
MaxTokens: request.MaxTokens,
Temperature: request.Temperature,
TopP: request.TopP,
TopK: request.TopK,
ResultFormat: "message",
Tools: request.Tools,
},
}
}
@@ -117,19 +120,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
}
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := openai.TextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
Choices: response.Output.Choices,
Usage: model.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
@@ -140,10 +135,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
}
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
if len(aliResponse.Output.Choices) == 0 {
return nil
}
aliChoice := aliResponse.Output.Choices[0]
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.Delta = aliChoice.Message
if aliChoice.FinishReason != "null" {
finishReason := aliChoice.FinishReason
choice.FinishReason = &finishReason
}
response := openai.ChatCompletionsStreamResponse{
@@ -204,6 +203,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
if response == nil {
return true
}
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
@@ -226,6 +228,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
ctx := c.Request.Context()
var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -235,6 +238,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
logger.Debugf(ctx, "response body: %s\n", responseBody)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -1,5 +1,10 @@
package ali
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
@@ -11,13 +16,15 @@ type Input struct {
}
type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
ResultFormat string `json:"result_format,omitempty"`
Tools []model.Tool `json:"tools,omitempty"`
}
type ChatRequest struct {
@@ -62,8 +69,9 @@ type Usage struct {
}
type Output struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
//Text string `json:"text"`
//FinishReason string `json:"finish_reason"`
Choices []openai.TextResponseChoice `json:"choices"`
}
type ChatResponse struct {

View File

@@ -38,6 +38,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
MaxTokens: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokens == 0 {

View File

@@ -38,16 +38,26 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
suffix += "completions_pro"
case "ERNIE-Bot-4":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-Bot":
suffix += "completions"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-4.0-8K":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-Speed-8K":
suffix += "ernie_speed"
case "ERNIE-Speed-128K":
suffix += "ernie-speed-128k"
case "ERNIE-Lite-8K":
suffix += "ernie-lite-8k"
case "ERNIE-Tiny-8K":
suffix += "ernie-tiny-8k"
case "BLOOMZ-7B":
suffix += "bloomz_7b1"
case "Embedding-V1":
@@ -59,7 +69,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
case "tao-8k":
suffix += "tao_8k"
default:
suffix += meta.ActualModelName
suffix += strings.ToLower(meta.ActualModelName)
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix)
var accessToken string

View File

@@ -1,11 +1,18 @@
package baidu
var ModelList = []string{
"ERNIE-Bot-4",
"ERNIE-Bot-8K",
"ERNIE-Bot",
"ERNIE-Speed",
"ERNIE-Bot-turbo",
"ERNIE-4.0-8K",
"ERNIE-Bot-8K-0922",
"ERNIE-3.5-8K",
"ERNIE-Lite-8K-0922",
"ERNIE-Speed-8K",
"ERNIE-3.5-4K-0205",
"ERNIE-3.5-8K-0205",
"ERNIE-3.5-8K-1222",
"ERNIE-Lite-8K",
"ERNIE-Speed-128K",
"ERNIE-Tiny-8K",
"BLOOMZ-7B",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",

View File

@@ -70,8 +70,10 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText, _ = StreamHandler(c, resp, meta.Mode)
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
err, responseText, usage = StreamHandler(c, resp, meta.Mode)
if usage == nil {
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
}
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
@@ -53,7 +54,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
continue // just ignore the error
}
for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content
responseText += conv.AsString(choice.Delta.Content)
}
if streamResponse.Usage != nil {
usage = streamResponse.Usage

View File

@@ -118,12 +118,9 @@ type ImageResponse struct {
}
type ChatCompletionsStreamResponseChoice struct {
Index int `json:"index"`
Delta struct {
Content string `json:"content"`
Role string `json:"role,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
Index int `json:"index"`
Delta model.Message `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {

View File

@@ -10,6 +10,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
@@ -129,7 +130,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
responseText += conv.AsString(response.Choices[0].Delta.Content)
}
jsonResponse, err := json.Marshal(response)
if err != nil {

View File

@@ -26,7 +26,11 @@ import (
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
var lastToolCalls []model.Tool
for _, message := range request.Messages {
if message.ToolCalls != nil {
lastToolCalls = message.ToolCalls
}
messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
@@ -39,9 +43,33 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages
if len(lastToolCalls) != 0 {
for _, toolCall := range lastToolCalls {
xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function)
}
}
return &xunfeiRequest
}
func getToolCalls(response *ChatResponse) []model.Tool {
var toolCalls []model.Tool
if len(response.Payload.Choices.Text) == 0 {
return toolCalls
}
item := response.Payload.Choices.Text[0]
if item.FunctionCall == nil {
return toolCalls
}
toolCall := model.Tool{
Id: fmt.Sprintf("call_%s", helper.GetUUID()),
Type: "function",
Function: *item.FunctionCall,
}
toolCalls = append(toolCalls, toolCall)
return toolCalls
}
func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []ChatResponseTextItem{
@@ -53,8 +81,9 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
ToolCalls: getToolCalls(response),
},
FinishReason: constant.StopFinishReason,
}
@@ -78,6 +107,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
choice.Delta.ToolCalls = getToolCalls(xunfeiResponse)
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &constant.StopFinishReason
}

View File

@@ -26,13 +26,18 @@ type ChatRequest struct {
Message struct {
Text []Message `json:"text"`
} `json:"message"`
Functions struct {
Text []model.Function `json:"text,omitempty"`
} `json:"functions,omitempty"`
} `json:"payload"`
}
type ChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
ContentType string `json:"content_type"`
FunctionCall *model.Function `json:"function_call"`
}
type ChatResponse struct {

View File

@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
@@ -35,6 +36,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
if a.APIVersion == "v4" {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
}
if meta.Mode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
}
method := "invoke"
if meta.IsStream {
method = "sse-invoke"
@@ -53,18 +57,24 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
// TopP (0.0, 1.0)
request.TopP = math.Min(0.99, request.TopP)
request.TopP = math.Max(0.01, request.TopP)
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
// TopP (0.0, 1.0)
request.TopP = math.Min(0.99, request.TopP)
request.TopP = math.Max(0.01, request.TopP)
// Temperature (0.0, 1.0)
request.Temperature = math.Min(0.99, request.Temperature)
request.Temperature = math.Max(0.01, request.Temperature)
a.SetVersionByModeName(request.Model)
if a.APIVersion == "v4" {
return request, nil
// Temperature (0.0, 1.0)
request.Temperature = math.Min(0.99, request.Temperature)
request.Temperature = math.Max(0.01, request.Temperature)
a.SetVersionByModeName(request.Model)
if a.APIVersion == "v4" {
return request, nil
}
return ConvertRequest(*request), nil
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
@@ -84,14 +94,26 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
if a.APIVersion == "v4" {
return a.DoResponseV4(c, resp, meta)
}
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp)
if meta.Mode == constant.RelayModeEmbeddings {
err, usage = EmbeddingsHandler(c, resp)
} else {
err, usage = Handler(c, resp)
}
}
return
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "embedding-2",
Input: request.Input.(string),
}
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}

View File

@@ -2,5 +2,5 @@ package zhipu
var ModelList = []string{
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
"glm-4", "glm-4v", "glm-3-turbo",
"glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
}

View File

@@ -254,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var zhipuResponse EmbeddingRespone
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
Model: response.Model,
Usage: model.Usage{
PromptTokens: response.PromptTokens,
CompletionTokens: response.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
},
}
for _, item := range response.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}

View File

@@ -44,3 +44,21 @@ type tokenData struct {
Token string
ExpiryTime time.Time
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input string `json:"input"`
}
type EmbeddingRespone struct {
Model string `json:"model"`
Object string `json:"object"`
Embeddings []EmbeddingData `json:"data"`
model.Usage `json:"usage"`
}
type EmbeddingData struct {
Index int `json:"index"`
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
}

View File

@@ -5,25 +5,29 @@ type ResponseFormat struct {
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
Model string `json:"model,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
Functions any `json:"functions,omitempty"`
User string `json:"user,omitempty"`
Prompt any `json:"prompt,omitempty"`
Input any `json:"input,omitempty"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {

View File

@@ -1,9 +1,10 @@
package model
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"`
Name *string `json:"name,omitempty"`
ToolCalls []Tool `json:"tool_calls,omitempty"`
}
func (m Message) IsStringContent() bool {

14
relay/model/tool.go Normal file
View File

@@ -0,0 +1,14 @@
package model
type Tool struct {
Id string `json:"id,omitempty"`
Type string `json:"type"`
Function Function `json:"function"`
}
type Function struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters any `json:"parameters,omitempty"` // request
Arguments any `json:"arguments,omitempty"` // response
}

View File

@@ -46,6 +46,15 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
if strings.Contains(err.Message, "quota") {
return true
}
if strings.Contains(err.Message, "credit") {
return true
}
if strings.Contains(err.Message, "balance") {
return true
}
return false
}

View File

@@ -26,6 +26,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp)
userRoute := apiRouter.Group("/user")
{
@@ -43,6 +44,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", controller.TopUp)
selfRoute.GET("/available_models", controller.GetUserAvailableModels)
}
adminRoute := userRoute.Group("/")

View File

@@ -437,7 +437,7 @@ const ChannelsTable = () => {
if (success) {
record.response_time = time * 1000;
record.test_time = Date.now() / 1000;
showInfo(`${record.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
showInfo(`${record.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else {
showError(message);
}
@@ -447,7 +447,7 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/test?scope=${scope}`);
const { success, message } = res.data;
if (success) {
showInfo('已成功开始测试道,请刷新页面查看结果。');
showInfo('已成功开始测试道,请刷新页面查看结果。');
} else {
showError(message);
}
@@ -470,7 +470,7 @@ const ChannelsTable = () => {
if (success) {
record.balance = balance;
record.balance_updated_time = Date.now() / 1000;
showInfo(`${record.name} 余额更新成功!`);
showInfo(`${record.name} 余额更新成功!`);
} else {
showError(message);
}
@@ -481,7 +481,7 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/update_balance`);
const { success, message } = res.data;
if (success) {
showInfo('已更新完毕所有已启用道余额!');
showInfo('已更新完毕所有已启用道余额!');
} else {
showError(message);
}
@@ -490,7 +490,7 @@ const ChannelsTable = () => {
const batchDeleteChannels = async () => {
if (selectedChannels.length === 0) {
showError('请先选择要删除的道!');
showError('请先选择要删除的道!');
return;
}
setLoading(true);
@@ -501,7 +501,7 @@ const ChannelsTable = () => {
const res = await API.post(`/api/channel/batch`, { ids: ids });
const { success, message, data } = res.data;
if (success) {
showSuccess(`已删除 ${data}道!`);
showSuccess(`已删除 ${data}道!`);
await refresh();
} else {
showError(message);
@@ -513,7 +513,7 @@ const ChannelsTable = () => {
const res = await API.post(`/api/channel/fix`);
const { success, message, data } = res.data;
if (success) {
showSuccess(`已修复 ${data}道!`);
showSuccess(`已修复 ${data}道!`);
await refresh();
} else {
showError(message);
@@ -633,7 +633,7 @@ const ChannelsTable = () => {
onConfirm={() => { testChannels("all") }}
position={isMobile() ? 'top' : 'left'}
>
<Button theme="light" type="warning" style={{ marginRight: 8 }}>测试所有</Button>
<Button theme="light" type="warning" style={{ marginRight: 8 }}>测试所有</Button>
</Popconfirm>
<Popconfirm
title="确定?"
@@ -648,16 +648,16 @@ const ChannelsTable = () => {
okType={'secondary'}
onConfirm={updateAllChannelsBalance}
>
<Button theme="light" type="secondary" style={{ marginRight: 8 }}>更新所有已启用道余额</Button>
<Button theme="light" type="secondary" style={{ marginRight: 8 }}>更新所有已启用道余额</Button>
</Popconfirm> */}
<Popconfirm
title="确定是否要删除禁用道?"
title="确定是否要删除禁用道?"
content="此修改将不可逆"
okType={'danger'}
onConfirm={deleteAllDisabledChannels}
position={isMobile() ? 'top' : 'left'}
>
<Button theme="light" type="danger" style={{ marginRight: 8 }}>删除禁用</Button>
<Button theme="light" type="danger" style={{ marginRight: 8 }}>删除禁用</Button>
</Popconfirm>
<Button theme="light" type="primary" style={{ marginRight: 8 }} onClick={refresh}>刷新</Button>
@@ -673,7 +673,7 @@ const ChannelsTable = () => {
setEnableBatchDelete(v);
}}></Switch>
<Popconfirm
title="确定是否要删除所选道?"
title="确定是否要删除所选道?"
content="此修改将不可逆"
okType={'danger'}
onConfirm={batchDeleteChannels}
@@ -681,7 +681,7 @@ const ChannelsTable = () => {
position={'top'}
>
<Button disabled={!enableBatchDelete} theme="light" type="danger"
style={{ marginRight: 8 }}>删除所选道</Button>
style={{ marginRight: 8 }}>删除所选道</Button>
</Popconfirm>
<Popconfirm
title="确定是否要修复数据库一致性?"

View File

@@ -261,7 +261,7 @@ const OperationSetting = () => {
value={inputs.ChannelDisableThreshold}
type='number'
min='0'
placeholder='单位秒,当运行道全部测试时,超过此时间将自动禁用道'
placeholder='单位秒,当运行道全部测试时,超过此时间将自动禁用道'
/>
<Form.Input
label='额度提醒阈值'
@@ -277,13 +277,13 @@ const OperationSetting = () => {
<Form.Group inline>
<Form.Checkbox
checked={inputs.AutomaticDisableChannelEnabled === 'true'}
label='失败时自动禁用道'
label='失败时自动禁用道'
name='AutomaticDisableChannelEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.AutomaticEnableChannelEnabled === 'true'}
label='成功时自动启用道'
label='成功时自动启用道'
name='AutomaticEnableChannelEnabled'
onChange={handleInputChange}
/>

View File

@@ -51,7 +51,7 @@ const Register = () => {
<Grid item xs={12}>
<Grid item container direction="column" alignItems="center" xs={12}>
<Typography component={Link} to="/login" variant="subtitle1" sx={{ textDecoration: 'none' }}>
已经有帐号了?点击登录
已经有帐号了点击登录
</Typography>
</Grid>
</Grid>

View File

@@ -296,7 +296,7 @@ const RegisterForm = ({ ...others }) => {
<Box sx={{ mt: 2 }}>
<AnimateButton>
<Button disableElevation disabled={isSubmitting} fullWidth size="large" type="submit" variant="contained" color="primary">
Sign up
注册
</Button>
</AnimateButton>
</Box>

View File

@@ -93,7 +93,7 @@ export default function ChannelTableRow({
test_time: Date.now() / 1000,
response_time: time * 1000,
});
showInfo(`${item.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
showInfo(`${item.name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
}
};
@@ -243,9 +243,9 @@ export default function ChannelTableRow({
</Popover>
<Dialog open={openDelete} onClose={handleDeleteClose}>
<DialogTitle>删除</DialogTitle>
<DialogTitle>删除</DialogTitle>
<DialogContent>
<DialogContentText>是否删除 {item.name}</DialogContentText>
<DialogContentText>是否删除 {item.name}</DialogContentText>
</DialogContent>
<DialogActions>
<Button onClick={handleDeleteClose}>关闭</Button>

View File

@@ -135,7 +135,7 @@ export default function ChannelPage() {
const res = await API.get(`/api/channel/test`);
const { success, message } = res.data;
if (success) {
showInfo('已成功开始测试所有道,请刷新页面查看结果。');
showInfo('已成功开始测试所有道,请刷新页面查看结果。');
} else {
showError(message);
}
@@ -159,7 +159,7 @@ export default function ChannelPage() {
const res = await API.get(`/api/channel/update_balance`);
const { success, message } = res.data;
if (success) {
showInfo('已更新完毕所有已启用道余额!');
showInfo('已更新完毕所有已启用道余额!');
} else {
showError(message);
}

View File

@@ -371,7 +371,7 @@ const OperationSetting = () => {
value={inputs.ChannelDisableThreshold}
onChange={handleInputChange}
label="最长响应时间"
placeholder="单位秒,当运行道全部测试时,超过此时间将自动禁用道"
placeholder="单位秒,当运行道全部测试时,超过此时间将自动禁用道"
disabled={loading}
/>
</FormControl>
@@ -392,7 +392,7 @@ const OperationSetting = () => {
</FormControl>
</Stack>
<FormControlLabel
label="失败时自动禁用道"
label="失败时自动禁用道"
control={
<Checkbox
checked={inputs.AutomaticDisableChannelEnabled === "true"}
@@ -402,7 +402,7 @@ const OperationSetting = () => {
}
/>
<FormControlLabel
label="成功时自动启用道"
label="成功时自动启用道"
control={
<Checkbox
checked={inputs.AutomaticEnableChannelEnabled === "true"}

View File

@@ -234,7 +234,7 @@ const ChannelsTable = () => {
newChannels[realIdx].response_time = time * 1000;
newChannels[realIdx].test_time = Date.now() / 1000;
setChannels(newChannels);
showInfo(`${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
showInfo(`${name} 测试成功,耗时 ${time.toFixed(2)} 秒。`);
} else {
showError(message);
}
@@ -244,7 +244,7 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/test?scope=${scope}`);
const { success, message } = res.data;
if (success) {
showInfo('已成功开始测试道,请刷新页面查看结果。');
showInfo('已成功开始测试道,请刷新页面查看结果。');
} else {
showError(message);
}
@@ -270,7 +270,7 @@ const ChannelsTable = () => {
newChannels[realIdx].balance = balance;
newChannels[realIdx].balance_updated_time = Date.now() / 1000;
setChannels(newChannels);
showInfo(`${name} 余额更新成功!`);
showInfo(`${name} 余额更新成功!`);
} else {
showError(message);
}
@@ -281,7 +281,7 @@ const ChannelsTable = () => {
const res = await API.get(`/api/channel/update_balance`);
const { success, message } = res.data;
if (success) {
showInfo('已更新完毕所有已启用道余额!');
showInfo('已更新完毕所有已启用道余额!');
} else {
showError(message);
}

View File

@@ -261,7 +261,7 @@ const OperationSetting = () => {
value={inputs.ChannelDisableThreshold}
type='number'
min='0'
placeholder='单位秒,当运行道全部测试时,超过此时间将自动禁用道'
placeholder='单位秒,当运行道全部测试时,超过此时间将自动禁用道'
/>
<Form.Input
label='额度提醒阈值'
@@ -277,13 +277,13 @@ const OperationSetting = () => {
<Form.Group inline>
<Form.Checkbox
checked={inputs.AutomaticDisableChannelEnabled === 'true'}
label='失败时自动禁用道'
label='失败时自动禁用道'
name='AutomaticDisableChannelEnabled'
onChange={handleInputChange}
/>
<Form.Checkbox
checked={inputs.AutomaticEnableChannelEnabled === 'true'}
label='成功时自动启用道'
label='成功时自动启用道'
name='AutomaticEnableChannelEnabled'
onChange={handleInputChange}
/>

View File

@@ -83,6 +83,7 @@ const EditChannel = () => {
data.model_mapping = JSON.stringify(JSON.parse(data.model_mapping), null, 2);
}
setInputs(data);
setBasicModels(getChannelModels(data.type));
} else {
showError(message);
}
@@ -99,9 +100,6 @@ const EditChannel = () => {
}));
setOriginModelOptions(localModelOptions);
setFullModels(res.data.data.map((model) => model.id));
setBasicModels(res.data.data.filter((model) => {
return model.id.startsWith('gpt-3') || model.id.startsWith('text-');
}).map((model) => model.id));
} catch (error) {
showError(error.message);
}
@@ -137,6 +135,9 @@ const EditChannel = () => {
useEffect(() => {
if (isEdit) {
loadChannel().then();
} else {
let localModels = getChannelModels(inputs.type);
setBasicModels(localModels);
}
fetchModels().then();
fetchGroups().then();
@@ -355,7 +356,7 @@ const EditChannel = () => {
<div style={{ lineHeight: '40px', marginBottom: '12px' }}>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: basicModels });
}}>填入基础模型</Button>
}}>填入相关模型</Button>
<Button type={'button'} onClick={() => {
handleInputChange(null, { name: 'models', value: fullModels });
}}>填入所有模型</Button>

View File

@@ -1,19 +1,21 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
import { useParams, useNavigate } from 'react-router-dom';
import { API, showError, showSuccess, timestamp2string } from '../../helpers';
import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render';
import { useNavigate, useParams } from 'react-router-dom';
import { API, copy, showError, showSuccess, timestamp2string } from '../../helpers';
import { renderQuotaWithPrompt } from '../../helpers/render';
const EditToken = () => {
const params = useParams();
const tokenId = params.id;
const isEdit = tokenId !== undefined;
const [loading, setLoading] = useState(isEdit);
const [modelOptions, setModelOptions] = useState([]);
const originInputs = {
name: '',
remain_quota: isEdit ? 0 : 500000,
expired_time: -1,
unlimited_quota: false
unlimited_quota: false,
models: []
};
const [inputs, setInputs] = useState(originInputs);
const { name, remain_quota, expired_time, unlimited_quota } = inputs;
@@ -22,8 +24,8 @@ const EditToken = () => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
const handleCancel = () => {
navigate("/token");
}
navigate('/token');
};
const setExpiredTime = (month, day, hour, minute) => {
let now = new Date();
let timestamp = now.getTime() / 1000;
@@ -50,6 +52,11 @@ const EditToken = () => {
if (data.expired_time !== -1) {
data.expired_time = timestamp2string(data.expired_time);
}
if (data.models === '') {
data.models = [];
} else {
data.models = data.models.split(',');
}
setInputs(data);
} else {
showError(message);
@@ -60,8 +67,26 @@ const EditToken = () => {
if (isEdit) {
loadToken().then();
}
loadAvailableModels().then();
}, []);
const loadAvailableModels = async () => {
let res = await API.get(`/api/user/available_models`);
const { success, message, data } = res.data;
if (success) {
let options = data.map((model) => {
return {
key: model,
text: model,
value: model
};
});
setModelOptions(options);
} else {
showError(message);
}
};
const submit = async () => {
if (!isEdit && inputs.name === '') return;
let localInputs = inputs;
@@ -74,6 +99,7 @@ const EditToken = () => {
}
localInputs.expired_time = Math.ceil(time / 1000);
}
localInputs.models = localInputs.models.join(',');
let res;
if (isEdit) {
res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) });
@@ -109,6 +135,24 @@ const EditToken = () => {
required={!isEdit}
/>
</Form.Field>
<Form.Field>
<Form.Dropdown
label='模型范围'
placeholder={'请选择允许使用的模型,留空则不进行限制'}
name='models'
fluid
multiple
search
onLabelClick={(e, { value }) => {
copy(value).then();
}}
selection
onChange={handleInputChange}
value={inputs.models}
autoComplete='new-password'
options={modelOptions}
/>
</Form.Field>
<Form.Field>
<Form.Input
label='过期时间'

View File

@@ -8,6 +8,7 @@ const TopUp = () => {
const [topUpLink, setTopUpLink] = useState('');
const [userQuota, setUserQuota] = useState(0);
const [isSubmitting, setIsSubmitting] = useState(false);
const [user, setUser] = useState({});
const topUp = async () => {
if (redemptionCode === '') {
@@ -41,7 +42,14 @@ const TopUp = () => {
showError('超级管理员未设置充值链接!');
return;
}
window.open(topUpLink, '_blank');
let url = new URL(topUpLink);
let username = user.username;
let user_id = user.id;
// add username and user_id to the topup link
url.searchParams.append('username', username);
url.searchParams.append('user_id', user_id);
url.searchParams.append('transaction_id', crypto.randomUUID());
window.open(url.toString(), '_blank');
};
const getUserQuota = async ()=>{
@@ -49,6 +57,7 @@ const TopUp = () => {
const {success, message, data} = res.data;
if (success) {
setUserQuota(data.quota);
setUser(data);
} else {
showError(message);
}
@@ -80,7 +89,7 @@ const TopUp = () => {
}}
/>
<Button color='green' onClick={openTopUpLink}>
获取兑换码
充值
</Button>
<Button color='yellow' onClick={topUp} disabled={isSubmitting}>
{isSubmitting ? '兑换中...' : '兑换'}