package usecase import ( "context" "encoding/json" "errors" "fmt" "strings" "sync" "time" redislib "gateway/internal/library/redis" permission "gateway/internal/model/permission/domain" "gateway/internal/model/permission/domain/entity" "gateway/internal/model/permission/domain/enum" domrepo "gateway/internal/model/permission/domain/repository" dom "gateway/internal/model/permission/domain/usecase" "github.com/casbin/casbin/v2" casbinmodel "github.com/casbin/casbin/v2/model" "github.com/zeromicro/go-zero/core/logx" ) // RBACUseCaseParam injects all repos + Redis Pub/Sub client. ModelPath // must point at etc/rbac.conf; CasbinModelText overrides ModelPath when // non-empty (used by tests / embedded resources). type RBACUseCaseParam struct { Roles domrepo.RoleRepository Permissions domrepo.PermissionRepository RolePermissions domrepo.RolePermissionRepository UserRoles domrepo.UserRoleRepository Redis *redislib.Client ModelPath string CasbinModelText string ReloadChannel string } // reloadEvent is the JSON payload published on the reload channel. type reloadEvent struct { TenantID string `json:"tenant_id"` TS int64 `json:"ts"` } type rbacUseCase struct { roles domrepo.RoleRepository perms domrepo.PermissionRepository rolePerms domrepo.RolePermissionRepository userRoles domrepo.UserRoleRepository redis *redislib.Client enforcerMu sync.RWMutex enforcers map[string]*casbin.SyncedEnforcer model casbinmodel.Model modelMu sync.Mutex modelTxt string reloadChannel string stopSubscribe context.CancelFunc stopMu sync.Mutex } // NewRBACUseCase wires the Casbin enforcer with the persistence layer. // Returns ErrCasbinNotConfigured when Redis is missing — Casbin's Redis // adapter and Pub/Sub require Redis to function. func NewRBACUseCase(param RBACUseCaseParam) (dom.RBACUseCase, error) { if param.Redis == nil || param.Redis.Zero() == nil { return nil, permission.ErrCasbinNotConfigured } channel := strings.TrimSpace(param.ReloadChannel) if channel == "" { channel = permission.PolicyReloadChannel } uc := &rbacUseCase{ roles: param.Roles, perms: param.Permissions, rolePerms: param.RolePermissions, userRoles: param.UserRoles, redis: param.Redis, enforcers: make(map[string]*casbin.SyncedEnforcer), modelTxt: strings.TrimSpace(param.CasbinModelText), reloadChannel: channel, } if uc.modelTxt == "" && param.ModelPath != "" { mdl, err := casbinmodel.NewModelFromFile(param.ModelPath) if err != nil { return nil, fmt.Errorf("permission: load casbin model: %w", err) } uc.model = mdl } return uc, nil } // Check enforces (tenant, uid → role keys) ∩ policy. Multiple roles use // any-allow semantics: the first matching role short-circuits with // allow=true. The `r.role == p.role` matcher means we must call EnforceEx // once per role; that is acceptable because a member typically has 1–3 // roles and the call is in-memory. func (uc *rbacUseCase) Check(ctx context.Context, req *dom.CheckRequest) (*dom.CheckResult, error) { if req == nil || req.TenantID == "" || req.UID == "" || req.Path == "" || req.Method == "" { return nil, permission.ErrInvalidCheckRequest } enforcer, err := uc.enforcerFor(ctx, req.TenantID) if err != nil { return nil, err } roleKeys, err := uc.roleKeysOf(ctx, req.TenantID, req.UID) if err != nil { return nil, err } if len(roleKeys) == 0 { return &dom.CheckResult{Allow: false}, nil } for _, key := range roleKeys { ok, matched, err := enforcer.EnforceEx(req.TenantID, key, req.Path, req.Method) if err != nil { return nil, fmt.Errorf("permission: enforce: %w", err) } if ok { return &dom.CheckResult{ Allow: true, MatchedRoleKey: key, MatchedPolicyRow: append([]string{permission.CasbinPolicyType}, matched...), }, nil } } return &dom.CheckResult{Allow: false}, nil } // LoadPolicy materialises role_permissions for a single tenant into // Casbin policy rules and atomically saves them via the Redis adapter. func (uc *rbacUseCase) LoadPolicy(ctx context.Context, tenantID string) error { rules, err := uc.buildRules(ctx, tenantID) if err != nil { return err } enforcer, err := uc.enforcerFor(ctx, tenantID) if err != nil { return err } enforcer.ClearPolicy() if len(rules) > 0 { if _, err := enforcer.AddPolicies(rules); err != nil { return fmt.Errorf("permission: add policies: %w", err) } } if err := uc.saveAdapter(ctx, tenantID, rules); err != nil { logx.WithContext(ctx).Errorf("permission: save adapter tenant=%s: %v", tenantID, err) } return nil } // LoadAllPolicies refreshes policies for every tenant. Used by the // 5-minute cron fallback (see plan §6.11). func (uc *rbacUseCase) LoadAllPolicies(ctx context.Context) error { // Tenant list comes from the member module via Casbin keys; here we // scan the role collection's distinct tenant_id. For simplicity we // reload only tenants that have at least one role. roles, err := uc.allTenantsWithRoles(ctx) if err != nil { return err } for _, tenantID := range roles { if err := uc.LoadPolicy(ctx, tenantID); err != nil { logx.WithContext(ctx).Errorf("permission: reload tenant=%s: %v", tenantID, err) } } return nil } // BroadcastReload publishes a tenant-scoped reload event over Redis // Pub/Sub. Other pods (and this pod itself) consume it to re-LoadPolicy. func (uc *rbacUseCase) BroadcastReload(ctx context.Context, tenantID string) error { if uc.redis == nil || uc.redis.Zero() == nil { return nil } if tenantID == "" { tenantID = permission.PolicyReloadAllToken } payload, err := json.Marshal(reloadEvent{TenantID: tenantID, TS: time.Now().UnixMilli()}) if err != nil { return err } _, err = uc.redis.Zero().PublishCtx(ctx, uc.reloadChannel, string(payload)) return err } // StartReloadSubscriber spins a goroutine that reads from the Redis // Pub/Sub channel and calls LoadPolicy for each event. Idempotent: a // second call replaces the prior subscription. func (uc *rbacUseCase) StartReloadSubscriber(ctx context.Context) error { uc.StopReloadSubscriber() pubsub := uc.redis.PubSubClient() if pubsub == nil { return nil } subCtx, cancel := context.WithCancel(ctx) uc.stopMu.Lock() uc.stopSubscribe = cancel uc.stopMu.Unlock() sub := pubsub.Subscribe(subCtx, uc.reloadChannel) if _, err := sub.Receive(subCtx); err != nil { cancel() return fmt.Errorf("permission: subscribe reload channel: %w", err) } ch := sub.Channel() go func() { defer func() { _ = sub.Close() }() for { select { case <-subCtx.Done(): return case msg, ok := <-ch: if !ok { return } uc.handleReload(subCtx, msg.Payload) } } }() return nil } // StopReloadSubscriber cancels the subscriber goroutine (best-effort). func (uc *rbacUseCase) StopReloadSubscriber() { uc.stopMu.Lock() defer uc.stopMu.Unlock() if uc.stopSubscribe != nil { uc.stopSubscribe() uc.stopSubscribe = nil } } func (uc *rbacUseCase) handleReload(ctx context.Context, payload string) { var ev reloadEvent if err := json.Unmarshal([]byte(payload), &ev); err != nil { logx.WithContext(ctx).Errorf("permission: invalid reload payload: %s", payload) return } if ev.TenantID == permission.PolicyReloadAllToken || ev.TenantID == "" { if err := uc.LoadAllPolicies(ctx); err != nil { logx.WithContext(ctx).Errorf("permission: reload all: %v", err) } return } if err := uc.LoadPolicy(ctx, ev.TenantID); err != nil { logx.WithContext(ctx).Errorf("permission: reload tenant=%s: %v", ev.TenantID, err) } } func (uc *rbacUseCase) enforcerFor(ctx context.Context, tenantID string) (*casbin.SyncedEnforcer, error) { uc.enforcerMu.RLock() if e, ok := uc.enforcers[tenantID]; ok { uc.enforcerMu.RUnlock() return e, nil } uc.enforcerMu.RUnlock() uc.enforcerMu.Lock() defer uc.enforcerMu.Unlock() if e, ok := uc.enforcers[tenantID]; ok { return e, nil } mdl, err := uc.cloneModel() if err != nil { return nil, err } enforcer, err := casbin.NewSyncedEnforcer(mdl) if err != nil { return nil, fmt.Errorf("permission: new enforcer: %w", err) } enforcer.EnableAutoSave(false) uc.enforcers[tenantID] = enforcer rules, err := uc.buildRules(ctx, tenantID) if err != nil { return nil, err } if len(rules) > 0 { if _, err := enforcer.AddPolicies(rules); err != nil { return nil, fmt.Errorf("permission: seed policies: %w", err) } } return enforcer, nil } func (uc *rbacUseCase) cloneModel() (casbinmodel.Model, error) { uc.modelMu.Lock() defer uc.modelMu.Unlock() if uc.modelTxt != "" { return casbinmodel.NewModelFromString(uc.modelTxt) } if uc.model == nil { return nil, errors.New("permission: casbin model not loaded") } // casbin/model is not safe for concurrent enforcers in some versions; // dump+parse keeps each enforcer isolated. return casbinmodel.NewModelFromString(uc.model.ToText()) } func (uc *rbacUseCase) buildRules(ctx context.Context, tenantID string) ([][]string, error) { roles, err := uc.roles.ListByTenant(ctx, tenantID) if err != nil { return nil, wrapRepoErr(err) } if len(roles) == 0 { return nil, nil } roleByID := make(map[string]*entity.Role, len(roles)) roleIDs := make([]string, 0, len(roles)) for _, role := range roles { if role.Status != enum.StatusOpen { continue } roleByID[role.ID.Hex()] = role roleIDs = append(roleIDs, role.ID.Hex()) } rps, err := uc.rolePerms.ListByRoles(ctx, tenantID, roleIDs) if err != nil { return nil, wrapRepoErr(err) } if len(rps) == 0 { return nil, nil } permIDSet := make(map[string]struct{}, len(rps)) for _, rp := range rps { permIDSet[rp.PermissionID] = struct{}{} } ids := make([]string, 0, len(permIDSet)) for id := range permIDSet { ids = append(ids, id) } perms, err := uc.perms.GetByIDs(ctx, ids) if err != nil { return nil, wrapRepoErr(err) } permByID := make(map[string]*entity.Permission, len(perms)) for _, perm := range perms { permByID[perm.ID.Hex()] = perm } rules := make([][]string, 0, len(rps)) for _, rp := range rps { role, ok := roleByID[rp.RoleID] if !ok { continue } perm, ok := permByID[rp.PermissionID] if !ok || !perm.IsLeaf() || perm.Status != enum.StatusOpen { continue } rules = append(rules, []string{ tenantID, role.Key, perm.HTTPPath, perm.HTTPMethods, perm.Name, }) } return rules, nil } func (uc *rbacUseCase) allTenantsWithRoles(ctx context.Context) ([]string, error) { // Casbin reload is best-effort across pods; we use the Redis cluster // to remember which tenant keys exist. Empty set ⇒ nothing to do. if uc.redis == nil || uc.redis.Zero() == nil { return nil, nil } keys, err := uc.redis.Zero().KeysCtx(ctx, permission.CasbinRulesRedisKey.String()+":*") if err != nil { return nil, err } prefix := permission.CasbinRulesRedisKey.String() + ":" tenantIDs := make([]string, 0, len(keys)) for _, key := range keys { tenantIDs = append(tenantIDs, strings.TrimPrefix(key, prefix)) } return tenantIDs, nil } func (uc *rbacUseCase) saveAdapter(ctx context.Context, tenantID string, rules [][]string) error { adapter, err := newRedisAdapterFromClient(uc.redis) if err != nil || adapter == nil { return err } return adapter.SaveAll(ctx, tenantID, rules) } // newRedisAdapterFromClient is implemented in casbin_adapter_bridge.go to // keep the import surface narrow (avoid pulling repository into usecase). func newRedisAdapterFromClient(client *redislib.Client) (domrepo.CasbinPolicyAdapter, error) { return RedisAdapterFactory(client) } // RedisAdapterFactory is plugged in by module.go (DI seam). Tests can // override by assigning a stub. var RedisAdapterFactory = func(_ *redislib.Client) (domrepo.CasbinPolicyAdapter, error) { return nil, nil } func (uc *rbacUseCase) roleKeysOf(ctx context.Context, tenantID, uid string) ([]string, error) { urs, err := uc.userRoles.ListByUser(ctx, tenantID, uid) if err != nil { return nil, wrapRepoErr(err) } if len(urs) == 0 { return nil, nil } roleIDs := make([]string, 0, len(urs)) for _, ur := range urs { roleIDs = append(roleIDs, ur.RoleID) } roles, err := uc.roles.ListByTenantAndIDs(ctx, tenantID, roleIDs) if err != nil { return nil, wrapRepoErr(err) } out := make([]string, 0, len(roles)) for _, role := range roles { if role.Status != enum.StatusOpen { continue } out = append(out, role.Key) } return out, nil } var _ dom.RBACUseCase = (*rbacUseCase)(nil)