feat: support configuration file (#117)

* ♻️ refactor: move file directory

* ♻️ refactor: move file directory

* ♻️ refactor: support multiple config methods

* 🔥 del: remove unused code

* 💩 refactor: Refactor channel management and synchronization

* 💄 improve: add channel website

*  feat: allow recording 0 consumption
This commit is contained in:
Buer
2024-03-20 14:12:47 +08:00
committed by GitHub
parent 0409de0ea9
commit 71171c63f5
50 changed files with 581 additions and 481 deletions

54
common/config/config.go Normal file
View File

@@ -0,0 +1,54 @@
package config
import (
"strings"
"time"
"one-api/common"
"github.com/spf13/viper"
)
func InitConf() {
flagConfig()
defaultConfig()
setConfigFile()
setEnv()
if viper.GetBool("debug") {
common.SysLog("running in debug mode")
}
common.IsMasterNode = viper.GetString("NODE_TYPE") != "slave"
common.RequestInterval = time.Duration(viper.GetInt("POLLING_INTERVAL")) * time.Second
common.SessionSecret = common.GetOrDefault("SESSION_SECRET", common.SessionSecret)
}
func setConfigFile() {
if !common.IsFileExist(*config) {
return
}
viper.SetConfigFile(*config)
if err := viper.ReadInConfig(); err != nil {
panic(err)
}
}
func setEnv() {
viper.AutomaticEnv()
viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
}
func defaultConfig() {
viper.SetDefault("port", "3000")
viper.SetDefault("gin_mode", "release")
viper.SetDefault("log_dir", "./logs")
viper.SetDefault("sqlite_path", "one-api.db")
viper.SetDefault("sqlite_busy_timeout", 3000)
viper.SetDefault("sync_frequency", 600)
viper.SetDefault("batch_update_interval", 5)
viper.SetDefault("global.api_rate_limit", 180)
viper.SetDefault("global.web_rate_limit", 100)
viper.SetDefault("connect_timeout", 5)
}

49
common/config/flag.go Normal file
View File

@@ -0,0 +1,49 @@
package config
import (
"flag"
"fmt"
"one-api/common"
"os"
"github.com/spf13/viper"
)
var (
port = flag.Int("port", 0, "the listening port")
printVersion = flag.Bool("version", false, "print version and exit")
printHelp = flag.Bool("help", false, "print help and exit")
logDir = flag.String("log-dir", "", "specify the log directory")
config = flag.String("config", "config.yaml", "specify the config.yaml path")
)
func flagConfig() {
flag.Parse()
if *printVersion {
fmt.Println(common.Version)
os.Exit(0)
}
if *printHelp {
help()
os.Exit(0)
}
if *port != 0 {
viper.Set("port", *port)
}
if *logDir != "" {
viper.Set("log_dir", *logDir)
}
}
func help() {
fmt.Println("One API " + common.Version + " - All in one API service for OpenAI API.")
fmt.Println("Copyright (C) 2024 MartialBE. All rights reserved.")
fmt.Println("Original copyright holder: JustSong")
fmt.Println("GitHub: https://github.com/MartialBE/one-api")
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--config <config.yaml path>] [--version] [--help]")
}

View File

