From 50bab08496db65d50283af6fcf3a3f132a8beb28 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Sat, 6 Apr 2024 05:18:04 +0000 Subject: [PATCH] Refactor codebase, introduce relaymode package, update constants and improve consistency - Refactor constant definitions and organization - Clean up package level variables and functions - Introduce new `relaymode` and `apitype` packages for constant definitions - Refactor and simplify code in several packages including `openai`, `relay/channel/baidu`, `relay/util`, `relay/controller`, `relay/channeltype` - Add helper functions in `relay/channeltype` package to convert channel type constants to corresponding API type constants - Remove deprecated functions such as `ResponseText2Usage` from `relay/channel/openai/helper.go` - Modify code in `relay/util/validation.go` and related files to use new `validator.ValidateTextRequest` function - Rename `util` package to `relaymode` and update related imports in several packages --- .gitignore | 1 + common/config/config.go | 3 + common/config/key.go | 9 + common/constants.go | 113 - common/helper/helper.go | 89 +- common/helper/time.go | 15 + common/network/ip.go | 52 + common/network/ip_test.go | 19 + common/random.go | 8 - common/random/main.go | 62 + controller/{ => auth}/github.go | 16 +- controller/auth/lark.go | 201 + controller/{ => auth}/wechat.go | 12 +- controller/channel-billing.go | 39 +- controller/channel-test.go | 44 +- controller/group.go | 4 +- controller/misc.go | 1 + controller/model.go | 112 +- controller/redemption.go | 3 +- controller/relay.go | 26 +- controller/token.go | 129 +- controller/user.go | 119 +- docs/API.md | 53 + go.mod | 10 +- go.sum | 24 +- main.go | 2 +- middleware/auth.go | 36 +- middleware/distributor.go | 71 +- middleware/utils.go | 42 + model/ability.go | 20 +- model/cache.go | 25 +- model/channel.go | 14 +- model/log.go | 16 +- model/main.go | 9 +- model/option.go | 20 +- model/redemption.go | 10 +- model/token.go | 50 +- model/user.go | 78 +- monitor/channel.go | 7 +- monitor/manage.go | 62 + relay/adaptor.go | 41 + relay/{channel => adaptor}/ai360/constants.go | 0 relay/{channel => adaptor}/aiproxy/adaptor.go | 29 +- .../{channel => adaptor}/aiproxy/constants.go | 2 +- relay/{channel => adaptor}/aiproxy/main.go | 9 +- relay/{channel => adaptor}/aiproxy/model.go | 0 relay/adaptor/ali/adaptor.go | 105 + relay/{channel => adaptor}/ali/constants.go | 1 + relay/adaptor/ali/image.go | 192 + relay/{channel => adaptor}/ali/main.go | 0 relay/adaptor/ali/model.go | 154 + .../{channel => adaptor}/anthropic/adaptor.go | 31 +- .../anthropic/constants.go | 0 relay/{channel => adaptor}/anthropic/main.go | 2 +- relay/{channel => adaptor}/anthropic/model.go | 0 relay/adaptor/azure/helper.go | 15 + .../baichuan/constants.go | 0 relay/{channel => adaptor}/baidu/adaptor.go | 0 relay/adaptor/baidu/constants.go | 20 + relay/{channel => adaptor}/baidu/main.go | 0 relay/{channel => adaptor}/baidu/model.go | 0 relay/{channel => adaptor}/common.go | 11 +- relay/{channel => adaptor}/gemini/adaptor.go | 25 +- .../{channel => adaptor}/gemini/constants.go | 0 relay/{channel => adaptor}/gemini/main.go | 251 +- relay/{channel => adaptor}/gemini/model.go | 0 relay/{channel => adaptor}/groq/constants.go | 0 relay/adaptor/interface.go | 21 + .../lingyiwanwu/constants.go | 0 .../{channel => adaptor}/minimax/constants.go | 0 relay/adaptor/minimax/main.go | 14 + .../{channel => adaptor}/mistral/constants.go | 0 .../moonshot/constants.go | 0 relay/{channel => adaptor}/ollama/adaptor.go | 35 +- .../{channel => adaptor}/ollama/constants.go | 0 relay/{channel => adaptor}/ollama/main.go | 9 +- relay/{channel => adaptor}/ollama/model.go | 0 relay/{channel => adaptor}/openai/adaptor.go | 56 +- relay/adaptor/openai/compatible.go | 50 + .../{channel => adaptor}/openai/constants.go | 0 relay/adaptor/openai/helper.go | 30 + relay/adaptor/openai/image.go | 44 + relay/{channel => adaptor}/openai/main.go | 6 +- relay/{channel => adaptor}/openai/model.go | 13 +- relay/{channel => adaptor}/openai/token.go | 4 +- relay/{channel => adaptor}/openai/util.go | 0 relay/{channel => adaptor}/palm/adaptor.go | 27 +- relay/{channel => adaptor}/palm/constants.go | 0 relay/{channel => adaptor}/palm/model.go | 0 relay/{channel => adaptor}/palm/palm.go | 5 +- relay/adaptor/stepfun/constants.go | 7 + relay/{channel => adaptor}/tencent/adaptor.go | 0 .../{channel => adaptor}/tencent/constants.go | 0 relay/adaptor/tencent/main.go | 238 + relay/{channel => adaptor}/tencent/model.go | 0 relay/{channel => adaptor}/xunfei/adaptor.go | 0 .../{channel => adaptor}/xunfei/constants.go | 0 relay/{channel => adaptor}/xunfei/main.go | 0 relay/{channel => adaptor}/xunfei/model.go | 0 relay/adaptor/zhipu/adaptor.go | 145 + relay/{channel => adaptor}/zhipu/constants.go | 0 relay/{channel => adaptor}/zhipu/main.go | 0 relay/{channel => adaptor}/zhipu/model.go | 0 relay/apitype/define.go | 17 + relay/billing/billing.go | 42 + .../billing/ratio/group.go | 2 +- relay/billing/ratio/image.go | 51 + .../billing/ratio/model.go | 52 +- relay/channel/ali/adaptor.go | 83 - relay/channel/ali/model.go | 71 - relay/channel/baidu/constants.go | 13 - relay/channel/interface.go | 20 - relay/channel/minimax/main.go | 16 - relay/channel/openai/compatible.go | 46 - relay/channel/openai/helper.go | 11 - relay/channel/zhipu/adaptor.go | 62 - relay/channeltype/define.go | 39 + relay/channeltype/helper.go | 30 + relay/channeltype/url.go | 43 + relay/{util => client}/init.go | 2 +- relay/constant/api_type.go | 48 - relay/constant/image.go | 24 - relay/constant/relay_mode.go | 42 - relay/controller/audio.go | 50 +- relay/controller/error.go | 91 + relay/controller/helper.go | 73 +- relay/controller/image.go | 125 +- relay/controller/text.go | 33 +- .../validator}/validation.go | 19 +- relay/helper/main.go | 40 - relay/{util => meta}/relay_meta.go | 26 +- relay/model/image.go | 12 + relay/relaymode/define.go | 14 + relay/relaymode/helper.go | 29 + relay/util/billing.go | 19 - relay/util/common.go | 188 - relay/util/model_mapping.go | 12 - router/{api-router.go => api.go} | 14 +- router/{relay-router.go => relay.go} | 0 router/{web-router.go => web.go} | 0 web/README.md | 3 + web/berry/src/constants/ChannelConstants.js | 6 + web/berry/src/constants/SnackbarConstants.js | 2 +- web/berry/src/utils/common.js | 4 +- .../src/views/Channel/component/EditModal.js | 4 +- .../src/views/Token/component/EditModal.js | 2 +- web/default/src/App.js | 9 + web/default/src/components/LarkOAuth.js | 58 + web/default/src/components/LoginForm.js | 17 +- web/default/src/components/PersonalSetting.js | 7 +- web/default/src/components/SystemSetting.js | 54 + web/default/src/components/utils.js | 9 + .../src/constants/channel.constants.js | 1 + web/default/src/images/lark.svg | 1 + web/default/src/pages/Token/EditToken.js | 67 +- web/default/src/pages/TopUp/index.js | 13 +- web/package-lock.json | 17443 +--------------- 157 files changed, 3217 insertions(+), 19160 deletions(-) create mode 100644 common/config/key.go create mode 100644 common/helper/time.go create mode 100644 common/network/ip.go create mode 100644 common/network/ip_test.go delete mode 100644 common/random.go create mode 100644 common/random/main.go rename controller/{ => auth}/github.go (94%) create mode 100644 controller/auth/lark.go rename controller/{ => auth}/wechat.go (93%) create mode 100644 docs/API.md create mode 100644 monitor/manage.go create mode 100644 relay/adaptor.go rename relay/{channel => adaptor}/ai360/constants.go (100%) rename relay/{channel => adaptor}/aiproxy/adaptor.go (54%) rename relay/{channel => adaptor}/aiproxy/constants.go (60%) rename relay/{channel => adaptor}/aiproxy/main.go (95%) rename relay/{channel => adaptor}/aiproxy/model.go (100%) create mode 100644 relay/adaptor/ali/adaptor.go rename relay/{channel => adaptor}/ali/constants.go (65%) create mode 100644 relay/adaptor/ali/image.go rename relay/{channel => adaptor}/ali/main.go (100%) create mode 100644 relay/adaptor/ali/model.go rename relay/{channel => adaptor}/anthropic/adaptor.go (60%) rename relay/{channel => adaptor}/anthropic/constants.go (100%) rename relay/{channel => adaptor}/anthropic/main.go (99%) rename relay/{channel => adaptor}/anthropic/model.go (100%) create mode 100644 relay/adaptor/azure/helper.go rename relay/{channel => adaptor}/baichuan/constants.go (100%) rename relay/{channel => adaptor}/baidu/adaptor.go (100%) create mode 100644 relay/adaptor/baidu/constants.go rename relay/{channel => adaptor}/baidu/main.go (100%) rename relay/{channel => adaptor}/baidu/model.go (100%) rename relay/{channel => adaptor}/common.go (81%) rename relay/{channel => adaptor}/gemini/adaptor.go (66%) rename relay/{channel => adaptor}/gemini/constants.go (100%) rename relay/{channel => adaptor}/gemini/main.go (56%) rename relay/{channel => adaptor}/gemini/model.go (100%) rename relay/{channel => adaptor}/groq/constants.go (100%) create mode 100644 relay/adaptor/interface.go rename relay/{channel => adaptor}/lingyiwanwu/constants.go (100%) rename relay/{channel => adaptor}/minimax/constants.go (100%) create mode 100644 relay/adaptor/minimax/main.go rename relay/{channel => adaptor}/mistral/constants.go (100%) rename relay/{channel => adaptor}/moonshot/constants.go (100%) rename relay/{channel => adaptor}/ollama/adaptor.go (58%) rename relay/{channel => adaptor}/ollama/constants.go (100%) rename relay/{channel => adaptor}/ollama/main.go (97%) rename relay/{channel => adaptor}/ollama/model.go (100%) rename relay/{channel => adaptor}/openai/adaptor.go (52%) create mode 100644 relay/adaptor/openai/compatible.go rename relay/{channel => adaptor}/openai/constants.go (100%) create mode 100644 relay/adaptor/openai/helper.go create mode 100644 relay/adaptor/openai/image.go rename relay/{channel => adaptor}/openai/main.go (97%) rename relay/{channel => adaptor}/openai/model.go (93%) rename relay/{channel => adaptor}/openai/token.go (98%) rename relay/{channel => adaptor}/openai/util.go (100%) rename relay/{channel => adaptor}/palm/adaptor.go (59%) rename relay/{channel => adaptor}/palm/constants.go (100%) rename relay/{channel => adaptor}/palm/model.go (100%) rename relay/{channel => adaptor}/palm/palm.go (97%) create mode 100644 relay/adaptor/stepfun/constants.go rename relay/{channel => adaptor}/tencent/adaptor.go (100%) rename relay/{channel => adaptor}/tencent/constants.go (100%) create mode 100644 relay/adaptor/tencent/main.go rename relay/{channel => adaptor}/tencent/model.go (100%) rename relay/{channel => adaptor}/xunfei/adaptor.go (100%) rename relay/{channel => adaptor}/xunfei/constants.go (100%) rename relay/{channel => adaptor}/xunfei/main.go (100%) rename relay/{channel => adaptor}/xunfei/model.go (100%) create mode 100644 relay/adaptor/zhipu/adaptor.go rename relay/{channel => adaptor}/zhipu/constants.go (100%) rename relay/{channel => adaptor}/zhipu/main.go (100%) rename relay/{channel => adaptor}/zhipu/model.go (100%) create mode 100644 relay/apitype/define.go create mode 100644 relay/billing/billing.go rename common/group-ratio.go => relay/billing/ratio/group.go (97%) create mode 100644 relay/billing/ratio/image.go rename common/model-ratio.go => relay/billing/ratio/model.go (85%) delete mode 100644 relay/channel/ali/adaptor.go delete mode 100644 relay/channel/ali/model.go delete mode 100644 relay/channel/baidu/constants.go delete mode 100644 relay/channel/interface.go delete mode 100644 relay/channel/minimax/main.go delete mode 100644 relay/channel/openai/compatible.go delete mode 100644 relay/channel/openai/helper.go delete mode 100644 relay/channel/zhipu/adaptor.go create mode 100644 relay/channeltype/define.go create mode 100644 relay/channeltype/helper.go create mode 100644 relay/channeltype/url.go rename relay/{util => client}/init.go (97%) delete mode 100644 relay/constant/api_type.go delete mode 100644 relay/constant/image.go delete mode 100644 relay/constant/relay_mode.go create mode 100644 relay/controller/error.go rename relay/{util => controller/validator}/validation.go (76%) delete mode 100644 relay/helper/main.go rename relay/{util => meta}/relay_meta.go (64%) create mode 100644 relay/model/image.go create mode 100644 relay/relaymode/define.go create mode 100644 relay/relaymode/helper.go delete mode 100644 relay/util/billing.go delete mode 100644 relay/util/common.go delete mode 100644 relay/util/model_mapping.go rename router/{api-router.go => api.go} (92%) rename router/{relay-router.go => relay.go} (100%) rename router/{web-router.go => web.go} (100%) create mode 100644 web/default/src/components/LarkOAuth.js create mode 100644 web/default/src/images/lark.svg diff --git a/.gitignore b/.gitignore index 1cfa1e7f..be4abc52 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ build logs data node_modules +cmd.md diff --git a/common/config/config.go b/common/config/config.go index 66cfee06..59fd77df 100644 --- a/common/config/config.go +++ b/common/config/config.go @@ -80,6 +80,9 @@ var SMTPToken = "" var GitHubClientId = "" var GitHubClientSecret = "" +var LarkClientId = "" +var LarkClientSecret = "" + var WeChatServerAddress = "" var WeChatServerToken = "" var WeChatAccountQRCodeImageURL = "" diff --git a/common/config/key.go b/common/config/key.go new file mode 100644 index 00000000..4b503c2d --- /dev/null +++ b/common/config/key.go @@ -0,0 +1,9 @@ +package config + +const ( + KeyPrefix = "cfg_" + + KeyAPIVersion = KeyPrefix + "api_version" + KeyLibraryID = KeyPrefix + "library_id" + KeyPlugin = KeyPrefix + "plugin" +) diff --git a/common/constants.go b/common/constants.go index 849bdce7..87221b61 100644 --- a/common/constants.go +++ b/common/constants.go @@ -4,116 +4,3 @@ import "time" var StartTime = time.Now().Unix() // unit: second var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change - -const ( - RoleGuestUser = 0 - RoleCommonUser = 1 - RoleAdminUser = 10 - RoleRootUser = 100 -) - -const ( - UserStatusEnabled = 1 // don't use 0, 0 is the default value! - UserStatusDisabled = 2 // also don't use 0 - UserStatusDeleted = 3 -) - -const ( - TokenStatusEnabled = 1 // don't use 0, 0 is the default value! - TokenStatusDisabled = 2 // also don't use 0 - TokenStatusExpired = 3 - TokenStatusExhausted = 4 -) - -const ( - RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value! - RedemptionCodeStatusDisabled = 2 // also don't use 0 - RedemptionCodeStatusUsed = 3 // also don't use 0 -) - -const ( - ChannelStatusUnknown = 0 - ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! - ChannelStatusManuallyDisabled = 2 // also don't use 0 - ChannelStatusAutoDisabled = 3 -) - -const ( - ChannelTypeUnknown = iota - ChannelTypeOpenAI - ChannelTypeAPI2D - ChannelTypeAzure - ChannelTypeCloseAI - ChannelTypeOpenAISB - ChannelTypeOpenAIMax - ChannelTypeOhMyGPT - ChannelTypeCustom - ChannelTypeAILS - ChannelTypeAIProxy - ChannelTypePaLM - ChannelTypeAPI2GPT - ChannelTypeAIGC2D - ChannelTypeAnthropic - ChannelTypeBaidu - ChannelTypeZhipu - ChannelTypeAli - ChannelTypeXunfei - ChannelType360 - ChannelTypeOpenRouter - ChannelTypeAIProxyLibrary - ChannelTypeFastGPT - ChannelTypeTencent - ChannelTypeGemini - ChannelTypeMoonshot - ChannelTypeBaichuan - ChannelTypeMinimax - ChannelTypeMistral - ChannelTypeGroq - ChannelTypeOllama - ChannelTypeLingYiWanWu - - ChannelTypeDummy -) - -var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "https://generativelanguage.googleapis.com", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 - "https://fastgpt.run/api/openapi", // 22 - "https://hunyuan.cloud.tencent.com", // 23 - "https://generativelanguage.googleapis.com", // 24 - "https://api.moonshot.cn", // 25 - "https://api.baichuan-ai.com", // 26 - "https://api.minimax.chat", // 27 - "https://api.mistral.ai", // 28 - "https://api.groq.com/openai", // 29 - "http://localhost:11434", // 30 - "https://api.lingyiwanwu.com", // 31 -} - -const ( - ConfigKeyPrefix = "cfg_" - - ConfigKeyAPIVersion = ConfigKeyPrefix + "api_version" - ConfigKeyLibraryID = ConfigKeyPrefix + "library_id" - ConfigKeyPlugin = ConfigKeyPrefix + "plugin" -) diff --git a/common/helper/helper.go b/common/helper/helper.go index db41ac74..35d075bc 100644 --- a/common/helper/helper.go +++ b/common/helper/helper.go @@ -2,16 +2,15 @@ package helper import ( "fmt" - "github.com/google/uuid" "html/template" "log" - "math/rand" "net" "os/exec" "runtime" "strconv" "strings" - "time" + + "github.com/songquanpeng/one-api/common/random" ) func OpenBrowser(url string) { @@ -79,31 +78,6 @@ func Bytes2Size(num int64) string { return numStr + " " + unit } -func Seconds2Time(num int) (time string) { - if num/31104000 > 0 { - time += strconv.Itoa(num/31104000) + " 年 " - num %= 31104000 - } - if num/2592000 > 0 { - time += strconv.Itoa(num/2592000) + " 个月 " - num %= 2592000 - } - if num/86400 > 0 { - time += strconv.Itoa(num/86400) + " 天 " - num %= 86400 - } - if num/3600 > 0 { - time += strconv.Itoa(num/3600) + " 小时 " - num %= 3600 - } - if num/60 > 0 { - time += strconv.Itoa(num/60) + " 分钟 " - num %= 60 - } - time += strconv.Itoa(num) + " 秒" - return -} - func Interface2String(inter interface{}) string { switch inter := inter.(type) { case string: @@ -128,65 +102,8 @@ func IntMax(a int, b int) int { } } -func GetUUID() string { - code := uuid.New().String() - code = strings.Replace(code, "-", "", -1) - return code -} - -const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -const keyNumbers = "0123456789" - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func GenerateKey() string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, 48) - for i := 0; i < 16; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - uuid_ := GetUUID() - for i := 0; i < 32; i++ { - c := uuid_[i] - if i%2 == 0 && c >= 'a' && c <= 'z' { - c = c - 'a' + 'A' - } - key[i+16] = c - } - return string(key) -} - -func GetRandomString(length int) string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, length) - for i := 0; i < length; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - return string(key) -} - -func GetRandomNumberString(length int) string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, length) - for i := 0; i < length; i++ { - key[i] = keyNumbers[rand.Intn(len(keyNumbers))] - } - return string(key) -} - -func GetTimestamp() int64 { - return time.Now().Unix() -} - -func GetTimeString() string { - now := time.Now() - return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) -} - func GenRequestID() string { - return GetTimeString() + GetRandomNumberString(8) + return GetTimeString() + random.GetRandomNumberString(8) } func Max(a int, b int) int { diff --git a/common/helper/time.go b/common/helper/time.go new file mode 100644 index 00000000..302746db --- /dev/null +++ b/common/helper/time.go @@ -0,0 +1,15 @@ +package helper + +import ( + "fmt" + "time" +) + +func GetTimestamp() int64 { + return time.Now().Unix() +} + +func GetTimeString() string { + now := time.Now() + return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) +} diff --git a/common/network/ip.go b/common/network/ip.go new file mode 100644 index 00000000..0fbe5e6f --- /dev/null +++ b/common/network/ip.go @@ -0,0 +1,52 @@ +package network + +import ( + "context" + "fmt" + "github.com/songquanpeng/one-api/common/logger" + "net" + "strings" +) + +func splitSubnets(subnets string) []string { + res := strings.Split(subnets, ",") + for i := 0; i < len(res); i++ { + res[i] = strings.TrimSpace(res[i]) + } + return res +} + +func isValidSubnet(subnet string) error { + _, _, err := net.ParseCIDR(subnet) + if err != nil { + return fmt.Errorf("failed to parse subnet: %w", err) + } + return nil +} + +func isIpInSubnet(ctx context.Context, ip string, subnet string) bool { + _, ipNet, err := net.ParseCIDR(subnet) + if err != nil { + logger.Errorf(ctx, "failed to parse subnet: %s", err.Error()) + return false + } + return ipNet.Contains(net.ParseIP(ip)) +} + +func IsValidSubnets(subnets string) error { + for _, subnet := range splitSubnets(subnets) { + if err := isValidSubnet(subnet); err != nil { + return err + } + } + return nil +} + +func IsIpInSubnets(ctx context.Context, ip string, subnets string) bool { + for _, subnet := range splitSubnets(subnets) { + if isIpInSubnet(ctx, ip, subnet) { + return true + } + } + return false +} diff --git a/common/network/ip_test.go b/common/network/ip_test.go new file mode 100644 index 00000000..6c593458 --- /dev/null +++ b/common/network/ip_test.go @@ -0,0 +1,19 @@ +package network + +import ( + "context" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestIsIpInSubnet(t *testing.T) { + ctx := context.Background() + ip1 := "192.168.0.5" + ip2 := "125.216.250.89" + subnet := "192.168.0.0/24" + Convey("TestIsIpInSubnet", t, func() { + So(isIpInSubnet(ctx, ip1, subnet), ShouldBeTrue) + So(isIpInSubnet(ctx, ip2, subnet), ShouldBeFalse) + }) +} diff --git a/common/random.go b/common/random.go deleted file mode 100644 index 44bd2856..00000000 --- a/common/random.go +++ /dev/null @@ -1,8 +0,0 @@ -package common - -import "math/rand" - -// RandRange returns a random number between min and max (max is not included) -func RandRange(min, max int) int { - return min + rand.Intn(max-min) -} diff --git a/common/random/main.go b/common/random/main.go new file mode 100644 index 00000000..c3c69488 --- /dev/null +++ b/common/random/main.go @@ -0,0 +1,62 @@ +package random + +import ( + "math/rand" + "strings" + "time" + + "github.com/google/uuid" +) + +func GetUUID() string { + code := uuid.New().String() + code = strings.Replace(code, "-", "", -1) + return code +} + +const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +const keyNumbers = "0123456789" + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func GenerateKey() string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, 48) + for i := 0; i < 16; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + uuid_ := GetUUID() + for i := 0; i < 32; i++ { + c := uuid_[i] + if i%2 == 0 && c >= 'a' && c <= 'z' { + c = c - 'a' + 'A' + } + key[i+16] = c + } + return string(key) +} + +func GetRandomString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + return string(key) +} + +func GetRandomNumberString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyNumbers[rand.Intn(len(keyNumbers))] + } + return string(key) +} + +// RandRange returns a random number between min and max (max is not included) +func RandRange(min, max int) int { + return min + rand.Intn(max-min) +} diff --git a/controller/github.go b/controller/auth/github.go similarity index 94% rename from controller/github.go rename to controller/auth/github.go index fc674852..22b48976 100644 --- a/controller/github.go +++ b/controller/auth/github.go @@ -1,4 +1,4 @@ -package controller +package auth import ( "bytes" @@ -7,10 +7,10 @@ import ( "github.com/Laisky/errors/v2" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -133,8 +133,8 @@ func GitHubOAuth(c *gin.Context) { user.DisplayName = "GitHub User" } user.Email = githubUser.Email - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -152,14 +152,14 @@ func GitHubOAuth(c *gin.Context) { } } - if user.Status != common.UserStatusEnabled { + if user.Status != model.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, }) return } - setupLogin(&user, c) + controller.SetupLogin(&user, c) } func GitHubBind(c *gin.Context) { @@ -219,7 +219,7 @@ func GitHubBind(c *gin.Context) { func GenerateOAuthCode(c *gin.Context) { session := sessions.Default(c) - state := helper.GetRandomString(12) + state := random.GetRandomString(12) session.Set("oauth_state", state) err := session.Save() if err != nil { diff --git a/controller/auth/lark.go b/controller/auth/lark.go new file mode 100644 index 00000000..a1dd8e84 --- /dev/null +++ b/controller/auth/lark.go @@ -0,0 +1,201 @@ +package auth + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "strconv" + "time" + + "github.com/Laisky/errors/v2" + "github.com/gin-contrib/sessions" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/model" +) + +type LarkOAuthResponse struct { + AccessToken string `json:"access_token"` +} + +type LarkUser struct { + Name string `json:"name"` + OpenID string `json:"open_id"` +} + +func getLarkUserInfoByCode(code string) (*LarkUser, error) { + if code == "" { + return nil, errors.New("无效的参数") + } + values := map[string]string{ + "client_id": config.LarkClientId, + "client_secret": config.LarkClientSecret, + "code": code, + "grant_type": "authorization_code", + "redirect_uri": fmt.Sprintf("%s/oauth/lark", config.ServerAddress), + } + jsonData, err := json.Marshal(values) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", "https://passport.feishu.cn/suite/passport/oauth/token", bytes.NewBuffer(jsonData)) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + client := http.Client{ + Timeout: 5 * time.Second, + } + res, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至飞书服务器,请稍后重试!") + } + defer res.Body.Close() + var oAuthResponse LarkOAuthResponse + err = json.NewDecoder(res.Body).Decode(&oAuthResponse) + if err != nil { + return nil, err + } + req, err = http.NewRequest("GET", "https://passport.feishu.cn/suite/passport/oauth/userinfo", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) + res2, err := client.Do(req) + if err != nil { + logger.SysLog(err.Error()) + return nil, errors.New("无法连接至飞书服务器,请稍后重试!") + } + var larkUser LarkUser + err = json.NewDecoder(res2.Body).Decode(&larkUser) + if err != nil { + return nil, err + } + return &larkUser, nil +} + +func LarkOAuth(c *gin.Context) { + session := sessions.Default(c) + state := c.Query("state") + if state == "" || session.Get("oauth_state") == nil || state != session.Get("oauth_state").(string) { + c.JSON(http.StatusForbidden, gin.H{ + "success": false, + "message": "state is empty or not same", + }) + return + } + username := session.Get("username") + if username != nil { + LarkBind(c) + return + } + code := c.Query("code") + larkUser, err := getLarkUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + LarkId: larkUser.OpenID, + } + if model.IsLarkIdAlreadyTaken(user.LarkId) { + err := user.FillUserByLarkId() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + if config.RegisterEnabled { + user.Username = "lark_" + strconv.Itoa(model.GetMaxUserId()+1) + if larkUser.Name != "" { + user.DisplayName = larkUser.Name + } else { + user.DisplayName = "Lark User" + } + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled + + if err := user.Insert(0); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + } else { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "管理员关闭了新用户注册", + }) + return + } + } + + if user.Status != model.UserStatusEnabled { + c.JSON(http.StatusOK, gin.H{ + "message": "用户已被封禁", + "success": false, + }) + return + } + controller.SetupLogin(&user, c) +} + +func LarkBind(c *gin.Context) { + code := c.Query("code") + larkUser, err := getLarkUserInfoByCode(code) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user := model.User{ + LarkId: larkUser.OpenID, + } + if model.IsLarkIdAlreadyTaken(user.LarkId) { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": "该飞书账户已被绑定", + }) + return + } + session := sessions.Default(c) + id := session.Get("id") + // id := c.GetInt("id") // critical bug! + user.Id = id.(int) + err = user.FillUserById() + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + user.LarkId = larkUser.OpenID + err = user.Update(false) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "bind", + }) + return +} diff --git a/controller/wechat.go b/controller/auth/wechat.go similarity index 93% rename from controller/wechat.go rename to controller/auth/wechat.go index 8f997bfb..da1b513b 100644 --- a/controller/wechat.go +++ b/controller/auth/wechat.go @@ -1,12 +1,12 @@ -package controller +package auth import ( "encoding/json" "fmt" "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -83,8 +83,8 @@ func WeChatAuth(c *gin.Context) { if config.RegisterEnabled { user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.DisplayName = "WeChat User" - user.Role = common.RoleCommonUser - user.Status = common.UserStatusEnabled + user.Role = model.RoleCommonUser + user.Status = model.UserStatusEnabled if err := user.Insert(0); err != nil { c.JSON(http.StatusOK, gin.H{ @@ -102,14 +102,14 @@ func WeChatAuth(c *gin.Context) { } } - if user.Status != common.UserStatusEnabled { + if user.Status != model.UserStatusEnabled { c.JSON(http.StatusOK, gin.H{ "message": "用户已被封禁", "success": false, }) return } - setupLogin(&user, c) + controller.SetupLogin(&user, c) } func WeChatBind(c *gin.Context) { diff --git a/controller/channel-billing.go b/controller/channel-billing.go index b9a3908e..79ef322a 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -3,18 +3,19 @@ package controller import ( "encoding/json" "fmt" - "github.com/Laisky/errors/v2" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/monitor" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "strconv" "time" + "github.com/Laisky/errors/v2" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/client" + "github.com/gin-gonic/gin" ) @@ -96,7 +97,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He for k := range headers { req.Header.Add(k, headers.Get(k)) } - res, err := util.HTTPClient.Do(req) + res, err := client.HTTPClient.Do(req) if err != nil { return nil, err } @@ -204,28 +205,28 @@ func updateChannelAIGC2DBalance(channel *model.Channel) (float64, error) { } func updateChannelBalance(channel *model.Channel) (float64, error) { - baseURL := common.ChannelBaseURLs[channel.Type] + baseURL := channeltype.ChannelBaseURLs[channel.Type] if channel.GetBaseURL() == "" { channel.BaseURL = &baseURL } switch channel.Type { - case common.ChannelTypeOpenAI: + case channeltype.OpenAI: if channel.GetBaseURL() != "" { baseURL = channel.GetBaseURL() } - case common.ChannelTypeAzure: + case channeltype.Azure: return 0, errors.New("尚未实现") - case common.ChannelTypeCustom: + case channeltype.Custom: baseURL = channel.GetBaseURL() - case common.ChannelTypeCloseAI: + case channeltype.CloseAI: return updateChannelCloseAIBalance(channel) - case common.ChannelTypeOpenAISB: + case channeltype.OpenAISB: return updateChannelOpenAISBBalance(channel) - case common.ChannelTypeAIProxy: + case channeltype.AIProxy: return updateChannelAIProxyBalance(channel) - case common.ChannelTypeAPI2GPT: + case channeltype.API2GPT: return updateChannelAPI2GPTBalance(channel) - case common.ChannelTypeAIGC2D: + case channeltype.AIGC2D: return updateChannelAIGC2DBalance(channel) default: return 0, errors.New("尚未实现") @@ -301,11 +302,11 @@ func updateAllChannelsBalance() error { return err } for _, channel := range channels { - if channel.Status != common.ChannelStatusEnabled { + if channel.Status != model.ChannelStatusEnabled { continue } // TODO: support Azure - if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeCustom { + if channel.Type != channeltype.OpenAI && channel.Type != channeltype.Custom { continue } balance, err := updateChannelBalance(channel) diff --git a/controller/channel-test.go b/controller/channel-test.go index 57138f49..535b21bd 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -4,18 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/Laisky/errors/v2" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/common/message" - "github.com/songquanpeng/one-api/middleware" - "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/monitor" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/helper" - relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" "net/http/httptest" @@ -25,6 +13,20 @@ import ( "sync" "time" + "github.com/Laisky/errors/v2" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/message" + "github.com/songquanpeng/one-api/middleware" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/monitor" + relay "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/controller" + "github.com/songquanpeng/one-api/relay/meta" + relaymodel "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" + "github.com/gin-gonic/gin" ) @@ -56,9 +58,9 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error c.Set("channel", channel.Type) c.Set("base_url", channel.GetBaseURL()) middleware.SetupContextForSelectedChannel(c, channel, "") - meta := util.GetRelayMeta(c) - apiType := constant.ChannelType2APIType(channel.Type) - adaptor := helper.GetAdaptor(apiType) + meta := meta.GetByContext(c) + apiType := channeltype.ToAPIType(channel.Type) + adaptor := relay.GetAdaptor(apiType) if adaptor == nil { return errors.Errorf("invalid api type: %d, adaptor is nil", apiType), nil } @@ -73,7 +75,7 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error request := buildTestRequest() request.Model = modelName meta.OriginModelName, meta.ActualModelName = modelName, modelName - convertedRequest, err := adaptor.ConvertRequest(c, constant.RelayModeChatCompletions, request) + convertedRequest, err := adaptor.ConvertRequest(c, relaymode.ChatCompletions, request) if err != nil { return err, nil } @@ -88,8 +90,8 @@ func testChannel(channel *model.Channel) (err error, openaiErr *relaymodel.Error return err, nil } if resp.StatusCode != http.StatusOK { - err := util.RelayErrorHandler(resp) - return errors.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error + err := controller.RelayErrorHandler(resp) + return fmt.Errorf("status code %d: %s", resp.StatusCode, err.Error.Message), &err.Error } usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { @@ -171,7 +173,7 @@ func testChannels(notify bool, scope string) error { } go func() { for _, channel := range channels { - isChannelEnabled := channel.Status == common.ChannelStatusEnabled + isChannelEnabled := channel.Status == model.ChannelStatusEnabled tik := time.Now() err, openaiErr := testChannel(channel) tok := time.Now() @@ -184,10 +186,10 @@ func testChannels(notify bool, scope string) error { _ = message.Notify(message.ByAll, fmt.Sprintf("渠道 %s (%d)测试超时", channel.Name, channel.Id), "", err.Error()) } } - if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { + if isChannelEnabled && monitor.ShouldDisableChannel(openaiErr, -1) { monitor.DisableChannel(channel.Id, channel.Name, err.Error()) } - if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { + if !isChannelEnabled && monitor.ShouldEnableChannel(err, openaiErr) { monitor.EnableChannel(channel.Id, channel.Name) } channel.UpdateResponseTime(milliseconds) diff --git a/controller/group.go b/controller/group.go index 128a3527..6f02394f 100644 --- a/controller/group.go +++ b/controller/group.go @@ -2,13 +2,13 @@ package controller import ( "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "net/http" ) func GetGroups(c *gin.Context) { groupNames := make([]string, 0) - for groupName := range common.GroupRatio { + for groupName := range billingratio.GroupRatio { groupNames = append(groupNames, groupName) } c.JSON(http.StatusOK, gin.H{ diff --git a/controller/misc.go b/controller/misc.go index f27fdb12..2928b8fb 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -23,6 +23,7 @@ func GetStatus(c *gin.Context) { "email_verification": config.EmailVerificationEnabled, "github_oauth": config.GitHubOAuthEnabled, "github_client_id": config.GitHubClientId, + "lark_client_id": config.LarkClientId, "system_name": config.SystemName, "logo": config.Logo, "footer_html": config.Footer, diff --git a/controller/model.go b/controller/model.go index 1be352f2..01d01bf0 100644 --- a/controller/model.go +++ b/controller/model.go @@ -3,13 +3,15 @@ package controller import ( "fmt" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/helper" + "github.com/songquanpeng/one-api/model" + relay "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/apitype" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "net/http" + "strings" ) // https://platform.openai.com/docs/api-reference/models/list @@ -39,8 +41,8 @@ type OpenAIModels struct { Parent *string `json:"parent"` } -var openAIModels []OpenAIModels -var openAIModelsMap map[string]OpenAIModels +var models []OpenAIModels +var modelsMap map[string]OpenAIModels var channelId2Models map[int][]string func init() { @@ -60,11 +62,11 @@ func init() { IsBlocking: false, }) // https://platform.openai.com/docs/models/model-endpoint-compatibility - for i := 0; i < constant.APITypeDummy; i++ { - if i == constant.APITypeAIProxyLibrary { + for i := 0; i < apitype.Dummy; i++ { + if i == apitype.AIProxyLibrary { continue } - adaptor := helper.GetAdaptor(i) + adaptor := relay.GetAdaptor(i) if adaptor == nil { continue } @@ -72,7 +74,7 @@ func init() { channelName := adaptor.GetChannelName() modelNames := adaptor.GetModelList() for _, modelName := range modelNames { - openAIModels = append(openAIModels, OpenAIModels{ + models = append(models, OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -84,12 +86,12 @@ func init() { } } for _, channelType := range openai.CompatibleChannels { - if channelType == common.ChannelTypeAzure { + if channelType == channeltype.Azure { continue } channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType) for _, modelName := range channelModelList { - openAIModels = append(openAIModels, OpenAIModels{ + models = append(models, OpenAIModels{ Id: modelName, Object: "model", Created: 1626777600, @@ -100,18 +102,18 @@ func init() { }) } } - openAIModelsMap = make(map[string]OpenAIModels) - for _, model := range openAIModels { - openAIModelsMap[model.Id] = model + modelsMap = make(map[string]OpenAIModels) + for _, model := range models { + modelsMap[model.Id] = model } channelId2Models = make(map[int][]string) - for i := 1; i < common.ChannelTypeDummy; i++ { - adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i)) + for i := 1; i < channeltype.Dummy; i++ { + adaptor := relay.GetAdaptor(channeltype.ToAPIType(i)) if adaptor == nil { continue } - meta := &util.RelayMeta{ + meta := &meta.Meta{ ChannelType: i, } adaptor.Init(meta) @@ -127,16 +129,55 @@ func DashboardListModels(c *gin.Context) { }) } -func ListModels(c *gin.Context) { +func ListAllModels(c *gin.Context) { c.JSON(200, gin.H{ "object": "list", - "data": openAIModels, + "data": models, + }) +} + +func ListModels(c *gin.Context) { + ctx := c.Request.Context() + var availableModels []string + if c.GetString("available_models") != "" { + availableModels = strings.Split(c.GetString("available_models"), ",") + } else { + userId := c.GetInt("id") + userGroup, _ := model.CacheGetUserGroup(userId) + availableModels, _ = model.CacheGetGroupModels(ctx, userGroup) + } + modelSet := make(map[string]bool) + for _, availableModel := range availableModels { + modelSet[availableModel] = true + } + availableOpenAIModels := make([]OpenAIModels, 0) + for _, model := range models { + if _, ok := modelSet[model.Id]; ok { + modelSet[model.Id] = false + availableOpenAIModels = append(availableOpenAIModels, model) + } + } + for modelName, ok := range modelSet { + if ok { + availableOpenAIModels = append(availableOpenAIModels, OpenAIModels{ + Id: modelName, + Object: "model", + Created: 1626777600, + OwnedBy: "custom", + Root: modelName, + Parent: nil, + }) + } + } + c.JSON(200, gin.H{ + "object": "list", + "data": availableOpenAIModels, }) } func RetrieveModel(c *gin.Context) { modelId := c.Param("model") - if model, ok := openAIModelsMap[modelId]; ok { + if model, ok := modelsMap[modelId]; ok { c.JSON(200, model) } else { Error := relaymodel.Error{ @@ -150,3 +191,30 @@ func RetrieveModel(c *gin.Context) { }) } } + +func GetUserAvailableModels(c *gin.Context) { + ctx := c.Request.Context() + id := c.GetInt("id") + userGroup, err := model.CacheGetUserGroup(id) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + models, err := model.CacheGetGroupModels(ctx, userGroup) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": models, + }) + return +} diff --git a/controller/redemption.go b/controller/redemption.go index 31c9348d..8d2b3f38 100644 --- a/controller/redemption.go +++ b/controller/redemption.go @@ -4,6 +4,7 @@ import ( "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -106,7 +107,7 @@ func AddRedemption(c *gin.Context) { } var keys []string for i := 0; i < redemption.Count; i++ { - key := helper.GetUUID() + key := random.GetUUID() cleanRedemption := model.Redemption{ UserId: c.GetInt("id"), Name: redemption.Name, diff --git a/controller/relay.go b/controller/relay.go index 85932368..8278bf23 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -4,9 +4,6 @@ import ( "bytes" "context" "fmt" - "io" - "net/http" - "github.com/gin-gonic/gin" "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" @@ -16,24 +13,25 @@ import ( "github.com/songquanpeng/one-api/middleware" dbmodel "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/monitor" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/controller" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" ) // https://platform.openai.com/docs/api-reference/chat -func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { +func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { var err *model.ErrorWithStatusCode switch relayMode { - case constant.RelayModeImagesGenerations: + case relaymode.ImagesGenerations: err = controller.RelayImageHelper(c, relayMode) - case constant.RelayModeAudioSpeech: + case relaymode.AudioSpeech: fallthrough - case constant.RelayModeAudioTranslation: + case relaymode.AudioTranslation: fallthrough - case constant.RelayModeAudioTranscription: + case relaymode.AudioTranscription: err = controller.RelayAudioHelper(c, relayMode) default: err = controller.RelayTextHelper(c) @@ -43,13 +41,13 @@ func relay(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func Relay(c *gin.Context) { ctx := c.Request.Context() - relayMode := constant.Path2RelayMode(c.Request.URL.Path) + relayMode := relaymode.GetByPath(c.Request.URL.Path) if config.DebugEnabled { requestBody, _ := common.GetRequestBody(c) logger.Debugf(ctx, "request body: %s", string(requestBody)) } channelId := c.GetInt("channel_id") - bizErr := relay(c, relayMode) + bizErr := relayHelper(c, relayMode) if bizErr == nil { monitor.Emit(channelId, true) return @@ -78,7 +76,7 @@ func Relay(c *gin.Context) { middleware.SetupContextForSelectedChannel(c, channel, originalModel) requestBody, err := common.GetRequestBody(c) c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - bizErr = relay(c, relayMode) + bizErr = relayHelper(c, relayMode) if bizErr == nil { return } @@ -117,7 +115,7 @@ func shouldRetry(c *gin.Context, statusCode int) error { func processChannelRelayError(ctx context.Context, channelId int, channelName string, err *model.ErrorWithStatusCode) { logger.Errorf(ctx, "relay error (channel #%d): %s", channelId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors - if util.ShouldDisableChannel(&err.Error, err.StatusCode) { + if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { monitor.DisableChannel(channelId, channelName, err.Message) } else { monitor.Emit(channelId, false) diff --git a/controller/token.go b/controller/token.go index 9b52b053..74d31547 100644 --- a/controller/token.go +++ b/controller/token.go @@ -6,9 +6,12 @@ import ( "strconv" "github.com/gin-gonic/gin" + "github.com/jinzhu/copier" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/network" + "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/model" ) @@ -106,9 +109,24 @@ func GetTokenStatus(c *gin.Context) { }) } +func validateToken(c *gin.Context, token *model.Token) error { + if len(token.Name) > 30 { + return fmt.Errorf("令牌名称过长") + } + + if token.Subnet != nil && *token.Subnet != "" { + err := network.IsValidSubnets(*token.Subnet) + if err != nil { + return fmt.Errorf("无效的网段:%s", err.Error()) + } + } + + return nil +} + func AddToken(c *gin.Context) { - token := model.Token{} - err := c.ShouldBindJSON(&token) + token := new(model.Token) + err := c.ShouldBindJSON(token) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -116,22 +134,27 @@ func AddToken(c *gin.Context) { }) return } - if len(token.Name) > 30 { + + err = validateToken(c, token) + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "令牌名称过长", + "message": fmt.Sprintf("参数错误:%s", err.Error()), }) return } + cleanToken := model.Token{ UserId: c.GetInt("id"), Name: token.Name, - Key: helper.GenerateKey(), + Key: random.GenerateKey(), CreatedTime: helper.GetTimestamp(), AccessedTime: helper.GetTimestamp(), ExpiredTime: token.ExpiredTime, RemainQuota: token.RemainQuota, UnlimitedQuota: token.UnlimitedQuota, + Models: token.Models, + Subnet: token.Subnet, } err = cleanToken.Insert() if err != nil { @@ -168,12 +191,7 @@ func DeleteToken(c *gin.Context) { } type updateTokenDto struct { - Id int `json:"id"` - Status int `json:"status" gorm:"default:1"` - Name *string `json:"name" gorm:"index" ` - ExpiredTime *int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired - RemainQuota *int `json:"remain_quota" gorm:"default:0"` - UnlimitedQuota *bool `json:"unlimited_quota" gorm:"default:false"` + model.Token // AddUsedQuota add or subtract used quota AddUsedQuota int `json:"add_used_quota" gorm:"-"` AddReason string `json:"add_reason" gorm:"-"` @@ -183,43 +201,51 @@ func UpdateToken(c *gin.Context) { userId := c.GetInt("id") statusOnly := c.Query("status_only") tokenPatch := new(updateTokenDto) - if err := c.ShouldBindJSON(tokenPatch); err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "parse request join: " + err.Error(), - }) - return - } - - if tokenPatch.Name != nil && - (len(*tokenPatch.Name) > 30 || len(*tokenPatch.Name) == 0) { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": "令牌名称错误,长度应在 1-30 之间", - }) - return - } - - tokenInDB, err := model.GetTokenByIds(tokenPatch.Id, userId) + err := c.ShouldBindJSON(tokenPatch) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": fmt.Sprintf("get token by id %d: %s", tokenPatch.Id, err.Error()), + "message": err.Error(), }) return } - if tokenPatch.Status == common.TokenStatusEnabled { - if tokenInDB.Status == common.TokenStatusExpired && tokenInDB.ExpiredTime <= helper.GetTimestamp() && tokenInDB.ExpiredTime != -1 { + token := new(model.Token) + if err = copier.Copy(token, tokenPatch); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + err = validateToken(c, token) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": fmt.Sprintf("参数错误:%s", err.Error()), + }) + return + } + + cleanToken, err := model.GetTokenByIds(token.Id, userId) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + if token.Status == model.TokenStatusEnabled { + if cleanToken.Status == model.TokenStatusExpired && cleanToken.ExpiredTime <= helper.GetTimestamp() && cleanToken.ExpiredTime != -1 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", }) return } - if tokenInDB.Status == common.TokenStatusExhausted && - tokenInDB.RemainQuota <= 0 && - !tokenInDB.UnlimitedQuota { + if cleanToken.Status == model.TokenStatusExhausted && cleanToken.RemainQuota <= 0 && !cleanToken.UnlimitedQuota { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", @@ -228,42 +254,33 @@ func UpdateToken(c *gin.Context) { } } if statusOnly != "" { - tokenInDB.Status = tokenPatch.Status + cleanToken.Status = token.Status } else { - // If you add more fields, please also update tokenPatch.Update() - if tokenPatch.Name != nil { - tokenInDB.Name = *tokenPatch.Name - } - if tokenPatch.ExpiredTime != nil { - tokenInDB.ExpiredTime = *tokenPatch.ExpiredTime - } - if tokenPatch.RemainQuota != nil { - tokenInDB.RemainQuota = int64(*tokenPatch.RemainQuota) - } - if tokenPatch.UnlimitedQuota != nil { - tokenInDB.UnlimitedQuota = *tokenPatch.UnlimitedQuota - } + // If you add more fields, please also update token.Update() + cleanToken.Name = token.Name + cleanToken.ExpiredTime = token.ExpiredTime + cleanToken.RemainQuota = token.RemainQuota + cleanToken.UnlimitedQuota = token.UnlimitedQuota + cleanToken.Models = token.Models + cleanToken.Subnet = token.Subnet } - tokenInDB.RemainQuota -= int64(tokenPatch.AddUsedQuota) - tokenInDB.UsedQuota += int64(tokenPatch.AddUsedQuota) - if tokenPatch.AddUsedQuota != 0 { model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("外部(%s)消耗 %s", tokenPatch.AddReason, common.LogQuota(int64(tokenPatch.AddUsedQuota)))) } - if err = tokenInDB.Update(); err != nil { + err = cleanToken.Update() + if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, - "message": "update token: " + err.Error(), + "message": err.Error(), }) return } - c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": tokenInDB, + "data": cleanToken, }) return } diff --git a/controller/user.go b/controller/user.go index 691378ec..bd31c034 100644 --- a/controller/user.go +++ b/controller/user.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/random" "github.com/songquanpeng/one-api/model" "net/http" "strconv" @@ -58,11 +58,11 @@ func Login(c *gin.Context) { }) return } - setupLogin(&user, c) + SetupLogin(&user, c) } // setup session & cookies and then return user info -func setupLogin(user *model.User, c *gin.Context) { +func SetupLogin(user *model.User, c *gin.Context) { session := sessions.Default(c) session.Set("id", user.Id) session.Set("username", user.Username) @@ -186,27 +186,27 @@ func Register(c *gin.Context) { } func GetAllUsers(c *gin.Context) { - p, _ := strconv.Atoi(c.Query("p")) - if p < 0 { - p = 0 - } - - order := c.DefaultQuery("order", "") - users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) - - if err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": err.Error(), - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "", - "data": users, - }) + p, _ := strconv.Atoi(c.Query("p")) + if p < 0 { + p = 0 + } + + order := c.DefaultQuery("order", "") + users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage, order) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + "data": users, + }) } func SearchUsers(c *gin.Context) { @@ -245,7 +245,7 @@ func GetUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= user.Role && myRole != common.RoleRootUser { + if myRole <= user.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权获取同级或更高等级用户的信息", @@ -293,7 +293,7 @@ func GenerateAccessToken(c *gin.Context) { }) return } - user.AccessToken = helper.GetUUID() + user.AccessToken = random.GetUUID() if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { c.JSON(http.StatusOK, gin.H{ @@ -330,7 +330,7 @@ func GetAffCode(c *gin.Context) { return } if user.AffCode == "" { - user.AffCode = helper.GetRandomString(4) + user.AffCode = random.GetRandomString(4) if err := user.Update(false); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -404,14 +404,14 @@ func UpdateUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= originUser.Role && myRole != common.RoleRootUser { + if myRole <= originUser.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权更新同权限等级或更高权限等级的用户信息", }) return } - if myRole <= updatedUser.Role && myRole != common.RoleRootUser { + if myRole <= updatedUser.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权将其他用户权限等级提升到大于等于自己的权限等级", @@ -525,7 +525,7 @@ func DeleteSelf(c *gin.Context) { id := c.GetInt("id") user, _ := model.GetUserById(id, false) - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "不能删除超级管理员账户", @@ -627,7 +627,7 @@ func ManageUser(c *gin.Context) { return } myRole := c.GetInt("role") - if myRole <= user.Role && myRole != common.RoleRootUser { + if myRole <= user.Role && myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无权更新同权限等级或更高权限等级的用户信息", @@ -636,8 +636,8 @@ func ManageUser(c *gin.Context) { } switch req.Action { case "disable": - user.Status = common.UserStatusDisabled - if user.Role == common.RoleRootUser { + user.Status = model.UserStatusDisabled + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法禁用超级管理员用户", @@ -645,9 +645,9 @@ func ManageUser(c *gin.Context) { return } case "enable": - user.Status = common.UserStatusEnabled + user.Status = model.UserStatusEnabled case "delete": - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法删除超级管理员用户", @@ -662,37 +662,37 @@ func ManageUser(c *gin.Context) { return } case "promote": - if myRole != common.RoleRootUser { + if myRole != model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "普通管理员用户无法提升其他用户为管理员", }) return } - if user.Role >= common.RoleAdminUser { + if user.Role >= model.RoleAdminUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该用户已经是管理员", }) return } - user.Role = common.RoleAdminUser + user.Role = model.RoleAdminUser case "demote": - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法降级超级管理员用户", }) return } - if user.Role == common.RoleCommonUser { + if user.Role == model.RoleCommonUser { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "该用户已经是普通用户", }) return } - user.Role = common.RoleCommonUser + user.Role = model.RoleCommonUser } if err := user.Update(false); err != nil { @@ -746,7 +746,7 @@ func EmailBind(c *gin.Context) { }) return } - if user.Role == common.RoleRootUser { + if user.Role == model.RoleRootUser { config.RootUserEmail = email } c.JSON(http.StatusOK, gin.H{ @@ -786,3 +786,38 @@ func TopUp(c *gin.Context) { }) return } + +type adminTopUpRequest struct { + UserId int `json:"user_id"` + Quota int `json:"quota"` + Remark string `json:"remark"` +} + +func AdminTopUp(c *gin.Context) { + req := adminTopUpRequest{} + err := c.ShouldBindJSON(&req) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + err = model.IncreaseUserQuota(req.UserId, int64(req.Quota)) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": err.Error(), + }) + return + } + if req.Remark == "" { + req.Remark = fmt.Sprintf("通过 API 充值 %s", common.LogQuota(int64(req.Quota))) + } + model.RecordTopupLog(req.UserId, req.Remark, req.Quota) + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "", + }) + return +} diff --git a/docs/API.md b/docs/API.md new file mode 100644 index 00000000..0b7ddf5a --- /dev/null +++ b/docs/API.md @@ -0,0 +1,53 @@ +# 使用 API 操控 & 扩展 One API +> 欢迎提交 PR 在此放上你的拓展项目。 + +例如,虽然 One API 本身没有直接支持支付,但是你可以通过系统扩展的 API 来实现支付功能。 + +又或者你想自定义渠道管理策略,也可以通过 API 来实现渠道的禁用与启用。 + +## 鉴权 +One API 支持两种鉴权方式:Cookie 和 Token,对于 Token,参照下图获取: + +![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/c15281a7-83ed-47cb-a1f6-913cb6bf4a7c) + +之后,将 Token 作为请求头的 Authorization 字段的值即可,例如下面使用 Token 调用测试渠道的 API: +![image](https://github.com/songquanpeng/songquanpeng.github.io/assets/39998050/1273b7ae-cb60-4c0d-93a6-b1cbc039c4f8) + +## 请求格式与响应格式 +One API 使用 JSON 格式进行请求和响应。 + +对于响应体,一般格式如下: +```json +{ + "message": "请求信息", + "success": true, + "data": {} +} +``` + +## API 列表 +> 当前 API 列表不全,请自行通过浏览器抓取前端请求 + +如果现有的 API 没有办法满足你的需求,欢迎提交 issue 讨论。 + +### 获取当前登录用户信息 +**GET** `/api/user/self` + +### 为给定用户充值额度 +**POST** `/api/topup` +```json +{ + "user_id": 1, + "quota": 100000, + "remark": "充值 100000 额度" +} +``` + +## 其他 +### 充值链接上的附加参数 +One API 会在用户点击充值按钮的时候,将用户的信息和充值信息附加在链接上,例如: +`https://example.com?username=root&user_id=1&transaction_id=4b3eed80-55d5-443f-bd44-fb18c648c837` + +你可以通过解析链接上的参数来获取用户信息和充值信息,然后调用 API 来为用户充值。 + +注意,不是所有主题都支持该功能,欢迎 PR 补齐。 \ No newline at end of file diff --git a/go.mod b/go.mod index 5c50c39d..39f9e295 100644 --- a/go.mod +++ b/go.mod @@ -13,8 +13,10 @@ require ( github.com/go-playground/validator/v10 v10.19.0 github.com/go-redis/redis/v8 v8.11.5 github.com/google/uuid v1.6.0 + github.com/jinzhu/copier v0.4.0 github.com/pkg/errors v0.9.1 github.com/pkoukk/tiktoken-go v0.1.6 + github.com/smartystreets/goconvey v1.8.1 github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.21.0 golang.org/x/image v0.15.0 @@ -48,16 +50,17 @@ require ( github.com/go-sql-driver/mysql v1.7.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cpy v0.0.0-20211218193943-a9c933c06932 // indirect + github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/securecookie v1.1.1 // indirect github.com/gorilla/sessions v1.2.1 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.4 // indirect - github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -66,6 +69,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/smarty/assertions v1.15.0 // indirect github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect @@ -78,7 +82,7 @@ require ( golang.org/x/sys v0.18.0 // indirect golang.org/x/term v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect - golang.org/x/tools v0.6.0 // indirect + golang.org/x/tools v0.7.0 // indirect google.golang.org/protobuf v1.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d518f25f..bd2decbc 100644 --- a/go.sum +++ b/go.sum @@ -96,6 +96,8 @@ github.com/google/go-cpy v0.0.0-20211218193943-a9c933c06932/go.mod h1:cC6EdPbj/1 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gopherjs/gopherjs v1.17.2 h1:fQnZVsXk8uxXIStYb0N4bGk7jeyTalG/wsZjQ25dO0g= +github.com/gopherjs/gopherjs v1.17.2/go.mod h1:pRRIvn/QzFLrKfvEz3qUuEhtE/zLCWfreZ6J5gM2i+k= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= @@ -107,16 +109,18 @@ github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsI github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.4 h1:Xp2aQS8uXButQdnCMWNmvx6UysWQQC+u1EoizjguY+8= -github.com/jackc/pgx/v5 v5.5.4/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= +github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= +github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= +github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= @@ -164,6 +168,10 @@ github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/smarty/assertions v1.15.0 h1:cR//PqUBUiQRakZWqBiFFQ9wb8emQGDb0HeGdqGByCY= +github.com/smarty/assertions v1.15.0/go.mod h1:yABtdzeQs6l1brC900WlRNwj6ZR55d7B+E8C6HtKdec= +github.com/smartystreets/goconvey v1.8.1 h1:qGjIddxOk4grTu9JPOU31tVfq3cNdBlNa5sSznIX1xY= +github.com/smartystreets/goconvey v1.8.1/go.mod h1:+/u4qLyY6x1jReYOp7GOM2FSt8aP9CzCZL03bI28W60= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -206,8 +214,8 @@ golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE golang.org/x/lint v0.0.0-20210508222113-6edffad5e616 h1:VLliZ0d+/avPrXXH+OakdXhpJuEoBZuwh1m2j7U6Iug= golang.org/x/lint v0.0.0-20210508222113-6edffad5e616/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= -golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.9.0 h1:KENHtAZL2y3NLMYZeHY9DW8HW8V+kQyJsY/V9JlKvCs= +golang.org/x/mod v0.9.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -238,8 +246,8 @@ golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= -golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.7.0 h1:W4OVu8VVOaIO0yzWMNdepAulS7YfoS3Zabrm8DOXXU4= +golang.org/x/tools v0.7.0/go.mod h1:4pg6aUX35JBAogB10C9AtvVL+qowtN4pT3CGSQex14s= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= diff --git a/main.go b/main.go index 9c727152..3ee1dc94 100644 --- a/main.go +++ b/main.go @@ -16,7 +16,7 @@ import ( "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/middleware" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/router" ) diff --git a/middleware/auth.go b/middleware/auth.go index 95ae5700..d55820f7 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -1,15 +1,15 @@ package middleware import ( - "net/http" - "strings" - + "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/network" "github.com/songquanpeng/one-api/model" + "net/http" + "strings" ) func authHelper(c *gin.Context, minRole int) { @@ -47,7 +47,7 @@ func authHelper(c *gin.Context, minRole int) { return } } - if status.(int) == common.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { + if status.(int) == model.UserStatusDisabled || blacklist.IsUserBanned(id.(int)) { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "用户已被封禁", @@ -74,24 +74,25 @@ func authHelper(c *gin.Context, minRole int) { func UserAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleCommonUser) + authHelper(c, model.RoleCommonUser) } } func AdminAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleAdminUser) + authHelper(c, model.RoleAdminUser) } } func RootAuth() func(c *gin.Context) { return func(c *gin.Context) { - authHelper(c, common.RoleRootUser) + authHelper(c, model.RoleRootUser) } } func TokenAuth() func(c *gin.Context) { return func(c *gin.Context) { + ctx := c.Request.Context() key := c.Request.Header.Get("Authorization") key = strings.TrimPrefix(key, "Bearer ") key = strings.TrimPrefix(strings.TrimPrefix(key, "sk-"), "laisky-") @@ -102,6 +103,12 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusUnauthorized, err.Error()) return } + if token.Subnet != nil && *token.Subnet != "" { + if !network.IsIpInSubnets(ctx, c.ClientIP(), *token.Subnet) { + abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌只能在指定网段使用:%s,当前 ip:%s", *token.Subnet, c.ClientIP())) + return + } + } userEnabled, err := model.CacheIsUserEnabled(token.UserId) if err != nil { abortWithMessage(c, http.StatusInternalServerError, err.Error()) @@ -111,6 +118,19 @@ func TokenAuth() func(c *gin.Context) { abortWithMessage(c, http.StatusForbidden, "用户已被封禁") return } + requestModel, err := getRequestModel(c) + if err != nil && !strings.HasPrefix(c.Request.URL.Path, "/v1/models") { + abortWithMessage(c, http.StatusBadRequest, err.Error()) + return + } + c.Set("request_model", requestModel) + if token.Models != nil && *token.Models != "" { + c.Set("available_models", *token.Models) + if requestModel != "" && !isModelInList(requestModel, *token.Models) { + abortWithMessage(c, http.StatusForbidden, fmt.Sprintf("该令牌无权使用模型:%s", requestModel)) + return + } + } c.Set("id", token.UserId) c.Set("token_id", token.Id) c.Set("token_name", token.Name) diff --git a/middleware/distributor.go b/middleware/distributor.go index f538f250..3eff734d 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -2,14 +2,16 @@ package middleware import ( "fmt" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" "net/http" "strconv" "strings" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" ) type ModelRequest struct { @@ -35,42 +37,16 @@ func Distribute() func(c *gin.Context) { abortWithMessage(c, http.StatusBadRequest, "无效的渠道 Id") return } - if channel.Status != common.ChannelStatusEnabled { + if channel.Status != model.ChannelStatusEnabled { abortWithMessage(c, http.StatusForbidden, "该渠道已被禁用") return } } else { - // Select a channel for the user - var modelRequest ModelRequest - err := common.UnmarshalBodyReusable(c, &modelRequest) + requestModel := c.GetString("request_model") + var err error + channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, requestModel, false) if err != nil { - abortWithMessage(c, http.StatusBadRequest, fmt.Sprintf("无效的请求: %+v", err)) - return - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - if modelRequest.Model == "" { - modelRequest.Model = "text-moderation-stable" - } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - if modelRequest.Model == "" { - modelRequest.Model = c.Param("model") - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - if modelRequest.Model == "" { - modelRequest.Model = "dall-e-2" - } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - if modelRequest.Model == "" { - modelRequest.Model = "whisper-1" - } - } - requestModel = modelRequest.Model - channel, err = model.CacheGetRandomSatisfiedChannel(userGroup, modelRequest.Model, false) - if err != nil { - message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) + message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, requestModel) if channel != nil { logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" @@ -85,17 +61,18 @@ func Distribute() func(c *gin.Context) { } func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, modelName string) { + // one channel could relates to multiple groups, + // and each groud has individual ratio, // set minimal group ratio as channel_ratio var minimalRatio float64 = -1 for _, grp := range strings.Split(channel.Group, ",") { - v := common.GetGroupRatio(grp) + v := ratio.GetGroupRatio(grp) if minimalRatio < 0 || v < minimalRatio { minimalRatio = v } } logger.Info(c.Request.Context(), fmt.Sprintf("set channel %s ratio to %f", channel.Name, minimalRatio)) c.Set("channel_ratio", minimalRatio) - c.Set("channel", channel.Type) c.Set("channel_id", channel.Id) c.Set("channel_name", channel.Name) @@ -105,19 +82,19 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode c.Set("base_url", channel.GetBaseURL()) // this is for backward compatibility switch channel.Type { - case common.ChannelTypeAzure: - c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeXunfei: - c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeGemini: - c.Set(common.ConfigKeyAPIVersion, channel.Other) - case common.ChannelTypeAIProxyLibrary: - c.Set(common.ConfigKeyLibraryID, channel.Other) - case common.ChannelTypeAli: - c.Set(common.ConfigKeyPlugin, channel.Other) + case channeltype.Azure: + c.Set(config.KeyAPIVersion, channel.Other) + case channeltype.Xunfei: + c.Set(config.KeyAPIVersion, channel.Other) + case channeltype.Gemini: + c.Set(config.KeyAPIVersion, channel.Other) + case channeltype.AIProxyLibrary: + c.Set(config.KeyLibraryID, channel.Other) + case channeltype.Ali: + c.Set(config.KeyPlugin, channel.Other) } cfg, _ := channel.LoadConfig() for k, v := range cfg { - c.Set(common.ConfigKeyPrefix+k, v) + c.Set(config.KeyPrefix+k, v) } } diff --git a/middleware/utils.go b/middleware/utils.go index bc14c367..b65b018b 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,9 +1,12 @@ package middleware import ( + "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "strings" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { @@ -16,3 +19,42 @@ func abortWithMessage(c *gin.Context, statusCode int, message string) { c.Abort() logger.Error(c.Request.Context(), message) } + +func getRequestModel(c *gin.Context) (string, error) { + var modelRequest ModelRequest + err := common.UnmarshalBodyReusable(c, &modelRequest) + if err != nil { + return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + if modelRequest.Model == "" { + modelRequest.Model = "text-moderation-stable" + } + } + if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + if modelRequest.Model == "" { + modelRequest.Model = c.Param("model") + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + if modelRequest.Model == "" { + modelRequest.Model = "dall-e-2" + } + } + if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + if modelRequest.Model == "" { + modelRequest.Model = "whisper-1" + } + } + return modelRequest.Model, nil +} + +func isModelInList(modelName string, models string) bool { + modelList := strings.Split(models, ",") + for _, model := range modelList { + if modelName == model { + return true + } + } + return false +} diff --git a/model/ability.go b/model/ability.go index 48b856a2..2db72518 100644 --- a/model/ability.go +++ b/model/ability.go @@ -1,8 +1,10 @@ package model import ( + "context" "github.com/songquanpeng/one-api/common" "gorm.io/gorm" + "sort" "strings" ) @@ -55,7 +57,7 @@ func (channel *Channel) AddAbilities() error { Group: group, Model: model, ChannelId: channel.Id, - Enabled: channel.Status == common.ChannelStatusEnabled, + Enabled: channel.Status == ChannelStatusEnabled, Priority: channel.Priority, } abilities = append(abilities, ability) @@ -88,3 +90,19 @@ func (channel *Channel) UpdateAbilities() error { func UpdateAbilityStatus(channelId int, status bool) error { return DB.Model(&Ability{}).Where("channel_id = ?", channelId).Select("enabled").Update("enabled", status).Error } + +func GetGroupModels(ctx context.Context, group string) ([]string, error) { + groupCol := "`group`" + trueVal := "1" + if common.UsingPostgreSQL { + groupCol = `"group"` + trueVal = "true" + } + var models []string + err := DB.Model(&Ability{}).Distinct("model").Where(groupCol+" = ? and enabled = "+trueVal, group).Pluck("model", &models).Error + if err != nil { + return nil, err + } + sort.Strings(models) + return models, err +} diff --git a/model/cache.go b/model/cache.go index 50946bd6..a05cec19 100644 --- a/model/cache.go +++ b/model/cache.go @@ -8,6 +8,7 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" "math/rand" "sort" "strconv" @@ -21,6 +22,7 @@ var ( UserId2GroupCacheSeconds = config.SyncFrequency UserId2QuotaCacheSeconds = config.SyncFrequency UserId2StatusCacheSeconds = config.SyncFrequency + GroupModelsCacheSeconds = config.SyncFrequency ) func CacheGetTokenByKey(key string) (*Token, error) { @@ -147,13 +149,32 @@ func CacheIsUserEnabled(userId int) (bool, error) { return userEnabled, err } +func CacheGetGroupModels(ctx context.Context, group string) ([]string, error) { + if !common.RedisEnabled { + return GetGroupModels(ctx, group) + } + modelsStr, err := common.RedisGet(fmt.Sprintf("group_models:%s", group)) + if err == nil { + return strings.Split(modelsStr, ","), nil + } + models, err := GetGroupModels(ctx, group) + if err != nil { + return nil, err + } + err = common.RedisSet(fmt.Sprintf("group_models:%s", group), strings.Join(models, ","), time.Duration(GroupModelsCacheSeconds)*time.Second) + if err != nil { + logger.SysError("Redis set group models error: " + err.Error()) + } + return models, nil +} + var group2model2channels map[string]map[string][]*Channel var channelSyncLock sync.RWMutex func InitChannelCache() { newChannelId2channel := make(map[int]*Channel) var channels []*Channel - DB.Where("status = ?", common.ChannelStatusEnabled).Find(&channels) + DB.Where("status = ?", ChannelStatusEnabled).Find(&channels) for _, channel := range channels { newChannelId2channel[channel.Id] = channel } @@ -228,7 +249,7 @@ func CacheGetRandomSatisfiedChannel(group string, model string, ignoreFirstPrior idx := rand.Intn(endIdx) if ignoreFirstPriority { if endIdx < len(channels) { // which means there are more than one priority - idx = common.RandRange(endIdx, len(channels)) + idx = random.RandRange(endIdx, len(channels)) } } return channels[idx], nil diff --git a/model/channel.go b/model/channel.go index fc4905b1..e667f7e7 100644 --- a/model/channel.go +++ b/model/channel.go @@ -3,13 +3,19 @@ package model import ( "encoding/json" "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" ) +const ( + ChannelStatusUnknown = 0 + ChannelStatusEnabled = 1 // don't use 0, 0 is the default value! + ChannelStatusManuallyDisabled = 2 // also don't use 0 + ChannelStatusAutoDisabled = 3 +) + type Channel struct { Id int `json:"id"` Type int `json:"type" gorm:"default:0"` @@ -39,7 +45,7 @@ func GetAllChannels(startIdx int, num int, scope string) ([]*Channel, error) { case "all": err = DB.Order("id desc").Find(&channels).Error case "disabled": - err = DB.Order("id desc").Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Find(&channels).Error + err = DB.Order("id desc").Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Find(&channels).Error default: err = DB.Order("id desc").Limit(num).Offset(startIdx).Omit("key").Find(&channels).Error } @@ -168,7 +174,7 @@ func (channel *Channel) LoadConfig() (map[string]string, error) { } func UpdateChannelStatusById(id int, status int) { - err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) + err := UpdateAbilityStatus(id, status == ChannelStatusEnabled) if err != nil { logger.SysError("failed to update ability status: " + err.Error()) } @@ -199,6 +205,6 @@ func DeleteChannelByStatus(status int64) (int64, error) { } func DeleteDisabledChannel() (int64, error) { - result := DB.Where("status = ? or status = ?", common.ChannelStatusAutoDisabled, common.ChannelStatusManuallyDisabled).Delete(&Channel{}) + result := DB.Where("status = ? or status = ?", ChannelStatusAutoDisabled, ChannelStatusManuallyDisabled).Delete(&Channel{}) return result.RowsAffected, result.Error } diff --git a/model/log.go b/model/log.go index 4409f73e..6fba776a 100644 --- a/model/log.go +++ b/model/log.go @@ -7,7 +7,6 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "gorm.io/gorm" ) @@ -51,6 +50,21 @@ func RecordLog(userId int, logType int, content string) { } } +func RecordTopupLog(userId int, content string, quota int) { + log := &Log{ + UserId: userId, + Username: GetUsernameById(userId), + CreatedAt: helper.GetTimestamp(), + Type: LogTypeTopup, + Content: content, + Quota: quota, + } + err := LOG_DB.Create(log).Error + if err != nil { + logger.SysError("failed to record log: " + err.Error()) + } +} + func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) { logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) if !config.LogConsumeEnabled { diff --git a/model/main.go b/model/main.go index 4bbfde27..e5124a4c 100644 --- a/model/main.go +++ b/model/main.go @@ -12,6 +12,7 @@ import ( "github.com/songquanpeng/one-api/common/env" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" @@ -33,10 +34,10 @@ func CreateRootAccountIfNeed() error { rootUser := User{ Username: "root", Password: hashedPassword, - Role: common.RoleRootUser, - Status: common.UserStatusEnabled, + Role: RoleRootUser, + Status: UserStatusEnabled, DisplayName: "Root User", - AccessToken: helper.GetUUID(), + AccessToken: random.GetUUID(), Quota: 500000000000000, } DB.Create(&rootUser) @@ -46,7 +47,7 @@ func CreateRootAccountIfNeed() error { Id: 1, UserId: rootUser.Id, Key: config.InitialRootToken, - Status: common.TokenStatusEnabled, + Status: TokenStatusEnabled, Name: "Initial Root Token", CreatedTime: helper.GetTimestamp(), AccessedTime: helper.GetTimestamp(), diff --git a/model/option.go b/model/option.go index 1d1c28b4..bed8d4c3 100644 --- a/model/option.go +++ b/model/option.go @@ -1,9 +1,9 @@ package model import ( - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "strconv" "strings" "time" @@ -66,9 +66,9 @@ func InitOptionMap() { config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10) config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10) config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10) - config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() - config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() - config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() + config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString() + config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString() + config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString() config.OptionMap["TopUpLink"] = config.TopUpLink config.OptionMap["ChatLink"] = config.ChatLink config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) @@ -82,7 +82,7 @@ func loadOptionsFromDatabase() { options, _ := AllOption() for _, option := range options { if option.Key == "ModelRatio" { - option.Value = common.AddNewMissingRatio(option.Value) + option.Value = billingratio.AddNewMissingRatio(option.Value) } err := updateOptionMap(option.Key, option.Value) if err != nil { @@ -172,6 +172,10 @@ func updateOptionMap(key string, value string) (err error) { config.GitHubClientId = value case "GitHubClientSecret": config.GitHubClientSecret = value + case "LarkClientId": + config.LarkClientId = value + case "LarkClientSecret": + config.LarkClientSecret = value case "Footer": config.Footer = value case "SystemName": @@ -205,11 +209,11 @@ func updateOptionMap(key string, value string) (err error) { case "RetryTimes": config.RetryTimes, _ = strconv.Atoi(value) case "ModelRatio": - err = common.UpdateModelRatioByJSONString(value) + err = billingratio.UpdateModelRatioByJSONString(value) case "GroupRatio": - err = common.UpdateGroupRatioByJSONString(value) + err = billingratio.UpdateGroupRatioByJSONString(value) case "CompletionRatio": - err = common.UpdateCompletionRatioByJSONString(value) + err = billingratio.UpdateCompletionRatioByJSONString(value) case "TopUpLink": config.TopUpLink = value case "ChatLink": diff --git a/model/redemption.go b/model/redemption.go index c3ed2576..62428d35 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -8,6 +8,12 @@ import ( "gorm.io/gorm" ) +const ( + RedemptionCodeStatusEnabled = 1 // don't use 0, 0 is the default value! + RedemptionCodeStatusDisabled = 2 // also don't use 0 + RedemptionCodeStatusUsed = 3 // also don't use 0 +) + type Redemption struct { Id int `json:"id"` UserId int `json:"user_id"` @@ -61,7 +67,7 @@ func Redeem(key string, userId int) (quota int64, err error) { if err != nil { return errors.New("无效的兑换码") } - if redemption.Status != common.RedemptionCodeStatusEnabled { + if redemption.Status != RedemptionCodeStatusEnabled { return errors.New("该兑换码已被使用") } err = tx.Model(&User{}).Where("id = ?", userId).Update("quota", gorm.Expr("quota + ?", redemption.Quota)).Error @@ -69,7 +75,7 @@ func Redeem(key string, userId int) (quota int64, err error) { return err } redemption.RedeemedTime = helper.GetTimestamp() - redemption.Status = common.RedemptionCodeStatusUsed + redemption.Status = RedemptionCodeStatusUsed err = tx.Save(redemption).Error return err }) diff --git a/model/token.go b/model/token.go index c5491d06..10fd0d78 100644 --- a/model/token.go +++ b/model/token.go @@ -12,25 +12,34 @@ import ( "gorm.io/gorm" ) +const ( + TokenStatusEnabled = 1 // don't use 0, 0 is the default value! + TokenStatusDisabled = 2 // also don't use 0 + TokenStatusExpired = 3 + TokenStatusExhausted = 4 +) + type Token struct { - Id int `json:"id"` - UserId int `json:"user_id"` - Key string `json:"key" gorm:"type:char(48);uniqueIndex"` - Status int `json:"status" gorm:"default:1"` - Name string `json:"name" gorm:"index" ` - CreatedTime int64 `json:"created_time" gorm:"bigint"` - AccessedTime int64 `json:"accessed_time" gorm:"bigint"` - ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired - RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` - UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` - UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota + Id int `json:"id"` + UserId int `json:"user_id"` + Key string `json:"key" gorm:"type:char(48);uniqueIndex"` + Status int `json:"status" gorm:"default:1"` + Name string `json:"name" gorm:"index" ` + CreatedTime int64 `json:"created_time" gorm:"bigint"` + AccessedTime int64 `json:"accessed_time" gorm:"bigint"` + ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired + RemainQuota int64 `json:"remain_quota" gorm:"bigint;default:0"` + UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"` + UsedQuota int64 `json:"used_quota" gorm:"bigint;default:0"` // used quota + Models *string `json:"models" gorm:"default:''"` // allowed models + Subnet *string `json:"subnet" gorm:"default:''"` // allowed subnet } func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token, error) { var tokens []*Token var err error query := DB.Where("user_id = ?", userId) - + switch order { case "remain_quota": query = query.Order("unlimited_quota desc, remain_quota desc") @@ -39,7 +48,7 @@ func GetAllUserTokens(userId int, startIdx int, num int, order string) ([]*Token default: query = query.Order("id desc") } - + err = query.Limit(num).Offset(startIdx).Find(&tokens).Error return tokens, err } @@ -62,18 +71,17 @@ func ValidateUserToken(key string) (token *Token, err error) { return nil, errors.Wrap(err, "failed to get token by key") } - - if token.Status == common.TokenStatusExhausted { - return nil, errors.New("该令牌额度已用尽") - } else if token.Status == common.TokenStatusExpired { + if token.Status == TokenStatusExhausted { + return nil, fmt.Errorf("令牌 %s(#%d)额度已用尽", token.Name, token.Id) + } else if token.Status == TokenStatusExpired { return nil, errors.New("该令牌已过期") } - if token.Status != common.TokenStatusEnabled { + if token.Status != TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { if !common.RedisEnabled { - token.Status = common.TokenStatusExpired + token.Status = TokenStatusExpired err := token.SelectUpdate() if err != nil { logger.SysError("failed to update token status" + err.Error()) @@ -84,7 +92,7 @@ func ValidateUserToken(key string) (token *Token, err error) { if !token.UnlimitedQuota && token.RemainQuota <= 0 { if !common.RedisEnabled { // in this case, we can make sure the token is exhausted - token.Status = common.TokenStatusExhausted + token.Status = TokenStatusExhausted err := token.SelectUpdate() if err != nil { logger.SysError("failed to update token status" + err.Error()) @@ -124,7 +132,7 @@ func (token *Token) Insert() error { // Update Make sure your token's fields is completed, because this will update non-zero values func (token *Token) Update() error { var err error - err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota").Updates(token).Error + err = DB.Model(token).Select("name", "status", "expired_time", "remain_quota", "unlimited_quota", "models", "subnet").Updates(token).Error return err } diff --git a/model/user.go b/model/user.go index e1244e3c..3cc1f9c0 100644 --- a/model/user.go +++ b/model/user.go @@ -8,11 +8,24 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/blacklist" "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/common/random" "gorm.io/gorm" ) +const ( + RoleGuestUser = 0 + RoleCommonUser = 1 + RoleAdminUser = 10 + RoleRootUser = 100 +) + +const ( + UserStatusEnabled = 1 // don't use 0, 0 is the default value! + UserStatusDisabled = 2 // also don't use 0 + UserStatusDeleted = 3 +) + // User if you add sensitive fields, don't forget to clean them in setupLogin function. // Otherwise, the sensitive information will be saved on local storage in plain text! type User struct { @@ -25,6 +38,7 @@ type User struct { Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` + LarkId string `json:"lark_id" gorm:"column:lark_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int64 `json:"quota" gorm:"bigint;default:0"` @@ -42,21 +56,21 @@ func GetMaxUserId() int { } func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) { - query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", common.UserStatusDeleted) - - switch order { - case "quota": - query = query.Order("quota desc") - case "used_quota": - query = query.Order("used_quota desc") - case "request_count": - query = query.Order("request_count desc") - default: - query = query.Order("id desc") - } - - err = query.Find(&users).Error - return users, err + query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted) + + switch order { + case "quota": + query = query.Order("quota desc") + case "used_quota": + query = query.Order("used_quota desc") + case "request_count": + query = query.Order("request_count desc") + default: + query = query.Order("id desc") + } + + err = query.Find(&users).Error + return users, err } func SearchUsers(keyword string) (users []*User, err error) { @@ -108,8 +122,8 @@ func (user *User) Insert(inviterId int) error { } } user.Quota = config.QuotaForNewUser - user.AccessToken = helper.GetUUID() - user.AffCode = helper.GetRandomString(4) + user.AccessToken = random.GetUUID() + user.AffCode = random.GetRandomString(4) result := DB.Create(user) if result.Error != nil { return result.Error @@ -138,9 +152,9 @@ func (user *User) Update(updatePassword bool) error { return err } } - if user.Status == common.UserStatusDisabled { + if user.Status == UserStatusDisabled { blacklist.BanUser(user.Id) - } else if user.Status == common.UserStatusEnabled { + } else if user.Status == UserStatusEnabled { blacklist.UnbanUser(user.Id) } err = DB.Model(user).Updates(user).Error @@ -152,8 +166,8 @@ func (user *User) Delete() error { return errors.New("id 为空!") } blacklist.BanUser(user.Id) - user.Username = fmt.Sprintf("deleted_%s", helper.GetUUID()) - user.Status = common.UserStatusDeleted + user.Username = fmt.Sprintf("deleted_%s", random.GetUUID()) + user.Status = UserStatusDeleted err := DB.Model(user).Updates(user).Error return err } @@ -177,7 +191,7 @@ func (user *User) ValidateAndFill() (err error) { } } okay := common.ValidatePasswordAndHash(password, user.Password) - if !okay || user.Status != common.UserStatusEnabled { + if !okay || user.Status != UserStatusEnabled { return errors.New("用户名或密码错误,或用户已被封禁") } return nil @@ -207,6 +221,14 @@ func (user *User) FillUserByGitHubId() error { return nil } +func (user *User) FillUserByLarkId() error { + if user.LarkId == "" { + return errors.New("lark id 为空!") + } + DB.Where(User{LarkId: user.LarkId}).First(user) + return nil +} + func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") @@ -235,6 +257,10 @@ func IsGitHubIdAlreadyTaken(githubId string) bool { return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 } +func IsLarkIdAlreadyTaken(githubId string) bool { + return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1 +} + func IsUsernameAlreadyTaken(username string) bool { return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1 } @@ -261,7 +287,7 @@ func IsAdmin(userId int) bool { logger.SysError("no such user " + err.Error()) return false } - return user.Role >= common.RoleAdminUser + return user.Role >= RoleAdminUser } func IsUserEnabled(userId int) (bool, error) { @@ -273,7 +299,7 @@ func IsUserEnabled(userId int) (bool, error) { if err != nil { return false, err } - return user.Status == common.UserStatusEnabled, nil + return user.Status == UserStatusEnabled, nil } func ValidateAccessToken(token string) (user *User) { @@ -346,7 +372,7 @@ func decreaseUserQuota(id int, quota int64) (err error) { } func GetRootUserEmail() (email string) { - DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) + DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email) return email } diff --git a/monitor/channel.go b/monitor/channel.go index ad82d2f5..7e5dc58a 100644 --- a/monitor/channel.go +++ b/monitor/channel.go @@ -2,7 +2,6 @@ package monitor import ( "fmt" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/common/message" @@ -29,7 +28,7 @@ func notifyRootUser(subject string, content string) { // DisableChannel disable & notify func DisableChannel(channelId int, channelName string, reason string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) logger.SysLog(fmt.Sprintf("channel #%d has been disabled: %s", channelId, reason)) subject := fmt.Sprintf("渠道「%s」(#%d)已被禁用", channelName, channelId) content := fmt.Sprintf("渠道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, reason) @@ -37,7 +36,7 @@ func DisableChannel(channelId int, channelName string, reason string) { } func MetricDisableChannel(channelId int, successRate float64) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusAutoDisabled) + model.UpdateChannelStatusById(channelId, model.ChannelStatusAutoDisabled) logger.SysLog(fmt.Sprintf("channel #%d has been disabled due to low success rate: %.2f", channelId, successRate*100)) subject := fmt.Sprintf("渠道 #%d 已被禁用", channelId) content := fmt.Sprintf("该渠道(#%d)在最近 %d 次调用中成功率为 %.2f%%,低于阈值 %.2f%%,因此被系统自动禁用。", @@ -47,7 +46,7 @@ func MetricDisableChannel(channelId int, successRate float64) { // EnableChannel enable & notify func EnableChannel(channelId int, channelName string) { - model.UpdateChannelStatusById(channelId, common.ChannelStatusEnabled) + model.UpdateChannelStatusById(channelId, model.ChannelStatusEnabled) logger.SysLog(fmt.Sprintf("channel #%d has been enabled", channelId)) subject := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) content := fmt.Sprintf("渠道「%s」(#%d)已被启用", channelName, channelId) diff --git a/monitor/manage.go b/monitor/manage.go new file mode 100644 index 00000000..946e78af --- /dev/null +++ b/monitor/manage.go @@ -0,0 +1,62 @@ +package monitor + +import ( + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/relay/model" + "net/http" + "strings" +) + +func ShouldDisableChannel(err *model.Error, statusCode int) bool { + if !config.AutomaticDisableChannelEnabled { + return false + } + if err == nil { + return false + } + if statusCode == http.StatusUnauthorized { + return true + } + switch err.Type { + case "insufficient_quota": + return true + // https://docs.anthropic.com/claude/reference/errors + case "authentication_error": + return true + case "permission_error": + return true + case "forbidden": + return true + } + if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { + return true + } + if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic + return true + } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { + return true + } + //if strings.Contains(err.Message, "quota") { + // return true + //} + if strings.Contains(err.Message, "credit") { + return true + } + if strings.Contains(err.Message, "balance") { + return true + } + return false +} + +func ShouldEnableChannel(err error, openAIErr *model.Error) bool { + if !config.AutomaticEnableChannelEnabled { + return false + } + if err != nil { + return false + } + if openAIErr != nil { + return false + } + return true +} diff --git a/relay/adaptor.go b/relay/adaptor.go new file mode 100644 index 00000000..ef549b5b --- /dev/null +++ b/relay/adaptor.go @@ -0,0 +1,41 @@ +package relay + +import ( + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/aiproxy" + "github.com/songquanpeng/one-api/relay/adaptor/anthropic" + "github.com/songquanpeng/one-api/relay/adaptor/gemini" + "github.com/songquanpeng/one-api/relay/adaptor/ollama" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/adaptor/palm" + "github.com/songquanpeng/one-api/relay/apitype" +) + +func GetAdaptor(apiType int) adaptor.Adaptor { + switch apiType { + case apitype.AIProxyLibrary: + return &aiproxy.Adaptor{} + // case apitype.Ali: + // return &ali.Adaptor{} + case apitype.Anthropic: + return &anthropic.Adaptor{} + // case apitype.Baidu: + // return &baidu.Adaptor{} + case apitype.Gemini: + return &gemini.Adaptor{} + case apitype.OpenAI: + return &openai.Adaptor{} + case apitype.PaLM: + return &palm.Adaptor{} + // case apitype.Tencent: + // return &tencent.Adaptor{} + // case apitype.Xunfei: + // return &xunfei.Adaptor{} + // case apitype.Zhipu: + // return &zhipu.Adaptor{} + case apitype.Ollama: + return &ollama.Adaptor{} + } + + return nil +} diff --git a/relay/channel/ai360/constants.go b/relay/adaptor/ai360/constants.go similarity index 100% rename from relay/channel/ai360/constants.go rename to relay/adaptor/ai360/constants.go diff --git a/relay/channel/aiproxy/adaptor.go b/relay/adaptor/aiproxy/adaptor.go similarity index 54% rename from relay/channel/aiproxy/adaptor.go rename to relay/adaptor/aiproxy/adaptor.go index 6f5d289f..31865698 100644 --- a/relay/channel/aiproxy/adaptor.go +++ b/relay/adaptor/aiproxy/adaptor.go @@ -4,10 +4,10 @@ import ( "fmt" "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -15,16 +15,16 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/api/library/ask", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("Authorization", "Bearer "+meta.APIKey) return nil } @@ -34,15 +34,22 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } aiProxyLibraryRequest := ConvertRequest(*request) - aiProxyLibraryRequest.LibraryId = c.GetString(common.ConfigKeyLibraryID) + aiProxyLibraryRequest.LibraryId = c.GetString(config.KeyLibraryID) return aiProxyLibraryRequest, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) } else { diff --git a/relay/channel/aiproxy/constants.go b/relay/adaptor/aiproxy/constants.go similarity index 60% rename from relay/channel/aiproxy/constants.go rename to relay/adaptor/aiproxy/constants.go index c4df51c4..818d2709 100644 --- a/relay/channel/aiproxy/constants.go +++ b/relay/adaptor/aiproxy/constants.go @@ -1,6 +1,6 @@ package aiproxy -import "github.com/songquanpeng/one-api/relay/channel/openai" +import "github.com/songquanpeng/one-api/relay/adaptor/openai" var ModelList = []string{""} diff --git a/relay/channel/aiproxy/main.go b/relay/adaptor/aiproxy/main.go similarity index 95% rename from relay/channel/aiproxy/main.go rename to relay/adaptor/aiproxy/main.go index 7b146828..961260de 100644 --- a/relay/channel/aiproxy/main.go +++ b/relay/adaptor/aiproxy/main.go @@ -13,7 +13,8 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" ) @@ -54,7 +55,7 @@ func responseAIProxyLibrary2OpenAI(response *LibraryResponse) *openai.TextRespon FinishReason: "stop", } fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion", Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, @@ -67,7 +68,7 @@ func documentsAIProxyLibrary(documents []LibraryDocument) *openai.ChatCompletion choice.Delta.Content = aiProxyDocuments2Markdown(documents) choice.FinishReason = &constant.StopFinishReason return &openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: "", @@ -79,7 +80,7 @@ func streamResponseAIProxyLibrary2OpenAI(response *LibraryStreamResponse) *opena var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = response.Content return &openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: response.Model, diff --git a/relay/channel/aiproxy/model.go b/relay/adaptor/aiproxy/model.go similarity index 100% rename from relay/channel/aiproxy/model.go rename to relay/adaptor/aiproxy/model.go diff --git a/relay/adaptor/ali/adaptor.go b/relay/adaptor/ali/adaptor.go new file mode 100644 index 00000000..e004211e --- /dev/null +++ b/relay/adaptor/ali/adaptor.go @@ -0,0 +1,105 @@ +package ali + +// import ( +// "github.com/Laisky/errors/v2" +// "fmt" +// "github.com/gin-gonic/gin" +// "github.com/songquanpeng/one-api/common/config" +// "github.com/songquanpeng/one-api/relay/adaptor" +// "github.com/songquanpeng/one-api/relay/meta" +// "github.com/songquanpeng/one-api/relay/model" +// "github.com/songquanpeng/one-api/relay/relaymode" +// "io" +// "net/http" +// ) + +// // https://help.aliyun.com/zh/dashscope/developer-reference/api-details + +// type Adaptor struct { +// } + +// func (a *Adaptor) Init(meta *meta.Meta) { + +// } + +// func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { +// fullRequestURL := "" +// switch meta.Mode { +// case relaymode.Embeddings: +// fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) +// case relaymode.ImagesGenerations: +// fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", meta.BaseURL) +// default: +// fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) +// } + +// return fullRequestURL, nil +// } + +// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { +// adaptor.SetupCommonRequestHeader(c, req, meta) +// if meta.IsStream { +// req.Header.Set("Accept", "text/event-stream") +// req.Header.Set("X-DashScope-SSE", "enable") +// } +// req.Header.Set("Authorization", "Bearer "+meta.APIKey) + +// if meta.Mode == relaymode.ImagesGenerations { +// req.Header.Set("X-DashScope-Async", "enable") +// } +// if c.GetString(config.KeyPlugin) != "" { +// req.Header.Set("X-DashScope-Plugin", c.GetString(config.KeyPlugin)) +// } +// return nil +// } + +// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { +// if request == nil { +// return nil, errors.New("request is nil") +// } +// switch relayMode { +// case relaymode.Embeddings: +// aliEmbeddingRequest := ConvertEmbeddingRequest(*request) +// return aliEmbeddingRequest, nil +// default: +// aliRequest := ConvertRequest(*request) +// return aliRequest, nil +// } +// } + +// func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +// if request == nil { +// return nil, errors.New("request is nil") +// } + +// aliRequest := ConvertImageRequest(*request) +// return aliRequest, nil +// } + +// func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { +// return adaptor.DoRequestHelper(a, c, meta, requestBody) +// } + +// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +// if meta.IsStream { +// err, usage = StreamHandler(c, resp) +// } else { +// switch meta.Mode { +// case relaymode.Embeddings: +// err, usage = EmbeddingHandler(c, resp) +// case relaymode.ImagesGenerations: +// err, usage = ImageHandler(c, resp) +// default: +// err, usage = Handler(c, resp) +// } +// } +// return +// } + +// func (a *Adaptor) GetModelList() []string { +// return ModelList +// } + +// func (a *Adaptor) GetChannelName() string { +// return "ali" +// } diff --git a/relay/channel/ali/constants.go b/relay/adaptor/ali/constants.go similarity index 65% rename from relay/channel/ali/constants.go rename to relay/adaptor/ali/constants.go index 16bcfca4..3f24ce2e 100644 --- a/relay/channel/ali/constants.go +++ b/relay/adaptor/ali/constants.go @@ -3,4 +3,5 @@ package ali var ModelList = []string{ "qwen-turbo", "qwen-plus", "qwen-max", "qwen-max-longcontext", "text-embedding-v1", + "ali-stable-diffusion-xl", "ali-stable-diffusion-v1.5", "wanx-v1", } diff --git a/relay/adaptor/ali/image.go b/relay/adaptor/ali/image.go new file mode 100644 index 00000000..cef509e2 --- /dev/null +++ b/relay/adaptor/ali/image.go @@ -0,0 +1,192 @@ +package ali + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "github.com/Laisky/errors/v2" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strings" + "time" +) + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + responseFormat := c.GetString("response_format") + + var aliTaskResponse TaskResponse + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &aliTaskResponse) + if err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + if aliTaskResponse.Message != "" { + logger.SysError("aliAsyncTask err: " + string(responseBody)) + return openai.ErrorWrapper(errors.New(aliTaskResponse.Message), "ali_async_task_failed", http.StatusInternalServerError), nil + } + + aliResponse, _, err := asyncTaskWait(aliTaskResponse.Output.TaskId, apiKey) + if err != nil { + return openai.ErrorWrapper(err, "ali_async_task_wait_failed", http.StatusInternalServerError), nil + } + + if aliResponse.Output.TaskStatus != "SUCCEEDED" { + return &model.ErrorWithStatusCode{ + Error: model.Error{ + Message: aliResponse.Output.Message, + Type: "ali_error", + Param: "", + Code: aliResponse.Output.Code, + }, + StatusCode: resp.StatusCode, + }, nil + } + + fullTextResponse := responseAli2OpenAIImage(aliResponse, responseFormat) + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, nil +} + +func asyncTask(taskID string, key string) (*TaskResponse, error, []byte) { + url := fmt.Sprintf("https://dashscope.aliyuncs.com/api/v1/tasks/%s", taskID) + + var aliResponse TaskResponse + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return &aliResponse, err, nil + } + + req.Header.Set("Authorization", "Bearer "+key) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + logger.SysError("aliAsyncTask client.Do err: " + err.Error()) + return &aliResponse, err, nil + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) + + var response TaskResponse + err = json.Unmarshal(responseBody, &response) + if err != nil { + logger.SysError("aliAsyncTask NewDecoder err: " + err.Error()) + return &aliResponse, err, nil + } + + return &response, nil, responseBody +} + +func asyncTaskWait(taskID string, key string) (*TaskResponse, []byte, error) { + waitSeconds := 2 + step := 0 + maxStep := 20 + + var taskResponse TaskResponse + var responseBody []byte + + for { + step++ + rsp, err, body := asyncTask(taskID, key) + responseBody = body + if err != nil { + return &taskResponse, responseBody, err + } + + if rsp.Output.TaskStatus == "" { + return &taskResponse, responseBody, nil + } + + switch rsp.Output.TaskStatus { + case "FAILED": + fallthrough + case "CANCELED": + fallthrough + case "SUCCEEDED": + fallthrough + case "UNKNOWN": + return rsp, responseBody, nil + } + if step >= maxStep { + break + } + time.Sleep(time.Duration(waitSeconds) * time.Second) + } + + return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout") +} + +func responseAli2OpenAIImage(response *TaskResponse, responseFormat string) *openai.ImageResponse { + imageResponse := openai.ImageResponse{ + Created: helper.GetTimestamp(), + } + + for _, data := range response.Output.Results { + var b64Json string + if responseFormat == "b64_json" { + // 读取 data.Url 的图片数据并转存到 b64Json + imageData, err := getImageData(data.Url) + if err != nil { + // 处理获取图片数据失败的情况 + logger.SysError("getImageData Error getting image data: " + err.Error()) + continue + } + + // 将图片数据转为 Base64 编码的字符串 + b64Json = Base64Encode(imageData) + } else { + // 如果 responseFormat 不是 "b64_json",则直接使用 data.B64Image + b64Json = data.B64Image + } + + imageResponse.Data = append(imageResponse.Data, openai.ImageData{ + Url: data.Url, + B64Json: b64Json, + RevisedPrompt: "", + }) + } + return &imageResponse +} + +func getImageData(url string) ([]byte, error) { + response, err := http.Get(url) + if err != nil { + return nil, err + } + defer response.Body.Close() + + imageData, err := io.ReadAll(response.Body) + if err != nil { + return nil, err + } + + return imageData, nil +} + +func Base64Encode(data []byte) string { + b64Json := base64.StdEncoding.EncodeToString(data) + return b64Json +} diff --git a/relay/channel/ali/main.go b/relay/adaptor/ali/main.go similarity index 100% rename from relay/channel/ali/main.go rename to relay/adaptor/ali/main.go diff --git a/relay/adaptor/ali/model.go b/relay/adaptor/ali/model.go new file mode 100644 index 00000000..450b5f52 --- /dev/null +++ b/relay/adaptor/ali/model.go @@ -0,0 +1,154 @@ +package ali + +import ( + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/model" +) + +type Message struct { + Content string `json:"content"` + Role string `json:"role"` +} + +type Input struct { + //Prompt string `json:"prompt"` + Messages []Message `json:"messages"` +} + +type Parameters struct { + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + Seed uint64 `json:"seed,omitempty"` + EnableSearch bool `json:"enable_search,omitempty"` + IncrementalOutput bool `json:"incremental_output,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + ResultFormat string `json:"result_format,omitempty"` + Tools []model.Tool `json:"tools,omitempty"` +} + +type ChatRequest struct { + Model string `json:"model"` + Input Input `json:"input"` + Parameters Parameters `json:"parameters,omitempty"` +} + +type ImageRequest struct { + Model string `json:"model"` + Input struct { + Prompt string `json:"prompt"` + NegativePrompt string `json:"negative_prompt,omitempty"` + } `json:"input"` + Parameters struct { + Size string `json:"size,omitempty"` + N int `json:"n,omitempty"` + Steps string `json:"steps,omitempty"` + Scale string `json:"scale,omitempty"` + } `json:"parameters,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` +} + +type TaskResponse struct { + StatusCode int `json:"status_code,omitempty"` + RequestId string `json:"request_id,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Output struct { + TaskId string `json:"task_id,omitempty"` + TaskStatus string `json:"task_status,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + Results []struct { + B64Image string `json:"b64_image,omitempty"` + Url string `json:"url,omitempty"` + Code string `json:"code,omitempty"` + Message string `json:"message,omitempty"` + } `json:"results,omitempty"` + TaskMetrics struct { + Total int `json:"TOTAL,omitempty"` + Succeeded int `json:"SUCCEEDED,omitempty"` + Failed int `json:"FAILED,omitempty"` + } `json:"task_metrics,omitempty"` + } `json:"output,omitempty"` + Usage Usage `json:"usage"` +} + +type Header struct { + Action string `json:"action,omitempty"` + Streaming string `json:"streaming,omitempty"` + TaskID string `json:"task_id,omitempty"` + Event string `json:"event,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + Attributes any `json:"attributes,omitempty"` +} + +type Payload struct { + Model string `json:"model,omitempty"` + Task string `json:"task,omitempty"` + TaskGroup string `json:"task_group,omitempty"` + Function string `json:"function,omitempty"` + Parameters struct { + SampleRate int `json:"sample_rate,omitempty"` + Rate float64 `json:"rate,omitempty"` + Format string `json:"format,omitempty"` + } `json:"parameters,omitempty"` + Input struct { + Text string `json:"text,omitempty"` + } `json:"input,omitempty"` + Usage struct { + Characters int `json:"characters,omitempty"` + } `json:"usage,omitempty"` +} + +type WSSMessage struct { + Header Header `json:"header,omitempty"` + Payload Payload `json:"payload,omitempty"` +} + +type EmbeddingRequest struct { + Model string `json:"model"` + Input struct { + Texts []string `json:"texts"` + } `json:"input"` + Parameters *struct { + TextType string `json:"text_type,omitempty"` + } `json:"parameters,omitempty"` +} + +type Embedding struct { + Embedding []float64 `json:"embedding"` + TextIndex int `json:"text_index"` +} + +type EmbeddingResponse struct { + Output struct { + Embeddings []Embedding `json:"embeddings"` + } `json:"output"` + Usage Usage `json:"usage"` + Error +} + +type Error struct { + Code string `json:"code"` + Message string `json:"message"` + RequestId string `json:"request_id"` +} + +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +type Output struct { + //Text string `json:"text"` + //FinishReason string `json:"finish_reason"` + Choices []openai.TextResponseChoice `json:"choices"` +} + +type ChatResponse struct { + Output Output `json:"output"` + Usage Usage `json:"usage"` + Error +} diff --git a/relay/channel/anthropic/adaptor.go b/relay/adaptor/anthropic/adaptor.go similarity index 60% rename from relay/channel/anthropic/adaptor.go rename to relay/adaptor/anthropic/adaptor.go index 9f1adb9a..07efb3c7 100644 --- a/relay/channel/anthropic/adaptor.go +++ b/relay/adaptor/anthropic/adaptor.go @@ -2,31 +2,31 @@ package anthropic import ( "fmt" - "github.com/Laisky/errors/v2" "io" "net/http" + "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" ) type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - // https://docs.anthropic.com/claude/reference/messages_post - // anthopic migrate to Message API +// https://docs.anthropic.com/claude/reference/messages_post +// anthopic migrate to Message API +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/v1/messages", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-api-key", meta.APIKey) anthropicVersion := c.Request.Header.Get("anthropic-version") if anthropicVersion == "" { @@ -46,11 +46,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) } else { diff --git a/relay/channel/anthropic/constants.go b/relay/adaptor/anthropic/constants.go similarity index 100% rename from relay/channel/anthropic/constants.go rename to relay/adaptor/anthropic/constants.go diff --git a/relay/channel/anthropic/main.go b/relay/adaptor/anthropic/main.go similarity index 99% rename from relay/channel/anthropic/main.go rename to relay/adaptor/anthropic/main.go index 198b66ad..aec327fe 100644 --- a/relay/channel/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -13,7 +13,7 @@ import ( "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/model" ) diff --git a/relay/channel/anthropic/model.go b/relay/adaptor/anthropic/model.go similarity index 100% rename from relay/channel/anthropic/model.go rename to relay/adaptor/anthropic/model.go diff --git a/relay/adaptor/azure/helper.go b/relay/adaptor/azure/helper.go new file mode 100644 index 00000000..dd207f37 --- /dev/null +++ b/relay/adaptor/azure/helper.go @@ -0,0 +1,15 @@ +package azure + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" +) + +func GetAPIVersion(c *gin.Context) string { + query := c.Request.URL.Query() + apiVersion := query.Get("api-version") + if apiVersion == "" { + apiVersion = c.GetString(config.KeyAPIVersion) + } + return apiVersion +} diff --git a/relay/channel/baichuan/constants.go b/relay/adaptor/baichuan/constants.go similarity index 100% rename from relay/channel/baichuan/constants.go rename to relay/adaptor/baichuan/constants.go diff --git a/relay/channel/baidu/adaptor.go b/relay/adaptor/baidu/adaptor.go similarity index 100% rename from relay/channel/baidu/adaptor.go rename to relay/adaptor/baidu/adaptor.go diff --git a/relay/adaptor/baidu/constants.go b/relay/adaptor/baidu/constants.go new file mode 100644 index 00000000..f952adc6 --- /dev/null +++ b/relay/adaptor/baidu/constants.go @@ -0,0 +1,20 @@ +package baidu + +var ModelList = []string{ + "ERNIE-4.0-8K", + "ERNIE-3.5-8K", + "ERNIE-3.5-8K-0205", + "ERNIE-3.5-8K-1222", + "ERNIE-Bot-8K", + "ERNIE-3.5-4K-0205", + "ERNIE-Speed-8K", + "ERNIE-Speed-128K", + "ERNIE-Lite-8K-0922", + "ERNIE-Lite-8K-0308", + "ERNIE-Tiny-8K", + "BLOOMZ-7B", + "Embedding-V1", + "bge-large-zh", + "bge-large-en", + "tao-8k", +} diff --git a/relay/channel/baidu/main.go b/relay/adaptor/baidu/main.go similarity index 100% rename from relay/channel/baidu/main.go rename to relay/adaptor/baidu/main.go diff --git a/relay/channel/baidu/model.go b/relay/adaptor/baidu/model.go similarity index 100% rename from relay/channel/baidu/model.go rename to relay/adaptor/baidu/model.go diff --git a/relay/channel/common.go b/relay/adaptor/common.go similarity index 81% rename from relay/channel/common.go rename to relay/adaptor/common.go index 2c4fb37c..13f57132 100644 --- a/relay/channel/common.go +++ b/relay/adaptor/common.go @@ -1,4 +1,4 @@ -package channel +package adaptor import ( "io" @@ -6,10 +6,11 @@ import ( "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/client" + "github.com/songquanpeng/one-api/relay/meta" ) -func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) { +func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) { req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) if meta.IsStream && c.Request.Header.Get("Accept") == "" { @@ -17,7 +18,7 @@ func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *util.Rela } } -func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { fullRequestURL, err := a.GetRequestURL(meta) if err != nil { return nil, errors.Wrap(err, "get request url failed") @@ -43,7 +44,7 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *util.RelayMeta, requestBod } func DoRequest(c *gin.Context, req *http.Request) (*http.Response, error) { - resp, err := util.HTTPClient.Do(req) + resp, err := client.HTTPClient.Do(req) if err != nil { return nil, err } diff --git a/relay/channel/gemini/adaptor.go b/relay/adaptor/gemini/adaptor.go similarity index 66% rename from relay/channel/gemini/adaptor.go rename to relay/adaptor/gemini/adaptor.go index 240ca28d..ecb72221 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/adaptor/gemini/adaptor.go @@ -8,21 +8,21 @@ import ( "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" - channelhelper "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" + channelhelper "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" ) type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { - version := helper.AssignOrDefault(meta.APIVersion, "v1beta") +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + version := helper.AssignOrDefault(meta.APIVersion, "v1") action := "generateContent" if meta.IsStream { action = "streamGenerateContent" @@ -30,7 +30,7 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { return fmt.Sprintf("%s/%s/models/%s:%s?key=%s", meta.BaseURL, version, meta.ActualModelName, action, meta.APIKey), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { channelhelper.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-goog-api-key", meta.APIKey) req.URL.Query().Add("key", meta.APIKey) @@ -44,11 +44,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { return channelhelper.DoRequestHelper(a, c, meta, requestBody) } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp) diff --git a/relay/channel/gemini/constants.go b/relay/adaptor/gemini/constants.go similarity index 100% rename from relay/channel/gemini/constants.go rename to relay/adaptor/gemini/constants.go diff --git a/relay/channel/gemini/main.go b/relay/adaptor/gemini/main.go similarity index 56% rename from relay/channel/gemini/main.go rename to relay/adaptor/gemini/main.go index b66f2d5e..27a9c023 100644 --- a/relay/channel/gemini/main.go +++ b/relay/adaptor/gemini/main.go @@ -1,21 +1,24 @@ package gemini import ( - "context" + "bufio" "encoding/json" "fmt" "io" "net/http" + "strings" - "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" + + "github.com/gin-gonic/gin" ) // https://ai.google.dev/docs/gemini_api_overview?hl=zh-cn @@ -82,13 +85,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { if imageNum > VisionMaxImageNum { continue } - mimeType, data, err := image.GetImageFromUrl(part.ImageURL.Url) - if err != nil { - logger.Warn(context.TODO(), - fmt.Sprintf("get image from url %s got %+v", part.ImageURL.Url, err)) - continue - } - + mimeType, data, _ := image.GetImageFromUrl(part.ImageURL.Url) parts = append(parts, Part{ InlineData: &InlineData{ MimeType: mimeType, @@ -97,9 +94,6 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *ChatRequest { }) } } - - logger.Info(context.TODO(), - fmt.Sprintf("send %d messages to gemini with %d images", len(parts), imageNum)) content.Parts = parts // there's no assistant role in gemini and API shall vomit if Role is not user or model @@ -163,7 +157,7 @@ type ChatPromptFeedback struct { func responseGeminiChat2OpenAI(response *ChatResponse) *openai.TextResponse { fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion", Created: helper.GetTimestamp(), Choices: make([]openai.TextResponseChoice, 0, len(response.Candidates)), @@ -196,182 +190,73 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *ChatResponse) *openai.ChatC return &response } -// [{ -// "candidates": [ -// { -// "content": { -// "parts": [ -// { -// "text": "```go \n\n// Package ratelimit implements tokens bucket algorithm.\npackage rate" -// } -// ], -// "role": "model" -// }, -// "finishReason": "STOP", -// "index": 0, -// "safetyRatings": [ -// { -// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", -// "probability": "NEGLIGIBLE" -// }, -// { -// "category": "HARM_CATEGORY_HATE_SPEECH", -// "probability": "NEGLIGIBLE" -// }, -// { -// "category": "HARM_CATEGORY_HARASSMENT", -// "probability": "NEGLIGIBLE" -// }, -// { -// "category": "HARM_CATEGORY_DANGEROUS_CONTENT", -// "probability": "NEGLIGIBLE" -// } -// ] -// } -// ], -// "promptFeedback": { -// "safetyRatings": [ -// { -// "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", -// "probability": "NEGLIGIBLE" -// }, -// { -// "category": "HARM_CATEGORY_HATE_SPEECH", -// "probability": "NEGLIGIBLE" -// }, -// { -// "category": "HARM_CATEGORY_HARASSMENT", -// "probability": "NEGLIGIBLE" -// }, -// { -// "category": "HARM_CATEGORY_DANGEROUS_CONTENT", -// "probability": "NEGLIGIBLE" -// } -// ] -// } -// }] -type GeminiStreamResp struct { - Candidates []struct { - Content struct { - Parts []struct { - Text string `json:"text"` - } `json:"parts"` - Role string `json:"role"` - } `json:"content"` - FinishReason string `json:"finishReason"` - Index int64 `json:"index"` - } `json:"candidates"` -} - func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return openai.ErrorWrapper(err, "read upstream's body", http.StatusInternalServerError), responseText - } - - var respData []GeminiStreamResp - if err = json.Unmarshal(respBody, &respData); err != nil { - return openai.ErrorWrapper(err, "unmarshal upstream's body", http.StatusInternalServerError), responseText - } - - for _, chunk := range respData { - for _, cad := range chunk.Candidates { - for _, part := range cad.Content.Parts { - responseText += part.Text - } + dataChan := make(chan string) + stopChan := make(chan bool) + scanner := bufio.NewScanner(resp.Body) + scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil } - } - - var choice openai.ChatCompletionsStreamResponseChoice - choice.Delta.Content = responseText - resp2cli, err := json.Marshal(&openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), - Object: "chat.completion.chunk", - Created: helper.GetTimestamp(), - Model: "gemini-pro", - Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + if i := strings.Index(string(data), "\n"); i >= 0 { + return i + 1, data[0:i], nil + } + if atEOF { + return len(data), data, nil + } + return 0, nil, nil }) + go func() { + for scanner.Scan() { + data := scanner.Text() + data = strings.TrimSpace(data) + if !strings.HasPrefix(data, "\"text\": \"") { + continue + } + data = strings.TrimPrefix(data, "\"text\": \"") + data = strings.TrimSuffix(data, "\"") + dataChan <- data + } + stopChan <- true + }() + common.SetEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + // this is used to prevent annoying \ related format bug + data = fmt.Sprintf("{\"content\": \"%s\"}", data) + type dummyStruct struct { + Content string `json:"content"` + } + var dummy dummyStruct + err := json.Unmarshal([]byte(data), &dummy) + responseText += dummy.Content + var choice openai.ChatCompletionsStreamResponseChoice + choice.Delta.Content = dummy.Content + response := openai.ChatCompletionsStreamResponse{ + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), + Object: "chat.completion.chunk", + Created: helper.GetTimestamp(), + Model: "gemini-pro", + Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, + } + jsonResponse, err := json.Marshal(response) + if err != nil { + logger.SysError("error marshalling stream response: " + err.Error()) + return true + } + c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() if err != nil { - return openai.ErrorWrapper(err, "marshal upstream's body", http.StatusInternalServerError), responseText - } - - c.Render(-1, common.CustomEvent{Data: "data: " + string(resp2cli)}) - c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - - // dataChan := make(chan string) - // stopChan := make(chan bool) - // scanner := bufio.NewScanner(resp.Body) - // scanner.Split(bufio.ScanLines) - // // scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { - // // if atEOF && len(data) == 0 { - // // return 0, nil, nil - // // } - // // if i := strings.Index(string(data), "\n"); i >= 0 { - // // return i + 1, data[0:i], nil - // // } - // // if atEOF { - // // return len(data), data, nil - // // } - // // return 0, nil, nil - // // }) - // go func() { - // var content string - // for scanner.Scan() { - // line := strings.TrimSpace(scanner.Text()) - // fmt.Printf("> gemini got line: %s\n", line) - // content += line - // // if !strings.HasPrefix(data, "\"text\": \"") { - // // continue - // // } - - // // data = strings.TrimPrefix(data, "\"text\": \"") - // // data = strings.TrimSuffix(data, "\"") - // // dataChan <- data - // } - - // dataChan <- content - // stopChan <- true - // }() - // common.SetEventStreamHeaders(c) - // c.Stream(func(w io.Writer) bool { - // select { - // case data := <-dataChan: - // // this is used to prevent annoying \ related format bug - // data = fmt.Sprintf("{\"content\": \"%s\"}", data) - // type dummyStruct struct { - // Content string `json:"content"` - // } - // var dummy dummyStruct - // err := json.Unmarshal([]byte(data), &dummy) - // responseText += dummy.Content - // var choice openai.ChatCompletionsStreamResponseChoice - // choice.Delta.Content = dummy.Content - // response := openai.ChatCompletionsStreamResponse{ - // Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), - // Object: "chat.completion.chunk", - // Created: helper.GetTimestamp(), - // Model: "gemini-pro", - // Choices: []openai.ChatCompletionsStreamResponseChoice{choice}, - // } - // jsonResponse, err := json.Marshal(response) - // if err != nil { - // logger.SysError("error marshalling stream response: " + err.Error()) - // return true - // } - // c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) - // return true - // case <-stopChan: - // c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) - // return false - // } - // }) - - if err := resp.Body.Close(); err != nil { return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" } - return nil, responseText } diff --git a/relay/channel/gemini/model.go b/relay/adaptor/gemini/model.go similarity index 100% rename from relay/channel/gemini/model.go rename to relay/adaptor/gemini/model.go diff --git a/relay/channel/groq/constants.go b/relay/adaptor/groq/constants.go similarity index 100% rename from relay/channel/groq/constants.go rename to relay/adaptor/groq/constants.go diff --git a/relay/adaptor/interface.go b/relay/adaptor/interface.go new file mode 100644 index 00000000..01b2e2cb --- /dev/null +++ b/relay/adaptor/interface.go @@ -0,0 +1,21 @@ +package adaptor + +import ( + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +type Adaptor interface { + Init(meta *meta.Meta) + GetRequestURL(meta *meta.Meta) (string, error) + SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error + ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) + ConvertImageRequest(request *model.ImageRequest) (any, error) + DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) + DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) + GetModelList() []string + GetChannelName() string +} diff --git a/relay/channel/lingyiwanwu/constants.go b/relay/adaptor/lingyiwanwu/constants.go similarity index 100% rename from relay/channel/lingyiwanwu/constants.go rename to relay/adaptor/lingyiwanwu/constants.go diff --git a/relay/channel/minimax/constants.go b/relay/adaptor/minimax/constants.go similarity index 100% rename from relay/channel/minimax/constants.go rename to relay/adaptor/minimax/constants.go diff --git a/relay/adaptor/minimax/main.go b/relay/adaptor/minimax/main.go new file mode 100644 index 00000000..fc9b5d26 --- /dev/null +++ b/relay/adaptor/minimax/main.go @@ -0,0 +1,14 @@ +package minimax + +import ( + "fmt" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +func GetRequestURL(meta *meta.Meta) (string, error) { + if meta.Mode == relaymode.ChatCompletions { + return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil + } + return "", fmt.Errorf("unsupported relay mode %d for minimax", meta.Mode) +} diff --git a/relay/channel/mistral/constants.go b/relay/adaptor/mistral/constants.go similarity index 100% rename from relay/channel/mistral/constants.go rename to relay/adaptor/mistral/constants.go diff --git a/relay/channel/moonshot/constants.go b/relay/adaptor/moonshot/constants.go similarity index 100% rename from relay/channel/moonshot/constants.go rename to relay/adaptor/moonshot/constants.go diff --git a/relay/channel/ollama/adaptor.go b/relay/adaptor/ollama/adaptor.go similarity index 58% rename from relay/channel/ollama/adaptor.go rename to relay/adaptor/ollama/adaptor.go index e2ae7d2b..ec1b0c40 100644 --- a/relay/channel/ollama/adaptor.go +++ b/relay/adaptor/ollama/adaptor.go @@ -1,36 +1,36 @@ package ollama import ( - "errors" "fmt" "io" "net/http" + "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/relaymode" ) type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { // https://github.com/ollama/ollama/blob/main/docs/api.md fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL) - if meta.Mode == constant.RelayModeEmbeddings { + if meta.Mode == relaymode.Embeddings { fullRequestURL = fmt.Sprintf("%s/api/embeddings", meta.BaseURL) } return fullRequestURL, nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("Authorization", "Bearer "+meta.APIKey) return nil } @@ -40,7 +40,7 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return nil, errors.New("request is nil") } switch relayMode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: ollamaEmbeddingRequest := ConvertEmbeddingRequest(*request) return ollamaEmbeddingRequest, nil default: @@ -48,16 +48,23 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G } } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { err, usage = StreamHandler(c, resp) } else { switch meta.Mode { - case constant.RelayModeEmbeddings: + case relaymode.Embeddings: err, usage = EmbeddingHandler(c, resp) default: err, usage = Handler(c, resp) diff --git a/relay/channel/ollama/constants.go b/relay/adaptor/ollama/constants.go similarity index 100% rename from relay/channel/ollama/constants.go rename to relay/adaptor/ollama/constants.go diff --git a/relay/channel/ollama/main.go b/relay/adaptor/ollama/main.go similarity index 97% rename from relay/channel/ollama/main.go rename to relay/adaptor/ollama/main.go index 821a335b..a7e4c058 100644 --- a/relay/channel/ollama/main.go +++ b/relay/adaptor/ollama/main.go @@ -5,15 +5,16 @@ import ( "context" "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/random" "io" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" ) @@ -51,7 +52,7 @@ func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse { choice.FinishReason = "stop" } fullTextResponse := openai.TextResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion", Created: helper.GetTimestamp(), Choices: []openai.TextResponseChoice{choice}, @@ -72,7 +73,7 @@ func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompl choice.FinishReason = &constant.StopFinishReason } response := openai.ChatCompletionsStreamResponse{ - Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()), + Id: fmt.Sprintf("chatcmpl-%s", random.GetUUID()), Object: "chat.completion.chunk", Created: helper.GetTimestamp(), Model: ollamaResponse.Model, diff --git a/relay/channel/ollama/model.go b/relay/adaptor/ollama/model.go similarity index 100% rename from relay/channel/ollama/model.go rename to relay/adaptor/ollama/model.go diff --git a/relay/channel/openai/adaptor.go b/relay/adaptor/openai/adaptor.go similarity index 52% rename from relay/channel/openai/adaptor.go rename to relay/adaptor/openai/adaptor.go index 260a35ae..24cf718f 100644 --- a/relay/channel/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -4,11 +4,12 @@ import ( "fmt" "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/minimax" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/minimax" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "strings" @@ -18,13 +19,20 @@ type Adaptor struct { ChannelType int } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { a.ChannelType = meta.ChannelType } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { switch meta.ChannelType { - case common.ChannelTypeAzure: + case channeltype.Azure: + if meta.Mode == relaymode.ImagesGenerations { + // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api + // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview + fullRequestURL := fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, meta.ActualModelName, meta.APIVersion) + return fullRequestURL, nil + } + // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api requestURL := strings.Split(meta.RequestURLPath, "?")[0] requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, meta.APIVersion) @@ -34,22 +42,22 @@ func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { //https://github.com/songquanpeng/one-api/issues/1191 // {your endpoint}/openai/deployments/{your azure_model}/chat/completions?api-version={api_version} requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) - return util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil - case common.ChannelTypeMinimax: + return GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType), nil + case channeltype.Minimax: return minimax.GetRequestURL(meta) default: - return util.GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil + return GetFullRequestURL(meta.BaseURL, meta.RequestURLPath, meta.ChannelType), nil } } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) - if meta.ChannelType == common.ChannelTypeAzure { +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + if meta.ChannelType == channeltype.Azure { req.Header.Set("api-key", meta.APIKey) return nil } req.Header.Set("Authorization", "Bearer "+meta.APIKey) - if meta.ChannelType == common.ChannelTypeOpenRouter { + if meta.ChannelType == channeltype.OpenRouter { req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") req.Header.Set("X-Title", "One API") } @@ -63,11 +71,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return request, nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText, usage = StreamHandler(c, resp, meta.Mode) @@ -75,7 +90,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel usage = ResponseText2Usage(responseText, meta.ActualModelName, meta.PromptTokens) } } else { - err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + switch meta.Mode { + case relaymode.ImagesGenerations: + err, _ = ImageHandler(c, resp) + default: + err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) + } } return } diff --git a/relay/adaptor/openai/compatible.go b/relay/adaptor/openai/compatible.go new file mode 100644 index 00000000..200eac44 --- /dev/null +++ b/relay/adaptor/openai/compatible.go @@ -0,0 +1,50 @@ +package openai + +import ( + "github.com/songquanpeng/one-api/relay/adaptor/ai360" + "github.com/songquanpeng/one-api/relay/adaptor/baichuan" + "github.com/songquanpeng/one-api/relay/adaptor/groq" + "github.com/songquanpeng/one-api/relay/adaptor/lingyiwanwu" + "github.com/songquanpeng/one-api/relay/adaptor/minimax" + "github.com/songquanpeng/one-api/relay/adaptor/mistral" + "github.com/songquanpeng/one-api/relay/adaptor/moonshot" + "github.com/songquanpeng/one-api/relay/adaptor/stepfun" + "github.com/songquanpeng/one-api/relay/channeltype" +) + +var CompatibleChannels = []int{ + channeltype.Azure, + channeltype.AI360, + channeltype.Moonshot, + channeltype.Baichuan, + channeltype.Minimax, + channeltype.Mistral, + channeltype.Groq, + channeltype.LingYiWanWu, + channeltype.StepFun, +} + +func GetCompatibleChannelMeta(channelType int) (string, []string) { + switch channelType { + case channeltype.Azure: + return "azure", ModelList + case channeltype.AI360: + return "360", ai360.ModelList + case channeltype.Moonshot: + return "moonshot", moonshot.ModelList + case channeltype.Baichuan: + return "baichuan", baichuan.ModelList + case channeltype.Minimax: + return "minimax", minimax.ModelList + case channeltype.Mistral: + return "mistralai", mistral.ModelList + case channeltype.Groq: + return "groq", groq.ModelList + case channeltype.LingYiWanWu: + return "lingyiwanwu", lingyiwanwu.ModelList + case channeltype.StepFun: + return "stepfun", stepfun.ModelList + default: + return "openai", ModelList + } +} diff --git a/relay/channel/openai/constants.go b/relay/adaptor/openai/constants.go similarity index 100% rename from relay/channel/openai/constants.go rename to relay/adaptor/openai/constants.go diff --git a/relay/adaptor/openai/helper.go b/relay/adaptor/openai/helper.go new file mode 100644 index 00000000..7d73303b --- /dev/null +++ b/relay/adaptor/openai/helper.go @@ -0,0 +1,30 @@ +package openai + +import ( + "fmt" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/model" + "strings" +) + +func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { + usage := &model.Usage{} + usage.PromptTokens = promptTokens + usage.CompletionTokens = CountTokenText(responseText, modeName) + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + return usage +} + +func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { + fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) + + if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { + switch channelType { + case channeltype.OpenAI: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) + case channeltype.Azure: + fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) + } + } + return fullRequestURL +} diff --git a/relay/adaptor/openai/image.go b/relay/adaptor/openai/image.go new file mode 100644 index 00000000..0f89618a --- /dev/null +++ b/relay/adaptor/openai/image.go @@ -0,0 +1,44 @@ +package openai + +import ( + "bytes" + "encoding/json" + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" +) + +func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + var imageResponse ImageResponse + responseBody, err := io.ReadAll(resp.Body) + + if err != nil { + return ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + err = json.Unmarshal(responseBody, &imageResponse) + if err != nil { + return ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + c.Writer.WriteHeader(resp.StatusCode) + + _, err = io.Copy(c.Writer, resp.Body) + if err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + return nil, nil +} diff --git a/relay/channel/openai/main.go b/relay/adaptor/openai/main.go similarity index 97% rename from relay/channel/openai/main.go rename to relay/adaptor/openai/main.go index 63cb9ae8..68d8f48f 100644 --- a/relay/channel/openai/main.go +++ b/relay/adaptor/openai/main.go @@ -8,8 +8,8 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/conv" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" "io" "net/http" "strings" @@ -46,7 +46,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E data = data[6:] if !strings.HasPrefix(data, "[DONE]") { switch relayMode { - case constant.RelayModeChatCompletions: + case relaymode.ChatCompletions: var streamResponse ChatCompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { @@ -59,7 +59,7 @@ func StreamHandler(c *gin.Context, resp *http.Response, relayMode int) (*model.E if streamResponse.Usage != nil { usage = streamResponse.Usage } - case constant.RelayModeCompletions: + case relaymode.Completions: var streamResponse CompletionsStreamResponse err := json.Unmarshal([]byte(data), &streamResponse) if err != nil { diff --git a/relay/channel/openai/model.go b/relay/adaptor/openai/model.go similarity index 93% rename from relay/channel/openai/model.go rename to relay/adaptor/openai/model.go index 30d77739..ce252ff6 100644 --- a/relay/channel/openai/model.go +++ b/relay/adaptor/openai/model.go @@ -110,11 +110,16 @@ type EmbeddingResponse struct { model.Usage `json:"usage"` } +type ImageData struct { + Url string `json:"url,omitempty"` + B64Json string `json:"b64_json,omitempty"` + RevisedPrompt string `json:"revised_prompt,omitempty"` +} + type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - } + Created int64 `json:"created"` + Data []ImageData `json:"data"` + //model.Usage `json:"usage"` } type ChatCompletionsStreamResponseChoice struct { diff --git a/relay/channel/openai/token.go b/relay/adaptor/openai/token.go similarity index 98% rename from relay/channel/openai/token.go rename to relay/adaptor/openai/token.go index d18ce0df..1e61d255 100644 --- a/relay/channel/openai/token.go +++ b/relay/adaptor/openai/token.go @@ -4,10 +4,10 @@ import ( "fmt" "github.com/Laisky/errors/v2" "github.com/pkoukk/tiktoken-go" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/image" "github.com/songquanpeng/one-api/common/logger" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/model" "math" "strings" @@ -28,7 +28,7 @@ func InitTokenEncoders() { if err != nil { logger.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) } - for model := range common.ModelRatio { + for model := range billingratio.ModelRatio { if strings.HasPrefix(model, "gpt-3.5") { tokenEncoderMap[model] = gpt35TokenEncoder } else if strings.HasPrefix(model, "gpt-4") { diff --git a/relay/channel/openai/util.go b/relay/adaptor/openai/util.go similarity index 100% rename from relay/channel/openai/util.go rename to relay/adaptor/openai/util.go diff --git a/relay/channel/palm/adaptor.go b/relay/adaptor/palm/adaptor.go similarity index 59% rename from relay/channel/palm/adaptor.go rename to relay/adaptor/palm/adaptor.go index 15ee010d..fa73dd30 100644 --- a/relay/channel/palm/adaptor.go +++ b/relay/adaptor/palm/adaptor.go @@ -4,10 +4,10 @@ import ( "fmt" "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" "io" "net/http" ) @@ -15,16 +15,16 @@ import ( type Adaptor struct { } -func (a *Adaptor) Init(meta *util.RelayMeta) { +func (a *Adaptor) Init(meta *meta.Meta) { } -func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", meta.BaseURL), nil } -func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { - channel.SetupCommonRequestHeader(c, req, meta) +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) req.Header.Set("x-goog-api-key", meta.APIKey) return nil } @@ -36,11 +36,18 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G return ConvertRequest(*request), nil } -func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { - return channel.DoRequestHelper(a, c, meta, requestBody) +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + if request == nil { + return nil, errors.New("request is nil") + } + return request, nil } -func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { if meta.IsStream { var responseText string err, responseText = StreamHandler(c, resp) diff --git a/relay/channel/palm/constants.go b/relay/adaptor/palm/constants.go similarity index 100% rename from relay/channel/palm/constants.go rename to relay/adaptor/palm/constants.go diff --git a/relay/channel/palm/model.go b/relay/adaptor/palm/model.go similarity index 100% rename from relay/channel/palm/model.go rename to relay/adaptor/palm/model.go diff --git a/relay/channel/palm/palm.go b/relay/adaptor/palm/palm.go similarity index 97% rename from relay/channel/palm/palm.go rename to relay/adaptor/palm/palm.go index 56738544..1e60e7cd 100644 --- a/relay/channel/palm/palm.go +++ b/relay/adaptor/palm/palm.go @@ -7,7 +7,8 @@ import ( "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/common/random" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/model" "io" @@ -74,7 +75,7 @@ func streamResponsePaLM2OpenAI(palmResponse *ChatResponse) *openai.ChatCompletio func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { responseText := "" - responseId := fmt.Sprintf("chatcmpl-%s", helper.GetUUID()) + responseId := fmt.Sprintf("chatcmpl-%s", random.GetUUID()) createdTime := helper.GetTimestamp() dataChan := make(chan string) stopChan := make(chan bool) diff --git a/relay/adaptor/stepfun/constants.go b/relay/adaptor/stepfun/constants.go new file mode 100644 index 00000000..a82e562b --- /dev/null +++ b/relay/adaptor/stepfun/constants.go @@ -0,0 +1,7 @@ +package stepfun + +var ModelList = []string{ + "step-1-32k", + "step-1v-32k", + "step-1-200k", +} diff --git a/relay/channel/tencent/adaptor.go b/relay/adaptor/tencent/adaptor.go similarity index 100% rename from relay/channel/tencent/adaptor.go rename to relay/adaptor/tencent/adaptor.go diff --git a/relay/channel/tencent/constants.go b/relay/adaptor/tencent/constants.go similarity index 100% rename from relay/channel/tencent/constants.go rename to relay/adaptor/tencent/constants.go diff --git a/relay/adaptor/tencent/main.go b/relay/adaptor/tencent/main.go new file mode 100644 index 00000000..aa87e9ce --- /dev/null +++ b/relay/adaptor/tencent/main.go @@ -0,0 +1,238 @@ +package tencent + +// import ( +// "bufio" +// "crypto/hmac" +// "crypto/sha1" +// "encoding/base64" +// "encoding/json" +// "github.com/Laisky/errors/v2" +// "fmt" +// "github.com/gin-gonic/gin" +// "github.com/songquanpeng/one-api/common" +// "github.com/songquanpeng/one-api/common/helper" +// "github.com/songquanpeng/one-api/common/logger" +// "github.com/songquanpeng/one-api/relay/channel/openai" +// "github.com/songquanpeng/one-api/relay/constant" +// "github.com/songquanpeng/one-api/relay/model" +// "io" +// "net/http" +// "sort" +// "strconv" +// "strings" +// ) + +// // https://cloud.tencent.com/document/product/1729/97732 + +// func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { +// messages := make([]Message, 0, len(request.Messages)) +// for i := 0; i < len(request.Messages); i++ { +// message := request.Messages[i] +// if message.Role == "system" { +// messages = append(messages, Message{ +// Role: "user", +// Content: message.StringContent(), +// }) +// messages = append(messages, Message{ +// Role: "assistant", +// Content: "Okay", +// }) +// continue +// } +// messages = append(messages, Message{ +// Content: message.StringContent(), +// Role: message.Role, +// }) +// } +// stream := 0 +// if request.Stream { +// stream = 1 +// } +// return &ChatRequest{ +// Timestamp: helper.GetTimestamp(), +// Expired: helper.GetTimestamp() + 24*60*60, +// QueryID: helper.GetUUID(), +// Temperature: request.Temperature, +// TopP: request.TopP, +// Stream: stream, +// Messages: messages, +// } +// } + +// func responseTencent2OpenAI(response *ChatResponse) *openai.TextResponse { +// fullTextResponse := openai.TextResponse{ +// Object: "chat.completion", +// Created: helper.GetTimestamp(), +// Usage: response.Usage, +// } +// if len(response.Choices) > 0 { +// choice := openai.TextResponseChoice{ +// Index: 0, +// Message: model.Message{ +// Role: "assistant", +// Content: response.Choices[0].Messages.Content, +// }, +// FinishReason: response.Choices[0].FinishReason, +// } +// fullTextResponse.Choices = append(fullTextResponse.Choices, choice) +// } +// return &fullTextResponse +// } + +// func streamResponseTencent2OpenAI(TencentResponse *ChatResponse) *openai.ChatCompletionsStreamResponse { +// response := openai.ChatCompletionsStreamResponse{ +// Object: "chat.completion.chunk", +// Created: helper.GetTimestamp(), +// Model: "tencent-hunyuan", +// } +// if len(TencentResponse.Choices) > 0 { +// var choice openai.ChatCompletionsStreamResponseChoice +// choice.Delta.Content = TencentResponse.Choices[0].Delta.Content +// if TencentResponse.Choices[0].FinishReason == "stop" { +// choice.FinishReason = &constant.StopFinishReason +// } +// response.Choices = append(response.Choices, choice) +// } +// return &response +// } + +// func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, string) { +// var responseText string +// scanner := bufio.NewScanner(resp.Body) +// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { +// if atEOF && len(data) == 0 { +// return 0, nil, nil +// } +// if i := strings.Index(string(data), "\n"); i >= 0 { +// return i + 1, data[0:i], nil +// } +// if atEOF { +// return len(data), data, nil +// } +// return 0, nil, nil +// }) +// dataChan := make(chan string) +// stopChan := make(chan bool) +// go func() { +// for scanner.Scan() { +// data := scanner.Text() +// if len(data) < 5 { // ignore blank line or wrong format +// continue +// } +// if data[:5] != "data:" { +// continue +// } +// data = data[5:] +// dataChan <- data +// } +// stopChan <- true +// }() +// common.SetEventStreamHeaders(c) +// c.Stream(func(w io.Writer) bool { +// select { +// case data := <-dataChan: +// var TencentResponse ChatResponse +// err := json.Unmarshal([]byte(data), &TencentResponse) +// if err != nil { +// logger.SysError("error unmarshalling stream response: " + err.Error()) +// return true +// } +// response := streamResponseTencent2OpenAI(&TencentResponse) +// if len(response.Choices) != 0 { +// responseText += response.Choices[0].Delta.Content +// } +// jsonResponse, err := json.Marshal(response) +// if err != nil { +// logger.SysError("error marshalling stream response: " + err.Error()) +// return true +// } +// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) +// return true +// case <-stopChan: +// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) +// return false +// } +// }) +// err := resp.Body.Close() +// if err != nil { +// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" +// } +// return nil, responseText +// } + +// func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { +// var TencentResponse ChatResponse +// responseBody, err := io.ReadAll(resp.Body) +// if err != nil { +// return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil +// } +// err = resp.Body.Close() +// if err != nil { +// return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil +// } +// err = json.Unmarshal(responseBody, &TencentResponse) +// if err != nil { +// return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil +// } +// if TencentResponse.Error.Code != 0 { +// return &model.ErrorWithStatusCode{ +// Error: model.Error{ +// Message: TencentResponse.Error.Message, +// Code: TencentResponse.Error.Code, +// }, +// StatusCode: resp.StatusCode, +// }, nil +// } +// fullTextResponse := responseTencent2OpenAI(&TencentResponse) +// fullTextResponse.Model = "hunyuan" +// jsonResponse, err := json.Marshal(fullTextResponse) +// if err != nil { +// return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil +// } +// c.Writer.Header().Set("Content-Type", "application/json") +// c.Writer.WriteHeader(resp.StatusCode) +// _, err = c.Writer.Write(jsonResponse) +// if err != nil { +// return openai.ErrorWrapper(err, "write_response_body_failed", http.StatusInternalServerError), nil +// } +// return nil, &fullTextResponse.Usage +// } + +// func ParseConfig(config string) (appId int64, secretId string, secretKey string, err error) { +// parts := strings.Split(config, "|") +// if len(parts) != 3 { +// err = errors.New("invalid tencent config") +// return +// } +// appId, err = strconv.ParseInt(parts[0], 10, 64) +// secretId = parts[1] +// secretKey = parts[2] +// return +// } + +// func GetSign(req ChatRequest, secretKey string) string { +// params := make([]string, 0) +// params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) +// params = append(params, "secret_id="+req.SecretId) +// params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) +// params = append(params, "query_id="+req.QueryID) +// params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) +// params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) +// params = append(params, "stream="+strconv.Itoa(req.Stream)) +// params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) + +// var messageStr string +// for _, msg := range req.Messages { +// messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) +// } +// messageStr = strings.TrimSuffix(messageStr, ",") +// params = append(params, "messages=["+messageStr+"]") + +// sort.Strings(params) +// url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") +// mac := hmac.New(sha1.New, []byte(secretKey)) +// signURL := url +// mac.Write([]byte(signURL)) +// sign := mac.Sum([]byte(nil)) +// return base64.StdEncoding.EncodeToString(sign) +// } diff --git a/relay/channel/tencent/model.go b/relay/adaptor/tencent/model.go similarity index 100% rename from relay/channel/tencent/model.go rename to relay/adaptor/tencent/model.go diff --git a/relay/channel/xunfei/adaptor.go b/relay/adaptor/xunfei/adaptor.go similarity index 100% rename from relay/channel/xunfei/adaptor.go rename to relay/adaptor/xunfei/adaptor.go diff --git a/relay/channel/xunfei/constants.go b/relay/adaptor/xunfei/constants.go similarity index 100% rename from relay/channel/xunfei/constants.go rename to relay/adaptor/xunfei/constants.go diff --git a/relay/channel/xunfei/main.go b/relay/adaptor/xunfei/main.go similarity index 100% rename from relay/channel/xunfei/main.go rename to relay/adaptor/xunfei/main.go diff --git a/relay/channel/xunfei/model.go b/relay/adaptor/xunfei/model.go similarity index 100% rename from relay/channel/xunfei/model.go rename to relay/adaptor/xunfei/model.go diff --git a/relay/adaptor/zhipu/adaptor.go b/relay/adaptor/zhipu/adaptor.go new file mode 100644 index 00000000..424fabd6 --- /dev/null +++ b/relay/adaptor/zhipu/adaptor.go @@ -0,0 +1,145 @@ +package zhipu + +// import ( +// "github.com/Laisky/errors/v2" +// "fmt" +// "github.com/gin-gonic/gin" +// "github.com/songquanpeng/one-api/relay/adaptor" +// "github.com/songquanpeng/one-api/relay/adaptor/openai" +// "github.com/songquanpeng/one-api/relay/meta" +// "github.com/songquanpeng/one-api/relay/model" +// "github.com/songquanpeng/one-api/relay/relaymode" +// "io" +// "math" +// "net/http" +// "strings" +// ) + +// type Adaptor struct { +// APIVersion string +// } + +// func (a *Adaptor) Init(meta *meta.Meta) { + +// } + +// func (a *Adaptor) SetVersionByModeName(modelName string) { +// if strings.HasPrefix(modelName, "glm-") { +// a.APIVersion = "v4" +// } else { +// a.APIVersion = "v3" +// } +// } + +// func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { +// switch meta.Mode { +// case relaymode.ImagesGenerations: +// return fmt.Sprintf("%s/api/paas/v4/images/generations", meta.BaseURL), nil +// case relaymode.Embeddings: +// return fmt.Sprintf("%s/api/paas/v4/embeddings", meta.BaseURL), nil +// } +// a.SetVersionByModeName(meta.ActualModelName) +// if a.APIVersion == "v4" { +// return fmt.Sprintf("%s/api/paas/v4/chat/completions", meta.BaseURL), nil +// } +// method := "invoke" +// if meta.IsStream { +// method = "sse-invoke" +// } +// return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil +// } + +// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { +// adaptor.SetupCommonRequestHeader(c, req, meta) +// token := GetToken(meta.APIKey) +// req.Header.Set("Authorization", token) +// return nil +// } + +// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { +// if request == nil { +// return nil, errors.New("request is nil") +// } +// switch relayMode { +// case relaymode.Embeddings: +// baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) +// return baiduEmbeddingRequest, nil +// default: +// // TopP (0.0, 1.0) +// request.TopP = math.Min(0.99, request.TopP) +// request.TopP = math.Max(0.01, request.TopP) + +// // Temperature (0.0, 1.0) +// request.Temperature = math.Min(0.99, request.Temperature) +// request.Temperature = math.Max(0.01, request.Temperature) +// a.SetVersionByModeName(request.Model) +// if a.APIVersion == "v4" { +// return request, nil +// } +// return ConvertRequest(*request), nil +// } +// } + +// func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { +// if request == nil { +// return nil, errors.New("request is nil") +// } +// newRequest := ImageRequest{ +// Model: request.Model, +// Prompt: request.Prompt, +// UserId: request.User, +// } +// return newRequest, nil +// } + +// func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { +// return adaptor.DoRequestHelper(a, c, meta, requestBody) +// } + +// func (a *Adaptor) DoResponseV4(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +// if meta.IsStream { +// err, _, usage = openai.StreamHandler(c, resp, meta.Mode) +// } else { +// err, usage = openai.Handler(c, resp, meta.PromptTokens, meta.ActualModelName) +// } +// return +// } + +// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { +// switch meta.Mode { +// case relaymode.Embeddings: +// err, usage = EmbeddingsHandler(c, resp) +// return +// case relaymode.ImagesGenerations: +// err, usage = openai.ImageHandler(c, resp) +// return +// } +// if a.APIVersion == "v4" { +// return a.DoResponseV4(c, resp, meta) +// } +// if meta.IsStream { +// err, usage = StreamHandler(c, resp) +// } else { +// if meta.Mode == relaymode.Embeddings { +// err, usage = EmbeddingsHandler(c, resp) +// } else { +// err, usage = Handler(c, resp) +// } +// } +// return +// } + +// func ConvertEmbeddingRequest(request model.GeneralOpenAIRequest) *EmbeddingRequest { +// return &EmbeddingRequest{ +// Model: "embedding-2", +// Input: request.Input.(string), +// } +// } + +// func (a *Adaptor) GetModelList() []string { +// return ModelList +// } + +// func (a *Adaptor) GetChannelName() string { +// return "zhipu" +// } diff --git a/relay/channel/zhipu/constants.go b/relay/adaptor/zhipu/constants.go similarity index 100% rename from relay/channel/zhipu/constants.go rename to relay/adaptor/zhipu/constants.go diff --git a/relay/channel/zhipu/main.go b/relay/adaptor/zhipu/main.go similarity index 100% rename from relay/channel/zhipu/main.go rename to relay/adaptor/zhipu/main.go diff --git a/relay/channel/zhipu/model.go b/relay/adaptor/zhipu/model.go similarity index 100% rename from relay/channel/zhipu/model.go rename to relay/adaptor/zhipu/model.go diff --git a/relay/apitype/define.go b/relay/apitype/define.go new file mode 100644 index 00000000..82d32a50 --- /dev/null +++ b/relay/apitype/define.go @@ -0,0 +1,17 @@ +package apitype + +const ( + OpenAI = iota + Anthropic + PaLM + Baidu + Zhipu + Ali + Xunfei + AIProxyLibrary + Tencent + Gemini + Ollama + + Dummy // this one is only for count, do not add any channel after this +) diff --git a/relay/billing/billing.go b/relay/billing/billing.go new file mode 100644 index 00000000..a99d37ee --- /dev/null +++ b/relay/billing/billing.go @@ -0,0 +1,42 @@ +package billing + +import ( + "context" + "fmt" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" +) + +func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { + if preConsumedQuota != 0 { + go func(ctx context.Context) { + // return pre-consumed quota + err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) + if err != nil { + logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) + } + }(ctx) + } +} + +func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { + // quotaDelta is remaining quota to be consumed + err := model.PostConsumeTokenQuota(tokenId, quotaDelta) + if err != nil { + logger.SysError("error consuming token remain quota: " + err.Error()) + } + err = model.CacheUpdateUserQuota(ctx, userId) + if err != nil { + logger.SysError("error update user quota cache: " + err.Error()) + } + // totalQuota is total quota consumed + if totalQuota != 0 { + logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) + model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) + model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) + model.UpdateChannelUsedQuota(channelId, totalQuota) + } + if totalQuota <= 0 { + logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) + } +} diff --git a/common/group-ratio.go b/relay/billing/ratio/group.go similarity index 97% rename from common/group-ratio.go rename to relay/billing/ratio/group.go index 2de6e810..8e9c5b73 100644 --- a/common/group-ratio.go +++ b/relay/billing/ratio/group.go @@ -1,4 +1,4 @@ -package common +package ratio import ( "encoding/json" diff --git a/relay/billing/ratio/image.go b/relay/billing/ratio/image.go new file mode 100644 index 00000000..5a29cddc --- /dev/null +++ b/relay/billing/ratio/image.go @@ -0,0 +1,51 @@ +package ratio + +var ImageSizeRatios = map[string]map[string]float64{ + "dall-e-2": { + "256x256": 1, + "512x512": 1.125, + "1024x1024": 1.25, + }, + "dall-e-3": { + "1024x1024": 1, + "1024x1792": 2, + "1792x1024": 2, + }, + "ali-stable-diffusion-xl": { + "512x1024": 1, + "1024x768": 1, + "1024x1024": 1, + "576x1024": 1, + "1024x576": 1, + }, + "ali-stable-diffusion-v1.5": { + "512x1024": 1, + "1024x768": 1, + "1024x1024": 1, + "576x1024": 1, + "1024x576": 1, + }, + "wanx-v1": { + "1024x1024": 1, + "720x1280": 1, + "1280x720": 1, + }, +} + +var ImageGenerationAmounts = map[string][2]int{ + "dall-e-2": {1, 10}, + "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. + "ali-stable-diffusion-xl": {1, 4}, // Ali + "ali-stable-diffusion-v1.5": {1, 4}, // Ali + "wanx-v1": {1, 4}, // Ali + "cogview-3": {1, 1}, +} + +var ImagePromptLengthLimitations = map[string]int{ + "dall-e-2": 1000, + "dall-e-3": 4000, + "ali-stable-diffusion-xl": 4000, + "ali-stable-diffusion-v1.5": 4000, + "wanx-v1": 4000, + "cogview-3": 833, +} diff --git a/common/model-ratio.go b/relay/billing/ratio/model.go similarity index 85% rename from common/model-ratio.go rename to relay/billing/ratio/model.go index 4f6acb7b..1f7daef6 100644 --- a/common/model-ratio.go +++ b/relay/billing/ratio/model.go @@ -1,4 +1,4 @@ -package common +package ratio import ( "encoding/json" @@ -63,8 +63,8 @@ var ModelRatio = map[string]float64{ "text-search-ada-doc-001": 10, "text-moderation-stable": 0.1, "text-moderation-latest": 0.1, - "dall-e-2": 8, // $0.016 - $0.020 / image - "dall-e-3": 20, // $0.040 - $0.120 / image + "dall-e-2": 0.02 * USD, // $0.016 - $0.020 / image + "dall-e-3": 0.04 * USD, // $0.040 - $0.120 / image // https://www.anthropic.com/api#pricing "claude-instant-1.2": 0.8 / 1000 * USD, "claude-2.0": 8.0 / 1000 * USD, @@ -73,14 +73,22 @@ var ModelRatio = map[string]float64{ "claude-3-sonnet-20240229": 3.0 / 1000 * USD, "claude-3-opus-20240229": 15.0 / 1000 * USD, // https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7 - "ERNIE-Bot": 0.8572, // ¥0.012 / 1k tokens - "ERNIE-Bot-turbo": 0.5715, // ¥0.008 / 1k tokens - "ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens - "ERNIE-Bot-8K": 0.024 * RMB, - "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens - "bge-large-zh": 0.002 * RMB, - "bge-large-en": 0.002 * RMB, - "bge-large-8k": 0.002 * RMB, + "ERNIE-4.0-8K": 0.120 * RMB, + "ERNIE-3.5-8K": 0.012 * RMB, + "ERNIE-3.5-8K-0205": 0.024 * RMB, + "ERNIE-3.5-8K-1222": 0.012 * RMB, + "ERNIE-Bot-8K": 0.024 * RMB, + "ERNIE-3.5-4K-0205": 0.012 * RMB, + "ERNIE-Speed-8K": 0.004 * RMB, + "ERNIE-Speed-128K": 0.004 * RMB, + "ERNIE-Lite-8K-0922": 0.008 * RMB, + "ERNIE-Lite-8K-0308": 0.003 * RMB, + "ERNIE-Tiny-8K": 0.001 * RMB, + "BLOOMZ-7B": 0.004 * RMB, + "Embedding-V1": 0.002 * RMB, + "bge-large-zh": 0.002 * RMB, + "bge-large-en": 0.002 * RMB, + "tao-8k": 0.002 * RMB, // https://ai.google.dev/pricing "PaLM-2": 1, "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens @@ -89,18 +97,24 @@ var ModelRatio = map[string]float64{ "gemini-1.0-pro-001": 1, "gemini-1.5-pro": 1, // https://open.bigmodel.cn/pricing - "glm-4": 0.1 * RMB, - "glm-4v": 0.1 * RMB, - "glm-3-turbo": 0.005 * RMB, - "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens - "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens - "chatglm_std": 0.3572, // ¥0.005 / 1k tokens - "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens - "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "glm-4": 0.1 * RMB, + "glm-4v": 0.1 * RMB, + "glm-3-turbo": 0.005 * RMB, + "embedding-2": 0.0005 * RMB, + "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens + "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens + "chatglm_std": 0.3572, // ¥0.005 / 1k tokens + "chatglm_lite": 0.1429, // ¥0.002 / 1k tokens + "cogview-3": 0.25 * RMB, + // https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-thousand-questions-metering-and-billing + "qwen-turbo": 0.5715, // ¥0.008 / 1k tokens "qwen-plus": 1.4286, // ¥0.02 / 1k tokens "qwen-max": 1.4286, // ¥0.02 / 1k tokens "qwen-max-longcontext": 1.4286, // ¥0.02 / 1k tokens "text-embedding-v1": 0.05, // ¥0.0007 / 1k tokens + "ali-stable-diffusion-xl": 8, + "ali-stable-diffusion-v1.5": 8, + "wanx-v1": 8, "SparkDesk": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v1.1": 1.2858, // ¥0.018 / 1k tokens "SparkDesk-v2.1": 1.2858, // ¥0.018 / 1k tokens diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go deleted file mode 100644 index b312934b..00000000 --- a/relay/channel/ali/adaptor.go +++ /dev/null @@ -1,83 +0,0 @@ -package ali - -// import ( -// "github.com/Laisky/errors/v2" -// "fmt" -// "github.com/gin-gonic/gin" -// "github.com/songquanpeng/one-api/common" -// "github.com/songquanpeng/one-api/relay/channel" -// "github.com/songquanpeng/one-api/relay/constant" -// "github.com/songquanpeng/one-api/relay/model" -// "github.com/songquanpeng/one-api/relay/util" -// "io" -// "net/http" -// ) - -// // https://help.aliyun.com/zh/dashscope/developer-reference/api-details - -// type Adaptor struct { -// } - -// func (a *Adaptor) Init(meta *util.RelayMeta) { - -// } - -// func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { -// fullRequestURL := fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", meta.BaseURL) -// if meta.Mode == constant.RelayModeEmbeddings { -// fullRequestURL = fmt.Sprintf("%s/api/v1/services/embeddings/text-embedding/text-embedding", meta.BaseURL) -// } -// return fullRequestURL, nil -// } - -// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { -// channel.SetupCommonRequestHeader(c, req, meta) -// req.Header.Set("Authorization", "Bearer "+meta.APIKey) -// if meta.IsStream { -// req.Header.Set("X-DashScope-SSE", "enable") -// } -// if c.GetString(common.ConfigKeyPlugin) != "" { -// req.Header.Set("X-DashScope-Plugin", c.GetString(common.ConfigKeyPlugin)) -// } -// return nil -// } - -// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { -// if request == nil { -// return nil, errors.New("request is nil") -// } -// switch relayMode { -// case constant.RelayModeEmbeddings: -// baiduEmbeddingRequest := ConvertEmbeddingRequest(*request) -// return baiduEmbeddingRequest, nil -// default: -// baiduRequest := ConvertRequest(*request) -// return baiduRequest, nil -// } -// } - -// func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { -// return channel.DoRequestHelper(a, c, meta, requestBody) -// } - -// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { -// if meta.IsStream { -// err, usage = StreamHandler(c, resp) -// } else { -// switch meta.Mode { -// case constant.RelayModeEmbeddings: -// err, usage = EmbeddingHandler(c, resp) -// default: -// err, usage = Handler(c, resp) -// } -// } -// return -// } - -// func (a *Adaptor) GetModelList() []string { -// return ModelList -// } - -// func (a *Adaptor) GetChannelName() string { -// return "ali" -// } diff --git a/relay/channel/ali/model.go b/relay/channel/ali/model.go deleted file mode 100644 index d0cd341c..00000000 --- a/relay/channel/ali/model.go +++ /dev/null @@ -1,71 +0,0 @@ -package ali - -// type Message struct { -// Content string `json:"content"` -// Role string `json:"role"` -// } - -// type Input struct { -// //Prompt string `json:"prompt"` -// Messages []Message `json:"messages"` -// } - -// type Parameters struct { -// TopP float64 `json:"top_p,omitempty"` -// TopK int `json:"top_k,omitempty"` -// Seed uint64 `json:"seed,omitempty"` -// EnableSearch bool `json:"enable_search,omitempty"` -// IncrementalOutput bool `json:"incremental_output,omitempty"` -// } - -// type ChatRequest struct { -// Model string `json:"model"` -// Input Input `json:"input"` -// Parameters Parameters `json:"parameters,omitempty"` -// } - -// type EmbeddingRequest struct { -// Model string `json:"model"` -// Input struct { -// Texts []string `json:"texts"` -// } `json:"input"` -// Parameters *struct { -// TextType string `json:"text_type,omitempty"` -// } `json:"parameters,omitempty"` -// } - -// type Embedding struct { -// Embedding []float64 `json:"embedding"` -// TextIndex int `json:"text_index"` -// } - -// type EmbeddingResponse struct { -// Output struct { -// Embeddings []Embedding `json:"embeddings"` -// } `json:"output"` -// Usage Usage `json:"usage"` -// Error -// } - -// type Error struct { -// Code string `json:"code"` -// Message string `json:"message"` -// RequestId string `json:"request_id"` -// } - -// type Usage struct { -// InputTokens int `json:"input_tokens"` -// OutputTokens int `json:"output_tokens"` -// TotalTokens int `json:"total_tokens"` -// } - -// type Output struct { -// Text string `json:"text"` -// FinishReason string `json:"finish_reason"` -// } - -// type ChatResponse struct { -// Output Output `json:"output"` -// Usage Usage `json:"usage"` -// Error -// } diff --git a/relay/channel/baidu/constants.go b/relay/channel/baidu/constants.go deleted file mode 100644 index 45a4e901..00000000 --- a/relay/channel/baidu/constants.go +++ /dev/null @@ -1,13 +0,0 @@ -package baidu - -var ModelList = []string{ - "ERNIE-Bot-4", - "ERNIE-Bot-8K", - "ERNIE-Bot", - "ERNIE-Speed", - "ERNIE-Bot-turbo", - "Embedding-V1", - "bge-large-zh", - "bge-large-en", - "tao-8k", -} diff --git a/relay/channel/interface.go b/relay/channel/interface.go deleted file mode 100644 index e25db83f..00000000 --- a/relay/channel/interface.go +++ /dev/null @@ -1,20 +0,0 @@ -package channel - -import ( - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - "io" - "net/http" -) - -type Adaptor interface { - Init(meta *util.RelayMeta) - GetRequestURL(meta *util.RelayMeta) (string, error) - SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error - ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) - DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) - DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) - GetModelList() []string - GetChannelName() string -} diff --git a/relay/channel/minimax/main.go b/relay/channel/minimax/main.go deleted file mode 100644 index a3cd0f14..00000000 --- a/relay/channel/minimax/main.go +++ /dev/null @@ -1,16 +0,0 @@ -package minimax - -import ( - "fmt" - - "github.com/Laisky/errors/v2" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/util" -) - -func GetRequestURL(meta *util.RelayMeta) (string, error) { - if meta.Mode == constant.RelayModeChatCompletions { - return fmt.Sprintf("%s/v1/text/chatcompletion_v2", meta.BaseURL), nil - } - return "", errors.Errorf("unsupported relay mode %d for minimax", meta.Mode) -} diff --git a/relay/channel/openai/compatible.go b/relay/channel/openai/compatible.go deleted file mode 100644 index e4951a34..00000000 --- a/relay/channel/openai/compatible.go +++ /dev/null @@ -1,46 +0,0 @@ -package openai - -import ( - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/channel/ai360" - "github.com/songquanpeng/one-api/relay/channel/baichuan" - "github.com/songquanpeng/one-api/relay/channel/groq" - "github.com/songquanpeng/one-api/relay/channel/lingyiwanwu" - "github.com/songquanpeng/one-api/relay/channel/minimax" - "github.com/songquanpeng/one-api/relay/channel/mistral" - "github.com/songquanpeng/one-api/relay/channel/moonshot" -) - -var CompatibleChannels = []int{ - common.ChannelTypeAzure, - common.ChannelType360, - common.ChannelTypeMoonshot, - common.ChannelTypeBaichuan, - common.ChannelTypeMinimax, - common.ChannelTypeMistral, - common.ChannelTypeGroq, - common.ChannelTypeLingYiWanWu, -} - -func GetCompatibleChannelMeta(channelType int) (string, []string) { - switch channelType { - case common.ChannelTypeAzure: - return "azure", ModelList - case common.ChannelType360: - return "360", ai360.ModelList - case common.ChannelTypeMoonshot: - return "moonshot", moonshot.ModelList - case common.ChannelTypeBaichuan: - return "baichuan", baichuan.ModelList - case common.ChannelTypeMinimax: - return "minimax", minimax.ModelList - case common.ChannelTypeMistral: - return "mistralai", mistral.ModelList - case common.ChannelTypeGroq: - return "groq", groq.ModelList - case common.ChannelTypeLingYiWanWu: - return "lingyiwanwu", lingyiwanwu.ModelList - default: - return "openai", ModelList - } -} diff --git a/relay/channel/openai/helper.go b/relay/channel/openai/helper.go deleted file mode 100644 index 9bca8cab..00000000 --- a/relay/channel/openai/helper.go +++ /dev/null @@ -1,11 +0,0 @@ -package openai - -import "github.com/songquanpeng/one-api/relay/model" - -func ResponseText2Usage(responseText string, modeName string, promptTokens int) *model.Usage { - usage := &model.Usage{} - usage.PromptTokens = promptTokens - usage.CompletionTokens = CountTokenText(responseText, modeName) - usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens - return usage -} diff --git a/relay/channel/zhipu/adaptor.go b/relay/channel/zhipu/adaptor.go deleted file mode 100644 index 9c7bdd36..00000000 --- a/relay/channel/zhipu/adaptor.go +++ /dev/null @@ -1,62 +0,0 @@ -package zhipu - -// import ( -// "github.com/Laisky/errors/v2" -// "fmt" -// "github.com/gin-gonic/gin" -// "github.com/songquanpeng/one-api/relay/channel" -// "github.com/songquanpeng/one-api/relay/model" -// "github.com/songquanpeng/one-api/relay/util" -// "io" -// "net/http" -// ) - -// type Adaptor struct { -// } - -// func (a *Adaptor) Init(meta *util.RelayMeta) { - -// } - -// func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) { -// method := "invoke" -// if meta.IsStream { -// method = "sse-invoke" -// } -// return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", meta.BaseURL, meta.ActualModelName, method), nil -// } - -// func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error { -// channel.SetupCommonRequestHeader(c, req, meta) -// token := GetToken(meta.APIKey) -// req.Header.Set("Authorization", token) -// return nil -// } - -// func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { -// if request == nil { -// return nil, errors.New("request is nil") -// } -// return ConvertRequest(*request), nil -// } - -// func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) { -// return channel.DoRequestHelper(a, c, meta, requestBody) -// } - -// func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) { -// if meta.IsStream { -// err, usage = StreamHandler(c, resp) -// } else { -// err, usage = Handler(c, resp) -// } -// return -// } - -// func (a *Adaptor) GetModelList() []string { -// return ModelList -// } - -// func (a *Adaptor) GetChannelName() string { -// return "zhipu" -// } diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go new file mode 100644 index 00000000..80027a80 --- /dev/null +++ b/relay/channeltype/define.go @@ -0,0 +1,39 @@ +package channeltype + +const ( + Unknown = iota + OpenAI + API2D + Azure + CloseAI + OpenAISB + OpenAIMax + OhMyGPT + Custom + Ails + AIProxy + PaLM + API2GPT + AIGC2D + Anthropic + Baidu + Zhipu + Ali + Xunfei + AI360 + OpenRouter + AIProxyLibrary + FastGPT + Tencent + Gemini + Moonshot + Baichuan + Minimax + Mistral + Groq + Ollama + LingYiWanWu + StepFun + + Dummy +) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go new file mode 100644 index 00000000..01c2918c --- /dev/null +++ b/relay/channeltype/helper.go @@ -0,0 +1,30 @@ +package channeltype + +import "github.com/songquanpeng/one-api/relay/apitype" + +func ToAPIType(channelType int) int { + apiType := apitype.OpenAI + switch channelType { + case Anthropic: + apiType = apitype.Anthropic + case Baidu: + apiType = apitype.Baidu + case PaLM: + apiType = apitype.PaLM + case Zhipu: + apiType = apitype.Zhipu + case Ali: + apiType = apitype.Ali + case Xunfei: + apiType = apitype.Xunfei + case AIProxyLibrary: + apiType = apitype.AIProxyLibrary + case Tencent: + apiType = apitype.Tencent + case Gemini: + apiType = apitype.Gemini + case Ollama: + apiType = apitype.Ollama + } + return apiType +} diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go new file mode 100644 index 00000000..eec59116 --- /dev/null +++ b/relay/channeltype/url.go @@ -0,0 +1,43 @@ +package channeltype + +var ChannelBaseURLs = []string{ + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "https://generativelanguage.googleapis.com", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 + "https://hunyuan.cloud.tencent.com", // 23 + "https://generativelanguage.googleapis.com", // 24 + "https://api.moonshot.cn", // 25 + "https://api.baichuan-ai.com", // 26 + "https://api.minimax.chat", // 27 + "https://api.mistral.ai", // 28 + "https://api.groq.com/openai", // 29 + "http://localhost:11434", // 30 + "https://api.lingyiwanwu.com", // 31 + "https://api.stepfun.com", // 32 +} + +func init() { + if len(ChannelBaseURLs) != Dummy { + panic("channel base urls length not match") + } +} diff --git a/relay/util/init.go b/relay/client/init.go similarity index 97% rename from relay/util/init.go rename to relay/client/init.go index b303de0c..73108700 100644 --- a/relay/util/init.go +++ b/relay/client/init.go @@ -1,4 +1,4 @@ -package util +package client import ( "net/http" diff --git a/relay/constant/api_type.go b/relay/constant/api_type.go deleted file mode 100644 index b249f6a2..00000000 --- a/relay/constant/api_type.go +++ /dev/null @@ -1,48 +0,0 @@ -package constant - -import ( - "github.com/songquanpeng/one-api/common" -) - -const ( - APITypeOpenAI = iota - APITypeAnthropic - APITypePaLM - APITypeBaidu - APITypeZhipu - APITypeAli - APITypeXunfei - APITypeAIProxyLibrary - APITypeTencent - APITypeGemini - APITypeOllama - - APITypeDummy // this one is only for count, do not add any channel after this -) - -func ChannelType2APIType(channelType int) int { - apiType := APITypeOpenAI - switch channelType { - case common.ChannelTypeAnthropic: - apiType = APITypeAnthropic - case common.ChannelTypeBaidu: - apiType = APITypeBaidu - case common.ChannelTypePaLM: - apiType = APITypePaLM - case common.ChannelTypeZhipu: - apiType = APITypeZhipu - case common.ChannelTypeAli: - apiType = APITypeAli - case common.ChannelTypeXunfei: - apiType = APITypeXunfei - case common.ChannelTypeAIProxyLibrary: - apiType = APITypeAIProxyLibrary - case common.ChannelTypeTencent: - apiType = APITypeTencent - case common.ChannelTypeGemini: - apiType = APITypeGemini - case common.ChannelTypeOllama: - apiType = APITypeOllama - } - return apiType -} diff --git a/relay/constant/image.go b/relay/constant/image.go deleted file mode 100644 index 5e04895f..00000000 --- a/relay/constant/image.go +++ /dev/null @@ -1,24 +0,0 @@ -package constant - -var DalleSizeRatios = map[string]map[string]float64{ - "dall-e-2": { - "256x256": 1, - "512x512": 1.125, - "1024x1024": 1.25, - }, - "dall-e-3": { - "1024x1024": 1, - "1024x1792": 2, - "1792x1024": 2, - }, -} - -var DalleGenerationImageAmounts = map[string][2]int{ - "dall-e-2": {1, 10}, - "dall-e-3": {1, 1}, // OpenAI allows n=1 currently. -} - -var DalleImagePromptLengthLimitations = map[string]int{ - "dall-e-2": 1000, - "dall-e-3": 4000, -} diff --git a/relay/constant/relay_mode.go b/relay/constant/relay_mode.go deleted file mode 100644 index 5e2fe574..00000000 --- a/relay/constant/relay_mode.go +++ /dev/null @@ -1,42 +0,0 @@ -package constant - -import "strings" - -const ( - RelayModeUnknown = iota - RelayModeChatCompletions - RelayModeCompletions - RelayModeEmbeddings - RelayModeModerations - RelayModeImagesGenerations - RelayModeEdits - RelayModeAudioSpeech - RelayModeAudioTranscription - RelayModeAudioTranslation -) - -func Path2RelayMode(path string) int { - relayMode := RelayModeUnknown - if strings.HasPrefix(path, "/v1/chat/completions") { - relayMode = RelayModeChatCompletions - } else if strings.HasPrefix(path, "/v1/completions") { - relayMode = RelayModeCompletions - } else if strings.HasPrefix(path, "/v1/embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasSuffix(path, "embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasPrefix(path, "/v1/moderations") { - relayMode = RelayModeModerations - } else if strings.HasPrefix(path, "/v1/images/generations") { - relayMode = RelayModeImagesGenerations - } else if strings.HasPrefix(path, "/v1/edits") { - relayMode = RelayModeEdits - } else if strings.HasPrefix(path, "/v1/audio/speech") { - relayMode = RelayModeAudioSpeech - } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { - relayMode = RelayModeAudioTranscription - } else if strings.HasPrefix(path, "/v1/audio/translations") { - relayMode = RelayModeAudioTranslation - } - return relayMode -} diff --git a/relay/controller/audio.go b/relay/controller/audio.go index 85599b1f..58094c22 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -6,20 +6,23 @@ import ( "context" "encoding/json" "fmt" - "io" - "net/http" - "strings" - "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/adaptor/azure" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/client" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/relaymode" + "io" + "net/http" + "strings" ) func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { @@ -34,7 +37,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus tokenName := c.GetString("token_name") var ttsRequest openai.TextToSpeechRequest - if relayMode == constant.RelayModeAudioSpeech { + if relayMode == relaymode.AudioSpeech { // Read JSON err := common.UnmarshalBodyReusable(c, &ttsRequest) // Check if JSON is valid @@ -48,14 +51,15 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - modelRatio := common.GetModelRatio(audioModel) - // groupRatio := common.GetGroupRatio(group) - groupRatio := c.GetFloat64("channel_ratio") + modelRatio := billingratio.GetModelRatio(audioModel) + // groupRatio := billingratio.GetGroupRatio(group) + groupRatio := c.GetFloat64("channel_ratio") // get minimal ratio from multiple groups + ratio := modelRatio * groupRatio var quota int64 var preConsumedQuota int64 switch relayMode { - case constant.RelayModeAudioSpeech: + case relaymode.AudioSpeech: preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota default: @@ -117,19 +121,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } } - baseURL := common.ChannelBaseURLs[channelType] + baseURL := channeltype.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() if c.GetString("base_url") != "" { baseURL = c.GetString("base_url") } - fullRequestURL := util.GetFullRequestURL(baseURL, requestURL, channelType) - if channelType == common.ChannelTypeAzure { - apiVersion := util.GetAzureAPIVersion(c) - if relayMode == constant.RelayModeAudioTranscription { + fullRequestURL := openai.GetFullRequestURL(baseURL, requestURL, channelType) + if channelType == channeltype.Azure { + apiVersion := azure.GetAPIVersion(c) + if relayMode == relaymode.AudioTranscription { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/transcriptions?api-version=%s", baseURL, audioModel, apiVersion) - } else if relayMode == constant.RelayModeAudioSpeech { + } else if relayMode == relaymode.AudioSpeech { // https://learn.microsoft.com/en-us/azure/ai-services/openai/text-to-speech-quickstart?tabs=command-line#rest-api fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/audio/speech?api-version=%s", baseURL, audioModel, apiVersion) } @@ -148,7 +152,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) } - if (relayMode == constant.RelayModeAudioTranscription || relayMode == constant.RelayModeAudioSpeech) && channelType == common.ChannelTypeAzure { + if (relayMode == relaymode.AudioTranscription || relayMode == relaymode.AudioSpeech) && channelType == channeltype.Azure { // https://learn.microsoft.com/en-us/azure/ai-services/openai/whisper-quickstart?tabs=command-line#rest-api apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") @@ -160,7 +164,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) req.Header.Set("Accept", c.Request.Header.Get("Accept")) - resp, err := util.HTTPClient.Do(req) + resp, err := client.HTTPClient.Do(req) if err != nil { return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } @@ -174,7 +178,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode != constant.RelayModeAudioSpeech { + if relayMode != relaymode.AudioSpeech { responseBody, err := io.ReadAll(resp.Body) if err != nil { return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -213,12 +217,12 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } if resp.StatusCode != http.StatusOK { - return util.RelayErrorHandler(resp) + return RelayErrorHandler(resp) } succeed = true quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) { - go util.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) + go billing.PostConsumeQuota(ctx, tokenId, quotaDelta, quota, userId, channelId, modelRatio, groupRatio, audioModel, tokenName) }(c.Request.Context()) for k, v := range resp.Header { diff --git a/relay/controller/error.go b/relay/controller/error.go new file mode 100644 index 00000000..69ece3ec --- /dev/null +++ b/relay/controller/error.go @@ -0,0 +1,91 @@ +package controller + +import ( + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/model" + "io" + "net/http" + "strconv" +) + +type GeneralErrorResponse struct { + Error model.Error `json:"error"` + Message string `json:"message"` + Msg string `json:"msg"` + Err string `json:"err"` + ErrorMsg string `json:"error_msg"` + Header struct { + Message string `json:"message"` + } `json:"header"` + Response struct { + Error struct { + Message string `json:"message"` + } `json:"error"` + } `json:"response"` +} + +func (e GeneralErrorResponse) ToMessage() string { + if e.Error.Message != "" { + return e.Error.Message + } + if e.Message != "" { + return e.Message + } + if e.Msg != "" { + return e.Msg + } + if e.Err != "" { + return e.Err + } + if e.ErrorMsg != "" { + return e.ErrorMsg + } + if e.Header.Message != "" { + return e.Header.Message + } + if e.Response.Error.Message != "" { + return e.Response.Error.Message + } + return "" +} + +func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *model.ErrorWithStatusCode) { + ErrorWithStatusCode = &model.ErrorWithStatusCode{ + StatusCode: resp.StatusCode, + Error: model.Error{ + Message: "", + Type: "upstream_error", + Code: "bad_response_status_code", + Param: strconv.Itoa(resp.StatusCode), + }, + } + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return + } + if config.DebugEnabled { + logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody))) + } + err = resp.Body.Close() + if err != nil { + return + } + var errResponse GeneralErrorResponse + err = json.Unmarshal(responseBody, &errResponse) + if err != nil { + return + } + if errResponse.Error.Message != "" { + // OpenAI format error, so we override the default one + ErrorWithStatusCode.Error = errResponse.Error + } else { + ErrorWithStatusCode.Error.Message = errResponse.ToMessage() + } + if ErrorWithStatusCode.Error.Message == "" { + ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) + } + return +} diff --git a/relay/controller/helper.go b/relay/controller/helper.go index 71dd653e..e07aba89 100644 --- a/relay/controller/helper.go +++ b/relay/controller/helper.go @@ -9,10 +9,13 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/controller/validator" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" + "github.com/songquanpeng/one-api/relay/relaymode" "math" "net/http" ) @@ -23,21 +26,21 @@ func getAndValidateTextRequest(c *gin.Context, relayMode int) (*relaymodel.Gener if err != nil { return nil, err } - if relayMode == constant.RelayModeModerations && textRequest.Model == "" { + if relayMode == relaymode.Moderations && textRequest.Model == "" { textRequest.Model = "text-moderation-latest" } - if relayMode == constant.RelayModeEmbeddings && textRequest.Model == "" { + if relayMode == relaymode.Embeddings && textRequest.Model == "" { textRequest.Model = c.Param("model") } - err = util.ValidateTextRequest(textRequest, relayMode) + err = validator.ValidateTextRequest(textRequest, relayMode) if err != nil { return nil, err } return textRequest, nil } -func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error) { - imageRequest := &openai.ImageRequest{} +func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, error) { + imageRequest := &relaymodel.ImageRequest{} err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { return nil, err @@ -54,9 +57,25 @@ func getImageRequest(c *gin.Context, relayMode int) (*openai.ImageRequest, error return imageRequest, nil } -func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMeta) *relaymodel.ErrorWithStatusCode { +func isValidImageSize(model string, size string) bool { + if model == "cogview-3" { + return true + } + _, ok := billingratio.ImageSizeRatios[model][size] + return ok +} + +func getImageSizeRatio(model string, size string) float64 { + ratio, ok := billingratio.ImageSizeRatios[model][size] + if !ok { + return 1 + } + return ratio +} + +func validateImageRequest(imageRequest *relaymodel.ImageRequest, meta *meta.Meta) *relaymodel.ErrorWithStatusCode { // model validation - _, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] + hasValidSize := isValidImageSize(imageRequest.Model, imageRequest.Size) if !hasValidSize { return openai.ErrorWrapper(errors.New("size not supported for this image model"), "size_not_supported", http.StatusBadRequest) } @@ -64,27 +83,24 @@ func validateImageRequest(imageRequest *openai.ImageRequest, meta *util.RelayMet if imageRequest.Prompt == "" { return openai.ErrorWrapper(errors.New("prompt is required"), "prompt_missing", http.StatusBadRequest) } - if len(imageRequest.Prompt) > constant.DalleImagePromptLengthLimitations[imageRequest.Model] { + if len(imageRequest.Prompt) > billingratio.ImagePromptLengthLimitations[imageRequest.Model] { return openai.ErrorWrapper(errors.New("prompt is too long"), "prompt_too_long", http.StatusBadRequest) } // Number of generated images validation if !isWithinRange(imageRequest.Model, imageRequest.N) { // channel not azure - if meta.ChannelType != common.ChannelTypeAzure { + if meta.ChannelType != channeltype.Azure { return openai.ErrorWrapper(errors.New("invalid value of n"), "n_not_within_range", http.StatusBadRequest) } } return nil } -func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) { +func getImageCostRatio(imageRequest *relaymodel.ImageRequest) (float64, error) { if imageRequest == nil { return 0, errors.New("imageRequest is nil") } - imageCostRatio, hasValidSize := constant.DalleSizeRatios[imageRequest.Model][imageRequest.Size] - if !hasValidSize { - return 0, errors.Errorf("size not supported for this image model: %s", imageRequest.Size) - } + imageCostRatio := getImageSizeRatio(imageRequest.Model, imageRequest.Size) if imageRequest.Quality == "hd" && imageRequest.Model == "dall-e-3" { if imageRequest.Size == "1024x1024" { imageCostRatio *= 2 @@ -97,11 +113,11 @@ func getImageCostRatio(imageRequest *openai.ImageRequest) (float64, error) { func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int) int { switch relayMode { - case constant.RelayModeChatCompletions: + case relaymode.ChatCompletions: return openai.CountTokenMessages(textRequest.Messages, textRequest.Model) - case constant.RelayModeCompletions: + case relaymode.Completions: return openai.CountTokenInput(textRequest.Prompt, textRequest.Model) - case constant.RelayModeModerations: + case relaymode.Moderations: return openai.CountTokenInput(textRequest.Input, textRequest.Model) } return 0 @@ -115,7 +131,7 @@ func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTok return int64(float64(preConsumedTokens) * ratio) } -func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) { +func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *meta.Meta) (int64, *relaymodel.ErrorWithStatusCode) { preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio) userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) @@ -144,13 +160,13 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR return preConsumedQuota, nil } -func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { +func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *meta.Meta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) { if usage == nil { logger.Error(ctx, "usage is nil, which is unexpected") return } var quota int64 - completionRatio := common.GetCompletionRatio(textRequest.Model) + completionRatio := billingratio.GetCompletionRatio(textRequest.Model) promptTokens := usage.PromptTokens completionTokens := usage.CompletionTokens quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) @@ -178,3 +194,14 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R model.UpdateUserUsedQuotaAndRequestCount(meta.UserId, quota) model.UpdateChannelUsedQuota(meta.ChannelId, quota) } + +func getMappedModelName(modelName string, mapping map[string]string) (string, bool) { + if mapping == nil { + return modelName, false + } + mappedModelName := mapping[modelName] + if mappedModelName != "" { + return mappedModelName, true + } + return modelName, false +} diff --git a/relay/controller/image.go b/relay/controller/image.go index d88dc271..ea3e32a0 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -5,35 +5,33 @@ import ( "context" "encoding/json" "fmt" - "github.com/Laisky/errors/v2" "io" "net/http" - "strings" - "github.com/songquanpeng/one-api/common" + "github.com/Laisky/errors/v2" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" relaymodel "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" - - "github.com/gin-gonic/gin" ) func isWithinRange(element string, value int) bool { - if _, ok := constant.DalleGenerationImageAmounts[element]; !ok { + if _, ok := billingratio.ImageGenerationAmounts[element]; !ok { return false } - min := constant.DalleGenerationImageAmounts[element][0] - max := constant.DalleGenerationImageAmounts[element][1] - + min := billingratio.ImageGenerationAmounts[element][0] + max := billingratio.ImageGenerationAmounts[element][1] return value >= min && value <= max } func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() - meta := util.GetRelayMeta(c) + meta := meta.GetByContext(c) imageRequest, err := getImageRequest(c, meta.Mode) if err != nil { logger.Errorf(ctx, "getImageRequest failed: %s", err.Error()) @@ -43,7 +41,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus // map model name var isModelMapped bool meta.OriginModelName = imageRequest.Model - imageRequest.Model, isModelMapped = util.GetMappedModelName(imageRequest.Model, meta.ModelMapping) + imageRequest.Model, isModelMapped = getMappedModelName(imageRequest.Model, meta.ModelMapping) meta.ActualModelName = imageRequest.Model // model validation @@ -57,17 +55,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "get_image_cost_ratio_failed", http.StatusInternalServerError) } - requestURL := c.Request.URL.String() - fullRequestURL := util.GetFullRequestURL(meta.BaseURL, requestURL, meta.ChannelType) - if meta.ChannelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line&pivots=rest-api - apiVersion := util.GetAzureAPIVersion(c) - // https://{resource_name}.openai.azure.com/openai/deployments/dall-e-3/images/generations?api-version=2024-03-01-preview - fullRequestURL = fmt.Sprintf("%s/openai/deployments/%s/images/generations?api-version=%s", meta.BaseURL, imageRequest.Model, apiVersion) - } - var requestBody io.Reader - if isModelMapped || meta.ChannelType == common.ChannelTypeAzure { // make Azure channel request body + if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) @@ -77,9 +66,32 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus requestBody = c.Request.Body } - modelRatio := common.GetModelRatio(imageRequest.Model) - // groupRatio := common.GetGroupRatio(meta.Group) + adaptor := relay.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } + + switch meta.ChannelType { + case channeltype.Ali: + fallthrough + case channeltype.Baidu: + fallthrough + case channeltype.Zhipu: + finalRequest, err := adaptor.ConvertImageRequest(imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) + } + jsonStr, err := json.Marshal(finalRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) + } + + modelRatio := billingratio.GetModelRatio(imageRequest.Model) + // groupRatio := billingratio.GetGroupRatio(meta.Group) groupRatio := c.GetFloat64("channel_ratio") // pre-selected cheapest channel ratio + ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) @@ -89,36 +101,13 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) } - req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return openai.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - token := c.Request.Header.Get("Authorization") - if meta.ChannelType == common.ChannelTypeAzure { // Azure authentication - token = strings.TrimPrefix(token, "Bearer ") - req.Header.Set("api-key", token) - } else { - req.Header.Set("Authorization", token) - } - - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - - resp, err := util.HTTPClient.Do(req) + // do request + resp, err := adaptor.DoRequest(c, meta, requestBody) if err != nil { + logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) } - err = req.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - var imageResponse openai.ImageResponse - defer func(ctx context.Context) { if resp.StatusCode != http.StatusOK { return @@ -141,34 +130,12 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } }(c.Request.Context()) - responseBody, err := io.ReadAll(resp.Body) - - if err != nil { - return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } - err = json.Unmarshal(responseBody, &imageResponse) - if err != nil { - return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError) + // do response + _, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + return respErr } - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - - for k, v := range resp.Header { - c.Writer.Header().Set(k, v[0]) - } - c.Writer.WriteHeader(resp.StatusCode) - - _, err = io.Copy(c.Writer, resp.Body) - if err != nil { - return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } return nil } diff --git a/relay/controller/text.go b/relay/controller/text.go index 282e8f25..beda2822 100644 --- a/relay/controller/text.go +++ b/relay/controller/text.go @@ -10,18 +10,20 @@ import ( "github.com/Laisky/errors/v2" "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/helper" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/apitype" + "github.com/songquanpeng/one-api/relay/billing" + billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/meta" "github.com/songquanpeng/one-api/relay/model" - "github.com/songquanpeng/one-api/relay/util" ) func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { ctx := c.Request.Context() - meta := util.GetRelayMeta(c) + meta := meta.GetByContext(c) // get & validate textRequest textRequest, err := getAndValidateTextRequest(c, meta.Mode) if err != nil { @@ -33,12 +35,13 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { // map model name var isModelMapped bool meta.OriginModelName = textRequest.Model - textRequest.Model, isModelMapped = util.GetMappedModelName(textRequest.Model, meta.ModelMapping) + textRequest.Model, isModelMapped = getMappedModelName(textRequest.Model, meta.ModelMapping) meta.ActualModelName = textRequest.Model // get model ratio & group ratio - modelRatio := common.GetModelRatio(textRequest.Model) - // groupRatio := common.GetGroupRatio(meta.Group) + modelRatio := billingratio.GetModelRatio(textRequest.Model) + // groupRatio := billingratio.GetGroupRatio(meta.Group) groupRatio := meta.ChannelRatio + ratio := modelRatio * groupRatio // pre-consume quota promptTokens := getPromptTokens(textRequest, meta.Mode) @@ -49,16 +52,16 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { return bizErr } - adaptor := helper.GetAdaptor(meta.APIType) + adaptor := relay.GetAdaptor(meta.APIType) if adaptor == nil { return openai.ErrorWrapper(errors.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) } // get request body var requestBody io.Reader - if meta.APIType == constant.APITypeOpenAI { + if meta.APIType == apitype.OpenAI { // no need to convert request for openai - shouldResetRequestBody := isModelMapped || meta.ChannelType == common.ChannelTypeBaichuan // frequency_penalty 0 is not acceptable for baichuan + shouldResetRequestBody := isModelMapped || meta.ChannelType == channeltype.Baichuan // frequency_penalty 0 is not acceptable for baichuan if shouldResetRequestBody { jsonStr, err := json.Marshal(textRequest) if err != nil { @@ -93,10 +96,10 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { } errorHappened := (resp.StatusCode != http.StatusOK) || (meta.IsStream && resp.Header.Get("Content-Type") == "application/json") if errorHappened { - util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) logger.Error(ctx, fmt.Sprintf("relay text [%d] <- %q %q", resp.StatusCode, resp.Request.URL.String(), string(requestBodyBytes))) - return util.RelayErrorHandler(resp) + return RelayErrorHandler(resp) } meta.IsStream = meta.IsStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") @@ -104,7 +107,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode { usage, respErr := adaptor.DoResponse(c, resp, meta) if respErr != nil { logger.Errorf(ctx, "respErr is not nil: %+v", respErr) - util.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) + billing.ReturnPreConsumedQuota(ctx, preConsumedQuota, meta.TokenId) return respErr } // post-consume quota diff --git a/relay/util/validation.go b/relay/controller/validator/validation.go similarity index 76% rename from relay/util/validation.go rename to relay/controller/validator/validation.go index b9d25c2a..3aab6ac8 100644 --- a/relay/util/validation.go +++ b/relay/controller/validator/validation.go @@ -1,10 +1,11 @@ -package util +package validator import ( - "github.com/Laisky/errors/v2" - "github.com/songquanpeng/one-api/relay/constant" - "github.com/songquanpeng/one-api/relay/model" "math" + + "github.com/Laisky/errors/v2" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" ) func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) error { @@ -15,20 +16,20 @@ func ValidateTextRequest(textRequest *model.GeneralOpenAIRequest, relayMode int) return errors.New("model is required") } switch relayMode { - case constant.RelayModeCompletions: + case relaymode.Completions: if textRequest.Prompt == "" { return errors.New("field prompt is required") } - case constant.RelayModeChatCompletions: + case relaymode.ChatCompletions: if textRequest.Messages == nil || len(textRequest.Messages) == 0 { return errors.New("field messages is required") } - case constant.RelayModeEmbeddings: - case constant.RelayModeModerations: + case relaymode.Embeddings: + case relaymode.Moderations: if textRequest.Input == "" { return errors.New("field input is required") } - case constant.RelayModeEdits: + case relaymode.Edits: if textRequest.Instruction == "" { return errors.New("field instruction is required") } diff --git a/relay/helper/main.go b/relay/helper/main.go deleted file mode 100644 index 18bbe51a..00000000 --- a/relay/helper/main.go +++ /dev/null @@ -1,40 +0,0 @@ -package helper - -import ( - "github.com/songquanpeng/one-api/relay/channel" - "github.com/songquanpeng/one-api/relay/channel/aiproxy" - "github.com/songquanpeng/one-api/relay/channel/anthropic" - "github.com/songquanpeng/one-api/relay/channel/gemini" - "github.com/songquanpeng/one-api/relay/channel/ollama" - "github.com/songquanpeng/one-api/relay/channel/openai" - "github.com/songquanpeng/one-api/relay/channel/palm" - "github.com/songquanpeng/one-api/relay/constant" -) - -func GetAdaptor(apiType int) channel.Adaptor { - switch apiType { - case constant.APITypeAIProxyLibrary: - return &aiproxy.Adaptor{} - // case constant.APITypeAli: - // return &ali.Adaptor{} - case constant.APITypeAnthropic: - return &anthropic.Adaptor{} - // case constant.APITypeBaidu: - // return &baidu.Adaptor{} - case constant.APITypeGemini: - return &gemini.Adaptor{} - case constant.APITypeOpenAI: - return &openai.Adaptor{} - case constant.APITypePaLM: - return &palm.Adaptor{} - // case constant.APITypeTencent: - // return &tencent.Adaptor{} - // case constant.APITypeXunfei: - // return &xunfei.Adaptor{} - // case constant.APITypeZhipu: - // return &zhipu.Adaptor{} - case constant.APITypeOllama: - return &ollama.Adaptor{} - } - return nil -} diff --git a/relay/util/relay_meta.go b/relay/meta/relay_meta.go similarity index 64% rename from relay/util/relay_meta.go rename to relay/meta/relay_meta.go index 17135816..a17aa0f0 100644 --- a/relay/util/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -1,13 +1,15 @@ -package util +package meta import ( "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/relay/adaptor/azure" + "github.com/songquanpeng/one-api/relay/channeltype" + "github.com/songquanpeng/one-api/relay/relaymode" "strings" ) -type RelayMeta struct { +type Meta struct { Mode int ChannelType int ChannelId int @@ -29,9 +31,9 @@ type RelayMeta struct { ChannelRatio float64 } -func GetRelayMeta(c *gin.Context) *RelayMeta { - meta := RelayMeta{ - Mode: constant.Path2RelayMode(c.Request.URL.Path), +func GetByContext(c *gin.Context) *Meta { + meta := Meta{ + Mode: relaymode.GetByPath(c.Request.URL.Path), ChannelType: c.GetInt("channel"), ChannelId: c.GetInt("channel_id"), TokenId: c.GetInt("token_id"), @@ -40,18 +42,18 @@ func GetRelayMeta(c *gin.Context) *RelayMeta { Group: c.GetString("group"), ModelMapping: c.GetStringMapString("model_mapping"), BaseURL: c.GetString("base_url"), - APIVersion: c.GetString(common.ConfigKeyAPIVersion), + APIVersion: c.GetString(config.KeyAPIVersion), APIKey: strings.TrimPrefix(c.Request.Header.Get("Authorization"), "Bearer "), Config: nil, RequestURLPath: c.Request.URL.String(), ChannelRatio: c.GetFloat64("channel_ratio"), } - if meta.ChannelType == common.ChannelTypeAzure { - meta.APIVersion = GetAzureAPIVersion(c) + if meta.ChannelType == channeltype.Azure { + meta.APIVersion = azure.GetAPIVersion(c) } if meta.BaseURL == "" { - meta.BaseURL = common.ChannelBaseURLs[meta.ChannelType] + meta.BaseURL = channeltype.ChannelBaseURLs[meta.ChannelType] } - meta.APIType = constant.ChannelType2APIType(meta.ChannelType) + meta.APIType = channeltype.ToAPIType(meta.ChannelType) return &meta } diff --git a/relay/model/image.go b/relay/model/image.go new file mode 100644 index 00000000..bab84256 --- /dev/null +++ b/relay/model/image.go @@ -0,0 +1,12 @@ +package model + +type ImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt" binding:"required"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Style string `json:"style,omitempty"` + User string `json:"user,omitempty"` +} diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go new file mode 100644 index 00000000..96d09438 --- /dev/null +++ b/relay/relaymode/define.go @@ -0,0 +1,14 @@ +package relaymode + +const ( + Unknown = iota + ChatCompletions + Completions + Embeddings + Moderations + ImagesGenerations + Edits + AudioSpeech + AudioTranscription + AudioTranslation +) diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go new file mode 100644 index 00000000..926dd42e --- /dev/null +++ b/relay/relaymode/helper.go @@ -0,0 +1,29 @@ +package relaymode + +import "strings" + +func GetByPath(path string) int { + relayMode := Unknown + if strings.HasPrefix(path, "/v1/chat/completions") { + relayMode = ChatCompletions + } else if strings.HasPrefix(path, "/v1/completions") { + relayMode = Completions + } else if strings.HasPrefix(path, "/v1/embeddings") { + relayMode = Embeddings + } else if strings.HasSuffix(path, "embeddings") { + relayMode = Embeddings + } else if strings.HasPrefix(path, "/v1/moderations") { + relayMode = Moderations + } else if strings.HasPrefix(path, "/v1/images/generations") { + relayMode = ImagesGenerations + } else if strings.HasPrefix(path, "/v1/edits") { + relayMode = Edits + } else if strings.HasPrefix(path, "/v1/audio/speech") { + relayMode = AudioSpeech + } else if strings.HasPrefix(path, "/v1/audio/transcriptions") { + relayMode = AudioTranscription + } else if strings.HasPrefix(path, "/v1/audio/translations") { + relayMode = AudioTranslation + } + return relayMode +} diff --git a/relay/util/billing.go b/relay/util/billing.go deleted file mode 100644 index 495d011e..00000000 --- a/relay/util/billing.go +++ /dev/null @@ -1,19 +0,0 @@ -package util - -import ( - "context" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" -) - -func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) { - if preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) - if err != nil { - logger.Error(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(ctx) - } -} diff --git a/relay/util/common.go b/relay/util/common.go deleted file mode 100644 index 518d0b00..00000000 --- a/relay/util/common.go +++ /dev/null @@ -1,188 +0,0 @@ -package util - -import ( - "context" - "encoding/json" - "fmt" - "github.com/songquanpeng/one-api/common" - "github.com/songquanpeng/one-api/common/config" - "github.com/songquanpeng/one-api/common/logger" - "github.com/songquanpeng/one-api/model" - relaymodel "github.com/songquanpeng/one-api/relay/model" - "io" - "net/http" - "strconv" - "strings" - - "github.com/gin-gonic/gin" -) - -func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool { - if !config.AutomaticDisableChannelEnabled { - return false - } - if err == nil { - return false - } - if statusCode == http.StatusUnauthorized { - return true - } - switch err.Type { - case "insufficient_quota": - return true - // https://docs.anthropic.com/claude/reference/errors - case "authentication_error": - return true - case "permission_error": - return true - case "forbidden": - return true - } - if err.Code == "invalid_api_key" || err.Code == "account_deactivated" { - return true - } - if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic - return true - } else if strings.HasPrefix(err.Message, "This organization has been disabled.") { - return true - } - return false -} - -func ShouldEnableChannel(err error, openAIErr *relaymodel.Error) bool { - if !config.AutomaticEnableChannelEnabled { - return false - } - if err != nil { - return false - } - if openAIErr != nil { - return false - } - return true -} - -type GeneralErrorResponse struct { - Error relaymodel.Error `json:"error"` - Message string `json:"message"` - Msg string `json:"msg"` - Err string `json:"err"` - ErrorMsg string `json:"error_msg"` - Header struct { - Message string `json:"message"` - } `json:"header"` - Response struct { - Error struct { - Message string `json:"message"` - } `json:"error"` - } `json:"response"` -} - -func (e GeneralErrorResponse) ToMessage() string { - if e.Error.Message != "" { - return e.Error.Message - } - if e.Message != "" { - return e.Message - } - if e.Msg != "" { - return e.Msg - } - if e.Err != "" { - return e.Err - } - if e.ErrorMsg != "" { - return e.ErrorMsg - } - if e.Header.Message != "" { - return e.Header.Message - } - if e.Response.Error.Message != "" { - return e.Response.Error.Message - } - return "" -} - -func RelayErrorHandler(resp *http.Response) (ErrorWithStatusCode *relaymodel.ErrorWithStatusCode) { - ErrorWithStatusCode = &relaymodel.ErrorWithStatusCode{ - StatusCode: resp.StatusCode, - Error: relaymodel.Error{ - Message: "", - Type: "upstream_error", - Code: "bad_response_status_code", - Param: strconv.Itoa(resp.StatusCode), - }, - } - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return - } - if config.DebugEnabled { - logger.SysLog(fmt.Sprintf("error happened, status code: %d, response: \n%s", resp.StatusCode, string(responseBody))) - } - err = resp.Body.Close() - if err != nil { - return - } - var errResponse GeneralErrorResponse - err = json.Unmarshal(responseBody, &errResponse) - if err != nil { - return - } - if errResponse.Error.Message != "" { - // OpenAI format error, so we override the default one - ErrorWithStatusCode.Error = errResponse.Error - } else { - ErrorWithStatusCode.Error.Message = errResponse.ToMessage() - } - if ErrorWithStatusCode.Error.Message == "" { - ErrorWithStatusCode.Error.Message = fmt.Sprintf("bad response status code %d", resp.StatusCode) - } - return -} - -func GetFullRequestURL(baseURL string, requestURL string, channelType int) string { - fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL) - - if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { - switch channelType { - case common.ChannelTypeOpenAI: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/v1")) - case common.ChannelTypeAzure: - fullRequestURL = fmt.Sprintf("%s%s", baseURL, strings.TrimPrefix(requestURL, "/openai/deployments")) - } - } - return fullRequestURL -} - -func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) { - // quotaDelta is remaining quota to be consumed - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - logger.SysError("error consuming token remain quota: " + err.Error()) - } - err = model.CacheUpdateUserQuota(ctx, userId) - if err != nil { - logger.SysError("error update user quota cache: " + err.Error()) - } - // totalQuota is total quota consumed - if totalQuota >= 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota) - model.UpdateChannelUsedQuota(channelId, totalQuota) - } - - if totalQuota < 0 { - logger.Error(ctx, fmt.Sprintf("totalQuota consumed is %d, something is wrong", totalQuota)) - } -} - -func GetAzureAPIVersion(c *gin.Context) string { - query := c.Request.URL.Query() - apiVersion := query.Get("api-version") - if apiVersion == "" { - apiVersion = c.GetString(common.ConfigKeyAPIVersion) - } - return apiVersion -} diff --git a/relay/util/model_mapping.go b/relay/util/model_mapping.go deleted file mode 100644 index 39e062a1..00000000 --- a/relay/util/model_mapping.go +++ /dev/null @@ -1,12 +0,0 @@ -package util - -func GetMappedModelName(modelName string, mapping map[string]string) (string, bool) { - if mapping == nil { - return modelName, false - } - mappedModelName := mapping[modelName] - if mappedModelName != "" { - return mappedModelName, true - } - return modelName, false -} diff --git a/router/api-router.go b/router/api.go similarity index 92% rename from router/api-router.go rename to router/api.go index 7af4511a..b9e5de38 100644 --- a/router/api-router.go +++ b/router/api.go @@ -2,6 +2,7 @@ package router import ( "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/controller/auth" "github.com/songquanpeng/one-api/middleware" "github.com/gin-contrib/gzip" @@ -22,11 +23,13 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.GET("/user/get-by-token", middleware.TokenAuth(), controller.GetSelfByToken) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword) - apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), controller.GitHubOAuth) - apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), controller.GenerateOAuthCode) - apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), controller.WeChatAuth) - apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.WeChatBind) + apiRouter.GET("/oauth/github", middleware.CriticalRateLimit(), auth.GitHubOAuth) + apiRouter.GET("/oauth/lark", middleware.CriticalRateLimit(), auth.LarkOAuth) + apiRouter.GET("/oauth/state", middleware.CriticalRateLimit(), auth.GenerateOAuthCode) + apiRouter.GET("/oauth/wechat", middleware.CriticalRateLimit(), auth.WeChatAuth) + apiRouter.GET("/oauth/wechat/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), auth.WeChatBind) apiRouter.GET("/oauth/email/bind", middleware.CriticalRateLimit(), middleware.UserAuth(), controller.EmailBind) + apiRouter.POST("/topup", middleware.AdminAuth(), controller.AdminTopUp) userRoute := apiRouter.Group("/user") { @@ -44,6 +47,7 @@ func SetApiRouter(router *gin.Engine) { selfRoute.GET("/token", controller.GenerateAccessToken) selfRoute.GET("/aff", controller.GetAffCode) selfRoute.POST("/topup", controller.TopUp) + selfRoute.GET("/available_models", controller.GetUserAvailableModels) } adminRoute := userRoute.Group("/") @@ -69,7 +73,7 @@ func SetApiRouter(router *gin.Engine) { { channelRoute.GET("/", controller.GetAllChannels) channelRoute.GET("/search", controller.SearchChannels) - channelRoute.GET("/models", controller.ListModels) + channelRoute.GET("/models", controller.ListAllModels) channelRoute.GET("/:id", controller.GetChannel) channelRoute.GET("/test", controller.TestChannels) channelRoute.GET("/test/:id", controller.TestChannel) diff --git a/router/relay-router.go b/router/relay.go similarity index 100% rename from router/relay-router.go rename to router/relay.go diff --git a/router/web-router.go b/router/web.go similarity index 100% rename from router/web-router.go rename to router/web.go diff --git a/web/README.md b/web/README.md index 29f4713e..829271e2 100644 --- a/web/README.md +++ b/web/README.md @@ -2,6 +2,9 @@ > 每个文件夹代表一个主题,欢迎提交你的主题 +> [!WARNING] +> 不是每一个主题都及时同步了所有功能,由于精力有限,优先更新默认主题,其他主题欢迎 & 期待 PR + ## 提交新的主题 > 欢迎在页面底部保留你和 One API 的版权信息以及指向链接 diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index ec049f7d..b74c58c7 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -107,6 +107,12 @@ export const CHANNEL_OPTIONS = { value: 31, color: 'primary' }, + 32: { + key: 32, + text: '阶跃星辰', + value: 32, + color: 'primary' + }, 8: { key: 8, text: '自定义渠道', diff --git a/web/berry/src/constants/SnackbarConstants.js b/web/berry/src/constants/SnackbarConstants.js index a05c6652..19523da1 100644 --- a/web/berry/src/constants/SnackbarConstants.js +++ b/web/berry/src/constants/SnackbarConstants.js @@ -18,7 +18,7 @@ export const snackbarConstants = { }, NOTICE: { variant: 'info', - autoHideDuration: 20000 + autoHideDuration: 7000 } }, Mobile: { diff --git a/web/berry/src/utils/common.js b/web/berry/src/utils/common.js index aa4b8c37..8925e542 100644 --- a/web/berry/src/utils/common.js +++ b/web/berry/src/utils/common.js @@ -51,9 +51,9 @@ export function showError(error) { export function showNotice(message, isHTML = false) { if (isHTML) { - enqueueSnackbar(, getSnackbarOptions('INFO')); + enqueueSnackbar(, getSnackbarOptions('NOTICE')); } else { - enqueueSnackbar(message, getSnackbarOptions('INFO')); + enqueueSnackbar(message, getSnackbarOptions('NOTICE')); } } diff --git a/web/berry/src/views/Channel/component/EditModal.js b/web/berry/src/views/Channel/component/EditModal.js index 07111c97..cbf411b9 100644 --- a/web/berry/src/views/Channel/component/EditModal.js +++ b/web/berry/src/views/Channel/component/EditModal.js @@ -340,7 +340,9 @@ const EditModal = ({ open, channelId, onCancel, onOk }) => { }, }} > - {Object.values(CHANNEL_OPTIONS).map((option) => { + {Object.values(CHANNEL_OPTIONS).sort((a, b) => { + return a.text.localeCompare(b.text) + }).map((option) => { return ( {option.text} diff --git a/web/berry/src/views/Token/component/EditModal.js b/web/berry/src/views/Token/component/EditModal.js index cf65a96a..a5e08be1 100644 --- a/web/berry/src/views/Token/component/EditModal.js +++ b/web/berry/src/views/Token/component/EditModal.js @@ -103,7 +103,7 @@ const EditModal = ({ open, tokenId, onCancel, onOk }) => { fontSize: "1.125rem", }} > - {tokenId ? "编辑Token" : "新建Token"} + {tokenId ? "编辑令牌" : "新建令牌"} diff --git a/web/default/src/App.js b/web/default/src/App.js index 13c884dc..4ece4eeb 100644 --- a/web/default/src/App.js +++ b/web/default/src/App.js @@ -24,6 +24,7 @@ import EditRedemption from './pages/Redemption/EditRedemption'; import TopUp from './pages/TopUp'; import Log from './pages/Log'; import Chat from './pages/Chat'; +import LarkOAuth from './components/LarkOAuth'; const Home = lazy(() => import('./pages/Home')); const About = lazy(() => import('./pages/About')); @@ -239,6 +240,14 @@ function App() { } /> + }> + + + } + /> { + const [searchParams, setSearchParams] = useSearchParams(); + + const [userState, userDispatch] = useContext(UserContext); + const [prompt, setPrompt] = useState('处理中...'); + const [processing, setProcessing] = useState(true); + + let navigate = useNavigate(); + + const sendCode = async (code, state, count) => { + const res = await API.get(`/api/oauth/lark?code=${code}&state=${state}`); + const { success, message, data } = res.data; + if (success) { + if (message === 'bind') { + showSuccess('绑定成功!'); + navigate('/setting'); + } else { + userDispatch({ type: 'login', payload: data }); + localStorage.setItem('user', JSON.stringify(data)); + showSuccess('登录成功!'); + navigate('/'); + } + } else { + showError(message); + if (count === 0) { + setPrompt(`操作失败,重定向至登录界面中...`); + navigate('/setting'); // in case this is failed to bind lark + return; + } + count++; + setPrompt(`出现错误,第 ${count} 次重试中...`); + await new Promise((resolve) => setTimeout(resolve, count * 2000)); + await sendCode(code, state, count); + } + }; + + useEffect(() => { + let code = searchParams.get('code'); + let state = searchParams.get('state'); + sendCode(code, state, 0).then(); + }, []); + + return ( + + + {prompt} + + + ); +}; + +export default LarkOAuth; diff --git a/web/default/src/components/LoginForm.js b/web/default/src/components/LoginForm.js index b48f64c4..01408f56 100644 --- a/web/default/src/components/LoginForm.js +++ b/web/default/src/components/LoginForm.js @@ -3,7 +3,8 @@ import { Button, Divider, Form, Grid, Header, Image, Message, Modal, Segment } f import { Link, useNavigate, useSearchParams } from 'react-router-dom'; import { UserContext } from '../context/User'; import { API, getLogo, showError, showSuccess, showWarning } from '../helpers'; -import { onGitHubOAuthClicked } from './utils'; +import { onGitHubOAuthClicked, onLarkOAuthClicked } from './utils'; +import larkIcon from '../images/lark.svg'; const LoginForm = () => { const [inputs, setInputs] = useState({ @@ -124,7 +125,7 @@ const LoginForm = () => { 点击注册 - {status.github_oauth || status.wechat_login ? ( + {status.github_oauth || status.wechat_login || status.lark_client_id ? ( <> Or {status.github_oauth ? ( @@ -137,6 +138,18 @@ const LoginForm = () => { ) : ( <> )} + {status.lark_client_id ? ( + + ) : ( + <> + )} {status.wechat_login ? ( ) } + { + status.lark_client_id && ( + + ) + }