Merge remote-tracking branch 'origin/upstream/main'

This commit is contained in:
Laisky.Cai 2024-03-15 09:49:49 +00:00
commit 41afad713e
49 changed files with 623 additions and 205 deletions

View File

@ -20,6 +20,12 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 0
fi
- uses: actions/setup-node@v3
with:
node-version: 16

View File

@ -20,6 +20,12 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 0
fi
- uses: actions/setup-node@v3
with:
node-version: 16

View File

@ -23,6 +23,12 @@ jobs:
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Check repository URL
run: |
REPO_URL=$(git config --get remote.origin.url)
if [[ $REPO_URL == *"pro" ]]; then
exit 0
fi
- uses: actions/setup-node@v3
with:
node-version: 16

View File

@ -9,7 +9,7 @@ import (
"sync"
"time"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/env"
)
func init() {
@ -90,14 +90,14 @@ var MessagePusherToken = ""
var TurnstileSiteKey = ""
var TurnstileSecretKey = ""
var QuotaForNewUser = 0
var QuotaForInviter = 0
var QuotaForInvitee = 0
var QuotaForNewUser int64 = 0
var QuotaForInviter int64 = 0
var QuotaForInvitee int64 = 0
var ChannelDisableThreshold = 5.0
var AutomaticDisableChannelEnabled = false
var AutomaticEnableChannelEnabled = false
var QuotaRemindThreshold = 1000
var PreConsumedQuota = 500
var QuotaRemindThreshold int64 = 1000
var PreConsumedQuota int64 = 500
var ApproximateTokenEnabled = false
var RetryTimes = 0
@ -108,17 +108,17 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave"
var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL"))
var RequestInterval = time.Duration(requestInterval) * time.Second
var SyncFrequency = helper.GetOrDefaultEnvInt("SYNC_FREQUENCY", 10*60) // unit is second
var SyncFrequency = env.Int("SYNC_FREQUENCY", 10*60) // unit is second
var BatchUpdateEnabled = false
var BatchUpdateInterval = helper.GetOrDefaultEnvInt("BATCH_UPDATE_INTERVAL", 5)
var BatchUpdateInterval = env.Int("BATCH_UPDATE_INTERVAL", 5)
var RelayTimeout = helper.GetOrDefaultEnvInt("RELAY_TIMEOUT", 0) // unit is second
var IdleTimeout = helper.GetOrDefaultEnvInt("IDLE_TIMEOUT", 30) // unit is second
var RelayTimeout = env.Int("RELAY_TIMEOUT", 0) // unit is second
var IdleTimeout = env.Int("IDLE_TIMEOUT", 30) // unit is second
var GeminiSafetySetting = helper.GetOrDefaultEnvString("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var GeminiSafetySetting = env.String("GEMINI_SAFETY_SETTING", "BLOCK_NONE")
var Theme = helper.GetOrDefaultEnvString("THEME", "default")
var Theme = env.String("THEME", "default")
var ValidThemes = map[string]bool{
"default": true,
"berry": true,
@ -127,10 +127,10 @@ var ValidThemes = map[string]bool{
// All duration's unit is seconds
// Shouldn't larger then RateLimitKeyExpirationDuration
var (
GlobalApiRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitNum = env.Int("GLOBAL_API_RATE_LIMIT", 180)
GlobalApiRateLimitDuration int64 = 3 * 60
GlobalWebRateLimitNum = helper.GetOrDefaultEnvInt("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitNum = env.Int("GLOBAL_WEB_RATE_LIMIT", 60)
GlobalWebRateLimitDuration int64 = 3 * 60
UploadRateLimitNum = 10
@ -145,8 +145,8 @@ var (
var RateLimitKeyExpirationDuration = 20 * time.Minute
var EnableMetric = helper.GetOrDefaultEnvBool("ENABLE_METRIC", false)
var MetricQueueSize = helper.GetOrDefaultEnvInt("METRIC_QUEUE_SIZE", 10)
var MetricSuccessRateThreshold = helper.GetOrDefaultEnvFloat64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
var MetricSuccessChanSize = helper.GetOrDefaultEnvInt("METRIC_SUCCESS_CHAN_SIZE", 1024)
var MetricFailChanSize = helper.GetOrDefaultEnvInt("METRIC_FAIL_CHAN_SIZE", 128)
var EnableMetric = env.Bool("ENABLE_METRIC", false)
var MetricQueueSize = env.Int("METRIC_QUEUE_SIZE", 10)
var MetricSuccessRateThreshold = env.Float64("METRIC_SUCCESS_RATE_THRESHOLD", 0.8)
var MetricSuccessChanSize = env.Int("METRIC_SUCCESS_CHAN_SIZE", 1024)
var MetricFailChanSize = env.Int("METRIC_FAIL_CHAN_SIZE", 128)

View File

@ -69,6 +69,8 @@ const (
ChannelTypeMinimax
ChannelTypeMistral
ChannelTypeGroq
ChannelTypeOllama
ChannelTypeLingYiWanWu
ChannelTypeDummy
)
@ -104,6 +106,8 @@ var ChannelBaseURLs = []string{
"https://api.minimax.chat", // 27
"https://api.mistral.ai", // 28
"https://api.groq.com/openai", // 29
"http://localhost:11434", // 30
"https://api.lingyiwanwu.com", // 31
}
const (

View File

@ -1,10 +1,12 @@
package common
import "github.com/songquanpeng/one-api/common/helper"
import (
"github.com/songquanpeng/one-api/common/env"
)
var UsingSQLite = false
var UsingPostgreSQL = false
var UsingMySQL = false
var SQLitePath = "one-api.db"
var SQLiteBusyTimeout = helper.GetOrDefaultEnvInt("SQLITE_BUSY_TIMEOUT", 3000)
var SQLiteBusyTimeout = env.Int("SQLITE_BUSY_TIMEOUT", 3000)

42
common/env/helper.go vendored Normal file
View File

@ -0,0 +1,42 @@
package env
import (
"os"
"strconv"
)
func Bool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env) == "true"
}
func Int(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
return defaultValue
}
return num
}
func Float64(env string, defaultValue float64) float64 {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.ParseFloat(os.Getenv(env), 64)
if err != nil {
return defaultValue
}
return num
}
func String(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}

View File

@ -3,12 +3,10 @@ package helper
import (
"fmt"
"github.com/google/uuid"
"github.com/songquanpeng/one-api/common/logger"
"html/template"
"log"
"math/rand"
"net"
"os"
"os/exec"
"runtime"
"strconv"
@ -187,6 +185,10 @@ func GetTimeString() string {
return fmt.Sprintf("%s%d", now.Format("20060102150405"), now.UnixNano()%1e9)
}
func GenRequestID() string {
return GetTimeString() + GetRandomNumberString(8)
}
func Max(a int, b int) int {
if a >= b {
return a
@ -195,44 +197,6 @@ func Max(a int, b int) int {
}
}
func GetOrDefaultEnvBool(env string, defaultValue bool) bool {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env) == "true"
}
func GetOrDefaultEnvInt(env string, defaultValue int) int {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.Atoi(os.Getenv(env))
if err != nil {
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultEnvFloat64(env string, defaultValue float64) float64 {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
num, err := strconv.ParseFloat(os.Getenv(env), 64)
if err != nil {
logger.SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %f", env, err.Error(), defaultValue))
return defaultValue
}
return num
}
func GetOrDefaultEnvString(env string, defaultValue string) string {
if env == "" || os.Getenv(env) == "" {
return defaultValue
}
return os.Getenv(env)
}
func AssignOrDefault(value string, defaultValue string) string {
if len(value) != 0 {
return value

View File

@ -4,6 +4,8 @@ import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/helper"
"io"
"log"
"os"
@ -54,7 +56,9 @@ func SysError(s string) {
}
func Debug(ctx context.Context, msg string) {
logHelper(ctx, loggerDEBUG, msg)
if config.DebugEnabled {
logHelper(ctx, loggerDEBUG, msg)
}
}
func Info(ctx context.Context, msg string) {
@ -91,6 +95,9 @@ func logHelper(ctx context.Context, level string, msg string) {
writer = gin.DefaultWriter
}
id := ctx.Value(RequestIdKey)
if id == nil {
id = helper.GenRequestID()
}
now := time.Now()
_, _ = fmt.Fprintf(writer, "[%s] %v | %s | %s \n", level, now.Format("2006/01/02 - 15:04:05"), id, msg)
if !setupLogWorking {

View File

@ -69,7 +69,7 @@ var ModelRatio = map[string]float64{
"claude-instant-1.2": 0.8 / 1000 * USD,
"claude-2.0": 8.0 / 1000 * USD,
"claude-2.1": 8.0 / 1000 * USD,
"claude-3-haiku-20240229": 0.25 / 1000 * USD,
"claude-3-haiku-20240307": 0.25 / 1000 * USD,
"claude-3-sonnet-20240229": 3.0 / 1000 * USD,
"claude-3-opus-20240229": 15.0 / 1000 * USD,
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/hlrk4akp7
@ -78,6 +78,9 @@ var ModelRatio = map[string]float64{
"ERNIE-Bot-4": 0.12 * RMB, // ¥0.12 / 1k tokens
"ERNIE-Bot-8k": 0.024 * RMB,
"Embedding-V1": 0.1429, // ¥0.002 / 1k tokens
"bge-large-zh": 0.002 * RMB,
"bge-large-en": 0.002 * RMB,
"bge-large-8k": 0.002 * RMB,
"PaLM-2": 1,
"gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
"gemini-pro-vision": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens
@ -130,6 +133,10 @@ var ModelRatio = map[string]float64{
"llama2-7b-2048": 0.1 / 1000 * USD,
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
"gemma-7b-it": 0.1 / 1000 * USD,
// https://platform.lingyiwanwu.com/docs#-计费单元
"yi-34b-chat-0205": 2.5 / 1000000 * RMB,
"yi-34b-chat-200k": 12.0 / 1000000 * RMB,
"yi-vl-plus": 6.0 / 1000000 * RMB,
}
var CompletionRatio = map[string]float64{}

View File

@ -5,7 +5,7 @@ import (
"github.com/songquanpeng/one-api/common/config"
)
func LogQuota(quota int) string {
func LogQuota(quota int64) string {
if config.DisplayInCurrencyEnabled {
return fmt.Sprintf("%.6f 额度", float64(quota)/config.QuotaPerUnit)
} else {

View File

@ -8,8 +8,8 @@ import (
)
func GetSubscription(c *gin.Context) {
var remainQuota int
var usedQuota int
var remainQuota int64
var usedQuota int64
var err error
var token *model.Token
var expiredTime int64
@ -60,7 +60,7 @@ func GetSubscription(c *gin.Context) {
}
func GetUsage(c *gin.Context) {
var quota int
var quota int64
var err error
var token *model.Token
if config.DisplayTokenStatEnabled {

View File

@ -30,7 +30,7 @@ import (
func buildTestRequest() *relaymodel.GeneralOpenAIRequest {
testRequest := &relaymodel.GeneralOpenAIRequest{
MaxTokens: 1,
MaxTokens: 2,
Stream: false,
Model: "gpt-3.5-turbo",
}

View File

@ -234,18 +234,18 @@ func UpdateToken(c *gin.Context) {
tokenInDB.ExpiredTime = *tokenPatch.ExpiredTime
}
if tokenPatch.RemainQuota != nil {
tokenInDB.RemainQuota = *tokenPatch.RemainQuota
tokenInDB.RemainQuota = int64(*tokenPatch.RemainQuota)
}
if tokenPatch.UnlimitedQuota != nil {
tokenInDB.UnlimitedQuota = *tokenPatch.UnlimitedQuota
}
}
tokenInDB.RemainQuota -= tokenPatch.AddUsedQuota
tokenInDB.UsedQuota += tokenPatch.AddUsedQuota
tokenInDB.RemainQuota -= int64(tokenPatch.AddUsedQuota)
tokenInDB.UsedQuota += int64(tokenPatch.AddUsedQuota)
if tokenPatch.AddUsedQuota != 0 {
model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("外部(%s)消耗 %s", tokenPatch.AddReason, common.LogQuota(tokenPatch.AddUsedQuota)))
model.RecordLog(userId, model.LogTypeConsume, fmt.Sprintf("外部(%s)消耗 %s", tokenPatch.AddReason, common.LogQuota(int64(tokenPatch.AddUsedQuota))))
}
if err = tokenInDB.Update(); err != nil {

View File

@ -2,7 +2,7 @@ version: '3.4'
services:
one-api:
image: justsong/one-api:latest
image: "${REGISTRY:-docker.io}/justsong/one-api:latest"
container_name: one-api
restart: always
command: --log-dir /app/logs
@ -29,12 +29,12 @@ services:
retries: 3
redis:
image: redis:latest
image: "${REGISTRY:-docker.io}/redis:latest"
container_name: redis
restart: always
db:
image: mysql:8.2.0
image: "${REGISTRY:-docker.io}/mysql:8.2.0"
restart: always
container_name: mysql
volumes:

16
main.go
View File

@ -32,11 +32,25 @@ func main() {
if config.DebugEnabled {
logger.SysLog("running in debug mode")
}
var err error
// Initialize SQL Database
err := model.InitDB()
model.DB, err = model.InitDB("SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize database: " + err.Error())
}
if os.Getenv("LOG_SQL_DSN") != "" {
logger.SysLog("using secondary database for table logs")
model.LOG_DB, err = model.InitDB("LOG_SQL_DSN")
if err != nil {
logger.FatalLog("failed to initialize secondary database: " + err.Error())
}
} else {
model.LOG_DB = model.DB
}
err = model.CreateRootAccountIfNeed()
if err != nil {
logger.FatalLog("database init error: " + err.Error())
}
defer func() {
err := model.CloseDB()
if err != nil {

View File

@ -9,7 +9,7 @@ import (
func RequestId() func(c *gin.Context) {
return func(c *gin.Context) {
id := helper.GetTimeString() + helper.GetRandomNumberString(8)
id := helper.GenRequestID()
c.Set(logger.RequestIdKey, id)
ctx := context.WithValue(c.Request.Context(), logger.RequestIdKey, id)
c.Request = c.Request.WithContext(ctx)

View File

@ -1,6 +1,7 @@
package model
import (
"context"
"encoding/json"
"fmt"
"github.com/Laisky/errors/v2"
@ -71,31 +72,42 @@ func CacheGetUserGroup(id int) (group string, err error) {
return group, err
}
func CacheGetUserQuota(id int) (quota int, err error) {
func fetchAndUpdateUserQuota(ctx context.Context, id int) (quota int64, err error) {
quota, err = GetUserQuota(id)
if err != nil {
return 0, err
}
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil {
logger.Error(ctx, "Redis set user quota error: "+err.Error())
}
return
}
func CacheGetUserQuota(ctx context.Context, id int) (quota int64, err error) {
if !common.RedisEnabled {
return GetUserQuota(id)
}
quotaString, err := common.RedisGet(fmt.Sprintf("user_quota:%d", id))
if err != nil {
quota, err = GetUserQuota(id)
if err != nil {
return 0, err
}
err = common.RedisSet(fmt.Sprintf("user_quota:%d", id), fmt.Sprintf("%d", quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second)
if err != nil {
logger.SysError("Redis set user quota error: " + err.Error())
}
return quota, err
return fetchAndUpdateUserQuota(ctx, id)
}
quota, err = strconv.Atoi(quotaString)
return quota, err
quota, err = strconv.ParseInt(quotaString, 10, 64)
if err != nil {
return 0, nil
}
if quota <= config.PreConsumedQuota { // when user's quota is less than pre-consumed quota, we need to fetch from db
logger.Infof(ctx, "user %d's cached quota is too low: %d, refreshing from db", quota, id)
return fetchAndUpdateUserQuota(ctx, id)
}
return quota, nil
}
func CacheUpdateUserQuota(id int) error {
func CacheUpdateUserQuota(ctx context.Context, id int) error {
if !common.RedisEnabled {
return nil
}
quota, err := CacheGetUserQuota(id)
quota, err := CacheGetUserQuota(ctx, id)
if err != nil {
return err
}
@ -103,7 +115,7 @@ func CacheUpdateUserQuota(id int) error {
return err
}
func CacheDecreaseUserQuota(id int, quota int) error {
func CacheDecreaseUserQuota(id int, quota int64) error {
if !common.RedisEnabled {
return nil
}

View File

@ -178,7 +178,7 @@ func UpdateChannelStatusById(id int, status int) {
}
}
func UpdateChannelUsedQuota(id int, quota int) {
func UpdateChannelUsedQuota(id int, quota int64) {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeChannelUsedQuota, id, quota)
return
@ -186,7 +186,7 @@ func UpdateChannelUsedQuota(id int, quota int) {
updateChannelUsedQuota(id, quota)
}
func updateChannelUsedQuota(id int, quota int) {
func updateChannelUsedQuota(id int, quota int64) {
err := DB.Model(&Channel{}).Where("id = ?", id).Update("used_quota", gorm.Expr("used_quota + ?", quota)).Error
if err != nil {
logger.SysError("failed to update channel used quota: " + err.Error())

View File

@ -45,13 +45,13 @@ func RecordLog(userId int, logType int, content string) {
Type: logType,
Content: content,
}
err := DB.Create(log).Error
err := LOG_DB.Create(log).Error
if err != nil {
logger.SysError("failed to record log: " + err.Error())
}
}
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int, content string) {
func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptTokens int, completionTokens int, modelName string, tokenName string, quota int64, content string) {
logger.Info(ctx, fmt.Sprintf("record consume log: userId=%d, channelId=%d, promptTokens=%d, completionTokens=%d, modelName=%s, tokenName=%s, quota=%d, content=%s", userId, channelId, promptTokens, completionTokens, modelName, tokenName, quota, content))
if !config.LogConsumeEnabled {
return
@ -66,10 +66,10 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
CompletionTokens: completionTokens,
TokenName: tokenName,
ModelName: modelName,
Quota: quota,
Quota: int(quota),
ChannelId: channelId,
}
err := DB.Create(log).Error
err := LOG_DB.Create(log).Error
if err != nil {
logger.Error(ctx, "failed to record log: "+err.Error())
}
@ -78,9 +78,9 @@ func RecordConsumeLog(ctx context.Context, userId int, channelId int, promptToke
func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, startIdx int, num int, channel int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB
tx = LOG_DB
} else {
tx = DB.Where("type = ?", logType)
tx = LOG_DB.Where("type = ?", logType)
}
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
@ -107,9 +107,9 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName
func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int64, modelName string, tokenName string, startIdx int, num int) (logs []*Log, err error) {
var tx *gorm.DB
if logType == LogTypeUnknown {
tx = DB.Where("user_id = ?", userId)
tx = LOG_DB.Where("user_id = ?", userId)
} else {
tx = DB.Where("user_id = ? and type = ?", userId, logType)
tx = LOG_DB.Where("user_id = ? and type = ?", userId, logType)
}
if modelName != "" {
tx = tx.Where("model_name = ?", modelName)
@ -128,17 +128,17 @@ func GetUserLogs(userId int, logType int, startTimestamp int64, endTimestamp int
}
func SearchAllLogs(keyword string) (logs []*Log, err error) {
err = DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
err = LOG_DB.Where("type = ? or content LIKE ?", keyword, keyword+"%").Order("id desc").Limit(config.MaxRecentItems).Find(&logs).Error
return logs, err
}
func SearchUserLogs(userId int, keyword string) (logs []*Log, err error) {
err = DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
err = LOG_DB.Where("user_id = ? and type = ?", userId, keyword).Order("id desc").Limit(config.MaxRecentItems).Omit("id").Find(&logs).Error
return logs, err
}
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int) {
tx := DB.Table("logs").Select("ifnull(sum(quota),0)")
func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string, channel int) (quota int64) {
tx := LOG_DB.Table("logs").Select("ifnull(sum(quota),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}
@ -162,7 +162,7 @@ func SumUsedQuota(logType int, startTimestamp int64, endTimestamp int64, modelNa
}
func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelName string, username string, tokenName string) (token int) {
tx := DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
tx := LOG_DB.Table("logs").Select("ifnull(sum(prompt_tokens),0) + ifnull(sum(completion_tokens),0)")
if username != "" {
tx = tx.Where("username = ?", username)
}
@ -183,7 +183,7 @@ func SumUsedToken(logType int, startTimestamp int64, endTimestamp int64, modelNa
}
func DeleteOldLog(targetTimestamp int64) (int64, error) {
result := DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
result := LOG_DB.Where("created_at < ?", targetTimestamp).Delete(&Log{})
return result.RowsAffected, result.Error
}
@ -207,7 +207,7 @@ func SearchLogsByDayAndModel(userId, start, end int) (LogStatistics []*LogStatis
groupSelect = "strftime('%Y-%m-%d', datetime(created_at, 'unixepoch')) as day"
}
err = DB.Raw(`
err = LOG_DB.Raw(`
SELECT `+groupSelect+`,
model_name, count(1) as request_count,
sum(quota) as quota,

View File

@ -9,6 +9,7 @@ import (
"github.com/pkg/errors"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/env"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/driver/mysql"
@ -18,8 +19,9 @@ import (
)
var DB *gorm.DB
var LOG_DB *gorm.DB
func createRootAccountIfNeed() error {
func CreateRootAccountIfNeed() error {
var user User
//if user.Status != util.UserStatusEnabled {
if err := DB.First(&user).Error; err != nil {
@ -42,9 +44,9 @@ func createRootAccountIfNeed() error {
return nil
}
func chooseDB() (*gorm.DB, error) {
if os.Getenv("SQL_DSN") != "" {
dsn := os.Getenv("SQL_DSN")
func chooseDB(envName string) (*gorm.DB, error) {
if os.Getenv(envName) != "" {
dsn := os.Getenv(envName)
if strings.HasPrefix(dsn, "postgres://") {
// Use PostgreSQL
logger.SysLog("using PostgreSQL as database")
@ -72,23 +74,22 @@ func chooseDB() (*gorm.DB, error) {
})
}
func InitDB() (err error) {
db, err := chooseDB()
func InitDB(envName string) (db *gorm.DB, err error) {
db, err = chooseDB(envName)
if err == nil {
if config.DebugSQLEnabled {
db = db.Debug()
}
DB = db
sqlDB, err := DB.DB()
sqlDB, err := db.DB()
if err != nil {
return errors.WithStack(err)
return nil, err
}
sqlDB.SetMaxIdleConns(helper.GetOrDefaultEnvInt("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(helper.GetOrDefaultEnvInt("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(helper.GetOrDefaultEnvInt("SQL_MAX_LIFETIME", 60)))
sqlDB.SetMaxIdleConns(env.Int("SQL_MAX_IDLE_CONNS", 100))
sqlDB.SetMaxOpenConns(env.Int("SQL_MAX_OPEN_CONNS", 1000))
sqlDB.SetConnMaxLifetime(time.Second * time.Duration(env.Int("SQL_MAX_LIFETIME", 60)))
if !config.IsMasterNode {
return nil
return db, err
}
if common.UsingMySQL {
_, _ = sqlDB.Exec("DROP INDEX idx_channels_key ON channels;") // TODO: delete this line when most users have upgraded
@ -96,46 +97,55 @@ func InitDB() (err error) {
logger.SysLog("database migration started")
err = db.AutoMigrate(&Channel{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
err = db.AutoMigrate(&Token{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
err = db.AutoMigrate(&User{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
err = db.AutoMigrate(&Option{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
err = db.AutoMigrate(&Redemption{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
err = db.AutoMigrate(&Ability{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
err = db.AutoMigrate(&Log{})
if err != nil {
return errors.WithStack(err)
return nil, err
}
logger.SysLog("database migrated")
err = createRootAccountIfNeed()
return errors.WithStack(err)
return db, err
} else {
logger.FatalLog(err)
}
return errors.WithStack(err)
return db, err
}
func CloseDB() error {
sqlDB, err := DB.DB()
func closeDB(db *gorm.DB) error {
sqlDB, err := db.DB()
if err != nil {
return errors.WithStack(err)
}
err = sqlDB.Close()
return errors.WithStack(err)
}
func CloseDB() error {
if LOG_DB != DB {
err := closeDB(LOG_DB)
if err != nil {
return err
}
}
return closeDB(DB)
}

View File

@ -61,11 +61,11 @@ func InitOptionMap() {
config.OptionMap["MessagePusherToken"] = ""
config.OptionMap["TurnstileSiteKey"] = ""
config.OptionMap["TurnstileSecretKey"] = ""
config.OptionMap["QuotaForNewUser"] = strconv.Itoa(config.QuotaForNewUser)
config.OptionMap["QuotaForInviter"] = strconv.Itoa(config.QuotaForInviter)
config.OptionMap["QuotaForInvitee"] = strconv.Itoa(config.QuotaForInvitee)
config.OptionMap["QuotaRemindThreshold"] = strconv.Itoa(config.QuotaRemindThreshold)
config.OptionMap["PreConsumedQuota"] = strconv.Itoa(config.PreConsumedQuota)
config.OptionMap["QuotaForNewUser"] = strconv.FormatInt(config.QuotaForNewUser, 10)
config.OptionMap["QuotaForInviter"] = strconv.FormatInt(config.QuotaForInviter, 10)
config.OptionMap["QuotaForInvitee"] = strconv.FormatInt(config.QuotaForInvitee, 10)
config.OptionMap["QuotaRemindThreshold"] = strconv.FormatInt(config.QuotaRemindThreshold, 10)
config.OptionMap["PreConsumedQuota"] = strconv.FormatInt(config.PreConsumedQuota, 10)
config.OptionMap["ModelRatio"] = common.ModelRatio2JSONString()
config.OptionMap["GroupRatio"] = common.GroupRatio2JSONString()
config.OptionMap["CompletionRatio"] = common.CompletionRatio2JSONString()
@ -193,15 +193,15 @@ func updateOptionMap(key string, value string) (err error) {
case "TurnstileSecretKey":
config.TurnstileSecretKey = value
case "QuotaForNewUser":
config.QuotaForNewUser, _ = strconv.Atoi(value)
config.QuotaForNewUser, _ = strconv.ParseInt(value, 10, 64)
case "QuotaForInviter":
config.QuotaForInviter, _ = strconv.Atoi(value)
config.QuotaForInviter, _ = strconv.ParseInt(value, 10, 64)
case "QuotaForInvitee":
config.QuotaForInvitee, _ = strconv.Atoi(value)
config.QuotaForInvitee, _ = strconv.ParseInt(value, 10, 64)
case "QuotaRemindThreshold":
config.QuotaRemindThreshold, _ = strconv.Atoi(value)
config.QuotaRemindThreshold, _ = strconv.ParseInt(value, 10, 64)
case "PreConsumedQuota":
config.PreConsumedQuota, _ = strconv.Atoi(value)
config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64)
case "RetryTimes":
config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio":

View File

@ -14,7 +14,7 @@ type Redemption struct {
Key string `json:"key" gorm:"type:char(32);uniqueIndex"`
Status int `json:"status" gorm:"default:1"`
Name string `json:"name" gorm:"index"`
Quota int `json:"quota" gorm:"default:100"`
Quota int64 `json:"quota" gorm:"default:100"`
CreatedTime int64 `json:"created_time" gorm:"bigint"`
RedeemedTime int64 `json:"redeemed_time" gorm:"bigint"`
Count int `json:"count" gorm:"-:all"` // only for api request
@ -42,7 +42,7 @@ func GetRedemptionById(id int) (*Redemption, error) {
return &redemption, err
}
func Redeem(key string, userId int) (quota int, err error) {
func Redeem(key string, userId int) (quota int64, err error) {
if key == "" {
return 0, errors.New("未提供兑换码")
}

View File

@ -21,9 +21,9 @@ type Token struct {
CreatedTime int64 `json:"created_time" gorm:"bigint"`
AccessedTime int64 `json:"accessed_time" gorm:"bigint"`
ExpiredTime int64 `json:"expired_time" gorm:"bigint;default:-1"` // -1 means never expired
RemainQuota int `json:"remain_quota" gorm:"default:0"`
RemainQuota int64 `json:"remain_quota" gorm:"default:0"`
UnlimitedQuota bool `json:"unlimited_quota" gorm:"default:false"`
UsedQuota int `json:"used_quota" gorm:"default:0"` // used quota
UsedQuota int64 `json:"used_quota" gorm:"default:0"` // used quota
}
func GetAllUserTokens(userId int, startIdx int, num int) ([]*Token, error) {
@ -141,7 +141,7 @@ func DeleteTokenById(id int, userId int) (err error) {
return token.Delete()
}
func IncreaseTokenQuota(id int, quota int) (err error) {
func IncreaseTokenQuota(id int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@ -152,7 +152,7 @@ func IncreaseTokenQuota(id int, quota int) (err error) {
return increaseTokenQuota(id, quota)
}
func increaseTokenQuota(id int, quota int) (err error) {
func increaseTokenQuota(id int, quota int64) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota + ?", quota),
@ -163,7 +163,7 @@ func increaseTokenQuota(id int, quota int) (err error) {
return err
}
func DecreaseTokenQuota(id int, quota int) (err error) {
func DecreaseTokenQuota(id int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@ -174,7 +174,7 @@ func DecreaseTokenQuota(id int, quota int) (err error) {
return decreaseTokenQuota(id, quota)
}
func decreaseTokenQuota(id int, quota int) (err error) {
func decreaseTokenQuota(id int, quota int64) (err error) {
err = DB.Model(&Token{}).Where("id = ?", id).Updates(
map[string]interface{}{
"remain_quota": gorm.Expr("remain_quota - ?", quota),
@ -185,7 +185,7 @@ func decreaseTokenQuota(id int, quota int) (err error) {
return err
}
func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
func PreConsumeTokenQuota(tokenId int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@ -235,7 +235,7 @@ func PreConsumeTokenQuota(tokenId int, quota int) (err error) {
return err
}
func PostConsumeTokenQuota(tokenId int, quota int) (err error) {
func PostConsumeTokenQuota(tokenId int, quota int64) (err error) {
token, err := GetTokenById(tokenId)
if quota > 0 {
err = DecreaseUserQuota(token.UserId, quota)

View File

@ -2,6 +2,8 @@ package model
import (
"fmt"
"strings"
"github.com/Laisky/errors/v2"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/blacklist"
@ -9,7 +11,6 @@ import (
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"gorm.io/gorm"
"strings"
)
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
@ -26,10 +27,10 @@ type User struct {
WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"`
VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database!
AccessToken string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
Quota int `json:"quota" gorm:"column:quota;type:int;default:0"`
UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"column:request_count;type:int;default:0;"` // request number
Group string `json:"group" gorm:"column:group;type:varchar(32);default:'default'"`
Quota int64 `json:"quota" gorm:"type:int;default:0"`
UsedQuota int64 `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota
RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number
Group string `json:"group" gorm:"type:varchar(32);default:'default'"`
AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
}
@ -274,12 +275,12 @@ func ValidateAccessToken(token string) (user *User) {
return nil
}
func GetUserQuota(id int) (quota int, err error) {
func GetUserQuota(id int) (quota int64, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
return quota, err
}
func GetUserUsedQuota(id int) (quota int, err error) {
func GetUserUsedQuota(id int) (quota int64, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find(&quota).Error
return quota, err
}
@ -299,7 +300,7 @@ func GetUserGroup(id int) (group string, err error) {
return group, err
}
func IncreaseUserQuota(id int, quota int) (err error) {
func IncreaseUserQuota(id int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@ -310,12 +311,12 @@ func IncreaseUserQuota(id int, quota int) (err error) {
return increaseUserQuota(id, quota)
}
func increaseUserQuota(id int, quota int) (err error) {
func increaseUserQuota(id int, quota int64) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
return err
}
func DecreaseUserQuota(id int, quota int) (err error) {
func DecreaseUserQuota(id int, quota int64) (err error) {
if quota < 0 {
return errors.New("quota 不能为负数!")
}
@ -326,7 +327,7 @@ func DecreaseUserQuota(id int, quota int) (err error) {
return decreaseUserQuota(id, quota)
}
func decreaseUserQuota(id int, quota int) (err error) {
func decreaseUserQuota(id int, quota int64) (err error) {
err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
return err
}
@ -336,7 +337,7 @@ func GetRootUserEmail() (email string) {
return email
}
func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) {
if config.BatchUpdateEnabled {
addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
addNewRecord(BatchUpdateTypeRequestCount, id, 1)
@ -345,7 +346,7 @@ func UpdateUserUsedQuotaAndRequestCount(id int, quota int) {
updateUserUsedQuotaAndRequestCount(id, quota, 1)
}
func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),
@ -357,7 +358,7 @@ func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) {
}
}
func updateUserUsedQuota(id int, quota int) {
func updateUserUsedQuota(id int, quota int64) {
err := DB.Model(&User{}).Where("id = ?", id).Updates(
map[string]interface{}{
"used_quota": gorm.Expr("used_quota + ?", quota),

View File

@ -16,12 +16,12 @@ const (
BatchUpdateTypeCount // if you add a new type, you need to add a new map and a new lock
)
var batchUpdateStores []map[int]int
var batchUpdateStores []map[int]int64
var batchUpdateLocks []sync.Mutex
func init() {
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateStores = append(batchUpdateStores, make(map[int]int))
batchUpdateStores = append(batchUpdateStores, make(map[int]int64))
batchUpdateLocks = append(batchUpdateLocks, sync.Mutex{})
}
}
@ -35,7 +35,7 @@ func InitBatchUpdater() {
}()
}
func addNewRecord(type_ int, id int, value int) {
func addNewRecord(type_ int, id int, value int64) {
batchUpdateLocks[type_].Lock()
defer batchUpdateLocks[type_].Unlock()
if _, ok := batchUpdateStores[type_][id]; !ok {
@ -50,7 +50,7 @@ func batchUpdate() {
for i := 0; i < BatchUpdateTypeCount; i++ {
batchUpdateLocks[i].Lock()
store := batchUpdateStores[i]
batchUpdateStores[i] = make(map[int]int)
batchUpdateStores[i] = make(map[int]int64)
batchUpdateLocks[i].Unlock()
// TODO: maybe we can combine updates with same key?
for key, value := range store {
@ -68,7 +68,7 @@ func batchUpdate() {
case BatchUpdateTypeUsedQuota:
updateUserUsedQuota(key, value)
case BatchUpdateTypeRequestCount:
updateUserRequestCount(key, value)
updateUserRequestCount(key, int(value))
case BatchUpdateTypeChannelUsedQuota:
updateChannelUsedQuota(key, value)
}

View File

@ -2,7 +2,7 @@ package anthropic
var ModelList = []string{
"claude-instant-1.2", "claude-2.0", "claude-2.1",
"claude-3-haiku-20240229",
"claude-3-haiku-20240307",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
}

View File

@ -7,4 +7,7 @@ var ModelList = []string{
"ERNIE-Speed",
"ERNIE-Bot-turbo",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",
"tao-8k",
}

View File

@ -0,0 +1,9 @@
package lingyiwanwu
// https://platform.lingyiwanwu.com/docs
var ModelList = []string{
"yi-34b-chat-0205",
"yi-34b-chat-200k",
"yi-vl-plus",
}

View File

@ -0,0 +1,65 @@
package ollama
import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"io"
"net/http"
)
type Adaptor struct {
}
func (a *Adaptor) Init(meta *util.RelayMeta) {
}
func (a *Adaptor) GetRequestURL(meta *util.RelayMeta) (string, error) {
// https://github.com/ollama/ollama/blob/main/docs/api.md
fullRequestURL := fmt.Sprintf("%s/api/chat", meta.BaseURL)
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *util.RelayMeta) error {
channel.SetupCommonRequestHeader(c, req, meta)
req.Header.Set("Authorization", "Bearer "+meta.APIKey)
return nil
}
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch relayMode {
case constant.RelayModeEmbeddings:
return nil, errors.New("not supported")
default:
return ConvertRequest(*request), nil
}
}
func (a *Adaptor) DoRequest(c *gin.Context, meta *util.RelayMeta, requestBody io.Reader) (*http.Response, error) {
return channel.DoRequestHelper(a, c, meta, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.RelayMeta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
if meta.IsStream {
err, usage = StreamHandler(c, resp)
} else {
err, usage = Handler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return "ollama"
}

View File

@ -0,0 +1,5 @@
package ollama
var ModelList = []string{
"qwen:0.5b-chat",
}

View File

@ -0,0 +1,178 @@
package ollama
import (
"bufio"
"context"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/common/logger"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
"strings"
)
func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
ollamaRequest := ChatRequest{
Model: request.Model,
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
Stream: request.Stream,
}
for _, message := range request.Messages {
ollamaRequest.Messages = append(ollamaRequest.Messages, Message{
Role: message.Role,
Content: message.StringContent(),
})
}
return &ollamaRequest
}
func responseOllama2OpenAI(response *ChatResponse) *openai.TextResponse {
choice := openai.TextResponseChoice{
Index: 0,
Message: model.Message{
Role: response.Message.Role,
Content: response.Message.Content,
},
}
if response.Done {
choice.FinishReason = "stop"
}
fullTextResponse := openai.TextResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion",
Created: helper.GetTimestamp(),
Choices: []openai.TextResponseChoice{choice},
Usage: model.Usage{
PromptTokens: response.PromptEvalCount,
CompletionTokens: response.EvalCount,
TotalTokens: response.PromptEvalCount + response.EvalCount,
},
}
return &fullTextResponse
}
func streamResponseOllama2OpenAI(ollamaResponse *ChatResponse) *openai.ChatCompletionsStreamResponse {
var choice openai.ChatCompletionsStreamResponseChoice
choice.Delta.Role = ollamaResponse.Message.Role
choice.Delta.Content = ollamaResponse.Message.Content
if ollamaResponse.Done {
choice.FinishReason = &constant.StopFinishReason
}
response := openai.ChatCompletionsStreamResponse{
Id: fmt.Sprintf("chatcmpl-%s", helper.GetUUID()),
Object: "chat.completion.chunk",
Created: helper.GetTimestamp(),
Model: ollamaResponse.Model,
Choices: []openai.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
var usage model.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "}\n"); i >= 0 {
return i + 2, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := strings.TrimPrefix(scanner.Text(), "}")
dataChan <- data + "}"
}
stopChan <- true
}()
common.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var ollamaResponse ChatResponse
err := json.Unmarshal([]byte(data), &ollamaResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if ollamaResponse.EvalCount != 0 {
usage.PromptTokens = ollamaResponse.PromptEvalCount
usage.CompletionTokens = ollamaResponse.EvalCount
usage.TotalTokens = ollamaResponse.PromptEvalCount + ollamaResponse.EvalCount
}
response := streamResponseOllama2OpenAI(&ollamaResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
err := resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
return nil, &usage
}
func Handler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) {
ctx := context.TODO()
var ollamaResponse ChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
}
logger.Debugf(ctx, "ollama response: %s", string(responseBody))
err = resp.Body.Close()
if err != nil {
return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
}
err = json.Unmarshal(responseBody, &ollamaResponse)
if err != nil {
return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
}
if ollamaResponse.Error != "" {
return &model.ErrorWithStatusCode{
Error: model.Error{
Message: ollamaResponse.Error,
Type: "ollama_error",
Param: "",
Code: "ollama_error",
},
StatusCode: resp.StatusCode,
}, nil
}
fullTextResponse := responseOllama2OpenAI(&ollamaResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return openai.ErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@ -0,0 +1,37 @@
package ollama
type Options struct {
Seed int `json:"seed,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
}
type Message struct {
Role string `json:"role,omitempty"`
Content string `json:"content,omitempty"`
Images []string `json:"images,omitempty"`
}
type ChatRequest struct {
Model string `json:"model,omitempty"`
Messages []Message `json:"messages,omitempty"`
Stream bool `json:"stream"`
Options *Options `json:"options,omitempty"`
}
type ChatResponse struct {
Model string `json:"model,omitempty"`
CreatedAt string `json:"created_at,omitempty"`
Message Message `json:"message,omitempty"`
Response string `json:"response,omitempty"` // for stream response
Done bool `json:"done,omitempty"`
TotalDuration int `json:"total_duration,omitempty"`
LoadDuration int `json:"load_duration,omitempty"`
PromptEvalCount int `json:"prompt_eval_count,omitempty"`
EvalCount int `json:"eval_count,omitempty"`
EvalDuration int `json:"eval_duration,omitempty"`
Error string `json:"error,omitempty"`
}

View File

@ -5,6 +5,7 @@ import (
"github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/groq"
"github.com/songquanpeng/one-api/relay/channel/lingyiwanwu"
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
@ -18,6 +19,7 @@ var CompatibleChannels = []int{
common.ChannelTypeMinimax,
common.ChannelTypeMistral,
common.ChannelTypeGroq,
common.ChannelTypeLingYiWanWu,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
@ -36,6 +38,8 @@ func GetCompatibleChannelMeta(channelType int) (string, []string) {
return "mistralai", mistral.ModelList
case common.ChannelTypeGroq:
return "groq", groq.ModelList
case common.ChannelTypeLingYiWanWu:
return "lingyiwanwu", lingyiwanwu.ModelList
default:
return "openai", ModelList
}

View File

@ -15,6 +15,7 @@ const (
APITypeAIProxyLibrary
APITypeTencent
APITypeGemini
APITypeOllama
APITypeDummy // this one is only for count, do not add any channel after this
)
@ -40,6 +41,8 @@ func ChannelType2APIType(channelType int) int {
apiType = APITypeTencent
case common.ChannelTypeGemini:
apiType = APITypeGemini
case common.ChannelTypeOllama:
apiType = APITypeOllama
}
return apiType
}

View File

@ -23,6 +23,7 @@ import (
)
func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
ctx := c.Request.Context()
audioModel := "whisper-1"
tokenId := c.GetInt("token_id")
@ -51,16 +52,16 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
// groupRatio := common.GetGroupRatio(group)
groupRatio := c.GetFloat64("channel_ratio")
ratio := modelRatio * groupRatio
var quota int
var preConsumedQuota int
var quota int64
var preConsumedQuota int64
switch relayMode {
case constant.RelayModeAudioSpeech:
preConsumedQuota = int(float64(len(ttsRequest.Input)) * ratio)
preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio)
quota = preConsumedQuota
default:
preConsumedQuota = int(float64(config.PreConsumedQuota) * ratio)
preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio)
}
userQuota, err := model.CacheGetUserQuota(userId)
userQuota, err := model.CacheGetUserQuota(ctx, userId)
if err != nil {
return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
@ -185,7 +186,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
if err != nil {
return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError)
}
quota = openai.CountTokenText(text, audioModel)
quota = int64(openai.CountTokenText(text, audioModel))
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
}
if resp.StatusCode != http.StatusOK {

View File

@ -107,18 +107,18 @@ func getPromptTokens(textRequest *relaymodel.GeneralOpenAIRequest, relayMode int
return 0
}
func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int {
func getPreConsumedQuota(textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64) int64 {
preConsumedTokens := config.PreConsumedQuota
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
preConsumedTokens = int64(promptTokens) + int64(textRequest.MaxTokens)
}
return int(float64(preConsumedTokens) * ratio)
return int64(float64(preConsumedTokens) * ratio)
}
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int, *relaymodel.ErrorWithStatusCode) {
func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIRequest, promptTokens int, ratio float64, meta *util.RelayMeta) (int64, *relaymodel.ErrorWithStatusCode) {
preConsumedQuota := getPreConsumedQuota(textRequest, promptTokens, ratio)
userQuota, err := model.CacheGetUserQuota(meta.UserId)
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
if err != nil {
return preConsumedQuota, openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
@ -144,16 +144,16 @@ func preConsumeQuota(ctx context.Context, textRequest *relaymodel.GeneralOpenAIR
return preConsumedQuota, nil
}
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int, modelRatio float64, groupRatio float64) {
func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.RelayMeta, textRequest *relaymodel.GeneralOpenAIRequest, ratio float64, preConsumedQuota int64, modelRatio float64, groupRatio float64) {
if usage == nil {
logger.Error(ctx, "usage is nil, which is unexpected")
return
}
quota := 0
var quota int64
completionRatio := common.GetCompletionRatio(textRequest.Model)
promptTokens := usage.PromptTokens
completionTokens := usage.CompletionTokens
quota = int(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
quota = int64(math.Ceil((float64(promptTokens) + float64(completionTokens)*completionRatio) * ratio))
if ratio != 0 && quota <= 0 {
quota = 1
}
@ -168,7 +168,7 @@ func postConsumeQuota(ctx context.Context, usage *relaymodel.Usage, meta *util.R
if err != nil {
logger.Error(ctx, "error consuming token remain quota: "+err.Error())
}
err = model.CacheUpdateUserQuota(meta.UserId)
err = model.CacheUpdateUserQuota(ctx, meta.UserId)
if err != nil {
logger.Error(ctx, "error update user quota cache: "+err.Error())
}

View File

@ -81,9 +81,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
// groupRatio := common.GetGroupRatio(meta.Group)
groupRatio := c.GetFloat64("channel_ratio") // pre-selected cheapest channel ratio
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(meta.UserId)
userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId)
quota := int(ratio*imageCostRatio*1000) * imageRequest.N
quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N)
if userQuota-quota < 0 {
return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden)
@ -127,7 +127,7 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(meta.UserId)
err = model.CacheUpdateUserQuota(ctx, meta.UserId)
if err != nil {
logger.SysError("error update user quota cache: " + err.Error())
}

View File

@ -77,6 +77,7 @@ func RelayTextHelper(c *gin.Context) *model.ErrorWithStatusCode {
if err != nil {
return openai.ErrorWrapper(err, "json_marshal_failed", http.StatusInternalServerError)
}
logger.Debugf(ctx, "converted request: \n%s", string(jsonData))
requestBody = bytes.NewBuffer(jsonData)
}

View File

@ -5,6 +5,7 @@ import (
"github.com/songquanpeng/one-api/relay/channel/aiproxy"
"github.com/songquanpeng/one-api/relay/channel/anthropic"
"github.com/songquanpeng/one-api/relay/channel/gemini"
"github.com/songquanpeng/one-api/relay/channel/ollama"
"github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channel/palm"
"github.com/songquanpeng/one-api/relay/constant"
@ -26,12 +27,14 @@ func GetAdaptor(apiType int) channel.Adaptor {
return &openai.Adaptor{}
case constant.APITypePaLM:
return &palm.Adaptor{}
// case constant.APITypeTencent:
// return &tencent.Adaptor{}
// case constant.APITypeXunfei:
// return &xunfei.Adaptor{}
// case constant.APITypeZhipu:
// return &zhipu.Adaptor{}
// case constant.APITypeTencent:
// return &tencent.Adaptor{}
// case constant.APITypeXunfei:
// return &xunfei.Adaptor{}
// case constant.APITypeZhipu:
// return &zhipu.Adaptor{}
case constant.APITypeOllama:
return &ollama.Adaptor{}
}
return nil
}

View File

@ -6,7 +6,7 @@ import (
"github.com/songquanpeng/one-api/model"
)
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int, tokenId int) {
func ReturnPreConsumedQuota(ctx context.Context, preConsumedQuota int64, tokenId int) {
if preConsumedQuota != 0 {
go func(ctx context.Context) {
// return pre-consumed quota

View File

@ -35,10 +35,17 @@ func ShouldDisableChannel(err *relaymodel.Error, statusCode int) bool {
return true
case "permission_error":
return true
case "forbidden":
return true
}
if err.Code == "invalid_api_key" || err.Code == "account_deactivated" {
return true
}
if strings.HasPrefix(err.Message, "Your credit balance is too low") { // anthropic
return true
} else if strings.HasPrefix(err.Message, "This organization has been disabled.") {
return true
}
return false
}
@ -148,20 +155,20 @@ func GetFullRequestURL(baseURL string, requestURL string, channelType int) strin
return fullRequestURL
}
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int, totalQuota int, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
func PostConsumeQuota(ctx context.Context, tokenId int, quotaDelta int64, totalQuota int64, userId int, channelId int, modelRatio float64, groupRatio float64, modelName string, tokenName string) {
// quotaDelta is remaining quota to be consumed
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
err = model.CacheUpdateUserQuota(ctx, userId)
if err != nil {
logger.SysError("error update user quota cache: " + err.Error())
}
// totalQuota is total quota consumed
if totalQuota >= 0 {
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(ctx, userId, channelId, totalQuota, 0, modelName, tokenName, totalQuota, logContent)
model.RecordConsumeLog(ctx, userId, channelId, int(totalQuota), 0, modelName, tokenName, totalQuota, logContent)
model.UpdateUserUsedQuotaAndRequestCount(userId, totalQuota)
model.UpdateChannelUsedQuota(channelId, totalQuota)
}

View File

@ -9,6 +9,7 @@ import (
func SetDashboardRouter(router *gin.Engine) {
apiRouter := router.Group("/")
apiRouter.Use(middleware.CORS())
apiRouter.Use(gzip.Gzip(gzip.DefaultCompression))
apiRouter.Use(middleware.GlobalAPIRateLimit())
apiRouter.Use(middleware.TokenAuth())

View File

@ -95,6 +95,18 @@ export const CHANNEL_OPTIONS = {
value: 29,
color: 'default'
},
30: {
key: 30,
text: 'Ollama',
value: 30,
color: 'default'
},
31: {
key: 31,
text: '零一万物',
value: 31,
color: 'default'
},
8: {
key: 8,
text: '自定义渠道',

View File

@ -166,6 +166,12 @@ const typeConfig = {
29: {
modelGroup: "groq",
},
30: {
modelGroup: "ollama",
},
31: {
modelGroup: "lingyiwanwu",
},
};
export { defaultConfig, typeConfig };

View File

@ -265,7 +265,7 @@ const OtherSetting = () => {
multiline
maxRows={15}
id="Footer"
label="公告"
label="页脚"
value={inputs.Footer}
name="Footer"
onChange={handleInputChange}

View File

@ -31,7 +31,7 @@ const COPY_OPTIONS = [
url: 'https://chat.oneapi.pro/#/?settings={"key":"sk-{key}","url":"{serverAddress}"}',
encode: false
},
{ key: 'ama', text: 'AMA 问天', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true },
{ key: 'ama', text: 'BotGem', url: 'ama://set-api-key?server={serverAddress}&key=sk-{key}', encode: true },
{ key: 'opencat', text: 'OpenCat', url: 'opencat://team/join?domain={serverAddress}&token=sk-{key}', encode: true }
];

View File

@ -12,7 +12,7 @@ const COPY_OPTIONS = [
];
const OPEN_LINK_OPTIONS = [
{ key: 'ama', text: 'AMA 问天', value: 'ama' },
{ key: 'ama', text: 'BotGem', value: 'ama' },
{ key: 'opencat', text: 'OpenCat', value: 'opencat' },
];

View File

@ -15,6 +15,8 @@ export const CHANNEL_OPTIONS = [
{ key: 26, text: '百川大模型', value: 26, color: 'orange' },
{ key: 27, text: 'MiniMax', value: 27, color: 'red' },
{ key: 29, text: 'Groq', value: 29, color: 'orange' },
{ key: 30, text: 'Ollama', value: 30, color: 'black' },
{ key: 31, text: '零一万物', value: 31, color: 'green' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },