template-monorepo/internal/model/permission/usecase/rbac_usecase.go

428 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package usecase
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
redislib "gateway/internal/library/redis"
permission "gateway/internal/model/permission/domain"
"gateway/internal/model/permission/domain/entity"
"gateway/internal/model/permission/domain/enum"
domrepo "gateway/internal/model/permission/domain/repository"
dom "gateway/internal/model/permission/domain/usecase"
"github.com/casbin/casbin/v2"
casbinmodel "github.com/casbin/casbin/v2/model"
"github.com/zeromicro/go-zero/core/logx"
)
// RBACUseCaseParam injects all repos + Redis Pub/Sub client. ModelPath
// must point at etc/rbac.conf; CasbinModelText overrides ModelPath when
// non-empty (used by tests / embedded resources).
type RBACUseCaseParam struct {
Roles domrepo.RoleRepository
Permissions domrepo.PermissionRepository
RolePermissions domrepo.RolePermissionRepository
UserRoles domrepo.UserRoleRepository
Redis *redislib.Client
ModelPath string
CasbinModelText string
ReloadChannel string
}
// reloadEvent is the JSON payload published on the reload channel.
type reloadEvent struct {
TenantID string `json:"tenant_id"`
TS int64 `json:"ts"`
}
type rbacUseCase struct {
roles domrepo.RoleRepository
perms domrepo.PermissionRepository
rolePerms domrepo.RolePermissionRepository
userRoles domrepo.UserRoleRepository
redis *redislib.Client
enforcerMu sync.RWMutex
enforcers map[string]*casbin.SyncedEnforcer
model casbinmodel.Model
modelMu sync.Mutex
modelTxt string
reloadChannel string
stopSubscribe context.CancelFunc
stopMu sync.Mutex
}
// NewRBACUseCase wires the Casbin enforcer with the persistence layer.
// Returns ErrCasbinNotConfigured when Redis is missing — Casbin's Redis
// adapter and Pub/Sub require Redis to function.
func NewRBACUseCase(param RBACUseCaseParam) (dom.RBACUseCase, error) {
if param.Redis == nil || param.Redis.Zero() == nil {
return nil, permission.ErrCasbinNotConfigured
}
channel := strings.TrimSpace(param.ReloadChannel)
if channel == "" {
channel = permission.PolicyReloadChannel
}
uc := &rbacUseCase{
roles: param.Roles,
perms: param.Permissions,
rolePerms: param.RolePermissions,
userRoles: param.UserRoles,
redis: param.Redis,
enforcers: make(map[string]*casbin.SyncedEnforcer),
modelTxt: strings.TrimSpace(param.CasbinModelText),
reloadChannel: channel,
}
if uc.modelTxt == "" && param.ModelPath != "" {
mdl, err := casbinmodel.NewModelFromFile(param.ModelPath)
if err != nil {
return nil, fmt.Errorf("permission: load casbin model: %w", err)
}
uc.model = mdl
}
return uc, nil
}
// Check enforces (tenant, uid → role keys) ∩ policy. Multiple roles use
// any-allow semantics: the first matching role short-circuits with
// allow=true. The `r.role == p.role` matcher means we must call EnforceEx
// once per role; that is acceptable because a member typically has 13
// roles and the call is in-memory.
func (uc *rbacUseCase) Check(ctx context.Context, req *dom.CheckRequest) (*dom.CheckResult, error) {
if req == nil || req.TenantID == "" || req.UID == "" || req.Path == "" || req.Method == "" {
return nil, permission.ErrInvalidCheckRequest
}
enforcer, err := uc.enforcerFor(ctx, req.TenantID)
if err != nil {
return nil, err
}
roleKeys, err := uc.roleKeysOf(ctx, req.TenantID, req.UID)
if err != nil {
return nil, err
}
if len(roleKeys) == 0 {
return &dom.CheckResult{Allow: false}, nil
}
for _, key := range roleKeys {
ok, matched, err := enforcer.EnforceEx(req.TenantID, key, req.Path, req.Method)
if err != nil {
return nil, fmt.Errorf("permission: enforce: %w", err)
}
if ok {
return &dom.CheckResult{
Allow: true,
MatchedRoleKey: key,
MatchedPolicyRow: append([]string{permission.CasbinPolicyType}, matched...),
}, nil
}
}
return &dom.CheckResult{Allow: false}, nil
}
// LoadPolicy materialises role_permissions for a single tenant into
// Casbin policy rules and atomically saves them via the Redis adapter.
func (uc *rbacUseCase) LoadPolicy(ctx context.Context, tenantID string) error {
rules, err := uc.buildRules(ctx, tenantID)
if err != nil {
return err
}
enforcer, err := uc.enforcerFor(ctx, tenantID)
if err != nil {
return err
}
enforcer.ClearPolicy()
if len(rules) > 0 {
if _, err := enforcer.AddPolicies(rules); err != nil {
return fmt.Errorf("permission: add policies: %w", err)
}
}
if err := uc.saveAdapter(ctx, tenantID, rules); err != nil {
logx.WithContext(ctx).Errorf("permission: save adapter tenant=%s: %v", tenantID, err)
}
return nil
}
// LoadAllPolicies refreshes policies for every tenant. Used by the
// 5-minute cron fallback (see plan §6.11).
func (uc *rbacUseCase) LoadAllPolicies(ctx context.Context) error {
// Tenant list comes from the member module via Casbin keys; here we
// scan the role collection's distinct tenant_id. For simplicity we
// reload only tenants that have at least one role.
roles, err := uc.allTenantsWithRoles(ctx)
if err != nil {
return err
}
for _, tenantID := range roles {
if err := uc.LoadPolicy(ctx, tenantID); err != nil {
logx.WithContext(ctx).Errorf("permission: reload tenant=%s: %v", tenantID, err)
}
}
return nil
}
// BroadcastReload publishes a tenant-scoped reload event over Redis
// Pub/Sub. Other pods (and this pod itself) consume it to re-LoadPolicy.
func (uc *rbacUseCase) BroadcastReload(ctx context.Context, tenantID string) error {
if uc.redis == nil || uc.redis.Zero() == nil {
return nil
}
if tenantID == "" {
tenantID = permission.PolicyReloadAllToken
}
payload, err := json.Marshal(reloadEvent{TenantID: tenantID, TS: time.Now().UnixMilli()})
if err != nil {
return err
}
_, err = uc.redis.Zero().PublishCtx(ctx, uc.reloadChannel, string(payload))
return err
}
// StartReloadSubscriber spins a goroutine that reads from the Redis
// Pub/Sub channel and calls LoadPolicy for each event. Idempotent: a
// second call replaces the prior subscription.
func (uc *rbacUseCase) StartReloadSubscriber(ctx context.Context) error {
uc.StopReloadSubscriber()
pubsub := uc.redis.PubSubClient()
if pubsub == nil {
return nil
}
subCtx, cancel := context.WithCancel(ctx)
uc.stopMu.Lock()
uc.stopSubscribe = cancel
uc.stopMu.Unlock()
sub := pubsub.Subscribe(subCtx, uc.reloadChannel)
if _, err := sub.Receive(subCtx); err != nil {
cancel()
return fmt.Errorf("permission: subscribe reload channel: %w", err)
}
ch := sub.Channel()
go func() {
defer func() { _ = sub.Close() }()
for {
select {
case <-subCtx.Done():
return
case msg, ok := <-ch:
if !ok {
return
}
uc.handleReload(subCtx, msg.Payload)
}
}
}()
return nil
}
// StopReloadSubscriber cancels the subscriber goroutine (best-effort).
func (uc *rbacUseCase) StopReloadSubscriber() {
uc.stopMu.Lock()
defer uc.stopMu.Unlock()
if uc.stopSubscribe != nil {
uc.stopSubscribe()
uc.stopSubscribe = nil
}
}
func (uc *rbacUseCase) handleReload(ctx context.Context, payload string) {
var ev reloadEvent
if err := json.Unmarshal([]byte(payload), &ev); err != nil {
logx.WithContext(ctx).Errorf("permission: invalid reload payload: %s", payload)
return
}
if ev.TenantID == permission.PolicyReloadAllToken || ev.TenantID == "" {
if err := uc.LoadAllPolicies(ctx); err != nil {
logx.WithContext(ctx).Errorf("permission: reload all: %v", err)
}
return
}
if err := uc.LoadPolicy(ctx, ev.TenantID); err != nil {
logx.WithContext(ctx).Errorf("permission: reload tenant=%s: %v", ev.TenantID, err)
}
}
func (uc *rbacUseCase) enforcerFor(ctx context.Context, tenantID string) (*casbin.SyncedEnforcer, error) {
uc.enforcerMu.RLock()
if e, ok := uc.enforcers[tenantID]; ok {
uc.enforcerMu.RUnlock()
return e, nil
}
uc.enforcerMu.RUnlock()
uc.enforcerMu.Lock()
defer uc.enforcerMu.Unlock()
if e, ok := uc.enforcers[tenantID]; ok {
return e, nil
}
mdl, err := uc.cloneModel()
if err != nil {
return nil, err
}
enforcer, err := casbin.NewSyncedEnforcer(mdl)
if err != nil {
return nil, fmt.Errorf("permission: new enforcer: %w", err)
}
enforcer.EnableAutoSave(false)
uc.enforcers[tenantID] = enforcer
rules, err := uc.buildRules(ctx, tenantID)
if err != nil {
return nil, err
}
if len(rules) > 0 {
if _, err := enforcer.AddPolicies(rules); err != nil {
return nil, fmt.Errorf("permission: seed policies: %w", err)
}
}
return enforcer, nil
}
func (uc *rbacUseCase) cloneModel() (casbinmodel.Model, error) {
uc.modelMu.Lock()
defer uc.modelMu.Unlock()
if uc.modelTxt != "" {
return casbinmodel.NewModelFromString(uc.modelTxt)
}
if uc.model == nil {
return nil, errors.New("permission: casbin model not loaded")
}
// casbin/model is not safe for concurrent enforcers in some versions;
// dump+parse keeps each enforcer isolated.
return casbinmodel.NewModelFromString(uc.model.ToText())
}
func (uc *rbacUseCase) buildRules(ctx context.Context, tenantID string) ([][]string, error) {
roles, err := uc.roles.ListByTenant(ctx, tenantID)
if err != nil {
return nil, wrapRepoErr(err)
}
if len(roles) == 0 {
return nil, nil
}
roleByID := make(map[string]*entity.Role, len(roles))
roleIDs := make([]string, 0, len(roles))
for _, role := range roles {
if role.Status != enum.StatusOpen {
continue
}
roleByID[role.ID.Hex()] = role
roleIDs = append(roleIDs, role.ID.Hex())
}
rps, err := uc.rolePerms.ListByRoles(ctx, tenantID, roleIDs)
if err != nil {
return nil, wrapRepoErr(err)
}
if len(rps) == 0 {
return nil, nil
}
permIDSet := make(map[string]struct{}, len(rps))
for _, rp := range rps {
permIDSet[rp.PermissionID] = struct{}{}
}
ids := make([]string, 0, len(permIDSet))
for id := range permIDSet {
ids = append(ids, id)
}
perms, err := uc.perms.GetByIDs(ctx, ids)
if err != nil {
return nil, wrapRepoErr(err)
}
permByID := make(map[string]*entity.Permission, len(perms))
for _, perm := range perms {
permByID[perm.ID.Hex()] = perm
}
rules := make([][]string, 0, len(rps))
for _, rp := range rps {
role, ok := roleByID[rp.RoleID]
if !ok {
continue
}
perm, ok := permByID[rp.PermissionID]
if !ok || !perm.IsLeaf() || perm.Status != enum.StatusOpen {
continue
}
rules = append(rules, []string{
tenantID,
role.Key,
perm.HTTPPath,
perm.HTTPMethods,
perm.Name,
})
}
return rules, nil
}
func (uc *rbacUseCase) allTenantsWithRoles(ctx context.Context) ([]string, error) {
// Casbin reload is best-effort across pods; we use the Redis cluster
// to remember which tenant keys exist. Empty set ⇒ nothing to do.
if uc.redis == nil || uc.redis.Zero() == nil {
return nil, nil
}
keys, err := uc.redis.Zero().KeysCtx(ctx, permission.CasbinRulesRedisKey.String()+":*")
if err != nil {
return nil, err
}
prefix := permission.CasbinRulesRedisKey.String() + ":"
tenantIDs := make([]string, 0, len(keys))
for _, key := range keys {
tenantIDs = append(tenantIDs, strings.TrimPrefix(key, prefix))
}
return tenantIDs, nil
}
func (uc *rbacUseCase) saveAdapter(ctx context.Context, tenantID string, rules [][]string) error {
adapter, err := newRedisAdapterFromClient(uc.redis)
if err != nil || adapter == nil {
return err
}
return adapter.SaveAll(ctx, tenantID, rules)
}
// newRedisAdapterFromClient is implemented in casbin_adapter_bridge.go to
// keep the import surface narrow (avoid pulling repository into usecase).
func newRedisAdapterFromClient(client *redislib.Client) (domrepo.CasbinPolicyAdapter, error) {
return RedisAdapterFactory(client)
}
// RedisAdapterFactory is plugged in by module.go (DI seam). Tests can
// override by assigning a stub.
var RedisAdapterFactory = func(_ *redislib.Client) (domrepo.CasbinPolicyAdapter, error) {
return nil, nil
}
func (uc *rbacUseCase) roleKeysOf(ctx context.Context, tenantID, uid string) ([]string, error) {
urs, err := uc.userRoles.ListByUser(ctx, tenantID, uid)
if err != nil {
return nil, wrapRepoErr(err)
}
if len(urs) == 0 {
return nil, nil
}
roleIDs := make([]string, 0, len(urs))
for _, ur := range urs {
roleIDs = append(roleIDs, ur.RoleID)
}
roles, err := uc.roles.ListByTenantAndIDs(ctx, tenantID, roleIDs)
if err != nil {
return nil, wrapRepoErr(err)
}
out := make([]string, 0, len(roles))
for _, role := range roles {
if role.Status != enum.StatusOpen {
continue
}
out = append(out, role.Key)
}
return out, nil
}
var _ dom.RBACUseCase = (*rbacUseCase)(nil)