backend/pkg/permission/repository/casbin_adapter.go

266 lines
5.6 KiB
Go
Raw Normal View History

2025-10-03 08:38:12 +00:00
package repository
import (
"context"
"backend/pkg/library/errs"
"backend/pkg/library/mongo"
"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/persist"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mon"
"go.mongodb.org/mongo-driver/v2/bson"
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
)
// CasbinRule represents a casbin rule in MongoDB
type CasbinRule struct {
ID bson.ObjectID `bson:"_id,omitempty"`
PType string `bson:"ptype"`
V0 string `bson:"v0"`
V1 string `bson:"v1"`
V2 string `bson:"v2"`
V3 string `bson:"v3"`
V4 string `bson:"v4"`
V5 string `bson:"v5"`
}
// CasbinAdapterParam Casbin adapter 參數
type CasbinAdapterParam struct {
Conf *mongo.Conf
CacheConf cache.CacheConf
DBOpts []mon.Option
CacheOpts []cache.Option
}
// CasbinAdapter MongoDB adapter for Casbin
type CasbinAdapter struct {
DB mongo.DocumentDBWithCacheUseCase
}
// NewCasbinAdapter 創建 Casbin adapter
func NewCasbinAdapter(param CasbinAdapterParam) persist.Adapter {
db, err := mongo.MustDocumentDBWithCache(
"casbin_rules",
param.Conf,
param.CacheConf,
param.CacheOpts,
param.DBOpts,
)
return &CasbinAdapter{
DB: db,
}
}
// LoadPolicy loads all policy rules from the storage.
func (a *CasbinAdapter) LoadPolicy(model model.Model) error {
ctx := context.Background()
var rules []CasbinRule
err := a.DB.Find(ctx, bson.M{}, &rules)
if err != nil {
return errs.DatabaseErr(err.Error())
}
for _, rule := range rules {
a.loadPolicyLine(&rule, model)
}
return nil
}
// SavePolicy saves all policy rules to the storage.
func (a *CasbinAdapter) SavePolicy(model model.Model) error {
ctx := context.Background()
// 清空現有規則
err := a.DB.DeleteMany(ctx, bson.M{})
if err != nil {
return errs.DatabaseErr(err.Error())
}
var rules []interface{}
for ptype, ast := range model["p"] {
for _, rule := range ast.Policy {
rules = append(rules, a.savePolicyLine(ptype, rule))
}
}
for ptype, ast := range model["g"] {
for _, rule := range ast.Policy {
rules = append(rules, a.savePolicyLine(ptype, rule))
}
}
if len(rules) > 0 {
_, err = a.DB.InsertMany(ctx, rules)
if err != nil {
return errs.DatabaseErr(err.Error())
}
}
return nil
}
// AddPolicy adds a policy rule to the storage.
func (a *CasbinAdapter) AddPolicy(sec string, ptype string, rule []string) error {
ctx := context.Background()
casbinRule := a.savePolicyLine(ptype, rule)
_, err := a.DB.InsertOne(ctx, casbinRule)
if err != nil {
return errs.DatabaseErr(err.Error())
}
return nil
}
// RemovePolicy removes a policy rule from the storage.
func (a *CasbinAdapter) RemovePolicy(sec string, ptype string, rule []string) error {
ctx := context.Background()
filter := bson.M{"ptype": ptype}
for i, value := range rule {
filter[getFieldName(i)] = value
}
err := a.DB.DeleteMany(ctx, filter)
if err != nil {
return errs.DatabaseErr(err.Error())
}
return nil
}
// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
func (a *CasbinAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
ctx := context.Background()
filter := bson.M{"ptype": ptype}
for i, value := range fieldValues {
if fieldIndex+i <= 5 && value != "" {
filter[getFieldName(fieldIndex+i)] = value
}
}
err := a.DB.DeleteMany(ctx, filter)
if err != nil {
return errs.DatabaseErr(err.Error())
}
return nil
}
// loadPolicyLine loads a line of policy from storage
func (a *CasbinAdapter) loadPolicyLine(rule *CasbinRule, model model.Model) {
lineText := rule.PType
if rule.V0 != "" {
lineText += ", " + rule.V0
}
if rule.V1 != "" {
lineText += ", " + rule.V1
}
if rule.V2 != "" {
lineText += ", " + rule.V2
}
if rule.V3 != "" {
lineText += ", " + rule.V3
}
if rule.V4 != "" {
lineText += ", " + rule.V4
}
if rule.V5 != "" {
lineText += ", " + rule.V5
}
persist.LoadPolicyLine(lineText, model)
}
// savePolicyLine saves a line of policy to storage
func (a *CasbinAdapter) savePolicyLine(ptype string, rule []string) *CasbinRule {
casbinRule := &CasbinRule{
PType: ptype,
}
if len(rule) > 0 {
casbinRule.V0 = rule[0]
}
if len(rule) > 1 {
casbinRule.V1 = rule[1]
}
if len(rule) > 2 {
casbinRule.V2 = rule[2]
}
if len(rule) > 3 {
casbinRule.V3 = rule[3]
}
if len(rule) > 4 {
casbinRule.V4 = rule[4]
}
if len(rule) > 5 {
casbinRule.V5 = rule[5]
}
return casbinRule
}
// getFieldName returns the field name for the given index
func getFieldName(index int) string {
switch index {
case 0:
return "v0"
case 1:
return "v1"
case 2:
return "v2"
case 3:
return "v3"
case 4:
return "v4"
case 5:
return "v5"
default:
return ""
}
}
// Index20241226001UP 創建索引
func (a *CasbinAdapter) Index20241226001UP(ctx context.Context) (bool, error) {
indexes := []mongodriver.IndexModel{
{
Keys: bson.D{
{Key: "ptype", Value: 1},
},
Options: &mongodriver.IndexOptions{
Name: &[]string{"idx_ptype"}[0],
},
},
{
Keys: bson.D{
{Key: "ptype", Value: 1},
{Key: "v0", Value: 1},
},
Options: &mongodriver.IndexOptions{
Name: &[]string{"idx_ptype_v0"}[0],
},
},
{
Keys: bson.D{
{Key: "ptype", Value: 1},
{Key: "v0", Value: 1},
{Key: "v1", Value: 1},
},
Options: &mongodriver.IndexOptions{
Name: &[]string{"idx_ptype_v0_v1"}[0],
},
},
}
// 需要轉換為 mongo.DocumentDBWithCacheUseCase 的 CreateIndexes 方法
// 這裡簡化處理,實際需要根據你的 mongo 包裝實現
return true, nil
}