Compare commits

...

38 Commits

Author SHA1 Message Date
JustSong
a44fb5d482 fix: fix channel model list is empty 2024-04-05 23:44:57 +08:00
JustSong
eec41849ec chore: fix ali image implementation 2024-04-05 18:25:57 +08:00
Mo
d4347e7a35 feat: support Ali stable-diffusion-xl and wanx-v1 model (#1240)
* Fix ali ConvertRequest function to use baidu keyword

* Support Ali stable-diffusion-xl and wanx-v1 model

* Support Ali stable-diffusion-xl and wanx-v1 model

* Support Ali stable-diffusion-xl and wanx-v1 model

* chore: update ali constants and model ratio

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2024-04-05 18:09:54 +08:00
manjieqi
b50b43eb65 feat: update baidu model name & ratio (#1277) 2024-04-05 17:30:48 +08:00
JustSong
348adc2b02 feat: able to set multiple subnets 2024-04-05 17:25:28 +08:00
JustSong
dcf24b98dc chore: update berry copy 2024-04-05 14:28:38 +08:00
JustSong
af679e04f4 chore: sort channel type for berry 2024-04-05 14:23:39 +08:00
JustSong
93cbca6a9f chore: update show notice duration 2024-04-05 14:14:21 +08:00
JustSong
840ef80d94 fix: do not try to parse model when requesting /v1/models (close #1272) 2024-04-05 12:50:31 +08:00
JustSong
9a2662af0d feat: show token info when quota is not enough (close #1274) 2024-04-05 12:42:14 +08:00
JustSong
77f9e75654 fix: fix IsValidSubnet 2024-04-05 12:40:03 +08:00
JustSong
5b41f57423 feat: support stepfun's models 2024-04-05 12:32:05 +08:00
JustSong
0bb7db0b44 fix: do not detect quota field in error message (close #1276) 2024-04-05 12:11:50 +08:00
JustSong
4d61b9937b feat: support feishu login now 2024-04-05 12:10:43 +08:00
JustSong
68605800af feat: add subnet validation (#1275) 2024-04-05 10:18:42 +08:00
JustSong
c49778c254 feat: now able to limit ip range for token now (close #1275) 2024-04-05 10:09:16 +08:00
JustSong
f02c7138ea docs: update README 2024-04-05 01:35:14 +08:00
JustSong
ca3228855a docs: update API docs 2024-04-05 01:29:22 +08:00
JustSong
f8cc63f00b feat: add user info to topup link 2024-04-05 01:23:11 +08:00
JustSong
0a37aa4cbd docs: add API docs 2024-04-05 01:10:30 +08:00
JustSong
054b00b725 docs: add API docs 2024-04-05 00:40:48 +08:00
JustSong
76569bb0b6 chore: disable channel when error message contain credit or balance 2024-04-05 00:31:41 +08:00
JustSong
1994256bac chore: disable channel when error message contain quota 2024-04-05 00:18:26 +08:00
JustSong
1f80b0a39f chore: add omitempty for xunfei functions 2024-04-05 00:13:37 +08:00
manjieqi
f73f2e51df feat: update baidu model name & ratio (#1253)
* 修正百度模型名称

* 更新百度模型名称,并保留旧版兼容以及修正单价

* chore: add more model and adjust order

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-05 00:02:15 +08:00
Yang Fei
6f036bd0c9 feat: add embedding-2 support for zhipu (#1273)
* 增加对智谱embedding-2模型的支持

* fix: fix usage & ratio

---------

Co-authored-by: yangfei <yangfei@xuyao.info>
Co-authored-by: JustSong <songquanpeng@foxmail.com>
2024-04-04 23:32:59 +08:00
JustSong
fb90747c23 fix: fix /v1/models return null data when no models available 2024-04-04 18:53:42 +08:00
JustSong
ed70881a58 fix: fix token create 2024-04-04 11:18:21 +08:00
JustSong
8b9fa3d6e4 fix: fix GetGroupModels 2024-04-04 02:58:21 +08:00
JustSong
8b9813d63b feat: /v1/models now only return available models 2024-04-04 02:44:59 +08:00
JustSong
dc7aaf2de5 feat: able to set model limitation for token (close #178) 2024-04-04 02:08:18 +08:00
JustSong
065da8ef8c fix: fix ali function call (#1242) 2024-04-04 00:46:30 +08:00
JustSong
e3cfb1fa52 feat: use given usage if available in stream mode 2024-03-31 23:41:52 +08:00
JustSong
f89ae5ad58 feat: initial function call support for xunfei 2024-03-31 23:12:29 +08:00
JustSong
06a3fc5421 chore: update GeneralOpenAIRequest 2024-03-31 22:23:42 +08:00
ManJieqi
a9c464ec5a fix: update model-ratio.go 修正文心计费模型名称
统一文心计费模型名称
2024-03-30 11:06:31 +08:00
JustSong
3f3c13c98c feat: support top_k for claude (close #1239) 2024-03-30 10:47:07 +08:00
JustSong
2ba28c72cb feat: support function call for ali (close #1242) 2024-03-30 10:43:26 +08:00
80 changed files with 1743 additions and 298 deletions

1
.gitignore vendored
View File

@@ -8,3 +8,4 @@ build
logs
data
/web/node_modules
cmd.md

View File

@@ -81,11 +81,12 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
+ [x] [Groq](https://wow.groq.com/)
+ [x] [Ollama](https://github.com/ollama/ollama)
+ [x] [零一万物](https://platform.lingyiwanwu.com/)
+ [x] [阶跃星辰](https://platform.stepfun.com/)
2. 支持配置镜像以及众多[第三方代理服务](https://iamazing.cn/page/openai-api-third-party-services)。
3. 支持通过**负载均衡**的方式访问多个渠道。
4. 支持 **stream 模式**,可以通过流式传输实现打字机效果。
5. 支持**多机部署**[详见此处](#多机部署)。
6. 支持**令牌管理**,设置令牌的过期时间额度。
6. 支持**令牌管理**,设置令牌的过期时间额度、允许的 IP 范围以及允许的模型访问
7. 支持**兑换码管理**,支持批量生成和导出兑换码,可使用兑换码为账户进行充值。
8. 支持**渠道管理**,批量创建渠道。
9. 支持**用户分组**以及**渠道分组**,支持为不同分组设置不同的倍率。
@@ -101,10 +102,11 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用
19. 支持丰富的**自定义**设置,
1. 支持自定义系统名称logo 以及页脚。
2. 支持自定义首页和关于页面,可以选择使用 HTML & Markdown 代码进行自定义,或者使用一个单独的网页通过 iframe 嵌入。
20. 支持通过系统访问令牌访问管理 APIbearer token用以替代 cookie你可以自行抓包来查看 API 的用法)
20. 支持通过系统访问令牌调用管理 API,进而**在无需二开的情况下扩展和自定义** One API 的功能,详情请参考此处 [API 文档](./docs/API.md)。
21. 支持 Cloudflare Turnstile 用户校验。
22. 支持用户管理,支持**多种用户登录注册方式**
+ 邮箱登录注册(支持注册邮箱白名单)以及通过邮箱进行密码重置。
+ 支持使用飞书进行授权登录。
+ [GitHub 开放授权](https://github.com/settings/applications/new)。
+ 微信公众号授权(需要额外部署 [WeChat Server](https://github.com/songquanpeng/wechat-server))。
23. 支持主题切换,设置环境变量 `THEME` 即可,默认为 `default`,欢迎 PR 更多主题,具体参考[此处](./web/README.md)。

View File

@@ -66,6 +66,9 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
var LarkClientId = ""
var LarkClientSecret = ""
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""

View File

@@ -71,6 +71,7 @@ const (
ChannelTypeGroq
ChannelTypeOllama
ChannelTypeLingYiWanWu
ChannelTypeStepFun
ChannelTypeDummy
)
@@ -108,6 +109,7 @@ var ChannelBaseURLs = []string{
"https://api.groq.com/openai", // 29
"http://localhost:11434", // 30
"https://api.lingyiwanwu.com", // 31
"https://api.stepfun.com", // 32
}
const (

6
common/conv/any.go Normal file
View File

@@ -0,0 +1,6 @@
package conv
func AsString(v any) string {
str, _ := v.(string)
return str
}

View File

@@ -72,14 +72,22 @@ var ModelRatio = map[string]float64{
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
"ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
"ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
"ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
"ERNIE-Bot-8k": 0.024 * RMB,
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"bge-large-8k": 0.002 * RMB,
"ERNIE-4.0-8K": 0.120 * RMB,
"ERNIE-3.5-8K": 0.012 * RMB,
"ERNIE-3.5-8K-0205": 0.024 * RMB,
"ERNIE-3.5-8K-1222": 0.012 * RMB,
"ERNIE-Bot-8K": 0.024 * RMB,
"ERNIE-3.5-4K-0205": 0.012 * RMB,
"ERNIE-Speed-8K": 0.004 * RMB,
"ERNIE-Speed-128K": 0.004 * RMB,
"ERNIE-Lite-8K-0922": 0.008 * RMB,
"ERNIE-Lite-8K-0308": 0.003 * RMB,
"ERNIE-Tiny-8K": 0.001 * RMB,
"BLOOMZ-7B": 0.004 * RMB,
"Embedding-V1": 0.002 * RMB,
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"tao-8k": 0.002 * RMB,
// https://ai.google.dev/pricing
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
@@ -91,15 +99,20 @@ var ModelRatio = map[string]float64{
"glm-4": 0.1 * RMB,
"glm-4v": 0.1 * RMB,
"glm-3-turbo": 0.005 * RMB,
"embedding-2": 0.0005 * RMB,
"chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
"chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
"chatglm_std": 0.3572, // ¥0.005 / 1k tokens
"chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
// https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
"qwen-turbo": 0.5715, // ¥0.008 / 1k tokens
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
"ali-stable-diffusion-xl": 8,
"ali-stable-diffusion-v1.5": 8,
"wanx-v1": 8,
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens

52
common/network/ip.go Normal file
View File

@@ -0,0 +1,52 @@
package network
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common/logger"
"net"
"strings"
)
func splitSubnets(subnets string) []string {
res := strings.Split(subnets, ",")
for i := 0; i < len(res); i++ {
res[i] = strings.TrimSpace(res[i])
}
return res
}
func isValidSubnet(subnet string) error {
_, _, err := net.ParseCIDR(subnet)
if err != nil {
return fmt.Errorf("failed to parse subnet: %w", err)
}
return nil
}
func isIpInSubnet(ctx context.Context, ip string, subnet string) bool {
_, ipNet, err := net.ParseCIDR(subnet)
if err != nil {
logger.Errorf(ctx, "failed to parse subnet: %s", err.Error())
return false
}
return ipNet.Contains(net.ParseIP(ip))
}
func IsValidSubnets(subnets string) error {
for _, subnet := range splitSubnets(subnets) {
if err := isValidSubnet(subnet); err != nil {
return err
}
}
return nil
}
func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool {
for _, subnet := range splitSubnets(subnets) {
if isIpInSubnet(ctx, ip, subnet) {
return true
}
}
return false
}

19
common/network/ip_test.go Normal file
View File

@@ -0,0 +1,19 @@
package network
import (
"context"
"testing"
. "github.com/smartystreets/goconvey/convey"
)
func TestIsIpInSubnet(t *testing.T) {
ctx := context.Background()
ip1 := "192.168.0.5"
ip2 := "125.216.250.89"
subnet := "192.168.0.0/24"
Convey("TestIsIpInSubnet", t, func() {
So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue)
So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse)
})
}

View File

@@ -1,4 +1,4 @@
package controller
package auth
import (
"bytes"
@@ -11,6 +11,7 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -159,7 +160,7 @@ func GitHubOAuth(c *gin.Context) {
})
return
}
setupLogin(&user, c)
controller.SetupLogin(&user, c)
}
func GitHubBind(c *gin.Context) {

201
controller/auth/lark.go Normal file
View File

@@ -0,0 +1,201 @@
package auth
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"time"
)
type LarkOAuthResponse struct {
AccessToken string `json:"access_token"`
}
type LarkUser struct {
Name string `json:"name"`
OpenID string `json:"open_id"`
}
func getLarkUserInfoByCode(code string) (*LarkUser, error) {
if code == "" {
return nil, errors.New("无效的参数")
}
values := map[string]string{
"client_id": config.LarkClientId,
"client_secret": config.LarkClientSecret,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress),
}
jsonData, err := json.Marshal(values)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
client := http.Client{
Timeout: 5 * time.Second,
}
res, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
}
defer res.Body.Close()
var oAuthResponse LarkOAuthResponse
err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
if err != nil {
return nil, err
}
req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
res2, err := client.Do(req)
if err != nil {
logger.SysLog(err.Error())
return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
}
var larkUser LarkUser
err = json.NewDecoder(res2.Body).Decode(&larkUser)
if err != nil {
return nil, err
}
return &larkUser, nil
}
func LarkOAuth(c *gin.Context) {
session := sessions.Default(c)
state := c.Query("state")
if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
c.JSON(http.StatusForbidden, gin.H{
"success": false,
"message": "state is empty or not same",
})
return
}
username := session.Get("username")
if username != nil {
LarkBind(c)
return
}
code := c.Query("code")
larkUser, err := getLarkUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LarkId: larkUser.OpenID,
}
if model.IsLarkIdAlreadyTaken(user.LarkId) {
err := user.FillUserByLarkId()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
if config.RegisterEnabled {
user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1)
if larkUser.Name != "" {
user.DisplayName = larkUser.Name
} else {
user.DisplayName = "Lark User"
}
user.Role = common.RoleCommonUser
user.Status = common.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
} else {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "管理员关闭了新用户注册",
})
return
}
}
if user.Status != common.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
controller.SetupLogin(&user, c)
}
func LarkBind(c *gin.Context) {
code := c.Query("code")
larkUser, err := getLarkUserInfoByCode(code)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user := model.User{
LarkId: larkUser.OpenID,
}
if model.IsLarkIdAlreadyTaken(user.LarkId) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该飞书账户已被绑定",
})
return
}
session := sessions.Default(c)
id := session.Get("id")
// id := c.GetInt("id") // critical bug!
user.Id = id.(int)
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.LarkId = larkUser.OpenID
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "bind",
})
return
}

View File

@@ -1,4 +1,4 @@
package controller
package auth
import (
"encoding/json"
@@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -109,7 +110,7 @@ func WeChatAuth(c *gin.Context) {
})
return
}
setupLogin(&user, c)
controller.SetupLogin(&user, c)
}
func WeChatBind(c *gin.Context) {

View File

@@ -23,6 +23,7 @@ func GetStatus(c *gin.Context) {
"email_verification": config.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId,
"lark_client_id": config.LarkClientId,
"system_name": config.SystemName,
"logo": config.Logo,
"footer_html": config.Footer,

View File

@@ -4,12 +4,14 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http"
"strings"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -39,8 +41,8 @@ type OpenAIModels struct {
Parent *string `json:"parent"`
}
var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels
var models []OpenAIModels
var modelsMap map[string]OpenAIModels
var channelId2Models map[int][]string
func init() {
@@ -68,7 +70,7 @@ func init() {
channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
openAIModels = append(openAIModels, OpenAIModels{
models = append(models, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
@@ -85,7 +87,7 @@ func init() {
}
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
for _, modelName := range channelModelList {
openAIModels = append(openAIModels, OpenAIModels{
models = append(models, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
@@ -96,9 +98,9 @@ func init() {
})
}
}
openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels {
openAIModelsMap[model.Id] = model
modelsMap = make(map[string]OpenAIModels)
for _, model := range models {
modelsMap[model.Id] = model
}
channelId2Models = make(map[int][]string)
for i := 1; i < common.ChannelTypeDummy; i++ {
@@ -119,16 +121,55 @@ func DashboardListModels(c *gin.Context) {
})
}
func ListModels(c *gin.Context) {
func ListAllModels(c *gin.Context) {
c.JSON(200, gin.H{
"object": "list",
"data": openAIModels,
"data": models,
})
}
func ListModels(c *gin.Context) {
ctx := c.Request.Context()
var availableModels []string
if c.GetString("available_models") != "" {
availableModels = strings.Split(c.GetString("available_models"), ",")
} else {
userId := c.GetInt("id")
userGroup, _ := model.CacheGetUserGroup(userId)
availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
}
modelSet := make(map[string]bool)
for _, availableModel := range availableModels {
modelSet[availableModel] = true
}
availableOpenAIModels := make([]OpenAIModels, 0)
for _, model := range models {
if _, ok := modelSet[model.Id]; ok {
modelSet[model.Id] = false
availableOpenAIModels = append(availableOpenAIModels, model)
}
}
for modelName, ok := range modelSet {
if ok {
availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "custom",
Root: modelName,
Parent: nil,
})
}
}
c.JSON(200, gin.H{
"object": "list",
"data": availableOpenAIModels,
})
}
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
if model, ok := openAIModelsMap[modelId]; ok {
if model, ok := modelsMap[modelId]; ok {
c.JSON(200, model)
} else {
Error := relaymodel.Error{
@@ -142,3 +183,30 @@ func RetrieveModel(c *gin.Context) {
})
}
}
func GetUserAvailableModels(c *gin.Context) {
ctx := c.Request.Context()
id := c.GetInt("id")
userGroup, err := model.CacheGetUserGroup(id)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
models, err := model.CacheGetGroupModels(ctx, userGroup)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": models,
})
return
}

View File

@@ -1,10 +1,12 @@
package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -104,6 +106,19 @@ func GetTokenStatus(c *gin.Context) {
})
}
func validateToken(c *gin.Context, token model.Token) error {
if len(token.Name) > 30 {
return fmt.Errorf("令牌名称过长")
}
if token.Subnet != nil && *token.Subnet != "" {
err := network.IsValidSubnets(*token.Subnet)
if err != nil {
return fmt.Errorf("无效的网段:%s", err.Error())
}
}
return nil
}
func AddToken(c *gin.Context) {
token := model.Token{}
err := c.ShouldBindJSON(&token)
@@ -114,13 +129,15 @@ func AddToken(c *gin.Context) {
})
return
}
if len(token.Name) > 30 {
err = validateToken(c, token)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌名称过长",
"message": fmt.Sprintf("参数错误:%s", err.Error()),
})
return
}
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
@@ -130,6 +147,8 @@ func AddToken(c *gin.Context) {
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
Models: token.Models,
Subnet: token.Subnet,
}
err = cleanToken.Insert()
if err != nil {
@@ -177,10 +196,11 @@ func UpdateToken(c *gin.Context) {
})
return
}
if len(token.Name) > 30 {
err = validateToken(c, token)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌名称过长",
"message": fmt.Sprintf("参数错误:%s", err.Error()),
})
return
}
@@ -216,6 +236,8 @@ func UpdateToken(c *gin.Context) {
cleanToken.ExpiredTime = token.ExpiredTime
cleanToken.RemainQuota = token.RemainQuota
cleanToken.UnlimitedQuota = token.UnlimitedQuota
cleanToken.Models = token.Models
cleanToken.Subnet = token.Subnet
}
err = cleanToken.Update()
if err != nil {

View File

@@ -58,11 +58,11 @@ func Login(c *gin.Context) {
})
return
}
setupLogin(&user, c)
SetupLogin(&user, c)
}
// setup session & cookies and then return user info
func setupLogin(user *model.User, c *gin.Context) {
func SetupLogin(user *model.User, c *gin.Context) {
session := sessions.Default(c)
session.Set("id", user.Id)
session.Set("username", user.Username)
@@ -180,27 +180,27 @@ func Register(c *gin.Context) {
}
func GetAllUsers(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
order := c.DefaultQuery("order", "")
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": users,
})
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
order := c.DefaultQuery("order", "")
users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": users,
})
}
func SearchUsers(c *gin.Context) {
@@ -770,3 +770,38 @@ func TopUp(c *gin.Context) {
})
return
}
type adminTopUpRequest struct {
UserId int `json:"user_id"`
Quota int `json:"quota"`
Remark string `json:"remark"`
}
func AdminTopUp(c *gin.Context) {
req := adminTopUpRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
err = model.IncreaseUserQuota(req.UserId, int64(req.Quota))
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
if req.Remark == "" {
req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
}
model.RecordTopupLog(req.UserId, req.Remark, req.Quota)
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
})
return
}

