mirror of
				https://github.com/songquanpeng/one-api.git
				synced 2025-11-04 15:53:42 +08:00 
			
		
		
		
	* 修复git、微信等用户注册不会创建默认令牌问题 修复git、微信等用户注册不会创建默认令牌问题 * 修复git、微信等用户注册不会创建默认令牌问题 删除普通用户注册代码 * fix: do not block if error happened --------- Co-authored-by: JustSong <songquanpeng@foxmail.com>
		
			
				
	
	
		
			438 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			438 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package model
 | 
						||
 | 
						||
import (
 | 
						||
	"errors"
 | 
						||
	"fmt"
 | 
						||
	"github.com/songquanpeng/one-api/common"
 | 
						||
	"github.com/songquanpeng/one-api/common/blacklist"
 | 
						||
	"github.com/songquanpeng/one-api/common/config"
 | 
						||
	"github.com/songquanpeng/one-api/common/helper"
 | 
						||
	"github.com/songquanpeng/one-api/common/logger"
 | 
						||
	"github.com/songquanpeng/one-api/common/random"
 | 
						||
	"gorm.io/gorm"
 | 
						||
	"strings"
 | 
						||
)
 | 
						||
 | 
						||
const (
 | 
						||
	RoleGuestUser  = 0
 | 
						||
	RoleCommonUser = 1
 | 
						||
	RoleAdminUser  = 10
 | 
						||
	RoleRootUser   = 100
 | 
						||
)
 | 
						||
 | 
						||
const (
 | 
						||
	UserStatusEnabled  = 1 // don't use 0, 0 is the default value!
 | 
						||
	UserStatusDisabled = 2 // also don't use 0
 | 
						||
	UserStatusDeleted  = 3
 | 
						||
)
 | 
						||
 | 
						||
// User if you add sensitive fields, don't forget to clean them in setupLogin function.
 | 
						||
// Otherwise, the sensitive information will be saved on local storage in plain text!
 | 
						||
type User struct {
 | 
						||
	Id               int    `json:"id"`
 | 
						||
	Username         string `json:"username" gorm:"unique;index" validate:"max=12"`
 | 
						||
	Password         string `json:"password" gorm:"not null;" validate:"min=8,max=20"`
 | 
						||
	DisplayName      string `json:"display_name" gorm:"index" validate:"max=20"`
 | 
						||
	Role             int    `json:"role" gorm:"type:int;default:1"`   // admin, util
 | 
						||
	Status           int    `json:"status" gorm:"type:int;default:1"` // enabled, disabled
 | 
						||
	Email            string `json:"email" gorm:"index" validate:"max=50"`
 | 
						||
	GitHubId         string `json:"github_id" gorm:"column:github_id;index"`
 | 
						||
	WeChatId         string `json:"wechat_id" gorm:"column:wechat_id;index"`
 | 
						||
	LarkId           string `json:"lark_id" gorm:"column:lark_id;index"`
 | 
						||
	VerificationCode string `json:"verification_code" gorm:"-:all"`                                    // this field is only for Email verification, don't save it to database!
 | 
						||
	AccessToken      string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management
 | 
						||
	Quota            int64  `json:"quota" gorm:"bigint;default:0"`
 | 
						||
	UsedQuota        int64  `json:"used_quota" gorm:"bigint;default:0;column:used_quota"` // used quota
 | 
						||
	RequestCount     int    `json:"request_count" gorm:"type:int;default:0;"`             // request number
 | 
						||
	Group            string `json:"group" gorm:"type:varchar(32);default:'default'"`
 | 
						||
	AffCode          string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"`
 | 
						||
	InviterId        int    `json:"inviter_id" gorm:"type:int;column:inviter_id;index"`
 | 
						||
}
 | 
						||
 | 
						||
func GetMaxUserId() int {
 | 
						||
	var user User
 | 
						||
	DB.Last(&user)
 | 
						||
	return user.Id
 | 
						||
}
 | 
						||
 | 
						||