@@ -1,8 +1,6 @@
package common
import (
"os"
"strconv"
"sync"
"time"
@@ -52,8 +50,7 @@ var EmailDomainWhitelist = []string{
"foxmail.com",
}
var DebugEnabled = os.Getenv("DEBUG") == "true"
var MemoryCacheEnabled = os.Getenv("MEMORY_CACHE_ENABLED") == "true"
var MemoryCacheEnabled = false
var LogConsumeEnabled = true
@@ -88,22 +85,12 @@ var RetryCooldownSeconds = 5
var RootUserEmail = ""
var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var IsMasterNode = true
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 10*60) // unit is second
var RequestInterval time.Duration
var BatchUpdateEnabled = false
var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 600) // unit is second
var ConnectTimeout = GetOrDefault("CONNECT_TIMEOUT", 5) // unit is second
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
var BatchUpdateInterval = 5
const (
RoleGuestUser = 0
@@ -112,32 +99,6 @@ const (
RoleRootUser = 100
)
var (
FileUploadPermission = RoleGuestUser
FileDownloadPermission = RoleGuestUser
ImageUploadPermission = RoleGuestUser
ImageDownloadPermission = RoleGuestUser
)
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 100)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
UploadRateLimitDuration int64 = 60
DownloadRateLimitNum = 10
DownloadRateLimitDuration int64 = 60
CriticalRateLimitNum = 20
CriticalRateLimitDuration int64 = 20 * 60
)
var RateLimitKeyExpirationDuration = 20 * time.Minute
const (

View File

@@ -1,82 +0,0 @@
// Copyright 2014 Manu Martinez-Almeida. All rights reserved.
// Use of this source code is governed by a MIT style
// license that can be found in the LICENSE file.
package common
import (
"fmt"
"io"
"net/http"
"strings"
)
type stringWriter interface {
io.Writer
writeString(string) (int, error)
}
type stringWrapper struct {
io.Writer
}
func (w stringWrapper) writeString(str string) (int, error) {
return w.Writer.Write([]byte(str))
}
func checkWriter(writer io.Writer) stringWriter {
if w, ok := writer.(stringWriter); ok {
return w
} else {
return stringWrapper{writer}
}
}
// Server-Sent Events
// W3C Working Draft 29 October 2009
// http://www.w3.org/TR/2009/WD-eventsource-20091029/
var contentType = []string{"text/event-stream"}
var noCache = []string{"no-cache"}
var fieldReplacer = strings.NewReplacer(
"\n", "\\n",
"\r", "\\r")
var dataReplacer = strings.NewReplacer(
"\n", "\ndata:",
"\r", "\\r")
type CustomEvent struct {
Event string
Id string
Retry uint
Data interface{}
}
func encode(writer io.Writer, event CustomEvent) error {
w := checkWriter(writer)
return writeData(w, event.Data)
}
func writeData(w stringWriter, data interface{}) error {
dataReplacer.WriteString(w, fmt.Sprint(data))
if strings.HasPrefix(data.(string), "data") {
w.writeString("\n\n")
}
return nil
}
func (r CustomEvent) Render(w http.ResponseWriter) error {
r.WriteContentType(w)
return encode(w, r)
}
func (r CustomEvent) WriteContentType(w http.ResponseWriter) {
header := w.Header()
header["Content-Type"] = contentType
if _, exist := header["Cache-Control"]; !exist {
header["Cache-Control"] = noCache
}
}

View File

@@ -2,6 +2,3 @@ package common
var UsingSQLite = false
var UsingPostgreSQL = false
var SQLitePath = "one-api.db"
var SQLiteBusyTimeout = GetOrDefault("SQLITE_BUSY_TIMEOUT", 3000)

View File

@@ -1,32 +0,0 @@
package common
import (
"embed"
"github.com/gin-contrib/static"
"io/fs"
"net/http"
)
// Credit: https://github.com/gin-contrib/static/issues/19
type embedFileSystem struct {
http.FileSystem
}
func (e embedFileSystem) Exists(prefix string, path string) bool {
_, err := e.Open(path)
if err != nil {
return false
}
return true
}
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
efs, err := fs.Sub(fsEmbed, targetPath)
if err != nil {
panic(err)
}
return embedFileSystem{
FileSystem: http.FS(efs),
}
}

View File

@@ -1,69 +0,0 @@
package common
import (
"flag"
"fmt"
"log"
"os"
"path/filepath"
"github.com/joho/godotenv"
)
var (
Port = flag.Int("port", 3000, "the listening port")
PrintVersion = flag.Bool("version", false, "print version and exit")
PrintHelp = flag.Bool("help", false, "print help and exit")
LogDir = flag.String("log-dir", "./logs", "specify the log directory")
)
func printHelp() {
fmt.Println("One API " + Version + " - All in one API service for OpenAI API.")
fmt.Println("Copyright (C) 2023 MartialBE. All rights reserved.")
fmt.Println("Original copyright holder: JustSong")
fmt.Println("GitHub: https://github.com/MartialBE/one-api")
fmt.Println("Usage: one-api [--port <port>] [--log-dir <log directory>] [--version] [--help]")
}
func init() {
// 加载.env文件
err := godotenv.Load()
if err != nil {
SysLog("failed to load .env file: " + err.Error())
}
flag.Parse()
if *PrintVersion {
fmt.Println(Version)
os.Exit(0)
}
if *PrintHelp {
printHelp()
os.Exit(0)
}
if os.Getenv("SESSION_SECRET") != "" {
if os.Getenv("SESSION_SECRET") == "random_string" {
SysError("SESSION_SECRET is set to an example value, please change it to a random string.")
} else {
SessionSecret = os.Getenv("SESSION_SECRET")
}
}
if os.Getenv("SQLITE_PATH") != "" {
SQLitePath = os.Getenv("SQLITE_PATH")
}
if *LogDir != "" {
var err error
*LogDir, err = filepath.Abs(*LogDir)
if err != nil {
log.Fatal(err)
}
if _, err := os.Stat(*LogDir); os.IsNotExist(err) {
err = os.Mkdir(*LogDir, 0777)
if err != nil {
log.Fatal(err)
}
}
}
}