53
docs/API.md Normal file
View File

@@ -0,0 +1,53 @@
# 使用 API 操控 & 扩展 One API
> 欢迎提交 PR 在此放上你的拓展项目。
例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。
又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。
## 鉴权
One API 支持两种鉴权方式Cookie 和 Token对于 Token参照下图获取
![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/c15281a7-83ed-47cb-a1f6-913cb6bf4a7c)
之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API
![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1273b7ae-cb60-4c0d-93a6-b1cbc039c4f8)
## 请求格式与响应格式
One API 使用 JSON 格式进行请求和响应。
对于响应体,一般格式如下:
```json
{
"message": "请求信息",
"success": true,
"data": {}
}
```
## API 列表
> 当前 API 列表不全,请自行通过浏览器抓取前端请求
如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。
### 获取当前登录用户信息
**GET** `/api/user/self`
### 为给定用户充值额度
**POST** `/api/topup`
```json
{
"user_id": 1,
"quota": 100000,
"remark": "充值 100000 额度"
}
```
## 其他
### 充值链接上的附加参数
One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如:
`https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837`
你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。
注意,不是所有主题都支持该功能,欢迎 PR 补齐。

4
go.mod
View File

@@ -15,6 +15,7 @@ require (
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/pkoukk/tiktoken-go v0.1.5
github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.8.3
golang.org/x/crypto v0.17.0
golang.org/x/image v0.14.0
@@ -37,6 +38,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-sql-driver/mysql v1.6.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
@@ -47,6 +49,7 @@ require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/jtolds/gls v4.20.0+incompatible // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
@@ -55,6 +58,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/smarty/assertions v1.15.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect

12
go.sum
View File

@@ -56,11 +56,13 @@ github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keL
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
@@ -85,6 +87,8 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
@@ -127,6 +131,10 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
@@ -177,8 +185,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=

View File

@@ -1,10 +1,12 @@
package middleware
import (
"fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/model"
"net/http"
"strings"
@@ -88,6 +90,7 @@ func RootAuth() func(c *gin.Context) {
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
ctx := c.Request.Context()
key := c.Request.Header.Get("Authorization")
key = strings.TrimPrefix(key, "Bearer ")
key = strings.TrimPrefix(key, "sk-")
@@ -98,6 +101,12 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusUnauthorized, err.Error())
return
}
if token.Subnet != nil && *token.Subnet != "" {
if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s当前 ip%s", *token.Subnet, c.ClientIP()))
return
}
}
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
@@ -107,6 +116,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
requestModel, err := getRequestModel(c)
if err != nil && !strings.HasPrefix(c.Request.URL.Path, "/v1/models") {
abortWithMessage(c, http.StatusBadRequest, err.Error())
return
}
c.Set("request_model", requestModel)
if token.Models != nil && *token.Models != "" {
c.Set("available_models", *token.Models)
if requestModel != "" && !isModelInList(requestModel, *token.Models) {
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
return
}
}
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)

