package model import ( "errors" "fmt" "one-api/common" "strconv" "strings" "time" "gorm.io/gorm" ) // User if you add sensitive fields, don't forget to clean them in setupLogin function. // Otherwise, the sensitive information will be saved on local storage in plain text! type User struct { Id int `json:"id"` Username string `json:"username" gorm:"unique;index" validate:"max=12"` Password string `json:"password" gorm:"not null;" validate:"min=8,max=20"` DisplayName string `json:"display_name" gorm:"index" validate:"max=20"` Role int `json:"role" gorm:"type:int;default:1"` // admin, common Status int `json:"status" gorm:"type:int;default:1"` // enabled, disabled Email string `json:"email" gorm:"index" validate:"max=50"` GitHubId string `json:"github_id" gorm:"column:github_id;index"` LinuxDoId string `json:"linuxdo_id" gorm:"column:linuxdo_id;index"` LinuxDoLevel int `json:"linuxdo_level" gorm:"column:linuxdo_level;type:int;default:0"` WeChatId string `json:"wechat_id" gorm:"column:wechat_id;index"` TelegramId string `json:"telegram_id" gorm:"column:telegram_id;index"` VerificationCode string `json:"verification_code" gorm:"-:all"` // this field is only for Email verification, don't save it to database! AccessToken *string `json:"access_token" gorm:"type:char(32);column:access_token;uniqueIndex"` // this token is for system management Quota int `json:"quota" gorm:"type:int;default:0"` UsedQuota int `json:"used_quota" gorm:"type:int;default:0;column:used_quota"` // used quota RequestCount int `json:"request_count" gorm:"type:int;default:0;"` // request number Group string `json:"group" gorm:"type:varchar(64);default:'default'"` AffCode string `json:"aff_code" gorm:"type:varchar(32);column:aff_code;uniqueIndex"` AffCount int `json:"aff_count" gorm:"type:int;default:0;column:aff_count"` AffQuota int `json:"aff_quota" gorm:"type:int;default:0;column:aff_quota"` // 邀请剩余额度 AffHistoryQuota int `json:"aff_history_quota" gorm:"type:int;default:0;column:aff_history"` // 邀请历史额度 InviterId int `json:"inviter_id" gorm:"type:int;column:inviter_id;index"` StripeCustomer string `json:"stripe_customer" gorm:"column:stripe_customer;index"` DeletedAt gorm.DeletedAt `gorm:"index"` } func (user *User) GetAccessToken() string { if user.AccessToken == nil { return "" } return *user.AccessToken } func (user *User) SetAccessToken(token string) { user.AccessToken = &token } // CheckUserExistOrDeleted check if user exist or deleted, if not exist, return false, nil, if deleted or exist, return true, nil func CheckUserExistOrDeleted(username string, email string) (bool, error) { var user User // err := DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error // check email if empty var err error if email == "" { err = DB.Unscoped().First(&user, "username = ?", username).Error } else { err = DB.Unscoped().First(&user, "username = ? or email = ?", username, email).Error } if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { // not exist, return false, nil return false, nil } // other error, return false, err return false, err } // exist, return true, nil return true, nil } func GetMaxUserId() int { var user User DB.Unscoped().Last(&user) return user.Id } func GetAllUsers(startIdx int, num int) (users []*User, err error) { err = DB.Unscoped().Order("id desc").Limit(num).Offset(startIdx).Omit("password").Find(&users).Error return users, err } func SearchUsers(keyword string, group string) ([]*User, error) { var users []*User var err error groupCol := "`group`" if common.UsingPostgreSQL { groupCol = `"group"` } // 尝试将关键字转换为整数ID keywordInt, err := strconv.Atoi(keyword) if err == nil { // 如果转换成功,按照ID和可选的组别搜索用户 query := DB.Unscoped().Omit("password").Where("id = ?", keywordInt) if group != "" { query = query.Where(groupCol+" = ?", group) // 使用反引号包围group } err = query.Find(&users).Error if err != nil || len(users) > 0 { return users, err } } err = nil query := DB.Unscoped().Omit("password") likeCondition := "username LIKE ? OR email LIKE ? OR display_name LIKE ?" if group != "" { query = query.Where("("+likeCondition+") AND "+groupCol+" = ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%", group) } else { query = query.Where(likeCondition, "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%") } err = query.Find(&users).Error return users, err } func GetUserById(id int, selectAll bool) (*User, error) { if id == 0 { return nil, errors.New("id 为空!") } user := User{Id: id} var err error = nil if selectAll { err = DB.First(&user, "id = ?", id).Error } else { err = DB.Omit("password").First(&user, "id = ?", id).Error } return &user, err } func GetUserByIdUnscoped(id int, selectAll bool) (*User, error) { if id == 0 { return nil, errors.New("id 为空!") } user := User{Id: id} var err error = nil if selectAll { err = DB.Unscoped().First(&user, "id = ?", id).Error } else { err = DB.Unscoped().Omit("password").First(&user, "id = ?", id).Error } return &user, err } func GetUserIdByAffCode(affCode string) (int, error) { if affCode == "" { return 0, errors.New("affCode 为空!") } var user User err := DB.Select("id").First(&user, "aff_code = ?", affCode).Error return user.Id, err } func DeleteUserById(id int) (err error) { if id == 0 { return errors.New("id 为空!") } user := User{Id: id} return user.Delete() } func HardDeleteUserById(id int) error { if id == 0 { return errors.New("id 为空!") } err := DB.Unscoped().Delete(&User{}, "id = ?", id).Error return err } func inviteUser(inviterId int) (err error) { user, err := GetUserById(inviterId, true) if err != nil { return err } user.AffCount++ user.AffQuota += common.QuotaForInviter user.AffHistoryQuota += common.QuotaForInviter return DB.Save(user).Error } func (user *User) TransferAffQuotaToQuota(quota int) error { // 检查quota是否小于最小额度 if float64(quota) < common.QuotaPerUnit { return fmt.Errorf("转移额度最小为%s!", common.LogQuota(int(common.QuotaPerUnit))) } // 开始数据库事务 tx := DB.Begin() if tx.Error != nil { return tx.Error } defer tx.Rollback() // 确保在函数退出时事务能回滚 // 加锁查询用户以确保数据一致性 err := tx.Set("gorm:query_option", "FOR UPDATE").First(&user, user.Id).Error if err != nil { return err } // 再次检查用户的AffQuota是否足够 if user.AffQuota < quota { return errors.New("邀请额度不足!") } // 更新用户额度 user.AffQuota -= quota user.Quota += quota // 保存用户状态 if err := tx.Save(user).Error; err != nil { return err } // 提交事务 return tx.Commit().Error } func (user *User) Insert(inviterId int) error { var err error if user.Password != "" { user.Password, err = common.Password2Hash(user.Password) if err != nil { return err } } user.Quota = common.QuotaForNewUser //user.SetAccessToken(common.GetUUID()) user.AffCode = common.GetRandomString(4) result := DB.Create(user) if result.Error != nil { return result.Error } if common.QuotaForNewUser > 0 { RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("新用户注册赠送 %s", common.LogQuota(common.QuotaForNewUser))) } if inviterId != 0 { if common.QuotaForInvitee > 0 { _ = IncreaseUserQuota(user.Id, common.QuotaForInvitee) RecordLog(user.Id, LogTypeSystem, fmt.Sprintf("使用邀请码赠送 %s", common.LogQuota(common.QuotaForInvitee))) } if common.QuotaForInviter > 0 { //_ = IncreaseUserQuota(inviterId, common.QuotaForInviter) RecordLog(inviterId, LogTypeSystem, fmt.Sprintf("邀请用户赠送 %s", common.LogQuota(common.QuotaForInviter))) _ = inviteUser(inviterId) } } return nil } func (user *User) Update(updatePassword bool) error { var err error if updatePassword { user.Password, err = common.Password2Hash(user.Password) if err != nil { return err } } newUser := *user DB.First(&user, user.Id) err = DB.Model(user).Updates(newUser).Error if err == nil { if common.RedisEnabled { _ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second) _ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) } } return err } func (user *User) Edit(updatePassword bool) error { var err error if updatePassword { user.Password, err = common.Password2Hash(user.Password) if err != nil { return err } } newUser := *user updates := map[string]interface{}{ "username": newUser.Username, "display_name": newUser.DisplayName, "group": newUser.Group, "quota": newUser.Quota, } if updatePassword { updates["password"] = newUser.Password } DB.First(&user, user.Id) err = DB.Model(user).Updates(updates).Error if err == nil { if common.RedisEnabled { _ = common.RedisSet(fmt.Sprintf("user_group:%d", user.Id), user.Group, time.Duration(UserId2GroupCacheSeconds)*time.Second) _ = common.RedisSet(fmt.Sprintf("user_quota:%d", user.Id), strconv.Itoa(user.Quota), time.Duration(UserId2QuotaCacheSeconds)*time.Second) } } return err } func (user *User) Delete() error { if user.Id == 0 { return errors.New("id 为空!") } err := DB.Delete(user).Error return err } func (user *User) HardDelete() error { if user.Id == 0 { return errors.New("id 为空!") } err := DB.Unscoped().Delete(user).Error return err } // ValidateAndFill check password & user status func (user *User) ValidateAndFill() (err error) { // When querying with struct, GORM will only query with non-zero fields, // that means if your field’s value is 0, '', false or other zero values, // it won’t be used to build query conditions password := user.Password username := strings.TrimSpace(user.Username) if username == "" || password == "" { return errors.New("用户名或密码为空") } // find buy username or email DB.Where("username = ? OR email = ?", username, username).First(user) okay := common.ValidatePasswordAndHash(password, user.Password) if !okay || user.Status != common.UserStatusEnabled { return errors.New("用户名或密码错误,或用户已被封禁") } return nil } func (user *User) FillUserById() error { if user.Id == 0 { return errors.New("id 为空!") } DB.Where(User{Id: user.Id}).First(user) return nil } func (user *User) FillUserByEmail() error { if user.Email == "" { return errors.New("email 为空!") } DB.Where(User{Email: user.Email}).First(user) return nil } func (user *User) FillUserByGitHubId() error { if user.GitHubId == "" { return errors.New("GitHub id 为空!") } DB.Where(User{GitHubId: user.GitHubId}).First(user) return nil } func (user *User) FillUserByLinuxDoId() error { if user.LinuxDoId == "" { return errors.New("LINUX DO id 为空!") } DB.Where(User{LinuxDoId: user.LinuxDoId}).First(user) return nil } func (user *User) FillUserByWeChatId() error { if user.WeChatId == "" { return errors.New("WeChat id 为空!") } DB.Where(User{WeChatId: user.WeChatId}).First(user) return nil } func (user *User) FillUserByTelegramId() error { if user.TelegramId == "" { return errors.New("Telegram id 为空!") } err := DB.Where(User{TelegramId: user.TelegramId}).First(user).Error if errors.Is(err, gorm.ErrRecordNotFound) { return errors.New("该 Telegram 账户未绑定") } return nil } func IsEmailAlreadyTaken(email string) bool { return DB.Unscoped().Where("email = ?", email).Find(&User{}).RowsAffected == 1 } func IsWeChatIdAlreadyTaken(wechatId string) bool { return DB.Unscoped().Where("wechat_id = ?", wechatId).Find(&User{}).RowsAffected == 1 } func IsGitHubIdAlreadyTaken(githubId string) bool { return DB.Unscoped().Where("github_id = ?", githubId).Find(&User{}).RowsAffected == 1 } func IsLinuxDoIdAlreadyTaken(linuxdoId string) bool { return DB.Unscoped().Where("linuxdo_id = ?", linuxdoId).Find(&User{}).RowsAffected == 1 } func IsUsernameAlreadyTaken(username string) bool { return DB.Unscoped().Where("username = ?", username).Find(&User{}).RowsAffected == 1 } func IsTelegramIdAlreadyTaken(telegramId string) bool { return DB.Unscoped().Where("telegram_id = ?", telegramId).Find(&User{}).RowsAffected == 1 } func ResetUserPasswordByEmail(email string, password string) error { if email == "" || password == "" { return errors.New("邮箱地址或密码为空!") } hashedPassword, err := common.Password2Hash(password) if err != nil { return err } err = DB.Model(&User{}).Where("email = ?", email).Update("password", hashedPassword).Error return err } func IsAdmin(userId int) bool { if userId == 0 { return false } var user User err := DB.Where("id = ?", userId).Select("role").Find(&user).Error if err != nil { common.SysError("no such user " + err.Error()) return false } return user.Role >= common.RoleAdminUser } func IsUserEnabled(userId int) (bool, error) { if userId == 0 { return false, errors.New("user id is empty") } var user User err := DB.Where("id = ?", userId).Select("status").Find(&user).Error if err != nil { return false, err } return user.Status == common.UserStatusEnabled, nil } func IsLinuxDoEnabled(userId int) (bool, error) { if userId == 0 { return false, errors.New("user id is empty") } var user User err := DB.Where("id = ?", userId).Select("linuxdo_id, linuxdo_level").Find(&user).Error if err != nil { return false, err } return user.LinuxDoId == "" || user.LinuxDoLevel >= common.LinuxDoMinLevel, nil } func ValidateAccessToken(token string) (user *User) { if token == "" { return nil } token = strings.Replace(token, "Bearer ", "", 1) user = &User{} if DB.Where("access_token = ?", token).First(user).RowsAffected == 1 { return user } return nil } func GetUserQuota(id int) (quota int, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find("a).Error if err != nil { if common.RedisEnabled { go cacheSetUserQuota(id, quota) } } return quota, err } func GetUserUsedQuota(id int) (quota int, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("used_quota").Find("a).Error return quota, err } func GetUserEmail(id int) (email string, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("email").Find(&email).Error return email, err } func GetUserGroup(id int) (group string, err error) { groupCol := "`group`" if common.UsingPostgreSQL { groupCol = `"group"` } err = DB.Model(&User{}).Where("id = ?", id).Select(groupCol).Find(&group).Error return group, err } func IncreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } if common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, quota) return nil } return increaseUserQuota(id, quota) } func increaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota + ?", quota)).Error return err } func DecreaseUserQuota(id int, quota int) (err error) { if quota < 0 { return errors.New("quota 不能为负数!") } if common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUserQuota, id, -quota) return nil } return decreaseUserQuota(id, quota) } func decreaseUserQuota(id int, quota int) (err error) { err = DB.Model(&User{}).Where("id = ?", id).Update("quota", gorm.Expr("quota - ?", quota)).Error return err } func GetRootUserEmail() (email string) { DB.Model(&User{}).Where("role = ?", common.RoleRootUser).Select("email").Find(&email) return email } func UpdateUserUsedQuotaAndRequestCount(id int, quota int) { if common.BatchUpdateEnabled { addNewRecord(BatchUpdateTypeUsedQuota, id, quota) addNewRecord(BatchUpdateTypeRequestCount, id, 1) return } updateUserUsedQuotaAndRequestCount(id, quota, 1) } func updateUserUsedQuotaAndRequestCount(id int, quota int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), "request_count": gorm.Expr("request_count + ?", count), }, ).Error if err != nil { common.SysError("failed to update user used quota and request count: " + err.Error()) } } func updateUserUsedQuota(id int, quota int) { err := DB.Model(&User{}).Where("id = ?", id).Updates( map[string]interface{}{ "used_quota": gorm.Expr("used_quota + ?", quota), }, ).Error if err != nil { common.SysError("failed to update user used quota: " + err.Error()) } } func updateUserRequestCount(id int, count int) { err := DB.Model(&User{}).Where("id = ?", id).Update("request_count", gorm.Expr("request_count + ?", count)).Error if err != nil { common.SysError("failed to update user request count: " + err.Error()) } } func GetUsernameById(id int) (username string, err error) { err = DB.Model(&User{}).Where("id = ?", id).Select("username").Find(&username).Error return username, err }