Compare commits

..

30 Commits

Author SHA1 Message Date
JustSong
7bf61f9165 fix: fix retry not working (close #1314) 2024-04-15 23:09:12 +08:00
JustSong
a10232f43a feat: add gpt-4-turbo support (close #1304) 2024-04-13 11:39:31 +08:00
JustSong
af543ab8ec docs: update readme 2024-04-06 20:50:43 +08:00
JustSong
e086da05b1 feat: able to change gemini version (close #1211) 2024-04-06 20:48:22 +08:00
JustSong
3af4649b52 fix: only check model when request path in whitelist 2024-04-06 20:42:35 +08:00
GAI Group
52c32c0b4a chore: resolve the issue of onclick event scope for custom Lark button (#1281)
chore: Resolve the issue of onclick event scope for custom Lark button
2024-04-06 20:08:05 +08:00
Buer
3fe2863ff7 feat: berry theme update & bug fix (#1282)
* ️ improve: delete google fonts

* ️ improve: Optimized priority input handling in TableRow component.

* 🔖 chore: channel batch add

*  feat: add dark mod

*  feat: support token limit ip range and models

*  feat: add MessagePusher

*  feat: add lark login
2024-04-06 19:44:23 +08:00
JustSong
acf8cb6248 chore: update default nextweb link 2024-04-06 11:47:31 +08:00
JustSong
572fc9ffb8 fix: fix stepfun model ratio & id 2024-04-06 10:43:54 +08:00
GAI Group
569c04acb0 fix: fix Lark icon button style (#1279) 2024-04-06 10:18:59 +08:00
JustSong
961b4108e6 chore: fix refactor caused typo 2024-04-06 02:12:50 +08:00
JustSong
0b8ccb94eb chore: reorganize common package 2024-04-06 02:03:59 +08:00
JustSong
f586ae0ad8 chore: remove helper & util subpackage for relay 2024-04-06 01:50:12 +08:00
JustSong
24ed170e7b chore: reorganize adaptor related package 2024-04-06 01:36:48 +08:00
JustSong
f70506eac1 chore: reorganize relay related package 2024-04-06 01:31:44 +08:00
JustSong
8f4d78e24d chore: reorganize billing related package 2024-04-06 01:26:48 +08:00
JustSong
cd2707692f chore: reorganize billing related package 2024-04-06 01:09:23 +08:00
JustSong
2ab7d25a80 chore: reorganize helper related package 2024-04-06 01:02:35 +08:00
JustSong
f9d914873f chore: reorganize constant related package 2024-04-06 00:44:33 +08:00
JustSong
880e12c855 feat: support cogview-3 2024-04-06 00:30:08 +08:00
JustSong
0cb224e62e chore: fix typo 2024-04-05 23:55:25 +08:00
JustSong
a44fb5d482 fix: fix channel model list is empty 2024-04-05 23:44:57 +08:00
JustSong
eec41849ec chore: fix ali image implementation 2024-04-05 18:25:57 +08:00
Mo
d4347e7a35 feat: support Ali stable-diffusion-xl and wanx-v1 model (#1240)
* Fix ali ConvertRequest function to use baidu keyword

* Support Ali stable-diffusion-xl and wanx-v1 model

* Support Ali stable-diffusion-xl and wanx-v1 model

* Support Ali stable-diffusion-xl and wanx-v1 model

* chore: update ali constants and model ratio

---------

Co-authored-by: JustSong <songquanpeng@foxmail.com>
Co-authored-by: JustSong <39998050+songquanpeng@users.noreply.github.com>
2024-04-05 18:09:54 +08:00
manjieqi
b50b43eb65 feat: update baidu model name & ratio (#1277) 2024-04-05 17:30:48 +08:00
JustSong
348adc2b02 feat: able to set multiple subnets 2024-04-05 17:25:28 +08:00
JustSong
dcf24b98dc chore: update berry copy 2024-04-05 14:28:38 +08:00
JustSong
af679e04f4 chore: sort channel type for berry 2024-04-05 14:23:39 +08:00
JustSong
93cbca6a9f chore: update show notice duration 2024-04-05 14:14:21 +08:00
JustSong
840ef80d94 fix: do not try to parse model when requesting /v1/models (close #1272) 2024-04-05 12:50:31 +08:00
180 changed files with 2939 additions and 1707 deletions

View File

@@ -363,28 +363,29 @@ graph LR
9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。 9. `CHANNEL_UPDATE_FREQUENCY`:设置之后将定期更新渠道余额,单位为分钟,未设置则不进行更新。
+ 例子:`CHANNEL_UPDATE_FREQUENCY=1440` + 例子:`CHANNEL_UPDATE_FREQUENCY=1440`
10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。 10. `CHANNEL_TEST_FREQUENCY`:设置之后将定期检查渠道,单位为分钟,未设置则不进行检查。
+ 例子:`CHANNEL_TEST_FREQUENCY=1440` 11. 例子:`CHANNEL_TEST_FREQUENCY=1440`
11. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。 12. `POLLING_INTERVAL`:批量更新渠道余额以及测试可用性时的请求间隔,单位为秒,默认无间隔。
+ 例子:`POLLING_INTERVAL=5` + 例子:`POLLING_INTERVAL=5`
12. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。 13. `BATCH_UPDATE_ENABLED`:启用数据库批量更新聚合,会导致用户额度的更新存在一定的延迟可选值为 `true` 和 `false`,未设置则默认为 `false`。
+ 例子:`BATCH_UPDATE_ENABLED=true` + 例子:`BATCH_UPDATE_ENABLED=true`
+ 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。 + 如果你遇到了数据库连接数过多的问题,可以尝试启用该选项。
13. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。 14. `BATCH_UPDATE_INTERVAL=5`:批量更新聚合的时间间隔,单位为秒,默认为 `5`。
+ 例子:`BATCH_UPDATE_INTERVAL=5` + 例子:`BATCH_UPDATE_INTERVAL=5`
14. 请求频率限制: 15. 请求频率限制:
+ `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。 + `GLOBAL_API_RATE_LIMIT`:全局 API 速率限制(除中继请求外),单 ip 三分钟内的最大请求数,默认为 `180`。
+ `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。 + `GLOBAL_WEB_RATE_LIMIT`:全局 Web 速率限制,单 ip 三分钟内的最大请求数,默认为 `60`。
15. 编码器缓存设置: 16. 编码器缓存设置:
+ `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。 + `TIKTOKEN_CACHE_DIR`:默认程序启动时会联网下载一些通用的词元的编码,如:`gpt-3.5-turbo`,在一些网络环境不稳定,或者离线情况,可能会导致启动有问题,可以配置此目录缓存数据,可迁移到离线环境。
+ `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。 + `DATA_GYM_CACHE_DIR`:目前该配置作用与 `TIKTOKEN_CACHE_DIR` 一致,但是优先级没有它高。
16. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。 17. `RELAY_TIMEOUT`:中继超时设置,单位为秒,默认不设置超时时间。
17. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。 18. `SQLITE_BUSY_TIMEOUT`SQLite 锁等待超时设置,单位为毫秒,默认 `3000`。
18. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。 19. `GEMINI_SAFETY_SETTING`Gemini 的安全设置,默认 `BLOCK_NONE`。
19. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md) 20. `GEMINI_VERSION`One API 所使用的 Gemini 版本,默认为 `v1`
20. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false` 21. `THEME`:系统的主题设置,默认为 `default`,具体可选值参考[此处](./web/README.md)
21. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。 22. `ENABLE_METRIC`:是否根据请求成功率禁用渠道,默认不开启,可选值为 `true` 和 `false`。
22. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`。 23. `METRIC_QUEUE_SIZE`:请求成功率统计队列大小,默认为 `10`。
23. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌 24. `METRIC_SUCCESS_RATE_THRESHOLD`:请求成功率阈值,默认为 `0.8`
25. `INITIAL_ROOT_TOKEN`:如果设置了该值,则在系统首次启动时会自动创建一个值为该环境变量值的 root 用户令牌。
### 命令行参数 ### 命令行参数
1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。 1. `--port <port_number>`: 指定服务器监听的端口号,默认为 `3000`。

View File

@@ -141,3 +141,5 @@ var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128) var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)
var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN") var InitialRootToken = os.Getenv("INITIAL_ROOT_TOKEN")
var GeminiVersion = env.String("GEMINI_VERSION", "v1")

9
common/config/key.go Normal file
View File

@@ -0,0 +1,9 @@
package config
const (
KeyPrefix = "cfg_"
KeyAPIVersion = KeyPrefix + "api_version"
KeyLibraryID = KeyPrefix + "library_id"
KeyPlugin = KeyPrefix + "plugin"
)

View File

@@ -4,118 +4,3 @@ import "time"
var StartTime = time.Now().Unix() // unit: second var StartTime = time.Now().Unix() // unit: second
var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change
const (
RoleGuestUser = 0
RoleCommonUser = 1
RoleAdminUser = 10
RoleRootUser = 100
)
const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0
UserStatusDeleted = 3
)
const (
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
TokenStatusDisabled = 2 // also don't use 0
TokenStatusExpired = 3
TokenStatusExhausted = 4
)
const (
RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
RedemptionCodeStatusDisabled = 2 // also don't use 0
RedemptionCodeStatusUsed = 3 // also don't use 0
)
const (
ChannelStatusUnknown = 0
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
ChannelStatusManuallyDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
)
const (
ChannelTypeUnknown = iota
ChannelTypeOpenAI
ChannelTypeAPI2D
ChannelTypeAzure
ChannelTypeCloseAI
ChannelTypeOpenAISB
ChannelTypeOpenAIMax
ChannelTypeOhMyGPT
ChannelTypeCustom
ChannelTypeAILS
ChannelTypeAIProxy
ChannelTypePaLM
ChannelTypeAPI2GPT
ChannelTypeAIGC2D
ChannelTypeAnthropic
ChannelTypeBaidu
ChannelTypeZhipu
ChannelTypeAli
ChannelTypeXunfei
ChannelType360
ChannelTypeOpenRouter
ChannelTypeAIProxyLibrary
ChannelTypeFastGPT
ChannelTypeTencent
ChannelTypeGemini
ChannelTypeMoonshot
ChannelTypeBaichuan
ChannelTypeMinimax
ChannelTypeMistral
ChannelTypeGroq
ChannelTypeOllama
ChannelTypeLingYiWanWu
ChannelTypeStepFun
ChannelTypeDummy
)
var ChannelBaseURLs = []string{
"", // 0
"https://api.openai.com", // 1
"https://oa.api2d.net", // 2
"", // 3
"https://api.closeai-proxy.xyz", // 4
"https://api.openai-sb.com", // 5
"https://api.openaimax.com", // 6
"https://api.ohmygpt.com", // 7
"", // 8
"https://api.caipacity.com", // 9
"https://api.aiproxy.io", // 10
"https://generativelanguage.googleapis.com", // 11
"https://api.api2gpt.com", // 12
"https://api.aigc2d.com", // 13
"https://api.anthropic.com", // 14
"https://aip.baidubce.com", // 15
"https://open.bigmodel.cn", // 16
"https://dashscope.aliyuncs.com", // 17
"", // 18
"https://ai.360.cn", // 19
"https://openrouter.ai/api", // 20
"https://api.aiproxy.io", // 21
"https://fastgpt.run/api/openapi", // 22
"https://hunyuan.cloud.tencent.com", // 23
"https://generativelanguage.googleapis.com", // 24
"https://api.moonshot.cn", // 25
"https://api.baichuan-ai.com", // 26
"https://api.minimax.chat", // 27
"https://api.mistral.ai", // 28
"https://api.groq.com/openai", // 29
"http://localhost:11434", // 30
"https://api.lingyiwanwu.com", // 31
"https://api.stepfun.com", // 32
}
const (
ConfigKeyPrefix = "cfg_"
ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version"
ConfigKeyLibraryID = ConfigKeyPrefix + "library_id"
ConfigKeyPlugin = ConfigKeyPrefix + "plugin"
)

View File

@@ -2,16 +2,14 @@ package helper
import ( import (
"fmt" "fmt"
"github.com/google/uuid" "github.com/songquanpeng/one-api/common/random"
"html/template" "html/template"
"log" "log"
"math/rand"
"net" "net"
"os/exec" "os/exec"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
"time"
) )
func OpenBrowser(url string) { func OpenBrowser(url string) {
@@ -79,31 +77,6 @@ func Bytes2Size(num int64) string {
return numStr + " " + unit return numStr + " " + unit
} }
func Seconds2Time(num int) (time string) {
if num/31104000 > 0 {
time += strconv.Itoa(num/31104000) + " 年 "
num %= 31104000
}
if num/2592000 > 0 {
time += strconv.Itoa(num/2592000) + " 个月 "
num %= 2592000
}
if num/86400 > 0 {
time += strconv.Itoa(num/86400) + " 天 "
num %= 86400
}
if num/3600 > 0 {
time += strconv.Itoa(num/3600) + " 小时 "
num %= 3600
}
if num/60 > 0 {
time += strconv.Itoa(num/60) + " 分钟 "
num %= 60
}
time += strconv.Itoa(num) + " 秒"
return
}
func Interface2String(inter interface{}) string { func Interface2String(inter interface{}) string {
switch inter := inter.(type) { switch inter := inter.(type) {
case string: case string:
@@ -128,65 +101,8 @@ func IntMax(a int, b int) int {
} }
} }
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const keyNumbers = "0123456789"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomNumberString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
}
return string(key)
}
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func GenRequestID() string { func GenRequestID() string {
return GetTimeString() + GetRandomNumberString(8) return GetTimeString() + random.GetRandomNumberString(8)
} }
func Max(a int, b int) int { func Max(a int, b int) int {

15
common/helper/time.go Normal file
View File

@@ -0,0 +1,15 @@
package helper
import (
"fmt"
"time"
)
func GetTimestamp() int64 {
return time.Now().Unix()
}
func GetTimeString() string {
now := time.Now()
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}

View File

@@ -5,9 +5,18 @@ import (
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"net" "net"
"strings"
) )
func IsValidSubnet(subnet string) error { func splitSubnets(subnets string) []string {
res := strings.Split(subnets, ",")
for i := 0; i < len(res); i++ {
res[i] = strings.TrimSpace(res[i])
}
return res
}
func isValidSubnet(subnet string) error {
_, _, err := net.ParseCIDR(subnet) _, _, err := net.ParseCIDR(subnet)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse subnet: %w", err) return fmt.Errorf("failed to parse subnet: %w", err)
@@ -15,7 +24,7 @@ func IsValidSubnet(subnet string) error {
return nil return nil
} }
func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool { func isIpInSubnet(ctx context.Context, ip string, subnet string) bool {
_, ipNet, err := net.ParseCIDR(subnet) _, ipNet, err := net.ParseCIDR(subnet)
if err != nil { if err != nil {
logger.Errorf(ctx, "failed to parse subnet: %s", err.Error()) logger.Errorf(ctx, "failed to parse subnet: %s", err.Error())
@@ -23,3 +32,21 @@ func IsIpInSubnet(ctx context.Context, ip string, subnet string) bool {
} }
return ipNet.Contains(net.ParseIP(ip)) return ipNet.Contains(net.ParseIP(ip))
} }
func IsValidSubnets(subnets string) error {
for _, subnet := range splitSubnets(subnets) {
if err := isValidSubnet(subnet); err != nil {
return err
}
}
return nil
}
func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool {
for _, subnet := range splitSubnets(subnets) {
if isIpInSubnet(ctx, ip, subnet) {
return true
}
}
return false
}

View File

@@ -13,7 +13,7 @@ func TestIsIpInSubnet(t *testing.T) {
ip2 := "125.216.250.89" ip2 := "125.216.250.89"
subnet := "192.168.0.0/24" subnet := "192.168.0.0/24"
Convey("TestIsIpInSubnet", t, func() { Convey("TestIsIpInSubnet", t, func() {
So(IsIpInSubnet(ctx, ip1, subnet), ShouldBeTrue) So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue)
So(IsIpInSubnet(ctx, ip2, subnet), ShouldBeFalse) So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse)
}) })
} }

View File

@@ -1,8 +0,0 @@
package common
import "math/rand"
// RandRange returns a random number between min and max (max is not included)
func RandRange(min, max int) int {
return min + rand.Intn(max-min)
}

61
common/random/main.go Normal file
View File

@@ -0,0 +1,61 @@
package random
import (
"github.com/google/uuid"
"math/rand"
"strings"
"time"
)
func GetUUID() string {
code := uuid.New().String()
code = strings.Replace(code, "-", "", -1)
return code
}
const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
const keyNumbers = "0123456789"
func init() {
rand.Seed(time.Now().UnixNano())
}
func GenerateKey() string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, 48)
for i := 0; i < 16; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
uuid_ := GetUUID()
for i := 0; i < 32; i++ {
c := uuid_[i]
if i%2 == 0 && c >= 'a' && c <= 'z' {
c = c - 'a' + 'A'
}
key[i+16] = c
}
return string(key)
}
func GetRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyChars[rand.Intn(len(keyChars))]
}
return string(key)
}
func GetRandomNumberString(length int) string {
rand.Seed(time.Now().UnixNano())
key := make([]byte, length)
for i := 0; i < length; i++ {
key[i] = keyNumbers[rand.Intn(len(keyNumbers))]
}
return string(key)
}
// RandRange returns a random number between min and max (max is not included)
func RandRange(min, max int) int {
return min + rand.Intn(max-min)
}

View File

@@ -7,10 +7,9 @@ import (
"fmt" "fmt"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http" "net/http"
@@ -134,8 +133,8 @@ func GitHubOAuth(c *gin.Context) {
user.DisplayName = "GitHub User" user.DisplayName = "GitHub User"
} }
user.Email = githubUser.Email user.Email = githubUser.Email
user.Role = common.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -153,7 +152,7 @@ func GitHubOAuth(c *gin.Context) {
} }
} }
if user.Status != common.UserStatusEnabled { if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁", "message": "用户已被封禁",
"success": false, "success": false,
@@ -220,7 +219,7 @@ func GitHubBind(c *gin.Context) {
func GenerateOAuthCode(c *gin.Context) { func GenerateOAuthCode(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
state := helper.GetRandomString(12) state := random.GetRandomString(12)
session.Set("oauth_state", state) session.Set("oauth_state", state)
err := session.Save() err := session.Save()
if err != nil { if err != nil {

View File

@@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
@@ -123,8 +122,8 @@ func LarkOAuth(c *gin.Context) {
} else { } else {
user.DisplayName = "Lark User" user.DisplayName = "Lark User"
} }
user.Role = common.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -142,7 +141,7 @@ func LarkOAuth(c *gin.Context) {
} }
} }
if user.Status != common.UserStatusEnabled { if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁", "message": "用户已被封禁",
"success": false, "success": false,

View File

@@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
@@ -84,8 +83,8 @@ func WeChatAuth(c *gin.Context) {
if config.RegisterEnabled { if config.RegisterEnabled {
user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1)
user.DisplayName = "WeChat User" user.DisplayName = "WeChat User"
user.Role = common.RoleCommonUser user.Role = model.RoleCommonUser
user.Status = common.UserStatusEnabled user.Status = model.UserStatusEnabled
if err := user.Insert(0); err != nil { if err := user.Insert(0); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -103,7 +102,7 @@ func WeChatAuth(c *gin.Context) {
} }
} }
if user.Status != common.UserStatusEnabled { if user.Status != model.UserStatusEnabled {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"message": "用户已被封禁", "message": "用户已被封禁",
"success": false, "success": false,

View File

@@ -4,12 +4,12 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/client"
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
@@ -96,7 +96,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He
for k := range headers { for k := range headers {
req.Header.Add(k, headers.Get(k)) req.Header.Add(k, headers.Get(k))
} }
res, err := util.HTTPClient.Do(req) res, err := client.HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -204,28 +204,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) {
} }
func updateChannelBalance(channel *model.Channel) (float64, error) { func updateChannelBalance(channel *model.Channel) (float64, error) {
baseURL := common.ChannelBaseURLs[channel.Type] baseURL := channeltype.ChannelBaseURLs[channel.Type]
if channel.GetBaseURL() == "" { if channel.GetBaseURL() == "" {
channel.BaseURL = &baseURL channel.BaseURL = &baseURL
} }
switch channel.Type { switch channel.Type {
case common.ChannelTypeOpenAI: case channeltype.OpenAI:
if channel.GetBaseURL() != "" { if channel.GetBaseURL() != "" {
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
} }
case common.ChannelTypeAzure: case channeltype.Azure:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
case common.ChannelTypeCustom: case channeltype.Custom:
baseURL = channel.GetBaseURL() baseURL = channel.GetBaseURL()
case common.ChannelTypeCloseAI: case channeltype.CloseAI:
return updateChannelCloseAIBalance(channel) return updateChannelCloseAIBalance(channel)
case common.ChannelTypeOpenAISB: case channeltype.OpenAISB:
return updateChannelOpenAISBBalance(channel) return updateChannelOpenAISBBalance(channel)
case common.ChannelTypeAIProxy: case channeltype.AIProxy:
return updateChannelAIProxyBalance(channel) return updateChannelAIProxyBalance(channel)
case common.ChannelTypeAPI2GPT: case channeltype.API2GPT:
return updateChannelAPI2GPTBalance(channel) return updateChannelAPI2GPTBalance(channel)
case common.ChannelTypeAIGC2D: case channeltype.AIGC2D:
return updateChannelAIGC2DBalance(channel) return updateChannelAIGC2DBalance(channel)
default: default:
return 0, errors.New("尚未实现") return 0, errors.New("尚未实现")
@@ -301,11 +301,11 @@ func updateAllChannelsBalance() error {
return err return err
} }
for _, channel := range channels { for _, channel := range channels {
if channel.Status != common.ChannelStatusEnabled { if channel.Status != model.ChannelStatusEnabled {
continue continue
} }
// TODO: support Azure // TODO: support Azure
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom {
continue continue
} }
balance, err := updateChannelBalance(channel) balance, err := updateChannelBalance(channel)

View File

@@ -5,17 +5,18 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message" "github.com/songquanpeng/one-api/common/message"
"github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant" relay "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@@ -56,9 +57,9 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
c.Set("channel", channel.Type) c.Set("channel", channel.Type)
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
middleware.SetupContextForSelectedChannel(c, channel, "") middleware.SetupContextForSelectedChannel(c, channel, "")
meta := util.GetRelayMeta(c) meta := meta.GetByContext(c)
apiType := constant.ChannelType2APIType(channel.Type) apiType := channeltype.ToAPIType(channel.Type)
adaptor := helper.GetAdaptor(apiType) adaptor := relay.GetAdaptor(apiType)
if adaptor == nil { if adaptor == nil {
return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil return fmt.Errorf("invalid api type: %d, adaptor is nil", apiType), nil
} }
@@ -73,7 +74,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
request := buildTestRequest() request := buildTestRequest()
request.Model = modelName request.Model = modelName
meta.OriginModelName, meta.ActualModelName = modelName, modelName meta.OriginModelName, meta.ActualModelName = modelName, modelName
convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request)
if err != nil { if err != nil {
return err, nil return err, nil
} }
@@ -88,7 +89,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error
return err, nil return err, nil
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
err := util.RelayErrorHandler(resp) err := controller.RelayErrorHandler(resp)
return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error
} }
usage, respErr := adaptor.DoResponse(c, resp, meta) usage, respErr := adaptor.DoResponse(c, resp, meta)
@@ -171,7 +172,7 @@ func testChannels(notify bool, scope string) error {
} }
go func() { go func() {
for _, channel := range channels { for _, channel := range channels {
isChannelEnabled := channel.Status == common.ChannelStatusEnabled isChannelEnabled := channel.Status == model.ChannelStatusEnabled
tik := time.Now() tik := time.Now()
err, openaiErr := testChannel(channel) err, openaiErr := testChannel(channel)
tok := time.Now() tok := time.Now()
@@ -184,10 +185,10 @@ func testChannels(notify bool, scope string) error {
_ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s %d测试超时", channel.Name, channel.Id), "", err.Error()) _ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s %d测试超时", channel.Name, channel.Id), "", err.Error())
} }
} }
if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) {
monitor.DisableChannel(channel.Id, channel.Name, err.Error()) monitor.DisableChannel(channel.Id, channel.Name, err.Error())
} }
if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) {
monitor.EnableChannel(channel.Id, channel.Name) monitor.EnableChannel(channel.Id, channel.Name)
} }
channel.UpdateResponseTime(milliseconds) channel.UpdateResponseTime(milliseconds)

