diff --git a/.gitignore b/.gitignore
index 1cfa1e7f..be4abc52 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,3 +8,4 @@ build
logs
data
node_modules
+cmd.md
diff --git a/common/config/config.go b/common/config/config.go
index 66cfee06..59fd77df 100644
--- a/common/config/config.go
+++ b/common/config/config.go
@@ -80,6 +80,9 @@ var SMTPToken = ""
var GitHubClientId = ""
var GitHubClientSecret = ""
+var LarkClientId = ""
+var LarkClientSecret = ""
+
var WeChatServerAddress = ""
var WeChatServerToken = ""
var WeChatAccountQRCodeImageURL = ""
diff --git a/common/config/key.go b/common/config/key.go
new file mode 100644
index 00000000..4b503c2d
--- /dev/null
+++ b/common/config/key.go
@@ -0,0 +1,9 @@
+package config
+
+const (
+ KeyPrefix = "cfg_"
+
+ KeyAPIVersion = KeyPrefix + "api_version"
+ KeyLibraryID = KeyPrefix + "library_id"
+ KeyPlugin = KeyPrefix + "plugin"
+)
diff --git a/common/constants.go b/common/constants.go
index 849bdce7..87221b61 100644
--- a/common/constants.go
+++ b/common/constants.go
@@ -4,116 +4,3 @@ import "time"
var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
-
-const (
- RoleGuestUser = 0
- RoleCommonUser = 1
- RoleAdminUser = 10
- RoleRootUser = 100
-)
-
-const (
- UserStatusEnabled = 1 // don't use 0, 0 is the default value!
- UserStatusDisabled = 2 // also don't use 0
- UserStatusDeleted = 3
-)
-
-const (
- TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
- TokenStatusDisabled = 2 // also don't use 0
- TokenStatusExpired = 3
- TokenStatusExhausted = 4
-)
-
-const (
- RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
- RedemptionCodeStatusDisabled = 2 // also don't use 0
- RedemptionCodeStatusUsed = 3 // also don't use 0
-)
-
-const (
- ChannelStatusUnknown = 0
- ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
- ChannelStatusManuallyDisabled = 2 // also don't use 0
- ChannelStatusAutoDisabled = 3
-)
-
-const (
- ChannelTypeUnknown = iota
- ChannelTypeOpenAI
- ChannelTypeAPI2D
- ChannelTypeAzure
- ChannelTypeCloseAI
- ChannelTypeOpenAISB
- ChannelTypeOpenAIMax
- ChannelTypeOhMyGPT
- ChannelTypeCustom
- ChannelTypeAILS
- ChannelTypeAIProxy
- ChannelTypePaLM
- ChannelTypeAPI2GPT
- ChannelTypeAIGC2D
- ChannelTypeAnthropic
- ChannelTypeBaidu
- ChannelTypeZhipu
- ChannelTypeAli
- ChannelTypeXunfei
- ChannelType360
- ChannelTypeOpenRouter
- ChannelTypeAIProxyLibrary
- ChannelTypeFastGPT
- ChannelTypeTencent
- ChannelTypeGemini
- ChannelTypeMoonshot
- ChannelTypeBaichuan
- ChannelTypeMinimax
- ChannelTypeMistral
- ChannelTypeGroq
- ChannelTypeOllama
- ChannelTypeLingYiWanWu
-
- ChannelTypeDummy
-)
-
-var ChannelBaseURLs = []string{
- "", // 0
- "https://api.openai.com", // 1
- "https://oa.api2d.net", // 2
- "", // 3
- "https://api.closeai-proxy.xyz", // 4
- "https://api.openai-sb.com", // 5
- "https://api.openaimax.com", // 6
- "https://api.ohmygpt.com", // 7
- "", // 8
- "https://api.caipacity.com", // 9
- "https://api.aiproxy.io", // 10
- "https://generativelanguage.googleapis.com", // 11
- "https://api.api2gpt.com", // 12
- "https://api.aigc2d.com", // 13
- "https://api.anthropic.com", // 14
- "https://aip.baidubce.com", // 15
- "https://open.bigmodel.cn", // 16
- "https://dashscope.aliyuncs.com", // 17
- "", // 18
- "https://ai.360.cn", // 19
- "https://openrouter.ai/api", // 20
- "https://api.aiproxy.io", // 21
- "https://fastgpt.run/api/openapi", // 22
- "https://hunyuan.cloud.tencent.com", // 23
- "https://generativelanguage.googleapis.com", // 24
- "https://api.moonshot.cn", // 25
- "https://api.baichuan-ai.com", // 26
- "https://api.minimax.chat", // 27
- "https://api.mistral.ai", // 28
- "https://api.groq.com/openai", // 29
- "http://localhost:11434", // 30
- "https://api.lingyiwanwu.com", // 31
-}
-
-const (
- ConfigKeyPrefix = "cfg_"
-
- ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
- ConfigKeyLibraryID = ConfigKeyPrefix + "library_id"
- ConfigKeyPlugin = ConfigKeyPrefix + "plugin"
-)
diff --git a/common/helper/helper.go b/common/helper/helper.go
index db41ac74..35d075bc 100644
--- a/common/helper/helper.go
+++ b/common/helper/helper.go
@@ -2,16 +2,15 @@ package helper
import (
"fmt"
- "github.com/google/uuid"
"html/template"
"log"
- "math/rand"
"net"
"os/exec"
"runtime"
"strconv"
"strings"
- "time"
+
+ "github.com/songquanpeng/one-api/common/random"
)
func OpenBrowser(url string) {
@@ -79,31 +78,6 @@ func Bytes2Size(num int64) string {
return numStr + " " + unit
}
-func Seconds2Time(num int) (time string) {
- if num/31104000 > 0 {
- time += strconv.Itoa(num/31104000) + " 年 "
- num %= 31104000
- }
- if num/2592000 > 0 {
- time += strconv.Itoa(num/2592000) + " 个月 "
- num %= 2592000
- }
- if num/86400 > 0 {
- time += strconv.Itoa(num/86400) + " 天 "
- num %= 86400
- }
- if num/3600 > 0 {
- time += strconv.Itoa(num/3600) + " 小时 "
- num %= 3600
- }
- if num/60 > 0 {
- time += strconv.Itoa(num/60) + " 分钟 "
- num %= 60
- }
- time += strconv.Itoa(num) + " 秒"
- return
-}
-
func Interface2String(inter interface{}) string {
switch inter := inter.(type) {
case string:
@@ -128,65 +102,8 @@ func IntMax(a int, b int) int {
}
}
-func GetUUID() string {
- code := uuid.New().String()
- code = strings.Replace(code, "-", "", -1)
- return code
-}
-
-const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
-const keyNumbers = "0123456789"
-
-func init() {
- rand.Seed(time.Now().UnixNano())
-}
-
-func GenerateKey() string {
- rand.Seed(time.Now().UnixNano())
- key := make([]byte, 48)
- for i := 0; i < 16; i++ {
- key[i] = keyChars[rand.Intn(len(keyChars))]
- }
- uuid_ := GetUUID()
- for i := 0; i < 32; i++ {
- c := uuid_[i]
- if i%2 == 0 && c >= 'a' && c <= 'z' {
- c = c - 'a' + 'A'
- }
- key[i+16] = c
- }
- return string(key)
-}
-
-func GetRandomString(length int) string {
- rand.Seed(time.Now().UnixNano())
- key := make([]byte, length)
- for i := 0; i < length; i++ {
- key[i] = keyChars[rand.Intn(len(keyChars))]
- }
- return string(key)
-}
-
-func GetRandomNumberString(length int) string {
- rand.Seed(time.Now().UnixNano())
- key := make([]byte, length)
- for i := 0; i < length; i++ {
- key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
- }
- return string(key)
-}
-
-func GetTimestamp() int64 {
- return time.Now().Unix()
-}
-
-func GetTimeString() string {
- now := time.Now()
- return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
-}
-
func GenRequestID() string {
- return GetTimeString() + GetRandomNumberString(8)
+ return GetTimeString() + random.GetRandomNumberString(8)
}
func Max(a int, b int) int {
diff --git a/common/helper/time.go b/common/helper/time.go
new file mode 100644
index 00000000..302746db
--- /dev/null
+++ b/common/helper/time.go
@@ -0,0 +1,15 @@
+package helper
+
+import (
+ "fmt"
+ "time"
+)
+
+func GetTimestamp() int64 {
+ return time.Now().Unix()
+}
+
+func GetTimeString() string {
+ now := time.Now()
+ return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
+}
diff --git a/common/network/ip.go b/common/network/ip.go
new file mode 100644
index 00000000..0fbe5e6f
--- /dev/null
+++ b/common/network/ip.go
@@ -0,0 +1,52 @@
+package network
+
+import (
+ "context"
+ "fmt"
+ "github.com/songquanpeng/one-api/common/logger"
+ "net"
+ "strings"
+)
+
+func splitSubnets(subnets string) []string {
+ res := strings.Split(subnets, ",")
+ for i := 0; i < len(res); i++ {
+ res[i] = strings.TrimSpace(res[i])
+ }
+ return res
+}
+
+func isValidSubnet(subnet string) error {
+ _, _, err := net.ParseCIDR(subnet)
+ if err != nil {
+ return fmt.Errorf("failed to parse subnet: %w", err)
+ }
+ return nil
+}
+
+func isIpInSubnet(ctx context.Context, ip string, subnet string) bool {
+ _, ipNet, err := net.ParseCIDR(subnet)
+ if err != nil {
+ logger.Errorf(ctx, "failed to parse subnet: %s", err.Error())
+ return false
+ }
+ return ipNet.Contains(net.ParseIP(ip))
+}
+
+func IsValidSubnets(subnets string) error {
+ for _, subnet := range splitSubnets(subnets) {
+ if err := isValidSubnet(subnet); err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool {
+ for _, subnet := range splitSubnets(subnets) {
+ if isIpInSubnet(ctx, ip, subnet) {
+ return true
+ }
+ }
+ return false
+}
diff --git a/common/network/ip_test.go b/common/network/ip_test.go
new file mode 100644
index 00000000..6c593458
--- /dev/null
+++ b/common/network/ip_test.go
@@ -0,0 +1,19 @@
+package network
+
+import (
+ "context"
+ "testing"
+
+ . "github.com/smartystreets/goconvey/convey"
+)
+
+func TestIsIpInSubnet(t *testing.T) {
+ ctx := context.Background()
+ ip1 := "192.168.0.5"
+ ip2 := "125.216.250.89"
+ subnet := "192.168.0.0/24"
+ Convey("TestIsIpInSubnet", t, func() {
+ So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue)
+ So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse)
+ })
+}
diff --git a/common/random.go b/common/random.go
deleted file mode 100644
index 44bd2856..00000000
--- a/common/random.go
+++ /dev/null
@@ -1,8 +0,0 @@
-package common
-
-import "math/rand"
-
-// RandRange returns a random number between min and max (max is not included)
-func RandRange(min, max int) int {
- return min + rand.Intn(max-min)
-}
diff --git a/common/random/main.go b/common/random/main.go
new file mode 100644
index 00000000..c3c69488
--- /dev/null
+++ b/common/random/main.go
@@ -0,0 +1,62 @@
+package random
+
+import (
+ "math/rand"
+ "strings"
+ "time"
+
+ "github.com/google/uuid"
+)
+
+func GetUUID() string {
+ code := uuid.New().String()
+ code = strings.Replace(code, "-", "", -1)
+ return code
+}
+
+const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+const keyNumbers = "0123456789"
+
+func init() {
+ rand.Seed(time.Now().UnixNano())
+}
+
+func GenerateKey() string {
+ rand.Seed(time.Now().UnixNano())
+ key := make([]byte, 48)
+ for i := 0; i < 16; i++ {
+ key[i] = keyChars[rand.Intn(len(keyChars))]
+ }
+ uuid_ := GetUUID()
+ for i := 0; i < 32; i++ {
+ c := uuid_[i]
+ if i%2 == 0 && c >= 'a' && c <= 'z' {
+ c = c - 'a' + 'A'
+ }
+ key[i+16] = c
+ }
+ return string(key)
+}
+
+func GetRandomString(length int) string {
+ rand.Seed(time.Now().UnixNano())
+ key := make([]byte, length)
+ for i := 0; i < length; i++ {
+ key[i] = keyChars[rand.Intn(len(keyChars))]
+ }
+ return string(key)
+}
+
+func GetRandomNumberString(length int) string {
+ rand.Seed(time.Now().UnixNano())
+ key := make([]byte, length)
+ for i := 0; i < length; i++ {
+ key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
+ }
+ return string(key)
+}
+
+// RandRange returns a random number between min and max (max is not included)
+func RandRange(min, max int) int {
+ return min + rand.Intn(max-min)
+}
diff --git a/controller/github.go b/controller/auth/github.go
similarity index 94%
rename from controller/github.go
rename to controller/auth/github.go
index fc674852..22b48976 100644
--- a/controller/github.go
+++ b/controller/auth/github.go
@@ -1,4 +1,4 @@
-package controller
+package auth
import (
"bytes"
@@ -7,10 +7,10 @@ import (
"github.com/Laisky/errors/v2"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
- "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/common/random"
+ "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -133,8 +133,8 @@ func GitHubOAuth(c *gin.Context) {
user.DisplayName = "GitHub User"
}
user.Email = githubUser.Email
- user.Role = common.RoleCommonUser
- user.Status = common.UserStatusEnabled
+ user.Role = model.RoleCommonUser
+ user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -152,14 +152,14 @@ func GitHubOAuth(c *gin.Context) {
}
}
- if user.Status != common.UserStatusEnabled {
+ if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
- setupLogin(&user, c)
+ controller.SetupLogin(&user, c)
}
func GitHubBind(c *gin.Context) {
@@ -219,7 +219,7 @@ func GitHubBind(c *gin.Context) {
func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c)
- state := helper.GetRandomString(12)
+ state := random.GetRandomString(12)
session.Set("oauth_state", state)
err := session.Save()
if err != nil {
diff --git a/controller/auth/lark.go b/controller/auth/lark.go
new file mode 100644
index 00000000..a1dd8e84
--- /dev/null
+++ b/controller/auth/lark.go
@@ -0,0 +1,201 @@
+package auth
+
+import (
+ "bytes"
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "strconv"
+ "time"
+
+ "github.com/Laisky/errors/v2"
+ "github.com/gin-contrib/sessions"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/controller"
+ "github.com/songquanpeng/one-api/model"
+)
+
+type LarkOAuthResponse struct {
+ AccessToken string `json:"access_token"`
+}
+
+type LarkUser struct {
+ Name string `json:"name"`
+ OpenID string `json:"open_id"`
+}
+
+func getLarkUserInfoByCode(code string) (*LarkUser, error) {
+ if code == "" {
+ return nil, errors.New("无效的参数")
+ }
+ values := map[string]string{
+ "client_id": config.LarkClientId,
+ "client_secret": config.LarkClientSecret,
+ "code": code,
+ "grant_type": "authorization_code",
+ "redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress),
+ }
+ jsonData, err := json.Marshal(values)
+ if err != nil {
+ return nil, err
+ }
+ req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData))
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Content-Type", "application/json")
+ req.Header.Set("Accept", "application/json")
+ client := http.Client{
+ Timeout: 5 * time.Second,
+ }
+ res, err := client.Do(req)
+ if err != nil {
+ logger.SysLog(err.Error())
+ return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
+ }
+ defer res.Body.Close()
+ var oAuthResponse LarkOAuthResponse
+ err = json.NewDecoder(res.Body).Decode(&oAuthResponse)
+ if err != nil {
+ return nil, err
+ }
+ req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil)
+ if err != nil {
+ return nil, err
+ }
+ req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken))
+ res2, err := client.Do(req)
+ if err != nil {
+ logger.SysLog(err.Error())
+ return nil, errors.New("无法连接至飞书服务器,请稍后重试!")
+ }
+ var larkUser LarkUser
+ err = json.NewDecoder(res2.Body).Decode(&larkUser)
+ if err != nil {
+ return nil, err
+ }
+ return &larkUser, nil
+}
+
+func LarkOAuth(c *gin.Context) {
+ session := sessions.Default(c)
+ state := c.Query("state")
+ if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) {
+ c.JSON(http.StatusForbidden, gin.H{
+ "success": false,
+ "message": "state is empty or not same",
+ })
+ return
+ }
+ username := session.Get("username")
+ if username != nil {
+ LarkBind(c)
+ return
+ }
+ code := c.Query("code")
+ larkUser, err := getLarkUserInfoByCode(code)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ user := model.User{
+ LarkId: larkUser.OpenID,
+ }
+ if model.IsLarkIdAlreadyTaken(user.LarkId) {
+ err := user.FillUserByLarkId()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ if config.RegisterEnabled {
+ user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1)
+ if larkUser.Name != "" {
+ user.DisplayName = larkUser.Name
+ } else {
+ user.DisplayName = "Lark User"
+ }
+ user.Role = model.RoleCommonUser
+ user.Status = model.UserStatusEnabled
+
+ if err := user.Insert(0); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ } else {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "管理员关闭了新用户注册",
+ })
+ return
+ }
+ }
+
+ if user.Status != model.UserStatusEnabled {
+ c.JSON(http.StatusOK, gin.H{
+ "message": "用户已被封禁",
+ "success": false,
+ })
+ return
+ }
+ controller.SetupLogin(&user, c)
+}
+
+func LarkBind(c *gin.Context) {
+ code := c.Query("code")
+ larkUser, err := getLarkUserInfoByCode(code)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ user := model.User{
+ LarkId: larkUser.OpenID,
+ }
+ if model.IsLarkIdAlreadyTaken(user.LarkId) {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": "该飞书账户已被绑定",
+ })
+ return
+ }
+ session := sessions.Default(c)
+ id := session.Get("id")
+ // id := c.GetInt("id") // critical bug!
+ user.Id = id.(int)
+ err = user.FillUserById()
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ user.LarkId = larkUser.OpenID
+ err = user.Update(false)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "bind",
+ })
+ return
+}
diff --git a/controller/wechat.go b/controller/auth/wechat.go
similarity index 93%
rename from controller/wechat.go
rename to controller/auth/wechat.go
index 8f997bfb..da1b513b 100644
--- a/controller/wechat.go
+++ b/controller/auth/wechat.go
@@ -1,12 +1,12 @@
-package controller
+package auth
import (
"encoding/json"
"fmt"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -83,8 +83,8 @@ func WeChatAuth(c *gin.Context) {
if config.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User"
- user.Role = common.RoleCommonUser
- user.Status = common.UserStatusEnabled
+ user.Role = model.RoleCommonUser
+ user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{
@@ -102,14 +102,14 @@ func WeChatAuth(c *gin.Context) {
}
}
- if user.Status != common.UserStatusEnabled {
+ if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁",
"success": false,
})
return
}
- setupLogin(&user, c)
+ controller.SetupLogin(&user, c)
}
func WeChatBind(c *gin.Context) {
diff --git a/controller/channel-billing.go b/controller/channel-billing.go
index b9a3908e..79ef322a 100644
--- a/controller/channel-billing.go
+++ b/controller/channel-billing.go
@@ -3,18 +3,19 @@ package controller
import (
"encoding/json"
"fmt"
- "github.com/Laisky/errors/v2"
- "github.com/songquanpeng/one-api/common"
- "github.com/songquanpeng/one-api/common/config"
- "github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/model"
- "github.com/songquanpeng/one-api/monitor"
- "github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"strconv"
"time"
+ "github.com/Laisky/errors/v2"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/model"
+ "github.com/songquanpeng/one-api/monitor"
+ "github.com/songquanpeng/one-api/relay/channeltype"
+ "github.com/songquanpeng/one-api/relay/client"
+
"github.com/gin-gonic/gin"
)
@@ -96,7 +97,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers {
req.Header.Add(k, headers.Get(k))
}
- res, err := util.HTTPClient.Do(req)
+ res, err := client.HTTPClient.Do(req)
if err != nil {
return nil, err
}
@@ -204,28 +205,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
}
func updateChannelBalance(channel *model.Channel) (float64, error) {
- baseURL := common.ChannelBaseURLs[channel.Type]
+ baseURL := channeltype.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL
}
switch channel.Type {
- case common.ChannelTypeOpenAI:
+ case channeltype.OpenAI:
if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL()
}
- case common.ChannelTypeAzure:
+ case channeltype.Azure:
return 0, errors.New("尚未实现")
- case common.ChannelTypeCustom:
+ case channeltype.Custom:
baseURL = channel.GetBaseURL()
- case common.ChannelTypeCloseAI:
+ case channeltype.CloseAI:
return updateChannelCloseAIBalance(channel)
- case common.ChannelTypeOpenAISB:
+ case channeltype.OpenAISB:
return updateChannelOpenAISBBalance(channel)
- case common.ChannelTypeAIProxy:
+ case channeltype.AIProxy:
return updateChannelAIProxyBalance(channel)
- case common.ChannelTypeAPI2GPT:
+ case channeltype.API2GPT:
return updateChannelAPI2GPTBalance(channel)
- case common.ChannelTypeAIGC2D:
+ case channeltype.AIGC2D:
return updateChannelAIGC2DBalance(channel)
default:
return 0, errors.New("尚未实现")
@@ -301,11 +302,11 @@ func updateAllChannelsBalance() error {
return err
}
for _, channel := range channels {
- if channel.Status != common.ChannelStatusEnabled {
+ if channel.Status != model.ChannelStatusEnabled {
continue
}
// TODO: support Azure
- if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom {
+ if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom {
continue
}
balance, err := updateChannelBalance(channel)
diff --git a/controller/channel-test.go b/controller/channel-test.go
index 57138f49..535b21bd 100644
--- a/controller/channel-test.go
+++ b/controller/channel-test.go
@@ -4,18 +4,6 @@ import (
"bytes"
"encoding/json"
"fmt"
- "github.com/Laisky/errors/v2"
- "github.com/songquanpeng/one-api/common"
- "github.com/songquanpeng/one-api/common/config"
- "github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/common/message"
- "github.com/songquanpeng/one-api/middleware"
- "github.com/songquanpeng/one-api/model"
- "github.com/songquanpeng/one-api/monitor"
- "github.com/songquanpeng/one-api/relay/constant"
- "github.com/songquanpeng/one-api/relay/helper"
- relaymodel "github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
"net/http/httptest"
@@ -25,6 +13,20 @@ import (
"sync"
"time"
+ "github.com/Laisky/errors/v2"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/common/message"
+ "github.com/songquanpeng/one-api/middleware"
+ "github.com/songquanpeng/one-api/model"
+ "github.com/songquanpeng/one-api/monitor"
+ relay "github.com/songquanpeng/one-api/relay"
+ "github.com/songquanpeng/one-api/relay/channeltype"
+ "github.com/songquanpeng/one-api/relay/controller"
+ "github.com/songquanpeng/one-api/relay/meta"
+ relaymodel "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/relaymode"
+
"github.com/gin-gonic/gin"
)
@@ -56,9 +58,9 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, "")
- meta := util.GetRelayMeta(c)
- apiType := constant.ChannelType2APIType(channel.Type)
- adaptor := helper.GetAdaptor(apiType)
+ meta := meta.GetByContext(c)
+ apiType := channeltype.ToAPIType(channel.Type)
+ adaptor := relay.GetAdaptor(apiType)
if adaptor == nil {
return errors.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
}
@@ -73,7 +75,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
request := buildTestRequest()
request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName
- convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request)
+ convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil {
return err, nil
}
@@ -88,8 +90,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
return err, nil
}
if resp.StatusCode != http.StatusOK {
- err := util.RelayErrorHandler(resp)
- return errors.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
+ err := controller.RelayErrorHandler(resp)
+ return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
}
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
@@ -171,7 +173,7 @@ func testChannels(notify bool, scope string) error {
}
go func() {
for _, channel := range channels {
- isChannelEnabled := channel.Status == common.ChannelStatusEnabled
+ isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now()
err, openaiErr := testChannel(channel)
tok := time.Now()
@@ -184,10 +186,10 @@ func testChannels(notify bool, scope string) error {
_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error())
}
}
- if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) {
+ if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) {
monitor.DisableChannel(channel.Id, channel.Name, err.Error())
}
- if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) {
+ if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) {
monitor.EnableChannel(channel.Id, channel.Name)
}
channel.UpdateResponseTime(milliseconds)
diff --git a/controller/group.go b/controller/group.go
index 128a3527..6f02394f 100644
--- a/controller/group.go
+++ b/controller/group.go
@@ -2,13 +2,13 @@ package controller
import (
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/common"
+ billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"net/http"
)
func GetGroups(c *gin.Context) {
groupNames := make([]string, 0)
- for groupName := range common.GroupRatio {
+ for groupName := range billingratio.GroupRatio {
groupNames = append(groupNames, groupName)
}
c.JSON(http.StatusOK, gin.H{
diff --git a/controller/misc.go b/controller/misc.go
index f27fdb12..2928b8fb 100644
--- a/controller/misc.go
+++ b/controller/misc.go
@@ -23,6 +23,7 @@ func GetStatus(c *gin.Context) {
"email_verification": config.EmailVerificationEnabled,
"github_oauth": config.GitHubOAuthEnabled,
"github_client_id": config.GitHubClientId,
+ "lark_client_id": config.LarkClientId,
"system_name": config.SystemName,
"logo": config.Logo,
"footer_html": config.Footer,
diff --git a/controller/model.go b/controller/model.go
index 1be352f2..01d01bf0 100644
--- a/controller/model.go
+++ b/controller/model.go
@@ -3,13 +3,15 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/common"
- "github.com/songquanpeng/one-api/relay/channel/openai"
- "github.com/songquanpeng/one-api/relay/constant"
- "github.com/songquanpeng/one-api/relay/helper"
+ "github.com/songquanpeng/one-api/model"
+ relay "github.com/songquanpeng/one-api/relay"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
+ "github.com/songquanpeng/one-api/relay/apitype"
+ "github.com/songquanpeng/one-api/relay/channeltype"
+ "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
"net/http"
+ "strings"
)
// https://platform.openai.com/docs/api-reference/models/list
@@ -39,8 +41,8 @@ type OpenAIModels struct {
Parent *string `json:"parent"`
}
-var openAIModels []OpenAIModels
-var openAIModelsMap map[string]OpenAIModels
+var models []OpenAIModels
+var modelsMap map[string]OpenAIModels
var channelId2Models map[int][]string
func init() {
@@ -60,11 +62,11 @@ func init() {
IsBlocking: false,
})
// https://platform.openai.com/docs/models/model-endpoint-compatibility
- for i := 0; i < constant.APITypeDummy; i++ {
- if i == constant.APITypeAIProxyLibrary {
+ for i := 0; i < apitype.Dummy; i++ {
+ if i == apitype.AIProxyLibrary {
continue
}
- adaptor := helper.GetAdaptor(i)
+ adaptor := relay.GetAdaptor(i)
if adaptor == nil {
continue
}
@@ -72,7 +74,7 @@ func init() {
channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList()
for _, modelName := range modelNames {
- openAIModels = append(openAIModels, OpenAIModels{
+ models = append(models, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
@@ -84,12 +86,12 @@ func init() {
}
}
for _, channelType := range openai.CompatibleChannels {
- if channelType == common.ChannelTypeAzure {
+ if channelType == channeltype.Azure {
continue
}
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
for _, modelName := range channelModelList {
- openAIModels = append(openAIModels, OpenAIModels{
+ models = append(models, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
@@ -100,18 +102,18 @@ func init() {
})
}
}
- openAIModelsMap = make(map[string]OpenAIModels)
- for _, model := range openAIModels {
- openAIModelsMap[model.Id] = model
+ modelsMap = make(map[string]OpenAIModels)
+ for _, model := range models {
+ modelsMap[model.Id] = model
}
channelId2Models = make(map[int][]string)
- for i := 1; i < common.ChannelTypeDummy; i++ {
- adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i))
+ for i := 1; i < channeltype.Dummy; i++ {
+ adaptor := relay.GetAdaptor(channeltype.ToAPIType(i))
if adaptor == nil {
continue
}
- meta := &util.RelayMeta{
+ meta := &meta.Meta{
ChannelType: i,
}
adaptor.Init(meta)
@@ -127,16 +129,55 @@ func DashboardListModels(c *gin.Context) {
})
}
-func ListModels(c *gin.Context) {
+func ListAllModels(c *gin.Context) {
c.JSON(200, gin.H{
"object": "list",
- "data": openAIModels,
+ "data": models,
+ })
+}
+
+func ListModels(c *gin.Context) {
+ ctx := c.Request.Context()
+ var availableModels []string
+ if c.GetString("available_models") != "" {
+ availableModels = strings.Split(c.GetString("available_models"), ",")
+ } else {
+ userId := c.GetInt("id")
+ userGroup, _ := model.CacheGetUserGroup(userId)
+ availableModels, _ = model.CacheGetGroupModels(ctx, userGroup)
+ }
+ modelSet := make(map[string]bool)
+ for _, availableModel := range availableModels {
+ modelSet[availableModel] = true
+ }
+ availableOpenAIModels := make([]OpenAIModels, 0)
+ for _, model := range models {
+ if _, ok := modelSet[model.Id]; ok {
+ modelSet[model.Id] = false
+ availableOpenAIModels = append(availableOpenAIModels, model)
+ }
+ }
+ for modelName, ok := range modelSet {
+ if ok {
+ availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{
+ Id: modelName,
+ Object: "model",
+ Created: 1626777600,
+ OwnedBy: "custom",
+ Root: modelName,
+ Parent: nil,
+ })
+ }
+ }
+ c.JSON(200, gin.H{
+ "object": "list",
+ "data": availableOpenAIModels,
})
}
func RetrieveModel(c *gin.Context) {
modelId := c.Param("model")
- if model, ok := openAIModelsMap[modelId]; ok {
+ if model, ok := modelsMap[modelId]; ok {
c.JSON(200, model)
} else {
Error := relaymodel.Error{
@@ -150,3 +191,30 @@ func RetrieveModel(c *gin.Context) {
})
}
}
+
+func GetUserAvailableModels(c *gin.Context) {
+ ctx := c.Request.Context()
+ id := c.GetInt("id")
+ userGroup, err := model.CacheGetUserGroup(id)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ models, err := model.CacheGetGroupModels(ctx, userGroup)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": models,
+ })
+ return
+}
diff --git a/controller/redemption.go b/controller/redemption.go
index 31c9348d..8d2b3f38 100644
--- a/controller/redemption.go
+++ b/controller/redemption.go
@@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -106,7 +107,7 @@ func AddRedemption(c *gin.Context) {
}
var keys []string
for i := 0; i < redemption.Count; i++ {
- key := helper.GetUUID()
+ key := random.GetUUID()
cleanRedemption := model.Redemption{
UserId: c.GetInt("id"),
Name: redemption.Name,
diff --git a/controller/relay.go b/controller/relay.go
index 85932368..8278bf23 100644
--- a/controller/relay.go
+++ b/controller/relay.go
@@ -4,9 +4,6 @@ import (
"bytes"
"context"
"fmt"
- "io"
- "net/http"
-
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
@@ -16,24 +13,25 @@ import (
"github.com/songquanpeng/one-api/middleware"
dbmodel "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor"
- "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
+ "github.com/songquanpeng/one-api/relay/relaymode"
+ "io"
+ "net/http"
)
// https://platform.openai.com/docs/api-reference/chat
-func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
+func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
var err *model.ErrorWithStatusCode
switch relayMode {
- case constant.RelayModeImagesGenerations:
+ case relaymode.ImagesGenerations:
err = controller.RelayImageHelper(c, relayMode)
- case constant.RelayModeAudioSpeech:
+ case relaymode.AudioSpeech:
fallthrough
- case constant.RelayModeAudioTranslation:
+ case relaymode.AudioTranslation:
fallthrough
- case constant.RelayModeAudioTranscription:
+ case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode)
default:
err = controller.RelayTextHelper(c)
@@ -43,13 +41,13 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
func Relay(c *gin.Context) {
ctx := c.Request.Context()
- relayMode := constant.Path2RelayMode(c.Request.URL.Path)
+ relayMode := relaymode.GetByPath(c.Request.URL.Path)
if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody))
}
channelId := c.GetInt("channel_id")
- bizErr := relay(c, relayMode)
+ bizErr := relayHelper(c, relayMode)
if bizErr == nil {
monitor.Emit(channelId, true)
return
@@ -78,7 +76,7 @@ func Relay(c *gin.Context) {
middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
- bizErr = relay(c, relayMode)
+ bizErr = relayHelper(c, relayMode)
if bizErr == nil {
return
}
@@ -117,7 +115,7 @@ func shouldRetry(c *gin.Context, statusCode int) error {
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors
- if util.ShouldDisableChannel(&err.Error, err.StatusCode) {
+ if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
monitor.DisableChannel(channelId, channelName, err.Message)
} else {
monitor.Emit(channelId, false)
diff --git a/controller/token.go b/controller/token.go
index 9b52b053..74d31547 100644
--- a/controller/token.go
+++ b/controller/token.go
@@ -6,9 +6,12 @@ import (
"strconv"
"github.com/gin-gonic/gin"
+ "github.com/jinzhu/copier"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/network"
+ "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model"
)
@@ -106,9 +109,24 @@ func GetTokenStatus(c *gin.Context) {
})
}
+func validateToken(c *gin.Context, token *model.Token) error {
+ if len(token.Name) > 30 {
+ return fmt.Errorf("令牌名称过长")
+ }
+
+ if token.Subnet != nil && *token.Subnet != "" {
+ err := network.IsValidSubnets(*token.Subnet)
+ if err != nil {
+ return fmt.Errorf("无效的网段:%s", err.Error())
+ }
+ }
+
+ return nil
+}
+
func AddToken(c *gin.Context) {
- token := model.Token{}
- err := c.ShouldBindJSON(&token)
+ token := new(model.Token)
+ err := c.ShouldBindJSON(token)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -116,22 +134,27 @@ func AddToken(c *gin.Context) {
})
return
}
- if len(token.Name) > 30 {
+
+ err = validateToken(c, token)
+ if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
- "message": "令牌名称过长",
+ "message": fmt.Sprintf("参数错误:%s", err.Error()),
})
return
}
+
cleanToken := model.Token{
UserId: c.GetInt("id"),
Name: token.Name,
- Key: helper.GenerateKey(),
+ Key: random.GenerateKey(),
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
ExpiredTime: token.ExpiredTime,
RemainQuota: token.RemainQuota,
UnlimitedQuota: token.UnlimitedQuota,
+ Models: token.Models,
+ Subnet: token.Subnet,
}
err = cleanToken.Insert()
if err != nil {
@@ -168,12 +191,7 @@ func DeleteToken(c *gin.Context) {
}
type updateTokenDto struct {
- Id int `json:"id"`
- Status int `json:"status" gorm:"default:1"`
- Name *string `json:"name" gorm:"index" `
- ExpiredTime *int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
- RemainQuota *int `json:"remain_quota" gorm:"default:0"`
- UnlimitedQuota *bool `json:"unlimited_quota" gorm:"default:false"`
+ model.Token
// AddUsedQuota add or subtract used quota
AddUsedQuota int `json:"add_used_quota" gorm:"-"`
AddReason string `json:"add_reason" gorm:"-"`
@@ -183,43 +201,51 @@ func UpdateToken(c *gin.Context) {
userId := c.GetInt("id")
statusOnly := c.Query("status_only")
tokenPatch := new(updateTokenDto)
- if err := c.ShouldBindJSON(tokenPatch); err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "parse request join: " + err.Error(),
- })
- return
- }
-
- if tokenPatch.Name != nil &&
- (len(*tokenPatch.Name) > 30 || len(*tokenPatch.Name) == 0) {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": "令牌名称错误,长度应在 1-30 之间",
- })
- return
- }
-
- tokenInDB, err := model.GetTokenByIds(tokenPatch.Id, userId)
+ err := c.ShouldBindJSON(tokenPatch)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
- "message": fmt.Sprintf("get token by id %d: %s", tokenPatch.Id, err.Error()),
+ "message": err.Error(),
})
return
}
- if tokenPatch.Status == common.TokenStatusEnabled {
- if tokenInDB.Status == common.TokenStatusExpired && tokenInDB.ExpiredTime <= helper.GetTimestamp() && tokenInDB.ExpiredTime != -1 {
+ token := new(model.Token)
+ if err = copier.Copy(token, tokenPatch); err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ err = validateToken(c, token)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": fmt.Sprintf("参数错误:%s", err.Error()),
+ })
+ return
+ }
+
+ cleanToken, err := model.GetTokenByIds(token.Id, userId)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ if token.Status == model.TokenStatusEnabled {
+ if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
})
return
}
- if tokenInDB.Status == common.TokenStatusExhausted &&
- tokenInDB.RemainQuota <= 0 &&
- !tokenInDB.UnlimitedQuota {
+ if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",
@@ -228,42 +254,33 @@ func UpdateToken(c *gin.Context) {
}
}
if statusOnly != "" {
- tokenInDB.Status = tokenPatch.Status
+ cleanToken.Status = token.Status
} else {
- // If you add more fields, please also update tokenPatch.Update()
- if tokenPatch.Name != nil {
- tokenInDB.Name = *tokenPatch.Name
- }
- if tokenPatch.ExpiredTime != nil {
- tokenInDB.ExpiredTime = *tokenPatch.ExpiredTime
- }
- if tokenPatch.RemainQuota != nil {
- tokenInDB.RemainQuota = int64(*tokenPatch.RemainQuota)
- }
- if tokenPatch.UnlimitedQuota != nil {
- tokenInDB.UnlimitedQuota = *tokenPatch.UnlimitedQuota
- }
+ // If you add more fields, please also update token.Update()
+ cleanToken.Name = token.Name
+ cleanToken.ExpiredTime = token.ExpiredTime
+ cleanToken.RemainQuota = token.RemainQuota
+ cleanToken.UnlimitedQuota = token.UnlimitedQuota
+ cleanToken.Models = token.Models
+ cleanToken.Subnet = token.Subnet
}
- tokenInDB.RemainQuota -= int64(tokenPatch.AddUsedQuota)
- tokenInDB.UsedQuota += int64(tokenPatch.AddUsedQuota)
-
if tokenPatch.AddUsedQuota != 0 {
model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("外部(%s)消耗 %s", tokenPatch.AddReason, common.LogQuota(int64(tokenPatch.AddUsedQuota))))
}
- if err = tokenInDB.Update(); err != nil {
+ err = cleanToken.Update()
+ if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
- "message": "update token: " + err.Error(),
+ "message": err.Error(),
})
return
}
-
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
- "data": tokenInDB,
+ "data": cleanToken,
})
return
}
diff --git a/controller/user.go b/controller/user.go
index 691378ec..bd31c034 100644
--- a/controller/user.go
+++ b/controller/user.go
@@ -5,7 +5,7 @@ import (
"fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
- "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
@@ -58,11 +58,11 @@ func Login(c *gin.Context) {
})
return
}
- setupLogin(&user, c)
+ SetupLogin(&user, c)
}
// setup session & cookies and then return user info
-func setupLogin(user *model.User, c *gin.Context) {
+func SetupLogin(user *model.User, c *gin.Context) {
session := sessions.Default(c)
session.Set("id", user.Id)
session.Set("username", user.Username)
@@ -186,27 +186,27 @@ func Register(c *gin.Context) {
}
func GetAllUsers(c *gin.Context) {
- p, _ := strconv.Atoi(c.Query("p"))
- if p < 0 {
- p = 0
- }
-
- order := c.DefaultQuery("order", "")
- users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
-
- if err != nil {
- c.JSON(http.StatusOK, gin.H{
- "success": false,
- "message": err.Error(),
- })
- return
- }
-
- c.JSON(http.StatusOK, gin.H{
- "success": true,
- "message": "",
- "data": users,
- })
+ p, _ := strconv.Atoi(c.Query("p"))
+ if p < 0 {
+ p = 0
+ }
+
+ order := c.DefaultQuery("order", "")
+ users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order)
+
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ "data": users,
+ })
}
func SearchUsers(c *gin.Context) {
@@ -245,7 +245,7 @@ func GetUser(c *gin.Context) {
return
}
myRole := c.GetInt("role")
- if myRole <= user.Role && myRole != common.RoleRootUser {
+ if myRole <= user.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权获取同级或更高等级用户的信息",
@@ -293,7 +293,7 @@ func GenerateAccessToken(c *gin.Context) {
})
return
}
- user.AccessToken = helper.GetUUID()
+ user.AccessToken = random.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{
@@ -330,7 +330,7 @@ func GetAffCode(c *gin.Context) {
return
}
if user.AffCode == "" {
- user.AffCode = helper.GetRandomString(4)
+ user.AffCode = random.GetRandomString(4)
if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
@@ -404,14 +404,14 @@ func UpdateUser(c *gin.Context) {
return
}
myRole := c.GetInt("role")
- if myRole <= originUser.Role && myRole != common.RoleRootUser {
+ if myRole <= originUser.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息",
})
return
}
- if myRole <= updatedUser.Role && myRole != common.RoleRootUser {
+ if myRole <= updatedUser.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
@@ -525,7 +525,7 @@ func DeleteSelf(c *gin.Context) {
id := c.GetInt("id")
user, _ := model.GetUserById(id, false)
- if user.Role == common.RoleRootUser {
+ if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "不能删除超级管理员账户",
@@ -627,7 +627,7 @@ func ManageUser(c *gin.Context) {
return
}
myRole := c.GetInt("role")
- if myRole <= user.Role && myRole != common.RoleRootUser {
+ if myRole <= user.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息",
@@ -636,8 +636,8 @@ func ManageUser(c *gin.Context) {
}
switch req.Action {
case "disable":
- user.Status = common.UserStatusDisabled
- if user.Role == common.RoleRootUser {
+ user.Status = model.UserStatusDisabled
+ if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法禁用超级管理员用户",
@@ -645,9 +645,9 @@ func ManageUser(c *gin.Context) {
return
}
case "enable":
- user.Status = common.UserStatusEnabled
+ user.Status = model.UserStatusEnabled
case "delete":
- if user.Role == common.RoleRootUser {
+ if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法删除超级管理员用户",
@@ -662,37 +662,37 @@ func ManageUser(c *gin.Context) {
return
}
case "promote":
- if myRole != common.RoleRootUser {
+ if myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "普通管理员用户无法提升其他用户为管理员",
})
return
}
- if user.Role >= common.RoleAdminUser {
+ if user.Role >= model.RoleAdminUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户已经是管理员",
})
return
}
- user.Role = common.RoleAdminUser
+ user.Role = model.RoleAdminUser
case "demote":
- if user.Role == common.RoleRootUser {
+ if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "无法降级超级管理员用户",
})
return
}
- if user.Role == common.RoleCommonUser {
+ if user.Role == model.RoleCommonUser {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "该用户已经是普通用户",
})
return
}
- user.Role = common.RoleCommonUser
+ user.Role = model.RoleCommonUser
}
if err := user.Update(false); err != nil {
@@ -746,7 +746,7 @@ func EmailBind(c *gin.Context) {
})
return
}
- if user.Role == common.RoleRootUser {
+ if user.Role == model.RoleRootUser {
config.RootUserEmail = email
}
c.JSON(http.StatusOK, gin.H{
@@ -786,3 +786,38 @@ func TopUp(c *gin.Context) {
})
return
}
+
+type adminTopUpRequest struct {
+ UserId int `json:"user_id"`
+ Quota int `json:"quota"`
+ Remark string `json:"remark"`
+}
+
+func AdminTopUp(c *gin.Context) {
+ req := adminTopUpRequest{}
+ err := c.ShouldBindJSON(&req)
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ err = model.IncreaseUserQuota(req.UserId, int64(req.Quota))
+ if err != nil {
+ c.JSON(http.StatusOK, gin.H{
+ "success": false,
+ "message": err.Error(),
+ })
+ return
+ }
+ if req.Remark == "" {
+ req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota)))
+ }
+ model.RecordTopupLog(req.UserId, req.Remark, req.Quota)
+ c.JSON(http.StatusOK, gin.H{
+ "success": true,
+ "message": "",
+ })
+ return
+}
diff --git a/docs/API.md b/docs/API.md
new file mode 100644
index 00000000..0b7ddf5a
--- /dev/null
+++ b/docs/API.md
@@ -0,0 +1,53 @@
+# 使用 API 操控 & 扩展 One API
+> 欢迎提交 PR 在此放上你的拓展项目。
+
+例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。
+
+又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。
+
+## 鉴权
+One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取:
+
+
+
+之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API:
+
+
+## 请求格式与响应格式
+One API 使用 JSON 格式进行请求和响应。
+
+对于响应体,一般格式如下:
+```json
+{
+ "message": "请求信息",
+ "success": true,
+ "data": {}
+}
+```
+
+## API 列表
+> 当前 API 列表不全,请自行通过浏览器抓取前端请求
+
+如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。
+
+### 获取当前登录用户信息
+**GET** `/api/user/self`
+
+### 为给定用户充值额度
+**POST** `/api/topup`
+```json
+{
+ "user_id": 1,
+ "quota": 100000,
+ "remark": "充值 100000 额度"
+}
+```
+
+## 其他
+### 充值链接上的附加参数
+One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如:
+`https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837`
+
+你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。
+
+注意,不是所有主题都支持该功能,欢迎 PR 补齐。
\ No newline at end of file
diff --git a/go.mod b/go.mod
index 5c50c39d..39f9e295 100644
--- a/go.mod
+++ b/go.mod
@@ -13,8 +13,10 @@ require (
github.com/go-playground/validator/v10 v10.19.0
github.com/go-redis/redis/v8 v8.11.5
github.com/google/uuid v1.6.0
+ github.com/jinzhu/copier v0.4.0
github.com/pkg/errors v0.9.1
github.com/pkoukk/tiktoken-go v0.1.6
+ github.com/smartystreets/goconvey v1.8.1
github.com/stretchr/testify v1.8.4
golang.org/x/crypto v0.21.0
golang.org/x/image v0.15.0
@@ -48,16 +50,17 @@ require (
github.com/go-sql-driver/mysql v1.7.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cpy v0.0.0-20211218193943-a9c933c06932 // indirect
+ github.com/gopherjs/gopherjs v1.17.2 // indirect
github.com/gorilla/context v1.1.1 // indirect
github.com/gorilla/securecookie v1.1.1 // indirect
github.com/gorilla/sessions v1.2.1 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
- github.com/jackc/pgx/v5 v5.5.4 // indirect
- github.com/jackc/puddle/v2 v2.2.1 // indirect
+ github.com/jackc/pgx/v5 v5.4.3 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect
+ github.com/jtolds/gls v4.20.0+incompatible // indirect
github.com/klauspost/cpuid/v2 v2.2.7 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
@@ -66,6 +69,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
+ github.com/smarty/assertions v1.15.0 // indirect
github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
@@ -78,7 +82,7 @@ require (
golang.org/x/sys v0.18.0 // indirect
golang.org/x/term v0.18.0 // indirect
golang.org/x/text v0.14.0 // indirect
- golang.org/x/tools v0.6.0 // indirect
+ golang.org/x/tools v0.7.0 // indirect
google.golang.org/protobuf v1.33.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
diff --git a/go.sum b/go.sum
index d518f25f..bd2decbc 100644
--- a/go.sum
+++ b/go.sum
@@ -96,6 +96,8 @@ github.com/google/go-cpy v0.0.0-20211218193943-a9c933c06932/go.mod h1:cC6EdPbj/1
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g=
+github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
@@ -107,16 +109,18 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
-github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8=
-github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
-github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
-github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
+github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY=
+github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA=
+github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8=
+github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
+github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo=
+github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM=
github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
@@ -164,6 +168,10 @@ github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
+github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY=
+github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec=
+github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY=
+github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -206,8 +214,8 @@ golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug=
golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
-golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
-golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
+golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs=
+golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
@@ -238,8 +246,8 @@ golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
-golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
-golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
+golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4=
+golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
diff --git a/main.go b/main.go
index 9c727152..3ee1dc94 100644
--- a/main.go
+++ b/main.go
@@ -16,7 +16,7 @@ import (
"github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/router"
)
diff --git a/middleware/auth.go b/middleware/auth.go
index 95ae5700..d55820f7 100644
--- a/middleware/auth.go
+++ b/middleware/auth.go
@@ -1,15 +1,15 @@
package middleware
import (
- "net/http"
- "strings"
-
+ "fmt"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/model"
+ "net/http"
+ "strings"
)
func authHelper(c *gin.Context, minRole int) {
@@ -47,7 +47,7 @@ func authHelper(c *gin.Context, minRole int) {
return
}
}
- if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
+ if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": "用户已被封禁",
@@ -74,24 +74,25 @@ func authHelper(c *gin.Context, minRole int) {
func UserAuth() func(c *gin.Context) {
return func(c *gin.Context) {
- authHelper(c, common.RoleCommonUser)
+ authHelper(c, model.RoleCommonUser)
}
}
func AdminAuth() func(c *gin.Context) {
return func(c *gin.Context) {
- authHelper(c, common.RoleAdminUser)
+ authHelper(c, model.RoleAdminUser)
}
}
func RootAuth() func(c *gin.Context) {
return func(c *gin.Context) {
- authHelper(c, common.RoleRootUser)
+ authHelper(c, model.RoleRootUser)
}
}
func TokenAuth() func(c *gin.Context) {
return func(c *gin.Context) {
+ ctx := c.Request.Context()
key := c.Request.Header.Get("Authorization")
key = strings.TrimPrefix(key, "Bearer ")
key = strings.TrimPrefix(strings.TrimPrefix(key, "sk-"), "laisky-")
@@ -102,6 +103,12 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusUnauthorized, err.Error())
return
}
+ if token.Subnet != nil && *token.Subnet != "" {
+ if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
+ abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP()))
+ return
+ }
+ }
userEnabled, err := model.CacheIsUserEnabled(token.UserId)
if err != nil {
abortWithMessage(c, http.StatusInternalServerError, err.Error())
@@ -111,6 +118,19 @@ func TokenAuth() func(c *gin.Context) {
abortWithMessage(c, http.StatusForbidden, "用户已被封禁")
return
}
+ requestModel, err := getRequestModel(c)
+ if err != nil && !strings.HasPrefix(c.Request.URL.Path, "/v1/models") {
+ abortWithMessage(c, http.StatusBadRequest, err.Error())
+ return
+ }
+ c.Set("request_model", requestModel)
+ if token.Models != nil && *token.Models != "" {
+ c.Set("available_models", *token.Models)
+ if requestModel != "" && !isModelInList(requestModel, *token.Models) {
+ abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel))
+ return
+ }
+ }
c.Set("id", token.UserId)
c.Set("token_id", token.Id)
c.Set("token_name", token.Name)
diff --git a/middleware/distributor.go b/middleware/distributor.go
index f538f250..3eff734d 100644
--- a/middleware/distributor.go
+++ b/middleware/distributor.go
@@ -2,14 +2,16 @@ package middleware
import (
"fmt"
- "github.com/songquanpeng/one-api/common"
- "github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/model"
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/model"
+ "github.com/songquanpeng/one-api/relay/billing/ratio"
+ "github.com/songquanpeng/one-api/relay/channeltype"
)
type ModelRequest struct {
@@ -35,42 +37,16 @@ func Distribute() func(c *gin.Context) {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return
}
- if channel.Status != common.ChannelStatusEnabled {
+ if channel.Status != model.ChannelStatusEnabled {
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
return
}
} else {
- // Select a channel for the user
- var modelRequest ModelRequest
- err := common.UnmarshalBodyReusable(c, &modelRequest)
+ requestModel := c.GetString("request_model")
+ var err error
+ channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
if err != nil {
- abortWithMessage(c, http.StatusBadRequest, fmt.Sprintf("无效的请求: %+v", err))
- return
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
- if modelRequest.Model == "" {
- modelRequest.Model = "text-moderation-stable"
- }
- }
- if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
- if modelRequest.Model == "" {
- modelRequest.Model = c.Param("model")
- }
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
- if modelRequest.Model == "" {
- modelRequest.Model = "dall-e-2"
- }
- }
- if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
- if modelRequest.Model == "" {
- modelRequest.Model = "whisper-1"
- }
- }
- requestModel = modelRequest.Model
- channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false)
- if err != nil {
- message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model)
+ message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel)
if channel != nil {
logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
message = "数据库一致性已被破坏,请联系管理员"
@@ -85,17 +61,18 @@ func Distribute() func(c *gin.Context) {
}
func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) {
+ // one channel could relates to multiple groups,
+ // and each groud has individual ratio,
// set minimal group ratio as channel_ratio
var minimalRatio float64 = -1
for _, grp := range strings.Split(channel.Group, ",") {
- v := common.GetGroupRatio(grp)
+ v := ratio.GetGroupRatio(grp)
if minimalRatio < 0 || v < minimalRatio {
minimalRatio = v
}
}
logger.Info(c.Request.Context(), fmt.Sprintf("set channel %s ratio to %f", channel.Name, minimalRatio))
c.Set("channel_ratio", minimalRatio)
-
c.Set("channel", channel.Type)
c.Set("channel_id", channel.Id)
c.Set("channel_name", channel.Name)
@@ -105,19 +82,19 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility
switch channel.Type {
- case common.ChannelTypeAzure:
- c.Set(common.ConfigKeyAPIVersion, channel.Other)
- case common.ChannelTypeXunfei:
- c.Set(common.ConfigKeyAPIVersion, channel.Other)
- case common.ChannelTypeGemini:
- c.Set(common.ConfigKeyAPIVersion, channel.Other)
- case common.ChannelTypeAIProxyLibrary:
- c.Set(common.ConfigKeyLibraryID, channel.Other)
- case common.ChannelTypeAli:
- c.Set(common.ConfigKeyPlugin, channel.Other)
+ case channeltype.Azure:
+ c.Set(config.KeyAPIVersion, channel.Other)
+ case channeltype.Xunfei:
+ c.Set(config.KeyAPIVersion, channel.Other)
+ case channeltype.Gemini:
+ c.Set(config.KeyAPIVersion, channel.Other)
+ case channeltype.AIProxyLibrary:
+ c.Set(config.KeyLibraryID, channel.Other)
+ case channeltype.Ali:
+ c.Set(config.KeyPlugin, channel.Other)
}
cfg, _ := channel.LoadConfig()
for k, v := range cfg {
- c.Set(common.ConfigKeyPrefix+k, v)
+ c.Set(config.KeyPrefix+k, v)
}
}
diff --git a/middleware/utils.go b/middleware/utils.go
index bc14c367..b65b018b 100644
--- a/middleware/utils.go
+++ b/middleware/utils.go
@@ -1,9 +1,12 @@
package middleware
import (
+ "fmt"
"github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
+ "strings"
)
func abortWithMessage(c *gin.Context, statusCode int, message string) {
@@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) {
c.Abort()
logger.Error(c.Request.Context(), message)
}
+
+func getRequestModel(c *gin.Context) (string, error) {
+ var modelRequest ModelRequest
+ err := common.UnmarshalBodyReusable(c, &modelRequest)
+ if err != nil {
+ return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err)
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = "text-moderation-stable"
+ }
+ }
+ if strings.HasSuffix(c.Request.URL.Path, "embeddings") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = c.Param("model")
+ }
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = "dall-e-2"
+ }
+ }
+ if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") {
+ if modelRequest.Model == "" {
+ modelRequest.Model = "whisper-1"
+ }
+ }
+ return modelRequest.Model, nil
+}
+
+func isModelInList(modelName string, models string) bool {
+ modelList := strings.Split(models, ",")
+ for _, model := range modelList {
+ if modelName == model {
+ return true
+ }
+ }
+ return false
+}
diff --git a/model/ability.go b/model/ability.go
index 48b856a2..2db72518 100644
--- a/model/ability.go
+++ b/model/ability.go
@@ -1,8 +1,10 @@
package model
import (
+ "context"
"github.com/songquanpeng/one-api/common"
"gorm.io/gorm"
+ "sort"
"strings"
)
@@ -55,7 +57,7 @@ func (channel *Channel) AddAbilities() error {
Group: group,
Model: model,
ChannelId: channel.Id,
- Enabled: channel.Status == common.ChannelStatusEnabled,
+ Enabled: channel.Status == ChannelStatusEnabled,
Priority: channel.Priority,
}
abilities = append(abilities, ability)
@@ -88,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error {
func UpdateAbilityStatus(channelId int, status bool) error {
return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error
}
+
+func GetGroupModels(ctx context.Context, group string) ([]string, error) {
+ groupCol := "`group`"
+ trueVal := "1"
+ if common.UsingPostgreSQL {
+ groupCol = `"group"`
+ trueVal = "true"
+ }
+ var models []string
+ err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error
+ if err != nil {
+ return nil, err
+ }
+ sort.Strings(models)
+ return models, err
+}
diff --git a/model/cache.go b/model/cache.go
index 50946bd6..a05cec19 100644
--- a/model/cache.go
+++ b/model/cache.go
@@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/common/random"
"math/rand"
"sort"
"strconv"
@@ -21,6 +22,7 @@ var (
UserId2GroupCacheSeconds = config.SyncFrequency
UserId2QuotaCacheSeconds = config.SyncFrequency
UserId2StatusCacheSeconds = config.SyncFrequency
+ GroupModelsCacheSeconds = config.SyncFrequency
)
func CacheGetTokenByKey(key string) (*Token, error) {
@@ -147,13 +149,32 @@ func CacheIsUserEnabled(userId int) (bool, error) {
return userEnabled, err
}
+func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) {
+ if !common.RedisEnabled {
+ return GetGroupModels(ctx, group)
+ }
+ modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group))
+ if err == nil {
+ return strings.Split(modelsStr, ","), nil
+ }
+ models, err := GetGroupModels(ctx, group)
+ if err != nil {
+ return nil, err
+ }
+ err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second)
+ if err != nil {
+ logger.SysError("Redis set group models error: " + err.Error())
+ }
+ return models, nil
+}
+
var group2model2channels map[string]map[string][]*Channel
var channelSyncLock sync.RWMutex
func InitChannelCache() {
newChannelId2channel := make(map[int]*Channel)
var channels []*Channel
- DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels)
+ DB.Where("status = ?", ChannelStatusEnabled).Find(&channels)
for _, channel := range channels {
newChannelId2channel[channel.Id] = channel
}
@@ -228,7 +249,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPrior
idx := rand.Intn(endIdx)
if ignoreFirstPriority {
if endIdx < len(channels) { // which means there are more than one priority
- idx = common.RandRange(endIdx, len(channels))
+ idx = random.RandRange(endIdx, len(channels))
}
}
return channels[idx], nil
diff --git a/model/channel.go b/model/channel.go
index fc4905b1..e667f7e7 100644
--- a/model/channel.go
+++ b/model/channel.go
@@ -3,13 +3,19 @@ package model
import (
"encoding/json"
"fmt"
- "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
)
+const (
+ ChannelStatusUnknown = 0
+ ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
+ ChannelStatusManuallyDisabled = 2 // also don't use 0
+ ChannelStatusAutoDisabled = 3
+)
+
type Channel struct {
Id int `json:"id"`
Type int `json:"type" gorm:"default:0"`
@@ -39,7 +45,7 @@ func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
case "all":
err = DB.Order("id desc").Find(&channels).Error
case "disabled":
- err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error
+ err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error
default:
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
}
@@ -168,7 +174,7 @@ func (channel *Channel) LoadConfig() (map[string]string, error) {
}
func UpdateChannelStatusById(id int, status int) {
- err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled)
+ err := UpdateAbilityStatus(id, status == ChannelStatusEnabled)
if err != nil {
logger.SysError("failed to update ability status: " + err.Error())
}
@@ -199,6 +205,6 @@ func DeleteChannelByStatus(status int64) (int64, error) {
}
func DeleteDisabledChannel() (int64, error) {
- result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{})
+ result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error
}
diff --git a/model/log.go b/model/log.go
index 4409f73e..6fba776a 100644
--- a/model/log.go
+++ b/model/log.go
@@ -7,7 +7,6 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
-
"gorm.io/gorm"
)
@@ -51,6 +50,21 @@ func RecordLog(userId int, logType int, content string) {
}
}
+func RecordTopupLog(userId int, content string, quota int) {
+ log := &Log{
+ UserId: userId,
+ Username: GetUsernameById(userId),
+ CreatedAt: helper.GetTimestamp(),
+ Type: LogTypeTopup,
+ Content: content,
+ Quota: quota,
+ }
+ err := LOG_DB.Create(log).Error
+ if err != nil {
+ logger.SysError("failed to record log: " + err.Error())
+ }
+}
+
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled {
diff --git a/model/main.go b/model/main.go
index 4bbfde27..e5124a4c 100644
--- a/model/main.go
+++ b/model/main.go
@@ -12,6 +12,7 @@ import (
"github.com/songquanpeng/one-api/common/env"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/common/random"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
@@ -33,10 +34,10 @@ func CreateRootAccountIfNeed() error {
rootUser := User{
Username: "root",
Password: hashedPassword,
- Role: common.RoleRootUser,
- Status: common.UserStatusEnabled,
+ Role: RoleRootUser,
+ Status: UserStatusEnabled,
DisplayName: "Root User",
- AccessToken: helper.GetUUID(),
+ AccessToken: random.GetUUID(),
Quota: 500000000000000,
}
DB.Create(&rootUser)
@@ -46,7 +47,7 @@ func CreateRootAccountIfNeed() error {
Id: 1,
UserId: rootUser.Id,
Key: config.InitialRootToken,
- Status: common.TokenStatusEnabled,
+ Status: TokenStatusEnabled,
Name: "Initial Root Token",
CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(),
diff --git a/model/option.go b/model/option.go
index 1d1c28b4..bed8d4c3 100644
--- a/model/option.go
+++ b/model/option.go
@@ -1,9 +1,9 @@
package model
import (
- "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
+ billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"strconv"
"strings"
"time"
@@ -66,9 +66,9 @@ func InitOptionMap() {
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
- config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
- config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
- config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
+ config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString()
+ config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString()
+ config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString()
config.OptionMap["TopUpLink"] = config.TopUpLink
config.OptionMap["ChatLink"] = config.ChatLink
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
@@ -82,7 +82,7 @@ func loadOptionsFromDatabase() {
options, _ := AllOption()
for _, option := range options {
if option.Key == "ModelRatio" {
- option.Value = common.AddNewMissingRatio(option.Value)
+ option.Value = billingratio.AddNewMissingRatio(option.Value)
}
err := updateOptionMap(option.Key, option.Value)
if err != nil {
@@ -172,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) {
config.GitHubClientId = value
case "GitHubClientSecret":
config.GitHubClientSecret = value
+ case "LarkClientId":
+ config.LarkClientId = value
+ case "LarkClientSecret":
+ config.LarkClientSecret = value
case "Footer":
config.Footer = value
case "SystemName":
@@ -205,11 +209,11 @@ func updateOptionMap(key string, value string) (err error) {
case "RetryTimes":
config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio":
- err = common.UpdateModelRatioByJSONString(value)
+ err = billingratio.UpdateModelRatioByJSONString(value)
case "GroupRatio":
- err = common.UpdateGroupRatioByJSONString(value)
+ err = billingratio.UpdateGroupRatioByJSONString(value)
case "CompletionRatio":
- err = common.UpdateCompletionRatioByJSONString(value)
+ err = billingratio.UpdateCompletionRatioByJSONString(value)
case "TopUpLink":
config.TopUpLink = value
case "ChatLink":
diff --git a/model/redemption.go b/model/redemption.go
index c3ed2576..62428d35 100644
--- a/model/redemption.go
+++ b/model/redemption.go
@@ -8,6 +8,12 @@ import (
"gorm.io/gorm"
)
+const (
+ RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
+ RedemptionCodeStatusDisabled = 2 // also don't use 0
+ RedemptionCodeStatusUsed = 3 // also don't use 0
+)
+
type Redemption struct {
Id int `json:"id"`
UserId int `json:"user_id"`
@@ -61,7 +67,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
if err != nil {
return errors.New("无效的兑换码")
}
- if redemption.Status != common.RedemptionCodeStatusEnabled {
+ if redemption.Status != RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用")
}
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
@@ -69,7 +75,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
return err
}
redemption.RedeemedTime = helper.GetTimestamp()
- redemption.Status = common.RedemptionCodeStatusUsed
+ redemption.Status = RedemptionCodeStatusUsed
err = tx.Save(redemption).Error
return err
})
diff --git a/model/token.go b/model/token.go
index c5491d06..10fd0d78 100644
--- a/model/token.go
+++ b/model/token.go
@@ -12,25 +12,34 @@ import (
"gorm.io/gorm"
)
+const (
+ TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
+ TokenStatusDisabled = 2 // also don't use 0
+ TokenStatusExpired = 3
+ TokenStatusExhausted = 4
+)
+
type Token struct {
- Id int `json:"id"`
- UserId int `json:"user_id"`
- Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
- Status int `json:"status" gorm:"default:1"`
- Name string `json:"name" gorm:"index" `
- CreatedTime int64 `json:"created_time" gorm:"bigint"`
- AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
- ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
- RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
- UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
- UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
+ Id int `json:"id"`
+ UserId int `json:"user_id"`
+ Key string `json:"key" gorm:"type:char(48);uniqueIndex"`
+ Status int `json:"status" gorm:"default:1"`
+ Name string `json:"name" gorm:"index" `
+ CreatedTime int64 `json:"created_time" gorm:"bigint"`
+ AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
+ ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
+ RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"`
+ UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
+ UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota
+ Models *string `json:"models" gorm:"default:''"` // allowed models
+ Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet
}
func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) {
var tokens []*Token
var err error
query := DB.Where("user_id = ?", userId)
-
+
switch order {
case "remain_quota":
query = query.Order("unlimited_quota desc, remain_quota desc")
@@ -39,7 +48,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token
default:
query = query.Order("id desc")
}
-
+
err = query.Limit(num).Offset(startIdx).Find(&tokens).Error
return tokens, err
}
@@ -62,18 +71,17 @@ func ValidateUserToken(key string) (token *Token, err error) {
return nil, errors.Wrap(err, "failed to get token by key")
}
-
- if token.Status == common.TokenStatusExhausted {
- return nil, errors.New("该令牌额度已用尽")
- } else if token.Status == common.TokenStatusExpired {
+ if token.Status == TokenStatusExhausted {
+ return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id)
+ } else if token.Status == TokenStatusExpired {
return nil, errors.New("该令牌已过期")
}
- if token.Status != common.TokenStatusEnabled {
+ if token.Status != TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用")
}
if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
if !common.RedisEnabled {
- token.Status = common.TokenStatusExpired
+ token.Status = TokenStatusExpired
err := token.SelectUpdate()
if err != nil {
logger.SysError("failed to update token status" + err.Error())
@@ -84,7 +92,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted
- token.Status = common.TokenStatusExhausted
+ token.Status = TokenStatusExhausted
err := token.SelectUpdate()
if err != nil {
logger.SysError("failed to update token status" + err.Error())
@@ -124,7 +132,7 @@ func (token *Token) Insert() error {
// Update Make sure your token's fields is completed, because this will update non-zero values
func (token *Token) Update() error {
var err error
- err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error
+ err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error
return err
}
diff --git a/model/user.go b/model/user.go
index e1244e3c..3cc1f9c0 100644
--- a/model/user.go
+++ b/model/user.go
@@ -8,11 +8,24 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config"
- "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm"
)
+const (
+ RoleGuestUser = 0
+ RoleCommonUser = 1
+ RoleAdminUser = 10
+ RoleRootUser = 100
+)
+
+const (
+ UserStatusEnabled = 1 // don't use 0, 0 is the default value!
+ UserStatusDisabled = 2 // also don't use 0
+ UserStatusDeleted = 3
+)
+
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
// Otherwise, the sensitive information will be saved on local storage in plain text!
type User struct {
@@ -25,6 +38,7 @@ type User struct {
Email string `json:"email" gorm:"index" validate:"max=50"`
GitHubId string `json:"github_id" gorm:"column:github_id;index"`
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
+ LarkId string `json:"lark_id" gorm:"column:lark_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int64 `json:"quota" gorm:"bigint;default:0"`
@@ -42,21 +56,21 @@ func GetMaxUserId() int {
}
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
- query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted)
-
- switch order {
- case "quota":
- query = query.Order("quota desc")
- case "used_quota":
- query = query.Order("used_quota desc")
- case "request_count":
- query = query.Order("request_count desc")
- default:
- query = query.Order("id desc")
- }
-
- err = query.Find(&users).Error
- return users, err
+ query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted)
+
+ switch order {
+ case "quota":
+ query = query.Order("quota desc")
+ case "used_quota":
+ query = query.Order("used_quota desc")
+ case "request_count":
+ query = query.Order("request_count desc")
+ default:
+ query = query.Order("id desc")
+ }
+
+ err = query.Find(&users).Error
+ return users, err
}
func SearchUsers(keyword string) (users []*User, err error) {
@@ -108,8 +122,8 @@ func (user *User) Insert(inviterId int) error {
}
}
user.Quota = config.QuotaForNewUser
- user.AccessToken = helper.GetUUID()
- user.AffCode = helper.GetRandomString(4)
+ user.AccessToken = random.GetUUID()
+ user.AffCode = random.GetRandomString(4)
result := DB.Create(user)
if result.Error != nil {
return result.Error
@@ -138,9 +152,9 @@ func (user *User) Update(updatePassword bool) error {
return err
}
}
- if user.Status == common.UserStatusDisabled {
+ if user.Status == UserStatusDisabled {
blacklist.BanUser(user.Id)
- } else if user.Status == common.UserStatusEnabled {
+ } else if user.Status == UserStatusEnabled {
blacklist.UnbanUser(user.Id)
}
err = DB.Model(user).Updates(user).Error
@@ -152,8 +166,8 @@ func (user *User) Delete() error {
return errors.New("id 为空!")
}
blacklist.BanUser(user.Id)
- user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID())
- user.Status = common.UserStatusDeleted
+ user.Username = fmt.Sprintf("deleted_%s", random.GetUUID())
+ user.Status = UserStatusDeleted
err := DB.Model(user).Updates(user).Error
return err
}
@@ -177,7 +191,7 @@ func (user *User) ValidateAndFill() (err error) {
}
}
okay := common.ValidatePasswordAndHash(password, user.Password)
- if !okay || user.Status != common.UserStatusEnabled {
+ if !okay || user.Status != UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁")
}
return nil
@@ -207,6 +221,14 @@ func (user *User) FillUserByGitHubId() error {
return nil
}
+func (user *User) FillUserByLarkId() error {
+ if user.LarkId == "" {
+ return errors.New("lark id 为空!")
+ }
+ DB.Where(User{LarkId: user.LarkId}).First(user)
+ return nil
+}
+
func (user *User) FillUserByWeChatId() error {
if user.WeChatId == "" {
return errors.New("WeChat id 为空!")
@@ -235,6 +257,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool {
return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
}
+func IsLarkIdAlreadyTaken(githubId string) bool {
+ return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
+}
+
func IsUsernameAlreadyTaken(username string) bool {
return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
}
@@ -261,7 +287,7 @@ func IsAdmin(userId int) bool {
logger.SysError("no such user " + err.Error())
return false
}
- return user.Role >= common.RoleAdminUser
+ return user.Role >= RoleAdminUser
}
func IsUserEnabled(userId int) (bool, error) {
@@ -273,7 +299,7 @@ func IsUserEnabled(userId int) (bool, error) {
if err != nil {
return false, err
}
- return user.Status == common.UserStatusEnabled, nil
+ return user.Status == UserStatusEnabled, nil
}
func ValidateAccessToken(token string) (user *User) {
@@ -346,7 +372,7 @@ func decreaseUserQuota(id int, quota int64) (err error) {
}
func GetRootUserEmail() (email string) {
- DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email)
+ DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email)
return email
}
diff --git a/monitor/channel.go b/monitor/channel.go
index ad82d2f5..7e5dc58a 100644
--- a/monitor/channel.go
+++ b/monitor/channel.go
@@ -2,7 +2,6 @@ package monitor
import (
"fmt"
- "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message"
@@ -29,7 +28,7 @@ func notifyRootUser(subject string, content string) {
// DisableChannel disable & notify
func DisableChannel(channelId int, channelName string, reason string) {
- model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
+ model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason))
subject := fmt.Sprintf("渠道「%s」(#%d)已被禁用", channelName, channelId)
content := fmt.Sprintf("渠道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason)
@@ -37,7 +36,7 @@ func DisableChannel(channelId int, channelName string, reason string) {
}
func MetricDisableChannel(channelId int, successRate float64) {
- model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled)
+ model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100))
subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId)
content := fmt.Sprintf("该渠道(#%d)在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
@@ -47,7 +46,7 @@ func MetricDisableChannel(channelId int, successRate float64) {
// EnableChannel enable & notify
func EnableChannel(channelId int, channelName string) {
- model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled)
+ model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled)
logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId))
subject := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId)
content := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId)
diff --git a/monitor/manage.go b/monitor/manage.go
new file mode 100644
index 00000000..946e78af
--- /dev/null
+++ b/monitor/manage.go
@@ -0,0 +1,62 @@
+package monitor
+
+import (
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/relay/model"
+ "net/http"
+ "strings"
+)
+
+func ShouldDisableChannel(err *model.Error, statusCode int) bool {
+ if !config.AutomaticDisableChannelEnabled {
+ return false
+ }
+ if err == nil {
+ return false
+ }
+ if statusCode == http.StatusUnauthorized {
+ return true
+ }
+ switch err.Type {
+ case "insufficient_quota":
+ return true
+ // https://docs.anthropic.com/claude/reference/errors
+ case "authentication_error":
+ return true
+ case "permission_error":
+ return true
+ case "forbidden":
+ return true
+ }
+ if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
+ return true
+ }
+ if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
+ return true
+ } else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
+ return true
+ }
+ //if strings.Contains(err.Message, "quota") {
+ // return true
+ //}
+ if strings.Contains(err.Message, "credit") {
+ return true
+ }
+ if strings.Contains(err.Message, "balance") {
+ return true
+ }
+ return false
+}
+
+func ShouldEnableChannel(err error, openAIErr *model.Error) bool {
+ if !config.AutomaticEnableChannelEnabled {
+ return false
+ }
+ if err != nil {
+ return false
+ }
+ if openAIErr != nil {
+ return false
+ }
+ return true
+}
diff --git a/relay/adaptor.go b/relay/adaptor.go
new file mode 100644
index 00000000..ef549b5b
--- /dev/null
+++ b/relay/adaptor.go
@@ -0,0 +1,41 @@
+package relay
+
+import (
+ "github.com/songquanpeng/one-api/relay/adaptor"
+ "github.com/songquanpeng/one-api/relay/adaptor/aiproxy"
+ "github.com/songquanpeng/one-api/relay/adaptor/anthropic"
+ "github.com/songquanpeng/one-api/relay/adaptor/gemini"
+ "github.com/songquanpeng/one-api/relay/adaptor/ollama"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
+ "github.com/songquanpeng/one-api/relay/adaptor/palm"
+ "github.com/songquanpeng/one-api/relay/apitype"
+)
+
+func GetAdaptor(apiType int) adaptor.Adaptor {
+ switch apiType {
+ case apitype.AIProxyLibrary:
+ return &aiproxy.Adaptor{}
+ // case apitype.Ali:
+ // return &ali.Adaptor{}
+ case apitype.Anthropic:
+ return &anthropic.Adaptor{}
+ // case apitype.Baidu:
+ // return &baidu.Adaptor{}
+ case apitype.Gemini:
+ return &gemini.Adaptor{}
+ case apitype.OpenAI:
+ return &openai.Adaptor{}
+ case apitype.PaLM:
+ return &palm.Adaptor{}
+ // case apitype.Tencent:
+ // return &tencent.Adaptor{}
+ // case apitype.Xunfei:
+ // return &xunfei.Adaptor{}
+ // case apitype.Zhipu:
+ // return &zhipu.Adaptor{}
+ case apitype.Ollama:
+ return &ollama.Adaptor{}
+ }
+
+ return nil
+}
diff --git a/relay/channel/ai360/constants.go b/relay/adaptor/ai360/constants.go
similarity index 100%
rename from relay/channel/ai360/constants.go
rename to relay/adaptor/ai360/constants.go
diff --git a/relay/channel/aiproxy/adaptor.go b/relay/adaptor/aiproxy/adaptor.go
similarity index 54%
rename from relay/channel/aiproxy/adaptor.go
rename to relay/adaptor/aiproxy/adaptor.go
index 6f5d289f..31865698 100644
--- a/relay/channel/aiproxy/adaptor.go
+++ b/relay/adaptor/aiproxy/adaptor.go
@@ -4,10 +4,10 @@ import (
"fmt"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/common"
- "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/relay/adaptor"
+ "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
@@ -15,16 +15,16 @@ import (
type Adaptor struct {
}
-func (a *Adaptor) Init(meta *util.RelayMeta) {
+func (a *Adaptor) Init(meta *meta.Meta) {
}
-func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
}
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
- channel.SetupCommonRequestHeader(c, req, meta)
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
+ adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
@@ -34,15 +34,22 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}
aiProxyLibraryRequest := ConvertRequest(*request)
- aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID)
+ aiProxyLibraryRequest.LibraryId = c.GetString(config.KeyLibraryID)
return aiProxyLibraryRequest, nil
}
-func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
- return channel.DoRequestHelper(a, c, meta, requestBody)
+func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
+ return adaptor.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
diff --git a/relay/channel/aiproxy/constants.go b/relay/adaptor/aiproxy/constants.go
similarity index 60%
rename from relay/channel/aiproxy/constants.go
rename to relay/adaptor/aiproxy/constants.go
index c4df51c4..818d2709 100644
--- a/relay/channel/aiproxy/constants.go
+++ b/relay/adaptor/aiproxy/constants.go
@@ -1,6 +1,6 @@
package aiproxy
-import "github.com/songquanpeng/one-api/relay/channel/openai"
+import "github.com/songquanpeng/one-api/relay/adaptor/openai"
var ModelList = []string{""}
diff --git a/relay/channel/aiproxy/main.go b/relay/adaptor/aiproxy/main.go
similarity index 95%
rename from relay/channel/aiproxy/main.go
rename to relay/adaptor/aiproxy/main.go
index 7b146828..961260de 100644
--- a/relay/channel/aiproxy/main.go
+++ b/relay/adaptor/aiproxy/main.go
@@ -13,7 +13,8 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/common/random"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
)
@@ -54,7 +55,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon
FinishReason: "stop",
}
fullTextResponse := openai.TextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
+ Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
@@ -67,7 +68,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion
choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &constant.StopFinishReason
return &openai.ChatCompletionsStreamResponse{
- Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
+ Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: "",
@@ -79,7 +80,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content
return &openai.ChatCompletionsStreamResponse{
- Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
+ Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: response.Model,
diff --git a/relay/channel/aiproxy/model.go b/relay/adaptor/aiproxy/model.go
similarity index 100%
rename from relay/channel/aiproxy/model.go
rename to relay/adaptor/aiproxy/model.go
diff --git a/relay/adaptor/ali/adaptor.go b/relay/adaptor/ali/adaptor.go
new file mode 100644
index 00000000..e004211e
--- /dev/null
+++ b/relay/adaptor/ali/adaptor.go
@@ -0,0 +1,105 @@
+package ali
+
+// import (
+// "github.com/Laisky/errors/v2"
+// "fmt"
+// "github.com/gin-gonic/gin"
+// "github.com/songquanpeng/one-api/common/config"
+// "github.com/songquanpeng/one-api/relay/adaptor"
+// "github.com/songquanpeng/one-api/relay/meta"
+// "github.com/songquanpeng/one-api/relay/model"
+// "github.com/songquanpeng/one-api/relay/relaymode"
+// "io"
+// "net/http"
+// )
+
+// // https://help.aliyun.com/zh/dashscope/developer-reference/api-details
+
+// type Adaptor struct {
+// }
+
+// func (a *Adaptor) Init(meta *meta.Meta) {
+
+// }
+
+// func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
+// fullRequestURL := ""
+// switch meta.Mode {
+// case relaymode.Embeddings:
+// fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
+// case relaymode.ImagesGenerations:
+// fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
+// default:
+// fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
+// }
+
+// return fullRequestURL, nil
+// }
+
+// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
+// adaptor.SetupCommonRequestHeader(c, req, meta)
+// if meta.IsStream {
+// req.Header.Set("Accept", "text/event-stream")
+// req.Header.Set("X-DashScope-SSE", "enable")
+// }
+// req.Header.Set("Authorization", "Bearer "+meta.APIKey)
+
+// if meta.Mode == relaymode.ImagesGenerations {
+// req.Header.Set("X-DashScope-Async", "enable")
+// }
+// if c.GetString(config.KeyPlugin) != "" {
+// req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin))
+// }
+// return nil
+// }
+
+// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+// if request == nil {
+// return nil, errors.New("request is nil")
+// }
+// switch relayMode {
+// case relaymode.Embeddings:
+// aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
+// return aliEmbeddingRequest, nil
+// default:
+// aliRequest := ConvertRequest(*request)
+// return aliRequest, nil
+// }
+// }
+
+// func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
+// if request == nil {
+// return nil, errors.New("request is nil")
+// }
+
+// aliRequest := ConvertImageRequest(*request)
+// return aliRequest, nil
+// }
+
+// func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
+// return adaptor.DoRequestHelper(a, c, meta, requestBody)
+// }
+
+// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+// if meta.IsStream {
+// err, usage = StreamHandler(c, resp)
+// } else {
+// switch meta.Mode {
+// case relaymode.Embeddings:
+// err, usage = EmbeddingHandler(c, resp)
+// case relaymode.ImagesGenerations:
+// err, usage = ImageHandler(c, resp)
+// default:
+// err, usage = Handler(c, resp)
+// }
+// }
+// return
+// }
+
+// func (a *Adaptor) GetModelList() []string {
+// return ModelList
+// }
+
+// func (a *Adaptor) GetChannelName() string {
+// return "ali"
+// }
diff --git a/relay/channel/ali/constants.go b/relay/adaptor/ali/constants.go
similarity index 65%
rename from relay/channel/ali/constants.go
rename to relay/adaptor/ali/constants.go
index 16bcfca4..3f24ce2e 100644
--- a/relay/channel/ali/constants.go
+++ b/relay/adaptor/ali/constants.go
@@ -3,4 +3,5 @@ package ali
var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1",
+ "ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
}
diff --git a/relay/adaptor/ali/image.go b/relay/adaptor/ali/image.go
new file mode 100644
index 00000000..cef509e2
--- /dev/null
+++ b/relay/adaptor/ali/image.go
@@ -0,0 +1,192 @@
+package ali
+
+import (
+ "encoding/base64"
+ "encoding/json"
+ "fmt"
+ "github.com/Laisky/errors/v2"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+ "io"
+ "net/http"
+ "strings"
+ "time"
+)
+
+func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ apiKey := c.Request.Header.Get("Authorization")
+ apiKey = strings.TrimPrefix(apiKey, "Bearer ")
+ responseFormat := c.GetString("response_format")
+
+ var aliTaskResponse TaskResponse
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = json.Unmarshal(responseBody, &aliTaskResponse)
+ if err != nil {
+ return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+
+ if aliTaskResponse.Message != "" {
+ logger.SysError("aliAsyncTask err: " + string(responseBody))
+ return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
+ }
+
+ aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey)
+ if err != nil {
+ return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
+ }
+
+ if aliResponse.Output.TaskStatus != "SUCCEEDED" {
+ return &model.ErrorWithStatusCode{
+ Error: model.Error{
+ Message: aliResponse.Output.Message,
+ Type: "ali_error",
+ Param: "",
+ Code: aliResponse.Output.Code,
+ },
+ StatusCode: resp.StatusCode,
+ }, nil
+ }
+
+ fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat)
+ jsonResponse, err := json.Marshal(fullTextResponse)
+ if err != nil {
+ return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+ c.Writer.Header().Set("Content-Type", "application/json")
+ c.Writer.WriteHeader(resp.StatusCode)
+ _, err = c.Writer.Write(jsonResponse)
+ return nil, nil
+}
+
+func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) {
+ url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID)
+
+ var aliResponse TaskResponse
+
+ req, err := http.NewRequest("GET", url, nil)
+ if err != nil {
+ return &aliResponse, err, nil
+ }
+
+ req.Header.Set("Authorization", "Bearer "+key)
+
+ client := &http.Client{}
+ resp, err := client.Do(req)
+ if err != nil {
+ logger.SysError("aliAsyncTask client.Do err: " + err.Error())
+ return &aliResponse, err, nil
+ }
+ defer resp.Body.Close()
+
+ responseBody, err := io.ReadAll(resp.Body)
+
+ var response TaskResponse
+ err = json.Unmarshal(responseBody, &response)
+ if err != nil {
+ logger.SysError("aliAsyncTask NewDecoder err: " + err.Error())
+ return &aliResponse, err, nil
+ }
+
+ return &response, nil, responseBody
+}
+
+func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) {
+ waitSeconds := 2
+ step := 0
+ maxStep := 20
+
+ var taskResponse TaskResponse
+ var responseBody []byte
+
+ for {
+ step++
+ rsp, err, body := asyncTask(taskID, key)
+ responseBody = body
+ if err != nil {
+ return &taskResponse, responseBody, err
+ }
+
+ if rsp.Output.TaskStatus == "" {
+ return &taskResponse, responseBody, nil
+ }
+
+ switch rsp.Output.TaskStatus {
+ case "FAILED":
+ fallthrough
+ case "CANCELED":
+ fallthrough
+ case "SUCCEEDED":
+ fallthrough
+ case "UNKNOWN":
+ return rsp, responseBody, nil
+ }
+ if step >= maxStep {
+ break
+ }
+ time.Sleep(time.Duration(waitSeconds) * time.Second)
+ }
+
+ return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
+}
+
+func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse {
+ imageResponse := openai.ImageResponse{
+ Created: helper.GetTimestamp(),
+ }
+
+ for _, data := range response.Output.Results {
+ var b64Json string
+ if responseFormat == "b64_json" {
+ // 读取 data.Url 的图片数据并转存到 b64Json
+ imageData, err := getImageData(data.Url)
+ if err != nil {
+ // 处理获取图片数据失败的情况
+ logger.SysError("getImageData Error getting image data: " + err.Error())
+ continue
+ }
+
+ // 将图片数据转为 Base64 编码的字符串
+ b64Json = Base64Encode(imageData)
+ } else {
+ // 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image
+ b64Json = data.B64Image
+ }
+
+ imageResponse.Data = append(imageResponse.Data, openai.ImageData{
+ Url: data.Url,
+ B64Json: b64Json,
+ RevisedPrompt: "",
+ })
+ }
+ return &imageResponse
+}
+
+func getImageData(url string) ([]byte, error) {
+ response, err := http.Get(url)
+ if err != nil {
+ return nil, err
+ }
+ defer response.Body.Close()
+
+ imageData, err := io.ReadAll(response.Body)
+ if err != nil {
+ return nil, err
+ }
+
+ return imageData, nil
+}
+
+func Base64Encode(data []byte) string {
+ b64Json := base64.StdEncoding.EncodeToString(data)
+ return b64Json
+}
diff --git a/relay/channel/ali/main.go b/relay/adaptor/ali/main.go
similarity index 100%
rename from relay/channel/ali/main.go
rename to relay/adaptor/ali/main.go
diff --git a/relay/adaptor/ali/model.go b/relay/adaptor/ali/model.go
new file mode 100644
index 00000000..450b5f52
--- /dev/null
+++ b/relay/adaptor/ali/model.go
@@ -0,0 +1,154 @@
+package ali
+
+import (
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
+ "github.com/songquanpeng/one-api/relay/model"
+)
+
+type Message struct {
+ Content string `json:"content"`
+ Role string `json:"role"`
+}
+
+type Input struct {
+ //Prompt string `json:"prompt"`
+ Messages []Message `json:"messages"`
+}
+
+type Parameters struct {
+ TopP float64 `json:"top_p,omitempty"`
+ TopK int `json:"top_k,omitempty"`
+ Seed uint64 `json:"seed,omitempty"`
+ EnableSearch bool `json:"enable_search,omitempty"`
+ IncrementalOutput bool `json:"incremental_output,omitempty"`
+ MaxTokens int `json:"max_tokens,omitempty"`
+ Temperature float64 `json:"temperature,omitempty"`
+ ResultFormat string `json:"result_format,omitempty"`
+ Tools []model.Tool `json:"tools,omitempty"`
+}
+
+type ChatRequest struct {
+ Model string `json:"model"`
+ Input Input `json:"input"`
+ Parameters Parameters `json:"parameters,omitempty"`
+}
+
+type ImageRequest struct {
+ Model string `json:"model"`
+ Input struct {
+ Prompt string `json:"prompt"`
+ NegativePrompt string `json:"negative_prompt,omitempty"`
+ } `json:"input"`
+ Parameters struct {
+ Size string `json:"size,omitempty"`
+ N int `json:"n,omitempty"`
+ Steps string `json:"steps,omitempty"`
+ Scale string `json:"scale,omitempty"`
+ } `json:"parameters,omitempty"`
+ ResponseFormat string `json:"response_format,omitempty"`
+}
+
+type TaskResponse struct {
+ StatusCode int `json:"status_code,omitempty"`
+ RequestId string `json:"request_id,omitempty"`
+ Code string `json:"code,omitempty"`
+ Message string `json:"message,omitempty"`
+ Output struct {
+ TaskId string `json:"task_id,omitempty"`
+ TaskStatus string `json:"task_status,omitempty"`
+ Code string `json:"code,omitempty"`
+ Message string `json:"message,omitempty"`
+ Results []struct {
+ B64Image string `json:"b64_image,omitempty"`
+ Url string `json:"url,omitempty"`
+ Code string `json:"code,omitempty"`
+ Message string `json:"message,omitempty"`
+ } `json:"results,omitempty"`
+ TaskMetrics struct {
+ Total int `json:"TOTAL,omitempty"`
+ Succeeded int `json:"SUCCEEDED,omitempty"`
+ Failed int `json:"FAILED,omitempty"`
+ } `json:"task_metrics,omitempty"`
+ } `json:"output,omitempty"`
+ Usage Usage `json:"usage"`
+}
+
+type Header struct {
+ Action string `json:"action,omitempty"`
+ Streaming string `json:"streaming,omitempty"`
+ TaskID string `json:"task_id,omitempty"`
+ Event string `json:"event,omitempty"`
+ ErrorCode string `json:"error_code,omitempty"`
+ ErrorMessage string `json:"error_message,omitempty"`
+ Attributes any `json:"attributes,omitempty"`
+}
+
+type Payload struct {
+ Model string `json:"model,omitempty"`
+ Task string `json:"task,omitempty"`
+ TaskGroup string `json:"task_group,omitempty"`
+ Function string `json:"function,omitempty"`
+ Parameters struct {
+ SampleRate int `json:"sample_rate,omitempty"`
+ Rate float64 `json:"rate,omitempty"`
+ Format string `json:"format,omitempty"`
+ } `json:"parameters,omitempty"`
+ Input struct {
+ Text string `json:"text,omitempty"`
+ } `json:"input,omitempty"`
+ Usage struct {
+ Characters int `json:"characters,omitempty"`
+ } `json:"usage,omitempty"`
+}
+
+type WSSMessage struct {
+ Header Header `json:"header,omitempty"`
+ Payload Payload `json:"payload,omitempty"`
+}
+
+type EmbeddingRequest struct {
+ Model string `json:"model"`
+ Input struct {
+ Texts []string `json:"texts"`
+ } `json:"input"`
+ Parameters *struct {
+ TextType string `json:"text_type,omitempty"`
+ } `json:"parameters,omitempty"`
+}
+
+type Embedding struct {
+ Embedding []float64 `json:"embedding"`
+ TextIndex int `json:"text_index"`
+}
+
+type EmbeddingResponse struct {
+ Output struct {
+ Embeddings []Embedding `json:"embeddings"`
+ } `json:"output"`
+ Usage Usage `json:"usage"`
+ Error
+}
+
+type Error struct {
+ Code string `json:"code"`
+ Message string `json:"message"`
+ RequestId string `json:"request_id"`
+}
+
+type Usage struct {
+ InputTokens int `json:"input_tokens"`
+ OutputTokens int `json:"output_tokens"`
+ TotalTokens int `json:"total_tokens"`
+}
+
+type Output struct {
+ //Text string `json:"text"`
+ //FinishReason string `json:"finish_reason"`
+ Choices []openai.TextResponseChoice `json:"choices"`
+}
+
+type ChatResponse struct {
+ Output Output `json:"output"`
+ Usage Usage `json:"usage"`
+ Error
+}
diff --git a/relay/channel/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go
similarity index 60%
rename from relay/channel/anthropic/adaptor.go
rename to relay/adaptor/anthropic/adaptor.go
index 9f1adb9a..07efb3c7 100644
--- a/relay/channel/anthropic/adaptor.go
+++ b/relay/adaptor/anthropic/adaptor.go
@@ -2,31 +2,31 @@ package anthropic
import (
"fmt"
- "github.com/Laisky/errors/v2"
"io"
"net/http"
+ "github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/channel"
+ "github.com/songquanpeng/one-api/relay/adaptor"
+ "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
)
type Adaptor struct {
}
-func (a *Adaptor) Init(meta *util.RelayMeta) {
+func (a *Adaptor) Init(meta *meta.Meta) {
}
-func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
- // https://docs.anthropic.com/claude/reference/messages_post
- // anthopic migrate to Message API
+// https://docs.anthropic.com/claude/reference/messages_post
+// anthopic migrate to Message API
+func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
}
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
- channel.SetupCommonRequestHeader(c, req, meta)
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
+ adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-api-key", meta.APIKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
@@ -46,11 +46,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil
}
-func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
- return channel.DoRequestHelper(a, c, meta, requestBody)
+func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
+ return adaptor.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
diff --git a/relay/channel/anthropic/constants.go b/relay/adaptor/anthropic/constants.go
similarity index 100%
rename from relay/channel/anthropic/constants.go
rename to relay/adaptor/anthropic/constants.go
diff --git a/relay/channel/anthropic/main.go b/relay/adaptor/anthropic/main.go
similarity index 99%
rename from relay/channel/anthropic/main.go
rename to relay/adaptor/anthropic/main.go
index 198b66ad..aec327fe 100644
--- a/relay/channel/anthropic/main.go
+++ b/relay/adaptor/anthropic/main.go
@@ -13,7 +13,7 @@ import (
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
)
diff --git a/relay/channel/anthropic/model.go b/relay/adaptor/anthropic/model.go
similarity index 100%
rename from relay/channel/anthropic/model.go
rename to relay/adaptor/anthropic/model.go
diff --git a/relay/adaptor/azure/helper.go b/relay/adaptor/azure/helper.go
new file mode 100644
index 00000000..dd207f37
--- /dev/null
+++ b/relay/adaptor/azure/helper.go
@@ -0,0 +1,15 @@
+package azure
+
+import (
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/common/config"
+)
+
+func GetAPIVersion(c *gin.Context) string {
+ query := c.Request.URL.Query()
+ apiVersion := query.Get("api-version")
+ if apiVersion == "" {
+ apiVersion = c.GetString(config.KeyAPIVersion)
+ }
+ return apiVersion
+}
diff --git a/relay/channel/baichuan/constants.go b/relay/adaptor/baichuan/constants.go
similarity index 100%
rename from relay/channel/baichuan/constants.go
rename to relay/adaptor/baichuan/constants.go
diff --git a/relay/channel/baidu/adaptor.go b/relay/adaptor/baidu/adaptor.go
similarity index 100%
rename from relay/channel/baidu/adaptor.go
rename to relay/adaptor/baidu/adaptor.go
diff --git a/relay/adaptor/baidu/constants.go b/relay/adaptor/baidu/constants.go
new file mode 100644
index 00000000..f952adc6
--- /dev/null
+++ b/relay/adaptor/baidu/constants.go
@@ -0,0 +1,20 @@
+package baidu
+
+var ModelList = []string{
+ "ERNIE-4.0-8K",
+ "ERNIE-3.5-8K",
+ "ERNIE-3.5-8K-0205",
+ "ERNIE-3.5-8K-1222",
+ "ERNIE-Bot-8K",
+ "ERNIE-3.5-4K-0205",
+ "ERNIE-Speed-8K",
+ "ERNIE-Speed-128K",
+ "ERNIE-Lite-8K-0922",
+ "ERNIE-Lite-8K-0308",
+ "ERNIE-Tiny-8K",
+ "BLOOMZ-7B",
+ "Embedding-V1",
+ "bge-large-zh",
+ "bge-large-en",
+ "tao-8k",
+}
diff --git a/relay/channel/baidu/main.go b/relay/adaptor/baidu/main.go
similarity index 100%
rename from relay/channel/baidu/main.go
rename to relay/adaptor/baidu/main.go
diff --git a/relay/channel/baidu/model.go b/relay/adaptor/baidu/model.go
similarity index 100%
rename from relay/channel/baidu/model.go
rename to relay/adaptor/baidu/model.go
diff --git a/relay/channel/common.go b/relay/adaptor/common.go
similarity index 81%
rename from relay/channel/common.go
rename to relay/adaptor/common.go
index 2c4fb37c..13f57132 100644
--- a/relay/channel/common.go
+++ b/relay/adaptor/common.go
@@ -1,4 +1,4 @@
-package channel
+package adaptor
import (
"io"
@@ -6,10 +6,11 @@ import (
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/util"
+ "github.com/songquanpeng/one-api/relay/client"
+ "github.com/songquanpeng/one-api/relay/meta"
)
-func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) {
+func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if meta.IsStream && c.Request.Header.Get("Accept") == "" {
@@ -17,7 +18,7 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.Rela
}
}
-func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(meta)
if err != nil {
return nil, errors.Wrap(err, "get request url failed")
@@ -43,7 +44,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBod
}
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
- resp, err := util.HTTPClient.Do(req)
+ resp, err := client.HTTPClient.Do(req)
if err != nil {
return nil, err
}
diff --git a/relay/channel/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go
similarity index 66%
rename from relay/channel/gemini/adaptor.go
rename to relay/adaptor/gemini/adaptor.go
index 240ca28d..ecb72221 100644
--- a/relay/channel/gemini/adaptor.go
+++ b/relay/adaptor/gemini/adaptor.go
@@ -8,21 +8,21 @@ import (
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
- channelhelper "github.com/songquanpeng/one-api/relay/channel"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
+ "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
)
type Adaptor struct {
}
-func (a *Adaptor) Init(meta *util.RelayMeta) {
+func (a *Adaptor) Init(meta *meta.Meta) {
}
-func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
- version := helper.AssignOrDefault(meta.APIVersion, "v1beta")
+func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
+ version := helper.AssignOrDefault(meta.APIVersion, "v1")
action := "generateContent"
if meta.IsStream {
action = "streamGenerateContent"
@@ -30,7 +30,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", meta.BaseURL, version, meta.ActualModelName, action, meta.APIKey), nil
}
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channelhelper.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey)
req.URL.Query().Add("key", meta.APIKey)
@@ -44,11 +44,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil
}
-func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
+func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
+}
+
+func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)
diff --git a/relay/channel/gemini/constants.go b/relay/adaptor/gemini/constants.go
similarity index 100%
rename from relay/channel/gemini/constants.go
rename to relay/adaptor/gemini/constants.go
diff --git a/relay/channel/gemini/main.go b/relay/adaptor/gemini/main.go
similarity index 56%
rename from relay/channel/gemini/main.go
rename to relay/adaptor/gemini/main.go
index b66f2d5e..27a9c023 100644
--- a/relay/channel/gemini/main.go
+++ b/relay/adaptor/gemini/main.go
@@ -1,21 +1,24 @@
package gemini
import (
- "context"
+ "bufio"
"encoding/json"
"fmt"
"io"
"net/http"
+ "strings"
- "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/common/random"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
+
+ "github.com/gin-gonic/gin"
)
// https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn
@@ -82,13 +85,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
if imageNum > VisionMaxImageNum {
continue
}
- mimeType, data, err := image.GetImageFromUrl(part.ImageURL.Url)
- if err != nil {
- logger.Warn(context.TODO(),
- fmt.Sprintf("get image from url %s got %+v", part.ImageURL.Url, err))
- continue
- }
-
+ mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url)
parts = append(parts, Part{
InlineData: &InlineData{
MimeType: mimeType,
@@ -97,9 +94,6 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest {
})
}
}
-
- logger.Info(context.TODO(),
- fmt.Sprintf("send %d messages to gemini with %d images", len(parts), imageNum))
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
@@ -163,7 +157,7 @@ type ChatPromptFeedback struct {
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
+ Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
@@ -196,182 +190,73 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC
return &response
}
-// [{
-// "candidates": [
-// {
-// "content": {
-// "parts": [
-// {
-// "text": "```go \n\n// Package ratelimit implements tokens bucket algorithm.\npackage rate"
-// }
-// ],
-// "role": "model"
-// },
-// "finishReason": "STOP",
-// "index": 0,
-// "safetyRatings": [
-// {
-// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
-// "probability": "NEGLIGIBLE"
-// },
-// {
-// "category": "HARM_CATEGORY_HATE_SPEECH",
-// "probability": "NEGLIGIBLE"
-// },
-// {
-// "category": "HARM_CATEGORY_HARASSMENT",
-// "probability": "NEGLIGIBLE"
-// },
-// {
-// "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
-// "probability": "NEGLIGIBLE"
-// }
-// ]
-// }
-// ],
-// "promptFeedback": {
-// "safetyRatings": [
-// {
-// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
-// "probability": "NEGLIGIBLE"
-// },
-// {
-// "category": "HARM_CATEGORY_HATE_SPEECH",
-// "probability": "NEGLIGIBLE"
-// },
-// {
-// "category": "HARM_CATEGORY_HARASSMENT",
-// "probability": "NEGLIGIBLE"
-// },
-// {
-// "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
-// "probability": "NEGLIGIBLE"
-// }
-// ]
-// }
-// }]
-type GeminiStreamResp struct {
- Candidates []struct {
- Content struct {
- Parts []struct {
- Text string `json:"text"`
- } `json:"parts"`
- Role string `json:"role"`
- } `json:"content"`
- FinishReason string `json:"finishReason"`
- Index int64 `json:"index"`
- } `json:"candidates"`
-}
-
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
-
- respBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return openai.ErrorWrapper(err, "read upstream's body", http.StatusInternalServerError), responseText
- }
-
- var respData []GeminiStreamResp
- if err = json.Unmarshal(respBody, &respData); err != nil {
- return openai.ErrorWrapper(err, "unmarshal upstream's body", http.StatusInternalServerError), responseText
- }
-
- for _, chunk := range respData {
- for _, cad := range chunk.Candidates {
- for _, part := range cad.Content.Parts {
- responseText += part.Text
- }
+ dataChan := make(chan string)
+ stopChan := make(chan bool)
+ scanner := bufio.NewScanner(resp.Body)
+ scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
+ if atEOF && len(data) == 0 {
+ return 0, nil, nil
}
- }
-
- var choice openai.ChatCompletionsStreamResponseChoice
- choice.Delta.Content = responseText
- resp2cli, err := json.Marshal(&openai.ChatCompletionsStreamResponse{
- Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
- Object: "chat.completion.chunk",
- Created: helper.GetTimestamp(),
- Model: "gemini-pro",
- Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
+ if i := strings.Index(string(data), "\n"); i >= 0 {
+ return i + 1, data[0:i], nil
+ }
+ if atEOF {
+ return len(data), data, nil
+ }
+ return 0, nil, nil
})
+ go func() {
+ for scanner.Scan() {
+ data := scanner.Text()
+ data = strings.TrimSpace(data)
+ if !strings.HasPrefix(data, "\"text\": \"") {
+ continue
+ }
+ data = strings.TrimPrefix(data, "\"text\": \"")
+ data = strings.TrimSuffix(data, "\"")
+ dataChan <- data
+ }
+ stopChan <- true
+ }()
+ common.SetEventStreamHeaders(c)
+ c.Stream(func(w io.Writer) bool {
+ select {
+ case data := <-dataChan:
+ // this is used to prevent annoying \ related format bug
+ data = fmt.Sprintf("{\"content\": \"%s\"}", data)
+ type dummyStruct struct {
+ Content string `json:"content"`
+ }
+ var dummy dummyStruct
+ err := json.Unmarshal([]byte(data), &dummy)
+ responseText += dummy.Content
+ var choice openai.ChatCompletionsStreamResponseChoice
+ choice.Delta.Content = dummy.Content
+ response := openai.ChatCompletionsStreamResponse{
+ Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
+ Object: "chat.completion.chunk",
+ Created: helper.GetTimestamp(),
+ Model: "gemini-pro",
+ Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
+ }
+ jsonResponse, err := json.Marshal(response)
+ if err != nil {
+ logger.SysError("error marshalling stream response: " + err.Error())
+ return true
+ }
+ c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+ return true
+ case <-stopChan:
+ c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+ return false
+ }
+ })
+ err := resp.Body.Close()
if err != nil {
- return openai.ErrorWrapper(err, "marshal upstream's body", http.StatusInternalServerError), responseText
- }
-
- c.Render(-1, common.CustomEvent{Data: "data: " + string(resp2cli)})
- c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
-
- // dataChan := make(chan string)
- // stopChan := make(chan bool)
- // scanner := bufio.NewScanner(resp.Body)
- // scanner.Split(bufio.ScanLines)
- // // scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
- // // if atEOF && len(data) == 0 {
- // // return 0, nil, nil
- // // }
- // // if i := strings.Index(string(data), "\n"); i >= 0 {
- // // return i + 1, data[0:i], nil
- // // }
- // // if atEOF {
- // // return len(data), data, nil
- // // }
- // // return 0, nil, nil
- // // })
- // go func() {
- // var content string
- // for scanner.Scan() {
- // line := strings.TrimSpace(scanner.Text())
- // fmt.Printf("> gemini got line: %s\n", line)
- // content += line
- // // if !strings.HasPrefix(data, "\"text\": \"") {
- // // continue
- // // }
-
- // // data = strings.TrimPrefix(data, "\"text\": \"")
- // // data = strings.TrimSuffix(data, "\"")
- // // dataChan <- data
- // }
-
- // dataChan <- content
- // stopChan <- true
- // }()
- // common.SetEventStreamHeaders(c)
- // c.Stream(func(w io.Writer) bool {
- // select {
- // case data := <-dataChan:
- // // this is used to prevent annoying \ related format bug
- // data = fmt.Sprintf("{\"content\": \"%s\"}", data)
- // type dummyStruct struct {
- // Content string `json:"content"`
- // }
- // var dummy dummyStruct
- // err := json.Unmarshal([]byte(data), &dummy)
- // responseText += dummy.Content
- // var choice openai.ChatCompletionsStreamResponseChoice
- // choice.Delta.Content = dummy.Content
- // response := openai.ChatCompletionsStreamResponse{
- // Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
- // Object: "chat.completion.chunk",
- // Created: helper.GetTimestamp(),
- // Model: "gemini-pro",
- // Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
- // }
- // jsonResponse, err := json.Marshal(response)
- // if err != nil {
- // logger.SysError("error marshalling stream response: " + err.Error())
- // return true
- // }
- // c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
- // return true
- // case <-stopChan:
- // c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
- // return false
- // }
- // })
-
- if err := resp.Body.Close(); err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
}
-
return nil, responseText
}
diff --git a/relay/channel/gemini/model.go b/relay/adaptor/gemini/model.go
similarity index 100%
rename from relay/channel/gemini/model.go
rename to relay/adaptor/gemini/model.go
diff --git a/relay/channel/groq/constants.go b/relay/adaptor/groq/constants.go
similarity index 100%
rename from relay/channel/groq/constants.go
rename to relay/adaptor/groq/constants.go
diff --git a/relay/adaptor/interface.go b/relay/adaptor/interface.go
new file mode 100644
index 00000000..01b2e2cb
--- /dev/null
+++ b/relay/adaptor/interface.go
@@ -0,0 +1,21 @@
+package adaptor
+
+import (
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/meta"
+ "github.com/songquanpeng/one-api/relay/model"
+ "io"
+ "net/http"
+)
+
+type Adaptor interface {
+ Init(meta *meta.Meta)
+ GetRequestURL(meta *meta.Meta) (string, error)
+ SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
+ ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
+ ConvertImageRequest(request *model.ImageRequest) (any, error)
+ DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
+ DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
+ GetModelList() []string
+ GetChannelName() string
+}
diff --git a/relay/channel/lingyiwanwu/constants.go b/relay/adaptor/lingyiwanwu/constants.go
similarity index 100%
rename from relay/channel/lingyiwanwu/constants.go
rename to relay/adaptor/lingyiwanwu/constants.go
diff --git a/relay/channel/minimax/constants.go b/relay/adaptor/minimax/constants.go
similarity index 100%
rename from relay/channel/minimax/constants.go
rename to relay/adaptor/minimax/constants.go
diff --git a/relay/adaptor/minimax/main.go b/relay/adaptor/minimax/main.go
new file mode 100644
index 00000000..fc9b5d26
--- /dev/null
+++ b/relay/adaptor/minimax/main.go
@@ -0,0 +1,14 @@
+package minimax
+
+import (
+ "fmt"
+ "github.com/songquanpeng/one-api/relay/meta"
+ "github.com/songquanpeng/one-api/relay/relaymode"
+)
+
+func GetRequestURL(meta *meta.Meta) (string, error) {
+ if meta.Mode == relaymode.ChatCompletions {
+ return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil
+ }
+ return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode)
+}
diff --git a/relay/channel/mistral/constants.go b/relay/adaptor/mistral/constants.go
similarity index 100%
rename from relay/channel/mistral/constants.go
rename to relay/adaptor/mistral/constants.go
diff --git a/relay/channel/moonshot/constants.go b/relay/adaptor/moonshot/constants.go
similarity index 100%
rename from relay/channel/moonshot/constants.go
rename to relay/adaptor/moonshot/constants.go
diff --git a/relay/channel/ollama/adaptor.go b/relay/adaptor/ollama/adaptor.go
similarity index 58%
rename from relay/channel/ollama/adaptor.go
rename to relay/adaptor/ollama/adaptor.go
index e2ae7d2b..ec1b0c40 100644
--- a/relay/channel/ollama/adaptor.go
+++ b/relay/adaptor/ollama/adaptor.go
@@ -1,36 +1,36 @@
package ollama
import (
- "errors"
"fmt"
"io"
"net/http"
+ "github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/channel"
- "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/adaptor"
+ "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
+ "github.com/songquanpeng/one-api/relay/relaymode"
)
type Adaptor struct {
}
-func (a *Adaptor) Init(meta *util.RelayMeta) {
+func (a *Adaptor) Init(meta *meta.Meta) {
}
-func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
- if meta.Mode == constant.RelayModeEmbeddings {
+ if meta.Mode == relaymode.Embeddings {
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
}
return fullRequestURL, nil
}
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
- channel.SetupCommonRequestHeader(c, req, meta)
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
+ adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
@@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil")
}
switch relayMode {
- case constant.RelayModeEmbeddings:
+ case relaymode.Embeddings:
ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
return ollamaEmbeddingRequest, nil
default:
@@ -48,16 +48,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
}
}
-func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
- return channel.DoRequestHelper(a, c, meta, requestBody)
+func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
+ return adaptor.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
- case constant.RelayModeEmbeddings:
+ case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
default:
err, usage = Handler(c, resp)
diff --git a/relay/channel/ollama/constants.go b/relay/adaptor/ollama/constants.go
similarity index 100%
rename from relay/channel/ollama/constants.go
rename to relay/adaptor/ollama/constants.go
diff --git a/relay/channel/ollama/main.go b/relay/adaptor/ollama/main.go
similarity index 97%
rename from relay/channel/ollama/main.go
rename to relay/adaptor/ollama/main.go
index 821a335b..a7e4c058 100644
--- a/relay/channel/ollama/main.go
+++ b/relay/adaptor/ollama/main.go
@@ -5,15 +5,16 @@ import (
"context"
"encoding/json"
"fmt"
+ "github.com/songquanpeng/one-api/common/helper"
+ "github.com/songquanpeng/one-api/common/random"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
- "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
)
@@ -51,7 +52,7 @@ func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse {
choice.FinishReason = "stop"
}
fullTextResponse := openai.TextResponse{
- Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
+ Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
@@ -72,7 +73,7 @@ func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompl
choice.FinishReason = &constant.StopFinishReason
}
response := openai.ChatCompletionsStreamResponse{
- Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
+ Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: ollamaResponse.Model,
diff --git a/relay/channel/ollama/model.go b/relay/adaptor/ollama/model.go
similarity index 100%
rename from relay/channel/ollama/model.go
rename to relay/adaptor/ollama/model.go
diff --git a/relay/channel/openai/adaptor.go b/relay/adaptor/openai/adaptor.go
similarity index 52%
rename from relay/channel/openai/adaptor.go
rename to relay/adaptor/openai/adaptor.go
index 260a35ae..24cf718f 100644
--- a/relay/channel/openai/adaptor.go
+++ b/relay/adaptor/openai/adaptor.go
@@ -4,11 +4,12 @@ import (
"fmt"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/common"
- "github.com/songquanpeng/one-api/relay/channel"
- "github.com/songquanpeng/one-api/relay/channel/minimax"
+ "github.com/songquanpeng/one-api/relay/adaptor"
+ "github.com/songquanpeng/one-api/relay/adaptor/minimax"
+ "github.com/songquanpeng/one-api/relay/channeltype"
+ "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
+ "github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
@@ -18,13 +19,20 @@ type Adaptor struct {
ChannelType int
}
-func (a *Adaptor) Init(meta *util.RelayMeta) {
+func (a *Adaptor) Init(meta *meta.Meta) {
a.ChannelType = meta.ChannelType
}
-func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
switch meta.ChannelType {
- case common.ChannelTypeAzure:
+ case channeltype.Azure:
+ if meta.Mode == relaymode.ImagesGenerations {
+ // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
+ // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
+ fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion)
+ return fullRequestURL, nil
+ }
+
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(meta.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
@@ -34,22 +42,22 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
//https://github.com/songquanpeng/one-api/issues/1191
// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
- return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
- case common.ChannelTypeMinimax:
+ return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
+ case channeltype.Minimax:
return minimax.GetRequestURL(meta)
default:
- return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
+ return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
}
}
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
- channel.SetupCommonRequestHeader(c, req, meta)
- if meta.ChannelType == common.ChannelTypeAzure {
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
+ adaptor.SetupCommonRequestHeader(c, req, meta)
+ if meta.ChannelType == channeltype.Azure {
req.Header.Set("api-key", meta.APIKey)
return nil
}
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
- if meta.ChannelType == common.ChannelTypeOpenRouter {
+ if meta.ChannelType == channeltype.OpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API")
}
@@ -63,11 +71,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return request, nil
}
-func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
- return channel.DoRequestHelper(a, c, meta, requestBody)
+func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
+ return adaptor.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText, usage = StreamHandler(c, resp, meta.Mode)
@@ -75,7 +90,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
}
} else {
- err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+ switch meta.Mode {
+ case relaymode.ImagesGenerations:
+ err, _ = ImageHandler(c, resp)
+ default:
+ err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+ }
}
return
}
diff --git a/relay/adaptor/openai/compatible.go b/relay/adaptor/openai/compatible.go
new file mode 100644
index 00000000..200eac44
--- /dev/null
+++ b/relay/adaptor/openai/compatible.go
@@ -0,0 +1,50 @@
+package openai
+
+import (
+ "github.com/songquanpeng/one-api/relay/adaptor/ai360"
+ "github.com/songquanpeng/one-api/relay/adaptor/baichuan"
+ "github.com/songquanpeng/one-api/relay/adaptor/groq"
+ "github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu"
+ "github.com/songquanpeng/one-api/relay/adaptor/minimax"
+ "github.com/songquanpeng/one-api/relay/adaptor/mistral"
+ "github.com/songquanpeng/one-api/relay/adaptor/moonshot"
+ "github.com/songquanpeng/one-api/relay/adaptor/stepfun"
+ "github.com/songquanpeng/one-api/relay/channeltype"
+)
+
+var CompatibleChannels = []int{
+ channeltype.Azure,
+ channeltype.AI360,
+ channeltype.Moonshot,
+ channeltype.Baichuan,
+ channeltype.Minimax,
+ channeltype.Mistral,
+ channeltype.Groq,
+ channeltype.LingYiWanWu,
+ channeltype.StepFun,
+}
+
+func GetCompatibleChannelMeta(channelType int) (string, []string) {
+ switch channelType {
+ case channeltype.Azure:
+ return "azure", ModelList
+ case channeltype.AI360:
+ return "360", ai360.ModelList
+ case channeltype.Moonshot:
+ return "moonshot", moonshot.ModelList
+ case channeltype.Baichuan:
+ return "baichuan", baichuan.ModelList
+ case channeltype.Minimax:
+ return "minimax", minimax.ModelList
+ case channeltype.Mistral:
+ return "mistralai", mistral.ModelList
+ case channeltype.Groq:
+ return "groq", groq.ModelList
+ case channeltype.LingYiWanWu:
+ return "lingyiwanwu", lingyiwanwu.ModelList
+ case channeltype.StepFun:
+ return "stepfun", stepfun.ModelList
+ default:
+ return "openai", ModelList
+ }
+}
diff --git a/relay/channel/openai/constants.go b/relay/adaptor/openai/constants.go
similarity index 100%
rename from relay/channel/openai/constants.go
rename to relay/adaptor/openai/constants.go
diff --git a/relay/adaptor/openai/helper.go b/relay/adaptor/openai/helper.go
new file mode 100644
index 00000000..7d73303b
--- /dev/null
+++ b/relay/adaptor/openai/helper.go
@@ -0,0 +1,30 @@
+package openai
+
+import (
+ "fmt"
+ "github.com/songquanpeng/one-api/relay/channeltype"
+ "github.com/songquanpeng/one-api/relay/model"
+ "strings"
+)
+
+func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
+ usage := &model.Usage{}
+ usage.PromptTokens = promptTokens
+ usage.CompletionTokens = CountTokenText(responseText, modeName)
+ usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
+ return usage
+}
+
+func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
+ fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
+
+ if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
+ switch channelType {
+ case channeltype.OpenAI:
+ fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
+ case channeltype.Azure:
+ fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
+ }
+ }
+ return fullRequestURL
+}
diff --git a/relay/adaptor/openai/image.go b/relay/adaptor/openai/image.go
new file mode 100644
index 00000000..0f89618a
--- /dev/null
+++ b/relay/adaptor/openai/image.go
@@ -0,0 +1,44 @@
+package openai
+
+import (
+ "bytes"
+ "encoding/json"
+ "github.com/gin-gonic/gin"
+ "github.com/songquanpeng/one-api/relay/model"
+ "io"
+ "net/http"
+)
+
+func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+ var imageResponse ImageResponse
+ responseBody, err := io.ReadAll(resp.Body)
+
+ if err != nil {
+ return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = json.Unmarshal(responseBody, &imageResponse)
+ if err != nil {
+ return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+ }
+
+ resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
+
+ for k, v := range resp.Header {
+ c.Writer.Header().Set(k, v[0])
+ }
+ c.Writer.WriteHeader(resp.StatusCode)
+
+ _, err = io.Copy(c.Writer, resp.Body)
+ if err != nil {
+ return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+ }
+ return nil, nil
+}
diff --git a/relay/channel/openai/main.go b/relay/adaptor/openai/main.go
similarity index 97%
rename from relay/channel/openai/main.go
rename to relay/adaptor/openai/main.go
index 63cb9ae8..68d8f48f 100644
--- a/relay/channel/openai/main.go
+++ b/relay/adaptor/openai/main.go
@@ -8,8 +8,8 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
"strings"
@@ -46,7 +46,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
data = data[6:]
if !strings.HasPrefix(data, "[DONE]") {
switch relayMode {
- case constant.RelayModeChatCompletions:
+ case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
@@ -59,7 +59,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
if streamResponse.Usage != nil {
usage = streamResponse.Usage
}
- case constant.RelayModeCompletions:
+ case relaymode.Completions:
var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil {
diff --git a/relay/channel/openai/model.go b/relay/adaptor/openai/model.go
similarity index 93%
rename from relay/channel/openai/model.go
rename to relay/adaptor/openai/model.go
index 30d77739..ce252ff6 100644
--- a/relay/channel/openai/model.go
+++ b/relay/adaptor/openai/model.go
@@ -110,11 +110,16 @@ type EmbeddingResponse struct {
model.Usage `json:"usage"`
}
+type ImageData struct {
+ Url string `json:"url,omitempty"`
+ B64Json string `json:"b64_json,omitempty"`
+ RevisedPrompt string `json:"revised_prompt,omitempty"`
+}
+
type ImageResponse struct {
- Created int `json:"created"`
- Data []struct {
- Url string `json:"url"`
- }
+ Created int64 `json:"created"`
+ Data []ImageData `json:"data"`
+ //model.Usage `json:"usage"`
}
type ChatCompletionsStreamResponseChoice struct {
diff --git a/relay/channel/openai/token.go b/relay/adaptor/openai/token.go
similarity index 98%
rename from relay/channel/openai/token.go
rename to relay/adaptor/openai/token.go
index d18ce0df..1e61d255 100644
--- a/relay/channel/openai/token.go
+++ b/relay/adaptor/openai/token.go
@@ -4,10 +4,10 @@ import (
"fmt"
"github.com/Laisky/errors/v2"
"github.com/pkoukk/tiktoken-go"
- "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger"
+ billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/model"
"math"
"strings"
@@ -28,7 +28,7 @@ func InitTokenEncoders() {
if err != nil {
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
- for model := range common.ModelRatio {
+ for model := range billingratio.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") {
diff --git a/relay/channel/openai/util.go b/relay/adaptor/openai/util.go
similarity index 100%
rename from relay/channel/openai/util.go
rename to relay/adaptor/openai/util.go
diff --git a/relay/channel/palm/adaptor.go b/relay/adaptor/palm/adaptor.go
similarity index 59%
rename from relay/channel/palm/adaptor.go
rename to relay/adaptor/palm/adaptor.go
index 15ee010d..fa73dd30 100644
--- a/relay/channel/palm/adaptor.go
+++ b/relay/adaptor/palm/adaptor.go
@@ -4,10 +4,10 @@ import (
"fmt"
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/channel"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/relay/adaptor"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
+ "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
@@ -15,16 +15,16 @@ import (
type Adaptor struct {
}
-func (a *Adaptor) Init(meta *util.RelayMeta) {
+func (a *Adaptor) Init(meta *meta.Meta) {
}
-func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
+func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil
}
-func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
- channel.SetupCommonRequestHeader(c, req, meta)
+func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
+ adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey)
return nil
}
@@ -36,11 +36,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil
}
-func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
- return channel.DoRequestHelper(a, c, meta, requestBody)
+func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
+ if request == nil {
+ return nil, errors.New("request is nil")
+ }
+ return request, nil
}
-func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
+ return adaptor.DoRequestHelper(a, c, meta, requestBody)
+}
+
+func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
var responseText string
err, responseText = StreamHandler(c, resp)
diff --git a/relay/channel/palm/constants.go b/relay/adaptor/palm/constants.go
similarity index 100%
rename from relay/channel/palm/constants.go
rename to relay/adaptor/palm/constants.go
diff --git a/relay/channel/palm/model.go b/relay/adaptor/palm/model.go
similarity index 100%
rename from relay/channel/palm/model.go
rename to relay/adaptor/palm/model.go
diff --git a/relay/channel/palm/palm.go b/relay/adaptor/palm/palm.go
similarity index 97%
rename from relay/channel/palm/palm.go
rename to relay/adaptor/palm/palm.go
index 56738544..1e60e7cd 100644
--- a/relay/channel/palm/palm.go
+++ b/relay/adaptor/palm/palm.go
@@ -7,7 +7,8 @@ import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/relay/channel/openai"
+ "github.com/songquanpeng/one-api/common/random"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
@@ -74,7 +75,7 @@ func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletio
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := ""
- responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID())
+ responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
createdTime := helper.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
diff --git a/relay/adaptor/stepfun/constants.go b/relay/adaptor/stepfun/constants.go
new file mode 100644
index 00000000..a82e562b
--- /dev/null
+++ b/relay/adaptor/stepfun/constants.go
@@ -0,0 +1,7 @@
+package stepfun
+
+var ModelList = []string{
+ "step-1-32k",
+ "step-1v-32k",
+ "step-1-200k",
+}
diff --git a/relay/channel/tencent/adaptor.go b/relay/adaptor/tencent/adaptor.go
similarity index 100%
rename from relay/channel/tencent/adaptor.go
rename to relay/adaptor/tencent/adaptor.go
diff --git a/relay/channel/tencent/constants.go b/relay/adaptor/tencent/constants.go
similarity index 100%
rename from relay/channel/tencent/constants.go
rename to relay/adaptor/tencent/constants.go
diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go
new file mode 100644
index 00000000..aa87e9ce
--- /dev/null
+++ b/relay/adaptor/tencent/main.go
@@ -0,0 +1,238 @@
+package tencent
+
+// import (
+// "bufio"
+// "crypto/hmac"
+// "crypto/sha1"
+// "encoding/base64"
+// "encoding/json"
+// "github.com/Laisky/errors/v2"
+// "fmt"
+// "github.com/gin-gonic/gin"
+// "github.com/songquanpeng/one-api/common"
+// "github.com/songquanpeng/one-api/common/helper"
+// "github.com/songquanpeng/one-api/common/logger"
+// "github.com/songquanpeng/one-api/relay/channel/openai"
+// "github.com/songquanpeng/one-api/relay/constant"
+// "github.com/songquanpeng/one-api/relay/model"
+// "io"
+// "net/http"
+// "sort"
+// "strconv"
+// "strings"
+// )
+
+// // https://cloud.tencent.com/document/product/1729/97732
+
+// func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
+// messages := make([]Message, 0, len(request.Messages))
+// for i := 0; i < len(request.Messages); i++ {
+// message := request.Messages[i]
+// if message.Role == "system" {
+// messages = append(messages, Message{
+// Role: "user",
+// Content: message.StringContent(),
+// })
+// messages = append(messages, Message{
+// Role: "assistant",
+// Content: "Okay",
+// })
+// continue
+// }
+// messages = append(messages, Message{
+// Content: message.StringContent(),
+// Role: message.Role,
+// })
+// }
+// stream := 0
+// if request.Stream {
+// stream = 1
+// }
+// return &ChatRequest{
+// Timestamp: helper.GetTimestamp(),
+// Expired: helper.GetTimestamp() + 24*60*60,
+// QueryID: helper.GetUUID(),
+// Temperature: request.Temperature,
+// TopP: request.TopP,
+// Stream: stream,
+// Messages: messages,
+// }
+// }
+
+// func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
+// fullTextResponse := openai.TextResponse{
+// Object: "chat.completion",
+// Created: helper.GetTimestamp(),
+// Usage: response.Usage,
+// }
+// if len(response.Choices) > 0 {
+// choice := openai.TextResponseChoice{
+// Index: 0,
+// Message: model.Message{
+// Role: "assistant",
+// Content: response.Choices[0].Messages.Content,
+// },
+// FinishReason: response.Choices[0].FinishReason,
+// }
+// fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
+// }
+// return &fullTextResponse
+// }
+
+// func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
+// response := openai.ChatCompletionsStreamResponse{
+// Object: "chat.completion.chunk",
+// Created: helper.GetTimestamp(),
+// Model: "tencent-hunyuan",
+// }
+// if len(TencentResponse.Choices) > 0 {
+// var choice openai.ChatCompletionsStreamResponseChoice
+// choice.Delta.Content = TencentResponse.Choices[0].Delta.Content
+// if TencentResponse.Choices[0].FinishReason == "stop" {
+// choice.FinishReason = &constant.StopFinishReason
+// }
+// response.Choices = append(response.Choices, choice)
+// }
+// return &response
+// }
+
+// func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
+// var responseText string
+// scanner := bufio.NewScanner(resp.Body)
+// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
+// if atEOF && len(data) == 0 {
+// return 0, nil, nil
+// }
+// if i := strings.Index(string(data), "\n"); i >= 0 {
+// return i + 1, data[0:i], nil
+// }
+// if atEOF {
+// return len(data), data, nil
+// }
+// return 0, nil, nil
+// })
+// dataChan := make(chan string)
+// stopChan := make(chan bool)
+// go func() {
+// for scanner.Scan() {
+// data := scanner.Text()
+// if len(data) < 5 { // ignore blank line or wrong format
+// continue
+// }
+// if data[:5] != "data:" {
+// continue
+// }
+// data = data[5:]
+// dataChan <- data
+// }
+// stopChan <- true
+// }()
+// common.SetEventStreamHeaders(c)
+// c.Stream(func(w io.Writer) bool {
+// select {
+// case data := <-dataChan:
+// var TencentResponse ChatResponse
+// err := json.Unmarshal([]byte(data), &TencentResponse)
+// if err != nil {
+// logger.SysError("error unmarshalling stream response: " + err.Error())
+// return true
+// }
+// response := streamResponseTencent2OpenAI(&TencentResponse)
+// if len(response.Choices) != 0 {
+// responseText += response.Choices[0].Delta.Content
+// }
+// jsonResponse, err := json.Marshal(response)
+// if err != nil {
+// logger.SysError("error marshalling stream response: " + err.Error())
+// return true
+// }
+// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
+// return true
+// case <-stopChan:
+// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
+// return false
+// }
+// })
+// err := resp.Body.Close()
+// if err != nil {
+// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
+// }
+// return nil, responseText
+// }
+
+// func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
+// var TencentResponse ChatResponse
+// responseBody, err := io.ReadAll(resp.Body)
+// if err != nil {
+// return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
+// }
+// err = resp.Body.Close()
+// if err != nil {
+// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
+// }
+// err = json.Unmarshal(responseBody, &TencentResponse)
+// if err != nil {
+// return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
+// }
+// if TencentResponse.Error.Code != 0 {
+// return &model.ErrorWithStatusCode{
+// Error: model.Error{
+// Message: TencentResponse.Error.Message,
+// Code: TencentResponse.Error.Code,
+// },
+// StatusCode: resp.StatusCode,
+// }, nil
+// }
+// fullTextResponse := responseTencent2OpenAI(&TencentResponse)
+// fullTextResponse.Model = "hunyuan"
+// jsonResponse, err := json.Marshal(fullTextResponse)
+// if err != nil {
+// return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
+// }
+// c.Writer.Header().Set("Content-Type", "application/json")
+// c.Writer.WriteHeader(resp.StatusCode)
+// _, err = c.Writer.Write(jsonResponse)
+// if err != nil {
+// return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil
+// }
+// return nil, &fullTextResponse.Usage
+// }
+
+// func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) {
+// parts := strings.Split(config, "|")
+// if len(parts) != 3 {
+// err = errors.New("invalid tencent config")
+// return
+// }
+// appId, err = strconv.ParseInt(parts[0], 10, 64)
+// secretId = parts[1]
+// secretKey = parts[2]
+// return
+// }
+
+// func GetSign(req ChatRequest, secretKey string) string {
+// params := make([]string, 0)
+// params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10))
+// params = append(params, "secret_id="+req.SecretId)
+// params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10))
+// params = append(params, "query_id="+req.QueryID)
+// params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64))
+// params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64))
+// params = append(params, "stream="+strconv.Itoa(req.Stream))
+// params = append(params, "expired="+strconv.FormatInt(req.Expired, 10))
+
+// var messageStr string
+// for _, msg := range req.Messages {
+// messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content)
+// }
+// messageStr = strings.TrimSuffix(messageStr, ",")
+// params = append(params, "messages=["+messageStr+"]")
+
+// sort.Strings(params)
+// url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&")
+// mac := hmac.New(sha1.New, []byte(secretKey))
+// signURL := url
+// mac.Write([]byte(signURL))
+// sign := mac.Sum([]byte(nil))
+// return base64.StdEncoding.EncodeToString(sign)
+// }
diff --git a/relay/channel/tencent/model.go b/relay/adaptor/tencent/model.go
similarity index 100%
rename from relay/channel/tencent/model.go
rename to relay/adaptor/tencent/model.go
diff --git a/relay/channel/xunfei/adaptor.go b/relay/adaptor/xunfei/adaptor.go
similarity index 100%
rename from relay/channel/xunfei/adaptor.go
rename to relay/adaptor/xunfei/adaptor.go
diff --git a/relay/channel/xunfei/constants.go b/relay/adaptor/xunfei/constants.go
similarity index 100%
rename from relay/channel/xunfei/constants.go
rename to relay/adaptor/xunfei/constants.go
diff --git a/relay/channel/xunfei/main.go b/relay/adaptor/xunfei/main.go
similarity index 100%
rename from relay/channel/xunfei/main.go
rename to relay/adaptor/xunfei/main.go
diff --git a/relay/channel/xunfei/model.go b/relay/adaptor/xunfei/model.go
similarity index 100%
rename from relay/channel/xunfei/model.go
rename to relay/adaptor/xunfei/model.go
diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go
new file mode 100644
index 00000000..424fabd6
--- /dev/null
+++ b/relay/adaptor/zhipu/adaptor.go
@@ -0,0 +1,145 @@
+package zhipu
+
+// import (
+// "github.com/Laisky/errors/v2"
+// "fmt"
+// "github.com/gin-gonic/gin"
+// "github.com/songquanpeng/one-api/relay/adaptor"
+// "github.com/songquanpeng/one-api/relay/adaptor/openai"
+// "github.com/songquanpeng/one-api/relay/meta"
+// "github.com/songquanpeng/one-api/relay/model"
+// "github.com/songquanpeng/one-api/relay/relaymode"
+// "io"
+// "math"
+// "net/http"
+// "strings"
+// )
+
+// type Adaptor struct {
+// APIVersion string
+// }
+
+// func (a *Adaptor) Init(meta *meta.Meta) {
+
+// }
+
+// func (a *Adaptor) SetVersionByModeName(modelName string) {
+// if strings.HasPrefix(modelName, "glm-") {
+// a.APIVersion = "v4"
+// } else {
+// a.APIVersion = "v3"
+// }
+// }
+
+// func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
+// switch meta.Mode {
+// case relaymode.ImagesGenerations:
+// return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
+// case relaymode.Embeddings:
+// return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
+// }
+// a.SetVersionByModeName(meta.ActualModelName)
+// if a.APIVersion == "v4" {
+// return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
+// }
+// method := "invoke"
+// if meta.IsStream {
+// method = "sse-invoke"
+// }
+// return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
+// }
+
+// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
+// adaptor.SetupCommonRequestHeader(c, req, meta)
+// token := GetToken(meta.APIKey)
+// req.Header.Set("Authorization", token)
+// return nil
+// }
+
+// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
+// if request == nil {
+// return nil, errors.New("request is nil")
+// }
+// switch relayMode {
+// case relaymode.Embeddings:
+// baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
+// return baiduEmbeddingRequest, nil
+// default:
+// // TopP (0.0, 1.0)
+// request.TopP = math.Min(0.99, request.TopP)
+// request.TopP = math.Max(0.01, request.TopP)
+
+// // Temperature (0.0, 1.0)
+// request.Temperature = math.Min(0.99, request.Temperature)
+// request.Temperature = math.Max(0.01, request.Temperature)
+// a.SetVersionByModeName(request.Model)
+// if a.APIVersion == "v4" {
+// return request, nil
+// }
+// return ConvertRequest(*request), nil
+// }
+// }
+
+// func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
+// if request == nil {
+// return nil, errors.New("request is nil")
+// }
+// newRequest := ImageRequest{
+// Model: request.Model,
+// Prompt: request.Prompt,
+// UserId: request.User,
+// }
+// return newRequest, nil
+// }
+
+// func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
+// return adaptor.DoRequestHelper(a, c, meta, requestBody)
+// }
+
+// func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+// if meta.IsStream {
+// err, _, usage = openai.StreamHandler(c, resp, meta.Mode)
+// } else {
+// err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
+// }
+// return
+// }
+
+// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
+// switch meta.Mode {
+// case relaymode.Embeddings:
+// err, usage = EmbeddingsHandler(c, resp)
+// return
+// case relaymode.ImagesGenerations:
+// err, usage = openai.ImageHandler(c, resp)
+// return
+// }
+// if a.APIVersion == "v4" {
+// return a.DoResponseV4(c, resp, meta)
+// }
+// if meta.IsStream {
+// err, usage = StreamHandler(c, resp)
+// } else {
+// if meta.Mode == relaymode.Embeddings {
+// err, usage = EmbeddingsHandler(c, resp)
+// } else {
+// err, usage = Handler(c, resp)
+// }
+// }
+// return
+// }
+
+// func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest {
+// return &EmbeddingRequest{
+// Model: "embedding-2",
+// Input: request.Input.(string),
+// }
+// }
+
+// func (a *Adaptor) GetModelList() []string {
+// return ModelList
+// }
+
+// func (a *Adaptor) GetChannelName() string {
+// return "zhipu"
+// }
diff --git a/relay/channel/zhipu/constants.go b/relay/adaptor/zhipu/constants.go
similarity index 100%
rename from relay/channel/zhipu/constants.go
rename to relay/adaptor/zhipu/constants.go
diff --git a/relay/channel/zhipu/main.go b/relay/adaptor/zhipu/main.go
similarity index 100%
rename from relay/channel/zhipu/main.go
rename to relay/adaptor/zhipu/main.go
diff --git a/relay/channel/zhipu/model.go b/relay/adaptor/zhipu/model.go
similarity index 100%
rename from relay/channel/zhipu/model.go
rename to relay/adaptor/zhipu/model.go
diff --git a/relay/apitype/define.go b/relay/apitype/define.go
new file mode 100644
index 00000000..82d32a50
--- /dev/null
+++ b/relay/apitype/define.go
@@ -0,0 +1,17 @@
+package apitype
+
+const (
+ OpenAI = iota
+ Anthropic
+ PaLM
+ Baidu
+ Zhipu
+ Ali
+ Xunfei
+ AIProxyLibrary
+ Tencent
+ Gemini
+ Ollama
+
+ Dummy // this one is only for count, do not add any channel after this
+)
diff --git a/relay/billing/billing.go b/relay/billing/billing.go
new file mode 100644
index 00000000..a99d37ee
--- /dev/null
+++ b/relay/billing/billing.go
@@ -0,0 +1,42 @@
+package billing
+
+import (
+ "context"
+ "fmt"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/model"
+)
+
+func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
+ if preConsumedQuota != 0 {
+ go func(ctx context.Context) {
+ // return pre-consumed quota
+ err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
+ if err != nil {
+ logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
+ }
+ }(ctx)
+ }
+}
+
+func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
+ // quotaDelta is remaining quota to be consumed
+ err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
+ if err != nil {
+ logger.SysError("error consuming token remain quota: " + err.Error())
+ }
+ err = model.CacheUpdateUserQuota(ctx, userId)
+ if err != nil {
+ logger.SysError("error update user quota cache: " + err.Error())
+ }
+ // totalQuota is total quota consumed
+ if totalQuota != 0 {
+ logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
+ model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
+ model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
+ model.UpdateChannelUsedQuota(channelId, totalQuota)
+ }
+ if totalQuota <= 0 {
+ logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
+ }
+}
diff --git a/common/group-ratio.go b/relay/billing/ratio/group.go
similarity index 97%
rename from common/group-ratio.go
rename to relay/billing/ratio/group.go
index 2de6e810..8e9c5b73 100644
--- a/common/group-ratio.go
+++ b/relay/billing/ratio/group.go
@@ -1,4 +1,4 @@
-package common
+package ratio
import (
"encoding/json"
diff --git a/relay/billing/ratio/image.go b/relay/billing/ratio/image.go
new file mode 100644
index 00000000..5a29cddc
--- /dev/null
+++ b/relay/billing/ratio/image.go
@@ -0,0 +1,51 @@
+package ratio
+
+var ImageSizeRatios = map[string]map[string]float64{
+ "dall-e-2": {
+ "256x256": 1,
+ "512x512": 1.125,
+ "1024x1024": 1.25,
+ },
+ "dall-e-3": {
+ "1024x1024": 1,
+ "1024x1792": 2,
+ "1792x1024": 2,
+ },
+ "ali-stable-diffusion-xl": {
+ "512x1024": 1,
+ "1024x768": 1,
+ "1024x1024": 1,
+ "576x1024": 1,
+ "1024x576": 1,
+ },
+ "ali-stable-diffusion-v1.5": {
+ "512x1024": 1,
+ "1024x768": 1,
+ "1024x1024": 1,
+ "576x1024": 1,
+ "1024x576": 1,
+ },
+ "wanx-v1": {
+ "1024x1024": 1,
+ "720x1280": 1,
+ "1280x720": 1,
+ },
+}
+
+var ImageGenerationAmounts = map[string][2]int{
+ "dall-e-2": {1, 10},
+ "dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
+ "ali-stable-diffusion-xl": {1, 4}, // Ali
+ "ali-stable-diffusion-v1.5": {1, 4}, // Ali
+ "wanx-v1": {1, 4}, // Ali
+ "cogview-3": {1, 1},
+}
+
+var ImagePromptLengthLimitations = map[string]int{
+ "dall-e-2": 1000,
+ "dall-e-3": 4000,
+ "ali-stable-diffusion-xl": 4000,
+ "ali-stable-diffusion-v1.5": 4000,
+ "wanx-v1": 4000,
+ "cogview-3": 833,
+}
diff --git a/common/model-ratio.go b/relay/billing/ratio/model.go
similarity index 85%
rename from common/model-ratio.go
rename to relay/billing/ratio/model.go
index 4f6acb7b..1f7daef6 100644
--- a/common/model-ratio.go
+++ b/relay/billing/ratio/model.go
@@ -1,4 +1,4 @@
-package common
+package ratio
import (
"encoding/json"
@@ -63,8 +63,8 @@ var ModelRatio = map[string]float64{
"text-search-ada-doc-001": 10,
"text-moderation-stable": 0.1,
"text-moderation-latest": 0.1,
- "dall-e-2": 8, // $0.016 - $0.020 / image
- "dall-e-3": 20, // $0.040 - $0.120 / image
+ "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image
+ "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image
// https://www.anthropic.com/api#pricing
"claude-instant-1.2": 0.8 / 1000 * USD,
"claude-2.0": 8.0 / 1000 * USD,
@@ -73,14 +73,22 @@ var ModelRatio = map[string]float64{
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
- "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens
- "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens
- "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
- "ERNIE-Bot-8K": 0.024 * RMB,
- "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
- "bge-large-zh": 0.002 * RMB,
- "bge-large-en": 0.002 * RMB,
- "bge-large-8k": 0.002 * RMB,
+ "ERNIE-4.0-8K": 0.120 * RMB,
+ "ERNIE-3.5-8K": 0.012 * RMB,
+ "ERNIE-3.5-8K-0205": 0.024 * RMB,
+ "ERNIE-3.5-8K-1222": 0.012 * RMB,
+ "ERNIE-Bot-8K": 0.024 * RMB,
+ "ERNIE-3.5-4K-0205": 0.012 * RMB,
+ "ERNIE-Speed-8K": 0.004 * RMB,
+ "ERNIE-Speed-128K": 0.004 * RMB,
+ "ERNIE-Lite-8K-0922": 0.008 * RMB,
+ "ERNIE-Lite-8K-0308": 0.003 * RMB,
+ "ERNIE-Tiny-8K": 0.001 * RMB,
+ "BLOOMZ-7B": 0.004 * RMB,
+ "Embedding-V1": 0.002 * RMB,
+ "bge-large-zh": 0.002 * RMB,
+ "bge-large-en": 0.002 * RMB,
+ "tao-8k": 0.002 * RMB,
// https://ai.google.dev/pricing
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
@@ -89,18 +97,24 @@ var ModelRatio = map[string]float64{
"gemini-1.0-pro-001": 1,
"gemini-1.5-pro": 1,
// https://open.bigmodel.cn/pricing
- "glm-4": 0.1 * RMB,
- "glm-4v": 0.1 * RMB,
- "glm-3-turbo": 0.005 * RMB,
- "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
- "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
- "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
- "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
- "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
+ "glm-4": 0.1 * RMB,
+ "glm-4v": 0.1 * RMB,
+ "glm-3-turbo": 0.005 * RMB,
+ "embedding-2": 0.0005 * RMB,
+ "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens
+ "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens
+ "chatglm_std": 0.3572, // ¥0.005 / 1k tokens
+ "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens
+ "cogview-3": 0.25 * RMB,
+ // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing
+ "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens
"qwen-plus": 1.4286, // ¥0.02 / 1k tokens
"qwen-max": 1.4286, // ¥0.02 / 1k tokens
"qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens
"text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens
+ "ali-stable-diffusion-xl": 8,
+ "ali-stable-diffusion-v1.5": 8,
+ "wanx-v1": 8,
"SparkDesk": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens
"SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens
diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go
deleted file mode 100644
index b312934b..00000000
--- a/relay/channel/ali/adaptor.go
+++ /dev/null
@@ -1,83 +0,0 @@
-package ali
-
-// import (
-// "github.com/Laisky/errors/v2"
-// "fmt"
-// "github.com/gin-gonic/gin"
-// "github.com/songquanpeng/one-api/common"
-// "github.com/songquanpeng/one-api/relay/channel"
-// "github.com/songquanpeng/one-api/relay/constant"
-// "github.com/songquanpeng/one-api/relay/model"
-// "github.com/songquanpeng/one-api/relay/util"
-// "io"
-// "net/http"
-// )
-
-// // https://help.aliyun.com/zh/dashscope/developer-reference/api-details
-
-// type Adaptor struct {
-// }
-
-// func (a *Adaptor) Init(meta *util.RelayMeta) {
-
-// }
-
-// func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
-// fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
-// if meta.Mode == constant.RelayModeEmbeddings {
-// fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
-// }
-// return fullRequestURL, nil
-// }
-
-// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
-// channel.SetupCommonRequestHeader(c, req, meta)
-// req.Header.Set("Authorization", "Bearer "+meta.APIKey)
-// if meta.IsStream {
-// req.Header.Set("X-DashScope-SSE", "enable")
-// }
-// if c.GetString(common.ConfigKeyPlugin) != "" {
-// req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin))
-// }
-// return nil
-// }
-
-// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
-// if request == nil {
-// return nil, errors.New("request is nil")
-// }
-// switch relayMode {
-// case constant.RelayModeEmbeddings:
-// baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
-// return baiduEmbeddingRequest, nil
-// default:
-// baiduRequest := ConvertRequest(*request)
-// return baiduRequest, nil
-// }
-// }
-
-// func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
-// return channel.DoRequestHelper(a, c, meta, requestBody)
-// }
-
-// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
-// if meta.IsStream {
-// err, usage = StreamHandler(c, resp)
-// } else {
-// switch meta.Mode {
-// case constant.RelayModeEmbeddings:
-// err, usage = EmbeddingHandler(c, resp)
-// default:
-// err, usage = Handler(c, resp)
-// }
-// }
-// return
-// }
-
-// func (a *Adaptor) GetModelList() []string {
-// return ModelList
-// }
-
-// func (a *Adaptor) GetChannelName() string {
-// return "ali"
-// }
diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go
deleted file mode 100644
index d0cd341c..00000000
--- a/relay/channel/ali/model.go
+++ /dev/null
@@ -1,71 +0,0 @@
-package ali
-
-// type Message struct {
-// Content string `json:"content"`
-// Role string `json:"role"`
-// }
-
-// type Input struct {
-// //Prompt string `json:"prompt"`
-// Messages []Message `json:"messages"`
-// }
-
-// type Parameters struct {
-// TopP float64 `json:"top_p,omitempty"`
-// TopK int `json:"top_k,omitempty"`
-// Seed uint64 `json:"seed,omitempty"`
-// EnableSearch bool `json:"enable_search,omitempty"`
-// IncrementalOutput bool `json:"incremental_output,omitempty"`
-// }
-
-// type ChatRequest struct {
-// Model string `json:"model"`
-// Input Input `json:"input"`
-// Parameters Parameters `json:"parameters,omitempty"`
-// }
-
-// type EmbeddingRequest struct {
-// Model string `json:"model"`
-// Input struct {
-// Texts []string `json:"texts"`
-// } `json:"input"`
-// Parameters *struct {
-// TextType string `json:"text_type,omitempty"`
-// } `json:"parameters,omitempty"`
-// }
-
-// type Embedding struct {
-// Embedding []float64 `json:"embedding"`
-// TextIndex int `json:"text_index"`
-// }
-
-// type EmbeddingResponse struct {
-// Output struct {
-// Embeddings []Embedding `json:"embeddings"`
-// } `json:"output"`
-// Usage Usage `json:"usage"`
-// Error
-// }
-
-// type Error struct {
-// Code string `json:"code"`
-// Message string `json:"message"`
-// RequestId string `json:"request_id"`
-// }
-
-// type Usage struct {
-// InputTokens int `json:"input_tokens"`
-// OutputTokens int `json:"output_tokens"`
-// TotalTokens int `json:"total_tokens"`
-// }
-
-// type Output struct {
-// Text string `json:"text"`
-// FinishReason string `json:"finish_reason"`
-// }
-
-// type ChatResponse struct {
-// Output Output `json:"output"`
-// Usage Usage `json:"usage"`
-// Error
-// }
diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go
deleted file mode 100644
index 45a4e901..00000000
--- a/relay/channel/baidu/constants.go
+++ /dev/null
@@ -1,13 +0,0 @@
-package baidu
-
-var ModelList = []string{
- "ERNIE-Bot-4",
- "ERNIE-Bot-8K",
- "ERNIE-Bot",
- "ERNIE-Speed",
- "ERNIE-Bot-turbo",
- "Embedding-V1",
- "bge-large-zh",
- "bge-large-en",
- "tao-8k",
-}
diff --git a/relay/channel/interface.go b/relay/channel/interface.go
deleted file mode 100644
index e25db83f..00000000
--- a/relay/channel/interface.go
+++ /dev/null
@@ -1,20 +0,0 @@
-package channel
-
-import (
- "github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
- "io"
- "net/http"
-)
-
-type Adaptor interface {
- Init(meta *util.RelayMeta)
- GetRequestURL(meta *util.RelayMeta) (string, error)
- SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error
- ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
- DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error)
- DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode)
- GetModelList() []string
- GetChannelName() string
-}
diff --git a/relay/channel/minimax/main.go b/relay/channel/minimax/main.go
deleted file mode 100644
index a3cd0f14..00000000
--- a/relay/channel/minimax/main.go
+++ /dev/null
@@ -1,16 +0,0 @@
-package minimax
-
-import (
- "fmt"
-
- "github.com/Laisky/errors/v2"
- "github.com/songquanpeng/one-api/relay/constant"
- "github.com/songquanpeng/one-api/relay/util"
-)
-
-func GetRequestURL(meta *util.RelayMeta) (string, error) {
- if meta.Mode == constant.RelayModeChatCompletions {
- return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil
- }
- return "", errors.Errorf("unsupported relay mode %d for minimax", meta.Mode)
-}
diff --git a/relay/channel/openai/compatible.go b/relay/channel/openai/compatible.go
deleted file mode 100644
index e4951a34..00000000
--- a/relay/channel/openai/compatible.go
+++ /dev/null
@@ -1,46 +0,0 @@
-package openai
-
-import (
- "github.com/songquanpeng/one-api/common"
- "github.com/songquanpeng/one-api/relay/channel/ai360"
- "github.com/songquanpeng/one-api/relay/channel/baichuan"
- "github.com/songquanpeng/one-api/relay/channel/groq"
- "github.com/songquanpeng/one-api/relay/channel/lingyiwanwu"
- "github.com/songquanpeng/one-api/relay/channel/minimax"
- "github.com/songquanpeng/one-api/relay/channel/mistral"
- "github.com/songquanpeng/one-api/relay/channel/moonshot"
-)
-
-var CompatibleChannels = []int{
- common.ChannelTypeAzure,
- common.ChannelType360,
- common.ChannelTypeMoonshot,
- common.ChannelTypeBaichuan,
- common.ChannelTypeMinimax,
- common.ChannelTypeMistral,
- common.ChannelTypeGroq,
- common.ChannelTypeLingYiWanWu,
-}
-
-func GetCompatibleChannelMeta(channelType int) (string, []string) {
- switch channelType {
- case common.ChannelTypeAzure:
- return "azure", ModelList
- case common.ChannelType360:
- return "360", ai360.ModelList
- case common.ChannelTypeMoonshot:
- return "moonshot", moonshot.ModelList
- case common.ChannelTypeBaichuan:
- return "baichuan", baichuan.ModelList
- case common.ChannelTypeMinimax:
- return "minimax", minimax.ModelList
- case common.ChannelTypeMistral:
- return "mistralai", mistral.ModelList
- case common.ChannelTypeGroq:
- return "groq", groq.ModelList
- case common.ChannelTypeLingYiWanWu:
- return "lingyiwanwu", lingyiwanwu.ModelList
- default:
- return "openai", ModelList
- }
-}
diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go
deleted file mode 100644
index 9bca8cab..00000000
--- a/relay/channel/openai/helper.go
+++ /dev/null
@@ -1,11 +0,0 @@
-package openai
-
-import "github.com/songquanpeng/one-api/relay/model"
-
-func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
- usage := &model.Usage{}
- usage.PromptTokens = promptTokens
- usage.CompletionTokens = CountTokenText(responseText, modeName)
- usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
- return usage
-}
diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go
deleted file mode 100644
index 9c7bdd36..00000000
--- a/relay/channel/zhipu/adaptor.go
+++ /dev/null
@@ -1,62 +0,0 @@
-package zhipu
-
-// import (
-// "github.com/Laisky/errors/v2"
-// "fmt"
-// "github.com/gin-gonic/gin"
-// "github.com/songquanpeng/one-api/relay/channel"
-// "github.com/songquanpeng/one-api/relay/model"
-// "github.com/songquanpeng/one-api/relay/util"
-// "io"
-// "net/http"
-// )
-
-// type Adaptor struct {
-// }
-
-// func (a *Adaptor) Init(meta *util.RelayMeta) {
-
-// }
-
-// func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
-// method := "invoke"
-// if meta.IsStream {
-// method = "sse-invoke"
-// }
-// return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
-// }
-
-// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
-// channel.SetupCommonRequestHeader(c, req, meta)
-// token := GetToken(meta.APIKey)
-// req.Header.Set("Authorization", token)
-// return nil
-// }
-
-// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
-// if request == nil {
-// return nil, errors.New("request is nil")
-// }
-// return ConvertRequest(*request), nil
-// }
-
-// func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
-// return channel.DoRequestHelper(a, c, meta, requestBody)
-// }
-
-// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
-// if meta.IsStream {
-// err, usage = StreamHandler(c, resp)
-// } else {
-// err, usage = Handler(c, resp)
-// }
-// return
-// }
-
-// func (a *Adaptor) GetModelList() []string {
-// return ModelList
-// }
-
-// func (a *Adaptor) GetChannelName() string {
-// return "zhipu"
-// }
diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go
new file mode 100644
index 00000000..80027a80
--- /dev/null
+++ b/relay/channeltype/define.go
@@ -0,0 +1,39 @@
+package channeltype
+
+const (
+ Unknown = iota
+ OpenAI
+ API2D
+ Azure
+ CloseAI
+ OpenAISB
+ OpenAIMax
+ OhMyGPT
+ Custom
+ Ails
+ AIProxy
+ PaLM
+ API2GPT
+ AIGC2D
+ Anthropic
+ Baidu
+ Zhipu
+ Ali
+ Xunfei
+ AI360
+ OpenRouter
+ AIProxyLibrary
+ FastGPT
+ Tencent
+ Gemini
+ Moonshot
+ Baichuan
+ Minimax
+ Mistral
+ Groq
+ Ollama
+ LingYiWanWu
+ StepFun
+
+ Dummy
+)
diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go
new file mode 100644
index 00000000..01c2918c
--- /dev/null
+++ b/relay/channeltype/helper.go
@@ -0,0 +1,30 @@
+package channeltype
+
+import "github.com/songquanpeng/one-api/relay/apitype"
+
+func ToAPIType(channelType int) int {
+ apiType := apitype.OpenAI
+ switch channelType {
+ case Anthropic:
+ apiType = apitype.Anthropic
+ case Baidu:
+ apiType = apitype.Baidu
+ case PaLM:
+ apiType = apitype.PaLM
+ case Zhipu:
+ apiType = apitype.Zhipu
+ case Ali:
+ apiType = apitype.Ali
+ case Xunfei:
+ apiType = apitype.Xunfei
+ case AIProxyLibrary:
+ apiType = apitype.AIProxyLibrary
+ case Tencent:
+ apiType = apitype.Tencent
+ case Gemini:
+ apiType = apitype.Gemini
+ case Ollama:
+ apiType = apitype.Ollama
+ }
+ return apiType
+}
diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go
new file mode 100644
index 00000000..eec59116
--- /dev/null
+++ b/relay/channeltype/url.go
@@ -0,0 +1,43 @@
+package channeltype
+
+var ChannelBaseURLs = []string{
+ "", // 0
+ "https://api.openai.com", // 1
+ "https://oa.api2d.net", // 2
+ "", // 3
+ "https://api.closeai-proxy.xyz", // 4
+ "https://api.openai-sb.com", // 5
+ "https://api.openaimax.com", // 6
+ "https://api.ohmygpt.com", // 7
+ "", // 8
+ "https://api.caipacity.com", // 9
+ "https://api.aiproxy.io", // 10
+ "https://generativelanguage.googleapis.com", // 11
+ "https://api.api2gpt.com", // 12
+ "https://api.aigc2d.com", // 13
+ "https://api.anthropic.com", // 14
+ "https://aip.baidubce.com", // 15
+ "https://open.bigmodel.cn", // 16
+ "https://dashscope.aliyuncs.com", // 17
+ "", // 18
+ "https://ai.360.cn", // 19
+ "https://openrouter.ai/api", // 20
+ "https://api.aiproxy.io", // 21
+ "https://fastgpt.run/api/openapi", // 22
+ "https://hunyuan.cloud.tencent.com", // 23
+ "https://generativelanguage.googleapis.com", // 24
+ "https://api.moonshot.cn", // 25
+ "https://api.baichuan-ai.com", // 26
+ "https://api.minimax.chat", // 27
+ "https://api.mistral.ai", // 28
+ "https://api.groq.com/openai", // 29
+ "http://localhost:11434", // 30
+ "https://api.lingyiwanwu.com", // 31
+ "https://api.stepfun.com", // 32
+}
+
+func init() {
+ if len(ChannelBaseURLs) != Dummy {
+ panic("channel base urls length not match")
+ }
+}
diff --git a/relay/util/init.go b/relay/client/init.go
similarity index 97%
rename from relay/util/init.go
rename to relay/client/init.go
index b303de0c..73108700 100644
--- a/relay/util/init.go
+++ b/relay/client/init.go
@@ -1,4 +1,4 @@
-package util
+package client
import (
"net/http"
diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go
deleted file mode 100644
index b249f6a2..00000000
--- a/relay/constant/api_type.go
+++ /dev/null
@@ -1,48 +0,0 @@
-package constant
-
-import (
- "github.com/songquanpeng/one-api/common"
-)
-
-const (
- APITypeOpenAI = iota
- APITypeAnthropic
- APITypePaLM
- APITypeBaidu
- APITypeZhipu
- APITypeAli
- APITypeXunfei
- APITypeAIProxyLibrary
- APITypeTencent
- APITypeGemini
- APITypeOllama
-
- APITypeDummy // this one is only for count, do not add any channel after this
-)
-
-func ChannelType2APIType(channelType int) int {
- apiType := APITypeOpenAI
- switch channelType {
- case common.ChannelTypeAnthropic:
- apiType = APITypeAnthropic
- case common.ChannelTypeBaidu:
- apiType = APITypeBaidu
- case common.ChannelTypePaLM:
- apiType = APITypePaLM
- case common.ChannelTypeZhipu:
- apiType = APITypeZhipu
- case common.ChannelTypeAli:
- apiType = APITypeAli
- case common.ChannelTypeXunfei:
- apiType = APITypeXunfei
- case common.ChannelTypeAIProxyLibrary:
- apiType = APITypeAIProxyLibrary
- case common.ChannelTypeTencent:
- apiType = APITypeTencent
- case common.ChannelTypeGemini:
- apiType = APITypeGemini
- case common.ChannelTypeOllama:
- apiType = APITypeOllama
- }
- return apiType
-}
diff --git a/relay/constant/image.go b/relay/constant/image.go
deleted file mode 100644
index 5e04895f..00000000
--- a/relay/constant/image.go
+++ /dev/null
@@ -1,24 +0,0 @@
-package constant
-
-var DalleSizeRatios = map[string]map[string]float64{
- "dall-e-2": {
- "256x256": 1,
- "512x512": 1.125,
- "1024x1024": 1.25,
- },
- "dall-e-3": {
- "1024x1024": 1,
- "1024x1792": 2,
- "1792x1024": 2,
- },
-}
-
-var DalleGenerationImageAmounts = map[string][2]int{
- "dall-e-2": {1, 10},
- "dall-e-3": {1, 1}, // OpenAI allows n=1 currently.
-}
-
-var DalleImagePromptLengthLimitations = map[string]int{
- "dall-e-2": 1000,
- "dall-e-3": 4000,
-}
diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go
deleted file mode 100644
index 5e2fe574..00000000
--- a/relay/constant/relay_mode.go
+++ /dev/null
@@ -1,42 +0,0 @@
-package constant
-
-import "strings"
-
-const (
- RelayModeUnknown = iota
- RelayModeChatCompletions
- RelayModeCompletions
- RelayModeEmbeddings
- RelayModeModerations
- RelayModeImagesGenerations
- RelayModeEdits
- RelayModeAudioSpeech
- RelayModeAudioTranscription
- RelayModeAudioTranslation
-)
-
-func Path2RelayMode(path string) int {
- relayMode := RelayModeUnknown
- if strings.HasPrefix(path, "/v1/chat/completions") {
- relayMode = RelayModeChatCompletions
- } else if strings.HasPrefix(path, "/v1/completions") {
- relayMode = RelayModeCompletions
- } else if strings.HasPrefix(path, "/v1/embeddings") {
- relayMode = RelayModeEmbeddings
- } else if strings.HasSuffix(path, "embeddings") {
- relayMode = RelayModeEmbeddings
- } else if strings.HasPrefix(path, "/v1/moderations") {
- relayMode = RelayModeModerations
- } else if strings.HasPrefix(path, "/v1/images/generations") {
- relayMode = RelayModeImagesGenerations
- } else if strings.HasPrefix(path, "/v1/edits") {
- relayMode = RelayModeEdits
- } else if strings.HasPrefix(path, "/v1/audio/speech") {
- relayMode = RelayModeAudioSpeech
- } else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
- relayMode = RelayModeAudioTranscription
- } else if strings.HasPrefix(path, "/v1/audio/translations") {
- relayMode = RelayModeAudioTranslation
- }
- return relayMode
-}
diff --git a/relay/controller/audio.go b/relay/controller/audio.go
index 85599b1f..58094c22 100644
--- a/relay/controller/audio.go
+++ b/relay/controller/audio.go
@@ -6,20 +6,23 @@ import (
"context"
"encoding/json"
"fmt"
- "io"
- "net/http"
- "strings"
-
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
- "github.com/songquanpeng/one-api/relay/channel/openai"
- "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/adaptor/azure"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
+ "github.com/songquanpeng/one-api/relay/billing"
+ billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
+ "github.com/songquanpeng/one-api/relay/channeltype"
+ "github.com/songquanpeng/one-api/relay/client"
relaymodel "github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
+ "github.com/songquanpeng/one-api/relay/relaymode"
+ "io"
+ "net/http"
+ "strings"
)
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
@@ -34,7 +37,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
tokenName := c.GetString("token_name")
var ttsRequest openai.TextToSpeechRequest
- if relayMode == constant.RelayModeAudioSpeech {
+ if relayMode == relaymode.AudioSpeech {
// Read JSON
err := common.UnmarshalBodyReusable(c, &ttsRequest)
// Check if JSON is valid
@@ -48,14 +51,15 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}
- modelRatio := common.GetModelRatio(audioModel)
- // groupRatio := common.GetGroupRatio(group)
- groupRatio := c.GetFloat64("channel_ratio")
+ modelRatio := billingratio.GetModelRatio(audioModel)
+ // groupRatio := billingratio.GetGroupRatio(group)
+ groupRatio := c.GetFloat64("channel_ratio") // get minimal ratio from multiple groups
+
ratio := modelRatio * groupRatio
var quota int64
var preConsumedQuota int64
switch relayMode {
- case constant.RelayModeAudioSpeech:
+ case relaymode.AudioSpeech:
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
default:
@@ -117,19 +121,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}
- baseURL := common.ChannelBaseURLs[channelType]
+ baseURL := channeltype.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
- fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType)
- if channelType == common.ChannelTypeAzure {
- apiVersion := util.GetAzureAPIVersion(c)
- if relayMode == constant.RelayModeAudioTranscription {
+ fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType)
+ if channelType == channeltype.Azure {
+ apiVersion := azure.GetAPIVersion(c)
+ if relayMode == relaymode.AudioTranscription {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion)
- } else if relayMode == constant.RelayModeAudioSpeech {
+ } else if relayMode == relaymode.AudioSpeech {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api
fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion)
}
@@ -148,7 +152,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
- if (relayMode == constant.RelayModeAudioTranscription || relayMode == constant.RelayModeAudioSpeech) && channelType == common.ChannelTypeAzure {
+ if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
@@ -160,7 +164,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
- resp, err := util.HTTPClient.Do(req)
+ resp, err := client.HTTPClient.Do(req)
if err != nil {
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
@@ -174,7 +178,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
}
- if relayMode != constant.RelayModeAudioSpeech {
+ if relayMode != relaymode.AudioSpeech {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
@@ -213,12 +217,12 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {
- return util.RelayErrorHandler(resp)
+ return RelayErrorHandler(resp)
}
succeed = true
quotaDelta := quota - preConsumedQuota
defer func(ctx context.Context) {
- go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
+ go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName)
}(c.Request.Context())
for k, v := range resp.Header {
diff --git a/relay/controller/error.go b/relay/controller/error.go
new file mode 100644
index 00000000..69ece3ec
--- /dev/null
+++ b/relay/controller/error.go
@@ -0,0 +1,91 @@
+package controller
+
+import (
+ "encoding/json"
+ "fmt"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/common/logger"
+ "github.com/songquanpeng/one-api/relay/model"
+ "io"
+ "net/http"
+ "strconv"
+)
+
+type GeneralErrorResponse struct {
+ Error model.Error `json:"error"`
+ Message string `json:"message"`
+ Msg string `json:"msg"`
+ Err string `json:"err"`
+ ErrorMsg string `json:"error_msg"`
+ Header struct {
+ Message string `json:"message"`
+ } `json:"header"`
+ Response struct {
+ Error struct {
+ Message string `json:"message"`
+ } `json:"error"`
+ } `json:"response"`
+}
+
+func (e GeneralErrorResponse) ToMessage() string {
+ if e.Error.Message != "" {
+ return e.Error.Message
+ }
+ if e.Message != "" {
+ return e.Message
+ }
+ if e.Msg != "" {
+ return e.Msg
+ }
+ if e.Err != "" {
+ return e.Err
+ }
+ if e.ErrorMsg != "" {
+ return e.ErrorMsg
+ }
+ if e.Header.Message != "" {
+ return e.Header.Message
+ }
+ if e.Response.Error.Message != "" {
+ return e.Response.Error.Message
+ }
+ return ""
+}
+
+func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *model.ErrorWithStatusCode) {
+ ErrorWithStatusCode = &model.ErrorWithStatusCode{
+ StatusCode: resp.StatusCode,
+ Error: model.Error{
+ Message: "",
+ Type: "upstream_error",
+ Code: "bad_response_status_code",
+ Param: strconv.Itoa(resp.StatusCode),
+ },
+ }
+ responseBody, err := io.ReadAll(resp.Body)
+ if err != nil {
+ return
+ }
+ if config.DebugEnabled {
+ logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody)))
+ }
+ err = resp.Body.Close()
+ if err != nil {
+ return
+ }
+ var errResponse GeneralErrorResponse
+ err = json.Unmarshal(responseBody, &errResponse)
+ if err != nil {
+ return
+ }
+ if errResponse.Error.Message != "" {
+ // OpenAI format error, so we override the default one
+ ErrorWithStatusCode.Error = errResponse.Error
+ } else {
+ ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
+ }
+ if ErrorWithStatusCode.Error.Message == "" {
+ ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
+ }
+ return
+}
diff --git a/relay/controller/helper.go b/relay/controller/helper.go
index 71dd653e..e07aba89 100644
--- a/relay/controller/helper.go
+++ b/relay/controller/helper.go
@@ -9,10 +9,13 @@ import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
- "github.com/songquanpeng/one-api/relay/channel/openai"
- "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
+ billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
+ "github.com/songquanpeng/one-api/relay/channeltype"
+ "github.com/songquanpeng/one-api/relay/controller/validator"
+ "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
+ "github.com/songquanpeng/one-api/relay/relaymode"
"math"
"net/http"
)
@@ -23,21 +26,21 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener
if err != nil {
return nil, err
}
- if relayMode == constant.RelayModeModerations && textRequest.Model == "" {
+ if relayMode == relaymode.Moderations && textRequest.Model == "" {
textRequest.Model = "text-moderation-latest"
}
- if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" {
+ if relayMode == relaymode.Embeddings && textRequest.Model == "" {
textRequest.Model = c.Param("model")
}
- err = util.ValidateTextRequest(textRequest, relayMode)
+ err = validator.ValidateTextRequest(textRequest, relayMode)
if err != nil {
return nil, err
}
return textRequest, nil
}
-func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) {
- imageRequest := &openai.ImageRequest{}
+func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) {
+ imageRequest := &relaymodel.ImageRequest{}
err := common.UnmarshalBodyReusable(c, imageRequest)
if err != nil {
return nil, err
@@ -54,9 +57,25 @@ func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error
return imageRequest, nil
}
-func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode {
+func isValidImageSize(model string, size string) bool {
+ if model == "cogview-3" {
+ return true
+ }
+ _, ok := billingratio.ImageSizeRatios[model][size]
+ return ok
+}
+
+func getImageSizeRatio(model string, size string) float64 {
+ ratio, ok := billingratio.ImageSizeRatios[model][size]
+ if !ok {
+ return 1
+ }
+ return ratio
+}
+
+func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode {
// model validation
- _, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
+ hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size)
if !hasValidSize {
return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest)
}
@@ -64,27 +83,24 @@ func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMet
if imageRequest.Prompt == "" {
return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest)
}
- if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] {
+ if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] {
return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest)
}
// Number of generated images validation
if !isWithinRange(imageRequest.Model, imageRequest.N) {
// channel not azure
- if meta.ChannelType != common.ChannelTypeAzure {
+ if meta.ChannelType != channeltype.Azure {
return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest)
}
}
return nil
}
-func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) {
+func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) {
if imageRequest == nil {
return 0, errors.New("imageRequest is nil")
}
- imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size]
- if !hasValidSize {
- return 0, errors.Errorf("size not supported for this image model: %s", imageRequest.Size)
- }
+ imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size)
if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" {
if imageRequest.Size == "1024x1024" {
imageCostRatio *= 2
@@ -97,11 +113,11 @@ func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) {
func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int {
switch relayMode {
- case constant.RelayModeChatCompletions:
+ case relaymode.ChatCompletions:
return openai.CountTokenMessages(textRequest.Messages, textRequest.Model)
- case constant.RelayModeCompletions:
+ case relaymode.Completions:
return openai.CountTokenInput(textRequest.Prompt, textRequest.Model)
- case constant.RelayModeModerations:
+ case relaymode.Moderations:
return openai.CountTokenInput(textRequest.Input, textRequest.Model)
}
return 0
@@ -115,7 +131,7 @@ func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTok
return int64(float64(preConsumedTokens) * ratio)
}
-func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) {
+func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) {
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
@@ -144,13 +160,13 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
return preConsumedQuota, nil
}
-func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
+func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected")
return
}
var quota int64
- completionRatio := common.GetCompletionRatio(textRequest.Model)
+ completionRatio := billingratio.GetCompletionRatio(textRequest.Model)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
@@ -178,3 +194,14 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R
model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota)
model.UpdateChannelUsedQuota(meta.ChannelId, quota)
}
+
+func getMappedModelName(modelName string, mapping map[string]string) (string, bool) {
+ if mapping == nil {
+ return modelName, false
+ }
+ mappedModelName := mapping[modelName]
+ if mappedModelName != "" {
+ return mappedModelName, true
+ }
+ return modelName, false
+}
diff --git a/relay/controller/image.go b/relay/controller/image.go
index d88dc271..ea3e32a0 100644
--- a/relay/controller/image.go
+++ b/relay/controller/image.go
@@ -5,35 +5,33 @@ import (
"context"
"encoding/json"
"fmt"
- "github.com/Laisky/errors/v2"
"io"
"net/http"
- "strings"
- "github.com/songquanpeng/one-api/common"
+ "github.com/Laisky/errors/v2"
+ "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
- "github.com/songquanpeng/one-api/relay/channel/openai"
- "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/relay"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
+ billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
+ "github.com/songquanpeng/one-api/relay/channeltype"
+ "github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
-
- "github.com/gin-gonic/gin"
)
func isWithinRange(element string, value int) bool {
- if _, ok := constant.DalleGenerationImageAmounts[element]; !ok {
+ if _, ok := billingratio.ImageGenerationAmounts[element]; !ok {
return false
}
- min := constant.DalleGenerationImageAmounts[element][0]
- max := constant.DalleGenerationImageAmounts[element][1]
-
+ min := billingratio.ImageGenerationAmounts[element][0]
+ max := billingratio.ImageGenerationAmounts[element][1]
return value >= min && value <= max
}
func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
- meta := util.GetRelayMeta(c)
+ meta := meta.GetByContext(c)
imageRequest, err := getImageRequest(c, meta.Mode)
if err != nil {
logger.Errorf(ctx, "getImageRequest failed: %s", err.Error())
@@ -43,7 +41,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
// map model name
var isModelMapped bool
meta.OriginModelName = imageRequest.Model
- imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping)
+ imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping)
meta.ActualModelName = imageRequest.Model
// model validation
@@ -57,17 +55,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError)
}
- requestURL := c.Request.URL.String()
- fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType)
- if meta.ChannelType == common.ChannelTypeAzure {
- // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
- apiVersion := util.GetAzureAPIVersion(c)
- // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
- fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion)
- }
-
var requestBody io.Reader
- if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body
+ if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body
jsonStr, err := json.Marshal(imageRequest)
if err != nil {
return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
@@ -77,9 +66,32 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
requestBody = c.Request.Body
}
- modelRatio := common.GetModelRatio(imageRequest.Model)
- // groupRatio := common.GetGroupRatio(meta.Group)
+ adaptor := relay.GetAdaptor(meta.APIType)
+ if adaptor == nil {
+ return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
+ }
+
+ switch meta.ChannelType {
+ case channeltype.Ali:
+ fallthrough
+ case channeltype.Baidu:
+ fallthrough
+ case channeltype.Zhipu:
+ finalRequest, err := adaptor.ConvertImageRequest(imageRequest)
+ if err != nil {
+ return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError)
+ }
+ jsonStr, err := json.Marshal(finalRequest)
+ if err != nil {
+ return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError)
+ }
+ requestBody = bytes.NewBuffer(jsonStr)
+ }
+
+ modelRatio := billingratio.GetModelRatio(imageRequest.Model)
+ // groupRatio := billingratio.GetGroupRatio(meta.Group)
groupRatio := c.GetFloat64("channel_ratio") // pre-selected cheapest channel ratio
+
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
@@ -89,36 +101,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
}
- req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
- if err != nil {
- return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
- }
- token := c.Request.Header.Get("Authorization")
- if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication
- token = strings.TrimPrefix(token, "Bearer ")
- req.Header.Set("api-key", token)
- } else {
- req.Header.Set("Authorization", token)
- }
-
- req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
- req.Header.Set("Accept", c.Request.Header.Get("Accept"))
-
- resp, err := util.HTTPClient.Do(req)
+ // do request
+ resp, err := adaptor.DoRequest(c, meta, requestBody)
if err != nil {
+ logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
}
- err = req.Body.Close()
- if err != nil {
- return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
- }
- err = c.Request.Body.Close()
- if err != nil {
- return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError)
- }
- var imageResponse openai.ImageResponse
-
defer func(ctx context.Context) {
if resp.StatusCode != http.StatusOK {
return
@@ -141,34 +130,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
}
}(c.Request.Context())
- responseBody, err := io.ReadAll(resp.Body)
-
- if err != nil {
- return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
- }
- err = resp.Body.Close()
- if err != nil {
- return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
- }
- err = json.Unmarshal(responseBody, &imageResponse)
- if err != nil {
- return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
+ // do response
+ _, respErr := adaptor.DoResponse(c, resp, meta)
+ if respErr != nil {
+ logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
+ return respErr
}
- resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
-
- for k, v := range resp.Header {
- c.Writer.Header().Set(k, v[0])
- }
- c.Writer.WriteHeader(resp.StatusCode)
-
- _, err = io.Copy(c.Writer, resp.Body)
- if err != nil {
- return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
- }
- err = resp.Body.Close()
- if err != nil {
- return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
- }
return nil
}
diff --git a/relay/controller/text.go b/relay/controller/text.go
index 282e8f25..beda2822 100644
--- a/relay/controller/text.go
+++ b/relay/controller/text.go
@@ -10,18 +10,20 @@ import (
"github.com/Laisky/errors/v2"
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/relay/channel/openai"
- "github.com/songquanpeng/one-api/relay/constant"
- "github.com/songquanpeng/one-api/relay/helper"
+ "github.com/songquanpeng/one-api/relay"
+ "github.com/songquanpeng/one-api/relay/adaptor/openai"
+ "github.com/songquanpeng/one-api/relay/apitype"
+ "github.com/songquanpeng/one-api/relay/billing"
+ billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
+ "github.com/songquanpeng/one-api/relay/channeltype"
+ "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
- "github.com/songquanpeng/one-api/relay/util"
)
func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
ctx := c.Request.Context()
- meta := util.GetRelayMeta(c)
+ meta := meta.GetByContext(c)
// get & validate textRequest
textRequest, err := getAndValidateTextRequest(c, meta.Mode)
if err != nil {
@@ -33,12 +35,13 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
// map model name
var isModelMapped bool
meta.OriginModelName = textRequest.Model
- textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping)
+ textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping)
meta.ActualModelName = textRequest.Model
// get model ratio & group ratio
- modelRatio := common.GetModelRatio(textRequest.Model)
- // groupRatio := common.GetGroupRatio(meta.Group)
+ modelRatio := billingratio.GetModelRatio(textRequest.Model)
+ // groupRatio := billingratio.GetGroupRatio(meta.Group)
groupRatio := meta.ChannelRatio
+
ratio := modelRatio * groupRatio
// pre-consume quota
promptTokens := getPromptTokens(textRequest, meta.Mode)
@@ -49,16 +52,16 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
return bizErr
}
- adaptor := helper.GetAdaptor(meta.APIType)
+ adaptor := relay.GetAdaptor(meta.APIType)
if adaptor == nil {
return openai.ErrorWrapper(errors.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
}
// get request body
var requestBody io.Reader
- if meta.APIType == constant.APITypeOpenAI {
+ if meta.APIType == apitype.OpenAI {
// no need to convert request for openai
- shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan
+ shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan
if shouldResetRequestBody {
jsonStr, err := json.Marshal(textRequest)
if err != nil {
@@ -93,10 +96,10 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
}
errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json")
if errorHappened {
- util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
+ billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
logger.Error(ctx, fmt.Sprintf("relay text [%d] <- %q %q",
resp.StatusCode, resp.Request.URL.String(), string(requestBodyBytes)))
- return util.RelayErrorHandler(resp)
+ return RelayErrorHandler(resp)
}
meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream")
@@ -104,7 +107,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
usage, respErr := adaptor.DoResponse(c, resp, meta)
if respErr != nil {
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
- util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
+ billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId)
return respErr
}
// post-consume quota
diff --git a/relay/util/validation.go b/relay/controller/validator/validation.go
similarity index 76%
rename from relay/util/validation.go
rename to relay/controller/validator/validation.go
index b9d25c2a..3aab6ac8 100644
--- a/relay/util/validation.go
+++ b/relay/controller/validator/validation.go
@@ -1,10 +1,11 @@
-package util
+package validator
import (
- "github.com/Laisky/errors/v2"
- "github.com/songquanpeng/one-api/relay/constant"
- "github.com/songquanpeng/one-api/relay/model"
"math"
+
+ "github.com/Laisky/errors/v2"
+ "github.com/songquanpeng/one-api/relay/model"
+ "github.com/songquanpeng/one-api/relay/relaymode"
)
func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) error {
@@ -15,20 +16,20 @@ func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int)
return errors.New("model is required")
}
switch relayMode {
- case constant.RelayModeCompletions:
+ case relaymode.Completions:
if textRequest.Prompt == "" {
return errors.New("field prompt is required")
}
- case constant.RelayModeChatCompletions:
+ case relaymode.ChatCompletions:
if textRequest.Messages == nil || len(textRequest.Messages) == 0 {
return errors.New("field messages is required")
}
- case constant.RelayModeEmbeddings:
- case constant.RelayModeModerations:
+ case relaymode.Embeddings:
+ case relaymode.Moderations:
if textRequest.Input == "" {
return errors.New("field input is required")
}
- case constant.RelayModeEdits:
+ case relaymode.Edits:
if textRequest.Instruction == "" {
return errors.New("field instruction is required")
}
diff --git a/relay/helper/main.go b/relay/helper/main.go
deleted file mode 100644
index 18bbe51a..00000000
--- a/relay/helper/main.go
+++ /dev/null
@@ -1,40 +0,0 @@
-package helper
-
-import (
- "github.com/songquanpeng/one-api/relay/channel"
- "github.com/songquanpeng/one-api/relay/channel/aiproxy"
- "github.com/songquanpeng/one-api/relay/channel/anthropic"
- "github.com/songquanpeng/one-api/relay/channel/gemini"
- "github.com/songquanpeng/one-api/relay/channel/ollama"
- "github.com/songquanpeng/one-api/relay/channel/openai"
- "github.com/songquanpeng/one-api/relay/channel/palm"
- "github.com/songquanpeng/one-api/relay/constant"
-)
-
-func GetAdaptor(apiType int) channel.Adaptor {
- switch apiType {
- case constant.APITypeAIProxyLibrary:
- return &aiproxy.Adaptor{}
- // case constant.APITypeAli:
- // return &ali.Adaptor{}
- case constant.APITypeAnthropic:
- return &anthropic.Adaptor{}
- // case constant.APITypeBaidu:
- // return &baidu.Adaptor{}
- case constant.APITypeGemini:
- return &gemini.Adaptor{}
- case constant.APITypeOpenAI:
- return &openai.Adaptor{}
- case constant.APITypePaLM:
- return &palm.Adaptor{}
- // case constant.APITypeTencent:
- // return &tencent.Adaptor{}
- // case constant.APITypeXunfei:
- // return &xunfei.Adaptor{}
- // case constant.APITypeZhipu:
- // return &zhipu.Adaptor{}
- case constant.APITypeOllama:
- return &ollama.Adaptor{}
- }
- return nil
-}
diff --git a/relay/util/relay_meta.go b/relay/meta/relay_meta.go
similarity index 64%
rename from relay/util/relay_meta.go
rename to relay/meta/relay_meta.go
index 17135816..a17aa0f0 100644
--- a/relay/util/relay_meta.go
+++ b/relay/meta/relay_meta.go
@@ -1,13 +1,15 @@
-package util
+package meta
import (
"github.com/gin-gonic/gin"
- "github.com/songquanpeng/one-api/common"
- "github.com/songquanpeng/one-api/relay/constant"
+ "github.com/songquanpeng/one-api/common/config"
+ "github.com/songquanpeng/one-api/relay/adaptor/azure"
+ "github.com/songquanpeng/one-api/relay/channeltype"
+ "github.com/songquanpeng/one-api/relay/relaymode"
"strings"
)
-type RelayMeta struct {
+type Meta struct {
Mode int
ChannelType int
ChannelId int
@@ -29,9 +31,9 @@ type RelayMeta struct {
ChannelRatio float64
}
-func GetRelayMeta(c *gin.Context) *RelayMeta {
- meta := RelayMeta{
- Mode: constant.Path2RelayMode(c.Request.URL.Path),
+func GetByContext(c *gin.Context) *Meta {
+ meta := Meta{
+ Mode: relaymode.GetByPath(c.Request.URL.Path),
ChannelType: c.GetInt("channel"),
ChannelId: c.GetInt("channel_id"),
TokenId: c.GetInt("token_id"),
@@ -40,18 +42,18 @@ func GetRelayMeta(c *gin.Context) *RelayMeta {
Group: c.GetString("group"),
ModelMapping: c.GetStringMapString("model_mapping"),
BaseURL: c.GetString("base_url"),
- APIVersion: c.GetString(common.ConfigKeyAPIVersion),
+ APIVersion: c.GetString(config.KeyAPIVersion),
APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "),
Config: nil,
RequestURLPath: c.Request.URL.String(),
ChannelRatio: c.GetFloat64("channel_ratio"),
}
- if meta.ChannelType == common.ChannelTypeAzure {
- meta.APIVersion = GetAzureAPIVersion(c)
+ if meta.ChannelType == channeltype.Azure {
+ meta.APIVersion = azure.GetAPIVersion(c)
}
if meta.BaseURL == "" {
- meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType]
+ meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType]
}
- meta.APIType = constant.ChannelType2APIType(meta.ChannelType)
+ meta.APIType = channeltype.ToAPIType(meta.ChannelType)
return &meta
}
diff --git a/relay/model/image.go b/relay/model/image.go
new file mode 100644
index 00000000..bab84256
--- /dev/null
+++ b/relay/model/image.go
@@ -0,0 +1,12 @@
+package model
+
+type ImageRequest struct {
+ Model string `json:"model"`
+ Prompt string `json:"prompt" binding:"required"`
+ N int `json:"n,omitempty"`
+ Size string `json:"size,omitempty"`
+ Quality string `json:"quality,omitempty"`
+ ResponseFormat string `json:"response_format,omitempty"`
+ Style string `json:"style,omitempty"`
+ User string `json:"user,omitempty"`
+}
diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go
new file mode 100644
index 00000000..96d09438
--- /dev/null
+++ b/relay/relaymode/define.go
@@ -0,0 +1,14 @@
+package relaymode
+
+const (
+ Unknown = iota
+ ChatCompletions
+ Completions
+ Embeddings
+ Moderations
+ ImagesGenerations
+ Edits
+ AudioSpeech
+ AudioTranscription
+ AudioTranslation
+)
diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go
new file mode 100644
index 00000000..926dd42e
--- /dev/null
+++ b/relay/relaymode/helper.go
@@ -0,0 +1,29 @@
+package relaymode
+
+import "strings"
+
+func GetByPath(path string) int {
+ relayMode := Unknown
+ if strings.HasPrefix(path, "/v1/chat/completions") {
+ relayMode = ChatCompletions
+ } else if strings.HasPrefix(path, "/v1/completions") {
+ relayMode = Completions
+ } else if strings.HasPrefix(path, "/v1/embeddings") {
+ relayMode = Embeddings
+ } else if strings.HasSuffix(path, "embeddings") {
+ relayMode = Embeddings
+ } else if strings.HasPrefix(path, "/v1/moderations") {
+ relayMode = Moderations
+ } else if strings.HasPrefix(path, "/v1/images/generations") {
+ relayMode = ImagesGenerations
+ } else if strings.HasPrefix(path, "/v1/edits") {
+ relayMode = Edits
+ } else if strings.HasPrefix(path, "/v1/audio/speech") {
+ relayMode = AudioSpeech
+ } else if strings.HasPrefix(path, "/v1/audio/transcriptions") {
+ relayMode = AudioTranscription
+ } else if strings.HasPrefix(path, "/v1/audio/translations") {
+ relayMode = AudioTranslation
+ }
+ return relayMode
+}
diff --git a/relay/util/billing.go b/relay/util/billing.go
deleted file mode 100644
index 495d011e..00000000
--- a/relay/util/billing.go
+++ /dev/null
@@ -1,19 +0,0 @@
-package util
-
-import (
- "context"
- "github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/model"
-)
-
-func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
- if preConsumedQuota != 0 {
- go func(ctx context.Context) {
- // return pre-consumed quota
- err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
- if err != nil {
- logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
- }
- }(ctx)
- }
-}
diff --git a/relay/util/common.go b/relay/util/common.go
deleted file mode 100644
index 518d0b00..00000000
--- a/relay/util/common.go
+++ /dev/null
@@ -1,188 +0,0 @@
-package util
-
-import (
- "context"
- "encoding/json"
- "fmt"
- "github.com/songquanpeng/one-api/common"
- "github.com/songquanpeng/one-api/common/config"
- "github.com/songquanpeng/one-api/common/logger"
- "github.com/songquanpeng/one-api/model"
- relaymodel "github.com/songquanpeng/one-api/relay/model"
- "io"
- "net/http"
- "strconv"
- "strings"
-
- "github.com/gin-gonic/gin"
-)
-
-func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
- if !config.AutomaticDisableChannelEnabled {
- return false
- }
- if err == nil {
- return false
- }
- if statusCode == http.StatusUnauthorized {
- return true
- }
- switch err.Type {
- case "insufficient_quota":
- return true
- // https://docs.anthropic.com/claude/reference/errors
- case "authentication_error":
- return true
- case "permission_error":
- return true
- case "forbidden":
- return true
- }
- if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
- return true
- }
- if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
- return true
- } else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
- return true
- }
- return false
-}
-
-func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool {
- if !config.AutomaticEnableChannelEnabled {
- return false
- }
- if err != nil {
- return false
- }
- if openAIErr != nil {
- return false
- }
- return true
-}
-
-type GeneralErrorResponse struct {
- Error relaymodel.Error `json:"error"`
- Message string `json:"message"`
- Msg string `json:"msg"`
- Err string `json:"err"`
- ErrorMsg string `json:"error_msg"`
- Header struct {
- Message string `json:"message"`
- } `json:"header"`
- Response struct {
- Error struct {
- Message string `json:"message"`
- } `json:"error"`
- } `json:"response"`
-}
-
-func (e GeneralErrorResponse) ToMessage() string {
- if e.Error.Message != "" {
- return e.Error.Message
- }
- if e.Message != "" {
- return e.Message
- }
- if e.Msg != "" {
- return e.Msg
- }
- if e.Err != "" {
- return e.Err
- }
- if e.ErrorMsg != "" {
- return e.ErrorMsg
- }
- if e.Header.Message != "" {
- return e.Header.Message
- }
- if e.Response.Error.Message != "" {
- return e.Response.Error.Message
- }
- return ""
-}
-
-func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) {
- ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{
- StatusCode: resp.StatusCode,
- Error: relaymodel.Error{
- Message: "",
- Type: "upstream_error",
- Code: "bad_response_status_code",
- Param: strconv.Itoa(resp.StatusCode),
- },
- }
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return
- }
- if config.DebugEnabled {
- logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody)))
- }
- err = resp.Body.Close()
- if err != nil {
- return
- }
- var errResponse GeneralErrorResponse
- err = json.Unmarshal(responseBody, &errResponse)
- if err != nil {
- return
- }
- if errResponse.Error.Message != "" {
- // OpenAI format error, so we override the default one
- ErrorWithStatusCode.Error = errResponse.Error
- } else {
- ErrorWithStatusCode.Error.Message = errResponse.ToMessage()
- }
- if ErrorWithStatusCode.Error.Message == "" {
- ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode)
- }
- return
-}
-
-func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
- fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
-
- if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
- switch channelType {
- case common.ChannelTypeOpenAI:
- fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
- case common.ChannelTypeAzure:
- fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
- }
- }
- return fullRequestURL
-}
-
-func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
- // quotaDelta is remaining quota to be consumed
- err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
- if err != nil {
- logger.SysError("error consuming token remain quota: " + err.Error())
- }
- err = model.CacheUpdateUserQuota(ctx, userId)
- if err != nil {
- logger.SysError("error update user quota cache: " + err.Error())
- }
- // totalQuota is total quota consumed
- if totalQuota >= 0 {
- logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
- model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
- model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
- model.UpdateChannelUsedQuota(channelId, totalQuota)
- }
-
- if totalQuota < 0 {
- logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
- }
-}
-
-func GetAzureAPIVersion(c *gin.Context) string {
- query := c.Request.URL.Query()
- apiVersion := query.Get("api-version")
- if apiVersion == "" {
- apiVersion = c.GetString(common.ConfigKeyAPIVersion)
- }
- return apiVersion
-}
diff --git a/relay/util/model_mapping.go b/relay/util/model_mapping.go
deleted file mode 100644
index 39e062a1..00000000
--- a/relay/util/model_mapping.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package util
-
-func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) {
- if mapping == nil {
- return modelName, false
- }
- mappedModelName := mapping[modelName]
- if mappedModelName != "" {
- return mappedModelName, true
- }
- return modelName, false
-}
diff --git a/router/api-router.go b/router/api.go
similarity index 92%
rename from router/api-router.go
rename to router/api.go
index 7af4511a..b9e5de38 100644
--- a/router/api-router.go
+++ b/router/api.go
@@ -2,6 +2,7 @@ package router
import (
"github.com/songquanpeng/one-api/controller"
+ "github.com/songquanpeng/one-api/controller/auth"
"github.com/songquanpeng/one-api/middleware"
"github.com/gin-contrib/gzip"
@@ -22,11 +23,13 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail)
apiRouter.GET("/user/get-by-token", middleware.TokenAuth(), controller.GetSelfByToken)
apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)
- apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth)
- apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode)
- apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth)
- apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind)
+ apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth)
+ apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth)
+ apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode)
+ apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth)
+ apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), auth.WeChatBind)
apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind)
+ apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp)
userRoute := apiRouter.Group("/user")
{
@@ -44,6 +47,7 @@ func SetApiRouter(router *gin.Engine) {
selfRoute.GET("/token", controller.GenerateAccessToken)
selfRoute.GET("/aff", controller.GetAffCode)
selfRoute.POST("/topup", controller.TopUp)
+ selfRoute.GET("/available_models", controller.GetUserAvailableModels)
}
adminRoute := userRoute.Group("/")
@@ -69,7 +73,7 @@ func SetApiRouter(router *gin.Engine) {
{
channelRoute.GET("/", controller.GetAllChannels)
channelRoute.GET("/search", controller.SearchChannels)
- channelRoute.GET("/models", controller.ListModels)
+ channelRoute.GET("/models", controller.ListAllModels)
channelRoute.GET("/:id", controller.GetChannel)
channelRoute.GET("/test", controller.TestChannels)
channelRoute.GET("/test/:id", controller.TestChannel)
diff --git a/router/relay-router.go b/router/relay.go
similarity index 100%
rename from router/relay-router.go
rename to router/relay.go
diff --git a/router/web-router.go b/router/web.go
similarity index 100%
rename from router/web-router.go
rename to router/web.go
diff --git a/web/README.md b/web/README.md
index 29f4713e..829271e2 100644
--- a/web/README.md
+++ b/web/README.md
@@ -2,6 +2,9 @@
> 每个文件夹代表一个主题,欢迎提交你的主题
+> [!WARNING]
+> 不是每一个主题都及时同步了所有功能,由于精力有限,优先更新默认主题,其他主题欢迎 & 期待 PR
+
## 提交新的主题
> 欢迎在页面底部保留你和 One API 的版权信息以及指向链接
diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js
index ec049f7d..b74c58c7 100644
--- a/web/berry/src/constants/ChannelConstants.js
+++ b/web/berry/src/constants/ChannelConstants.js
@@ -107,6 +107,12 @@ export const CHANNEL_OPTIONS = {
value: 31,
color: 'primary'
},
+ 32: {
+ key: 32,
+ text: '阶跃星辰',
+ value: 32,
+ color: 'primary'
+ },
8: {
key: 8,
text: '自定义渠道',
diff --git a/web/berry/src/constants/SnackbarConstants.js b/web/berry/src/constants/SnackbarConstants.js
index a05c6652..19523da1 100644
--- a/web/berry/src/constants/SnackbarConstants.js
+++ b/web/berry/src/constants/SnackbarConstants.js
@@ -18,7 +18,7 @@ export const snackbarConstants = {
},
NOTICE: {
variant: 'info',
- autoHideDuration: 20000
+ autoHideDuration: 7000
}
},
Mobile: {
diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js
index aa4b8c37..8925e542 100644
--- a/web/berry/src/utils/common.js
+++ b/web/berry/src/utils/common.js
@@ -51,9 +51,9 @@ export function showError(error) {
export function showNotice(message, isHTML = false) {
if (isHTML) {
- enqueueSnackbar(, getSnackbarOptions('INFO'));
+ enqueueSnackbar(, getSnackbarOptions('NOTICE'));
} else {
- enqueueSnackbar(message, getSnackbarOptions('INFO'));
+ enqueueSnackbar(message, getSnackbarOptions('NOTICE'));
}
}
diff --git a/web/berry/src/views/Channel/component/EditModal.js b/web/berry/src/views/Channel/component/EditModal.js
index 07111c97..cbf411b9 100644
--- a/web/berry/src/views/Channel/component/EditModal.js
+++ b/web/berry/src/views/Channel/component/EditModal.js
@@ -340,7 +340,9 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => {
},
}}
>
- {Object.values(CHANNEL_OPTIONS).map((option) => {
+ {Object.values(CHANNEL_OPTIONS).sort((a, b) => {
+ return a.text.localeCompare(b.text)
+ }).map((option) => {
return (