View File

@@ -3,13 +3,15 @@ package common
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"io"
"log"
"os"
"path/filepath"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/spf13/viper"
)
const (
@@ -17,6 +19,9 @@ const (
loggerWarn = "WARN"
loggerError = "ERR"
)
const (
RequestIdKey = "X-Oneapi-Request-Id"
)
const maxLogCount = 1000000
@@ -24,25 +29,54 @@ var logCount int
var setupLogLock sync.Mutex
var setupLogWorking bool
var defaultLogDir = "./logs"
func SetupLogger() {
if *LogDir != "" {
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
return
}
defer func() {
setupLogLock.Unlock()
setupLogWorking = false
}()
logPath := filepath.Join(*LogDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
logDir := getLogDir()
if logDir == "" {
return
}
ok := setupLogLock.TryLock()
if !ok {
log.Println("setup log is already working")
return
}
defer func() {
setupLogLock.Unlock()
setupLogWorking = false
}()
logPath := filepath.Join(logDir, fmt.Sprintf("oneapi-%s.log", time.Now().Format("20060102")))
fd, err := os.OpenFile(logPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
log.Fatal("failed to open log file")
}
gin.DefaultWriter = io.MultiWriter(os.Stdout, fd)
gin.DefaultErrorWriter = io.MultiWriter(os.Stderr, fd)
}
func getLogDir() string {
logDir := viper.GetString("log_dir")
if logDir == "" {
logDir = defaultLogDir
}
var err error
logDir, err = filepath.Abs(viper.GetString("log_dir"))
if err != nil {
log.Fatal(err)
return ""
}
if !IsFileExist(logDir) {
err = os.Mkdir(logDir, 0777)
if err != nil {
log.Fatal(err)
return ""
}
}
return logDir
}
func SysLog(s string) {

View File

@@ -2,30 +2,32 @@ package common
import (
"context"
"github.com/go-redis/redis/v8"
"os"
"time"
"github.com/go-redis/redis/v8"
"github.com/spf13/viper"
)
var RDB *redis.Client
var RedisEnabled = true
var RedisEnabled = false
// InitRedisClient This function is called after init()
func InitRedisClient() (err error) {
if os.Getenv("REDIS_CONN_STRING") == "" {
RedisEnabled = false
redisConn := viper.GetString("REDIS_CONN_STRING")
if redisConn == "" {
SysLog("REDIS_CONN_STRING not set, Redis is not enabled")
return nil
}
if os.Getenv("SYNC_FREQUENCY") == "" {
RedisEnabled = false
if viper.GetInt("SYNC_FREQUENCY") == 0 {
SysLog("SYNC_FREQUENCY not set, Redis is disabled")
return nil
}
SysLog("Redis is enabled")
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
opt, err := redis.ParseURL(redisConn)
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())
return
}
RDB = redis.NewClient(opt)
@@ -35,12 +37,17 @@ func InitRedisClient() (err error) {
_, err = RDB.Ping(ctx).Result()
if err != nil {
FatalLog("Redis ping test failed: " + err.Error())
} else {
RedisEnabled = true
// for compatibility with old versions
MemoryCacheEnabled = true
}
return err
}
func ParseRedisOption() *redis.Options {
opt, err := redis.ParseURL(os.Getenv("REDIS_CONN_STRING"))
opt, err := redis.ParseURL(viper.GetString("REDIS_CONN_STRING"))
if err != nil {
FatalLog("failed to parse Redis connection string: " + err.Error())
}

View File

@@ -39,7 +39,7 @@ func proxyFunc(req *http.Request) (*url.URL, error) {
func socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error) {
// 设置TCP超时
dialer := &net.Dialer{
Timeout: time.Duration(common.ConnectTimeout) * time.Second,
Timeout: time.Duration(common.GetOrDefault("CONNECT_TIMEOUT", 5)) * time.Second,
KeepAlive: 30 * time.Second,
}
@@ -64,7 +64,7 @@ func socks5ProxyFunc(ctx context.Context, network, addr string) (net.Conn, error
var HTTPClient *http.Client
func init() {
func InitHttpClient() {
trans := &http.Transport{
DialContext: socks5ProxyFunc,
Proxy: proxyFunc,
@@ -74,7 +74,8 @@ func init() {
Transport: trans,
}
if common.RelayTimeout != 0 {
HTTPClient.Timeout = time.Duration(common.RelayTimeout) * time.Second
relayTimeout := common.GetOrDefault("RELAY_TIMEOUT", 600)
if relayTimeout != 0 {
HTTPClient.Timeout = time.Duration(relayTimeout) * time.Second
}
}

View File

@@ -1,4 +1,4 @@
package common
package requester
import (
"encoding/json"

View File

@@ -5,7 +5,6 @@ import (
"context"
"io"
"net/http"
"one-api/common"
)
type RequestBuilder interface {
@@ -13,12 +12,12 @@ type RequestBuilder interface {
}
type HTTPRequestBuilder struct {
marshaller common.Marshaller
marshaller Marshaller
}
func NewRequestBuilder() *HTTPRequestBuilder {
return &HTTPRequestBuilder{
marshaller: &common.JSONMarshaller{},
marshaller: &JSONMarshaller{},
}
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"one-api/common"
"one-api/model"
"os"
"strings"
"time"
@@ -14,6 +13,7 @@ import (
"github.com/PaulSonOfLars/gotgbot/v2/ext/handlers"
"github.com/PaulSonOfLars/gotgbot/v2/ext/handlers/filters/callbackquery"
"github.com/PaulSonOfLars/gotgbot/v2/ext/handlers/filters/message"
"github.com/spf13/viper"
)
var TGupdater *ext.Updater
@@ -28,13 +28,14 @@ func InitTelegramBot() {
return
}
if os.Getenv("TG_BOT_API_KEY") == "" {
botKey := viper.GetString("TG_BOT_API_KEY")
if botKey == "" {
common.SysLog("Telegram bot is not enabled")
return
}
var err error
TGBot, err = gotgbot.NewBot(os.Getenv("TG_BOT_API_KEY"), nil)
TGBot, err = gotgbot.NewBot(botKey, nil)
if err != nil {
common.SysLog("failed to create new telegram bot: " + err.Error())
return
@@ -47,15 +48,16 @@ func InitTelegramBot() {
}
func StartTelegramBot() {
if os.Getenv("TG_WEBHOOK_SECRET") != "" {
botWebhook := viper.GetString("TG_WEBHOOK_SECRET")
if botWebhook != "" {
if common.ServerAddress == "" {
common.SysLog("Telegram bot is not enabled: Server address is not set")
StopTelegramBot()
return
}
TGWebHookSecret = os.Getenv("TG_WEBHOOK_SECRET")
TGWebHookSecret = botWebhook
serverAddress := strings.TrimSuffix(common.ServerAddress, "/")
urlPath := fmt.Sprintf("/api/telegram/%s", os.Getenv("TG_BOT_API_KEY"))
urlPath := fmt.Sprintf("/api/telegram/%s", viper.GetString("TG_BOT_API_KEY"))
webHookOpts := &ext.AddWebhookOpts{
SecretToken: TGWebHookSecret,

View File

@@ -2,7 +2,6 @@ package common
import (
"fmt"
"github.com/google/uuid"
"html/template"
"log"
"math/rand"
@@ -13,6 +12,9 @@ import (
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/spf13/viper"
)
func OpenBrowser(url string) {
@@ -184,16 +186,14 @@ func Max(a int, b int) int {
}
}
func GetOrDefault(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
func GetOrDefault[T any](env string, defaultValue T) T {
if viper.IsSet(env) {
value := viper.Get(env)
if v, ok := value.(T); ok {
return v
}
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
return defaultValue
}
func MessageWithRequestId(message string, id string) string {
@@ -207,3 +207,8 @@ func String2Int(str string) int {
}
return num
}
func IsFileExist(path string) bool {
_, err := os.Stat(path)
return err == nil || os.IsExist(err)
}