View File

@@ -2,14 +2,12 @@ package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type ModelRequest struct {
@@ -40,37 +38,11 @@ func Distribute() func(c *gin.Context) {
return
}
} else {
// Select a channel for the user
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
requestModel := c.GetString("request_model")
var err error
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
if err != nil {
abortWithMessage(c, http.StatusBadRequest, "无效的请求")
return
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e-2"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
}
requestModel = modelRequest.Model
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
if err != nil {
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
if channel != nil {
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"

View File

@@ -1,9 +1,12 @@
package middleware
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"strings"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
@@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.Abort()
logger.Error(c.Request.Context(), message)
}
func getRequestModel(c *gin.Context) (string, error) {
var modelRequest ModelRequest
err := common.UnmarshalBodyReusable(c, &modelRequest)
if err != nil {
return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
if modelRequest.Model == "" {
modelRequest.Model = "text-moderation-stable"
}
}
if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
if modelRequest.Model == "" {
modelRequest.Model = c.Param("model")
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
if modelRequest.Model == "" {
modelRequest.Model = "dall-e-2"
}
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
if modelRequest.Model == "" {
modelRequest.Model = "whisper-1"
}
}
return modelRequest.Model, nil
}
func isModelInList(modelName string, models string) bool {
modelList := strings.Split(models, ",")
for _, model := range modelList {
if modelName == model {
return true
}
}
return false
}

View File

@@ -1,8 +1,10 @@
package model
import (
"context"
"github.com/songquanpeng/one-api/common"
"gorm.io/gorm"
"sort"
"strings"
)
@@ -88,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error {
func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}
func GetGroupModels(ctx context.Context, group string) ([]string, error) {
groupCol := "`group`"
trueVal := "1"
if common.UsingPostgreSQL {
groupCol = `"group"`
trueVal = "true"
}
var models []string
err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error
if err != nil {
return nil, err
}
sort.Strings(models)
return models, err
}

View File

@@ -21,6 +21,7 @@ var (
UserId2GroupCacheSeconds = config.SyncFrequency
UserId2QuotaCacheSeconds = config.SyncFrequency
UserId2StatusCacheSeconds = config.SyncFrequency
GroupModelsCacheSeconds = config.SyncFrequency
)
func CacheGetTokenByKey(key string) (*Token, error) {
@@ -146,6 +147,25 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err
}
func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
if !common.RedisEnabled {
return GetGroupModels(ctx, group)
}
modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group))
if err == nil {
return strings.Split(modelsStr, ","), nil
}
models, err := GetGroupModels(ctx, group)
if err != nil {
return nil, err
}
err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second)
if err != nil {
logger.SysError("Redis set group models error: " + err.Error())
}
return models, nil
}
var group2model2channels map[string]map[string][]*Channel
var channelSyncLock sync.RWMutex

