mirror of
https://github.com/bufanyun/hotgo.git
synced 2025-10-12 04:53:47 +08:00
v2.0
This commit is contained in:
324
server/internal/library/casbin/adapter.go
Normal file
324
server/internal/library/casbin/adapter.go
Normal file
@@ -0,0 +1,324 @@
|
||||
// Package casbin
|
||||
// @Link https://github.com/bufanyun/hotgo
|
||||
// @Copyright Copyright (c) 2022 HotGo CLI
|
||||
// @Author Ms <133814250@qq.com>
|
||||
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
|
||||
//
|
||||
package casbin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/casbin/casbin/v2/model"
|
||||
"github.com/casbin/casbin/v2/persist"
|
||||
"github.com/gogf/gf/v2/database/gdb"
|
||||
"log"
|
||||
"math"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultTableName = "hg_admin_role_casbin"
|
||||
dropPolicyTableSql = `DROP TABLE IF EXISTS %s`
|
||||
createPolicyTableSql = `
|
||||
CREATE TABLE IF NOT EXISTS %s (
|
||||
id bigint(20) NOT NULL AUTO_INCREMENT,
|
||||
p_type varchar(64) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
|
||||
v0 varchar(256) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
|
||||
v1 varchar(256) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
|
||||
v2 varchar(256) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
|
||||
v3 varchar(256) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
|
||||
v4 varchar(256) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
|
||||
v5 varchar(256) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
|
||||
PRIMARY KEY (id) USING BTREE
|
||||
) ENGINE = InnoDB AUTO_INCREMENT = 1 CHARACTER SET = utf8 COLLATE = utf8_general_ci COMMENT = 'casbin权限表' ROW_FORMAT = Dynamic;
|
||||
`
|
||||
)
|
||||
|
||||
type (
|
||||
adapter struct {
|
||||
db gdb.DB
|
||||
table string
|
||||
}
|
||||
|
||||
policyColumns struct {
|
||||
ID string // ID
|
||||
PType string // PType
|
||||
V0 string // V0
|
||||
V1 string // V1
|
||||
V2 string // V2
|
||||
V3 string // V3
|
||||
V4 string // V4
|
||||
V5 string // V5
|
||||
}
|
||||
|
||||
// policy rule entity
|
||||
policyRule struct {
|
||||
ID int64 `orm:"id" json:"id"`
|
||||
PType string `orm:"p_type" json:"p_type"`
|
||||
V0 string `orm:"v0" json:"v0"`
|
||||
V1 string `orm:"v1" json:"v1"`
|
||||
V2 string `orm:"v2" json:"v2"`
|
||||
V3 string `orm:"v3" json:"v3"`
|
||||
V4 string `orm:"v4" json:"v4"`
|
||||
V5 string `orm:"v5" json:"v5"`
|
||||
}
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidDatabaseLink = errors.New("invalid database link")
|
||||
policyColumnsName = policyColumns{
|
||||
ID: "id",
|
||||
PType: "p_type",
|
||||
V0: "v0",
|
||||
V1: "v1",
|
||||
V2: "v2",
|
||||
V3: "v3",
|
||||
V4: "v4",
|
||||
V5: "v5",
|
||||
}
|
||||
)
|
||||
|
||||
// NewAdapter Create a casbin adapter
|
||||
func NewAdapter(link string) (adp *adapter, err error) {
|
||||
adp = &adapter{table: defaultTableName}
|
||||
config := strings.SplitN(link, ":", 2)
|
||||
|
||||
if len(config) != 2 {
|
||||
err = errInvalidDatabaseLink
|
||||
return
|
||||
}
|
||||
|
||||
if adp.db, err = gdb.New(gdb.ConfigNode{Type: config[0], Link: config[1]}); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err = adp.createPolicyTable()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (a *adapter) model() *gdb.Model {
|
||||
return a.db.Model(a.table).Safe().Ctx(context.TODO())
|
||||
}
|
||||
|
||||
// create a policy table when it's not exists.
|
||||
func (a *adapter) createPolicyTable() (err error) {
|
||||
_, err = a.db.Exec(context.TODO(), fmt.Sprintf(createPolicyTableSql, a.table))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// drop policy table from the storage.
|
||||
func (a *adapter) dropPolicyTable() (err error) {
|
||||
_, err = a.db.Exec(context.TODO(), fmt.Sprintf(dropPolicyTableSql, a.table))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// LoadPolicy loads all policy rules from the storage.
|
||||
func (a *adapter) LoadPolicy(model model.Model) (err error) {
|
||||
log.Println("LoadPolicy...")
|
||||
var rules []policyRule
|
||||
|
||||
if err = a.model().Scan(&rules); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
a.loadPolicyRule(rule, model)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// SavePolicy Saves all policy rules to the storage.
|
||||
func (a *adapter) SavePolicy(model model.Model) (err error) {
|
||||
if err = a.dropPolicyTable(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err = a.createPolicyTable(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
policyRules := make([]policyRule, 0)
|
||||
|
||||
for ptype, ast := range model["p"] {
|
||||
for _, rule := range ast.Policy {
|
||||
policyRules = append(policyRules, a.buildPolicyRule(ptype, rule))
|
||||
}
|
||||
}
|
||||
|
||||
for ptype, ast := range model["g"] {
|
||||
for _, rule := range ast.Policy {
|
||||
policyRules = append(policyRules, a.buildPolicyRule(ptype, rule))
|
||||
}
|
||||
}
|
||||
|
||||
if count := len(policyRules); count > 0 {
|
||||
if _, err = a.model().Insert(policyRules); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// AddPolicy adds a policy rule to the storage.
|
||||
func (a *adapter) AddPolicy(sec string, ptype string, rule []string) (err error) {
|
||||
_, err = a.model().Insert(a.buildPolicyRule(ptype, rule))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// AddPolicies adds policy rules to the storage.
|
||||
func (a *adapter) AddPolicies(sec string, ptype string, rules [][]string) (err error) {
|
||||
if len(rules) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
policyRules := make([]policyRule, 0, len(rules))
|
||||
|
||||
for _, rule := range rules {
|
||||
policyRules = append(policyRules, a.buildPolicyRule(ptype, rule))
|
||||
}
|
||||
|
||||
_, err = a.model().Insert(policyRules)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// RemovePolicy removes a policy rule from the storage.
|
||||
func (a *adapter) RemovePolicy(sec string, ptype string, rule []string) (err error) {
|
||||
db := a.model()
|
||||
db = db.Where(policyColumnsName.PType, ptype)
|
||||
for index := 0; index < len(rule); index++ {
|
||||
db = db.Where(fmt.Sprintf("v%d", index), rule[index])
|
||||
}
|
||||
_, err = db.Delete()
|
||||
return err
|
||||
}
|
||||
|
||||
// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
|
||||
func (a *adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) (err error) {
|
||||
db := a.model()
|
||||
db = db.Where(policyColumnsName.PType, ptype)
|
||||
for index := 0; index <= 5; index++ {
|
||||
if fieldIndex <= index && index < fieldIndex+len(fieldValues) {
|
||||
db = db.Where(fmt.Sprintf("v%d", index), fieldValues[index-fieldIndex])
|
||||
}
|
||||
}
|
||||
_, err = db.Delete()
|
||||
return
|
||||
}
|
||||
|
||||
// RemovePolicies removes policy rules from the storage (implements the persist.BatchAdapter interface).
|
||||
func (a *adapter) RemovePolicies(sec string, ptype string, rules [][]string) (err error) {
|
||||
db := a.model()
|
||||
|
||||
for _, rule := range rules {
|
||||
where := map[string]interface{}{policyColumnsName.PType: ptype}
|
||||
|
||||
for i := 0; i <= 5; i++ {
|
||||
if len(rule) > i {
|
||||
where[fmt.Sprintf("v%d", i)] = rule[i]
|
||||
}
|
||||
}
|
||||
|
||||
db = db.WhereOr(where)
|
||||
}
|
||||
|
||||
_, err = db.Delete()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// UpdatePolicy updates a policy rule from storage.
|
||||
func (a *adapter) UpdatePolicy(sec string, ptype string, oldRule, newRule []string) (err error) {
|
||||
_, err = a.model().Update(a.buildPolicyRule(ptype, newRule), a.buildPolicyRule(ptype, oldRule))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// UpdatePolicies updates some policy rules to storage, like db, redis.
|
||||
func (a *adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) (err error) {
|
||||
if len(oldRules) == 0 || len(newRules) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
err = a.db.Transaction(context.TODO(), func(ctx context.Context, tx *gdb.TX) error {
|
||||
for i := 0; i < int(math.Min(float64(len(oldRules)), float64(len(newRules)))); i++ {
|
||||
if _, err = tx.Model(a.table).Update(a.buildPolicyRule(ptype, newRules[i]), a.buildPolicyRule(ptype, oldRules[i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 加载策略规则
|
||||
func (a *adapter) loadPolicyRule(rule policyRule, model model.Model) {
|
||||
ruleText := rule.PType
|
||||
|
||||
if rule.V0 != "" {
|
||||
ruleText += ", " + rule.V0
|
||||
}
|
||||
|
||||
if rule.V1 != "" {
|
||||
ruleText += ", " + rule.V1
|
||||
}
|
||||
|
||||
if rule.V2 != "" {
|
||||
ruleText += ", " + rule.V2
|
||||
}
|
||||
|
||||
if rule.V3 != "" {
|
||||
ruleText += ", " + rule.V3
|
||||
}
|
||||
|
||||
if rule.V4 != "" {
|
||||
ruleText += ", " + rule.V4
|
||||
}
|
||||
|
||||
if rule.V5 != "" {
|
||||
ruleText += ", " + rule.V5
|
||||
}
|
||||
|
||||
persist.LoadPolicyLine(ruleText, model)
|
||||
}
|
||||
|
||||
// 构建策略规则
|
||||
func (a *adapter) buildPolicyRule(ptype string, data []string) policyRule {
|
||||
rule := policyRule{PType: ptype}
|
||||
|
||||
if len(data) > 0 {
|
||||
rule.V0 = data[0]
|
||||
}
|
||||
|
||||
if len(data) > 1 {
|
||||
rule.V1 = data[1]
|
||||
}
|
||||
|
||||
if len(data) > 2 {
|
||||
rule.V2 = data[2]
|
||||
}
|
||||
|
||||
if len(data) > 3 {
|
||||
rule.V3 = data[3]
|
||||
}
|
||||
|
||||
if len(data) > 4 {
|
||||
rule.V4 = data[4]
|
||||
}
|
||||
|
||||
if len(data) > 5 {
|
||||
rule.V5 = data[5]
|
||||
}
|
||||
|
||||
return rule
|
||||
}
|
Reference in New Issue
Block a user