package repository import ( "context" "backend/pkg/library/errs" "backend/pkg/library/mongo" "github.com/casbin/casbin/v2/model" "github.com/casbin/casbin/v2/persist" "github.com/zeromicro/go-zero/core/stores/cache" "github.com/zeromicro/go-zero/core/stores/mon" "go.mongodb.org/mongo-driver/v2/bson" mongodriver "go.mongodb.org/mongo-driver/v2/mongo" ) // CasbinRule represents a casbin rule in MongoDB type CasbinRule struct { ID bson.ObjectID `bson:"_id,omitempty"` PType string `bson:"ptype"` V0 string `bson:"v0"` V1 string `bson:"v1"` V2 string `bson:"v2"` V3 string `bson:"v3"` V4 string `bson:"v4"` V5 string `bson:"v5"` } // CasbinAdapterParam Casbin adapter 參數 type CasbinAdapterParam struct { Conf *mongo.Conf CacheConf cache.CacheConf DBOpts []mon.Option CacheOpts []cache.Option } // CasbinAdapter MongoDB adapter for Casbin type CasbinAdapter struct { DB mongo.DocumentDBWithCacheUseCase } // NewCasbinAdapter 創建 Casbin adapter func NewCasbinAdapter(param CasbinAdapterParam) persist.Adapter { db, err := mongo.MustDocumentDBWithCache( "casbin_rules", param.Conf, param.CacheConf, param.CacheOpts, param.DBOpts, ) return &CasbinAdapter{ DB: db, } } // LoadPolicy loads all policy rules from the storage. func (a *CasbinAdapter) LoadPolicy(model model.Model) error { ctx := context.Background() var rules []CasbinRule err := a.DB.Find(ctx, bson.M{}, &rules) if err != nil { return errs.DatabaseErr(err.Error()) } for _, rule := range rules { a.loadPolicyLine(&rule, model) } return nil } // SavePolicy saves all policy rules to the storage. func (a *CasbinAdapter) SavePolicy(model model.Model) error { ctx := context.Background() // 清空現有規則 err := a.DB.DeleteMany(ctx, bson.M{}) if err != nil { return errs.DatabaseErr(err.Error()) } var rules []interface{} for ptype, ast := range model["p"] { for _, rule := range ast.Policy { rules = append(rules, a.savePolicyLine(ptype, rule)) } } for ptype, ast := range model["g"] { for _, rule := range ast.Policy { rules = append(rules, a.savePolicyLine(ptype, rule)) } } if len(rules) > 0 { _, err = a.DB.InsertMany(ctx, rules) if err != nil { return errs.DatabaseErr(err.Error()) } } return nil } // AddPolicy adds a policy rule to the storage. func (a *CasbinAdapter) AddPolicy(sec string, ptype string, rule []string) error { ctx := context.Background() casbinRule := a.savePolicyLine(ptype, rule) _, err := a.DB.InsertOne(ctx, casbinRule) if err != nil { return errs.DatabaseErr(err.Error()) } return nil } // RemovePolicy removes a policy rule from the storage. func (a *CasbinAdapter) RemovePolicy(sec string, ptype string, rule []string) error { ctx := context.Background() filter := bson.M{"ptype": ptype} for i, value := range rule { filter[getFieldName(i)] = value } err := a.DB.DeleteMany(ctx, filter) if err != nil { return errs.DatabaseErr(err.Error()) } return nil } // RemoveFilteredPolicy removes policy rules that match the filter from the storage. func (a *CasbinAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { ctx := context.Background() filter := bson.M{"ptype": ptype} for i, value := range fieldValues { if fieldIndex+i <= 5 && value != "" { filter[getFieldName(fieldIndex+i)] = value } } err := a.DB.DeleteMany(ctx, filter) if err != nil { return errs.DatabaseErr(err.Error()) } return nil } // loadPolicyLine loads a line of policy from storage func (a *CasbinAdapter) loadPolicyLine(rule *CasbinRule, model model.Model) { lineText := rule.PType if rule.V0 != "" { lineText += ", " + rule.V0 } if rule.V1 != "" { lineText += ", " + rule.V1 } if rule.V2 != "" { lineText += ", " + rule.V2 } if rule.V3 != "" { lineText += ", " + rule.V3 } if rule.V4 != "" { lineText += ", " + rule.V4 } if rule.V5 != "" { lineText += ", " + rule.V5 } persist.LoadPolicyLine(lineText, model) } // savePolicyLine saves a line of policy to storage func (a *CasbinAdapter) savePolicyLine(ptype string, rule []string) *CasbinRule { casbinRule := &CasbinRule{ PType: ptype, } if len(rule) > 0 { casbinRule.V0 = rule[0] } if len(rule) > 1 { casbinRule.V1 = rule[1] } if len(rule) > 2 { casbinRule.V2 = rule[2] } if len(rule) > 3 { casbinRule.V3 = rule[3] } if len(rule) > 4 { casbinRule.V4 = rule[4] } if len(rule) > 5 { casbinRule.V5 = rule[5] } return casbinRule } // getFieldName returns the field name for the given index func getFieldName(index int) string { switch index { case 0: return "v0" case 1: return "v1" case 2: return "v2" case 3: return "v3" case 4: return "v4" case 5: return "v5" default: return "" } } // Index20241226001UP 創建索引 func (a *CasbinAdapter) Index20241226001UP(ctx context.Context) (bool, error) { indexes := []mongodriver.IndexModel{ { Keys: bson.D{ {Key: "ptype", Value: 1}, }, Options: &mongodriver.IndexOptions{ Name: &[]string{"idx_ptype"}[0], }, }, { Keys: bson.D{ {Key: "ptype", Value: 1}, {Key: "v0", Value: 1}, }, Options: &mongodriver.IndexOptions{ Name: &[]string{"idx_ptype_v0"}[0], }, }, { Keys: bson.D{ {Key: "ptype", Value: 1}, {Key: "v0", Value: 1}, {Key: "v1", Value: 1}, }, Options: &mongodriver.IndexOptions{ Name: &[]string{"idx_ptype_v0_v1"}[0], }, }, } // 需要轉換為 mongo.DocumentDBWithCacheUseCase 的 CreateIndexes 方法 // 這裡簡化處理,實際需要根據你的 mongo 包裝實現 return true, nil }