View File

@@ -51,6 +51,21 @@ func RecordLog(userId int, logType int, content string) {
}
}
func RecordTopupLog(userId int, content string, quota int) {
log := &Log{
UserId: userId,
Username: GetUsernameById(userId),
CreatedAt: helper.GetTimestamp(),
Type: LogTypeTopup,
Content: content,
Quota: quota,
}
err := LOG_DB.Create(log).Error
if err != nil {
logger.SysError("failed to record log: " + err.Error())
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled {

View File

@@ -172,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) {
config.GitHubClientId = value
case "GitHubClientSecret":
config.GitHubClientSecret = value
case "LarkClientId":
config.LarkClientId = value
case "LarkClientSecret":
config.LarkClientSecret = value
case "Footer":
config.Footer = value
case "SystemName":

View File

@@ -12,24 +12,26 @@ import (
)
type Token struct {
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
Id int `json:"id"`
UserId int `json:"user_id"`
Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index" `
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
Models *string `json:"models" gorm:"default:''"` // allowed models
Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
}
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
var tokens []*Token
var err error
query := DB.Where("user_id = ?", userId)
switch order {
case "remain_quota":
query = query.Order("unlimited_quota desc, remain_quota desc")
@@ -38,7 +40,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token
default:
query = query.Order("id desc")
}
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
return tokens, err
}
@@ -61,7 +63,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
return nil, errors.New("令牌验证失败")
}
if token.Status == common.TokenStatusExhausted {
return nil, errors.New("令牌额度已用尽")
return nil, fmt.Errorf("令牌 %s#%d额度已用尽", token.Name, token.Id)
} else if token.Status == common.TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
@@ -121,7 +123,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error
return err
}

View File

@@ -24,6 +24,7 @@ type User struct {
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int64 `json:"quota" gorm:"bigint;default:0"`
@@ -41,21 +42,21 @@ func GetMaxUserId() int {
}
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
switch order {
case "quota":
query = query.Order("quota desc")
case "used_quota":
query = query.Order("used_quota desc")
case "request_count":
query = query.Order("request_count desc")
default:
query = query.Order("id desc")
}
err = query.Find(&users).Error
return users, err
query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
switch order {
case "quota":
query = query.Order("quota desc")
case "used_quota":
query = query.Order("used_quota desc")
case "request_count":
query = query.Order("request_count desc")
default:
query = query.Order("id desc")
}
err = query.Find(&users).Error
return users, err
}
func SearchUsers(keyword string) (users []*User, err error) {
@@ -206,6 +207,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
func (user *User) FillUserByLarkId() error {
if user.LarkId == "" {
return errors.New("lark id 为空!")
}
DB.Where(User{LarkId: user.LarkId}).First(user)
return nil
}
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -234,6 +243,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsLarkIdAlreadyTaken(githubId string) bool {
return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}

View File

@@ -38,6 +38,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return aiProxyLibraryRequest, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}

View File

@@ -23,10 +23,16 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings {
fullRequestURL := ""
switch meta.Mode {
case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
default:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
}
return fullRequestURL, nil
}
@@ -34,10 +40,12 @@ func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *ut
channel.SetupCommonRequestHeader(c, req, meta)
if meta.IsStream {
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("X-DashScope-SSE", "enable")
}
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.IsStream {
req.Header.Set("X-DashScope-SSE", "enable")
if meta.Mode == constant.RelayModeImagesGenerations {
req.Header.Set("X-DashScope-Async", "enable")
}
if c.GetString(common.ConfigKeyPlugin) != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
@@ -51,14 +59,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
}
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
return aliEmbeddingRequest, nil
default:
baiduRequest := ConvertRequest(*request)
return baiduRequest, nil
aliRequest := ConvertRequest(*request)
return aliRequest, nil
}
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
aliRequest := ConvertImageRequest(*request)
return aliRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
@@ -70,6 +87,8 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
switch meta.Mode {
case constant.RelayModeEmbeddings:
err, usage = EmbeddingHandler(c, resp)
case constant.RelayModeImagesGenerations:
err, usage = ImageHandler(c, resp)
default:
err, usage = Handler(c, resp)
}

View File

@@ -3,4 +3,5 @@ package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1",
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
}

192
relay/channel/ali/image.go Normal file
View File