View File

@@ -2,13 +2,13 @@ package controller
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"net/http" "net/http"
) )
func GetGroups(c *gin.Context) { func GetGroups(c *gin.Context) {
groupNames := make([]string, 0) groupNames := make([]string, 0)
for groupName := range common.GroupRatio { for groupName := range billingratio.GroupRatio {
groupNames = append(groupNames, groupName) groupNames = append(groupNames, groupName)
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -3,13 +3,13 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai" relay "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http" "net/http"
"strings" "strings"
) )
@@ -41,8 +41,8 @@ type OpenAIModels struct {
Parent *string `json:"parent"` Parent *string `json:"parent"`
} }
var openAIModels []OpenAIModels var models []OpenAIModels
var openAIModelsMap map[string]OpenAIModels var modelsMap map[string]OpenAIModels
var channelId2Models map[int][]string var channelId2Models map[int][]string
func init() { func init() {
@@ -62,15 +62,15 @@ func init() {
IsBlocking: false, IsBlocking: false,
}) })
// https://platform.openai.com/docs/models/model-endpoint-compatibility // https://platform.openai.com/docs/models/model-endpoint-compatibility
for i := 0; i < constant.APITypeDummy; i++ { for i := 0; i < apitype.Dummy; i++ {
if i == constant.APITypeAIProxyLibrary { if i == apitype.AIProxyLibrary {
continue continue
} }
adaptor := helper.GetAdaptor(i) adaptor := relay.GetAdaptor(i)
channelName := adaptor.GetChannelName() channelName := adaptor.GetChannelName()
modelNames := adaptor.GetModelList() modelNames := adaptor.GetModelList()
for _, modelName := range modelNames { for _, modelName := range modelNames {
openAIModels = append(openAIModels, OpenAIModels{ models = append(models, OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
@@ -82,12 +82,12 @@ func init() {
} }
} }
for _, channelType := range openai.CompatibleChannels { for _, channelType := range openai.CompatibleChannels {
if channelType == common.ChannelTypeAzure { if channelType == channeltype.Azure {
continue continue
} }
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
for _, modelName := range channelModelList { for _, modelName := range channelModelList {
openAIModels = append(openAIModels, OpenAIModels{ models = append(models, OpenAIModels{
Id: modelName, Id: modelName,
Object: "model", Object: "model",
Created: 1626777600, Created: 1626777600,
@@ -98,14 +98,14 @@ func init() {
}) })
} }
} }
openAIModelsMap = make(map[string]OpenAIModels) modelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels { for _, model := range models {
openAIModelsMap[model.Id] = model modelsMap[model.Id] = model
} }
channelId2Models = make(map[int][]string) channelId2Models = make(map[int][]string)
for i := 1; i < common.ChannelTypeDummy; i++ { for i := 1; i < channeltype.Dummy; i++ {
adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i)) adaptor := relay.GetAdaptor(channeltype.ToAPIType(i))
meta := &util.RelayMeta{ meta := &meta.Meta{
ChannelType: i, ChannelType: i,
} }
adaptor.Init(meta) adaptor.Init(meta)
@@ -121,6 +121,13 @@ func DashboardListModels(c *gin.Context) {
}) })
} }
func ListAllModels(c *gin.Context) {
c.JSON(200, gin.H{
"object": "list",
"data": models,
})
}
func ListModels(c *gin.Context) { func ListModels(c *gin.Context) {
ctx := c.Request.Context() ctx := c.Request.Context()
var availableModels []string var availableModels []string
@@ -136,7 +143,7 @@ func ListModels(c *gin.Context) {
modelSet[availableModel] = true modelSet[availableModel] = true
} }
availableOpenAIModels := make([]OpenAIModels, 0) availableOpenAIModels := make([]OpenAIModels, 0)
for _, model := range openAIModels { for _, model := range models {
if _, ok := modelSet[model.Id]; ok { if _, ok := modelSet[model.Id]; ok {
modelSet[model.Id] = false modelSet[model.Id] = false
availableOpenAIModels = append(availableOpenAIModels, model) availableOpenAIModels = append(availableOpenAIModels, model)
@@ -162,7 +169,7 @@ func ListModels(c *gin.Context) {
func RetrieveModel(c *gin.Context) { func RetrieveModel(c *gin.Context) {
modelId := c.Param("model") modelId := c.Param("model")
if model, ok := openAIModelsMap[modelId]; ok { if model, ok := modelsMap[modelId]; ok {
c.JSON(200, model) c.JSON(200, model)
} else { } else {
Error := relaymodel.Error{ Error := relaymodel.Error{

View File

@@ -4,6 +4,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"strconv" "strconv"
@@ -106,7 +107,7 @@ func AddRedemption(c *gin.Context) {
} }
var keys []string var keys []string
for i := 0; i < redemption.Count; i++ { for i := 0; i < redemption.Count; i++ {
key := helper.GetUUID() key := random.GetUUID()
cleanRedemption := model.Redemption{ cleanRedemption := model.Redemption{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: redemption.Name, Name: redemption.Name,

View File

@@ -12,26 +12,25 @@ import (
"github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/middleware"
dbmodel "github.com/songquanpeng/one-api/model" dbmodel "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/monitor" "github.com/songquanpeng/one-api/monitor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/controller"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
) )
// https://platform.openai.com/docs/api-reference/chat // https://platform.openai.com/docs/api-reference/chat
func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
var err *model.ErrorWithStatusCode var err *model.ErrorWithStatusCode
switch relayMode { switch relayMode {
case constant.RelayModeImagesGenerations: case relaymode.ImagesGenerations:
err = controller.RelayImageHelper(c, relayMode) err = controller.RelayImageHelper(c, relayMode)
case constant.RelayModeAudioSpeech: case relaymode.AudioSpeech:
fallthrough fallthrough
case constant.RelayModeAudioTranslation: case relaymode.AudioTranslation:
fallthrough fallthrough
case constant.RelayModeAudioTranscription: case relaymode.AudioTranscription:
err = controller.RelayAudioHelper(c, relayMode) err = controller.RelayAudioHelper(c, relayMode)
default: default:
err = controller.RelayTextHelper(c) err = controller.RelayTextHelper(c)
@@ -41,13 +40,13 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
func Relay(c *gin.Context) { func Relay(c *gin.Context) {
ctx := c.Request.Context() ctx := c.Request.Context()
relayMode := constant.Path2RelayMode(c.Request.URL.Path) relayMode := relaymode.GetByPath(c.Request.URL.Path)
if config.DebugEnabled { if config.DebugEnabled {
requestBody, _ := common.GetRequestBody(c) requestBody, _ := common.GetRequestBody(c)
logger.Debugf(ctx, "request body: %s", string(requestBody)) logger.Debugf(ctx, "request body: %s", string(requestBody))
} }
channelId := c.GetInt("channel_id") channelId := c.GetInt("channel_id")
bizErr := relay(c, relayMode) bizErr := relayHelper(c, relayMode)
if bizErr == nil { if bizErr == nil {
monitor.Emit(channelId, true) monitor.Emit(channelId, true)
return return
@@ -76,7 +75,7 @@ func Relay(c *gin.Context) {
middleware.SetupContextForSelectedChannel(c, channel, originalModel) middleware.SetupContextForSelectedChannel(c, channel, originalModel)
requestBody, err := common.GetRequestBody(c) requestBody, err := common.GetRequestBody(c)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
bizErr = relay(c, relayMode) bizErr = relayHelper(c, relayMode)
if bizErr == nil { if bizErr == nil {
return return
} }
@@ -118,7 +117,7 @@ func shouldRetry(c *gin.Context, statusCode int) bool {
func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) { func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) {
logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message)
// https://platform.openai.com/docs/guides/error-codes/api-errors // https://platform.openai.com/docs/guides/error-codes/api-errors
if util.ShouldDisableChannel(&err.Error, err.StatusCode) { if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) {
monitor.DisableChannel(channelId, channelName, err.Message) monitor.DisableChannel(channelId, channelName, err.Message)
} else { } else {
monitor.Emit(channelId, false) monitor.Emit(channelId, false)

View File

@@ -3,10 +3,10 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"strconv" "strconv"
@@ -111,7 +111,7 @@ func validateToken(c *gin.Context, token model.Token) error {
return fmt.Errorf("令牌名称过长") return fmt.Errorf("令牌名称过长")
} }
if token.Subnet != nil && *token.Subnet != "" { if token.Subnet != nil && *token.Subnet != "" {
err := network.IsValidSubnet(*token.Subnet) err := network.IsValidSubnets(*token.Subnet)
if err != nil { if err != nil {
return fmt.Errorf("无效的网段:%s", err.Error()) return fmt.Errorf("无效的网段:%s", err.Error())
} }
@@ -141,7 +141,7 @@ func AddToken(c *gin.Context) {
cleanToken := model.Token{ cleanToken := model.Token{
UserId: c.GetInt("id"), UserId: c.GetInt("id"),
Name: token.Name, Name: token.Name,
Key: helper.GenerateKey(), Key: random.GenerateKey(),
CreatedTime: helper.GetTimestamp(), CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(), AccessedTime: helper.GetTimestamp(),
ExpiredTime: token.ExpiredTime, ExpiredTime: token.ExpiredTime,
@@ -212,15 +212,15 @@ func UpdateToken(c *gin.Context) {
}) })
return return
} }
if token.Status == common.TokenStatusEnabled { if token.Status == model.TokenStatusEnabled {
if cleanToken.Status == common.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期",
}) })
return return
} }
if cleanToken.Status == common.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度",

View File

@@ -5,7 +5,7 @@ import (
"fmt" "fmt"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"net/http" "net/http"
"strconv" "strconv"
@@ -239,7 +239,7 @@ func GetUser(c *gin.Context) {
return return
} }
myRole := c.GetInt("role") myRole := c.GetInt("role")
if myRole <= user.Role && myRole != common.RoleRootUser { if myRole <= user.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权获取同级或更高等级用户的信息", "message": "无权获取同级或更高等级用户的信息",
@@ -287,7 +287,7 @@ func GenerateAccessToken(c *gin.Context) {
}) })
return return
} }
user.AccessToken = helper.GetUUID() user.AccessToken = random.GetUUID()
if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
@@ -324,7 +324,7 @@ func GetAffCode(c *gin.Context) {
return return
} }
if user.AffCode == "" { if user.AffCode == "" {
user.AffCode = helper.GetRandomString(4) user.AffCode = random.GetRandomString(4)
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
@@ -388,14 +388,14 @@ func UpdateUser(c *gin.Context) {
return return
} }
myRole := c.GetInt("role") myRole := c.GetInt("role")
if myRole <= originUser.Role && myRole != common.RoleRootUser { if myRole <= originUser.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息", "message": "无权更新同权限等级或更高权限等级的用户信息",
}) })
return return
} }
if myRole <= updatedUser.Role && myRole != common.RoleRootUser { if myRole <= updatedUser.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权将其他用户权限等级提升到大于等于自己的权限等级", "message": "无权将其他用户权限等级提升到大于等于自己的权限等级",
@@ -509,7 +509,7 @@ func DeleteSelf(c *gin.Context) {
id := c.GetInt("id") id := c.GetInt("id")
user, _ := model.GetUserById(id, false) user, _ := model.GetUserById(id, false)
if user.Role == common.RoleRootUser { if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "不能删除超级管理员账户", "message": "不能删除超级管理员账户",
@@ -611,7 +611,7 @@ func ManageUser(c *gin.Context) {
return return
} }
myRole := c.GetInt("role") myRole := c.GetInt("role")
if myRole <= user.Role && myRole != common.RoleRootUser { if myRole <= user.Role && myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无权更新同权限等级或更高权限等级的用户信息", "message": "无权更新同权限等级或更高权限等级的用户信息",
@@ -620,8 +620,8 @@ func ManageUser(c *gin.Context) {
} }
switch req.Action { switch req.Action {
case "disable": case "disable":
user.Status = common.UserStatusDisabled user.Status = model.UserStatusDisabled
if user.Role == common.RoleRootUser { if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法禁用超级管理员用户", "message": "无法禁用超级管理员用户",
@@ -629,9 +629,9 @@ func ManageUser(c *gin.Context) {
return return
} }
case "enable": case "enable":
user.Status = common.UserStatusEnabled user.Status = model.UserStatusEnabled
case "delete": case "delete":
if user.Role == common.RoleRootUser { if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法删除超级管理员用户", "message": "无法删除超级管理员用户",
@@ -646,37 +646,37 @@ func ManageUser(c *gin.Context) {
return return
} }
case "promote": case "promote":
if myRole != common.RoleRootUser { if myRole != model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "普通管理员用户无法提升其他用户为管理员", "message": "普通管理员用户无法提升其他用户为管理员",
}) })
return return
} }
if user.Role >= common.RoleAdminUser { if user.Role >= model.RoleAdminUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "该用户已经是管理员", "message": "该用户已经是管理员",
}) })
return return
} }
user.Role = common.RoleAdminUser user.Role = model.RoleAdminUser
case "demote": case "demote":
if user.Role == common.RoleRootUser { if user.Role == model.RoleRootUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "无法降级超级管理员用户", "message": "无法降级超级管理员用户",
}) })
return return
} }
if user.Role == common.RoleCommonUser { if user.Role == model.RoleCommonUser {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "该用户已经是普通用户", "message": "该用户已经是普通用户",
}) })
return return
} }
user.Role = common.RoleCommonUser user.Role = model.RoleCommonUser
} }
if err := user.Update(false); err != nil { if err := user.Update(false); err != nil {
@@ -730,7 +730,7 @@ func EmailBind(c *gin.Context) {
}) })
return return
} }
if user.Role == common.RoleRootUser { if user.Role == model.RoleRootUser {
config.RootUserEmail = email config.RootUserEmail = email
} }
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@@ -12,7 +12,7 @@ import (
"github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/controller"
"github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/middleware"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/router" "github.com/songquanpeng/one-api/router"
"os" "os"
"strconv" "strconv"

View File

@@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/common/network"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
@@ -45,7 +44,7 @@ func authHelper(c *gin.Context, minRole int) {
return return
} }
} }
if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"success": false, "success": false,
"message": "用户已被封禁", "message": "用户已被封禁",
@@ -72,19 +71,19 @@ func authHelper(c *gin.Context, minRole int) {
func UserAuth() func(c *gin.Context) { func UserAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
authHelper(c, common.RoleCommonUser) authHelper(c, model.RoleCommonUser)
} }
} }
func AdminAuth() func(c *gin.Context) { func AdminAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
authHelper(c, common.RoleAdminUser) authHelper(c, model.RoleAdminUser)
} }
} }
func RootAuth() func(c *gin.Context) { func RootAuth() func(c *gin.Context) {
return func(c *gin.Context) { return func(c *gin.Context) {
authHelper(c, common.RoleRootUser) authHelper(c, model.RoleRootUser)
} }
} }
@@ -102,7 +101,7 @@ func TokenAuth() func(c *gin.Context) {
return return
} }
if token.Subnet != nil && *token.Subnet != "" { if token.Subnet != nil && *token.Subnet != "" {
if !network.IsIpInSubnet(ctx, c.ClientIP(), *token.Subnet) { if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) {
abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s当前 ip%s", *token.Subnet, c.ClientIP())) abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s当前 ip%s", *token.Subnet, c.ClientIP()))
return return
} }
@@ -117,7 +116,7 @@ func TokenAuth() func(c *gin.Context) {
return return
} }
requestModel, err := getRequestModel(c) requestModel, err := getRequestModel(c)
if err != nil { if err != nil && shouldCheckModel(c) {
abortWithMessage(c, http.StatusBadRequest, err.Error()) abortWithMessage(c, http.StatusBadRequest, err.Error())
return return
} }
@@ -143,3 +142,19 @@ func TokenAuth() func(c *gin.Context) {
c.Next() c.Next()
} }
} }
func shouldCheckModel(c *gin.Context) bool {
if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") {
return true
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
return true
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/images") {
return true
}
if strings.HasPrefix(c.Request.URL.Path, "/v1/audio") {
return true
}
return false
}

