diff --git a/controller/model.go b/controller/model.go index de86ca3..5e1aa7d 100644 --- a/controller/model.go +++ b/controller/model.go @@ -108,8 +108,8 @@ func init() { }) } openAIModelsMap = make(map[string]dto.OpenAIModels) - for _, model := range openAIModels { - openAIModelsMap[model.Id] = model + for _, aiModel := range openAIModels { + openAIModelsMap[aiModel.Id] = aiModel } channelId2Models = make(map[int][]string) for i := 1; i <= common.ChannelTypeDummy; i++ { @@ -174,8 +174,8 @@ func DashboardListModels(c *gin.Context) { func RetrieveModel(c *gin.Context) { modelId := c.Param("model") - if model, ok := openAIModelsMap[modelId]; ok { - c.JSON(200, model) + if aiModel, ok := openAIModelsMap[modelId]; ok { + c.JSON(200, aiModel) } else { openAIError := dto.OpenAIError{ Message: fmt.Sprintf("The model '%s' does not exist", modelId), @@ -191,12 +191,12 @@ func RetrieveModel(c *gin.Context) { func GetPricing(c *gin.Context) { userId := c.GetInt("id") - user, _ := model.GetUserById(userId, true) + group, err := model.CacheGetUserGroup(userId) groupRatio := common.GetGroupRatio("default") - if user != nil { - groupRatio = common.GetGroupRatio(user.Group) + if err != nil { + groupRatio = common.GetGroupRatio(group) } - pricing := model.GetPricing(user, openAIModels) + pricing := model.GetPricing(group) c.JSON(200, gin.H{ "success": true, "data": pricing, diff --git a/model/pricing.go b/model/pricing.go index 237227c..90d8bc7 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -13,16 +13,16 @@ var ( updatePricingLock sync.Mutex ) -func GetPricing(user *User, openAIModels []dto.OpenAIModels) []dto.ModelPricing { +func GetPricing(group string) []dto.ModelPricing { updatePricingLock.Lock() defer updatePricingLock.Unlock() if time.Since(lastGetPricingTime) > time.Minute*1 || len(pricingMap) == 0 { - updatePricing(openAIModels) + updatePricing() } - if user != nil { + if group != "" { userPricingMap := make([]dto.ModelPricing, 0) - models := GetGroupModels(user.Group) + models := GetGroupModels(group) for _, pricing := range pricingMap { if !common.StringsContains(models, pricing.ModelName) { pricing.Available = false @@ -34,7 +34,7 @@ func GetPricing(user *User, openAIModels []dto.OpenAIModels) []dto.ModelPricing return pricingMap } -func updatePricing(openAIModels []dto.OpenAIModels) { +func updatePricing() { //modelRatios := common.GetModelRatios() enabledModels := GetEnabledModels() allModels := make(map[string]int) diff --git a/router/api-router.go b/router/api-router.go index add5c5f..b98a94e 100644 --- a/router/api-router.go +++ b/router/api-router.go @@ -20,7 +20,7 @@ func SetApiRouter(router *gin.Engine) { apiRouter.GET("/about", controller.GetAbout) //apiRouter.GET("/midjourney", controller.GetMidjourney) apiRouter.GET("/home_page_content", controller.GetHomePageContent) - apiRouter.GET("/pricing", middleware.CriticalRateLimit(), middleware.TryUserAuth(), controller.GetPricing) + apiRouter.GET("/pricing", middleware.TryUserAuth(), controller.GetPricing) apiRouter.GET("/verification", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendEmailVerification) apiRouter.GET("/reset_password", middleware.CriticalRateLimit(), middleware.TurnstileCheck(), controller.SendPasswordResetEmail) apiRouter.POST("/user/reset", middleware.CriticalRateLimit(), controller.ResetPassword)