diff --git a/README.en.md b/README.en.md index e7f254f7..eec0047b 100644 --- a/README.en.md +++ b/README.en.md @@ -134,12 +134,12 @@ The initial account username is `root` and password is `123456`. git clone https://github.com/songquanpeng/one-api.git # Build the frontend - cd one-api/web + cd one-api/web/default npm install npm run build # Build the backend - cd .. + cd ../.. go mod download go build -ldflags "-s -w" -o one-api ``` diff --git a/README.ja.md b/README.ja.md index edfd2a28..e9149d71 100644 --- a/README.ja.md +++ b/README.ja.md @@ -135,12 +135,12 @@ sudo service nginx restart git clone https://github.com/songquanpeng/one-api.git # フロントエンドのビルド - cd one-api/web + cd one-api/web/default npm install npm run build # バックエンドのビルド - cd .. + cd ../.. go mod download go build -ldflags "-s -w" -o one-api ``` diff --git a/common/config/config.go b/common/config/config.go new file mode 100644 index 00000000..dd0236b4 --- /dev/null +++ b/common/config/config.go @@ -0,0 +1,127 @@ +package config + +import ( + "github.com/songquanpeng/one-api/common/helper" + "os" + "strconv" + "sync" + "time" + + "github.com/google/uuid" +) + +var SystemName = "One API" +var ServerAddress = "http://localhost:3000" +var Footer = "" +var Logo = "" +var TopUpLink = "" +var ChatLink = "" +var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens +var DisplayInCurrencyEnabled = true +var DisplayTokenStatEnabled = true + +// Any options with "Secret", "Token" in its key won't be return by GetOptions + +var SessionSecret = uuid.New().String() + +var OptionMap map[string]string +var OptionMapRWMutex sync.RWMutex + +var ItemsPerPage = 10 +var MaxRecentItems = 100 + +var PasswordLoginEnabled = true +var PasswordRegisterEnabled = true +var EmailVerificationEnabled = false +var GitHubOAuthEnabled = false +var WeChatAuthEnabled = false +var TurnstileCheckEnabled = false +var RegisterEnabled = true + +var EmailDomainRestrictionEnabled = false +var EmailDomainWhitelist = []string{ + "gmail.com", + "163.com", + "126.com", + "qq.com", + "outlook.com", + "hotmail.com", + "icloud.com", + "yahoo.com", + "foxmail.com", +} + +var DebugEnabled = os.Getenv("DEBUG") == "true" +var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" + +var LogConsumeEnabled = true + +var SMTPServer = "" +var SMTPPort = 587 +var SMTPAccount = "" +var SMTPFrom = "" +var SMTPToken = "" + +var GitHubClientId = "" +var GitHubClientSecret = "" + +var WeChatServerAddress = "" +var WeChatServerToken = "" +var WeChatAccountQRCodeImageURL = "" + +var TurnstileSiteKey = "" +var TurnstileSecretKey = "" + +var QuotaForNewUser = 0 +var QuotaForInviter = 0 +var QuotaForInvitee = 0 +var ChannelDisableThreshold = 5.0 +var AutomaticDisableChannelEnabled = false +var AutomaticEnableChannelEnabled = false +var QuotaRemindThreshold = 1000 +var PreConsumedQuota = 500 +var ApproximateTokenEnabled = false +var RetryTimes = 0 + +var RootUserEmail = "" + +var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" + +var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) +var RequestInterval = time.Duration(requestInterval) * time.Second + +var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second + +var BatchUpdateEnabled = false +var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5) + +var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second + +var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") + +var Theme = helper.GetOrDefaultEnvString("THEME", "default") +var ValidThemes = map[string]bool{ + "default": true, + "berry": true, +} + +// All duration's unit is seconds +// Shouldn't larger then RateLimitKeyExpirationDuration +var ( + GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180) + GlobalApiRateLimitDuration int64 = 3 * 60 + + GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60) + GlobalWebRateLimitDuration int64 = 3 * 60 + + UploadRateLimitNum = 10 + UploadRateLimitDuration int64 = 60 + + DownloadRateLimitNum = 10 + DownloadRateLimitDuration int64 = 60 + + CriticalRateLimitNum = 20 + CriticalRateLimitDuration int64 = 20 * 60 +) + +var RateLimitKeyExpirationDuration = 20 * time.Minute diff --git a/common/constants.go b/common/constants.go index 9ee791df..325454d4 100644 --- a/common/constants.go +++ b/common/constants.go @@ -1,114 +1,9 @@ package common -import ( - "os" - "strconv" - "sync" - "time" - - "github.com/google/uuid" -) +import "time" var StartTime = time.Now().Unix() // unit: second var Version = "v0.0.0" // this hard coding will be replaced automatically when building, no need to manually change -var SystemName = "One API" -var ServerAddress = "http://localhost:3000" -var Footer = "" -var Logo = "" -var TopUpLink = "" -var ChatLink = "" -var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens -var DisplayInCurrencyEnabled = true -var DisplayTokenStatEnabled = true - -// Any options with "Secret", "Token" in its key won't be return by GetOptions - -var SessionSecret = uuid.New().String() - -var OptionMap map[string]string -var OptionMapRWMutex sync.RWMutex - -var ItemsPerPage = 10 -var MaxRecentItems = 100 - -var PasswordLoginEnabled = true -var PasswordRegisterEnabled = true -var EmailVerificationEnabled = false -var GitHubOAuthEnabled = false -var WeChatAuthEnabled = false -var TurnstileCheckEnabled = false -var RegisterEnabled = true - -var EmailDomainRestrictionEnabled = false -var EmailDomainWhitelist = []string{ - "gmail.com", - "163.com", - "126.com", - "qq.com", - "outlook.com", - "hotmail.com", - "icloud.com", - "yahoo.com", - "foxmail.com", -} - -var DebugEnabled = os.Getenv("DEBUG") == "true" -var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true" - -var LogConsumeEnabled = true - -var SMTPServer = "" -var SMTPPort = 587 -var SMTPAccount = "" -var SMTPFrom = "" -var SMTPToken = "" - -var GitHubClientId = "" -var GitHubClientSecret = "" - -var WeChatServerAddress = "" -var WeChatServerToken = "" -var WeChatAccountQRCodeImageURL = "" - -var TurnstileSiteKey = "" -var TurnstileSecretKey = "" - -var QuotaForNewUser = 0 -var QuotaForInviter = 0 -var QuotaForInvitee = 0 -var ChannelDisableThreshold = 5.0 -var AutomaticDisableChannelEnabled = false -var AutomaticEnableChannelEnabled = false -var QuotaRemindThreshold = 1000 -var PreConsumedQuota = 500 -var ApproximateTokenEnabled = false -var RetryTimes = 0 - -var RootUserEmail = "" - -var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" - -var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) -var RequestInterval = time.Duration(requestInterval) * time.Second - -var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second - -var BatchUpdateEnabled = false -var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) - -var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second - -var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") - -var Theme = GetOrDefaultString("THEME", "default") -var ValidThemes = map[string]bool{ - "default": true, - "berry": true, -} - -const ( - RequestIdKey = "X-Oneapi-Request-Id" -) const ( RoleGuestUser = 0 @@ -117,34 +12,6 @@ const ( RoleRootUser = 100 ) -var ( - FileUploadPermission = RoleGuestUser - FileDownloadPermission = RoleGuestUser - ImageUploadPermission = RoleGuestUser - ImageDownloadPermission = RoleGuestUser -) - -// All duration's unit is seconds -// Shouldn't larger then RateLimitKeyExpirationDuration -var ( - GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) - GlobalApiRateLimitDuration int64 = 3 * 60 - - GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) - GlobalWebRateLimitDuration int64 = 3 * 60 - - UploadRateLimitNum = 10 - UploadRateLimitDuration int64 = 60 - - DownloadRateLimitNum = 10 - DownloadRateLimitDuration int64 = 60 - - CriticalRateLimitNum = 20 - CriticalRateLimitDuration int64 = 20 * 60 -) - -var RateLimitKeyExpirationDuration = 20 * time.Minute - const ( UserStatusEnabled = 1 // don't use 0, 0 is the default value! UserStatusDisabled = 2 // also don't use 0 @@ -199,29 +66,29 @@ const ( ) var ChannelBaseURLs = []string{ - "", // 0 - "https://api.openai.com", // 1 - "https://oa.api2d.net", // 2 - "", // 3 - "https://api.closeai-proxy.xyz", // 4 - "https://api.openai-sb.com", // 5 - "https://api.openaimax.com", // 6 - "https://api.ohmygpt.com", // 7 - "", // 8 - "https://api.caipacity.com", // 9 - "https://api.aiproxy.io", // 10 - "", // 11 - "https://api.api2gpt.com", // 12 - "https://api.aigc2d.com", // 13 - "https://api.anthropic.com", // 14 - "https://aip.baidubce.com", // 15 - "https://open.bigmodel.cn", // 16 - "https://dashscope.aliyuncs.com", // 17 - "", // 18 - "https://ai.360.cn", // 19 - "https://openrouter.ai/api", // 20 - "https://api.aiproxy.io", // 21 - "https://fastgpt.run/api/openapi", // 22 - "https://hunyuan.cloud.tencent.com", //23 - "", //24 + "", // 0 + "https://api.openai.com", // 1 + "https://oa.api2d.net", // 2 + "", // 3 + "https://api.closeai-proxy.xyz", // 4 + "https://api.openai-sb.com", // 5 + "https://api.openaimax.com", // 6 + "https://api.ohmygpt.com", // 7 + "", // 8 + "https://api.caipacity.com", // 9 + "https://api.aiproxy.io", // 10 + "https://generativelanguage.googleapis.com", // 11 + "https://api.api2gpt.com", // 12 + "https://api.aigc2d.com", // 13 + "https://api.anthropic.com", // 14 + "https://aip.baidubce.com", // 15 + "https://open.bigmodel.cn", // 16 + "https://dashscope.aliyuncs.com", // 17 + "", // 18 + "https://ai.360.cn", // 19 + "https://openrouter.ai/api", // 20 + "https://api.aiproxy.io", // 21 + "https://fastgpt.run/api/openapi", // 22 + "https://hunyuan.cloud.tencent.com", // 23 + "https://generativelanguage.googleapis.com", // 24 } diff --git a/common/database.go b/common/database.go index 76f2cd55..9b52a0d5 100644 --- a/common/database.go +++ b/common/database.go @@ -1,7 +1,9 @@ package common +import "github.com/songquanpeng/one-api/common/helper" + var UsingSQLite = false var UsingPostgreSQL = false var SQLitePath = "one-api.db" -var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000) +var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000) diff --git a/common/email.go b/common/email.go index 9120bea6..5657a852 100644 --- a/common/email.go +++ b/common/email.go @@ -5,19 +5,20 @@ import ( "crypto/tls" "encoding/base64" "fmt" + "github.com/songquanpeng/one-api/common/config" "net/smtp" "strings" "time" ) func SendEmail(subject string, receiver string, content string) error { - if SMTPFrom == "" { // for compatibility - SMTPFrom = SMTPAccount + if config.SMTPFrom == "" { // for compatibility + config.SMTPFrom = config.SMTPAccount } encodedSubject := fmt.Sprintf("=?UTF-8?B?%s?=", base64.StdEncoding.EncodeToString([]byte(subject))) // Extract domain from SMTPFrom - parts := strings.Split(SMTPFrom, "@") + parts := strings.Split(config.SMTPFrom, "@") var domain string if len(parts) > 1 { domain = parts[1] @@ -36,21 +37,21 @@ func SendEmail(subject string, receiver string, content string) error { "Message-ID: %s\r\n"+ // add Message-ID header to avoid being treated as spam, RFC 5322 "Date: %s\r\n"+ "Content-Type: text/html; charset=UTF-8\r\n\r\n%s\r\n", - receiver, SystemName, SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) - auth := smtp.PlainAuth("", SMTPAccount, SMTPToken, SMTPServer) - addr := fmt.Sprintf("%s:%d", SMTPServer, SMTPPort) + receiver, config.SystemName, config.SMTPFrom, encodedSubject, messageId, time.Now().Format(time.RFC1123Z), content)) + auth := smtp.PlainAuth("", config.SMTPAccount, config.SMTPToken, config.SMTPServer) + addr := fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort) to := strings.Split(receiver, ";") - if SMTPPort == 465 { + if config.SMTPPort == 465 { tlsConfig := &tls.Config{ - // InsecureSkipVerify: true, - ServerName: SMTPServer, + InsecureSkipVerify: false, + ServerName: config.SMTPServer, } - conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", SMTPServer, SMTPPort), tlsConfig) + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", config.SMTPServer, config.SMTPPort), tlsConfig) if err != nil { return err } - client, err := smtp.NewClient(conn, SMTPServer) + client, err := smtp.NewClient(conn, config.SMTPServer) if err != nil { return err } @@ -58,7 +59,7 @@ func SendEmail(subject string, receiver string, content string) error { if err = client.Auth(auth); err != nil { return err } - if err = client.Mail(SMTPFrom); err != nil { + if err = client.Mail(config.SMTPFrom); err != nil { return err } receiverEmails := strings.Split(receiver, ";") @@ -80,7 +81,7 @@ func SendEmail(subject string, receiver string, content string) error { return err } } else { - err = smtp.SendMail(addr, auth, SMTPAccount, to, mail) + err = smtp.SendMail(addr, auth, config.SMTPAccount, to, mail) } return err } diff --git a/common/gin.go b/common/gin.go index f5012688..bed2c2b1 100644 --- a/common/gin.go +++ b/common/gin.go @@ -31,3 +31,11 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) return nil } + +func SetEventStreamHeaders(c *gin.Context) { + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("Transfer-Encoding", "chunked") + c.Writer.Header().Set("X-Accel-Buffering", "no") +} diff --git a/common/group-ratio.go b/common/group-ratio.go index 1ec73c78..2de6e810 100644 --- a/common/group-ratio.go +++ b/common/group-ratio.go @@ -1,6 +1,9 @@ package common -import "encoding/json" +import ( + "encoding/json" + "github.com/songquanpeng/one-api/common/logger" +) var GroupRatio = map[string]float64{ "default": 1, @@ -11,7 +14,7 @@ var GroupRatio = map[string]float64{ func GroupRatio2JSONString() string { jsonBytes, err := json.Marshal(GroupRatio) if err != nil { - SysError("error marshalling model ratio: " + err.Error()) + logger.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -24,7 +27,7 @@ func UpdateGroupRatioByJSONString(jsonStr string) error { func GetGroupRatio(name string) float64 { ratio, ok := GroupRatio[name] if !ok { - SysError("group ratio not found: " + name) + logger.SysError("group ratio not found: " + name) return 1 } return ratio diff --git a/common/helper/helper.go b/common/helper/helper.go new file mode 100644 index 00000000..12c66d18 --- /dev/null +++ b/common/helper/helper.go @@ -0,0 +1,224 @@ +package helper + +import ( + "fmt" + "github.com/google/uuid" + "github.com/songquanpeng/one-api/common/logger" + "html/template" + "log" + "math/rand" + "net" + "os" + "os/exec" + "runtime" + "strconv" + "strings" + "time" +) + +func OpenBrowser(url string) { + var err error + + switch runtime.GOOS { + case "linux": + err = exec.Command("xdg-open", url).Start() + case "windows": + err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() + case "darwin": + err = exec.Command("open", url).Start() + } + if err != nil { + log.Println(err) + } +} + +func GetIp() (ip string) { + ips, err := net.InterfaceAddrs() + if err != nil { + log.Println(err) + return ip + } + + for _, a := range ips { + if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.To4() != nil { + ip = ipNet.IP.String() + if strings.HasPrefix(ip, "10") { + return + } + if strings.HasPrefix(ip, "172") { + return + } + if strings.HasPrefix(ip, "192.168") { + return + } + ip = "" + } + } + } + return +} + +var sizeKB = 1024 +var sizeMB = sizeKB * 1024 +var sizeGB = sizeMB * 1024 + +func Bytes2Size(num int64) string { + numStr := "" + unit := "B" + if num/int64(sizeGB) > 1 { + numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) + unit = "GB" + } else if num/int64(sizeMB) > 1 { + numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) + unit = "MB" + } else if num/int64(sizeKB) > 1 { + numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) + unit = "KB" + } else { + numStr = fmt.Sprintf("%d", num) + } + return numStr + " " + unit +} + +func Seconds2Time(num int) (time string) { + if num/31104000 > 0 { + time += strconv.Itoa(num/31104000) + " 年 " + num %= 31104000 + } + if num/2592000 > 0 { + time += strconv.Itoa(num/2592000) + " 个月 " + num %= 2592000 + } + if num/86400 > 0 { + time += strconv.Itoa(num/86400) + " 天 " + num %= 86400 + } + if num/3600 > 0 { + time += strconv.Itoa(num/3600) + " 小时 " + num %= 3600 + } + if num/60 > 0 { + time += strconv.Itoa(num/60) + " 分钟 " + num %= 60 + } + time += strconv.Itoa(num) + " 秒" + return +} + +func Interface2String(inter interface{}) string { + switch inter.(type) { + case string: + return inter.(string) + case int: + return fmt.Sprintf("%d", inter.(int)) + case float64: + return fmt.Sprintf("%f", inter.(float64)) + } + return "Not Implemented" +} + +func UnescapeHTML(x string) interface{} { + return template.HTML(x) +} + +func IntMax(a int, b int) int { + if a >= b { + return a + } else { + return b + } +} + +func GetUUID() string { + code := uuid.New().String() + code = strings.Replace(code, "-", "", -1) + return code +} + +const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func GenerateKey() string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, 48) + for i := 0; i < 16; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + uuid_ := GetUUID() + for i := 0; i < 32; i++ { + c := uuid_[i] + if i%2 == 0 && c >= 'a' && c <= 'z' { + c = c - 'a' + 'A' + } + key[i+16] = c + } + return string(key) +} + +func GetRandomString(length int) string { + rand.Seed(time.Now().UnixNano()) + key := make([]byte, length) + for i := 0; i < length; i++ { + key[i] = keyChars[rand.Intn(len(keyChars))] + } + return string(key) +} + +func GetTimestamp() int64 { + return time.Now().Unix() +} + +func GetTimeString() string { + now := time.Now() + return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) +} + +func Max(a int, b int) int { + if a >= b { + return a + } else { + return b + } +} + +func GetOrDefaultEnvInt(env string, defaultValue int) int { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.Atoi(os.Getenv(env)) + if err != nil { + logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) + return defaultValue + } + return num +} + +func GetOrDefaultEnvString(env string, defaultValue string) string { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) +} + +func AssignOrDefault(value string, defaultValue string) string { + if len(value) != 0 { + return value + } + return defaultValue +} + +func MessageWithRequestId(message string, id string) string { + return fmt.Sprintf("%s (request id: %s)", message, id) +} + +func String2Int(str string) int { + num, err := strconv.Atoi(str) + if err != nil { + return 0 + } + return num +} diff --git a/common/image/image_test.go b/common/image/image_test.go index 8e47b109..15ed78bc 100644 --- a/common/image/image_test.go +++ b/common/image/image_test.go @@ -12,7 +12,7 @@ import ( "strings" "testing" - img "one-api/common/image" + img "github.com/songquanpeng/one-api/common/image" "github.com/stretchr/testify/assert" _ "golang.org/x/image/webp" diff --git a/common/init.go b/common/init.go index 26dd6086..2c1204d6 100644 --- a/common/init.go +++ b/common/init.go @@ -3,6 +3,8 @@ package common import ( "flag" "fmt" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" "log" "os" "path/filepath" @@ -37,9 +39,9 @@ func init() { if os.Getenv("SESSION_SECRET") != "" { if os.Getenv("SESSION_SECRET") == "random_string" { - SysError("SESSION_SECRET is set to an example value, please change it to a random string.") + logger.SysError("SESSION_SECRET is set to an example value, please change it to a random string.") } else { - SessionSecret = os.Getenv("SESSION_SECRET") + config.SessionSecret = os.Getenv("SESSION_SECRET") } } if os.Getenv("SQLITE_PATH") != "" { @@ -57,5 +59,6 @@ func init() { log.Fatal(err) } } + logger.LogDir = *LogDir } } diff --git a/common/logger/constants.go b/common/logger/constants.go new file mode 100644 index 00000000..78d32062 --- /dev/null +++ b/common/logger/constants.go @@ -0,0 +1,7 @@ +package logger + +const ( + RequestIdKey = "X-Oneapi-Request-Id" +) + +var LogDir string diff --git a/common/logger.go b/common/logger/logger.go similarity index 76% rename from common/logger.go rename to common/logger/logger.go index 61627217..b89dbdb7 100644 --- a/common/logger.go +++ b/common/logger/logger.go @@ -1,4 +1,4 @@ -package common +package logger import ( "context" @@ -25,7 +25,7 @@ var setupLogLock sync.Mutex var setupLogWorking bool func SetupLogger() { - if *LogDir != "" { + if LogDir != "" { ok := setupLogLock.TryLock() if !ok { log.Println("setup log is already working") @@ -35,7 +35,7 @@ func SetupLogger() { setupLogLock.Unlock() setupLogWorking = false }() - logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) + logPath := filepath.Join(LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102"))) fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal("failed to open log file") @@ -55,18 +55,30 @@ func SysError(s string) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[SYS] %v | %s \n", t.Format("2006/01/02 - 15:04:05"), s) } -func LogInfo(ctx context.Context, msg string) { +func Info(ctx context.Context, msg string) { logHelper(ctx, loggerINFO, msg) } -func LogWarn(ctx context.Context, msg string) { +func Warn(ctx context.Context, msg string) { logHelper(ctx, loggerWarn, msg) } -func LogError(ctx context.Context, msg string) { +func Error(ctx context.Context, msg string) { logHelper(ctx, loggerError, msg) } +func Infof(ctx context.Context, format string, a ...any) { + Info(ctx, fmt.Sprintf(format, a)) +} + +func Warnf(ctx context.Context, format string, a ...any) { + Warn(ctx, fmt.Sprintf(format, a)) +} + +func Errorf(ctx context.Context, format string, a ...any) { + Error(ctx, fmt.Sprintf(format, a)) +} + func logHelper(ctx context.Context, level string, msg string) { writer := gin.DefaultErrorWriter if level == loggerINFO { @@ -90,11 +102,3 @@ func FatalLog(v ...any) { _, _ = fmt.Fprintf(gin.DefaultErrorWriter, "[FATAL] %v | %v \n", t.Format("2006/01/02 - 15:04:05"), v) os.Exit(1) } - -func LogQuota(quota int) string { - if DisplayInCurrencyEnabled { - return fmt.Sprintf("$%.6f 额度", float64(quota)/QuotaPerUnit) - } else { - return fmt.Sprintf("%d 点额度", quota) - } -} diff --git a/common/model-ratio.go b/common/model-ratio.go index 97cb060d..08cde8c7 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -2,6 +2,7 @@ package common import ( "encoding/json" + "github.com/songquanpeng/one-api/common/logger" "strings" "time" ) @@ -44,6 +45,8 @@ var ModelRatio = map[string]float64{ "gpt-4-32k-0314": 30, "gpt-4-32k-0613": 30, "gpt-4-1106-preview": 5, // $0.01 / 1K tokens + "gpt-4-0125-preview": 5, // $0.01 / 1K tokens + "gpt-4-turbo-preview": 5, // $0.01 / 1K tokens "gpt-4-vision-preview": 5, // $0.01 / 1K tokens "gpt-3.5-turbo": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo-0301": 0.75, @@ -52,6 +55,7 @@ var ModelRatio = map[string]float64{ "gpt-3.5-turbo-16k-0613": 1.5, "gpt-3.5-turbo-instruct": 0.75, // $0.0015 / 1K tokens "gpt-3.5-turbo-1106": 0.5, // $0.001 / 1K tokens + "gpt-3.5-turbo-0125": 0.25, // $0.0005 / 1K tokens "davinci-002": 1, // $0.002 / 1K tokens "babbage-002": 0.2, // $0.0004 / 1K tokens "text-ada-001": 0.2, @@ -71,6 +75,8 @@ var ModelRatio = map[string]float64{ "babbage": 10, "ada": 10, "text-embedding-ada-002": 0.05, + "text-embedding-3-small": 0.01, + "text-embedding-3-large": 0.065, "text-search-ada-doc-001": 10, "text-moderation-stable": 0.1, "text-moderation-latest": 0.1, @@ -107,7 +113,7 @@ var ModelRatio = map[string]float64{ func ModelRatio2JSONString() string { jsonBytes, err := json.Marshal(ModelRatio) if err != nil { - SysError("error marshalling model ratio: " + err.Error()) + logger.SysError("error marshalling model ratio: " + err.Error()) } return string(jsonBytes) } @@ -123,14 +129,37 @@ func GetModelRatio(name string) float64 { } ratio, ok := ModelRatio[name] if !ok { - SysError("model ratio not found: " + name) + logger.SysError("model ratio not found: " + name) return 30 } return ratio } +var CompletionRatio = map[string]float64{} + +func CompletionRatio2JSONString() string { + jsonBytes, err := json.Marshal(CompletionRatio) + if err != nil { + logger.SysError("error marshalling completion ratio: " + err.Error()) + } + return string(jsonBytes) +} + +func UpdateCompletionRatioByJSONString(jsonStr string) error { + CompletionRatio = make(map[string]float64) + return json.Unmarshal([]byte(jsonStr), &CompletionRatio) +} + func GetCompletionRatio(name string) float64 { + if ratio, ok := CompletionRatio[name]; ok { + return ratio + } if strings.HasPrefix(name, "gpt-3.5") { + if strings.HasSuffix(name, "0125") { + // https://openai.com/blog/new-embedding-models-and-api-updates + // Updated GPT-3.5 Turbo model and lower pricing + return 3 + } if strings.HasSuffix(name, "1106") { return 2 } diff --git a/common/redis.go b/common/redis.go index 12c477b8..f3205567 100644 --- a/common/redis.go +++ b/common/redis.go @@ -3,6 +3,7 @@ package common import ( "context" "github.com/go-redis/redis/v8" + "github.com/songquanpeng/one-api/common/logger" "os" "time" ) @@ -14,18 +15,18 @@ var RedisEnabled = true func InitRedisClient() (err error) { if os.Getenv("REDIS_CONN_STRING") == "" { RedisEnabled = false - SysLog("REDIS_CONN_STRING not set, Redis is not enabled") + logger.SysLog("REDIS_CONN_STRING not set, Redis is not enabled") return nil } if os.Getenv("SYNC_FREQUENCY") == "" { RedisEnabled = false - SysLog("SYNC_FREQUENCY not set, Redis is disabled") + logger.SysLog("SYNC_FREQUENCY not set, Redis is disabled") return nil } - SysLog("Redis is enabled") + logger.SysLog("Redis is enabled") opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) if err != nil { - FatalLog("failed to parse Redis connection string: " + err.Error()) + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) } RDB = redis.NewClient(opt) @@ -34,7 +35,7 @@ func InitRedisClient() (err error) { _, err = RDB.Ping(ctx).Result() if err != nil { - FatalLog("Redis ping test failed: " + err.Error()) + logger.FatalLog("Redis ping test failed: " + err.Error()) } return err } @@ -42,7 +43,7 @@ func InitRedisClient() (err error) { func ParseRedisOption() *redis.Options { opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING")) if err != nil { - FatalLog("failed to parse Redis connection string: " + err.Error()) + logger.FatalLog("failed to parse Redis connection string: " + err.Error()) } return opt } diff --git a/common/utils.go b/common/utils.go index 9a7038e2..24615225 100644 --- a/common/utils.go +++ b/common/utils.go @@ -2,215 +2,13 @@ package common import ( "fmt" - "github.com/google/uuid" - "html/template" - "log" - "math/rand" - "net" - "os" - "os/exec" - "runtime" - "strconv" - "strings" - "time" + "github.com/songquanpeng/one-api/common/config" ) -func OpenBrowser(url string) { - var err error - - switch runtime.GOOS { - case "linux": - err = exec.Command("xdg-open", url).Start() - case "windows": - err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() - case "darwin": - err = exec.Command("open", url).Start() - } - if err != nil { - log.Println(err) - } -} - -func GetIp() (ip string) { - ips, err := net.InterfaceAddrs() - if err != nil { - log.Println(err) - return ip - } - - for _, a := range ips { - if ipNet, ok := a.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { - if ipNet.IP.To4() != nil { - ip = ipNet.IP.String() - if strings.HasPrefix(ip, "10") { - return - } - if strings.HasPrefix(ip, "172") { - return - } - if strings.HasPrefix(ip, "192.168") { - return - } - ip = "" - } - } - } - return -} - -var sizeKB = 1024 -var sizeMB = sizeKB * 1024 -var sizeGB = sizeMB * 1024 - -func Bytes2Size(num int64) string { - numStr := "" - unit := "B" - if num/int64(sizeGB) > 1 { - numStr = fmt.Sprintf("%.2f", float64(num)/float64(sizeGB)) - unit = "GB" - } else if num/int64(sizeMB) > 1 { - numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeMB))) - unit = "MB" - } else if num/int64(sizeKB) > 1 { - numStr = fmt.Sprintf("%d", int(float64(num)/float64(sizeKB))) - unit = "KB" +func LogQuota(quota int) string { + if config.DisplayInCurrencyEnabled { + return fmt.Sprintf("$%.6f 额度", float64(quota)/config.QuotaPerUnit) } else { - numStr = fmt.Sprintf("%d", num) - } - return numStr + " " + unit -} - -func Seconds2Time(num int) (time string) { - if num/31104000 > 0 { - time += strconv.Itoa(num/31104000) + " 年 " - num %= 31104000 - } - if num/2592000 > 0 { - time += strconv.Itoa(num/2592000) + " 个月 " - num %= 2592000 - } - if num/86400 > 0 { - time += strconv.Itoa(num/86400) + " 天 " - num %= 86400 - } - if num/3600 > 0 { - time += strconv.Itoa(num/3600) + " 小时 " - num %= 3600 - } - if num/60 > 0 { - time += strconv.Itoa(num/60) + " 分钟 " - num %= 60 - } - time += strconv.Itoa(num) + " 秒" - return -} - -func Interface2String(inter interface{}) string { - switch inter.(type) { - case string: - return inter.(string) - case int: - return fmt.Sprintf("%d", inter.(int)) - case float64: - return fmt.Sprintf("%f", inter.(float64)) - } - return "Not Implemented" -} - -func UnescapeHTML(x string) interface{} { - return template.HTML(x) -} - -func IntMax(a int, b int) int { - if a >= b { - return a - } else { - return b + return fmt.Sprintf("%d 点额度", quota) } } - -func GetUUID() string { - code := uuid.New().String() - code = strings.Replace(code, "-", "", -1) - return code -} - -const keyChars = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func GenerateKey() string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, 48) - for i := 0; i < 16; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - uuid_ := GetUUID() - for i := 0; i < 32; i++ { - c := uuid_[i] - if i%2 == 0 && c >= 'a' && c <= 'z' { - c = c - 'a' + 'A' - } - key[i+16] = c - } - return string(key) -} - -func GetRandomString(length int) string { - rand.Seed(time.Now().UnixNano()) - key := make([]byte, length) - for i := 0; i < length; i++ { - key[i] = keyChars[rand.Intn(len(keyChars))] - } - return string(key) -} - -func GetTimestamp() int64 { - return time.Now().Unix() -} - -func GetTimeString() string { - now := time.Now() - return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9) -} - -func Max(a int, b int) int { - if a >= b { - return a - } else { - return b - } -} - -func GetOrDefault(env string, defaultValue int) int { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - num, err := strconv.Atoi(os.Getenv(env)) - if err != nil { - SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) - return defaultValue - } - return num -} - -func GetOrDefaultString(env string, defaultValue string) string { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - return os.Getenv(env) -} - -func MessageWithRequestId(message string, id string) string { - return fmt.Sprintf("%s (request id: %s)", message, id) -} - -func String2Int(str string) int { - num, err := strconv.Atoi(str) - if err != nil { - return 0 - } - return num -} diff --git a/controller/billing.go b/controller/billing.go index 42e86aea..c08f2a2d 100644 --- a/controller/billing.go +++ b/controller/billing.go @@ -2,8 +2,9 @@ package controller import ( "github.com/gin-gonic/gin" - "one-api/common" - "one-api/model" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/channel/openai" ) func GetSubscription(c *gin.Context) { @@ -12,7 +13,7 @@ func GetSubscription(c *gin.Context) { var err error var token *model.Token var expiredTime int64 - if common.DisplayTokenStatEnabled { + if config.DisplayTokenStatEnabled { tokenId := c.GetInt("token_id") token, err = model.GetTokenById(tokenId) expiredTime = token.ExpiredTime @@ -27,19 +28,19 @@ func GetSubscription(c *gin.Context) { expiredTime = 0 } if err != nil { - openAIError := OpenAIError{ + Error := openai.Error{ Message: err.Error(), Type: "upstream_error", } c.JSON(200, gin.H{ - "error": openAIError, + "error": Error, }) return } quota := remainQuota + usedQuota amount := float64(quota) - if common.DisplayInCurrencyEnabled { - amount /= common.QuotaPerUnit + if config.DisplayInCurrencyEnabled { + amount /= config.QuotaPerUnit } if token != nil && token.UnlimitedQuota { amount = 100000000 @@ -60,7 +61,7 @@ func GetUsage(c *gin.Context) { var quota int var err error var token *model.Token - if common.DisplayTokenStatEnabled { + if config.DisplayTokenStatEnabled { tokenId := c.GetInt("token_id") token, err = model.GetTokenById(tokenId) quota = token.UsedQuota @@ -69,18 +70,18 @@ func GetUsage(c *gin.Context) { quota, err = model.GetUserUsedQuota(userId) } if err != nil { - openAIError := OpenAIError{ + Error := openai.Error{ Message: err.Error(), Type: "one_api_error", } c.JSON(200, gin.H{ - "error": openAIError, + "error": Error, }) return } amount := float64(quota) - if common.DisplayInCurrencyEnabled { - amount /= common.QuotaPerUnit + if config.DisplayInCurrencyEnabled { + amount /= config.QuotaPerUnit } usage := OpenAIUsageResponse{ Object: "list", diff --git a/controller/channel-billing.go b/controller/channel-billing.go index 6ddad7ea..abeab26a 100644 --- a/controller/channel-billing.go +++ b/controller/channel-billing.go @@ -4,10 +4,13 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/util" "io" "net/http" - "one-api/common" - "one-api/model" "strconv" "time" @@ -92,7 +95,7 @@ func GetResponseBody(method, url string, channel *model.Channel, headers http.He for k := range headers { req.Header.Add(k, headers.Get(k)) } - res, err := httpClient.Do(req) + res, err := util.HTTPClient.Do(req) if err != nil { return nil, err } @@ -313,7 +316,7 @@ func updateAllChannelsBalance() error { disableChannel(channel.Id, channel.Name, "余额不足") } } - time.Sleep(common.RequestInterval) + time.Sleep(config.RequestInterval) } return nil } @@ -338,8 +341,8 @@ func UpdateAllChannelsBalance(c *gin.Context) { func AutomaticallyUpdateChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("updating all channels") + logger.SysLog("updating all channels") _ = updateAllChannelsBalance() - common.SysLog("channels update done") + logger.SysLog("channels update done") } } diff --git a/controller/channel-test.go b/controller/channel-test.go index 3aaa4897..c8007031 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -5,10 +5,14 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/util" "io" "net/http" - "one-api/common" - "one-api/model" "strconv" "sync" "time" @@ -16,7 +20,7 @@ import ( "github.com/gin-gonic/gin" ) -func testChannel(channel *model.Channel, request ChatRequest) (err error, openaiErr *OpenAIError) { +func testChannel(channel *model.Channel, request openai.ChatRequest) (err error, openaiErr *openai.Error) { switch channel.Type { case common.ChannelTypePaLM: fallthrough @@ -46,13 +50,13 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai } requestURL := common.ChannelBaseURLs[channel.Type] if channel.Type == common.ChannelTypeAzure { - requestURL = getFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) + requestURL = util.GetFullRequestURL(channel.GetBaseURL(), fmt.Sprintf("/openai/deployments/%s/chat/completions?api-version=2023-03-15-preview", request.Model), channel.Type) } else { if baseURL := channel.GetBaseURL(); len(baseURL) > 0 { requestURL = baseURL } - requestURL = getFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) + requestURL = util.GetFullRequestURL(requestURL, "/v1/chat/completions", channel.Type) } jsonData, err := json.Marshal(request) if err != nil { @@ -68,12 +72,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai req.Header.Set("Authorization", "Bearer "+channel.Key) } req.Header.Set("Content-Type", "application/json") - resp, err := httpClient.Do(req) + resp, err := util.HTTPClient.Do(req) if err != nil { return err, nil } defer resp.Body.Close() - var response TextResponse + var response openai.SlimTextResponse body, err := io.ReadAll(resp.Body) if err != nil { return err, nil @@ -91,12 +95,12 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai return nil, nil } -func buildTestRequest() *ChatRequest { - testRequest := &ChatRequest{ +func buildTestRequest() *openai.ChatRequest { + testRequest := &openai.ChatRequest{ Model: "", // this will be set later MaxTokens: 1, } - testMessage := Message{ + testMessage := openai.Message{ Role: "user", Content: "hi", } @@ -148,12 +152,12 @@ var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false func notifyRootUser(subject string, content string) { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() + if config.RootUserEmail == "" { + config.RootUserEmail = model.GetRootUserEmail() } - err := common.SendEmail(subject, common.RootUserEmail, content) + err := common.SendEmail(subject, config.RootUserEmail, content) if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } @@ -174,8 +178,8 @@ func enableChannel(channelId int, channelName string) { } func testAllChannels(notify bool) error { - if common.RootUserEmail == "" { - common.RootUserEmail = model.GetRootUserEmail() + if config.RootUserEmail == "" { + config.RootUserEmail = model.GetRootUserEmail() } testAllChannelsLock.Lock() if testAllChannelsRunning { @@ -189,7 +193,7 @@ func testAllChannels(notify bool) error { return err } testRequest := buildTestRequest() - var disableThreshold = int64(common.ChannelDisableThreshold * 1000) + var disableThreshold = int64(config.ChannelDisableThreshold * 1000) if disableThreshold == 0 { disableThreshold = 10000000 // a impossible value } @@ -204,22 +208,22 @@ func testAllChannels(notify bool) error { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) disableChannel(channel.Id, channel.Name, err.Error()) } - if isChannelEnabled && shouldDisableChannel(openaiErr, -1) { + if isChannelEnabled && util.ShouldDisableChannel(openaiErr, -1) { disableChannel(channel.Id, channel.Name, err.Error()) } - if !isChannelEnabled && shouldEnableChannel(err, openaiErr) { + if !isChannelEnabled && util.ShouldEnableChannel(err, openaiErr) { enableChannel(channel.Id, channel.Name) } channel.UpdateResponseTime(milliseconds) - time.Sleep(common.RequestInterval) + time.Sleep(config.RequestInterval) } testAllChannelsLock.Lock() testAllChannelsRunning = false testAllChannelsLock.Unlock() if notify { - err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") + err := common.SendEmail("通道测试完成", config.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") if err != nil { - common.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) + logger.SysError(fmt.Sprintf("failed to send email: %s", err.Error())) } } }() @@ -245,8 +249,8 @@ func TestAllChannels(c *gin.Context) { func AutomaticallyTestChannels(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Minute) - common.SysLog("testing all channels") + logger.SysLog("testing all channels") _ = testAllChannels(false) - common.SysLog("channel test finished") + logger.SysLog("channel test finished") } } diff --git a/controller/channel.go b/controller/channel.go index 904abc23..bdfa00d9 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -2,9 +2,10 @@ package controller import ( "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/model" "net/http" - "one-api/common" - "one-api/model" "strconv" "strings" ) @@ -14,7 +15,7 @@ func GetAllChannels(c *gin.Context) { if p < 0 { p = 0 } - channels, err := model.GetAllChannels(p*common.ItemsPerPage, common.ItemsPerPage, false) + channels, err := model.GetAllChannels(p*config.ItemsPerPage, config.ItemsPerPage, false) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -83,7 +84,7 @@ func AddChannel(c *gin.Context) { }) return } - channel.CreatedTime = common.GetTimestamp() + channel.CreatedTime = helper.GetTimestamp() keys := strings.Split(channel.Key, "\n") channels := make([]model.Channel, 0, len(keys)) for _, key := range keys { diff --git a/controller/github.go b/controller/github.go index ee995379..7d7fa106 100644 --- a/controller/github.go +++ b/controller/github.go @@ -7,9 +7,12 @@ import ( "fmt" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" "net/http" - "one-api/common" - "one-api/model" "strconv" "time" ) @@ -30,7 +33,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { if code == "" { return nil, errors.New("无效的参数") } - values := map[string]string{"client_id": common.GitHubClientId, "client_secret": common.GitHubClientSecret, "code": code} + values := map[string]string{"client_id": config.GitHubClientId, "client_secret": config.GitHubClientSecret, "code": code} jsonData, err := json.Marshal(values) if err != nil { return nil, err @@ -46,7 +49,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { } res, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res.Body.Close() @@ -62,7 +65,7 @@ func getGitHubUserInfoByCode(code string) (*GitHubUser, error) { req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", oAuthResponse.AccessToken)) res2, err := client.Do(req) if err != nil { - common.SysLog(err.Error()) + logger.SysLog(err.Error()) return nil, errors.New("无法连接至 GitHub 服务器,请稍后重试!") } defer res2.Body.Close() @@ -93,7 +96,7 @@ func GitHubOAuth(c *gin.Context) { return } - if !common.GitHubOAuthEnabled { + if !config.GitHubOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 GitHub 登录以及注册", @@ -122,7 +125,7 @@ func GitHubOAuth(c *gin.Context) { return } } else { - if common.RegisterEnabled { + if config.RegisterEnabled { user.Username = "github_" + strconv.Itoa(model.GetMaxUserId()+1) if githubUser.Name != "" { user.DisplayName = githubUser.Name @@ -160,7 +163,7 @@ func GitHubOAuth(c *gin.Context) { } func GitHubBind(c *gin.Context) { - if !common.GitHubOAuthEnabled { + if !config.GitHubOAuthEnabled { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "管理员未开启通过 GitHub 登录以及注册", @@ -216,7 +219,7 @@ func GitHubBind(c *gin.Context) { func GenerateOAuthCode(c *gin.Context) { session := sessions.Default(c) - state := common.GetRandomString(12) + state := helper.GetRandomString(12) session.Set("oauth_state", state) err := session.Save() if err != nil { diff --git a/controller/group.go b/controller/group.go index d959bd37..128a3527 100644 --- a/controller/group.go +++ b/controller/group.go @@ -2,8 +2,8 @@ package controller import ( "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" "net/http" - "one-api/common" ) func GetGroups(c *gin.Context) { diff --git a/controller/log.go b/controller/log.go index b65867fe..9377b338 100644 --- a/controller/log.go +++ b/controller/log.go @@ -2,9 +2,9 @@ package controller import ( "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/model" "net/http" - "one-api/common" - "one-api/model" "strconv" ) @@ -20,7 +20,7 @@ func GetAllLogs(c *gin.Context) { tokenName := c.Query("token_name") modelName := c.Query("model_name") channel, _ := strconv.Atoi(c.Query("channel")) - logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*common.ItemsPerPage, common.ItemsPerPage, channel) + logs, err := model.GetAllLogs(logType, startTimestamp, endTimestamp, modelName, username, tokenName, p*config.ItemsPerPage, config.ItemsPerPage, channel) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -47,7 +47,7 @@ func GetUserLogs(c *gin.Context) { endTimestamp, _ := strconv.ParseInt(c.Query("end_timestamp"), 10, 64) tokenName := c.Query("token_name") modelName := c.Query("model_name") - logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*common.ItemsPerPage, common.ItemsPerPage) + logs, err := model.GetUserLogs(userId, logType, startTimestamp, endTimestamp, modelName, tokenName, p*config.ItemsPerPage, config.ItemsPerPage) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, diff --git a/controller/misc.go b/controller/misc.go index 2bcbb41f..036bdbd1 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -3,9 +3,10 @@ package controller import ( "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/model" "net/http" - "one-api/common" - "one-api/model" "strings" "github.com/gin-gonic/gin" @@ -18,55 +19,55 @@ func GetStatus(c *gin.Context) { "data": gin.H{ "version": common.Version, "start_time": common.StartTime, - "email_verification": common.EmailVerificationEnabled, - "github_oauth": common.GitHubOAuthEnabled, - "github_client_id": common.GitHubClientId, - "system_name": common.SystemName, - "logo": common.Logo, - "footer_html": common.Footer, - "wechat_qrcode": common.WeChatAccountQRCodeImageURL, - "wechat_login": common.WeChatAuthEnabled, - "server_address": common.ServerAddress, - "turnstile_check": common.TurnstileCheckEnabled, - "turnstile_site_key": common.TurnstileSiteKey, - "top_up_link": common.TopUpLink, - "chat_link": common.ChatLink, - "quota_per_unit": common.QuotaPerUnit, - "display_in_currency": common.DisplayInCurrencyEnabled, + "email_verification": config.EmailVerificationEnabled, + "github_oauth": config.GitHubOAuthEnabled, + "github_client_id": config.GitHubClientId, + "system_name": config.SystemName, + "logo": config.Logo, + "footer_html": config.Footer, + "wechat_qrcode": config.WeChatAccountQRCodeImageURL, + "wechat_login": config.WeChatAuthEnabled, + "server_address": config.ServerAddress, + "turnstile_check": config.TurnstileCheckEnabled, + "turnstile_site_key": config.TurnstileSiteKey, + "top_up_link": config.TopUpLink, + "chat_link": config.ChatLink, + "quota_per_unit": config.QuotaPerUnit, + "display_in_currency": config.DisplayInCurrencyEnabled, }, }) return } func GetNotice(c *gin.Context) { - common.OptionMapRWMutex.RLock() - defer common.OptionMapRWMutex.RUnlock() + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": common.OptionMap["Notice"], + "data": config.OptionMap["Notice"], }) return } func GetAbout(c *gin.Context) { - common.OptionMapRWMutex.RLock() - defer common.OptionMapRWMutex.RUnlock() + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": common.OptionMap["About"], + "data": config.OptionMap["About"], }) return } func GetHomePageContent(c *gin.Context) { - common.OptionMapRWMutex.RLock() - defer common.OptionMapRWMutex.RUnlock() + config.OptionMapRWMutex.RLock() + defer config.OptionMapRWMutex.RUnlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": common.OptionMap["HomePageContent"], + "data": config.OptionMap["HomePageContent"], }) return } @@ -80,9 +81,9 @@ func SendEmailVerification(c *gin.Context) { }) return } - if common.EmailDomainRestrictionEnabled { + if config.EmailDomainRestrictionEnabled { allowed := false - for _, domain := range common.EmailDomainWhitelist { + for _, domain := range config.EmailDomainWhitelist { if strings.HasSuffix(email, "@"+domain) { allowed = true break @@ -105,10 +106,10 @@ func SendEmailVerification(c *gin.Context) { } code := common.GenerateVerificationCode(6) common.RegisterVerificationCodeWithKey(email, code, common.EmailVerificationPurpose) - subject := fmt.Sprintf("%s邮箱验证邮件", common.SystemName) + subject := fmt.Sprintf("%s邮箱验证邮件", config.SystemName) content := fmt.Sprintf("
您好,你正在进行%s邮箱验证。
"+ "您的验证码为: %s
"+ - "验证码 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, code, common.VerificationValidMinutes) + "验证码 %d 分钟内有效,如果不是本人操作,请忽略。
", config.SystemName, code, common.VerificationValidMinutes) err := common.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ @@ -142,12 +143,12 @@ func SendPasswordResetEmail(c *gin.Context) { } code := common.GenerateVerificationCode(0) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) - link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", common.ServerAddress, email, code) - subject := fmt.Sprintf("%s密码重置", common.SystemName) + link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", config.ServerAddress, email, code) + subject := fmt.Sprintf("%s密码重置", config.SystemName) content := fmt.Sprintf("您好,你正在进行%s密码重置。
"+ "点击 此处 进行密码重置。
"+ "如果链接无法点击,请尝试点击下面的链接或将其复制到浏览器中打开:
%s
重置链接 %d 分钟内有效,如果不是本人操作,请忽略。
", common.SystemName, link, link, common.VerificationValidMinutes) + "重置链接 %d 分钟内有效,如果不是本人操作,请忽略。
", config.SystemName, link, link, common.VerificationValidMinutes) err := common.SendEmail(subject, email, content) if err != nil { c.JSON(http.StatusOK, gin.H{ diff --git a/controller/model.go b/controller/model.go index c12ccf34..e3e83fcd 100644 --- a/controller/model.go +++ b/controller/model.go @@ -2,8 +2,8 @@ package controller import ( "fmt" - "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/channel/openai" ) // https://platform.openai.com/docs/api-reference/models/list @@ -171,6 +171,15 @@ func init() { Root: "gpt-3.5-turbo-1106", Parent: nil, }, + { + Id: "gpt-3.5-turbo-0125", + Object: "model", + Created: 1706232090, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-3.5-turbo-0125", + Parent: nil, + }, { Id: "gpt-3.5-turbo-instruct", Object: "model", @@ -243,6 +252,24 @@ func init() { Root: "gpt-4-1106-preview", Parent: nil, }, + { + Id: "gpt-4-0125-preview", + Object: "model", + Created: 1706232090, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-0125-preview", + Parent: nil, + }, + { + Id: "gpt-4-turbo-preview", + Object: "model", + Created: 1706232090, + OwnedBy: "openai", + Permission: permission, + Root: "gpt-4-turbo-preview", + Parent: nil, + }, { Id: "gpt-4-vision-preview", Object: "model", @@ -261,6 +288,24 @@ func init() { Root: "text-embedding-ada-002", Parent: nil, }, + { + Id: "text-embedding-3-small", + Object: "model", + Created: 1706232090, + OwnedBy: "openai", + Permission: permission, + Root: "text-embedding-3-small", + Parent: nil, + }, + { + Id: "text-embedding-3-large", + Object: "model", + Created: 1706232090, + OwnedBy: "openai", + Permission: permission, + Root: "text-embedding-3-large", + Parent: nil, + }, { Id: "text-davinci-003", Object: "model", @@ -613,14 +658,14 @@ func RetrieveModel(c *gin.Context) { if model, ok := openAIModelsMap[modelId]; ok { c.JSON(200, model) } else { - openAIError := OpenAIError{ + Error := openai.Error{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), Type: "invalid_request_error", Param: "model", Code: "model_not_found", } c.JSON(200, gin.H{ - "error": openAIError, + "error": Error, }) } } diff --git a/controller/option.go b/controller/option.go index 3b1cbad2..f86e3a64 100644 --- a/controller/option.go +++ b/controller/option.go @@ -2,9 +2,10 @@ package controller import ( "encoding/json" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/model" "net/http" - "one-api/common" - "one-api/model" "strings" "github.com/gin-gonic/gin" @@ -12,17 +13,17 @@ import ( func GetOptions(c *gin.Context) { var options []*model.Option - common.OptionMapRWMutex.Lock() - for k, v := range common.OptionMap { + config.OptionMapRWMutex.Lock() + for k, v := range config.OptionMap { if strings.HasSuffix(k, "Token") || strings.HasSuffix(k, "Secret") { continue } options = append(options, &model.Option{ Key: k, - Value: common.Interface2String(v), + Value: helper.Interface2String(v), }) } - common.OptionMapRWMutex.Unlock() + config.OptionMapRWMutex.Unlock() c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", @@ -43,7 +44,7 @@ func UpdateOption(c *gin.Context) { } switch option.Key { case "Theme": - if !common.ValidThemes[option.Value] { + if !config.ValidThemes[option.Value] { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无效的主题", @@ -51,7 +52,7 @@ func UpdateOption(c *gin.Context) { return } case "GitHubOAuthEnabled": - if option.Value == "true" && common.GitHubClientId == "" { + if option.Value == "true" && config.GitHubClientId == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 GitHub OAuth,请先填入 GitHub Client Id 以及 GitHub Client Secret!", @@ -59,7 +60,7 @@ func UpdateOption(c *gin.Context) { return } case "EmailDomainRestrictionEnabled": - if option.Value == "true" && len(common.EmailDomainWhitelist) == 0 { + if option.Value == "true" && len(config.EmailDomainWhitelist) == 0 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用邮箱域名限制,请先填入限制的邮箱域名!", @@ -67,7 +68,7 @@ func UpdateOption(c *gin.Context) { return } case "WeChatAuthEnabled": - if option.Value == "true" && common.WeChatServerAddress == "" { + if option.Value == "true" && config.WeChatServerAddress == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用微信登录,请先填入微信登录相关配置信息!", @@ -75,7 +76,7 @@ func UpdateOption(c *gin.Context) { return } case "TurnstileCheckEnabled": - if option.Value == "true" && common.TurnstileSiteKey == "" { + if option.Value == "true" && config.TurnstileSiteKey == "" { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "无法启用 Turnstile 校验,请先填入 Turnstile 校验相关配置信息!", diff --git a/controller/redemption.go b/controller/redemption.go index 0f656be0..31c9348d 100644 --- a/controller/redemption.go +++ b/controller/redemption.go @@ -2,9 +2,10 @@ package controller import ( "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/model" "net/http" - "one-api/common" - "one-api/model" "strconv" ) @@ -13,7 +14,7 @@ func GetAllRedemptions(c *gin.Context) { if p < 0 { p = 0 } - redemptions, err := model.GetAllRedemptions(p*common.ItemsPerPage, common.ItemsPerPage) + redemptions, err := model.GetAllRedemptions(p*config.ItemsPerPage, config.ItemsPerPage) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -105,12 +106,12 @@ func AddRedemption(c *gin.Context) { } var keys []string for i := 0; i < redemption.Count; i++ { - key := common.GetUUID() + key := helper.GetUUID() cleanRedemption := model.Redemption{ UserId: c.GetInt("id"), Name: redemption.Name, Key: key, - CreatedTime: common.GetTimestamp(), + CreatedTime: helper.GetTimestamp(), Quota: redemption.Quota, } err = cleanRedemption.Insert() diff --git a/controller/relay-tencent.go b/controller/relay-tencent.go deleted file mode 100644 index 27540806..00000000 --- a/controller/relay-tencent.go +++ /dev/null @@ -1,287 +0,0 @@ -package controller - -// import ( -// "bufio" -// "crypto/hmac" -// "crypto/sha1" -// "encoding/base64" -// "encoding/json" -// "errors" -// "fmt" -// "github.com/gin-gonic/gin" -// "io" -// "net/http" -// "one-api/common" -// "sort" -// "strconv" -// "strings" -// ) - -// // https://cloud.tencent.com/document/product/1729/97732 - -// type TencentMessage struct { -// Role string `json:"role"` -// Content string `json:"content"` -// } - -// type TencentChatRequest struct { -// AppId int64 `json:"app_id"` // 腾讯云账号的 APPID -// SecretId string `json:"secret_id"` // 官网 SecretId -// // Timestamp当前 UNIX 时间戳,单位为秒,可记录发起 API 请求的时间。 -// // 例如1529223702,如果与当前时间相差过大,会引起签名过期错误 -// Timestamp int64 `json:"timestamp"` -// // Expired 签名的有效期,是一个符合 UNIX Epoch 时间戳规范的数值, -// // 单位为秒;Expired 必须大于 Timestamp 且 Expired-Timestamp 小于90天 -// Expired int64 `json:"expired"` -// QueryID string `json:"query_id"` //请求 Id,用于问题排查 -// // Temperature 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定 -// // 默认 1.0,取值区间为[0.0,2.0],非必要不建议使用,不合理的取值会影响效果 -// // 建议该参数和 top_p 只设置1个,不要同时更改 top_p -// Temperature float64 `json:"temperature"` -// // TopP 影响输出文本的多样性,取值越大,生成文本的多样性越强 -// // 默认1.0,取值区间为[0.0, 1.0],非必要不建议使用, 不合理的取值会影响效果 -// // 建议该参数和 temperature 只设置1个,不要同时更改 -// TopP float64 `json:"top_p"` -// // Stream 0:同步,1:流式 (默认,协议:SSE) -// // 同步请求超时:60s,如果内容较长建议使用流式 -// Stream int `json:"stream"` -// // Messages 会话内容, 长度最多为40, 按对话时间从旧到新在数组中排列 -// // 输入 content 总数最大支持 3000 token。 -// Messages []TencentMessage `json:"messages"` -// } - -// type TencentError struct { -// Code int `json:"code"` -// Message string `json:"message"` -// } - -// type TencentUsage struct { -// InputTokens int `json:"input_tokens"` -// OutputTokens int `json:"output_tokens"` -// TotalTokens int `json:"total_tokens"` -// } - -// type TencentResponseChoices struct { -// FinishReason string `json:"finish_reason,omitempty"` // 流式结束标志位,为 stop 则表示尾包 -// Messages TencentMessage `json:"messages,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。 -// Delta TencentMessage `json:"delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。 -// } - -// type TencentChatResponse struct { -// Choices []TencentResponseChoices `json:"choices,omitempty"` // 结果 -// Created string `json:"created,omitempty"` // unix 时间戳的字符串 -// Id string `json:"id,omitempty"` // 会话 id -// Usage Usage `json:"usage,omitempty"` // token 数量 -// Error TencentError `json:"error,omitempty"` // 错误信息 注意:此字段可能返回 null,表示取不到有效值 -// Note string `json:"note,omitempty"` // 注释 -// ReqID string `json:"req_id,omitempty"` // 唯一请求 Id,每次请求都会返回。用于反馈接口入参 -// } - -// func requestOpenAI2Tencent(request GeneralOpenAIRequest) *TencentChatRequest { -// messages := make([]TencentMessage, 0, len(request.Messages)) -// for i := 0; i < len(request.Messages); i++ { -// message := request.Messages[i] -// if message.Role == "system" { -// messages = append(messages, TencentMessage{ -// Role: "user", -// Content: message.Content, -// }) -// messages = append(messages, TencentMessage{ -// Role: "assistant", -// Content: "Okay", -// }) -// continue -// } -// messages = append(messages, TencentMessage{ -// Content: message.Content, -// Role: message.Role, -// }) -// } -// stream := 0 -// if request.Stream { -// stream = 1 -// } -// return &TencentChatRequest{ -// Timestamp: common.GetTimestamp(), -// Expired: common.GetTimestamp() + 24*60*60, -// QueryID: common.GetUUID(), -// Temperature: request.Temperature, -// TopP: request.TopP, -// Stream: stream, -// Messages: messages, -// } -// } - -// func responseTencent2OpenAI(response *TencentChatResponse) *OpenAITextResponse { -// fullTextResponse := OpenAITextResponse{ -// Object: "chat.completion", -// Created: common.GetTimestamp(), -// Usage: response.Usage, -// } -// if len(response.Choices) > 0 { -// choice := OpenAITextResponseChoice{ -// Index: 0, -// Message: Message{ -// Role: "assistant", -// Content: response.Choices[0].Messages.Content, -// }, -// FinishReason: response.Choices[0].FinishReason, -// } -// fullTextResponse.Choices = append(fullTextResponse.Choices, choice) -// } -// return &fullTextResponse -// } - -// func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *ChatCompletionsStreamResponse { -// response := ChatCompletionsStreamResponse{ -// Object: "chat.completion.chunk", -// Created: common.GetTimestamp(), -// Model: "tencent-hunyuan", -// } -// if len(TencentResponse.Choices) > 0 { -// var choice ChatCompletionsStreamResponseChoice -// choice.Delta.Content = TencentResponse.Choices[0].Delta.Content -// if TencentResponse.Choices[0].FinishReason == "stop" { -// choice.FinishReason = &stopFinishReason -// } -// response.Choices = append(response.Choices, choice) -// } -// return &response -// } - -// func tencentStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { -// var responseText string -// scanner := bufio.NewScanner(resp.Body) -// scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { -// if atEOF && len(data) == 0 { -// return 0, nil, nil -// } -// if i := strings.Index(string(data), "\n"); i >= 0 { -// return i + 1, data[0:i], nil -// } -// if atEOF { -// return len(data), data, nil -// } -// return 0, nil, nil -// }) -// dataChan := make(chan string) -// stopChan := make(chan bool) -// go func() { -// for scanner.Scan() { -// data := scanner.Text() -// if len(data) < 5 { // ignore blank line or wrong format -// continue -// } -// if data[:5] != "data:" { -// continue -// } -// data = data[5:] -// dataChan <- data -// } -// stopChan <- true -// }() -// setEventStreamHeaders(c) -// c.Stream(func(w io.Writer) bool { -// select { -// case data := <-dataChan: -// var TencentResponse TencentChatResponse -// err := json.Unmarshal([]byte(data), &TencentResponse) -// if err != nil { -// common.SysError("error unmarshalling stream response: " + err.Error()) -// return true -// } -// response := streamResponseTencent2OpenAI(&TencentResponse) -// if len(response.Choices) != 0 { -// responseText += response.Choices[0].Delta.Content -// } -// jsonResponse, err := json.Marshal(response) -// if err != nil { -// common.SysError("error marshalling stream response: " + err.Error()) -// return true -// } -// c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)}) -// return true -// case <-stopChan: -// c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) -// return false -// } -// }) -// err := resp.Body.Close() -// if err != nil { -// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" -// } -// return nil, responseText -// } - -// func tencentHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, *Usage) { -// var TencentResponse TencentChatResponse -// responseBody, err := io.ReadAll(resp.Body) -// if err != nil { -// return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil -// } -// err = resp.Body.Close() -// if err != nil { -// return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil -// } -// err = json.Unmarshal(responseBody, &TencentResponse) -// if err != nil { -// return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil -// } -// if TencentResponse.Error.Code != 0 { -// return &OpenAIErrorWithStatusCode{ -// OpenAIError: OpenAIError{ -// Message: TencentResponse.Error.Message, -// Code: TencentResponse.Error.Code, -// }, -// StatusCode: resp.StatusCode, -// }, nil -// } -// fullTextResponse := responseTencent2OpenAI(&TencentResponse) -// jsonResponse, err := json.Marshal(fullTextResponse) -// if err != nil { -// return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil -// } -// c.Writer.Header().Set("Content-Type", "application/json") -// c.Writer.WriteHeader(resp.StatusCode) -// _, err = c.Writer.Write(jsonResponse) -// return nil, &fullTextResponse.Usage -// } - -// func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) { -// parts := strings.Split(config, "|") -// if len(parts) != 3 { -// err = errors.New("invalid tencent config") -// return -// } -// appId, err = strconv.ParseInt(parts[0], 10, 64) -// secretId = parts[1] -// secretKey = parts[2] -// return -// } - -// func getTencentSign(req TencentChatRequest, secretKey string) string { -// params := make([]string, 0) -// params = append(params, "app_id="+strconv.FormatInt(req.AppId, 10)) -// params = append(params, "secret_id="+req.SecretId) -// params = append(params, "timestamp="+strconv.FormatInt(req.Timestamp, 10)) -// params = append(params, "query_id="+req.QueryID) -// params = append(params, "temperature="+strconv.FormatFloat(req.Temperature, 'f', -1, 64)) -// params = append(params, "top_p="+strconv.FormatFloat(req.TopP, 'f', -1, 64)) -// params = append(params, "stream="+strconv.Itoa(req.Stream)) -// params = append(params, "expired="+strconv.FormatInt(req.Expired, 10)) - -// var messageStr string -// for _, msg := range req.Messages { -// messageStr += fmt.Sprintf(`{"role":"%s","content":"%s"},`, msg.Role, msg.Content) -// } -// messageStr = strings.TrimSuffix(messageStr, ",") -// params = append(params, "messages=["+messageStr+"]") - -// sort.Sort(sort.StringSlice(params)) -// url := "hunyuan.cloud.tencent.com/hyllm/v1/chat/completions?" + strings.Join(params, "&") -// mac := hmac.New(sha1.New, []byte(secretKey)) -// signURL := url -// mac.Write([]byte(signURL)) -// sign := mac.Sum([]byte(nil)) -// return base64.StdEncoding.EncodeToString(sign) -// } diff --git a/controller/relay-text.go b/controller/relay-text.go deleted file mode 100644 index c8d4623f..00000000 --- a/controller/relay-text.go +++ /dev/null @@ -1,759 +0,0 @@ -package controller - -import ( - "bytes" - "context" - "encoding/json" - "errors" - "fmt" - "io" - "math" - "net/http" - "one-api/common" - "one-api/model" - "os" - "strings" - "time" - - "github.com/gin-gonic/gin" -) - -const ( - APITypeOpenAI = iota - APITypeClaude - APITypePaLM - APITypeBaidu - APITypeZhipu - APITypeAli - APITypeXunfei - APITypeAIProxyLibrary - APITypeTencent - APITypeGemini -) - -var httpClient *http.Client -var impatientHTTPClient *http.Client - -func init() { - if common.RelayTimeout == 0 { - httpClient = &http.Client{} - } else { - httpClient = &http.Client{ - Timeout: time.Duration(common.RelayTimeout) * time.Second, - } - } - - impatientHTTPClient = &http.Client{ - Timeout: 5 * time.Second, - } -} - -func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { - channelType := c.GetInt("channel") - channelId := c.GetInt("channel_id") - tokenId := c.GetInt("token_id") - userId := c.GetInt("id") - group := c.GetString("group") - var textRequest GeneralOpenAIRequest - err := common.UnmarshalBodyReusable(c, &textRequest) - if err != nil { - return errorWrapper(err, "bind_request_body_failed", http.StatusBadRequest) - } - if textRequest.MaxTokens < 0 || textRequest.MaxTokens > math.MaxInt32/2 { - return errorWrapper(errors.New("max_tokens is invalid"), "invalid_max_tokens", http.StatusBadRequest) - } - if relayMode == RelayModeModerations && textRequest.Model == "" { - textRequest.Model = "text-moderation-latest" - } - if relayMode == RelayModeEmbeddings && textRequest.Model == "" { - textRequest.Model = c.Param("model") - } - // request validation - if textRequest.Model == "" { - return errorWrapper(errors.New("model is required"), "required_field_missing", http.StatusBadRequest) - } - switch relayMode { - case RelayModeCompletions: - if textRequest.Prompt == "" { - return errorWrapper(errors.New("field prompt is required"), "required_field_missing", http.StatusBadRequest) - } - case RelayModeChatCompletions: - if textRequest.Messages == nil || len(textRequest.Messages) == 0 { - return errorWrapper(errors.New("field messages is required"), "required_field_missing", http.StatusBadRequest) - } - case RelayModeEmbeddings: - case RelayModeModerations: - if textRequest.Input == "" { - return errorWrapper(errors.New("field input is required"), "required_field_missing", http.StatusBadRequest) - } - case RelayModeEdits: - if textRequest.Instruction == "" { - return errorWrapper(errors.New("field instruction is required"), "required_field_missing", http.StatusBadRequest) - } - } - // map model name - modelMapping := c.GetString("model_mapping") - isModelMapped := false - if modelMapping != "" && modelMapping != "{}" { - modelMap := make(map[string]string) - err := json.Unmarshal([]byte(modelMapping), &modelMap) - if err != nil { - return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError) - } - if modelMap[textRequest.Model] != "" { - textRequest.Model = modelMap[textRequest.Model] - isModelMapped = true - } - } - apiType := APITypeOpenAI - switch channelType { - case common.ChannelTypeAnthropic: - apiType = APITypeClaude - case common.ChannelTypeBaidu: - apiType = APITypeBaidu - case common.ChannelTypePaLM: - apiType = APITypePaLM - case common.ChannelTypeZhipu: - apiType = APITypeZhipu - case common.ChannelTypeAli: - apiType = APITypeAli - case common.ChannelTypeXunfei: - apiType = APITypeXunfei - case common.ChannelTypeAIProxyLibrary: - apiType = APITypeAIProxyLibrary - case common.ChannelTypeTencent: - apiType = APITypeTencent - case common.ChannelTypeGemini: - apiType = APITypeGemini - } - baseURL := common.ChannelBaseURLs[channelType] - requestURL := c.Request.URL.String() - if c.GetString("base_url") != "" { - baseURL = c.GetString("base_url") - } - fullRequestURL := getFullRequestURL(baseURL, requestURL, channelType) - switch apiType { - case APITypeOpenAI: - if channelType == common.ChannelTypeAzure { - // https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api - apiVersion := GetAPIVersion(c) - requestURL := strings.Split(requestURL, "?")[0] - requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion) - baseURL = c.GetString("base_url") - task := strings.TrimPrefix(requestURL, "/v1/") - model_ := textRequest.Model - model_ = strings.Replace(model_, ".", "", -1) - // https://github.com/songquanpeng/one-api/issues/67 - model_ = strings.TrimSuffix(model_, "-0301") - model_ = strings.TrimSuffix(model_, "-0314") - model_ = strings.TrimSuffix(model_, "-0613") - - requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task) - fullRequestURL = getFullRequestURL(baseURL, requestURL, channelType) - } - case APITypeClaude: - fullRequestURL = "https://api.anthropic.com/v1/complete" - if baseURL != "" { - fullRequestURL = fmt.Sprintf("%s/v1/complete", baseURL) - } - // case APITypeBaidu: - // switch textRequest.Model { - // case "ERNIE-Bot": - // fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions" - // case "ERNIE-Bot-turbo": - // fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant" - // case "ERNIE-Bot-4": - // fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro" - // case "BLOOMZ-7B": - // fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/bloomz_7b1" - // case "Embedding-V1": - // fullRequestURL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/embedding-v1" - // } - // apiKey := c.Request.Header.Get("Authorization") - // apiKey = strings.TrimPrefix(apiKey, "Bearer ") - // var err error - // if apiKey, err = getBaiduAccessToken(apiKey); err != nil { - // return errorWrapper(err, "invalid_baidu_config", http.StatusInternalServerError) - // } - // fullRequestURL += "?access_token=" + apiKey - case APITypePaLM: - fullRequestURL = "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" - if baseURL != "" { - fullRequestURL = fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", baseURL) - } - case APITypeGemini: - requestBaseURL := "https://generativelanguage.googleapis.com" - if baseURL != "" { - requestBaseURL = baseURL - } - version := "v1" - if c.GetString("api_version") != "" { - version = c.GetString("api_version") - } - action := "generateContent" - if textRequest.Stream { - action = "streamGenerateContent" - } - fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) - case APITypeZhipu: - method := "invoke" - if textRequest.Stream { - method = "sse-invoke" - } - fullRequestURL = fmt.Sprintf("https://open.bigmodel.cn/api/paas/v3/model-api/%s/%s", textRequest.Model, method) - case APITypeAli: - fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" - if relayMode == RelayModeEmbeddings { - fullRequestURL = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" - } - case APITypeTencent: - fullRequestURL = "https://hunyuan.cloud.tencent.com/hyllm/v1/chat/completions" - case APITypeAIProxyLibrary: - fullRequestURL = fmt.Sprintf("%s/api/library/ask", baseURL) - } - var promptTokens int - var completionTokens int - switch relayMode { - case RelayModeChatCompletions: - promptTokens = countTokenMessages(textRequest.Messages, textRequest.Model) - case RelayModeCompletions: - promptTokens = countTokenInput(textRequest.Prompt, textRequest.Model) - case RelayModeModerations: - promptTokens = countTokenInput(textRequest.Input, textRequest.Model) - } - preConsumedTokens := common.PreConsumedQuota - if textRequest.MaxTokens != 0 { - preConsumedTokens = promptTokens + textRequest.MaxTokens - } - modelRatio := common.GetModelRatio(textRequest.Model) - groupRatio := common.GetGroupRatio(group) - ratio := modelRatio * groupRatio - preConsumedQuota := int(float64(preConsumedTokens) * ratio) - userQuota, err := model.CacheGetUserQuota(userId) - if err != nil { - return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) - } - if userQuota-preConsumedQuota < 0 { - return errorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) - } - err = model.CacheDecreaseUserQuota(userId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "decrease_user_quota_failed", http.StatusInternalServerError) - } - if userQuota > 100*preConsumedQuota { - // in this case, we do not pre-consume quota - // because the user has enough quota - preConsumedQuota = 0 - common.LogInfo(c.Request.Context(), fmt.Sprintf("user %d has enough quota %d, trusted and no need to pre-consume", userId, userQuota)) - } - if preConsumedQuota > 0 { - err := model.PreConsumeTokenQuota(tokenId, preConsumedQuota) - if err != nil { - return errorWrapper(err, "pre_consume_token_quota_failed", http.StatusForbidden) - } - } - var requestBody io.Reader - if isModelMapped { - jsonStr, err := json.Marshal(textRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } else { - requestBody = c.Request.Body - } - - common.LogInfo(c.Request.Context(), fmt.Sprintf( - "convert to apitype %d, channel_type %d, channel_id %d", - apiType, channelType, channelId)) - switch apiType { - case APITypeClaude: - claudeRequest := requestOpenAI2Claude(textRequest) - jsonStr, err := json.Marshal(claudeRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - // case APITypeBaidu: - // var jsonData []byte - // var err error - // switch relayMode { - // case RelayModeEmbeddings: - // baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(textRequest) - // jsonData, err = json.Marshal(baiduEmbeddingRequest) - // default: - // baiduRequest := requestOpenAI2Baidu(textRequest) - // jsonData, err = json.Marshal(baiduRequest) - // } - // if err != nil { - // return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - // } - // requestBody = bytes.NewBuffer(jsonData) - case APITypePaLM: - palmRequest := requestOpenAI2PaLM(textRequest) - jsonStr, err := json.Marshal(palmRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - case APITypeGemini: - geminiChatRequest := requestOpenAI2Gemini(textRequest) - jsonStr, err := json.Marshal(geminiChatRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - fmt.Println(">> convert request body to gemini: " + string(jsonStr)) // FIXME - // case APITypeZhipu: - // zhipuRequest := requestOpenAI2Zhipu(textRequest) - // jsonStr, err := json.Marshal(zhipuRequest) - // if err != nil { - // return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - // } - // requestBody = bytes.NewBuffer(jsonStr) - // case APITypeAli: - // var jsonStr []byte - // var err error - // switch relayMode { - // case RelayModeEmbeddings: - // aliEmbeddingRequest := embeddingRequestOpenAI2Ali(textRequest) - // jsonStr, err = json.Marshal(aliEmbeddingRequest) - // default: - // aliRequest := requestOpenAI2Ali(textRequest) - // jsonStr, err = json.Marshal(aliRequest) - // } - // if err != nil { - // return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - // } - // requestBody = bytes.NewBuffer(jsonStr) - // case APITypeTencent: - // apiKey := c.Request.Header.Get("Authorization") - // apiKey = strings.TrimPrefix(apiKey, "Bearer ") - // appId, secretId, secretKey, err := parseTencentConfig(apiKey) - // if err != nil { - // return errorWrapper(err, "invalid_tencent_config", http.StatusInternalServerError) - // } - // tencentRequest := requestOpenAI2Tencent(textRequest) - // tencentRequest.AppId = appId - // tencentRequest.SecretId = secretId - // jsonStr, err := json.Marshal(tencentRequest) - // if err != nil { - // return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - // } - // sign := getTencentSign(*tencentRequest, secretKey) - // c.Request.Header.Set("Authorization", sign) - // requestBody = bytes.NewBuffer(jsonStr) - case APITypeAIProxyLibrary: - aiProxyLibraryRequest := requestOpenAI2AIProxyLibrary(textRequest) - aiProxyLibraryRequest.LibraryId = c.GetString("library_id") - jsonStr, err := json.Marshal(aiProxyLibraryRequest) - if err != nil { - return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) - } - requestBody = bytes.NewBuffer(jsonStr) - } - - var req *http.Request - var resp *http.Response - isStream := textRequest.Stream - - if apiType != APITypeXunfei { // cause xunfei use websocket - req, err = http.NewRequest(c.Request.Method, fullRequestURL, requestBody) - if err != nil { - return errorWrapper(err, "new_request_failed", http.StatusInternalServerError) - } - apiKey := c.Request.Header.Get("Authorization") - apiKey = strings.TrimPrefix(apiKey, "Bearer ") - - switch apiType { - case APITypeOpenAI: - if channelType == common.ChannelTypeAzure { - req.Header.Set("api-key", apiKey) - } else { - req.Header.Set("Authorization", c.Request.Header.Get("Authorization")) - if channelType == common.ChannelTypeOpenRouter { - req.Header.Set("HTTP-Referer", "https://github.com/songquanpeng/one-api") - req.Header.Set("X-Title", "One API") - } - } - case APITypeClaude: - req.Header.Set("x-api-key", apiKey) - anthropicVersion := c.Request.Header.Get("anthropic-version") - if anthropicVersion == "" { - anthropicVersion = "2023-06-01" - } - req.Header.Set("anthropic-version", anthropicVersion) - // case APITypeZhipu: - // token := getZhipuToken(apiKey) - // req.Header.Set("Authorization", token) - // case APITypeAli: - // req.Header.Set("Authorization", "Bearer "+apiKey) - // if textRequest.Stream { - // req.Header.Set("X-DashScope-SSE", "enable") - // } - // if c.GetString("plugin") != "" { - // req.Header.Set("X-DashScope-Plugin", c.GetString("plugin")) - // } - // case APITypeTencent: - // req.Header.Set("Authorization", apiKey) - case APITypePaLM: - req.Header.Set("x-goog-api-key", apiKey) - case APITypeGemini: - req.Header.Set("x-goog-api-key", apiKey) - default: - req.Header.Set("Authorization", "Bearer "+apiKey) - } - req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type")) - req.Header.Set("Accept", c.Request.Header.Get("Accept")) - if isStream && c.Request.Header.Get("Accept") == "" { - req.Header.Set("Accept", "text/event-stream") - } - //req.Header.Set("Connection", c.Request.Header.Get("Connection")) - resp, err = httpClient.Do(req) - if err != nil { - return errorWrapper(err, "do_request_failed", http.StatusInternalServerError) - } - err = req.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - err = c.Request.Body.Close() - if err != nil { - return errorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) - } - isStream = isStream || strings.HasPrefix(resp.Header.Get("Content-Type"), "text/event-stream") - - if resp.StatusCode != http.StatusOK { - if preConsumedQuota != 0 { - go func(ctx context.Context) { - // return pre-consumed quota - err := model.PostConsumeTokenQuota(tokenId, -preConsumedQuota) - if err != nil { - common.LogError(ctx, "error return pre-consumed quota: "+err.Error()) - } - }(c.Request.Context()) - } - - { // more error info - if reqdata, err := json.Marshal(req.Body); err != nil { - fmt.Printf("[ERROR] marshal relay text error: %s\n", err.Error()) - } else { - if respdata, err := io.ReadAll(resp.Body); err != nil { - fmt.Printf("[ERROR] read resp body error: %s\n", err.Error()) - } else { - resp.Body = io.NopCloser(bytes.NewBuffer(respdata)) - - fmt.Printf("[ERROR] send req %q to %s got error [%d]%s\n", - string(reqdata), req.URL.String(), resp.StatusCode, string(respdata)) - } - } - } - - return relayErrorHandler(resp) - } - } - - var textResponse TextResponse - tokenName := c.GetString("token_name") - - defer func(ctx context.Context) { - // c.Writer.Flush() - go func() { - quota := 0 - completionRatio := common.GetCompletionRatio(textRequest.Model) - promptTokens = textResponse.Usage.PromptTokens - completionTokens = textResponse.Usage.CompletionTokens - quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio)) - if ratio != 0 && quota <= 0 { - quota = 1 - } - totalTokens := promptTokens + completionTokens - if totalTokens == 0 { - // in this case, must be some error happened - // we cannot just return, because we may have to return the pre-consumed quota - quota = 0 - } - quotaDelta := quota - preConsumedQuota - err := model.PostConsumeTokenQuota(tokenId, quotaDelta) - if err != nil { - common.LogError(ctx, "error consuming token remain quota: "+err.Error()) - } - err = model.CacheUpdateUserQuota(userId) - if err != nil { - common.LogError(ctx, "error update user quota cache: "+err.Error()) - } - if quota != 0 { - logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio) - model.RecordConsumeLog(ctx, userId, channelId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent) - model.UpdateUserUsedQuotaAndRequestCount(userId, quota) - model.UpdateChannelUsedQuota(channelId, quota) - } - - if os.Getenv("LLM_CONSERVATION_AUDIT") != "" && - textRequest.Model != "" || - textRequest.MaxTokens != 0 || - len(textRequest.Messages) != 0 || - textResponse.Content != "" { - go func() { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - body, err := json.Marshal(map[string]any{ - "model": textRequest.Model, - "max_tokens": textRequest.MaxTokens, - "messages": textRequest.Messages, - "response": textResponse.Content, - }) - if err != nil { - common.LogError(ctx, "error marshal conservation audit: "+err.Error()) - return - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, os.Getenv("LLM_CONSERVATION_AUDIT"), bytes.NewBuffer(body)) - if err != nil { - common.LogError(ctx, "error new request conservation audit: "+err.Error()) - return - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - common.LogError(ctx, "error do conservation audit: "+err.Error()) - return - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - respBody, err := io.ReadAll(resp.Body) - if err != nil { - common.LogError(ctx, "error conservation audit: "+err.Error()) - return - } - - common.LogError(ctx, fmt.Sprintf("error conservation audit: [%d]%s", resp.StatusCode, string(respBody))) - } - }() - } - }() - }(c.Request.Context()) - switch apiType { - case APITypeOpenAI: - if isStream { - err, responseText := openaiStreamHandler(c, resp, relayMode) - if err != nil { - return err - } - textResponse.Content = responseText - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := openaiHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeClaude: - if isStream { - err, responseText := claudeStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Content = responseText - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := claudeHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - // case APITypeBaidu: - // if isStream { - // err, usage := baiduStreamHandler(c, resp) - // if err != nil { - // return err - // } - // if usage != nil { - // textResponse.Usage = *usage - // } - // return nil - // } else { - // var err *OpenAIErrorWithStatusCode - // var usage *Usage - // switch relayMode { - // case RelayModeEmbeddings: - // err, usage = baiduEmbeddingHandler(c, resp) - // default: - // err, usage = baiduHandler(c, resp) - // } - // if err != nil { - // return err - // } - // if usage != nil { - // textResponse.Usage = *usage - // } - // return nil - // } - case APITypePaLM: - if textRequest.Stream { // PaLM2 API does not support stream - err, responseText := palmStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Content = responseText - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := palmHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - case APITypeGemini: - if textRequest.Stream { - err, responseText := geminiChatStreamHandler(c, resp) - if err != nil { - return err - } - textResponse.Content = responseText - textResponse.Usage.PromptTokens = promptTokens - textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - return nil - } else { - err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - // case APITypeZhipu: - // if isStream { - // err, usage := zhipuStreamHandler(c, resp) - // if err != nil { - // return err - // } - // if usage != nil { - // textResponse.Usage = *usage - // } - // // zhipu's API does not return prompt tokens & completion tokens - // textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens - // return nil - // } else { - // err, usage := zhipuHandler(c, resp) - // if err != nil { - // return err - // } - // if usage != nil { - // textResponse.Usage = *usage - // } - // // zhipu's API does not return prompt tokens & completion tokens - // textResponse.Usage.PromptTokens = textResponse.Usage.TotalTokens - // return nil - // } - // case APITypeAli: - // if isStream { - // err, usage := aliStreamHandler(c, resp) - // if err != nil { - // return err - // } - // if usage != nil { - // textResponse.Usage = *usage - // } - // return nil - // } else { - // var err *OpenAIErrorWithStatusCode - // var usage *Usage - // switch relayMode { - // case RelayModeEmbeddings: - // err, usage = aliEmbeddingHandler(c, resp) - // default: - // err, usage = aliHandler(c, resp) - // } - // if err != nil { - // return err - // } - // if usage != nil { - // textResponse.Usage = *usage - // } - // return nil - // } - // case APITypeXunfei: - // auth := c.Request.Header.Get("Authorization") - // auth = strings.TrimPrefix(auth, "Bearer ") - // splits := strings.Split(auth, "|") - // if len(splits) != 3 { - // return errorWrapper(errors.New("invalid auth"), "invalid_auth", http.StatusBadRequest) - // } - // var err *OpenAIErrorWithStatusCode - // var usage *Usage - // if isStream { - // err, usage = xunfeiStreamHandler(c, textRequest, splits[0], splits[1], splits[2]) - // } else { - // err, usage = xunfeiHandler(c, textRequest, splits[0], splits[1], splits[2]) - // } - // if err != nil { - // return err - // } - // if usage != nil { - // textResponse.Usage = *usage - // } - // return nil - case APITypeAIProxyLibrary: - if isStream { - err, usage := aiProxyLibraryStreamHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } else { - err, usage := aiProxyLibraryHandler(c, resp) - if err != nil { - return err - } - if usage != nil { - textResponse.Usage = *usage - } - return nil - } - // case APITypeTencent: - // if isStream { - // err, responseText := tencentStreamHandler(c, resp) - // if err != nil { - // return err - // } - // textResponse.Usage.PromptTokens = promptTokens - // textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) - // return nil - // } else { - // err, usage := tencentHandler(c, resp) - // if err != nil { - // return err - // } - // if usage != nil { - // textResponse.Usage = *usage - // } - // return nil - // } - default: - return errorWrapper(errors.New("unknown api type"), "unknown_api_type", http.StatusInternalServerError) - } -} diff --git a/controller/relay.go b/controller/relay.go index 116ed151..cfe37984 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -2,450 +2,57 @@ package controller import ( "fmt" - "net/http" - "one-api/common" - "strconv" - "strings" - "github.com/gin-gonic/gin" -) - -type Message struct { - Role string `json:"role"` - Content any `json:"content"` - Name *string `json:"name,omitempty"` -} - -type VisionMessage struct { - Role string `json:"role"` - Content []OpenaiVisionMessageContent `json:"content"` - Name *string `json:"name,omitempty"` -} - -// OpenaiVisionMessageContentType vision message content type -type OpenaiVisionMessageContentType string - -const ( - // OpenaiVisionMessageContentTypeText text - OpenaiVisionMessageContentTypeText OpenaiVisionMessageContentType = "text" - // OpenaiVisionMessageContentTypeImageUrl image url - OpenaiVisionMessageContentTypeImageUrl OpenaiVisionMessageContentType = "image_url" -) - -// OpenaiVisionMessageContent vision message content -type OpenaiVisionMessageContent struct { - Type OpenaiVisionMessageContentType `json:"type"` - Text string `json:"text,omitempty"` - ImageUrl OpenaiVisionMessageContentImageUrl `json:"image_url,omitempty"` -} - -// VisionImageResolution image resolution -type VisionImageResolution string - -const ( - // VisionImageResolutionLow low resolution - VisionImageResolutionLow VisionImageResolution = "low" - // VisionImageResolutionHigh high resolution - VisionImageResolutionHigh VisionImageResolution = "high" -) - -type OpenaiVisionMessageContentImageUrl struct { - URL string `json:"url"` - Detail VisionImageResolution `json:"detail,omitempty"` -} - -type ImageURL struct { - Url string `json:"url,omitempty"` - Detail string `json:"detail,omitempty"` -} - -type TextContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text,omitempty"` -} - -type ImageContent struct { - Type string `json:"type,omitempty"` - ImageURL *ImageURL `json:"image_url,omitempty"` -} - -const ( - ContentTypeText = "text" - ContentTypeImageURL = "image_url" -) - -type OpenAIMessageContent struct { - Type string `json:"type,omitempty"` - Text string `json:"text"` - ImageURL *ImageURL `json:"image_url,omitempty"` -} - -func (m Message) IsStringContent() bool { - _, ok := m.Content.(string) - return ok -} - -func (m Message) StringContent() string { - content, ok := m.Content.(string) - if ok { - return content - } - contentList, ok := m.Content.([]any) - if ok { - var contentStr string - for _, contentItem := range contentList { - contentMap, ok := contentItem.(map[string]any) - if !ok { - continue - } - if contentMap["type"] == ContentTypeText { - if subStr, ok := contentMap["text"].(string); ok { - contentStr += subStr - } - } - } - return contentStr - } - return "" -} - -func (m Message) ParseContent() []OpenAIMessageContent { - var contentList []OpenAIMessageContent - content, ok := m.Content.(string) - if ok { - contentList = append(contentList, OpenAIMessageContent{ - Type: ContentTypeText, - Text: content, - }) - return contentList - } - anyList, ok := m.Content.([]any) - if ok { - for _, contentItem := range anyList { - contentMap, ok := contentItem.(map[string]any) - if !ok { - continue - } - switch contentMap["type"] { - case ContentTypeText: - if subStr, ok := contentMap["text"].(string); ok { - contentList = append(contentList, OpenAIMessageContent{ - Type: ContentTypeText, - Text: subStr, - }) - } - case ContentTypeImageURL: - if subObj, ok := contentMap["image_url"].(map[string]any); ok { - contentList = append(contentList, OpenAIMessageContent{ - Type: ContentTypeImageURL, - ImageURL: &ImageURL{ - Url: subObj["url"].(string), - }, - }) - } - } - } - return contentList - } - return nil -} - -const ( - RelayModeUnknown = iota - RelayModeChatCompletions - RelayModeCompletions - RelayModeEmbeddings - RelayModeModerations - RelayModeImagesGenerations - RelayModeEdits - RelayModeAudioSpeech - RelayModeAudioTranscription - RelayModeAudioTranslation + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/relay/constant" + "github.com/songquanpeng/one-api/relay/controller" + "github.com/songquanpeng/one-api/relay/util" + "net/http" + "strconv" ) // https://platform.openai.com/docs/api-reference/chat -type ResponseFormat struct { - Type string `json:"type,omitempty"` -} - -type GeneralOpenAIRequest struct { - Model string `json:"model,omitempty"` - Messages []Message `json:"messages,omitempty"` - Prompt any `json:"prompt,omitempty"` - Stream bool `json:"stream,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - N int `json:"n,omitempty"` - Input any `json:"input,omitempty"` - Instruction string `json:"instruction,omitempty"` - Size string `json:"size,omitempty"` - Functions any `json:"functions,omitempty"` - FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` - PresencePenalty float64 `json:"presence_penalty,omitempty"` - ResponseFormat *ResponseFormat `json:"response_format,omitempty"` - Seed float64 `json:"seed,omitempty"` - Tools any `json:"tools,omitempty"` - ToolChoice any `json:"tool_choice,omitempty"` - User string `json:"user,omitempty"` -} - -// func (r *GeneralOpenAIRequest) MessagesLen() int { -// switch msgs := r.Messages.(type) { -// case []any: -// return len(msgs) -// case []Message: -// return len(msgs) -// case []VisionMessage: -// return len(msgs) -// case []map[string]any: -// return len(msgs) -// default: -// return 0 -// } -// } - -// TextMessages returns messages as []Message -// func (r *GeneralOpenAIRequest) TextMessages() (messages []Message, err error) { -// if blob, err := json.Marshal(r.Messages); err != nil { -// return nil, errors.Wrap(err, "marshal messages failed") -// } else if err := json.Unmarshal(blob, &messages); err != nil { -// return nil, errors.Wrapf(err, "unmarshal messages failed %q", string(blob)) -// } else { -// return messages, nil -// } -// } - -// VisionMessages returns messages as []VisionMessage -// func (r *GeneralOpenAIRequest) VisionMessages() (messages []VisionMessage, err error) { -// if blob, err := json.Marshal(r.Messages); err != nil { -// return nil, errors.Wrap(err, "marshal vision messages failed") -// } else if err := json.Unmarshal(blob, &messages); err != nil { -// return nil, errors.Wrapf(err, "unmarshal vision messages failed %q", string(blob)) -// } else { -// return messages, nil -// } -// } - -func (r GeneralOpenAIRequest) ParseInput() []string { - if r.Input == nil { - return nil - } - var input []string - switch r.Input.(type) { - case string: - input = []string{r.Input.(string)} - case []any: - input = make([]string, 0, len(r.Input.([]any))) - for _, item := range r.Input.([]any) { - if str, ok := item.(string); ok { - input = append(input, str) - } - } - } - return input -} - -type ChatRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - MaxTokens int `json:"max_tokens"` -} - -type TextRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Prompt string `json:"prompt"` - MaxTokens int `json:"max_tokens"` - //Stream bool `json:"stream"` -} - -// ImageRequest docs: https://platform.openai.com/docs/api-reference/images/create -type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style string `json:"style,omitempty"` - User string `json:"user,omitempty"` -} - -type WhisperJSONResponse struct { - Text string `json:"text,omitempty"` -} - -type WhisperVerboseJSONResponse struct { - Task string `json:"task,omitempty"` - Language string `json:"language,omitempty"` - Duration float64 `json:"duration,omitempty"` - Text string `json:"text,omitempty"` - Segments []Segment `json:"segments,omitempty"` -} - -type Segment struct { - Id int `json:"id"` - Seek int `json:"seek"` - Start float64 `json:"start"` - End float64 `json:"end"` - Text string `json:"text"` - Tokens []int `json:"tokens"` - Temperature float64 `json:"temperature"` - AvgLogprob float64 `json:"avg_logprob"` - CompressionRatio float64 `json:"compression_ratio"` - NoSpeechProb float64 `json:"no_speech_prob"` -} - -type TextToSpeechRequest struct { - Model string `json:"model" binding:"required"` - Input string `json:"input" binding:"required"` - Voice string `json:"voice" binding:"required"` - Speed float64 `json:"speed"` - ResponseFormat string `json:"response_format"` -} - -type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -type OpenAIError struct { - Message string `json:"message"` - Type string `json:"type"` - Param string `json:"param"` - Code any `json:"code"` -} - -type OpenAIErrorWithStatusCode struct { - OpenAIError - StatusCode int `json:"status_code"` -} - -type TextResponse struct { - Choices []OpenAITextResponseChoice `json:"choices"` - Usage `json:"usage"` - Error OpenAIError `json:"error"` - Content string `json:"-"` -} - -type OpenAITextResponseChoice struct { - Index int `json:"index"` - Message `json:"message"` - FinishReason string `json:"finish_reason"` -} - -type OpenAITextResponse struct { - Id string `json:"id"` - Model string `json:"model,omitempty"` - Object string `json:"object"` - Created int64 `json:"created"` - Choices []OpenAITextResponseChoice `json:"choices"` - Usage `json:"usage"` -} - -type OpenAIEmbeddingResponseItem struct { - Object string `json:"object"` - Index int `json:"index"` - Embedding []float64 `json:"embedding"` -} - -type OpenAIEmbeddingResponse struct { - Object string `json:"object"` - Data []OpenAIEmbeddingResponseItem `json:"data"` - Model string `json:"model"` - Usage `json:"usage"` -} - -type ImageResponse struct { - Created int `json:"created"` - Data []struct { - Url string `json:"url"` - } -} - -type ChatCompletionsStreamResponseChoice struct { - Delta struct { - Content string `json:"content"` - } `json:"delta"` - FinishReason *string `json:"finish_reason,omitempty"` -} - -type ChatCompletionsStreamResponse struct { - Id string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []ChatCompletionsStreamResponseChoice `json:"choices"` -} - -type CompletionsStreamResponse struct { - Choices []struct { - Text string `json:"text"` - FinishReason string `json:"finish_reason"` - } `json:"choices"` -} - func Relay(c *gin.Context) { - relayMode := RelayModeUnknown - if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") { - relayMode = RelayModeChatCompletions - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/completions") { - relayMode = RelayModeCompletions - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasSuffix(c.Request.URL.Path, "embeddings") { - relayMode = RelayModeEmbeddings - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { - relayMode = RelayModeModerations - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { - relayMode = RelayModeImagesGenerations - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/edits") { - relayMode = RelayModeEdits - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/speech") { - relayMode = RelayModeAudioSpeech - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") { - relayMode = RelayModeAudioTranscription - } else if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { - relayMode = RelayModeAudioTranslation - } - var err *OpenAIErrorWithStatusCode + relayMode := constant.Path2RelayMode(c.Request.URL.Path) + var err *openai.ErrorWithStatusCode switch relayMode { - case RelayModeImagesGenerations: - err = relayImageHelper(c, relayMode) - case RelayModeAudioSpeech: + case constant.RelayModeImagesGenerations: + err = controller.RelayImageHelper(c, relayMode) + case constant.RelayModeAudioSpeech: fallthrough - case RelayModeAudioTranslation: + case constant.RelayModeAudioTranslation: fallthrough - case RelayModeAudioTranscription: - err = relayAudioHelper(c, relayMode) + case constant.RelayModeAudioTranscription: + err = controller.RelayAudioHelper(c, relayMode) default: - err = relayTextHelper(c, relayMode) + err = controller.RelayTextHelper(c) } if err != nil { - requestId := c.GetString(common.RequestIdKey) + requestId := c.GetString(logger.RequestIdKey) retryTimesStr := c.Query("retry") retryTimes, _ := strconv.Atoi(retryTimesStr) if retryTimesStr == "" { - retryTimes = common.RetryTimes + retryTimes = config.RetryTimes } if retryTimes > 0 { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1)) } else { if err.StatusCode == http.StatusTooManyRequests { - err.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" + err.Error.Message = "当前分组上游负载已饱和,请稍后再试" } - err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) + err.Error.Message = helper.MessageWithRequestId(err.Error.Message, requestId) c.JSON(err.StatusCode, gin.H{ - "error": err.OpenAIError, + "error": err.Error, }) } channelId := c.GetInt("channel_id") - common.LogError(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %+v", channelId, err)) + logger.Error(c.Request.Context(), fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Message)) // https://platform.openai.com/docs/guides/error-codes/api-errors - if shouldDisableChannel(&err.OpenAIError, err.StatusCode) { + if util.ShouldDisableChannel(&err.Error, err.StatusCode) { channelId := c.GetInt("channel_id") channelName := c.GetString("channel_name") disableChannel(channelId, channelName, err.Message) @@ -454,7 +61,7 @@ func Relay(c *gin.Context) { } func RelayNotImplemented(c *gin.Context) { - err := OpenAIError{ + err := openai.Error{ Message: "API not implemented", Type: "one_api_error", Param: "", @@ -466,7 +73,7 @@ func RelayNotImplemented(c *gin.Context) { } func RelayNotFound(c *gin.Context) { - err := OpenAIError{ + err := openai.Error{ Message: fmt.Sprintf("Invalid URL (%s %s)", c.Request.Method, c.Request.URL.Path), Type: "invalid_request_error", Param: "", diff --git a/controller/relay_test.go b/controller/relay_test.go deleted file mode 100644 index e150c21a..00000000 --- a/controller/relay_test.go +++ /dev/null @@ -1,46 +0,0 @@ -package controller - -import ( - "encoding/json" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestGeneralOpenAIRequest_TextMessages(t *testing.T) { - tests := []struct { - name string - messages []Message - want []Message - wantErr error - }{ - { - name: "Test with []any messages", - messages: []Message{{}, {}}, - want: []Message{{}, {}}, - wantErr: nil, - }, - { - name: "Test with []Message messages", - messages: []Message{{}, {}}, - want: []Message{{}, {}}, - wantErr: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := &GeneralOpenAIRequest{ - Messages: tt.messages, - } - got := new(GeneralOpenAIRequest) - - blob, err := json.Marshal(r) - require.NoError(t, err) - err = json.Unmarshal(blob, got) - require.NoError(t, err) - - require.Equal(t, tt.want, got.Messages) - }) - } -} diff --git a/controller/token.go b/controller/token.go index e3b5f6f3..1e24d741 100644 --- a/controller/token.go +++ b/controller/token.go @@ -3,11 +3,13 @@ package controller import ( "fmt" "net/http" - "one-api/common" - "one-api/model" "strconv" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/model" ) func GetAllTokens(c *gin.Context) { @@ -16,7 +18,7 @@ func GetAllTokens(c *gin.Context) { if p < 0 { p = 0 } - tokens, err := model.GetAllUserTokens(userId, p*common.ItemsPerPage, common.ItemsPerPage) + tokens, err := model.GetAllUserTokens(userId, p*config.ItemsPerPage, config.ItemsPerPage) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -121,9 +123,9 @@ func AddToken(c *gin.Context) { cleanToken := model.Token{ UserId: c.GetInt("id"), Name: token.Name, - Key: common.GenerateKey(), - CreatedTime: common.GetTimestamp(), - AccessedTime: common.GetTimestamp(), + Key: helper.GenerateKey(), + CreatedTime: helper.GetTimestamp(), + AccessedTime: helper.GetTimestamp(), ExpiredTime: token.ExpiredTime, RemainQuota: token.RemainQuota, UnlimitedQuota: token.UnlimitedQuota, @@ -194,7 +196,7 @@ func UpdateToken(c *gin.Context) { return } - cleanToken, err := model.GetTokenByIds(tokenPatch.Id, userId) + tokenInDB, err := model.GetTokenByIds(tokenPatch.Id, userId) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -204,18 +206,16 @@ func UpdateToken(c *gin.Context) { } if tokenPatch.Status == common.TokenStatusEnabled { - if cleanToken.Status == common.TokenStatusExpired && - cleanToken.ExpiredTime <= common.GetTimestamp() && - cleanToken.ExpiredTime != -1 { + if tokenInDB.Status == common.TokenStatusExpired && tokenInDB.ExpiredTime <= helper.GetTimestamp() && tokenInDB.ExpiredTime != -1 { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌已过期,无法启用,请先修改令牌过期时间,或者设置为永不过期", }) return } - if cleanToken.Status == common.TokenStatusExhausted && - cleanToken.RemainQuota <= 0 && - !cleanToken.UnlimitedQuota { + if tokenInDB.Status == common.TokenStatusExhausted && + tokenInDB.RemainQuota <= 0 && + !tokenInDB.UnlimitedQuota { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "令牌可用额度已用尽,无法启用,请先修改令牌剩余额度,或者设置为无限额度", @@ -224,31 +224,31 @@ func UpdateToken(c *gin.Context) { } } if statusOnly != "" { - cleanToken.Status = tokenPatch.Status + tokenInDB.Status = tokenPatch.Status } else { // If you add more fields, please also update tokenPatch.Update() if tokenPatch.Name != nil { - cleanToken.Name = *tokenPatch.Name + tokenInDB.Name = *tokenPatch.Name } if tokenPatch.ExpiredTime != nil { - cleanToken.ExpiredTime = *tokenPatch.ExpiredTime + tokenInDB.ExpiredTime = *tokenPatch.ExpiredTime } if tokenPatch.RemainQuota != nil { - cleanToken.RemainQuota = *tokenPatch.RemainQuota + tokenInDB.RemainQuota = *tokenPatch.RemainQuota } if tokenPatch.UnlimitedQuota != nil { - cleanToken.UnlimitedQuota = *tokenPatch.UnlimitedQuota + tokenInDB.UnlimitedQuota = *tokenPatch.UnlimitedQuota } } - cleanToken.RemainQuota -= tokenPatch.AddUsedQuota - cleanToken.UsedQuota += tokenPatch.AddUsedQuota + tokenInDB.RemainQuota -= tokenPatch.AddUsedQuota + tokenInDB.UsedQuota += tokenPatch.AddUsedQuota if tokenPatch.AddUsedQuota != 0 { model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("外部(%s)消耗 %s", tokenPatch.AddReason, common.LogQuota(tokenPatch.AddUsedQuota))) } - if err = cleanToken.Update(); err != nil { + if err = tokenInDB.Update(); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, "message": "update token: " + err.Error(), @@ -259,7 +259,7 @@ func UpdateToken(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", - "data": cleanToken, + "data": tokenInDB, }) return } diff --git a/controller/user.go b/controller/user.go index 3e6a3827..243980e8 100644 --- a/controller/user.go +++ b/controller/user.go @@ -3,9 +3,11 @@ package controller import ( "encoding/json" "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/model" "net/http" - "one-api/common" - "one-api/model" "strconv" "time" @@ -19,7 +21,7 @@ type LoginRequest struct { } func Login(c *gin.Context) { - if !common.PasswordLoginEnabled { + if !config.PasswordLoginEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了密码登录", "success": false, @@ -106,14 +108,14 @@ func Logout(c *gin.Context) { } func Register(c *gin.Context) { - if !common.RegisterEnabled { + if !config.RegisterEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了新用户注册", "success": false, }) return } - if !common.PasswordRegisterEnabled { + if !config.PasswordRegisterEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员关闭了通过密码进行注册,请使用第三方账户验证的形式进行注册", "success": false, @@ -136,7 +138,7 @@ func Register(c *gin.Context) { }) return } - if common.EmailVerificationEnabled { + if config.EmailVerificationEnabled { if user.Email == "" || user.VerificationCode == "" { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -160,7 +162,7 @@ func Register(c *gin.Context) { DisplayName: user.Username, InviterId: inviterId, } - if common.EmailVerificationEnabled { + if config.EmailVerificationEnabled { cleanUser.Email = user.Email } if err := cleanUser.Insert(inviterId); err != nil { @@ -182,7 +184,7 @@ func GetAllUsers(c *gin.Context) { if p < 0 { p = 0 } - users, err := model.GetAllUsers(p*common.ItemsPerPage, common.ItemsPerPage) + users, err := model.GetAllUsers(p*config.ItemsPerPage, config.ItemsPerPage) if err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -282,7 +284,7 @@ func GenerateAccessToken(c *gin.Context) { }) return } - user.AccessToken = common.GetUUID() + user.AccessToken = helper.GetUUID() if model.DB.Where("access_token = ?", user.AccessToken).First(user).RowsAffected != 0 { c.JSON(http.StatusOK, gin.H{ @@ -319,7 +321,7 @@ func GetAffCode(c *gin.Context) { return } if user.AffCode == "" { - user.AffCode = common.GetRandomString(4) + user.AffCode = helper.GetRandomString(4) if err := user.Update(false); err != nil { c.JSON(http.StatusOK, gin.H{ "success": false, @@ -736,7 +738,7 @@ func EmailBind(c *gin.Context) { return } if user.Role == common.RoleRootUser { - common.RootUserEmail = email + config.RootUserEmail = email } c.JSON(http.StatusOK, gin.H{ "success": true, diff --git a/controller/wechat.go b/controller/wechat.go index ff4c9fb6..74be5604 100644 --- a/controller/wechat.go +++ b/controller/wechat.go @@ -5,9 +5,10 @@ import ( "errors" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/model" "net/http" - "one-api/common" - "one-api/model" "strconv" "time" ) @@ -22,11 +23,11 @@ func getWeChatIdByCode(code string) (string, error) { if code == "" { return "", errors.New("无效的参数") } - req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", common.WeChatServerAddress, code), nil) + req, err := http.NewRequest("GET", fmt.Sprintf("%s/api/wechat/user?code=%s", config.WeChatServerAddress, code), nil) if err != nil { return "", err } - req.Header.Set("Authorization", common.WeChatServerToken) + req.Header.Set("Authorization", config.WeChatServerToken) client := http.Client{ Timeout: 5 * time.Second, } @@ -50,7 +51,7 @@ func getWeChatIdByCode(code string) (string, error) { } func WeChatAuth(c *gin.Context) { - if !common.WeChatAuthEnabled { + if !config.WeChatAuthEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员未开启通过微信登录以及注册", "success": false, @@ -79,7 +80,7 @@ func WeChatAuth(c *gin.Context) { return } } else { - if common.RegisterEnabled { + if config.RegisterEnabled { user.Username = "wechat_" + strconv.Itoa(model.GetMaxUserId()+1) user.DisplayName = "WeChat User" user.Role = common.RoleCommonUser @@ -112,7 +113,7 @@ func WeChatAuth(c *gin.Context) { } func WeChatBind(c *gin.Context) { - if !common.WeChatAuthEnabled { + if !config.WeChatAuthEnabled { c.JSON(http.StatusOK, gin.H{ "message": "管理员未开启通过微信登录以及注册", "success": false, diff --git a/go.mod b/go.mod index 493edee1..c0adb67c 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,8 @@ -module one-api +module github.com/songquanpeng/one-api go 1.21 require ( - github.com/Laisky/errors/v2 v2.0.1 github.com/gin-contrib/cors v1.5.0 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -12,6 +11,7 @@ require ( github.com/go-playground/validator/v10 v10.16.0 github.com/go-redis/redis/v8 v8.11.5 github.com/google/uuid v1.5.0 + github.com/pkg/errors v0.9.1 github.com/pkoukk/tiktoken-go v0.1.6 github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.17.0 diff --git a/go.sum b/go.sum index 204f27f2..8ab19e07 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -github.com/Laisky/errors/v2 v2.0.1 h1:yqCBrRzaP012AMB+7fVlXrP34OWRHrSO/hZ38CFdH84= -github.com/Laisky/errors/v2 v2.0.1/go.mod h1:mTn1LHSmKm4CYug0rpYO7rz13dp/DKrtzlSELSrxvT0= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM= github.com/bytedance/sonic v1.10.1 h1:7a1wuFXL1cMy7a3f7/VFcEtriuXQnUBhtoVfOZiaysc= @@ -124,6 +122,8 @@ github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZO github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkoukk/tiktoken-go v0.1.6 h1:JF0TlJzhTbrI30wCvFuiw6FzP2+/bR+FIxUdgEAcUsw= github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -161,8 +161,6 @@ golang.org/x/crypto v0.17.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq golang.org/x/image v0.15.0 h1:kOELfmgrmJlw4Cdb7g/QGuB3CvDrXbqEIww/pNtNBm8= golang.org/x/image v0.15.0/go.mod h1:HUYqC05R2ZcZ3ejNQsIHQDQiwWM4JBqmm6MKANTp4LE= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.16.0 h1:7eBu7KsSvFDtSXUIDbh3aqlK4DPsZ1rByC8PFfBThos= -golang.org/x/net v0.16.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/main.go b/main.go index d248e9b4..4fbbf2a2 100644 --- a/main.go +++ b/main.go @@ -3,85 +3,88 @@ package main import ( "embed" "fmt" - "one-api/common" - "one-api/controller" - "one-api/middleware" - "one-api/model" - "one-api/router" "os" "strconv" "github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions/cookie" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/controller" + "github.com/songquanpeng/one-api/middleware" + "github.com/songquanpeng/one-api/model" + "github.com/songquanpeng/one-api/relay/channel/openai" + "github.com/songquanpeng/one-api/router" ) //go:embed web/build/* var buildFS embed.FS func main() { - common.SetupLogger() - common.SysLog(fmt.Sprintf("One API %s started", common.Version)) + logger.SetupLogger() + logger.SysLog(fmt.Sprintf("One API %s started", common.Version)) if os.Getenv("GIN_MODE") != "debug" { gin.SetMode(gin.ReleaseMode) } - if common.DebugEnabled { - common.SysLog("running in debug mode") + if config.DebugEnabled { + logger.SysLog("running in debug mode") } // Initialize SQL Database err := model.InitDB() if err != nil { - common.FatalLog(fmt.Sprintf("failed to initialize database: %+v", err)) + logger.FatalLog("failed to initialize database: " + err.Error()) } defer func() { err := model.CloseDB() if err != nil { - common.FatalLog(fmt.Sprintf("failed to close database: %+v", err)) + logger.FatalLog("failed to close database: " + err.Error()) } }() // Initialize Redis err = common.InitRedisClient() if err != nil { - common.FatalLog(fmt.Sprintf("failed to initialize Redis: %+v", err)) + logger.FatalLog("failed to initialize Redis: " + err.Error()) } // Initialize options model.InitOptionMap() - common.SysLog(fmt.Sprintf("using theme %s", common.Theme)) + logger.SysLog(fmt.Sprintf("using theme %s", config.Theme)) if common.RedisEnabled { // for compatibility with old versions - common.MemoryCacheEnabled = true + config.MemoryCacheEnabled = true } - if common.MemoryCacheEnabled { - common.SysLog("memory cache enabled") - common.SysError(fmt.Sprintf("sync frequency: %d seconds", common.SyncFrequency)) + if config.MemoryCacheEnabled { + logger.SysLog("memory cache enabled") + logger.SysError(fmt.Sprintf("sync frequency: %d seconds", config.SyncFrequency)) model.InitChannelCache() } - if common.MemoryCacheEnabled { - go model.SyncOptions(common.SyncFrequency) - go model.SyncChannelCache(common.SyncFrequency) + if config.MemoryCacheEnabled { + go model.SyncOptions(config.SyncFrequency) + go model.SyncChannelCache(config.SyncFrequency) } if os.Getenv("CHANNEL_UPDATE_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_UPDATE_FREQUENCY")) if err != nil { - common.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) + logger.FatalLog("failed to parse CHANNEL_UPDATE_FREQUENCY: " + err.Error()) } go controller.AutomaticallyUpdateChannels(frequency) } if os.Getenv("CHANNEL_TEST_FREQUENCY") != "" { frequency, err := strconv.Atoi(os.Getenv("CHANNEL_TEST_FREQUENCY")) if err != nil { - common.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) + logger.FatalLog("failed to parse CHANNEL_TEST_FREQUENCY: " + err.Error()) } go controller.AutomaticallyTestChannels(frequency) } if os.Getenv("BATCH_UPDATE_ENABLED") == "true" { - common.BatchUpdateEnabled = true - common.SysLog("batch update enabled with interval " + strconv.Itoa(common.BatchUpdateInterval) + "s") + config.BatchUpdateEnabled = true + logger.SysLog("batch update enabled with interval " + strconv.Itoa(config.BatchUpdateInterval) + "s") model.InitBatchUpdater() } - controller.InitTokenEncoders() + openai.InitTokenEncoders() // Initialize HTTP server server := gin.New() @@ -91,7 +94,7 @@ func main() { server.Use(middleware.RequestId()) middleware.SetUpLogger(server) // Initialize session store - store := cookie.NewStore([]byte(common.SessionSecret)) + store := cookie.NewStore([]byte(config.SessionSecret)) server.Use(sessions.Sessions("session", store)) router.SetRouter(server, buildFS) @@ -101,6 +104,6 @@ func main() { } err = server.Run(":" + port) if err != nil { - common.FatalLog("failed to start HTTP server: " + err.Error()) + logger.FatalLog("failed to start HTTP server: " + err.Error()) } } diff --git a/middleware/auth.go b/middleware/auth.go index e95174ec..ef896f09 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -3,9 +3,9 @@ package middleware import ( "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/model" "net/http" - "one-api/common" - "one-api/model" "strings" ) diff --git a/middleware/distributor.go b/middleware/distributor.go index 81338130..0ed250fd 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -2,9 +2,10 @@ package middleware import ( "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/model" "net/http" - "one-api/common" - "one-api/model" "strconv" "strings" @@ -69,7 +70,7 @@ func Distribute() func(c *gin.Context) { if err != nil { message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", userGroup, modelRequest.Model) if channel != nil { - common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) + logger.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id)) message = "数据库一致性已被破坏,请联系管理员" } abortWithMessage(c, http.StatusServiceUnavailable, message) diff --git a/middleware/logger.go b/middleware/logger.go index 02f2e0a9..6aae4f23 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -3,14 +3,14 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" - "one-api/common" + "github.com/songquanpeng/one-api/common/logger" ) func SetUpLogger(server *gin.Engine) { server.Use(gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { var requestID string if param.Keys != nil { - requestID = param.Keys[common.RequestIdKey].(string) + requestID = param.Keys[logger.RequestIdKey].(string) } return fmt.Sprintf("[GIN] %s | %s | %3d | %13v | %15s | %7s %s\n", param.TimeStamp.Format("2006/01/02 - 15:04:05"), diff --git a/middleware/rate-limit.go b/middleware/rate-limit.go index 8e5cff6c..0f300f2b 100644 --- a/middleware/rate-limit.go +++ b/middleware/rate-limit.go @@ -4,8 +4,9 @@ import ( "context" "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" "net/http" - "one-api/common" "time" ) @@ -26,7 +27,7 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st } if listLength < int64(maxRequestNum) { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) } else { oldTimeStr, _ := rdb.LIndex(ctx, key, -1).Result() oldTime, err := time.Parse(timeFormat, oldTimeStr) @@ -47,14 +48,14 @@ func redisRateLimiter(c *gin.Context, maxRequestNum int, duration int64, mark st // time.Since will return negative number! // See: https://stackoverflow.com/questions/50970900/why-is-time-since-returning-negative-durations-on-windows if int64(nowTime.Sub(oldTime).Seconds()) < duration { - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) c.Status(http.StatusTooManyRequests) c.Abort() return } else { rdb.LPush(ctx, key, time.Now().Format(timeFormat)) rdb.LTrim(ctx, key, 0, int64(maxRequestNum-1)) - rdb.Expire(ctx, key, common.RateLimitKeyExpirationDuration) + rdb.Expire(ctx, key, config.RateLimitKeyExpirationDuration) } } } @@ -75,7 +76,7 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi } } else { // It's safe to call multi times. - inMemoryRateLimiter.Init(common.RateLimitKeyExpirationDuration) + inMemoryRateLimiter.Init(config.RateLimitKeyExpirationDuration) return func(c *gin.Context) { memoryRateLimiter(c, maxRequestNum, duration, mark) } @@ -83,21 +84,21 @@ func rateLimitFactory(maxRequestNum int, duration int64, mark string) func(c *gi } func GlobalWebRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.GlobalWebRateLimitNum, common.GlobalWebRateLimitDuration, "GW") + return rateLimitFactory(config.GlobalWebRateLimitNum, config.GlobalWebRateLimitDuration, "GW") } func GlobalAPIRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.GlobalApiRateLimitNum, common.GlobalApiRateLimitDuration, "GA") + return rateLimitFactory(config.GlobalApiRateLimitNum, config.GlobalApiRateLimitDuration, "GA") } func CriticalRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.CriticalRateLimitNum, common.CriticalRateLimitDuration, "CT") + return rateLimitFactory(config.CriticalRateLimitNum, config.CriticalRateLimitDuration, "CT") } func DownloadRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.DownloadRateLimitNum, common.DownloadRateLimitDuration, "DW") + return rateLimitFactory(config.DownloadRateLimitNum, config.DownloadRateLimitDuration, "DW") } func UploadRateLimit() func(c *gin.Context) { - return rateLimitFactory(common.UploadRateLimitNum, common.UploadRateLimitDuration, "UP") + return rateLimitFactory(config.UploadRateLimitNum, config.UploadRateLimitDuration, "UP") } diff --git a/middleware/recover.go b/middleware/recover.go index 8338a514..02e3e3bb 100644 --- a/middleware/recover.go +++ b/middleware/recover.go @@ -3,8 +3,8 @@ package middleware import ( "fmt" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/logger" "net/http" - "one-api/common" "runtime/debug" ) @@ -12,8 +12,8 @@ func RelayPanicRecover() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if err := recover(); err != nil { - common.SysError(fmt.Sprintf("panic detected: %v", err)) - common.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) + logger.SysError(fmt.Sprintf("panic detected: %v", err)) + logger.SysError(fmt.Sprintf("stacktrace from panic: %s", string(debug.Stack()))) c.JSON(http.StatusInternalServerError, gin.H{ "error": gin.H{ "message": fmt.Sprintf("Panic detected, error: %v. Please submit a issue here: https://github.com/songquanpeng/one-api", err), diff --git a/middleware/request-id.go b/middleware/request-id.go index e623be7a..7cb66e93 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -3,16 +3,17 @@ package middleware import ( "context" "github.com/gin-gonic/gin" - "one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" ) func RequestId() func(c *gin.Context) { return func(c *gin.Context) { - id := common.GetTimeString() + common.GetRandomString(8) - c.Set(common.RequestIdKey, id) - ctx := context.WithValue(c.Request.Context(), common.RequestIdKey, id) + id := helper.GetTimeString() + helper.GetRandomString(8) + c.Set(logger.RequestIdKey, id) + ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id) c.Request = c.Request.WithContext(ctx) - c.Header(common.RequestIdKey, id) + c.Header(logger.RequestIdKey, id) c.Next() } } diff --git a/middleware/turnstile-check.go b/middleware/turnstile-check.go index 26688810..403bcb34 100644 --- a/middleware/turnstile-check.go +++ b/middleware/turnstile-check.go @@ -4,9 +4,10 @@ import ( "encoding/json" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" "net/http" "net/url" - "one-api/common" ) type turnstileCheckResponse struct { @@ -15,7 +16,7 @@ type turnstileCheckResponse struct { func TurnstileCheck() gin.HandlerFunc { return func(c *gin.Context) { - if common.TurnstileCheckEnabled { + if config.TurnstileCheckEnabled { session := sessions.Default(c) turnstileChecked := session.Get("turnstile") if turnstileChecked != nil { @@ -32,12 +33,12 @@ func TurnstileCheck() gin.HandlerFunc { return } rawRes, err := http.PostForm("https://challenges.cloudflare.com/turnstile/v0/siteverify", url.Values{ - "secret": {common.TurnstileSecretKey}, + "secret": {config.TurnstileSecretKey}, "response": {response}, "remoteip": {c.ClientIP()}, }) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), @@ -49,7 +50,7 @@ func TurnstileCheck() gin.HandlerFunc { var res turnstileCheckResponse err = json.NewDecoder(rawRes.Body).Decode(&res) if err != nil { - common.SysError(err.Error()) + logger.SysError(err.Error()) c.JSON(http.StatusOK, gin.H{ "success": false, "message": err.Error(), diff --git a/middleware/utils.go b/middleware/utils.go index 536125cc..bc14c367 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -2,16 +2,17 @@ package middleware import ( "github.com/gin-gonic/gin" - "one-api/common" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { c.JSON(statusCode, gin.H{ "error": gin.H{ - "message": common.MessageWithRequestId(message, c.GetString(common.RequestIdKey)), + "message": helper.MessageWithRequestId(message, c.GetString(logger.RequestIdKey)), "type": "one_api_error", }, }) c.Abort() - common.LogError(c.Request.Context(), message) + logger.Error(c.Request.Context(), message) } diff --git a/model/ability.go b/model/ability.go index 3da83be8..7127abc3 100644 --- a/model/ability.go +++ b/model/ability.go @@ -1,7 +1,7 @@ package model import ( - "one-api/common" + "github.com/songquanpeng/one-api/common" "strings" ) diff --git a/model/cache.go b/model/cache.go index c6d0c70a..297df153 100644 --- a/model/cache.go +++ b/model/cache.go @@ -4,8 +4,10 @@ import ( "encoding/json" "errors" "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" "math/rand" - "one-api/common" "sort" "strconv" "strings" @@ -14,10 +16,10 @@ import ( ) var ( - TokenCacheSeconds = common.SyncFrequency - UserId2GroupCacheSeconds = common.SyncFrequency - UserId2QuotaCacheSeconds = common.SyncFrequency - UserId2StatusCacheSeconds = common.SyncFrequency + TokenCacheSeconds = config.SyncFrequency + UserId2GroupCacheSeconds = config.SyncFrequency + UserId2QuotaCacheSeconds = config.SyncFrequency + UserId2StatusCacheSeconds = config.SyncFrequency ) func CacheGetTokenByKey(key string) (*Token, error) { @@ -42,7 +44,7 @@ func CacheGetTokenByKey(key string) (*Token, error) { } err = common.RedisSet(fmt.Sprintf("token:%s", key), string(jsonBytes), time.Duration(TokenCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set token error: " + err.Error()) + logger.SysError("Redis set token error: " + err.Error()) } return &token, nil } @@ -62,7 +64,7 @@ func CacheGetUserGroup(id int) (group string, err error) { } err = common.RedisSet(fmt.Sprintf("user_group:%d", id), group, time.Duration(UserId2GroupCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set user group error: " + err.Error()) + logger.SysError("Redis set user group error: " + err.Error()) } } return group, err @@ -80,7 +82,7 @@ func CacheGetUserQuota(id int) (quota int, err error) { } err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set user quota error: " + err.Error()) + logger.SysError("Redis set user quota error: " + err.Error()) } return quota, err } @@ -127,7 +129,7 @@ func CacheIsUserEnabled(userId int) (bool, error) { } err = common.RedisSet(fmt.Sprintf("user_enabled:%d", userId), enabled, time.Duration(UserId2StatusCacheSeconds)*time.Second) if err != nil { - common.SysError("Redis set user enabled error: " + err.Error()) + logger.SysError("Redis set user enabled error: " + err.Error()) } return userEnabled, err } @@ -178,19 +180,19 @@ func InitChannelCache() { channelSyncLock.Lock() group2model2channels = newGroup2model2channels channelSyncLock.Unlock() - common.SysLog("channels synced from database") + logger.SysLog("channels synced from database") } func SyncChannelCache(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing channels from database") + logger.SysLog("syncing channels from database") InitChannelCache() } } func CacheGetRandomSatisfiedChannel(group string, model string) (*Channel, error) { - if !common.MemoryCacheEnabled { + if !config.MemoryCacheEnabled { return GetRandomSatisfiedChannel(group, model) } channelSyncLock.RLock() diff --git a/model/channel.go b/model/channel.go index 7e7b42e6..0503a620 100644 --- a/model/channel.go +++ b/model/channel.go @@ -1,8 +1,13 @@ package model import ( + "encoding/json" + "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" - "one-api/common" ) type Channel struct { @@ -42,7 +47,7 @@ func SearchChannels(keyword string) (channels []*Channel, err error) { if common.UsingPostgreSQL { keyCol = `"key"` } - err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", common.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error + err = DB.Omit("key").Where("id = ? or name LIKE ? or "+keyCol+" = ?", helper.String2Int(keyword), keyword+"%", keyword).Find(&channels).Error return channels, err } @@ -86,11 +91,17 @@ func (channel *Channel) GetBaseURL() string { return *channel.BaseURL } -func (channel *Channel) GetModelMapping() string { - if channel.ModelMapping == nil { - return "" +func (channel *Channel) GetModelMapping() map[string]string { + if channel.ModelMapping == nil || *channel.ModelMapping == "" || *channel.ModelMapping == "{}" { + return nil } - return *channel.ModelMapping + modelMapping := make(map[string]string) + err := json.Unmarshal([]byte(*channel.ModelMapping), &modelMapping) + if err != nil { + logger.SysError(fmt.Sprintf("failed to unmarshal model mapping for channel %d, error: %s", channel.Id, err.Error())) + return nil + } + return modelMapping } func (channel *Channel) Insert() error { @@ -116,21 +127,21 @@ func (channel *Channel) Update() error { func (channel *Channel) UpdateResponseTime(responseTime int64) { err := DB.Model(channel).Select("response_time", "test_time").Updates(Channel{ - TestTime: common.GetTimestamp(), + TestTime: helper.GetTimestamp(), ResponseTime: int(responseTime), }).Error if err != nil { - common.SysError("failed to update response time: " + err.Error()) + logger.SysError("failed to update response time: " + err.Error()) } } func (channel *Channel) UpdateBalance(balance float64) { err := DB.Model(channel).Select("balance_updated_time", "balance").Updates(Channel{ - BalanceUpdatedTime: common.GetTimestamp(), + BalanceUpdatedTime: helper.GetTimestamp(), Balance: balance, }).Error if err != nil { - common.SysError("failed to update balance: " + err.Error()) + logger.SysError("failed to update balance: " + err.Error()) } } @@ -147,16 +158,16 @@ func (channel *Channel) Delete() error { func UpdateChannelStatusById(id int, status int) { err := UpdateAbilityStatus(id, status == common.ChannelStatusEnabled) if err != nil { - common.SysError("failed to update ability status: " + err.Error()) + logger.SysError("failed to update ability status: " + err.Error()) } err = DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error if err != nil { - common.SysError("failed to update channel status: " + err.Error()) + logger.SysError("failed to update channel status: " + err.Error()) } } func UpdateChannelUsedQuota(id int, quota int) { - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota) return } @@ -166,7 +177,7 @@ func UpdateChannelUsedQuota(id int, quota int) { func updateChannelUsedQuota(id int, quota int) { err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error if err != nil { - common.SysError("failed to update channel used quota: " + err.Error()) + logger.SysError("failed to update channel used quota: " + err.Error()) } } diff --git a/model/log.go b/model/log.go index 307928c4..9615c237 100644 --- a/model/log.go +++ b/model/log.go @@ -3,15 +3,18 @@ package model import ( "context" "fmt" - "one-api/common" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" ) type Log struct { - Id int `json:"id;index:idx_created_at_id,priority:1"` + Id int `json:"id"` UserId int `json:"user_id" gorm:"index"` - CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_id,priority:2;index:idx_created_at_type"` + CreatedAt int64 `json:"created_at" gorm:"bigint;index:idx_created_at_type"` Type int `json:"type" gorm:"index:idx_created_at_type"` Content string `json:"content"` Username string `json:"username" gorm:"index:index_username_model_name,priority:2;default:''"` @@ -32,31 +35,31 @@ const ( ) func RecordLog(userId int, logType int, content string) { - if logType == LogTypeConsume && !common.LogConsumeEnabled { + if logType == LogTypeConsume && !config.LogConsumeEnabled { return } log := &Log{ UserId: userId, Username: GetUsernameById(userId), - CreatedAt: common.GetTimestamp(), + CreatedAt: helper.GetTimestamp(), Type: logType, Content: content, } err := DB.Create(log).Error if err != nil { - common.SysError("failed to record log: " + err.Error()) + logger.SysError("failed to record log: " + err.Error()) } } func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) { - common.LogInfo(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) - if !common.LogConsumeEnabled { + logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content)) + if !config.LogConsumeEnabled { return } log := &Log{ UserId: userId, Username: GetUsernameById(userId), - CreatedAt: common.GetTimestamp(), + CreatedAt: helper.GetTimestamp(), Type: LogTypeConsume, Content: content, PromptTokens: promptTokens, @@ -68,7 +71,7 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke } err := DB.Create(log).Error if err != nil { - common.LogError(ctx, "failed to record log: "+err.Error()) + logger.Error(ctx, "failed to record log: "+err.Error()) } } @@ -125,12 +128,12 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int } func SearchAllLogs(keyword string) (logs []*Log, err error) { - err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(common.MaxRecentItems).Find(&logs).Error + err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error return logs, err } func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) { - err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(common.MaxRecentItems).Omit("id").Find(&logs).Error + err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error return logs, err } diff --git a/model/main.go b/model/main.go index 01c103a4..38adb8b6 100644 --- a/model/main.go +++ b/model/main.go @@ -2,24 +2,28 @@ package model import ( "fmt" - "github.com/Laisky/errors/v2" + "os" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/gorm" - "one-api/common" - "os" - "strings" - "time" ) var DB *gorm.DB func createRootAccountIfNeed() error { var user User - //if user.Status != common.UserStatusEnabled { + //if user.Status != util.UserStatusEnabled { if err := DB.First(&user).Error; err != nil { - common.SysLog("no user exists, create a root user for you: username is root, password is 123456") + logger.SysLog("no user exists, create a root user for you: username is root, password is 123456") hashedPassword, err := common.Password2Hash("123456") if err != nil { return errors.WithStack(err) @@ -30,7 +34,7 @@ func createRootAccountIfNeed() error { Role: common.RoleRootUser, Status: common.UserStatusEnabled, DisplayName: "Root User", - AccessToken: common.GetUUID(), + AccessToken: helper.GetUUID(), Quota: 100000000, } DB.Create(&rootUser) @@ -43,7 +47,7 @@ func chooseDB() (*gorm.DB, error) { dsn := os.Getenv("SQL_DSN") if strings.HasPrefix(dsn, "postgres://") { // Use PostgreSQL - common.SysLog("using PostgreSQL as database") + logger.SysLog("using PostgreSQL as database") common.UsingPostgreSQL = true return gorm.Open(postgres.New(postgres.Config{ DSN: dsn, @@ -53,13 +57,13 @@ func chooseDB() (*gorm.DB, error) { }) } // Use MySQL - common.SysLog("using MySQL as database") + logger.SysLog("using MySQL as database") return gorm.Open(mysql.Open(dsn), &gorm.Config{ PrepareStmt: true, // precompile SQL }) } // Use SQLite - common.SysLog("SQL_DSN not set, using SQLite as database") + logger.SysLog("SQL_DSN not set, using SQLite as database") common.UsingSQLite = true config := fmt.Sprintf("?_busy_timeout=%d", common.SQLiteBusyTimeout) return gorm.Open(sqlite.Open(common.SQLitePath+config), &gorm.Config{ @@ -70,7 +74,7 @@ func chooseDB() (*gorm.DB, error) { func InitDB() (err error) { db, err := chooseDB() if err == nil { - if common.DebugEnabled { + if config.DebugEnabled { db = db.Debug() } DB = db @@ -78,14 +82,14 @@ func InitDB() (err error) { if err != nil { return errors.WithStack(err) } - sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) + sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60))) - if !common.IsMasterNode { + if !config.IsMasterNode { return nil } - common.SysLog("database migration started") + logger.SysLog("database migration started") err = db.AutoMigrate(&Channel{}) if err != nil { return errors.WithStack(err) @@ -114,11 +118,11 @@ func InitDB() (err error) { if err != nil { return errors.WithStack(err) } - common.SysLog("database migrated") + logger.SysLog("database migrated") err = createRootAccountIfNeed() return errors.WithStack(err) } else { - common.FatalLog(err) + logger.FatalLog(err) } return errors.WithStack(err) } diff --git a/model/option.go b/model/option.go index 20575c9a..6002c795 100644 --- a/model/option.go +++ b/model/option.go @@ -1,7 +1,9 @@ package model import ( - "one-api/common" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/logger" "strconv" "strings" "time" @@ -20,60 +22,57 @@ func AllOption() ([]*Option, error) { } func InitOptionMap() { - common.OptionMapRWMutex.Lock() - common.OptionMap = make(map[string]string) - common.OptionMap["FileUploadPermission"] = strconv.Itoa(common.FileUploadPermission) - common.OptionMap["FileDownloadPermission"] = strconv.Itoa(common.FileDownloadPermission) - common.OptionMap["ImageUploadPermission"] = strconv.Itoa(common.ImageUploadPermission) - common.OptionMap["ImageDownloadPermission"] = strconv.Itoa(common.ImageDownloadPermission) - common.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(common.PasswordLoginEnabled) - common.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(common.PasswordRegisterEnabled) - common.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(common.EmailVerificationEnabled) - common.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(common.GitHubOAuthEnabled) - common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) - common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) - common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) - common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) - common.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(common.AutomaticEnableChannelEnabled) - common.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(common.ApproximateTokenEnabled) - common.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(common.LogConsumeEnabled) - common.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(common.DisplayInCurrencyEnabled) - common.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(common.DisplayTokenStatEnabled) - common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) - common.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(common.EmailDomainRestrictionEnabled) - common.OptionMap["EmailDomainWhitelist"] = strings.Join(common.EmailDomainWhitelist, ",") - common.OptionMap["SMTPServer"] = "" - common.OptionMap["SMTPFrom"] = "" - common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) - common.OptionMap["SMTPAccount"] = "" - common.OptionMap["SMTPToken"] = "" - common.OptionMap["Notice"] = "" - common.OptionMap["About"] = "" - common.OptionMap["HomePageContent"] = "" - common.OptionMap["Footer"] = common.Footer - common.OptionMap["SystemName"] = common.SystemName - common.OptionMap["Logo"] = common.Logo - common.OptionMap["ServerAddress"] = "" - common.OptionMap["GitHubClientId"] = "" - common.OptionMap["GitHubClientSecret"] = "" - common.OptionMap["WeChatServerAddress"] = "" - common.OptionMap["WeChatServerToken"] = "" - common.OptionMap["WeChatAccountQRCodeImageURL"] = "" - common.OptionMap["TurnstileSiteKey"] = "" - common.OptionMap["TurnstileSecretKey"] = "" - common.OptionMap["QuotaForNewUser"] = strconv.Itoa(common.QuotaForNewUser) - common.OptionMap["QuotaForInviter"] = strconv.Itoa(common.QuotaForInviter) - common.OptionMap["QuotaForInvitee"] = strconv.Itoa(common.QuotaForInvitee) - common.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(common.QuotaRemindThreshold) - common.OptionMap["PreConsumedQuota"] = strconv.Itoa(common.PreConsumedQuota) - common.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() - common.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() - common.OptionMap["TopUpLink"] = common.TopUpLink - common.OptionMap["ChatLink"] = common.ChatLink - common.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(common.QuotaPerUnit, 'f', -1, 64) - common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes) - common.OptionMap["Theme"] = common.Theme - common.OptionMapRWMutex.Unlock() + config.OptionMapRWMutex.Lock() + config.OptionMap = make(map[string]string) + config.OptionMap["PasswordLoginEnabled"] = strconv.FormatBool(config.PasswordLoginEnabled) + config.OptionMap["PasswordRegisterEnabled"] = strconv.FormatBool(config.PasswordRegisterEnabled) + config.OptionMap["EmailVerificationEnabled"] = strconv.FormatBool(config.EmailVerificationEnabled) + config.OptionMap["GitHubOAuthEnabled"] = strconv.FormatBool(config.GitHubOAuthEnabled) + config.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(config.WeChatAuthEnabled) + config.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(config.TurnstileCheckEnabled) + config.OptionMap["RegisterEnabled"] = strconv.FormatBool(config.RegisterEnabled) + config.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(config.AutomaticDisableChannelEnabled) + config.OptionMap["AutomaticEnableChannelEnabled"] = strconv.FormatBool(config.AutomaticEnableChannelEnabled) + config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled) + config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled) + config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled) + config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled) + config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64) + config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled) + config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",") + config.OptionMap["SMTPServer"] = "" + config.OptionMap["SMTPFrom"] = "" + config.OptionMap["SMTPPort"] = strconv.Itoa(config.SMTPPort) + config.OptionMap["SMTPAccount"] = "" + config.OptionMap["SMTPToken"] = "" + config.OptionMap["Notice"] = "" + config.OptionMap["About"] = "" + config.OptionMap["HomePageContent"] = "" + config.OptionMap["Footer"] = config.Footer + config.OptionMap["SystemName"] = config.SystemName + config.OptionMap["Logo"] = config.Logo + config.OptionMap["ServerAddress"] = "" + config.OptionMap["GitHubClientId"] = "" + config.OptionMap["GitHubClientSecret"] = "" + config.OptionMap["WeChatServerAddress"] = "" + config.OptionMap["WeChatServerToken"] = "" + config.OptionMap["WeChatAccountQRCodeImageURL"] = "" + config.OptionMap["TurnstileSiteKey"] = "" + config.OptionMap["TurnstileSecretKey"] = "" + config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser) + config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter) + config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee) + config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold) + config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota) + config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString() + config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString() + config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString() + config.OptionMap["TopUpLink"] = config.TopUpLink + config.OptionMap["ChatLink"] = config.ChatLink + config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64) + config.OptionMap["RetryTimes"] = strconv.Itoa(config.RetryTimes) + config.OptionMap["Theme"] = config.Theme + config.OptionMapRWMutex.Unlock() loadOptionsFromDatabase() } @@ -82,7 +81,7 @@ func loadOptionsFromDatabase() { for _, option := range options { err := updateOptionMap(option.Key, option.Value) if err != nil { - common.SysError("failed to update option map: " + err.Error()) + logger.SysError("failed to update option map: " + err.Error()) } } } @@ -90,7 +89,7 @@ func loadOptionsFromDatabase() { func SyncOptions(frequency int) { for { time.Sleep(time.Duration(frequency) * time.Second) - common.SysLog("syncing options from database") + logger.SysLog("syncing options from database") loadOptionsFromDatabase() } } @@ -112,117 +111,106 @@ func UpdateOption(key string, value string) error { } func updateOptionMap(key string, value string) (err error) { - common.OptionMapRWMutex.Lock() - defer common.OptionMapRWMutex.Unlock() - common.OptionMap[key] = value - if strings.HasSuffix(key, "Permission") { - intValue, _ := strconv.Atoi(value) - switch key { - case "FileUploadPermission": - common.FileUploadPermission = intValue - case "FileDownloadPermission": - common.FileDownloadPermission = intValue - case "ImageUploadPermission": - common.ImageUploadPermission = intValue - case "ImageDownloadPermission": - common.ImageDownloadPermission = intValue - } - } + config.OptionMapRWMutex.Lock() + defer config.OptionMapRWMutex.Unlock() + config.OptionMap[key] = value if strings.HasSuffix(key, "Enabled") { boolValue := value == "true" switch key { case "PasswordRegisterEnabled": - common.PasswordRegisterEnabled = boolValue + config.PasswordRegisterEnabled = boolValue case "PasswordLoginEnabled": - common.PasswordLoginEnabled = boolValue + config.PasswordLoginEnabled = boolValue case "EmailVerificationEnabled": - common.EmailVerificationEnabled = boolValue + config.EmailVerificationEnabled = boolValue case "GitHubOAuthEnabled": - common.GitHubOAuthEnabled = boolValue + config.GitHubOAuthEnabled = boolValue case "WeChatAuthEnabled": - common.WeChatAuthEnabled = boolValue + config.WeChatAuthEnabled = boolValue case "TurnstileCheckEnabled": - common.TurnstileCheckEnabled = boolValue + config.TurnstileCheckEnabled = boolValue case "RegisterEnabled": - common.RegisterEnabled = boolValue + config.RegisterEnabled = boolValue case "EmailDomainRestrictionEnabled": - common.EmailDomainRestrictionEnabled = boolValue + config.EmailDomainRestrictionEnabled = boolValue case "AutomaticDisableChannelEnabled": - common.AutomaticDisableChannelEnabled = boolValue + config.AutomaticDisableChannelEnabled = boolValue case "AutomaticEnableChannelEnabled": - common.AutomaticEnableChannelEnabled = boolValue + config.AutomaticEnableChannelEnabled = boolValue case "ApproximateTokenEnabled": - common.ApproximateTokenEnabled = boolValue + config.ApproximateTokenEnabled = boolValue case "LogConsumeEnabled": - common.LogConsumeEnabled = boolValue + config.LogConsumeEnabled = boolValue case "DisplayInCurrencyEnabled": - common.DisplayInCurrencyEnabled = boolValue + config.DisplayInCurrencyEnabled = boolValue case "DisplayTokenStatEnabled": - common.DisplayTokenStatEnabled = boolValue + config.DisplayTokenStatEnabled = boolValue } } switch key { case "EmailDomainWhitelist": - common.EmailDomainWhitelist = strings.Split(value, ",") + config.EmailDomainWhitelist = strings.Split(value, ",") case "SMTPServer": - common.SMTPServer = value + config.SMTPServer = value case "SMTPPort": intValue, _ := strconv.Atoi(value) - common.SMTPPort = intValue + config.SMTPPort = intValue case "SMTPAccount": - common.SMTPAccount = value + config.SMTPAccount = value case "SMTPFrom": - common.SMTPFrom = value + config.SMTPFrom = value case "SMTPToken": - common.SMTPToken = value + config.SMTPToken = value case "ServerAddress": - common.ServerAddress = value + config.ServerAddress = value case "GitHubClientId": - common.GitHubClientId = value + config.GitHubClientId = value case "GitHubClientSecret": - common.GitHubClientSecret = value + config.GitHubClientSecret = value case "Footer": - common.Footer = value + config.Footer = value case "SystemName": - common.SystemName = value + config.SystemName = value case "Logo": - common.Logo = value + config.Logo = value case "WeChatServerAddress": - common.WeChatServerAddress = value + config.WeChatServerAddress = value case "WeChatServerToken": - common.WeChatServerToken = value + config.WeChatServerToken = value case "WeChatAccountQRCodeImageURL": - common.WeChatAccountQRCodeImageURL = value + config.WeChatAccountQRCodeImageURL = value case "TurnstileSiteKey": - common.TurnstileSiteKey = value + config.TurnstileSiteKey = value case "TurnstileSecretKey": - common.TurnstileSecretKey = value + config.TurnstileSecretKey = value case "QuotaForNewUser": - common.QuotaForNewUser, _ = strconv.Atoi(value) + config.QuotaForNewUser, _ = strconv.Atoi(value) case "QuotaForInviter": - common.QuotaForInviter, _ = strconv.Atoi(value) + config.QuotaForInviter, _ = strconv.Atoi(value) case "QuotaForInvitee": - common.QuotaForInvitee, _ = strconv.Atoi(value) + config.QuotaForInvitee, _ = strconv.Atoi(value) case "QuotaRemindThreshold": - common.QuotaRemindThreshold, _ = strconv.Atoi(value) + config.QuotaRemindThreshold, _ = strconv.Atoi(value) case "PreConsumedQuota": - common.PreConsumedQuota, _ = strconv.Atoi(value) + config.PreConsumedQuota, _ = strconv.Atoi(value) case "RetryTimes": - common.RetryTimes, _ = strconv.Atoi(value) + config.RetryTimes, _ = strconv.Atoi(value) case "ModelRatio": err = common.UpdateModelRatioByJSONString(value) case "GroupRatio": err = common.UpdateGroupRatioByJSONString(value) + case "CompletionRatio": + err = common.UpdateCompletionRatioByJSONString(value) case "TopUpLink": - common.TopUpLink = value + config.TopUpLink = value case "ChatLink": - common.ChatLink = value + config.ChatLink = value case "ChannelDisableThreshold": - common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) + config.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) case "QuotaPerUnit": - common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) + config.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) case "Theme": - common.Theme = value + config.Theme = value } return err } diff --git a/model/redemption.go b/model/redemption.go index f16412b5..2c5a4141 100644 --- a/model/redemption.go +++ b/model/redemption.go @@ -3,8 +3,9 @@ package model import ( "errors" "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/helper" "gorm.io/gorm" - "one-api/common" ) type Redemption struct { @@ -67,7 +68,7 @@ func Redeem(key string, userId int) (quota int, err error) { if err != nil { return err } - redemption.RedeemedTime = common.GetTimestamp() + redemption.RedeemedTime = helper.GetTimestamp() redemption.Status = common.RedemptionCodeStatusUsed err = tx.Save(redemption).Error return err diff --git a/model/token.go b/model/token.go index 2e53ac0b..d0a0648a 100644 --- a/model/token.go +++ b/model/token.go @@ -3,8 +3,11 @@ package model import ( "errors" "fmt" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/config" + "github.com/songquanpeng/one-api/common/helper" + "github.com/songquanpeng/one-api/common/logger" "gorm.io/gorm" - "one-api/common" ) type Token struct { @@ -39,7 +42,7 @@ func ValidateUserToken(key string) (token *Token, err error) { } token, err = CacheGetTokenByKey(key) if err != nil { - common.SysError("CacheGetTokenByKey failed: " + err.Error()) + logger.SysError("CacheGetTokenByKey failed: " + err.Error()) if errors.Is(err, gorm.ErrRecordNotFound) { return nil, errors.New("无效的令牌") } @@ -53,12 +56,12 @@ func ValidateUserToken(key string) (token *Token, err error) { if token.Status != common.TokenStatusEnabled { return nil, errors.New("该令牌状态不可用") } - if token.ExpiredTime != -1 && token.ExpiredTime < common.GetTimestamp() { + if token.ExpiredTime != -1 && token.ExpiredTime < helper.GetTimestamp() { if !common.RedisEnabled { token.Status = common.TokenStatusExpired err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + logger.SysError("failed to update token status" + err.Error()) } } return nil, errors.New("该令牌已过期") @@ -69,7 +72,7 @@ func ValidateUserToken(key string) (token *Token, err error) { token.Status = common.TokenStatusExhausted err := token.SelectUpdate() if err != nil { - common.SysError("failed to update token status" + err.Error()) + logger.SysError("failed to update token status" + err.Error()) } } return nil, errors.New("该令牌额度已用尽") @@ -138,7 +141,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeTokenQuota, id, quota) return nil } @@ -150,7 +153,7 @@ func increaseTokenQuota(id int, quota int) (err error) { map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota + ?", quota), "used_quota": gorm.Expr("used_quota - ?", quota), - "accessed_time": common.GetTimestamp(), + "accessed_time": helper.GetTimestamp(), }, ).Error return err @@ -160,7 +163,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } - if common.BatchUpdateEnabled { + if config.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeTokenQuota, id, -quota) return nil } @@ -172,7 +175,7 @@ func decreaseTokenQuota(id int, quota int) (err error) { map[string]interface{}{ "remain_quota": gorm.Expr("remain_quota - ?", quota), "used_quota": gorm.Expr("used_quota + ?", quota), - "accessed_time": common.GetTimestamp(), + "accessed_time": helper.GetTimestamp(), }, ).Error return err @@ -196,24 +199,24 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) { if userQuota < quota { return errors.New("用户额度不足") } - quotaTooLow := userQuota >= common.QuotaRemindThreshold && userQuota-quota < common.QuotaRemindThreshold + quotaTooLow := userQuota >= config.QuotaRemindThreshold && userQuota-quota < config.QuotaRemindThreshold noMoreQuota := userQuota-quota <= 0 if quotaTooLow || noMoreQuota { go func() { email, err := GetUserEmail(token.UserId) if err != nil { - common.SysError("failed to fetch user email: " + err.Error()) + logger.SysError("failed to fetch user email: " + err.Error()) } prompt := "您的额度即将用尽" if noMoreQuota { prompt = "您的额度已用尽" } if email != "" { - topUpLink := fmt.Sprintf("%s/topup", common.ServerAddress) + topUpLink := fmt.Sprintf("%s/topup", config.ServerAddress) err = common.SendEmail(prompt, email, fmt.Sprintf("%s,当前剩余额度为 %d,为了不影响您的使用,请及时充值。