View File

@@ -3,9 +3,10 @@ package middleware
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/model"
"github.com/songquanpeng/one-api/relay/channeltype"
"net/http" "net/http"
"strconv" "strconv"
) )
@@ -33,12 +34,12 @@ func Distribute() func(c *gin.Context) {
abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id")
return return
} }
if channel.Status != common.ChannelStatusEnabled { if channel.Status != model.ChannelStatusEnabled {
abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用")
return return
} }
} else { } else {
requestModel := c.GetString("request_model") requestModel = c.GetString("request_model")
var err error var err error
channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false)
if err != nil { if err != nil {
@@ -66,19 +67,19 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
c.Set("base_url", channel.GetBaseURL()) c.Set("base_url", channel.GetBaseURL())
// this is for backward compatibility // this is for backward compatibility
switch channel.Type { switch channel.Type {
case common.ChannelTypeAzure: case channeltype.Azure:
c.Set(common.ConfigKeyAPIVersion, channel.Other) c.Set(config.KeyAPIVersion, channel.Other)
case common.ChannelTypeXunfei: case channeltype.Xunfei:
c.Set(common.ConfigKeyAPIVersion, channel.Other) c.Set(config.KeyAPIVersion, channel.Other)
case common.ChannelTypeGemini: case channeltype.Gemini:
c.Set(common.ConfigKeyAPIVersion, channel.Other) c.Set(config.KeyAPIVersion, channel.Other)
case common.ChannelTypeAIProxyLibrary: case channeltype.AIProxyLibrary:
c.Set(common.ConfigKeyLibraryID, channel.Other) c.Set(config.KeyLibraryID, channel.Other)
case common.ChannelTypeAli: case channeltype.Ali:
c.Set(common.ConfigKeyPlugin, channel.Other) c.Set(config.KeyPlugin, channel.Other)
} }
cfg, _ := channel.LoadConfig() cfg, _ := channel.LoadConfig()
for k, v := range cfg { for k, v := range cfg {
c.Set(common.ConfigKeyPrefix+k, v) c.Set(config.KeyPrefix+k, v)
} }
} }

