hotgo/server/internal/library/casbin/adapter.go
2024-03-15 14:42:40 +08:00

330 lines
8.0 KiB
Go

// Package casbin
// @Link https://github.com/bufanyun/hotgo
// @Copyright Copyright (c) 2023 HotGo CLI
// @Author Ms <133814250@qq.com>
// @License https://github.com/bufanyun/hotgo/blob/master/LICENSE
package casbin
import (
"context"
"errors"
"fmt"
"math"
"strings"
"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/persist"
"github.com/gogf/gf/v2/database/gdb"
)
const (
defaultTableName = "hg_admin_role_casbin"
dropPolicyTableSql = `DROP TABLE IF EXISTS %s`
createPolicyTableSql = `
CREATE TABLE IF NOT EXISTS %s (
id bigint(20) NOT NULL AUTO_INCREMENT,
p_type varchar(64) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
v0 varchar(256) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
v1 varchar(256) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
v2 varchar(256) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
v3 varchar(256) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
v4 varchar(256) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
v5 varchar(256) CHARACTER SET utf8 COLLATE utf8_general_ci NULL DEFAULT NULL,
PRIMARY KEY (id) USING BTREE
) ENGINE = InnoDB AUTO_INCREMENT = 1 CHARACTER SET = utf8 COLLATE = utf8_general_ci COMMENT = '管理员_casbin权限表' ROW_FORMAT = Dynamic;
`
createPolicyTableSql_sqlite = `
CREATE TABLE IF NOT EXISTS %s (
id INTEGER NOT NULL ,
p_type TEXT DEFAULT NULL,
v0 TEXT DEFAULT NULL,
v1 TEXT DEFAULT NULL,
v2 TEXT DEFAULT NULL,
v3 TEXT DEFAULT NULL,
v4 TEXT DEFAULT NULL,
v5 TEXT DEFAULT NULL,
PRIMARY KEY (id)
);
`
)
type (
adapter struct {
db gdb.DB
table string
}
policyColumns struct {
ID string // ID
PType string // PType
V0 string // V0
V1 string // V1
V2 string // V2
V3 string // V3
V4 string // V4
V5 string // V5
}
// policy rule entity
policyRule struct {
ID int64 `orm:"id" json:"id"`
PType string `orm:"p_type" json:"p_type"`
V0 string `orm:"v0" json:"v0"`
V1 string `orm:"v1" json:"v1"`
V2 string `orm:"v2" json:"v2"`
V3 string `orm:"v3" json:"v3"`
V4 string `orm:"v4" json:"v4"`
V5 string `orm:"v5" json:"v5"`
}
)
var (
errInvalidDatabaseLink = errors.New("invalid database link")
policyColumnsName = policyColumns{
ID: "id",
PType: "p_type",
V0: "v0",
V1: "v1",
V2: "v2",
V3: "v3",
V4: "v4",
V5: "v5",
}
)
// NewAdapter Create a casbin adapter
func NewAdapter(link string) (adp *adapter, err error) {
adp = &adapter{table: defaultTableName}
config := strings.SplitN(link, ":", 2)
if len(config) != 2 {
err = errInvalidDatabaseLink
return
}
cNode := gdb.ConfigNode{Type: config[0], Link: config[1]}
if config[0] == "sqlite" {
cNode = gdb.ConfigNode{Type: "sqlite", Link: link}
}
if adp.db, err = gdb.New(cNode); err != nil {
return
}
err = adp.createPolicyTable()
return
}
func (a *adapter) model() *gdb.Model {
return a.db.Model(a.table).Safe().Ctx(context.TODO())
}
// create a policy table when it's not exists.
func (a *adapter) createPolicyTable() (err error) {
createSql := fmt.Sprintf(createPolicyTableSql, a.table)
if a.db.GetConfig().Type == "sqlite" {
createSql = fmt.Sprintf(createPolicyTableSql_sqlite, a.table)
}
_, err = a.db.Exec(context.TODO(), createSql)
return
}
// drop policy table from the storage.
func (a *adapter) dropPolicyTable() (err error) {
_, err = a.db.Exec(context.TODO(), fmt.Sprintf(dropPolicyTableSql, a.table))
return
}
// LoadPolicy loads all policy rules from the storage.
func (a *adapter) LoadPolicy(model model.Model) (err error) {
var rules []policyRule
if err = a.model().Scan(&rules); err != nil {
return
}
for _, rule := range rules {
a.loadPolicyRule(rule, model)
}
return
}
// SavePolicy Saves all policy rules to the storage.
func (a *adapter) SavePolicy(model model.Model) (err error) {
if err = a.dropPolicyTable(); err != nil {
return
}
if err = a.createPolicyTable(); err != nil {
return
}
policyRules := make([]policyRule, 0)
for ptype, ast := range model["p"] {
for _, rule := range ast.Policy {
policyRules = append(policyRules, a.buildPolicyRule(ptype, rule))
}
}
for ptype, ast := range model["g"] {
for _, rule := range ast.Policy {
policyRules = append(policyRules, a.buildPolicyRule(ptype, rule))
}
}
if count := len(policyRules); count > 0 {
if _, err = a.model().Insert(policyRules); err != nil {
return
}
}
return
}
// AddPolicy adds a policy rule to the storage.
func (a *adapter) AddPolicy(sec string, ptype string, rule []string) (err error) {
_, err = a.model().Insert(a.buildPolicyRule(ptype, rule))
return
}
// AddPolicies adds policy rules to the storage.
func (a *adapter) AddPolicies(sec string, ptype string, rules [][]string) (err error) {
if len(rules) == 0 {
return
}
policyRules := make([]policyRule, 0, len(rules))
for _, rule := range rules {
policyRules = append(policyRules, a.buildPolicyRule(ptype, rule))
}
_, err = a.model().Insert(policyRules)
return
}
// RemovePolicy removes a policy rule from the storage.
func (a *adapter) RemovePolicy(sec string, ptype string, rule []string) (err error) {
db := a.model()
db = db.Where(policyColumnsName.PType, ptype)
for index := 0; index < len(rule); index++ {
db = db.Where(fmt.Sprintf("v%d", index), rule[index])
}
_, err = db.Delete()
return err
}
// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
func (a *adapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) (err error) {
db := a.model()
db = db.Where(policyColumnsName.PType, ptype)
for index := 0; index <= 5; index++ {
if fieldIndex <= index && index < fieldIndex+len(fieldValues) {
db = db.Where(fmt.Sprintf("v%d", index), fieldValues[index-fieldIndex])
}
}
_, err = db.Delete()
return
}
// RemovePolicies removes policy rules from the storage (implements the persist.BatchAdapter interface).
func (a *adapter) RemovePolicies(sec string, ptype string, rules [][]string) (err error) {
db := a.model()
for _, rule := range rules {
where := map[string]interface{}{policyColumnsName.PType: ptype}
for i := 0; i <= 5; i++ {
if len(rule) > i {
where[fmt.Sprintf("v%d", i)] = rule[i]
}
}
db = db.WhereOr(where)
}
_, err = db.Delete()
return
}
// UpdatePolicy updates a policy rule from storage.
func (a *adapter) UpdatePolicy(sec string, ptype string, oldRule, newRule []string) (err error) {
_, err = a.model().Update(a.buildPolicyRule(ptype, newRule), a.buildPolicyRule(ptype, oldRule))
return
}
// UpdatePolicies updates some policy rules to storage, like db, redis.
func (a *adapter) UpdatePolicies(sec string, ptype string, oldRules, newRules [][]string) (err error) {
if len(oldRules) == 0 || len(newRules) == 0 {
return
}
return a.db.Transaction(context.TODO(), func(ctx context.Context, tx gdb.TX) error {
for i := 0; i < int(math.Min(float64(len(oldRules)), float64(len(newRules)))); i++ {
if _, err = tx.Model(a.table).Update(a.buildPolicyRule(ptype, newRules[i]), a.buildPolicyRule(ptype, oldRules[i])); err != nil {
return err
}
}
return nil
})
}
// 加载策略规则
func (a *adapter) loadPolicyRule(rule policyRule, model model.Model) {
ruleText := rule.PType
if rule.V0 != "" {
ruleText += ", " + rule.V0
}
if rule.V1 != "" {
ruleText += ", " + rule.V1
}
if rule.V2 != "" {
ruleText += ", " + rule.V2
}
if rule.V3 != "" {
ruleText += ", " + rule.V3
}
if rule.V4 != "" {
ruleText += ", " + rule.V4
}
if rule.V5 != "" {
ruleText += ", " + rule.V5
}
if err := persist.LoadPolicyLine(ruleText, model); err != nil {
panic(err)
}
}
// 构建策略规则
func (a *adapter) buildPolicyRule(ptype string, data []string) policyRule {
rule := policyRule{PType: ptype}
if len(data) > 0 {
rule.V0 = data[0]
}
if len(data) > 1 {
rule.V1 = data[1]
}
if len(data) > 2 {
rule.V2 = data[2]
}
if len(data) > 3 {
rule.V3 = data[3]
}
if len(data) > 4 {
rule.V4 = data[4]
}
if len(data) > 5 {
rule.V5 = data[5]
}
return rule
}