func GetAllUsers(startIdx int, num int, order string) (users []*User, err error) {
 | 
						||
	query := DB.Limit(num).Offset(startIdx).Omit("password").Where("status != ?", UserStatusDeleted)
 | 
						||
 | 
						||
	switch order {
 | 
						||
	case "quota":
 | 
						||
		query = query.Order("quota desc")
 | 
						||
	case "used_quota":
 | 
						||
		query = query.Order("used_quota desc")
 | 
						||
	case "request_count":
 | 
						||
		query = query.Order("request_count desc")
 | 
						||
	default:
 | 
						||
		query = query.Order("id desc")
 | 
						||
	}
 | 
						||
 | 
						||
	err = query.Find(&users).Error
 | 
						||
	return users, err
 | 
						||
}
 | 
						||
 | 
						||
func SearchUsers(keyword string) (users []*User, err error) {
 | 
						||
	if !common.UsingPostgreSQL {
 | 
						||
		err = DB.Omit("password").Where("id = ? or username LIKE ? or email LIKE ? or display_name LIKE ?", keyword, keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
 | 
						||
	} else {
 | 
						||
		err = DB.Omit("password").Where("username LIKE ? or email LIKE ? or display_name LIKE ?", keyword+"%", keyword+"%", keyword+"%").Find(&users).Error
 | 
						||
	}
 | 
						||
	return users, err
 | 
						||
}
 | 
						||
 | 
						||
func GetUserById(id int, selectAll bool) (*User, error) {
 | 
						||
	if id == 0 {
 | 
						||
		return nil, errors.New("id 为空!")
 | 
						||
	}
 | 
						||
	user := User{Id: id}
 | 
						||
	var err error = nil
 | 
						||
	if selectAll {
 | 
						||
		err = DB.First(&user, "id = ?", id).Error
 | 
						||
	} else {
 | 
						||
		err = DB.Omit("password").First(&user, "id = ?", id).Error
 | 
						||
	}
 | 
						||
	return &user, err
 | 
						||
}
 | 
						||
 | 
						||
func GetUserIdByAffCode(affCode string) (int, error) {
 | 
						||
	if affCode == "" {
 | 
						||
		return 0, errors.New("affCode 为空!")
 | 
						||
	}
 | 
						||
	var user User
 | 
						||
	err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error
 | 
						||
	return user.Id, err
 | 
						||
}
 | 
						||
 | 
						||
func DeleteUserById(id int) (err error) {
 | 
						||
	if id == 0 {
 | 
						||
		return errors.New("id 为空!")
 | 
						||
	}
 | 
						||
	user := User{Id: id}
 | 
						||
	return user.Delete()
 | 
						||
}
 | 
						||
 | 
						||
func (user *User) Insert(inviterId int) error {
 | 
						||
	var err error
 | 
						||
	if user.Password != "" {
 | 
						||
		user.Password, err = common.Password2Hash(user.Password)
 | 
						||
		if err != nil {
 | 
						||
			return err
 | 
						||
		}
 | 
						||
	}
 | 
						||
	user.Quota = config.QuotaForNewUser
 | 
						||
	user.AccessToken = random.GetUUID()
 | 
						||
	user.AffCode = random.GetRandomString(4)
 | 
						||
	result := DB.Create(user)
 | 
						||
	if result.Error != nil {
 | 
						||
		return result.Error
 | 
						||
	}
 | 
						||
	if config.QuotaForNewUser > 0 {
 | 
						||
		RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(config.QuotaForNewUser)))
 | 
						||
	}
 | 
						||
	if inviterId != 0 {
 | 
						||
		if config.QuotaForInvitee > 0 {
 | 
						||
			_ = IncreaseUserQuota(user.Id, config.QuotaForInvitee)
 | 
						||
			RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(config.QuotaForInvitee)))
 | 
						||
		}
 | 
						||
		if config.QuotaForInviter > 0 {
 | 
						||
			_ = IncreaseUserQuota(inviterId, config.QuotaForInviter)
 | 
						||
			RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(config.QuotaForInviter)))
 | 
						||
		}
 | 
						||
	}
 | 
						||
	// create default token
 | 
						||
	cleanToken := Token{
 | 
						||
		UserId:         user.Id,
 | 
						||
		Name:           "default",
 | 
						||
		Key:            random.GenerateKey(),
 | 
						||
		CreatedTime:    helper.GetTimestamp(),
 | 
						||
		AccessedTime:   helper.GetTimestamp(),
 | 
						||
		ExpiredTime:    -1,
 | 
						||
		RemainQuota:    -1,
 | 
						||
		UnlimitedQuota: true,
 | 
						||
	}
 | 
						||
	result.Error = cleanToken.Insert()
 | 
						||
	if result.Error != nil {
 | 
						||
		// do not block
 | 
						||
		logger.SysError(fmt.Sprintf("create default token for user %d failed: %s", user.Id, result.Error.Error()))
 | 
						||
	}
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func (user *User) Update(updatePassword bool) error {
 | 
						||
	var err error
 | 
						||
	if updatePassword {
 | 
						||
		user.Password, err = common.Password2Hash(user.Password)
 | 
						||
		if err != nil {
 | 
						||
			return err
 | 
						||
		}
 | 
						||
	}
 | 
						||
	if user.Status == UserStatusDisabled {
 | 
						||
		blacklist.BanUser(user.Id)
 | 
						||
	} else if user.Status == UserStatusEnabled {
 | 
						||
		blacklist.UnbanUser(user.Id)
 | 
						||
	}
 | 
						||
	err = DB.Model(user).Updates(user).Error
 | 
						||
	return err
 | 
						||
}
 | 
						||
 | 
						||