@@ -0,0 +1,192 @@
package ali
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"time"
)
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
responseFormat := c.GetString("response_format")
var aliTaskResponse TaskResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {
logger.SysError("aliAsyncTask err: " + string(responseBody))
return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
}
aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey)
if err != nil {
return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, nil
}
func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) {
url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID)
var aliResponse TaskResponse
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return &aliResponse, err, nil
}
req.Header.Set("Authorization", "Bearer "+key)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
logger.SysError("aliAsyncTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
var response TaskResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
logger.SysError("aliAsyncTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
return &response, nil, responseBody
}
func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) {
waitSeconds := 2
step := 0
maxStep := 20
var taskResponse TaskResponse
var responseBody []byte
for {
step++
rsp, err, body := asyncTask(taskID, key)
responseBody = body
if err != nil {
return &taskResponse, responseBody, err
}
if rsp.Output.TaskStatus == "" {
return &taskResponse, responseBody, nil
}
switch rsp.Output.TaskStatus {
case "FAILED":
fallthrough
case "CANCELED":
fallthrough
case "SUCCEEDED":
fallthrough
case "UNKNOWN":
return rsp, responseBody, nil
}
if step >= maxStep {
break
}
time.Sleep(time.Duration(waitSeconds) * time.Second)
}
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
}
func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse {
imageResponse := openai.ImageResponse{
Created: helper.GetTimestamp(),
}
for _, data := range response.Output.Results {
var b64Json string
if responseFormat == "b64_json" {
// 读取 data.Url 的图片数据并转存到 b64Json
imageData, err := getImageData(data.Url)
if err != nil {
// 处理获取图片数据失败的情况
logger.SysError("getImageData Error getting image data: " + err.Error())
continue
}
// 将图片数据转为 Base64 编码的字符串
b64Json = Base64Encode(imageData)
} else {
// 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image
b64Json = data.B64Image
}
imageResponse.Data = append(imageResponse.Data, openai.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
return &imageResponse
}
func getImageData(url string) ([]byte, error) {
response, err := http.Get(url)
if err != nil {
return nil, err
}
defer response.Body.Close()
imageData, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}
return imageData, nil
}
func Base64Encode(data []byte) string {
b64Json := base64.StdEncoding.EncodeToString(data)
return b64Json
}

View File

@@ -48,6 +48,9 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
MaxTokens: request.MaxTokens,
Temperature: request.Temperature,
TopP: request.TopP,
TopK: request.TopK,
ResultFormat: "message",
Tools: request.Tools,
},
}
}
@@ -63,6 +66,17 @@ func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingReque
}
}
func ConvertImageRequest(request model.ImageRequest) *ImageRequest {
var imageRequest ImageRequest
imageRequest.Input.Prompt = request.Prompt
imageRequest.Model = request.Model
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
imageRequest.Parameters.N = request.N
imageRequest.ResponseFormat = request.ResponseFormat
return &imageRequest
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse)
@@ -117,19 +131,11 @@ func embeddingResponseAli2OpenAI(response *EmbeddingResponse) *openai.EmbeddingR
}
func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := openai.TextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
Choices: response.Output.Choices,
Usage: model.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
@@ -140,10 +146,14 @@ func responseAli2OpenAI(response *ChatResponse) *openai.TextResponse {
}
func streamResponseAli2OpenAI(aliResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
if len(aliResponse.Output.Choices) == 0 {
return nil
}
aliChoice := aliResponse.Output.Choices[0]
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = aliResponse.Output.Text
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.Delta = aliChoice.Message
if aliChoice.FinishReason != "null" {
finishReason := aliChoice.FinishReason
choice.FinishReason = &finishReason
}
response := openai.ChatCompletionsStreamResponse{
@@ -204,6 +214,9 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
if response == nil {
return true
}
//response.Choices[0].Delta.Content = strings.TrimPrefix(response.Choices[0].Delta.Content, lastResponseText)
//lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
@@ -226,6 +239,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
ctx := c.Request.Context()
var aliResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -235,6 +249,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
logger.Debugf(ctx, "response body: %s\n", responseBody)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil

View File

@@ -1,5 +1,10 @@
package ali
import (
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
@@ -11,13 +16,15 @@ type Input struct {
}
type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
ResultFormat string `json:"result_format,omitempty"`
Tools []model.Tool `json:"tools,omitempty"`
}
type ChatRequest struct {
@@ -26,6 +33,79 @@ type ChatRequest struct {
Parameters Parameters `json:"parameters,omitempty"`
}
type ImageRequest struct {
Model string `json:"model"`
Input struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
} `json:"input"`
Parameters struct {
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
} `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
type TaskResponse struct {
StatusCode int `json:"status_code,omitempty"`
RequestId string `json:"request_id,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Output struct {
TaskId string `json:"task_id,omitempty"`
TaskStatus string `json:"task_status,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Results []struct {
B64Image string `json:"b64_image,omitempty"`
Url string `json:"url,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
} `json:"results,omitempty"`
TaskMetrics struct {
Total int `json:"TOTAL,omitempty"`
Succeeded int `json:"SUCCEEDED,omitempty"`
Failed int `json:"FAILED,omitempty"`
} `json:"task_metrics,omitempty"`
} `json:"output,omitempty"`
Usage Usage `json:"usage"`
}
type Header struct {
Action string `json:"action,omitempty"`
Streaming string `json:"streaming,omitempty"`
TaskID string `json:"task_id,omitempty"`
Event string `json:"event,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
Attributes any `json:"attributes,omitempty"`
}
type Payload struct {
Model string `json:"model,omitempty"`
Task string `json:"task,omitempty"`
TaskGroup string `json:"task_group,omitempty"`
Function string `json:"function,omitempty"`
Parameters struct {
SampleRate int `json:"sample_rate,omitempty"`
Rate float64 `json:"rate,omitempty"`
Format string `json:"format,omitempty"`
} `json:"parameters,omitempty"`
Input struct {
Text string `json:"text,omitempty"`
} `json:"input,omitempty"`
Usage struct {
Characters int `json:"characters,omitempty"`
} `json:"usage,omitempty"`
}
type WSSMessage struct {
Header Header `json:"header,omitempty"`
Payload Payload `json:"payload,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input struct {
@@ -62,8 +142,9 @@ type Usage struct {
}
type Output struct {
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
//Text string `json:"text"`
//FinishReason string `json:"finish_reason"`
Choices []openai.TextResponseChoice `json:"choices"`
}
type ChatResponse struct {

View File

@@ -41,6 +41,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}

View File

@@ -38,6 +38,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
MaxTokens: textRequest.MaxTokens,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokens == 0 {

View File

@@ -38,16 +38,34 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
suffix += "completions_pro"
case "ERNIE-Bot-4":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-Bot":
suffix += "completions"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-4.0-8K":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-3.5-8K-0205":
suffix += "ernie-3.5-8k-0205"
case "ERNIE-3.5-8K-1222":
suffix += "ernie-3.5-8k-1222"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-3.5-4K-0205":
suffix += "ernie-3.5-4k-0205"
case "ERNIE-Speed-8K":
suffix += "ernie_speed"
case "ERNIE-Speed-128K":
suffix += "ernie-speed-128k"
case "ERNIE-Lite-8K-0922":
suffix += "eb-instant"
case "ERNIE-Lite-8K-0308":
suffix += "ernie-lite-8k"
case "ERNIE-Tiny-8K":
suffix += "ernie-tiny-8k"
case "BLOOMZ-7B":
suffix += "bloomz_7b1"
case "Embedding-V1":
@@ -59,7 +77,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
case "tao-8k":
suffix += "tao_8k"
default:
suffix += meta.ActualModelName
suffix += strings.ToLower(meta.ActualModelName)
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", meta.BaseURL, suffix)
var accessToken string
@@ -91,6 +109,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
}
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}

View File

@@ -1,11 +1,18 @@
package baidu
var ModelList = []string{
"ERNIE-Bot-4",
"ERNIE-4.0-8K",
"ERNIE-3.5-8K",
"ERNIE-3.5-8K-0205",
"ERNIE-3.5-8K-1222",
"ERNIE-Bot-8K",
"ERNIE-Bot",
"ERNIE-Speed",
"ERNIE-Bot-turbo",
"ERNIE-3.5-4K-0205",
"ERNIE-Speed-8K",
"ERNIE-Speed-128K",
"ERNIE-Lite-8K-0922",
"ERNIE-Lite-8K-0308",
"ERNIE-Tiny-8K",
"BLOOMZ-7B",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",

View File

@@ -42,6 +42,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}

View File

@@ -13,6 +13,7 @@ type Adaptor interface {
GetRequestURL(meta *util.RelayMeta) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
ConvertImageRequest(request *model.ImageRequest) (any, error)
DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
GetModelList() []string

View File

@@ -48,6 +48,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
}
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
@@ -25,6 +26,13 @@ func (a *Adaptor) Init(meta *util.RelayMeta) {
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
switch meta.ChannelType {
case common.ChannelTypeAzure:
if meta.Mode == constant.RelayModeImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion)
return fullRequestURL, nil
}
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
@@ -63,6 +71,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return request, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
@@ -70,10 +85,17 @@ func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText, _ = StreamHandler(c, resp, meta.Mode)
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
err, responseText, usage = StreamHandler(c, resp, meta.Mode)
if usage == nil {
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
}
} else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
switch meta.Mode {
case constant.RelayModeImagesGenerations:
err, _ = ImageHandler(c, resp)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
}
return
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/channel/stepfun"
)
var CompatibleChannels = []int{
@@ -20,6 +21,7 @@ var CompatibleChannels = []int{
common.ChannelTypeMistral,
common.ChannelTypeGroq,
common.ChannelTypeLingYiWanWu,
common.ChannelTypeStepFun,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
@@ -40,6 +42,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "groq", groq.ModelList
case common.ChannelTypeLingYiWanWu:
return "lingyiwanwu", lingyiwanwu.ModelList
case common.ChannelTypeStepFun:
return "stepfun", stepfun.ModelList
default:
return "openai", ModelList
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
@@ -53,7 +54,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
continue // just ignore the error
}
for _, choice := range streamResponse.Choices {
responseText += choice.Delta.Content
responseText += conv.AsString(choice.Delta.Content)
}
if streamResponse.Usage != nil {
usage = streamResponse.Usage
@@ -148,3 +149,37 @@ func Handler(c *gin.Context, resp *http.Response, promptTokens int, modelName st
}
return nil, &textResponse.Usage
}
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var imageResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &imageResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, nil
}

View File