View File

@@ -57,7 +57,7 @@ func (channel *Channel) AddAbilities() error {
Group: group, Group: group,
Model: model, Model: model,
ChannelId: channel.Id, ChannelId: channel.Id,
Enabled: channel.Status == common.ChannelStatusEnabled, Enabled: channel.Status == ChannelStatusEnabled,
Priority: channel.Priority, Priority: channel.Priority,
} }
abilities = append(abilities, ability) abilities = append(abilities, ability)

View File

@@ -8,6 +8,7 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"math/rand" "math/rand"
"sort" "sort"
"strconv" "strconv"
@@ -172,7 +173,7 @@ var channelSyncLock sync.RWMutex
func InitChannelCache() { func InitChannelCache() {
newChannelId2channel := make(map[int]*Channel) newChannelId2channel := make(map[int]*Channel)
var channels []*Channel var channels []*Channel
DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) DB.Where("status = ?", ChannelStatusEnabled).Find(&channels)
for _, channel := range channels { for _, channel := range channels {
newChannelId2channel[channel.Id] = channel newChannelId2channel[channel.Id] = channel
} }
@@ -247,7 +248,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPrior
idx := rand.Intn(endIdx) idx := rand.Intn(endIdx)
if ignoreFirstPriority { if ignoreFirstPriority {
if endIdx < len(channels) { // which means there are more than one priority if endIdx < len(channels) { // which means there are more than one priority
idx = common.RandRange(endIdx, len(channels)) idx = random.RandRange(endIdx, len(channels))
} }
} }
return channels[idx], nil return channels[idx], nil

View File

@@ -3,13 +3,19 @@ package model
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm" "gorm.io/gorm"
) )
const (
ChannelStatusUnknown = 0
ChannelStatusEnabled = 1 // don't use 0, 0 is the default value!
ChannelStatusManuallyDisabled = 2 // also don't use 0
ChannelStatusAutoDisabled = 3
)
type Channel struct { type Channel struct {
Id int `json:"id"` Id int `json:"id"`
Type int `json:"type" gorm:"default:0"` Type int `json:"type" gorm:"default:0"`
@@ -39,7 +45,7 @@ func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) {
case "all": case "all":
err = DB.Order("id desc").Find(&channels).Error err = DB.Order("id desc").Find(&channels).Error
case "disabled": case "disabled":
err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error
default: default:
err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error
} }
@@ -168,7 +174,7 @@ func (channel *Channel) LoadConfig() (map[string]string, error) {
} }
func UpdateChannelStatusById(id int, status int) { func UpdateChannelStatusById(id int, status int) {
err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) err := UpdateAbilityStatus(id, status == ChannelStatusEnabled)
if err != nil { if err != nil {
logger.SysError("failed to update ability status: " + err.Error()) logger.SysError("failed to update ability status: " + err.Error())
} }
@@ -199,6 +205,6 @@ func DeleteChannelByStatus(status int64) (int64, error) {
} }
func DeleteDisabledChannel() (int64, error) { func DeleteDisabledChannel() (int64, error) {
result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{})
return result.RowsAffected, result.Error return result.RowsAffected, result.Error
} }

View File

@@ -7,7 +7,6 @@ import (
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm" "gorm.io/gorm"
) )

View File

@@ -7,6 +7,7 @@ import (
"github.com/songquanpeng/one-api/common/env" "github.com/songquanpeng/one-api/common/env"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
@@ -31,10 +32,10 @@ func CreateRootAccountIfNeed() error {
rootUser := User{ rootUser := User{
Username: "root", Username: "root",
Password: hashedPassword, Password: hashedPassword,
Role: common.RoleRootUser, Role: RoleRootUser,
Status: common.UserStatusEnabled, Status: UserStatusEnabled,
DisplayName: "Root User", DisplayName: "Root User",
AccessToken: helper.GetUUID(), AccessToken: random.GetUUID(),
Quota: 500000000000000, Quota: 500000000000000,
} }
DB.Create(&rootUser) DB.Create(&rootUser)
@@ -44,7 +45,7 @@ func CreateRootAccountIfNeed() error {
Id: 1, Id: 1,
UserId: rootUser.Id, UserId: rootUser.Id,
Key: config.InitialRootToken, Key: config.InitialRootToken,
Status: common.TokenStatusEnabled, Status: TokenStatusEnabled,
Name: "Initial Root Token", Name: "Initial Root Token",
CreatedTime: helper.GetTimestamp(), CreatedTime: helper.GetTimestamp(),
AccessedTime: helper.GetTimestamp(), AccessedTime: helper.GetTimestamp(),

View File

@@ -1,9 +1,9 @@
package model package model
import ( import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -66,9 +66,9 @@ func InitOptionMap() {
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString()
config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString()
config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString()
config.OptionMap["TopUpLink"] = config.TopUpLink config.OptionMap["TopUpLink"] = config.TopUpLink
config.OptionMap["ChatLink"] = config.ChatLink config.OptionMap["ChatLink"] = config.ChatLink
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
@@ -82,7 +82,7 @@ func loadOptionsFromDatabase() {
options, _ := AllOption() options, _ := AllOption()
for _, option := range options { for _, option := range options {
if option.Key == "ModelRatio" { if option.Key == "ModelRatio" {
option.Value = common.AddNewMissingRatio(option.Value) option.Value = billingratio.AddNewMissingRatio(option.Value)
} }
err := updateOptionMap(option.Key, option.Value) err := updateOptionMap(option.Key, option.Value)
if err != nil { if err != nil {
@@ -209,11 +209,11 @@ func updateOptionMap(key string, value string) (err error) {
case "RetryTimes": case "RetryTimes":
config.RetryTimes, _ = strconv.Atoi(value) config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio": case "ModelRatio":
err = common.UpdateModelRatioByJSONString(value) err = billingratio.UpdateModelRatioByJSONString(value)
case "GroupRatio": case "GroupRatio":
err = common.UpdateGroupRatioByJSONString(value) err = billingratio.UpdateGroupRatioByJSONString(value)
case "CompletionRatio": case "CompletionRatio":
err = common.UpdateCompletionRatioByJSONString(value) err = billingratio.UpdateCompletionRatioByJSONString(value)
case "TopUpLink": case "TopUpLink":
config.TopUpLink = value config.TopUpLink = value
case "ChatLink": case "ChatLink":

View File

@@ -8,6 +8,12 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
const (
RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value!
RedemptionCodeStatusDisabled = 2 // also don't use 0
RedemptionCodeStatusUsed = 3 // also don't use 0
)
type Redemption struct { type Redemption struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
@@ -61,7 +67,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
if err != nil { if err != nil {
return errors.New("无效的兑换码") return errors.New("无效的兑换码")
} }
if redemption.Status != common.RedemptionCodeStatusEnabled { if redemption.Status != RedemptionCodeStatusEnabled {
return errors.New("该兑换码已被使用") return errors.New("该兑换码已被使用")
} }
err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error
@@ -69,7 +75,7 @@ func Redeem(key string, userId int) (quota int64, err error) {
return err return err
} }
redemption.RedeemedTime = helper.GetTimestamp() redemption.RedeemedTime = helper.GetTimestamp()
redemption.Status = common.RedemptionCodeStatusUsed redemption.Status = RedemptionCodeStatusUsed
err = tx.Save(redemption).Error err = tx.Save(redemption).Error
return err return err
}) })

View File

@@ -11,6 +11,13 @@ import (
"gorm.io/gorm" "gorm.io/gorm"
) )
const (
TokenStatusEnabled = 1 // don't use 0, 0 is the default value!
TokenStatusDisabled = 2 // also don't use 0
TokenStatusExpired = 3
TokenStatusExhausted = 4
)
type Token struct { type Token struct {
Id int `json:"id"` Id int `json:"id"`
UserId int `json:"user_id"` UserId int `json:"user_id"`
@@ -62,17 +69,17 @@ func ValidateUserToken(key string) (token *Token, err error) {
} }
return nil, errors.New("令牌验证失败") return nil, errors.New("令牌验证失败")
} }
if token.Status == common.TokenStatusExhausted { if token.Status == TokenStatusExhausted {
return nil, fmt.Errorf("令牌 %s#%d额度已用尽", token.Name, token.Id) return nil, fmt.Errorf("令牌 %s#%d额度已用尽", token.Name, token.Id)
} else if token.Status == common.TokenStatusExpired { } else if token.Status == TokenStatusExpired {
return nil, errors.New("该令牌已过期") return nil, errors.New("该令牌已过期")
} }
if token.Status != common.TokenStatusEnabled { if token.Status != TokenStatusEnabled {
return nil, errors.New("该令牌状态不可用") return nil, errors.New("该令牌状态不可用")
} }
if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() {
if !common.RedisEnabled { if !common.RedisEnabled {
token.Status = common.TokenStatusExpired token.Status = TokenStatusExpired
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
logger.SysError("failed to update token status" + err.Error()) logger.SysError("failed to update token status" + err.Error())
@@ -83,7 +90,7 @@ func ValidateUserToken(key string) (token *Token, err error) {
if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !token.UnlimitedQuota && token.RemainQuota <= 0 {
if !common.RedisEnabled { if !common.RedisEnabled {
// in this case, we can make sure the token is exhausted // in this case, we can make sure the token is exhausted
token.Status = common.TokenStatusExhausted token.Status = TokenStatusExhausted
err := token.SelectUpdate() err := token.SelectUpdate()
if err != nil { if err != nil {
logger.SysError("failed to update token status" + err.Error()) logger.SysError("failed to update token status" + err.Error())

View File

@@ -6,12 +6,25 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/blacklist"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/random"
"gorm.io/gorm" "gorm.io/gorm"
"strings" "strings"
) )
const (
RoleGuestUser = 0
RoleCommonUser = 1
RoleAdminUser = 10
RoleRootUser = 100
)
const (
UserStatusEnabled = 1 // don't use 0, 0 is the default value!
UserStatusDisabled = 2 // also don't use 0
UserStatusDeleted = 3
)
// User if you add sensitive fields, don't forget to clean them in setupLogin function. // User if you add sensitive fields, don't forget to clean them in setupLogin function.
// Otherwise, the sensitive information will be saved on local storage in plain text! // Otherwise, the sensitive information will be saved on local storage in plain text!
type User struct { type User struct {
@@ -42,7 +55,7 @@ func GetMaxUserId() int {
} }
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted) query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted)
switch order { switch order {
case "quota": case "quota":
@@ -108,8 +121,8 @@ func (user *User) Insert(inviterId int) error {
} }
} }
user.Quota = config.QuotaForNewUser user.Quota = config.QuotaForNewUser
user.AccessToken = helper.GetUUID() user.AccessToken = random.GetUUID()
user.AffCode = helper.GetRandomString(4) user.AffCode = random.GetRandomString(4)
result := DB.Create(user) result := DB.Create(user)
if result.Error != nil { if result.Error != nil {
return result.Error return result.Error
@@ -138,9 +151,9 @@ func (user *User) Update(updatePassword bool) error {
return err return err
} }
} }
if user.Status == common.UserStatusDisabled { if user.Status == UserStatusDisabled {
blacklist.BanUser(user.Id) blacklist.BanUser(user.Id)
} else if user.Status == common.UserStatusEnabled { } else if user.Status == UserStatusEnabled {
blacklist.UnbanUser(user.Id) blacklist.UnbanUser(user.Id)
} }
err = DB.Model(user).Updates(user).Error err = DB.Model(user).Updates(user).Error
@@ -152,8 +165,8 @@ func (user *User) Delete() error {
return errors.New("id 为空!") return errors.New("id 为空!")
} }
blacklist.BanUser(user.Id) blacklist.BanUser(user.Id)
user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID()) user.Username = fmt.Sprintf("deleted_%s", random.GetUUID())
user.Status = common.UserStatusDeleted user.Status = UserStatusDeleted
err := DB.Model(user).Updates(user).Error err := DB.Model(user).Updates(user).Error
return err return err
} }
@@ -177,7 +190,7 @@ func (user *User) ValidateAndFill() (err error) {
} }
} }
okay := common.ValidatePasswordAndHash(password, user.Password) okay := common.ValidatePasswordAndHash(password, user.Password)
if !okay || user.Status != common.UserStatusEnabled { if !okay || user.Status != UserStatusEnabled {
return errors.New("用户名或密码错误,或用户已被封禁") return errors.New("用户名或密码错误,或用户已被封禁")
} }
return nil return nil
@@ -273,7 +286,7 @@ func IsAdmin(userId int) bool {
logger.SysError("no such user " + err.Error()) logger.SysError("no such user " + err.Error())
return false return false
} }
return user.Role >= common.RoleAdminUser return user.Role >= RoleAdminUser
} }
func IsUserEnabled(userId int) (bool, error) { func IsUserEnabled(userId int) (bool, error) {
@@ -285,7 +298,7 @@ func IsUserEnabled(userId int) (bool, error) {
if err != nil { if err != nil {
return false, err return false, err
} }
return user.Status == common.UserStatusEnabled, nil return user.Status == UserStatusEnabled, nil
} }
func ValidateAccessToken(token string) (user *User) { func ValidateAccessToken(token string) (user *User) {
@@ -358,7 +371,7 @@ func decreaseUserQuota(id int, quota int64) (err error) {
} }
func GetRootUserEmail() (email string) { func GetRootUserEmail() (email string) {
DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email)
return email return email
} }

View File

@@ -2,7 +2,6 @@ package monitor
import ( import (
"fmt" "fmt"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/common/message" "github.com/songquanpeng/one-api/common/message"
@@ -29,7 +28,7 @@ func notifyRootUser(subject string, content string) {
// DisableChannel disable & notify // DisableChannel disable & notify
func DisableChannel(channelId int, channelName string, reason string) { func DisableChannel(channelId int, channelName string, reason string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason))
subject := fmt.Sprintf("渠道「%s」#%d已被禁用", channelName, channelId) subject := fmt.Sprintf("渠道「%s」#%d已被禁用", channelName, channelId)
content := fmt.Sprintf("渠道「%s」#%d已被禁用原因%s", channelName, channelId, reason) content := fmt.Sprintf("渠道「%s」#%d已被禁用原因%s", channelName, channelId, reason)
@@ -37,7 +36,7 @@ func DisableChannel(channelId int, channelName string, reason string) {
} }
func MetricDisableChannel(channelId int, successRate float64) { func MetricDisableChannel(channelId int, successRate float64) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled)
logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100))
subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId) subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId)
content := fmt.Sprintf("该渠道(#%d在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", content := fmt.Sprintf("该渠道(#%d在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。",
@@ -47,7 +46,7 @@ func MetricDisableChannel(channelId int, successRate float64) {
// EnableChannel enable & notify // EnableChannel enable & notify
func EnableChannel(channelId int, channelName string) { func EnableChannel(channelId int, channelName string) {
model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled)
logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId))
subject := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId) subject := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId)
content := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId) content := fmt.Sprintf("渠道「%s」#%d已被启用", channelName, channelId)

62
monitor/manage.go Normal file
View File

@@ -0,0 +1,62 @@
package monitor
import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/model"
"net/http"
"strings"
)
func ShouldDisableChannel(err *model.Error, statusCode int) bool {
if !config.AutomaticDisableChannelEnabled {
return false
}
if err == nil {
return false
}
if statusCode == http.StatusUnauthorized {
return true
}
switch err.Type {
case "insufficient_quota":
return true
// https://docs.anthropic.com/claude/reference/errors
case "authentication_error":
return true
case "permission_error":
return true
case "forbidden":
return true
}
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
//if strings.Contains(err.Message, "quota") {
// return true
//}
if strings.Contains(err.Message, "credit") {
return true
}
if strings.Contains(err.Message, "balance") {
return true
}
return false
}
func ShouldEnableChannel(err error, openAIErr *model.Error) bool {
if !config.AutomaticEnableChannelEnabled {
return false
}
if err != nil {
return false
}
if openAIErr != nil {
return false
}
return true
}

45
relay/adaptor.go Normal file
View File