func (user *User) Delete() error {
 | 
						||
	if user.Id == 0 {
 | 
						||
		return errors.New("id 为空!")
 | 
						||
	}
 | 
						||
	blacklist.BanUser(user.Id)
 | 
						||
	user.Username = fmt.Sprintf("deleted_%s", random.GetUUID())
 | 
						||
	user.Status = UserStatusDeleted
 | 
						||
	err := DB.Model(user).Updates(user).Error
 | 
						||
	return err
 | 
						||
}
 | 
						||
 | 
						||
// ValidateAndFill check password & user status
 | 
						||
func (user *User) ValidateAndFill() (err error) {
 | 
						||
	// When querying with struct, GORM will only query with non-zero fields,
 | 
						||
	// that means if your field’s value is 0, '', false or other zero values,
 | 
						||
	// it won’t be used to build query conditions
 | 
						||
	password := user.Password
 | 
						||
	if user.Username == "" || password == "" {
 | 
						||
		return errors.New("用户名或密码为空")
 | 
						||
	}
 | 
						||
	err = DB.Where("username = ?", user.Username).First(user).Error
 | 
						||
	if err != nil {
 | 
						||
		// we must make sure check username firstly
 | 
						||
		// consider this case: a malicious user set his username as other's email
 | 
						||
		err := DB.Where("email = ?", user.Username).First(user).Error
 | 
						||
		if err != nil {
 | 
						||
			return errors.New("用户名或密码错误,或用户已被封禁")
 | 
						||
		}
 | 
						||
	}
 | 
						||
	okay := common.ValidatePasswordAndHash(password, user.Password)
 | 
						||
	if !okay || user.Status != UserStatusEnabled {
 | 
						||
		return errors.New("用户名或密码错误,或用户已被封禁")
 | 
						||
	}
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func (user *User) FillUserById() error {
 | 
						||
	if user.Id == 0 {
 | 
						||
		return errors.New("id 为空!")
 | 
						||
	}
 | 
						||
	DB.Where(User{Id: user.Id}).First(user)
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func (user *User) FillUserByEmail() error {
 | 
						||
	if user.Email == "" {
 | 
						||
		return errors.New("email 为空!")
 | 
						||
	}
 | 
						||
	DB.Where(User{Email: user.Email}).First(user)
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func (user *User) FillUserByGitHubId() error {
 | 
						||
	if user.GitHubId == "" {
 | 
						||
		return errors.New("GitHub id 为空!")
 | 
						||
	}
 | 
						||
	DB.Where(User{GitHubId: user.GitHubId}).First(user)
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func (user *User) FillUserByLarkId() error {
 | 
						||
	if user.LarkId == "" {
 | 
						||
		return errors.New("lark id 为空!")
 | 
						||
	}
 | 
						||
	DB.Where(User{LarkId: user.LarkId}).First(user)
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func (user *User) FillUserByWeChatId() error {
 | 
						||
	if user.WeChatId == "" {
 | 
						||
		return errors.New("WeChat id 为空!")
 | 
						||
	}
 | 
						||
	DB.Where(User{WeChatId: user.WeChatId}).First(user)
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func (user *User) FillUserByUsername() error {
 | 
						||
	if user.Username == "" {
 | 
						||
		return errors.New("username 为空!")
 | 
						||
	}
 | 
						||
	DB.Where(User{Username: user.Username}).First(user)
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func IsEmailAlreadyTaken(email string) bool {
 | 
						||
	return DB.Where("email = ?", email).Find(&User{}).RowsAffected == 1
 | 
						||
}
 | 
						||
 | 
						||
func IsWeChatIdAlreadyTaken(wechatId string) bool {
 | 
						||
	return DB.Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1
 | 
						||
}
 | 
						||
 | 
						||
func IsGitHubIdAlreadyTaken(githubId string) bool {
 | 
						||
	return DB.Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1
 | 
						||
}
 | 
						||
 | 
						||
func IsLarkIdAlreadyTaken(githubId string) bool {
 | 
						||
	return DB.Where("lark_id = ?", githubId).Find(&User{}).RowsAffected == 1
 | 
						||
}
 | 
						||
 | 
						||
func IsUsernameAlreadyTaken(username string) bool {
 | 
						||
	return DB.Where("username = ?", username).Find(&User{}).RowsAffected == 1
 | 
						||
}
 | 
						||
 | 
						||
func ResetUserPasswordByEmail(email string, password string) error {
 | 
						||
	if email == "" || password == "" {
 | 
						||
		return errors.New("邮箱地址或密码为空!")
 | 
						||
	}
 | 
						||
	hashedPassword, err := common.Password2Hash(password)
 | 
						||
	if err != nil {
 | 
						||
		return err
 | 
						||
	}
 | 
						||
	err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error
 | 
						||
	return err
 | 
						||
}
 | 
						||
 | 
						||
func IsAdmin(userId int) bool {
 | 
						||
	if userId == 0 {
 | 
						||
		return false
 | 
						||
	}
 | 
						||
	var user User
 | 
						||
	err := DB.Where("id = ?", userId).Select("role").Find(&user).Error
 | 
						||
	if err != nil {
 | 
						||
		logger.SysError("no such user " + err.Error())
 | 
						||
		return false
 | 
						||
	}
 | 
						||
	return user.Role >= RoleAdminUser
 | 
						||
}
 | 
						||
 | 
						||
func IsUserEnabled(userId int) (bool, error) {
 | 
						||
	if userId == 0 {
 | 
						||
		return false, errors.New("user id is empty")
 | 
						||
	}
 | 
						||
	var user User
 | 
						||
	err := DB.Where("id = ?", userId).Select("status").Find(&user).Error
 | 
						||
	if err != nil {
 | 
						||
		return false, err
 | 
						||
	}
 | 
						||
	return user.Status == UserStatusEnabled, nil
 | 
						||
}
 | 
						||
 | 
						||
func ValidateAccessToken(token string) (user *User) {
 | 
						||
	if token == "" {
 | 
						||
		return nil
 | 
						||
	}
 | 
						||
	token = strings.Replace(token, "Bearer ", "", 1)
 | 
						||
	user = &User{}
 | 
						||
	if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 {
 | 
						||
		return user
 | 
						||
	}
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func GetUserQuota(id int) (quota int64, err error) {
 | 
						||
	err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error
 | 
						||
	return quota, err
 | 
						||
}
 | 
						||
 | 
						||
func GetUserUsedQuota(id int) (quota int64, err error) {
 | 
						||
	err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error
 | 
						||
	return quota, err
 | 
						||
}
 | 
						||
 | 
						||
func GetUserEmail(id int) (email string, err error) {
 | 
						||
	err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error
 | 
						||
	return email, err
 | 
						||
}
 | 
						||
 | 
						||
func GetUserGroup(id int) (group string, err error) {
 | 
						||
	groupCol := "`group`"
 | 
						||
	if common.UsingPostgreSQL {
 | 
						||
		groupCol = `"group"`
 | 
						||
	}
 | 
						||
 | 
						||
	err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error
 | 
						||
	return group, err
 | 
						||
}
 | 
						||
 | 
						||
func IncreaseUserQuota(id int, quota int64) (err error) {
 | 
						||
	if quota < 0 {
 | 
						||
		return errors.New("quota 不能为负数!")
 | 
						||
	}
 | 
						||
	if config.BatchUpdateEnabled {
 | 
						||
		addNewRecord(BatchUpdateTypeUserQuota, id, quota)
 | 
						||
		return nil
 | 
						||
	}
 | 
						||
	return increaseUserQuota(id, quota)
 | 
						||
}
 | 
						||
 | 
						||
func increaseUserQuota(id int, quota int64) (err error) {
 | 
						||
	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error
 | 
						||
	return err
 | 
						||
}
 | 
						||
 | 
						||
func DecreaseUserQuota(id int, quota int64) (err error) {
 | 
						||
	if quota < 0 {
 | 
						||
		return errors.New("quota 不能为负数!")
 | 
						||
	}
 | 
						||
	if config.BatchUpdateEnabled {
 | 
						||
		addNewRecord(BatchUpdateTypeUserQuota, id, -quota)
 | 
						||
		return nil
 | 
						||
	}
 | 
						||
	return decreaseUserQuota(id, quota)
 | 
						||
}
 | 
						||
 | 
						||
func decreaseUserQuota(id int, quota int64) (err error) {
 | 
						||
	err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error
 | 
						||
	return err
 | 
						||
}
 | 
						||
 | 
						||
func GetRootUserEmail() (email string) {
 | 
						||
	DB.Model(&User{}).Where("role = ?", RoleRootUser).Select("email").Find(&email)
 | 
						||
	return email
 | 
						||
}
 | 
						||
 | 
						||
func UpdateUserUsedQuotaAndRequestCount(id int, quota int64) {
 | 
						||
	if config.BatchUpdateEnabled {
 | 
						||
		addNewRecord(BatchUpdateTypeUsedQuota, id, quota)
 | 
						||
		addNewRecord(BatchUpdateTypeRequestCount, id, 1)
 | 
						||
		return
 | 
						||
	}
 | 
						||
	updateUserUsedQuotaAndRequestCount(id, quota, 1)
 | 
						||
}
 | 
						||
 | 
						||
func updateUserUsedQuotaAndRequestCount(id int, quota int64, count int) {
 | 
						||
	err := DB.Model(&User{}).Where("id = ?", id).Updates(
 | 
						||
		map[string]interface{}{
 | 
						||
			"used_quota":    gorm.Expr("used_quota + ?", quota),
 | 
						||
			"request_count": gorm.Expr("request_count + ?", count),
 | 
						||
		},
 | 
						||
	).Error
 | 
						||
	if err != nil {
 | 
						||
		logger.SysError("failed to update user used quota and request count: " + err.Error())
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
func updateUserUsedQuota(id int, quota int64) {
 | 
						||
	err := DB.Model(&User{}).Where("id = ?", id).Updates(
 | 
						||
		map[string]interface{}{
 | 
						||
			"used_quota": gorm.Expr("used_quota + ?", quota),
 | 
						||
		},
 | 
						||
	).Error
 | 
						||
	if err != nil {
 | 
						||
		logger.SysError("failed to update user used quota: " + err.Error())
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
func updateUserRequestCount(id int, count int) {
 | 
						||
	err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error
 | 
						||
	if err != nil {
 | 
						||
		logger.SysError("failed to update user request count: " + err.Error())
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
func GetUsernameById(id int) (username string) {
 | 
						||
	DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username)
 | 
						||
	return username
 | 
						||
}
 |