@@ -110,20 +110,22 @@ type EmbeddingResponse struct {
model.Usage `json:"usage"`
}
type ImageData struct {
Url string `json:"url,omitempty"`
B64Json string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
type ImageResponse struct {
Created int `json:"created"`
Data []struct {
Url string `json:"url"`
}
Created int64 `json:"created"`
Data []ImageData `json:"data"`
//model.Usage `json:"usage"`
}
type ChatCompletionsStreamResponseChoice struct {
Index int `json:"index"`
Delta struct {
Content string `json:"content"`
Role string `json:"role,omitempty"`
} `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
Index int `json:"index"`
Delta model.Message `json:"delta"`
FinishReason *string `json:"finish_reason,omitempty"`
}
type ChatCompletionsStreamResponse struct {

View File

@@ -36,6 +36,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}

View File

@@ -0,0 +1,7 @@
package stepfun
var ModelList = []string{
"step-1-32k",
"step-1v-32k",
"step-1-200k",
}

View File

@@ -52,6 +52,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return tencentRequest, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}

View File

@@ -10,6 +10,7 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
@@ -129,7 +130,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
}
response := streamResponseTencent2OpenAI(&TencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.Content
responseText += conv.AsString(response.Choices[0].Delta.Content)
}
jsonResponse, err := json.Marshal(response)
if err != nil {

View File

@@ -38,6 +38,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
// xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{}

View File

@@ -26,7 +26,11 @@ import (
func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string, domain string) *ChatRequest {
messages := make([]Message, 0, len(request.Messages))
var lastToolCalls []model.Tool
for _, message := range request.Messages {
if message.ToolCalls != nil {
lastToolCalls = message.ToolCalls
}
messages = append(messages, Message{
Role: message.Role,
Content: message.StringContent(),
@@ -39,9 +43,33 @@ func requestOpenAI2Xunfei(request model.GeneralOpenAIRequest, xunfeiAppId string
xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
xunfeiRequest.Payload.Message.Text = messages
if len(lastToolCalls) != 0 {
for _, toolCall := range lastToolCalls {
xunfeiRequest.Payload.Functions.Text = append(xunfeiRequest.Payload.Functions.Text, toolCall.Function)
}
}
return &xunfeiRequest
}
func getToolCalls(response *ChatResponse) []model.Tool {
var toolCalls []model.Tool
if len(response.Payload.Choices.Text) == 0 {
return toolCalls
}
item := response.Payload.Choices.Text[0]
if item.FunctionCall == nil {
return toolCalls
}
toolCall := model.Tool{
Id: fmt.Sprintf("call_%s", helper.GetUUID()),
Type: "function",
Function: *item.FunctionCall,
}
toolCalls = append(toolCalls, toolCall)
return toolCalls
}
func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
if len(response.Payload.Choices.Text) == 0 {
response.Payload.Choices.Text = []ChatResponseTextItem{
@@ -53,8 +81,9 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
Role: "assistant",
Content: response.Payload.Choices.Text[0].Content,
ToolCalls: getToolCalls(response),
},
FinishReason: constant.StopFinishReason,
}
@@ -78,6 +107,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl
}
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = xunfeiResponse.Payload.Choices.Text[0].Content
choice.Delta.ToolCalls = getToolCalls(xunfeiResponse)
if xunfeiResponse.Payload.Choices.Status == 2 {
choice.FinishReason = &constant.StopFinishReason
}

View File

@@ -26,13 +26,18 @@ type ChatRequest struct {
Message struct {
Text []Message `json:"text"`
} `json:"message"`
Functions struct {
Text []model.Function `json:"text,omitempty"`
} `json:"functions,omitempty"`
} `json:"payload"`
}
type ChatResponseTextItem struct {
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
Content string `json:"content"`
Role string `json:"role"`
Index int `json:"index"`
ContentType string `json:"content_type"`
FunctionCall *model.Function `json:"function_call"`
}
type ChatResponse struct {

View File

@@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
@@ -35,6 +36,9 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
if a.APIVersion == "v4" {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
}
if meta.Mode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
}
method := "invoke"
if meta.IsStream {
method = "sse-invoke"
@@ -53,18 +57,31 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
if request == nil {
return nil, errors.New("request is nil")
}
// TopP (0.0, 1.0)
request.TopP = math.Min(0.99, request.TopP)
request.TopP = math.Max(0.01, request.TopP)
switch relayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil
default:
// TopP (0.0, 1.0)
request.TopP = math.Min(0.99, request.TopP)
request.TopP = math.Max(0.01, request.TopP)
// Temperature (0.0, 1.0)
request.Temperature = math.Min(0.99, request.Temperature)
request.Temperature = math.Max(0.01, request.Temperature)
a.SetVersionByModeName(request.Model)
if a.APIVersion == "v4" {
return request, nil
// Temperature (0.0, 1.0)
request.Temperature = math.Min(0.99, request.Temperature)
request.Temperature = math.Max(0.01, request.Temperature)
a.SetVersionByModeName(request.Model)
if a.APIVersion == "v4" {
return request, nil
}
return ConvertRequest(*request), nil
}
return ConvertRequest(*request), nil
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
@@ -84,14 +101,26 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
if a.APIVersion == "v4" {
return a.DoResponseV4(c, resp, meta)
}
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp)
if meta.Mode == constant.RelayModeEmbeddings {
err, usage = EmbeddingsHandler(c, resp)
} else {
err, usage = Handler(c, resp)
}
}
return
}
func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
return &EmbeddingRequest{
Model: "embedding-2",
Input: request.Input.(string),
}
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}

View File

@@ -2,5 +2,5 @@ package zhipu
var ModelList = []string{
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
"glm-4", "glm-4v", "glm-3-turbo",
"glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
}

View File

@@ -254,3 +254,50 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var zhipuResponse EmbeddingRespone
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
fullTextResponse := embeddingResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),
Model: response.Model,
Usage: model.Usage{
PromptTokens: response.PromptTokens,
CompletionTokens: response.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
},
}
for _, item := range response.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, openai.EmbeddingResponseItem{
Object: `embedding`,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}

View File

@@ -44,3 +44,21 @@ type tokenData struct {
Token string
ExpiryTime time.Time
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input string `json:"input"`
}
type EmbeddingRespone struct {
Model string `json:"model"`
Object string `json:"object"`
Embeddings []EmbeddingData `json:"data"`
model.Usage `json:"usage"`
}
type EmbeddingData struct {
Index int `json:"index"`
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
}

View File

@@ -11,14 +11,31 @@ var DalleSizeRatios = map[string]map[string]float64{
"1024x1792": 2,
"1792x1024": 2,
},
"stable-diffusion-xl": {
"512x1024": 1,
"1024x768": 1,
"1024x1024": 1,
"576x1024": 1,
"1024x576": 1,
},
"wanx-v1": {
"1024x1024": 1,
"720x1280": 1,
"1280x720": 1,
},
}
var DalleGenerationImageAmounts = map[string][2]int{
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
"dall-e-2": {1, 10},
"dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
"stable-diffusion-xl": {1, 4}, // Ali
"wanx-v1": {1, 4}, // Ali
}
var DalleImagePromptLengthLimitations = map[string]int{
"dall-e-2": 1000,
"dall-e-3": 4000,
"dall-e-2": 1000,
"dall-e-3": 4000,
"stable-diffusion-xl": 4000,
"wanx-v1": 4000,
"cogview-3": 833,
}

View File

@@ -36,8 +36,8 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
return textRequest, nil
}
func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) {
imageRequest := &openai.ImageRequest{}
func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
imageRequest := &relaymodel.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
@@ -54,7 +54,7 @@ func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error
return imageRequest, nil
}
func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode {
func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode {
// model validation
_, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
if !hasValidSize {
@@ -77,7 +77,7 @@ func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMet
return nil
}
func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) {
func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}

View File

@@ -6,18 +6,17 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
)
func isWithinRange(element string, value int) bool {
@@ -56,15 +55,6 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
}
requestURL := c.Request.URL.String()
fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
if meta.ChannelType == common.ChannelTypeAzure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
apiVersion := util.GetAzureAPIVersion(c)
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion)
}
var requestBody io.Reader
if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
@@ -76,6 +66,29 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = c.Request.Body
}
adaptor := helper.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
switch meta.ChannelType {
case common.ChannelTypeAli:
fallthrough
case common.ChannelTypeBaidu:
fallthrough
case common.ChannelTypeZhipu:
finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
}
jsonStr, err := json.Marshal(finalRequest)
if err != nil {
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
}
requestBody = bytes.NewBuffer(jsonStr)
}
modelRatio := common.GetModelRatio(imageRequest.Model)
groupRatio := common.GetGroupRatio(meta.Group)
ratio := modelRatio * groupRatio
@@ -87,36 +100,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
token := c.Request.Header.Get("Authorization")
if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication
token = strings.TrimPrefix(token, "Bearer ")
req.Header.Set("api-key", token)
} else {
req.Header.Set("Authorization", token)
}
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
resp, err := util.HTTPClient.Do(req)
// do request
resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
err = req.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
err = c.Request.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
var imageResponse openai.ImageResponse
defer func(ctx context.Context) {
if resp.StatusCode != http.StatusOK {
return
@@ -139,34 +129,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}(c.Request.Context())
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
err = json.Unmarshal(responseBody, &imageResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
// do response
_, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
return respErr
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
return nil
}

View File

@@ -5,25 +5,29 @@ type ResponseFormat struct {
}
type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Prompt any `json:"prompt,omitempty"`
Stream bool `json:"stream,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
N int `json:"n,omitempty"`
Input any `json:"input,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
Functions any `json:"functions,omitempty"`
Model string `json:"model,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Tools any `json:"tools,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
FunctionCall any `json:"function_call,omitempty"`
Functions any `json:"functions,omitempty"`
User string `json:"user,omitempty"`
Prompt any `json:"prompt,omitempty"`
Input any `json:"input,omitempty"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
Instruction string `json:"instruction,omitempty"`
Size string `json:"size,omitempty"`
}
func (r GeneralOpenAIRequest) ParseInput() []string {

12
relay/model/image.go Normal file
View File

@@ -0,0 +1,12 @@
package model
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt" binding:"required"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
Style string `json:"style,omitempty"`
User string `json:"user,omitempty"`
}

View File

@@ -1,9 +1,10 @@
package model
type Message struct {
Role string `json:"role"`
Content any `json:"content"`
Name *string `json:"name,omitempty"`
Role string `json:"role,omitempty"`
Content any `json:"content,omitempty"`
Name *string `json:"name,omitempty"`
ToolCalls []Tool `json:"tool_calls,omitempty"`
}
func (m Message) IsStringContent() bool {

14
relay/model/tool.go Normal file
View File

@@ -0,0 +1,14 @@
package model
type Tool struct {
Id string `json:"id,omitempty"`
Type string `json:"type"`
Function Function `json:"function"`
}
type Function struct {
Description string `json:"description,omitempty"`
Name string `json:"name"`
Parameters any `json:"parameters,omitempty"` // request
Arguments any `json:"arguments,omitempty"` // response
}

View File

@@ -46,6 +46,15 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
//if strings.Contains(err.Message, "quota") {
// return true
//}
if strings.Contains(err.Message, "credit") {
return true
}
if strings.Contains(err.Message, "balance") {
return true
}
return false
}

View File

@@ -2,6 +2,7 @@ package router
import (
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/controller/auth"
"github.com/songquanpeng/one-api/middleware"
"github.com/gin-contrib/gzip"
@@ -21,11 +22,13 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification)
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth)
apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth)
apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode)
apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)
apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), auth.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp)
userRoute := apiRouter.Group("/user")
{
@@ -43,6 +46,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", controller.TopUp)
selfRoute.GET("/available_models", controller.GetUserAvailableModels)
}
adminRoute := userRoute.Group("/")
@@ -68,7 +72,7 @@ func SetApiRouter(router *gin.Engine) {
{
channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels)
channelRoute.GET("/models", controller.ListModels)
channelRoute.GET("/models", controller.ListAllModels)
channelRoute.GET("/:id", controller.GetChannel)
channelRoute.GET("/test", controller.TestChannels)
channelRoute.GET("/test/:id", controller.TestChannel)

View File

@@ -2,6 +2,9 @@
> 每个文件夹代表一个主题,欢迎提交你的主题
> [!WARNING]
> 不是每一个主题都及时同步了所有功能,由于精力有限,优先更新默认主题,其他主题欢迎 & 期待 PR
## 提交新的主题
> 欢迎在页面底部保留你和 One API 的版权信息以及指向链接

View File

@@ -107,6 +107,12 @@ export const CHANNEL_OPTIONS = {
value: 31,
color: 'primary'
},
32: {
key: 32,
text: '阶跃星辰',
value: 32,
color: 'primary'
},
8: {
key: 8,
text: '自定义渠道',

View File

@@ -18,7 +18,7 @@ export const snackbarConstants = {
},
NOTICE: {
variant: 'info',
autoHideDuration: 20000
autoHideDuration: 7000
}
},
Mobile: {

View File

@@ -51,9 +51,9 @@ export function showError(error) {
export function showNotice(message, isHTML = false) {
if (isHTML) {
enqueueSnackbar(<SnackbarHTMLContent htmlContent={message} />, getSnackbarOptions('INFO'));
enqueueSnackbar(<SnackbarHTMLContent htmlContent={message} />, getSnackbarOptions('NOTICE'));
} else {
enqueueSnackbar(message, getSnackbarOptions('INFO'));
enqueueSnackbar(message, getSnackbarOptions('NOTICE'));
}
}

View File

@@ -340,7 +340,9 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
},
}}
>
{Object.values(CHANNEL_OPTIONS).map((option) => {
{Object.values(CHANNEL_OPTIONS).sort((a, b) => {
return a.text.localeCompare(b.text)
}).map((option) => {
return (
<MenuItem key={option.value} value={option.value}>
{option.text}

View File

@@ -103,7 +103,7 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => {
fontSize: "1.125rem",
}}
>
{tokenId ? "编辑Token" : "新建Token"}
{tokenId ? "编辑令牌" : "新建令牌"}
</DialogTitle>
<Divider />
<DialogContent>

View File

@@ -24,6 +24,7 @@ import EditRedemption from './pages/Redemption/EditRedemption';
import TopUp from './pages/TopUp';
import Log from './pages/Log';
import Chat from './pages/Chat';
import LarkOAuth from './components/LarkOAuth';
const Home = lazy(() => import('./pages/Home'));
const About = lazy(() => import('./pages/About'));
@@ -239,6 +240,14 @@ function App() {
</Suspense>
}
/>
<Route
path='/oauth/lark'
element={
<Suspense fallback={<Loading></Loading>}>
<LarkOAuth />
</Suspense>
}
/>
<Route
path='/setting'
element={

View File

@@ -0,0 +1,58 @@
import React, { useContext, useEffect, useState } from 'react';
import { Dimmer, Loader, Segment } from 'semantic-ui-react';
import { useNavigate, useSearchParams } from 'react-router-dom';
import { API, showError, showSuccess } from '../helpers';
import { UserContext } from '../context/User';
const LarkOAuth = () => {
const [searchParams, setSearchParams] = useSearchParams();
const [userState, userDispatch] = useContext(UserContext);
const [prompt, setPrompt] = useState('处理中...');
const [processing, setProcessing] = useState(true);
let navigate = useNavigate();
const sendCode = async (code, state, count) => {
const res = await API.get(`/api/oauth/lark?code=${code}&state=${state}`);
const { success, message, data } = res.data;
if (success) {
if (message === 'bind') {
showSuccess('绑定成功!');
navigate('/setting');
} else {
userDispatch({ type: 'login', payload: data });
localStorage.setItem('user', JSON.stringify(data));
showSuccess('登录成功!');
navigate('/');
}
} else {
showError(message);
if (count === 0) {
setPrompt(`操作失败,重定向至登录界面中...`);
navigate('/setting'); // in case this is failed to bind lark
return;
}
count++;
setPrompt(`出现错误,第 ${count} 次重试中...`);
await new Promise((resolve) => setTimeout(resolve, count * 2000));
await sendCode(code, state, count);
}
};
useEffect(() => {
let code = searchParams.get('code');
let state = searchParams.get('state');
sendCode(code, state, 0).then();
}, []);
return (
<Segment style={{ minHeight: '300px' }}>
<Dimmer active inverted>
<Loader size='large'>{prompt}</Loader>
</Dimmer>
</Segment>
);
};
export default LarkOAuth;

View File

@@ -3,7 +3,8 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f
import { Link, useNavigate, useSearchParams } from 'react-router-dom';
import { UserContext } from '../context/User';
import { API, getLogo, showError, showSuccess, showWarning } from '../helpers';
import { onGitHubOAuthClicked } from './utils';
import { onGitHubOAuthClicked, onLarkOAuthClicked } from './utils';
import larkIcon from '../images/lark.svg';
const LoginForm = () => {
const [inputs, setInputs] = useState({
@@ -124,7 +125,7 @@ const LoginForm = () => {
点击注册
</Link>
</Message>
{status.github_oauth || status.wechat_login ? (
{status.github_oauth || status.wechat_login || status.lark_client_id ? (
<>
<Divider horizontal>Or</Divider>
{status.github_oauth ? (
@@ -137,6 +138,18 @@ const LoginForm = () => {
) : (
<></>
)}
{status.lark_client_id ? (
<Button
// circular
color=''
onClick={() => onLarkOAuthClicked(status.lark_client_id)}
style={{ padding: 0, width: 36, height: 36 }}
>
<img src={larkIcon} width={36} height={36} />
</Button>
) : (
<></>
)}
{status.wechat_login ? (
<Button
circular

View File

@@ -4,7 +4,7 @@ import { Link, useNavigate } from 'react-router-dom';
import { API, copy, showError, showInfo, showNotice, showSuccess } from '../helpers';
import Turnstile from 'react-turnstile';
import { UserContext } from '../context/User';
import { onGitHubOAuthClicked } from './utils';
import { onGitHubOAuthClicked, onLarkOAuthClicked } from './utils';
const PersonalSetting = () => {
const [userState, userDispatch] = useContext(UserContext);
@@ -247,6 +247,11 @@ const PersonalSetting = () => {
<Button onClick={()=>{onGitHubOAuthClicked(status.github_client_id)}}>绑定 GitHub 账号</Button>
)
}
{
status.lark_client_id && (
<Button onClick={()=>{onLarkOAuthClicked(status.lark_client_id)}}>绑定飞书账号</Button>
)
}
<Button
onClick={() => {
setShowEmailBindModal(true);

View File

@@ -10,6 +10,8 @@ const SystemSetting = () => {
GitHubOAuthEnabled: '',
GitHubClientId: '',
GitHubClientSecret: '',
LarkClientId: '',
LarkClientSecret: '',
Notice: '',
SMTPServer: '',
SMTPPort: '',
@@ -109,6 +111,8 @@ const SystemSetting = () => {
name === 'ServerAddress' ||
name === 'GitHubClientId' ||
name === 'GitHubClientSecret' ||
name === 'LarkClientId' ||
name === 'LarkClientSecret' ||
name === 'WeChatServerAddress' ||
name === 'WeChatServerToken' ||
name === 'WeChatAccountQRCodeImageURL' ||
@@ -212,6 +216,18 @@ const SystemSetting = () => {
}
};
const submitLarkOAuth = async () => {
if (originInputs['LarkClientId'] !== inputs.LarkClientId) {
await updateOption('LarkClientId', inputs.LarkClientId);
}
if (
originInputs['LarkClientSecret'] !== inputs.LarkClientSecret &&
inputs.LarkClientSecret !== ''
) {
await updateOption('LarkClientSecret', inputs.LarkClientSecret);
}
};
const submitTurnstile = async () => {
if (originInputs['TurnstileSiteKey'] !== inputs.TurnstileSiteKey) {
await updateOption('TurnstileSiteKey', inputs.TurnstileSiteKey);
@@ -469,6 +485,44 @@ const SystemSetting = () => {
保存 GitHub OAuth 设置
</Form.Button>
<Divider />
<Header as='h3'>
配置飞书授权登录
<Header.Subheader>
用以支持通过飞书进行登录注册
<a href='https://open.feishu.cn/app' target='_blank'>
点击此处
</a>
管理你的飞书应用
</Header.Subheader>
</Header>
<Message>
主页链接填 <code>{inputs.ServerAddress}</code>
重定向 URL {' '}
<code>{`${inputs.ServerAddress}/oauth/lark`}</code>
</Message>
<Form.Group widths={3}>
<Form.Input
label='App ID'
name='LarkClientId'
onChange={handleInputChange}
autoComplete='new-password'
value={inputs.LarkClientId}
placeholder='输入 App ID'
/>
<Form.Input
label='App Secret'
name='LarkClientSecret'
onChange={handleInputChange}
type='password'
autoComplete='new-password'
value={inputs.LarkClientSecret}
placeholder='敏感信息不会发送到前端显示'
/>
</Form.Group>
<Form.Button onClick={submitLarkOAuth}>
保存飞书 OAuth 设置
</Form.Button>
<Divider />
<Header as='h3'>
配置 WeChat Server
<Header.Subheader>

View File

@@ -17,4 +17,13 @@ export async function onGitHubOAuthClicked(github_client_id) {
window.open(
`https://github.com/login/oauth/authorize?client_id=${github_client_id}&state=${state}&scope=user:email`
);
}
export async function onLarkOAuthClicked(lark_client_id) {
const state = await getOAuthState();
if (!state) return;
let redirect_uri = `${window.location.origin}/oauth/lark`;
window.open(
`https://open.feishu.cn/open-apis/authen/v1/index?redirect_uri=${redirect_uri}&app_id=${lark_client_id}&state=${state}`
);
}