@@ -0,0 +1,45 @@
package relay
import (
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/aiproxy"
"github.com/songquanpeng/one-api/relay/adaptor/ali"
"github.com/songquanpeng/one-api/relay/adaptor/anthropic"
"github.com/songquanpeng/one-api/relay/adaptor/baidu"
"github.com/songquanpeng/one-api/relay/adaptor/gemini"
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/adaptor/palm"
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
"github.com/songquanpeng/one-api/relay/adaptor/zhipu"
"github.com/songquanpeng/one-api/relay/apitype"
)
func GetAdaptor(apiType int) adaptor.Adaptor {
switch apiType {
case apitype.AIProxyLibrary:
return &aiproxy.Adaptor{}
case apitype.Ali:
return &ali.Adaptor{}
case apitype.Anthropic:
return &anthropic.Adaptor{}
case apitype.Baidu:
return &baidu.Adaptor{}
case apitype.Gemini:
return &gemini.Adaptor{}
case apitype.OpenAI:
return &openai.Adaptor{}
case apitype.PaLM:
return &palm.Adaptor{}
case apitype.Tencent:
return &tencent.Adaptor{}
case apitype.Xunfei:
return &xunfei.Adaptor{}
case apitype.Zhipu:
return &zhipu.Adaptor{}
case apitype.Ollama:
return &ollama.Adaptor{}
}
return nil
}

View File

@@ -4,10 +4,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
) )
@@ -15,16 +15,16 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey) req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil return nil
} }
@@ -34,15 +34,22 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
aiProxyLibraryRequest := ConvertRequest(*request) aiProxyLibraryRequest := ConvertRequest(*request)
aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID) aiProxyLibraryRequest.LibraryId = c.GetString(config.KeyLibraryID)
return aiProxyLibraryRequest, nil return aiProxyLibraryRequest, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {

View File

@@ -1,6 +1,6 @@
package aiproxy package aiproxy
import "github.com/songquanpeng/one-api/relay/channel/openai" import "github.com/songquanpeng/one-api/relay/adaptor/openai"
var ModelList = []string{""} var ModelList = []string{""}

View File

@@ -8,7 +8,8 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@@ -53,7 +54,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon
FinishReason: "stop", FinishReason: "stop",
} }
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
@@ -66,7 +67,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion
choice.Delta.Content = aiProxyDocuments2Markdown(documents) choice.Delta.Content = aiProxyDocuments2Markdown(documents)
choice.FinishReason = &constant.StopFinishReason choice.FinishReason = &constant.StopFinishReason
return &openai.ChatCompletionsStreamResponse{ return &openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "", Model: "",
@@ -78,7 +79,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = response.Content choice.Delta.Content = response.Content
return &openai.ChatCompletionsStreamResponse{ return &openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Model: response.Model, Model: response.Model,

View File

@@ -0,0 +1,105 @@
package ali
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
)
// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
type Adaptor struct {
}
func (a *Adaptor) Init(meta *meta.Meta) {
}
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
fullRequestURL := ""
switch meta.Mode {
case relaymode.Embeddings:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL)
case relaymode.ImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL)
default:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL)
}
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
adaptor.SetupCommonRequestHeader(c, req, meta)
if meta.IsStream {
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("X-DashScope-SSE", "enable")
}
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.Mode == relaymode.ImagesGenerations {
req.Header.Set("X-DashScope-Async", "enable")
}
if c.GetString(config.KeyPlugin) != "" {
req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin))
}
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case relaymode.Embeddings:
aliEmbeddingRequest := ConvertEmbeddingRequest(*request)
return aliEmbeddingRequest, nil
default:
aliRequest := ConvertRequest(*request)
return aliRequest, nil
}
}
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
aliRequest := ConvertImageRequest(*request)
return aliRequest, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
switch meta.Mode {
case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp)
case relaymode.ImagesGenerations:
err, usage = ImageHandler(c, resp)
default:
err, usage = Handler(c, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "ali"
}

View File

@@ -3,4 +3,5 @@ package ali
var ModelList = []string{ var ModelList = []string{
"qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext",
"text-embedding-v1", "text-embedding-v1",
"ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1",
} }

192
relay/adaptor/ali/image.go Normal file
View File

@@ -0,0 +1,192 @@
package ali
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
"time"
)
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
apiKey := c.Request.Header.Get("Authorization")
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
responseFormat := c.GetString("response_format")
var aliTaskResponse TaskResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {
logger.SysError("aliAsyncTask err: " + string(responseBody))
return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil
}
aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey)
if err != nil {
return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, nil
}
func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) {
url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID)
var aliResponse TaskResponse
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return &aliResponse, err, nil
}
req.Header.Set("Authorization", "Bearer "+key)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
logger.SysError("aliAsyncTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
var response TaskResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
logger.SysError("aliAsyncTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
return &response, nil, responseBody
}
func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) {
waitSeconds := 2
step := 0
maxStep := 20
var taskResponse TaskResponse
var responseBody []byte
for {
step++
rsp, err, body := asyncTask(taskID, key)
responseBody = body
if err != nil {
return &taskResponse, responseBody, err
}
if rsp.Output.TaskStatus == "" {
return &taskResponse, responseBody, nil
}
switch rsp.Output.TaskStatus {
case "FAILED":
fallthrough
case "CANCELED":
fallthrough
case "SUCCEEDED":
fallthrough
case "UNKNOWN":
return rsp, responseBody, nil
}
if step >= maxStep {
break
}
time.Sleep(time.Duration(waitSeconds) * time.Second)
}
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
}
func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse {
imageResponse := openai.ImageResponse{
Created: helper.GetTimestamp(),
}
for _, data := range response.Output.Results {
var b64Json string
if responseFormat == "b64_json" {
// 读取 data.Url 的图片数据并转存到 b64Json
imageData, err := getImageData(data.Url)
if err != nil {
// 处理获取图片数据失败的情况
logger.SysError("getImageData Error getting image data: " + err.Error())
continue
}
// 将图片数据转为 Base64 编码的字符串
b64Json = Base64Encode(imageData)
} else {
// 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image
b64Json = data.B64Image
}
imageResponse.Data = append(imageResponse.Data, openai.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
return &imageResponse
}
func getImageData(url string) ([]byte, error) {
response, err := http.Get(url)
if err != nil {
return nil, err
}
defer response.Body.Close()
imageData, err := io.ReadAll(response.Body)
if err != nil {
return nil, err
}
return imageData, nil
}
func Base64Encode(data []byte) string {
b64Json := base64.StdEncoding.EncodeToString(data)
return b64Json
}

View File

@@ -7,7 +7,7 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
"net/http" "net/http"
@@ -66,6 +66,17 @@ func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingReque
} }
} }
func ConvertImageRequest(request model.ImageRequest) *ImageRequest {
var imageRequest ImageRequest
imageRequest.Input.Prompt = request.Prompt
imageRequest.Model = request.Model
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
imageRequest.Parameters.N = request.N
imageRequest.ResponseFormat = request.ResponseFormat
return &imageRequest
}
func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func EmbeddingHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var aliResponse EmbeddingResponse var aliResponse EmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&aliResponse) err := json.NewDecoder(resp.Body).Decode(&aliResponse)

154
relay/adaptor/ali/model.go Normal file
View File

