template-monorepo/internal/model/permission/usecase/rbac_usecase.go

428 lines
12 KiB
Go
Raw Permalink Normal View History

package usecase
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
redislib "gateway/internal/library/redis"
permission "gateway/internal/model/permission/domain"
"gateway/internal/model/permission/domain/entity"
"gateway/internal/model/permission/domain/enum"
domrepo "gateway/internal/model/permission/domain/repository"
dom "gateway/internal/model/permission/domain/usecase"
"github.com/casbin/casbin/v2"
casbinmodel "github.com/casbin/casbin/v2/model"
"github.com/zeromicro/go-zero/core/logx"
)
// RBACUseCaseParam injects all repos + Redis Pub/Sub client. ModelPath
// must point at etc/rbac.conf; CasbinModelText overrides ModelPath when
// non-empty (used by tests / embedded resources).
type RBACUseCaseParam struct {
Roles domrepo.RoleRepository
Permissions domrepo.PermissionRepository
RolePermissions domrepo.RolePermissionRepository
UserRoles domrepo.UserRoleRepository
Redis *redislib.Client
ModelPath string
CasbinModelText string
ReloadChannel string
}
// reloadEvent is the JSON payload published on the reload channel.
type reloadEvent struct {
TenantID string `json:"tenant_id"`
TS int64 `json:"ts"`
}
type rbacUseCase struct {
roles domrepo.RoleRepository
perms domrepo.PermissionRepository
rolePerms domrepo.RolePermissionRepository
userRoles domrepo.UserRoleRepository
redis *redislib.Client
enforcerMu sync.RWMutex
enforcers map[string]*casbin.SyncedEnforcer
model casbinmodel.Model
modelMu sync.Mutex
modelTxt string
reloadChannel string
stopSubscribe context.CancelFunc
stopMu sync.Mutex
}
// NewRBACUseCase wires the Casbin enforcer with the persistence layer.
// Returns ErrCasbinNotConfigured when Redis is missing — Casbin's Redis
// adapter and Pub/Sub require Redis to function.
func NewRBACUseCase(param RBACUseCaseParam) (dom.RBACUseCase, error) {
if param.Redis == nil || param.Redis.Zero() == nil {
return nil, permission.ErrCasbinNotConfigured
}
channel := strings.TrimSpace(param.ReloadChannel)
if channel == "" {
channel = permission.PolicyReloadChannel
}
uc := &rbacUseCase{
roles: param.Roles,
perms: param.Permissions,
rolePerms: param.RolePermissions,
userRoles: param.UserRoles,
redis: param.Redis,
enforcers: make(map[string]*casbin.SyncedEnforcer),
modelTxt: strings.TrimSpace(param.CasbinModelText),
reloadChannel: channel,
}
if uc.modelTxt == "" && param.ModelPath != "" {
mdl, err := casbinmodel.NewModelFromFile(param.ModelPath)
if err != nil {
return nil, fmt.Errorf("permission: load casbin model: %w", err)
}
uc.model = mdl
}
return uc, nil
}
// Check enforces (tenant, uid → role keys) ∩ policy. Multiple roles use
// any-allow semantics: the first matching role short-circuits with
// allow=true. The `r.role == p.role` matcher means we must call EnforceEx
// once per role; that is acceptable because a member typically has 13
// roles and the call is in-memory.
func (uc *rbacUseCase) Check(ctx context.Context, req *dom.CheckRequest) (*dom.CheckResult, error) {
if req == nil || req.TenantID == "" || req.UID == "" || req.Path == "" || req.Method == "" {
return nil, permission.ErrInvalidCheckRequest
}
enforcer, err := uc.enforcerFor(ctx, req.TenantID)
if err != nil {
return nil, err
}
roleKeys, err := uc.roleKeysOf(ctx, req.TenantID, req.UID)
if err != nil {
return nil, err
}
if len(roleKeys) == 0 {
return &dom.CheckResult{Allow: false}, nil
}
for _, key := range roleKeys {
ok, matched, err := enforcer.EnforceEx(req.TenantID, key, req.Path, req.Method)
if err != nil {
return nil, fmt.Errorf("permission: enforce: %w", err)
}
if ok {
return &dom.CheckResult{
Allow: true,
MatchedRoleKey: key,
MatchedPolicyRow: append([]string{permission.CasbinPolicyType}, matched...),
}, nil
}
}
return &dom.CheckResult{Allow: false}, nil
}
// LoadPolicy materialises role_permissions for a single tenant into
// Casbin policy rules and atomically saves them via the Redis adapter.
func (uc *rbacUseCase) LoadPolicy(ctx context.Context, tenantID string) error {
rules, err := uc.buildRules(ctx, tenantID)
if err != nil {
return err
}
enforcer, err := uc.enforcerFor(ctx, tenantID)
if err != nil {
return err
}
enforcer.ClearPolicy()
if len(rules) > 0 {
if _, err := enforcer.AddPolicies(rules); err != nil {
return fmt.Errorf("permission: add policies: %w", err)
}
}
if err := uc.saveAdapter(ctx, tenantID, rules); err != nil {
logx.WithContext(ctx).Errorf("permission: save adapter tenant=%s: %v", tenantID, err)
}
return nil
}
// LoadAllPolicies refreshes policies for every tenant. Used by the
// 5-minute cron fallback (see plan §6.11).
func (uc *rbacUseCase) LoadAllPolicies(ctx context.Context) error {
// Tenant list comes from the member module via Casbin keys; here we
// scan the role collection's distinct tenant_id. For simplicity we
// reload only tenants that have at least one role.
roles, err := uc.allTenantsWithRoles(ctx)
if err != nil {
return err
}
for _, tenantID := range roles {
if err := uc.LoadPolicy(ctx, tenantID); err != nil {
logx.WithContext(ctx).Errorf("permission: reload tenant=%s: %v", tenantID, err)
}
}
return nil
}
// BroadcastReload publishes a tenant-scoped reload event over Redis
// Pub/Sub. Other pods (and this pod itself) consume it to re-LoadPolicy.
func (uc *rbacUseCase) BroadcastReload(ctx context.Context, tenantID string) error {
if uc.redis == nil || uc.redis.Zero() == nil {
return nil
}
if tenantID == "" {
tenantID = permission.PolicyReloadAllToken
}
payload, err := json.Marshal(reloadEvent{TenantID: tenantID, TS: time.Now().UnixMilli()})
if err != nil {
return err
}
_, err = uc.redis.Zero().PublishCtx(ctx, uc.reloadChannel, string(payload))
return err
}
// StartReloadSubscriber spins a goroutine that reads from the Redis
// Pub/Sub channel and calls LoadPolicy for each event. Idempotent: a
// second call replaces the prior subscription.
func (uc *rbacUseCase) StartReloadSubscriber(ctx context.Context) error {
uc.StopReloadSubscriber()
pubsub := uc.redis.PubSubClient()
if pubsub == nil {
return nil
}
subCtx, cancel := context.WithCancel(ctx)
uc.stopMu.Lock()
uc.stopSubscribe = cancel
uc.stopMu.Unlock()
sub := pubsub.Subscribe(subCtx, uc.reloadChannel)
if _, err := sub.Receive(subCtx); err != nil {
cancel()
return fmt.Errorf("permission: subscribe reload channel: %w", err)
}
ch := sub.Channel()
go func() {
defer func() { _ = sub.Close() }()
for {
select {
case <-subCtx.Done():
return
case msg, ok := <-ch:
if !ok {
return
}
uc.handleReload(subCtx, msg.Payload)
}
}
}()
return nil
}
// StopReloadSubscriber cancels the subscriber goroutine (best-effort).
func (uc *rbacUseCase) StopReloadSubscriber() {
uc.stopMu.Lock()
defer uc.stopMu.Unlock()
if uc.stopSubscribe != nil {
uc.stopSubscribe()
uc.stopSubscribe = nil
}
}
func (uc *rbacUseCase) handleReload(ctx context.Context, payload string) {
var ev reloadEvent
if err := json.Unmarshal([]byte(payload), &ev); err != nil {
logx.WithContext(ctx).Errorf("permission: invalid reload payload: %s", payload)
return
}
if ev.TenantID == permission.PolicyReloadAllToken || ev.TenantID == "" {
if err := uc.LoadAllPolicies(ctx); err != nil {
logx.WithContext(ctx).Errorf("permission: reload all: %v", err)
}
return
}
if err := uc.LoadPolicy(ctx, ev.TenantID); err != nil {
logx.WithContext(ctx).Errorf("permission: reload tenant=%s: %v", ev.TenantID, err)
}
}
func (uc *rbacUseCase) enforcerFor(ctx context.Context, tenantID string) (*casbin.SyncedEnforcer, error) {
uc.enforcerMu.RLock()
if e, ok := uc.enforcers[tenantID]; ok {
uc.enforcerMu.RUnlock()
return e, nil
}
uc.enforcerMu.RUnlock()
uc.enforcerMu.Lock()
defer uc.enforcerMu.Unlock()
if e, ok := uc.enforcers[tenantID]; ok {
return e, nil
}
mdl, err := uc.cloneModel()
if err != nil {
return nil, err
}
enforcer, err := casbin.NewSyncedEnforcer(mdl)
if err != nil {
return nil, fmt.Errorf("permission: new enforcer: %w", err)
}
enforcer.EnableAutoSave(false)
uc.enforcers[tenantID] = enforcer
rules, err := uc.buildRules(ctx, tenantID)
if err != nil {
return nil, err
}
if len(rules) > 0 {
if _, err := enforcer.AddPolicies(rules); err != nil {
return nil, fmt.Errorf("permission: seed policies: %w", err)
}
}
return enforcer, nil
}
func (uc *rbacUseCase) cloneModel() (casbinmodel.Model, error) {
uc.modelMu.Lock()
defer uc.modelMu.Unlock()
if uc.modelTxt != "" {
return casbinmodel.NewModelFromString(uc.modelTxt)
}
if uc.model == nil {
return nil, errors.New("permission: casbin model not loaded")
}
// casbin/model is not safe for concurrent enforcers in some versions;
// dump+parse keeps each enforcer isolated.
return casbinmodel.NewModelFromString(uc.model.ToText())
}
func (uc *rbacUseCase) buildRules(ctx context.Context, tenantID string) ([][]string, error) {
roles, err := uc.roles.ListByTenant(ctx, tenantID)
if err != nil {
return nil, wrapRepoErr(err)
}
if len(roles) == 0 {
return nil, nil
}
roleByID := make(map[string]*entity.Role, len(roles))
roleIDs := make([]string, 0, len(roles))
for _, role := range roles {
if role.Status != enum.StatusOpen {
continue
}
roleByID[role.ID.Hex()] = role
roleIDs = append(roleIDs, role.ID.Hex())
}
rps, err := uc.rolePerms.ListByRoles(ctx, tenantID, roleIDs)
if err != nil {
return nil, wrapRepoErr(err)
}
if len(rps) == 0 {
return nil, nil
}
permIDSet := make(map[string]struct{}, len(rps))
for _, rp := range rps {
permIDSet[rp.PermissionID] = struct{}{}
}
ids := make([]string, 0, len(permIDSet))
for id := range permIDSet {
ids = append(ids, id)
}
perms, err := uc.perms.GetByIDs(ctx, ids)
if err != nil {
return nil, wrapRepoErr(err)
}
permByID := make(map[string]*entity.Permission, len(perms))
for _, perm := range perms {
permByID[perm.ID.Hex()] = perm
}
rules := make([][]string, 0, len(rps))
for _, rp := range rps {
role, ok := roleByID[rp.RoleID]
if !ok {
continue
}
perm, ok := permByID[rp.PermissionID]
if !ok || !perm.IsLeaf() || perm.Status != enum.StatusOpen {
continue
}
rules = append(rules, []string{
tenantID,
role.Key,
perm.HTTPPath,
perm.HTTPMethods,
perm.Name,
})
}
return rules, nil
}
func (uc *rbacUseCase) allTenantsWithRoles(ctx context.Context) ([]string, error) {
// Casbin reload is best-effort across pods; we use the Redis cluster
// to remember which tenant keys exist. Empty set ⇒ nothing to do.
if uc.redis == nil || uc.redis.Zero() == nil {
return nil, nil
}
keys, err := uc.redis.Zero().KeysCtx(ctx, permission.CasbinRulesRedisKey.String()+":*")
if err != nil {
return nil, err
}
prefix := permission.CasbinRulesRedisKey.String() + ":"
tenantIDs := make([]string, 0, len(keys))
for _, key := range keys {
tenantIDs = append(tenantIDs, strings.TrimPrefix(key, prefix))
}
return tenantIDs, nil
}
func (uc *rbacUseCase) saveAdapter(ctx context.Context, tenantID string, rules [][]string) error {
adapter, err := newRedisAdapterFromClient(uc.redis)
if err != nil || adapter == nil {
return err
}
return adapter.SaveAll(ctx, tenantID, rules)
}
// newRedisAdapterFromClient is implemented in casbin_adapter_bridge.go to
// keep the import surface narrow (avoid pulling repository into usecase).
func newRedisAdapterFromClient(client *redislib.Client) (domrepo.CasbinPolicyAdapter, error) {
return RedisAdapterFactory(client)
}
// RedisAdapterFactory is plugged in by module.go (DI seam). Tests can
// override by assigning a stub.
var RedisAdapterFactory = func(_ *redislib.Client) (domrepo.CasbinPolicyAdapter, error) {
return nil, nil
}
func (uc *rbacUseCase) roleKeysOf(ctx context.Context, tenantID, uid string) ([]string, error) {
urs, err := uc.userRoles.ListByUser(ctx, tenantID, uid)
if err != nil {
return nil, wrapRepoErr(err)
}
if len(urs) == 0 {
return nil, nil
}
roleIDs := make([]string, 0, len(urs))
for _, ur := range urs {
roleIDs = append(roleIDs, ur.RoleID)
}
roles, err := uc.roles.ListByTenantAndIDs(ctx, tenantID, roleIDs)
if err != nil {
return nil, wrapRepoErr(err)
}
out := make([]string, 0, len(roles))
for _, role := range roles {
if role.Status != enum.StatusOpen {
continue
}
out = append(out, role.Key)
}
return out, nil
}
var _ dom.RBACUseCase = (*rbacUseCase)(nil)