View File

@@ -17,6 +17,7 @@ export const CHANNEL_OPTIONS = [
{ key: 29, text: 'Groq', value: 29, color: 'orange' },
{ key: 30, text: 'Ollama', value: 30, color: 'black' },
{ key: 31, text: '零一万物', value: 31, color: 'green' },
{ key: 31, text: '阶跃星辰', value: 32, color: 'blue' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 5.4 KiB

View File

@@ -1,19 +1,22 @@
import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Message, Segment } from 'semantic-ui-react';
import { useParams, useNavigate } from 'react-router-dom';
import { API, showError, showSuccess, timestamp2string } from '../../helpers';
import { renderQuota, renderQuotaWithPrompt } from '../../helpers/render';
import { useNavigate, useParams } from 'react-router-dom';
import { API, copy, showError, showSuccess, timestamp2string } from '../../helpers';
import { renderQuotaWithPrompt } from '../../helpers/render';
const EditToken = () => {
const params = useParams();
const tokenId = params.id;
const isEdit = tokenId !== undefined;
const [loading, setLoading] = useState(isEdit);
const [modelOptions, setModelOptions] = useState([]);
const originInputs = {
name: '',
remain_quota: isEdit ? 0 : 500000,
expired_time: -1,
unlimited_quota: false
unlimited_quota: false,
models: [],
subnet: "",
};
const [inputs, setInputs] = useState(originInputs);
const { name, remain_quota, expired_time, unlimited_quota } = inputs;
@@ -22,8 +25,8 @@ const EditToken = () => {
setInputs((inputs) => ({ ...inputs, [name]: value }));
};
const handleCancel = () => {
navigate("/token");
}
navigate('/token');
};
const setExpiredTime = (month, day, hour, minute) => {
let now = new Date();
let timestamp = now.getTime() / 1000;
@@ -50,6 +53,11 @@ const EditToken = () => {
if (data.expired_time !== -1) {
data.expired_time = timestamp2string(data.expired_time);
}
if (data.models === '') {
data.models = [];
} else {
data.models = data.models.split(',');
}
setInputs(data);
} else {
showError(message);
@@ -60,8 +68,26 @@ const EditToken = () => {
if (isEdit) {
loadToken().then();
}
loadAvailableModels().then();
}, []);
const loadAvailableModels = async () => {
let res = await API.get(`/api/user/available_models`);
const { success, message, data } = res.data;
if (success) {
let options = data.map((model) => {
return {
key: model,
text: model,
value: model
};
});
setModelOptions(options);
} else {
showError(message);
}
};
const submit = async () => {
if (!isEdit && inputs.name === '') return;
let localInputs = inputs;
@@ -74,6 +100,7 @@ const EditToken = () => {
}
localInputs.expired_time = Math.ceil(time / 1000);
}
localInputs.models = localInputs.models.join(',');
let res;
if (isEdit) {
res = await API.put(`/api/token/`, { ...localInputs, id: parseInt(tokenId) });
@@ -109,6 +136,34 @@ const EditToken = () => {
required={!isEdit}
/>
</Form.Field>
<Form.Field>
<Form.Dropdown
label='模型范围'
placeholder={'请选择允许使用的模型,留空则不进行限制'}
name='models'
fluid
multiple
search
onLabelClick={(e, { value }) => {
copy(value).then();
}}
selection
onChange={handleInputChange}
value={inputs.models}
autoComplete='new-password'
options={modelOptions}
/>
</Form.Field>
<Form.Field>
<Form.Input
label='IP 限制'
name='subnet'
placeholder={'请输入允许访问的网段例如192.168.0.0/24请使用英文逗号分隔多个网段'}
onChange={handleInputChange}
value={inputs.subnet}
autoComplete='new-password'
/>
</Form.Field>
<Form.Field>
<Form.Input
label='过期时间'

View File

@@ -8,6 +8,7 @@ const TopUp = () => {
const [topUpLink, setTopUpLink] = useState('');
const [userQuota, setUserQuota] = useState(0);
const [isSubmitting, setIsSubmitting] = useState(false);
const [user, setUser] = useState({});
const topUp = async () => {
if (redemptionCode === '') {
@@ -41,7 +42,14 @@ const TopUp = () => {
showError('超级管理员未设置充值链接!');
return;
}
window.open(topUpLink, '_blank');
let url = new URL(topUpLink);
let username = user.username;
let user_id = user.id;
// add username and user_id to the topup link
url.searchParams.append('username', username);
url.searchParams.append('user_id', user_id);
url.searchParams.append('transaction_id', crypto.randomUUID());
window.open(url.toString(), '_blank');
};
const getUserQuota = async ()=>{
@@ -49,6 +57,7 @@ const TopUp = () => {
const {success, message, data} = res.data;
if (success) {
setUserQuota(data.quota);
setUser(data);
} else {
showError(message);
}
@@ -80,7 +89,7 @@ const TopUp = () => {
}}
/>
<Button color='green' onClick={openTopUpLink}>
获取兑换码
充值
</Button>
<Button color='yellow' onClick={topUp} disabled={isSubmitting}>
{isSubmitting ? '兑换中...' : '兑换'}