@@ -0,0 +1,154 @@
package ali
import (
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model"
)
type Message struct {
Content string `json:"content"`
Role string `json:"role"`
}
type Input struct {
//Prompt string `json:"prompt"`
Messages []Message `json:"messages"`
}
type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
ResultFormat string `json:"result_format,omitempty"`
Tools []model.Tool `json:"tools,omitempty"`
}
type ChatRequest struct {
Model string `json:"model"`
Input Input `json:"input"`
Parameters Parameters `json:"parameters,omitempty"`
}
type ImageRequest struct {
Model string `json:"model"`
Input struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
} `json:"input"`
Parameters struct {
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
} `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
type TaskResponse struct {
StatusCode int `json:"status_code,omitempty"`
RequestId string `json:"request_id,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Output struct {
TaskId string `json:"task_id,omitempty"`
TaskStatus string `json:"task_status,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
Results []struct {
B64Image string `json:"b64_image,omitempty"`
Url string `json:"url,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
} `json:"results,omitempty"`
TaskMetrics struct {
Total int `json:"TOTAL,omitempty"`
Succeeded int `json:"SUCCEEDED,omitempty"`
Failed int `json:"FAILED,omitempty"`
} `json:"task_metrics,omitempty"`
} `json:"output,omitempty"`
Usage Usage `json:"usage"`
}
type Header struct {
Action string `json:"action,omitempty"`
Streaming string `json:"streaming,omitempty"`
TaskID string `json:"task_id,omitempty"`
Event string `json:"event,omitempty"`
ErrorCode string `json:"error_code,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
Attributes any `json:"attributes,omitempty"`
}
type Payload struct {
Model string `json:"model,omitempty"`
Task string `json:"task,omitempty"`
TaskGroup string `json:"task_group,omitempty"`
Function string `json:"function,omitempty"`
Parameters struct {
SampleRate int `json:"sample_rate,omitempty"`
Rate float64 `json:"rate,omitempty"`
Format string `json:"format,omitempty"`
} `json:"parameters,omitempty"`
Input struct {
Text string `json:"text,omitempty"`
} `json:"input,omitempty"`
Usage struct {
Characters int `json:"characters,omitempty"`
} `json:"usage,omitempty"`
}
type WSSMessage struct {
Header Header `json:"header,omitempty"`
Payload Payload `json:"payload,omitempty"`
}
type EmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type Embedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type EmbeddingResponse struct {
Output struct {
Embeddings []Embedding `json:"embeddings"`
} `json:"output"`
Usage Usage `json:"usage"`
Error
}
type Error struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type Output struct {
//Text string `json:"text"`
//FinishReason string `json:"finish_reason"`
Choices []openai.TextResponseChoice `json:"choices"`
}
type ChatResponse struct {
Output Output `json:"output"`
Usage Usage `json:"usage"`
Error
}

View File

@@ -4,9 +4,9 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
) )
@@ -14,16 +14,16 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-api-key", meta.APIKey) req.Header.Set("x-api-key", meta.APIKey)
anthropicVersion := c.Request.Header.Get("anthropic-version") anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" { if anthropicVersion == "" {
@@ -41,11 +41,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil return ConvertRequest(*request), nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {

View File

@@ -9,7 +9,7 @@ import (
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
"net/http" "net/http"

View File

@@ -0,0 +1,15 @@
package azure
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
)
func GetAPIVersion(c *gin.Context) string {
query := c.Request.URL.Query()
apiVersion := query.Get("api-version")
if apiVersion == "" {
apiVersion = c.GetString(config.KeyAPIVersion)
}
return apiVersion
}

View File

@@ -3,25 +3,25 @@ package baidu
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
) )
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/" suffix := "chat/"
if strings.HasPrefix(meta.ActualModelName, "Embedding") { if strings.HasPrefix(meta.ActualModelName, "Embedding") {
@@ -44,17 +44,25 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
suffix += "eb-instant" suffix += "eb-instant"
case "ERNIE-Speed": case "ERNIE-Speed":
suffix += "ernie_speed" suffix += "ernie_speed"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-4.0-8K": case "ERNIE-4.0-8K":
suffix += "completions_pro" suffix += "completions_pro"
case "ERNIE-3.5-8K": case "ERNIE-3.5-8K":
suffix += "completions" suffix += "completions"
case "ERNIE-3.5-8K-0205":
suffix += "ernie-3.5-8k-0205"
case "ERNIE-3.5-8K-1222":
suffix += "ernie-3.5-8k-1222"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-3.5-4K-0205":
suffix += "ernie-3.5-4k-0205"
case "ERNIE-Speed-8K": case "ERNIE-Speed-8K":
suffix += "ernie_speed" suffix += "ernie_speed"
case "ERNIE-Speed-128K": case "ERNIE-Speed-128K":
suffix += "ernie-speed-128k" suffix += "ernie-speed-128k"
case "ERNIE-Lite-8K": case "ERNIE-Lite-8K-0922":
suffix += "eb-instant"
case "ERNIE-Lite-8K-0308":
suffix += "ernie-lite-8k" suffix += "ernie-lite-8k"
case "ERNIE-Tiny-8K": case "ERNIE-Tiny-8K":
suffix += "ernie-tiny-8k" suffix += "ernie-tiny-8k"
@@ -81,8 +89,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fullRequestURL, nil return fullRequestURL, nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey) req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil return nil
} }
@@ -92,7 +100,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch relayMode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil return baiduEmbeddingRequest, nil
default: default:
@@ -101,16 +109,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
} }
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {
switch meta.Mode { switch meta.Mode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp) err, usage = EmbeddingHandler(c, resp)
default: default:
err, usage = Handler(c, resp) err, usage = Handler(c, resp)

View File

@@ -2,15 +2,15 @@ package baidu
var ModelList = []string{ var ModelList = []string{
"ERNIE-4.0-8K", "ERNIE-4.0-8K",
"ERNIE-Bot-8K-0922",
"ERNIE-3.5-8K", "ERNIE-3.5-8K",
"ERNIE-Lite-8K-0922",
"ERNIE-Speed-8K",
"ERNIE-3.5-4K-0205",
"ERNIE-3.5-8K-0205", "ERNIE-3.5-8K-0205",
"ERNIE-3.5-8K-1222", "ERNIE-3.5-8K-1222",
"ERNIE-Lite-8K", "ERNIE-Bot-8K",
"ERNIE-3.5-4K-0205",
"ERNIE-Speed-8K",
"ERNIE-Speed-128K", "ERNIE-Speed-128K",
"ERNIE-Lite-8K-0922",
"ERNIE-Lite-8K-0308",
"ERNIE-Tiny-8K", "ERNIE-Tiny-8K",
"BLOOMZ-7B", "BLOOMZ-7B",
"Embedding-V1", "Embedding-V1",

View File

@@ -8,10 +8,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -305,7 +305,7 @@ func getBaiduAccessTokenHelper(apiKey string) (*AccessToken, error) {
} }
req.Header.Add("Content-Type", "application/json") req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json") req.Header.Add("Accept", "application/json")
res, err := util.ImpatientHTTPClient.Do(req) res, err := client.ImpatientHTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,15 +1,16 @@
package channel package adaptor
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/client"
"github.com/songquanpeng/one-api/relay/meta"
"io" "io"
"net/http" "net/http"
) )
func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) { func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept")) req.Header.Set("Accept", c.Request.Header.Get("Accept"))
if meta.IsStream && c.Request.Header.Get("Accept") == "" { if meta.IsStream && c.Request.Header.Get("Accept") == "" {
@@ -17,7 +18,7 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.Rela
} }
} }
func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(meta) fullRequestURL, err := a.GetRequestURL(meta)
if err != nil { if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err) return nil, fmt.Errorf("get request url failed: %w", err)
@@ -38,7 +39,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBod
} }
func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) { func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) {
resp, err := util.HTTPClient.Do(req) resp, err := client.HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -4,11 +4,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
channelhelper "github.com/songquanpeng/one-api/relay/channel" channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
) )
@@ -16,12 +17,12 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
version := helper.AssignOrDefault(meta.APIVersion, "v1") version := helper.AssignOrDefault(meta.APIVersion, config.GeminiVersion)
action := "generateContent" action := "generateContent"
if meta.IsStream { if meta.IsStream {
action = "streamGenerateContent" action = "streamGenerateContent"
@@ -29,7 +30,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", meta.BaseURL, version, meta.ActualModelName, action), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channelhelper.SetupCommonRequestHeader(c, req, meta) channelhelper.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey) req.Header.Set("x-goog-api-key", meta.APIKey)
return nil return nil
@@ -42,11 +43,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil return ConvertRequest(*request), nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return channelhelper.DoRequestHelper(a, c, meta, requestBody) return channelhelper.DoRequestHelper(a, c, meta, requestBody)
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
var responseText string var responseText string
err, responseText = StreamHandler(c, resp) err, responseText = StreamHandler(c, resp)

View File

@@ -9,7 +9,8 @@ import (
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@@ -155,7 +156,7 @@ type ChatPromptFeedback struct {
func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse {
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)),
@@ -233,7 +234,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC
var choice openai.ChatCompletionsStreamResponseChoice var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Content = dummy.Content choice.Delta.Content = dummy.Content
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "gemini-pro", Model: "gemini-pro",

View File

@@ -0,0 +1,21 @@
package adaptor
import (
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
type Adaptor interface {
Init(meta *meta.Meta)
GetRequestURL(meta *meta.Meta) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error
ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error)
ConvertImageRequest(request *model.ImageRequest) (any, error)
DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode)
GetModelList() []string
GetChannelName() string
}

View File

@@ -0,0 +1,14 @@
package minimax
import (
"fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
)
func GetRequestURL(meta *meta.Meta) (string, error) {
if meta.Mode == relaymode.ChatCompletions {
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil
}
return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode)
}

View File

@@ -3,34 +3,34 @@ package ollama
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
) )
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md // https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
if meta.Mode == constant.RelayModeEmbeddings { if meta.Mode == relaymode.Embeddings {
fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL)
} }
return fullRequestURL, nil return fullRequestURL, nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey) req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil return nil
} }
@@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch relayMode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request) ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request)
return ollamaEmbeddingRequest, nil return ollamaEmbeddingRequest, nil
default: default:
@@ -48,16 +48,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
} }
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {
switch meta.Mode { switch meta.Mode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
err, usage = EmbeddingHandler(c, resp) err, usage = EmbeddingHandler(c, resp)
default: default:
err, usage = Handler(c, resp) err, usage = Handler(c, resp)

View File

@@ -5,15 +5,16 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/random"
"io" "io"
"net/http" "net/http"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
) )
@@ -51,7 +52,7 @@ func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse {
choice.FinishReason = "stop" choice.FinishReason = "stop"
} }
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
@@ -72,7 +73,7 @@ func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompl
choice.FinishReason = &constant.StopFinishReason choice.FinishReason = &constant.StopFinishReason
} }
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Model: ollamaResponse.Model, Model: ollamaResponse.Model,

View File

@@ -4,11 +4,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/channel/minimax" "github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -18,13 +19,20 @@ type Adaptor struct {
ChannelType int ChannelType int
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
a.ChannelType = meta.ChannelType a.ChannelType = meta.ChannelType
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
switch meta.ChannelType { switch meta.ChannelType {
case common.ChannelTypeAzure: case channeltype.Azure:
if meta.Mode == relaymode.ImagesGenerations {
// https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api
// https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview
fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion)
return fullRequestURL, nil
}
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(meta.RequestURLPath, "?")[0] requestURL := strings.Split(meta.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion)
@@ -34,22 +42,22 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
//https://github.com/songquanpeng/one-api/issues/1191 //https://github.com/songquanpeng/one-api/issues/1191
// {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version} // {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version}
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil
case common.ChannelTypeMinimax: case channeltype.Minimax:
return minimax.GetRequestURL(meta) return minimax.GetRequestURL(meta)
default: default:
return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil
} }
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
if meta.ChannelType == common.ChannelTypeAzure { if meta.ChannelType == channeltype.Azure {
req.Header.Set("api-key", meta.APIKey) req.Header.Set("api-key", meta.APIKey)
return nil return nil
} }
req.Header.Set("Authorization", "Bearer "+meta.APIKey) req.Header.Set("Authorization", "Bearer "+meta.APIKey)
if meta.ChannelType == common.ChannelTypeOpenRouter { if meta.ChannelType == channeltype.OpenRouter {
req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api")
req.Header.Set("X-Title", "One API") req.Header.Set("X-Title", "One API")
} }
@@ -63,11 +71,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return request, nil return request, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
var responseText string var responseText string
err, responseText, usage = StreamHandler(c, resp, meta.Mode) err, responseText, usage = StreamHandler(c, resp, meta.Mode)
@@ -75,7 +90,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens)
} }
} else { } else {
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) switch meta.Mode {
case relaymode.ImagesGenerations:
err, _ = ImageHandler(c, resp)
default:
err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName)
}
} }
return return
} }

View File

@@ -0,0 +1,50 @@
package openai
import (
"github.com/songquanpeng/one-api/relay/adaptor/ai360"
"github.com/songquanpeng/one-api/relay/adaptor/baichuan"
"github.com/songquanpeng/one-api/relay/adaptor/groq"
"github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu"
"github.com/songquanpeng/one-api/relay/adaptor/minimax"
"github.com/songquanpeng/one-api/relay/adaptor/mistral"
"github.com/songquanpeng/one-api/relay/adaptor/moonshot"
"github.com/songquanpeng/one-api/relay/adaptor/stepfun"
"github.com/songquanpeng/one-api/relay/channeltype"
)
var CompatibleChannels = []int{
channeltype.Azure,
channeltype.AI360,
channeltype.Moonshot,
channeltype.Baichuan,
channeltype.Minimax,
channeltype.Mistral,
channeltype.Groq,
channeltype.LingYiWanWu,
channeltype.StepFun,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
switch channelType {
case channeltype.Azure:
return "azure", ModelList
case channeltype.AI360:
return "360", ai360.ModelList
case channeltype.Moonshot:
return "moonshot", moonshot.ModelList
case channeltype.Baichuan:
return "baichuan", baichuan.ModelList
case channeltype.Minimax:
return "minimax", minimax.ModelList
case channeltype.Mistral:
return "mistralai", mistral.ModelList
case channeltype.Groq:
return "groq", groq.ModelList
case channeltype.LingYiWanWu:
return "lingyiwanwu", lingyiwanwu.ModelList
case channeltype.StepFun:
return "stepfun", stepfun.ModelList
default:
return "openai", ModelList
}
}

View File

@@ -6,7 +6,7 @@ var ModelList = []string{
"gpt-3.5-turbo-instruct", "gpt-3.5-turbo-instruct",
"gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
"gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613", "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4-vision-preview", "gpt-4-vision-preview",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001", "text-davinci-002", "text-davinci-003",

View File

@@ -0,0 +1,30 @@
package openai
import (
"fmt"
"github.com/songquanpeng/one-api/relay/channeltype"
"github.com/songquanpeng/one-api/relay/model"
"strings"
)
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage {
usage := &model.Usage{}
usage.PromptTokens = promptTokens
usage.CompletionTokens = CountTokenText(responseText, modeName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return usage
}
func GetFullRequestURL(baseURL string, requestURL string, channelType int) string {
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
switch channelType {
case channeltype.OpenAI:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1"))
case channeltype.Azure:
fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments"))
}
}
return fullRequestURL
}

View File

@@ -0,0 +1,44 @@
package openai
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)
func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var imageResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &imageResponse)
if err != nil {
return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil
}
err = resp.Body.Close()
if err != nil {
return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, nil
}

View File

@@ -8,8 +8,8 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -46,7 +46,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
data = data[6:] data = data[6:]
if !strings.HasPrefix(data, "[DONE]") { if !strings.HasPrefix(data, "[DONE]") {
switch relayMode { switch relayMode {
case constant.RelayModeChatCompletions: case relaymode.ChatCompletions:
var streamResponse ChatCompletionsStreamResponse var streamResponse ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse) err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil { if err != nil {
@@ -59,7 +59,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E
if streamResponse.Usage != nil { if streamResponse.Usage != nil {
usage = streamResponse.Usage usage = streamResponse.Usage
} }
case constant.RelayModeCompletions: case relaymode.Completions:
var streamResponse CompletionsStreamResponse var streamResponse CompletionsStreamResponse
err := json.Unmarshal([]byte(data), &streamResponse) err := json.Unmarshal([]byte(data), &streamResponse)
if err != nil { if err != nil {

View File

@@ -110,11 +110,16 @@ type EmbeddingResponse struct {
model.Usage `json:"usage"` model.Usage `json:"usage"`
} }
type ImageData struct {
Url string `json:"url,omitempty"`
B64Json string `json:"b64_json,omitempty"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
}
type ImageResponse struct { type ImageResponse struct {
Created int `json:"created"` Created int64 `json:"created"`
Data []struct { Data []ImageData `json:"data"`
Url string `json:"url"` //model.Usage `json:"usage"`
}
} }
type ChatCompletionsStreamResponseChoice struct { type ChatCompletionsStreamResponseChoice struct {

View File

@@ -4,10 +4,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/pkoukk/tiktoken-go" "github.com/pkoukk/tiktoken-go"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/image"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"math" "math"
"strings" "strings"
@@ -28,7 +28,7 @@ func InitTokenEncoders() {
if err != nil { if err != nil {
logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
} }
for model := range common.ModelRatio { for model := range billingratio.ModelRatio {
if strings.HasPrefix(model, "gpt-3.5") { if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder tokenEncoderMap[model] = gpt35TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") { } else if strings.HasPrefix(model, "gpt-4") {

View File

@@ -4,10 +4,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
) )
@@ -15,16 +15,16 @@ import (
type Adaptor struct { type Adaptor struct {
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("x-goog-api-key", meta.APIKey) req.Header.Set("x-goog-api-key", meta.APIKey)
return nil return nil
} }
@@ -36,11 +36,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return ConvertRequest(*request), nil return ConvertRequest(*request), nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
var responseText string var responseText string
err, responseText = StreamHandler(c, resp) err, responseText = StreamHandler(c, resp)

View File

@@ -7,7 +7,8 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@@ -74,7 +75,7 @@ func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletio
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) {
responseText := "" responseText := ""
responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID())
createdTime := helper.GetTimestamp() createdTime := helper.GetTimestamp()
dataChan := make(chan string) dataChan := make(chan string)
stopChan := make(chan bool) stopChan := make(chan bool)

View File

@@ -4,10 +4,10 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -19,16 +19,16 @@ type Adaptor struct {
Sign string Sign string
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil return fmt.Sprintf("%s/hyllm/v1/chat/completions", meta.BaseURL), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", a.Sign) req.Header.Set("Authorization", a.Sign)
req.Header.Set("X-TC-Action", meta.ActualModelName) req.Header.Set("X-TC-Action", meta.ActualModelName)
return nil return nil
@@ -52,11 +52,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return tencentRequest, nil return tencentRequest, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
var responseText string var responseText string
err, responseText = StreamHandler(c, resp) err, responseText = StreamHandler(c, resp)

View File

@@ -13,7 +13,8 @@ import (
"github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/conv"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@@ -41,7 +42,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
return &ChatRequest{ return &ChatRequest{
Timestamp: helper.GetTimestamp(), Timestamp: helper.GetTimestamp(),
Expired: helper.GetTimestamp() + 24*60*60, Expired: helper.GetTimestamp() + 24*60*60,
QueryID: helper.GetUUID(), QueryID: random.GetUUID(),
Temperature: request.Temperature, Temperature: request.Temperature,
TopP: request.TopP, TopP: request.TopP,
Stream: stream, Stream: stream,
@@ -71,7 +72,7 @@ func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse {
func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "tencent-hunyuan", Model: "tencent-hunyuan",

View File

@@ -3,10 +3,10 @@ package xunfei
import ( import (
"errors" "errors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@@ -16,16 +16,16 @@ type Adaptor struct {
request *model.GeneralOpenAIRequest request *model.GeneralOpenAIRequest
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
return "", nil return "", nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
// check DoResponse for auth part // check DoResponse for auth part
return nil return nil
} }
@@ -38,14 +38,21 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, nil return nil, nil
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
// xunfei's request is not http request, so we don't need to do anything here // xunfei's request is not http request, so we don't need to do anything here
dummyResp := &http.Response{} dummyResp := &http.Response{}
dummyResp.StatusCode = http.StatusOK dummyResp.StatusCode = http.StatusOK
return dummyResp, nil return dummyResp, nil
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
splits := strings.Split(meta.APIKey, "|") splits := strings.Split(meta.APIKey, "|")
if len(splits) != 3 { if len(splits) != 3 {
return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) return nil, openai.ErrorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest)

View File

@@ -9,9 +9,11 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/common/random"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@@ -62,7 +64,7 @@ func getToolCalls(response *ChatResponse) []model.Tool {
return toolCalls return toolCalls
} }
toolCall := model.Tool{ toolCall := model.Tool{
Id: fmt.Sprintf("call_%s", helper.GetUUID()), Id: fmt.Sprintf("call_%s", random.GetUUID()),
Type: "function", Type: "function",
Function: *item.FunctionCall, Function: *item.FunctionCall,
} }
@@ -88,7 +90,7 @@ func responseXunfei2OpenAI(response *ChatResponse) *openai.TextResponse {
FinishReason: constant.StopFinishReason, FinishReason: constant.StopFinishReason,
} }
fullTextResponse := openai.TextResponse{ fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion", Object: "chat.completion",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice}, Choices: []openai.TextResponseChoice{choice},
@@ -112,7 +114,7 @@ func streamResponseXunfei2OpenAI(xunfeiResponse *ChatResponse) *openai.ChatCompl
choice.FinishReason = &constant.StopFinishReason choice.FinishReason = &constant.StopFinishReason
} }
response := openai.ChatCompletionsStreamResponse{ response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()),
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
Created: helper.GetTimestamp(), Created: helper.GetTimestamp(),
Model: "SparkDesk", Model: "SparkDesk",
@@ -278,7 +280,7 @@ func getAPIVersion(c *gin.Context, modelName string) string {
return apiVersion return apiVersion
} }
apiVersion = c.GetString(common.ConfigKeyAPIVersion) apiVersion = c.GetString(config.KeyAPIVersion)
if apiVersion != "" { if apiVersion != "" {
return apiVersion return apiVersion
} }

View File

@@ -4,11 +4,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/relaymode"
"io" "io"
"math" "math"
"net/http" "net/http"
@@ -19,7 +19,7 @@ type Adaptor struct {
APIVersion string APIVersion string
} }
func (a *Adaptor) Init(meta *util.RelayMeta) { func (a *Adaptor) Init(meta *meta.Meta) {
} }
@@ -31,14 +31,17 @@ func (a *Adaptor) SetVersionByModeName(modelName string) {
} }
} }
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
switch meta.Mode {
case relaymode.ImagesGenerations:
return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil
case relaymode.Embeddings:
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
}
a.SetVersionByModeName(meta.ActualModelName) a.SetVersionByModeName(meta.ActualModelName)
if a.APIVersion == "v4" { if a.APIVersion == "v4" {
return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil
} }
if meta.Mode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil
}
method := "invoke" method := "invoke"
if meta.IsStream { if meta.IsStream {
method = "sse-invoke" method = "sse-invoke"
@@ -46,8 +49,8 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
channel.SetupCommonRequestHeader(c, req, meta) adaptor.SetupCommonRequestHeader(c, req, meta)
token := GetToken(meta.APIKey) token := GetToken(meta.APIKey)
req.Header.Set("Authorization", token) req.Header.Set("Authorization", token)
return nil return nil
@@ -58,7 +61,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
switch relayMode { switch relayMode {
case constant.RelayModeEmbeddings: case relaymode.Embeddings:
baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) baiduEmbeddingRequest := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, nil return baiduEmbeddingRequest, nil
default: default:
@@ -77,11 +80,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
} }
} }
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
return channel.DoRequestHelper(a, c, meta, requestBody) if request == nil {
return nil, errors.New("request is nil")
}
newRequest := ImageRequest{
Model: request.Model,
Prompt: request.Prompt,
UserId: request.User,
}
return newRequest, nil
} }
func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
return adaptor.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream { if meta.IsStream {
err, _, usage = openai.StreamHandler(c, resp, meta.Mode) err, _, usage = openai.StreamHandler(c, resp, meta.Mode)
} else { } else {
@@ -90,15 +105,22 @@ func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *util.R
return return
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
switch meta.Mode {
case relaymode.Embeddings:
err, usage = EmbeddingsHandler(c, resp)
return
case relaymode.ImagesGenerations:
err, usage = openai.ImageHandler(c, resp)
return
}
if a.APIVersion == "v4" { if a.APIVersion == "v4" {
return a.DoResponseV4(c, resp, meta) return a.DoResponseV4(c, resp, meta)
} }
if meta.IsStream { if meta.IsStream {
err, usage = StreamHandler(c, resp) err, usage = StreamHandler(c, resp)
} else { } else {
if meta.Mode == constant.RelayModeEmbeddings { if meta.Mode == relaymode.Embeddings {
err, usage = EmbeddingsHandler(c, resp) err, usage = EmbeddingsHandler(c, resp)
} else { } else {
err, usage = Handler(c, resp) err, usage = Handler(c, resp)

View File

@@ -3,4 +3,5 @@ package zhipu
var ModelList = []string{ var ModelList = []string{
"chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite", "chatglm_turbo", "chatglm_pro", "chatglm_std", "chatglm_lite",
"glm-4", "glm-4v", "glm-3-turbo", "embedding-2", "glm-4", "glm-4v", "glm-3-turbo", "embedding-2",
"cogview-3",
} }

View File

@@ -8,7 +8,7 @@ import (
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai" "github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"io" "io"
@@ -256,7 +256,7 @@ func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *
} }
func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var zhipuResponse EmbeddingRespone var zhipuResponse EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
@@ -280,7 +280,7 @@ func EmbeddingsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithSta
return nil, &fullTextResponse.Usage return nil, &fullTextResponse.Usage
} }
func embeddingResponseZhipu2OpenAI(response *EmbeddingRespone) *openai.EmbeddingResponse { func embeddingResponseZhipu2OpenAI(response *EmbeddingResponse) *openai.EmbeddingResponse {
openAIEmbeddingResponse := openai.EmbeddingResponse{ openAIEmbeddingResponse := openai.EmbeddingResponse{
Object: "list", Object: "list",
Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)), Data: make([]openai.EmbeddingResponseItem, 0, len(response.Embeddings)),

View File

@@ -50,7 +50,7 @@ type EmbeddingRequest struct {
Input string `json:"input"` Input string `json:"input"`
} }
type EmbeddingRespone struct { type EmbeddingResponse struct {
Model string `json:"model"` Model string `json:"model"`
Object string `json:"object"` Object string `json:"object"`
Embeddings []EmbeddingData `json:"data"` Embeddings []EmbeddingData `json:"data"`
@@ -62,3 +62,9 @@ type EmbeddingData struct {
Object string `json:"object"` Object string `json:"object"`
Embedding []float64 `json:"embedding"` Embedding []float64 `json:"embedding"`
} }
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
UserId string `json:"user_id,omitempty"`
}

17
relay/apitype/define.go Normal file
View File

@@ -0,0 +1,17 @@
package apitype
const (
OpenAI = iota
Anthropic
PaLM
Baidu
Zhipu
Ali
Xunfei
AIProxyLibrary
Tencent
Gemini
Ollama
Dummy // this one is only for count, do not add any channel after this
)

42
relay/billing/billing.go Normal file
View File

@@ -0,0 +1,42 @@
package billing
import (
"context"
"fmt"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/model"
)
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota
err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota)
if err != nil {
logger.Error(ctx, "error return pre-consumed quota: "+err.Error())
}
}(ctx)
}
}
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(ctx, userId)
if err != nil {
logger.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota != 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota)
}
if totalQuota <= 0 {
logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota))
}
}

Some files were not shown because too many files have changed in this diff Show More