428 lines
12 KiB
Go
428 lines
12 KiB
Go
package usecase
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
redislib "gateway/internal/library/redis"
|
||
permission "gateway/internal/model/permission/domain"
|
||
"gateway/internal/model/permission/domain/entity"
|
||
"gateway/internal/model/permission/domain/enum"
|
||
domrepo "gateway/internal/model/permission/domain/repository"
|
||
dom "gateway/internal/model/permission/domain/usecase"
|
||
|
||
"github.com/casbin/casbin/v2"
|
||
casbinmodel "github.com/casbin/casbin/v2/model"
|
||
"github.com/zeromicro/go-zero/core/logx"
|
||
)
|
||
|
||
// RBACUseCaseParam injects all repos + Redis Pub/Sub client. ModelPath
|
||
// must point at etc/rbac.conf; CasbinModelText overrides ModelPath when
|
||
// non-empty (used by tests / embedded resources).
|
||
type RBACUseCaseParam struct {
|
||
Roles domrepo.RoleRepository
|
||
Permissions domrepo.PermissionRepository
|
||
RolePermissions domrepo.RolePermissionRepository
|
||
UserRoles domrepo.UserRoleRepository
|
||
Redis *redislib.Client
|
||
ModelPath string
|
||
CasbinModelText string
|
||
ReloadChannel string
|
||
}
|
||
|
||
// reloadEvent is the JSON payload published on the reload channel.
|
||
type reloadEvent struct {
|
||
TenantID string `json:"tenant_id"`
|
||
TS int64 `json:"ts"`
|
||
}
|
||
|
||
type rbacUseCase struct {
|
||
roles domrepo.RoleRepository
|
||
perms domrepo.PermissionRepository
|
||
rolePerms domrepo.RolePermissionRepository
|
||
userRoles domrepo.UserRoleRepository
|
||
redis *redislib.Client
|
||
|
||
enforcerMu sync.RWMutex
|
||
enforcers map[string]*casbin.SyncedEnforcer
|
||
|
||
model casbinmodel.Model
|
||
modelMu sync.Mutex
|
||
modelTxt string
|
||
|
||
reloadChannel string
|
||
stopSubscribe context.CancelFunc
|
||
stopMu sync.Mutex
|
||
}
|
||
|
||
// NewRBACUseCase wires the Casbin enforcer with the persistence layer.
|
||
// Returns ErrCasbinNotConfigured when Redis is missing — Casbin's Redis
|
||
// adapter and Pub/Sub require Redis to function.
|
||
func NewRBACUseCase(param RBACUseCaseParam) (dom.RBACUseCase, error) {
|
||
if param.Redis == nil || param.Redis.Zero() == nil {
|
||
return nil, permission.ErrCasbinNotConfigured
|
||
}
|
||
channel := strings.TrimSpace(param.ReloadChannel)
|
||
if channel == "" {
|
||
channel = permission.PolicyReloadChannel
|
||
}
|
||
uc := &rbacUseCase{
|
||
roles: param.Roles,
|
||
perms: param.Permissions,
|
||
rolePerms: param.RolePermissions,
|
||
userRoles: param.UserRoles,
|
||
redis: param.Redis,
|
||
enforcers: make(map[string]*casbin.SyncedEnforcer),
|
||
modelTxt: strings.TrimSpace(param.CasbinModelText),
|
||
reloadChannel: channel,
|
||
}
|
||
if uc.modelTxt == "" && param.ModelPath != "" {
|
||
mdl, err := casbinmodel.NewModelFromFile(param.ModelPath)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("permission: load casbin model: %w", err)
|
||
}
|
||
uc.model = mdl
|
||
}
|
||
return uc, nil
|
||
}
|
||
|
||
// Check enforces (tenant, uid → role keys) ∩ policy. Multiple roles use
|
||
// any-allow semantics: the first matching role short-circuits with
|
||
// allow=true. The `r.role == p.role` matcher means we must call EnforceEx
|
||
// once per role; that is acceptable because a member typically has 1–3
|
||
// roles and the call is in-memory.
|
||
func (uc *rbacUseCase) Check(ctx context.Context, req *dom.CheckRequest) (*dom.CheckResult, error) {
|
||
if req == nil || req.TenantID == "" || req.UID == "" || req.Path == "" || req.Method == "" {
|
||
return nil, permission.ErrInvalidCheckRequest
|
||
}
|
||
enforcer, err := uc.enforcerFor(ctx, req.TenantID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
roleKeys, err := uc.roleKeysOf(ctx, req.TenantID, req.UID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(roleKeys) == 0 {
|
||
return &dom.CheckResult{Allow: false}, nil
|
||
}
|
||
for _, key := range roleKeys {
|
||
ok, matched, err := enforcer.EnforceEx(req.TenantID, key, req.Path, req.Method)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("permission: enforce: %w", err)
|
||
}
|
||
if ok {
|
||
return &dom.CheckResult{
|
||
Allow: true,
|
||
MatchedRoleKey: key,
|
||
MatchedPolicyRow: append([]string{permission.CasbinPolicyType}, matched...),
|
||
}, nil
|
||
}
|
||
}
|
||
return &dom.CheckResult{Allow: false}, nil
|
||
}
|
||
|
||
// LoadPolicy materialises role_permissions for a single tenant into
|
||
// Casbin policy rules and atomically saves them via the Redis adapter.
|
||
func (uc *rbacUseCase) LoadPolicy(ctx context.Context, tenantID string) error {
|
||
rules, err := uc.buildRules(ctx, tenantID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
enforcer, err := uc.enforcerFor(ctx, tenantID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
enforcer.ClearPolicy()
|
||
if len(rules) > 0 {
|
||
if _, err := enforcer.AddPolicies(rules); err != nil {
|
||
return fmt.Errorf("permission: add policies: %w", err)
|
||
}
|
||
}
|
||
if err := uc.saveAdapter(ctx, tenantID, rules); err != nil {
|
||
logx.WithContext(ctx).Errorf("permission: save adapter tenant=%s: %v", tenantID, err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// LoadAllPolicies refreshes policies for every tenant. Used by the
|
||
// 5-minute cron fallback (see plan §6.11).
|
||
func (uc *rbacUseCase) LoadAllPolicies(ctx context.Context) error {
|
||
// Tenant list comes from the member module via Casbin keys; here we
|
||
// scan the role collection's distinct tenant_id. For simplicity we
|
||
// reload only tenants that have at least one role.
|
||
roles, err := uc.allTenantsWithRoles(ctx)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
for _, tenantID := range roles {
|
||
if err := uc.LoadPolicy(ctx, tenantID); err != nil {
|
||
logx.WithContext(ctx).Errorf("permission: reload tenant=%s: %v", tenantID, err)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// BroadcastReload publishes a tenant-scoped reload event over Redis
|
||
// Pub/Sub. Other pods (and this pod itself) consume it to re-LoadPolicy.
|
||
func (uc *rbacUseCase) BroadcastReload(ctx context.Context, tenantID string) error {
|
||
if uc.redis == nil || uc.redis.Zero() == nil {
|
||
return nil
|
||
}
|
||
if tenantID == "" {
|
||
tenantID = permission.PolicyReloadAllToken
|
||
}
|
||
payload, err := json.Marshal(reloadEvent{TenantID: tenantID, TS: time.Now().UnixMilli()})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
_, err = uc.redis.Zero().PublishCtx(ctx, uc.reloadChannel, string(payload))
|
||
return err
|
||
}
|
||
|
||
// StartReloadSubscriber spins a goroutine that reads from the Redis
|
||
// Pub/Sub channel and calls LoadPolicy for each event. Idempotent: a
|
||
// second call replaces the prior subscription.
|
||
func (uc *rbacUseCase) StartReloadSubscriber(ctx context.Context) error {
|
||
uc.StopReloadSubscriber()
|
||
pubsub := uc.redis.PubSubClient()
|
||
if pubsub == nil {
|
||
return nil
|
||
}
|
||
subCtx, cancel := context.WithCancel(ctx)
|
||
uc.stopMu.Lock()
|
||
uc.stopSubscribe = cancel
|
||
uc.stopMu.Unlock()
|
||
|
||
sub := pubsub.Subscribe(subCtx, uc.reloadChannel)
|
||
if _, err := sub.Receive(subCtx); err != nil {
|
||
cancel()
|
||
return fmt.Errorf("permission: subscribe reload channel: %w", err)
|
||
}
|
||
ch := sub.Channel()
|
||
go func() {
|
||
defer func() { _ = sub.Close() }()
|
||
for {
|
||
select {
|
||
case <-subCtx.Done():
|
||
return
|
||
case msg, ok := <-ch:
|
||
if !ok {
|
||
return
|
||
}
|
||
uc.handleReload(subCtx, msg.Payload)
|
||
}
|
||
}
|
||
}()
|
||
return nil
|
||
}
|
||
|
||
// StopReloadSubscriber cancels the subscriber goroutine (best-effort).
|
||
func (uc *rbacUseCase) StopReloadSubscriber() {
|
||
uc.stopMu.Lock()
|
||
defer uc.stopMu.Unlock()
|
||
if uc.stopSubscribe != nil {
|
||
uc.stopSubscribe()
|
||
uc.stopSubscribe = nil
|
||
}
|
||
}
|
||
|
||
func (uc *rbacUseCase) handleReload(ctx context.Context, payload string) {
|
||
var ev reloadEvent
|
||
if err := json.Unmarshal([]byte(payload), &ev); err != nil {
|
||
logx.WithContext(ctx).Errorf("permission: invalid reload payload: %s", payload)
|
||
return
|
||
}
|
||
if ev.TenantID == permission.PolicyReloadAllToken || ev.TenantID == "" {
|
||
if err := uc.LoadAllPolicies(ctx); err != nil {
|
||
logx.WithContext(ctx).Errorf("permission: reload all: %v", err)
|
||
}
|
||
return
|
||
}
|
||
if err := uc.LoadPolicy(ctx, ev.TenantID); err != nil {
|
||
logx.WithContext(ctx).Errorf("permission: reload tenant=%s: %v", ev.TenantID, err)
|
||
}
|
||
}
|
||
|
||
func (uc *rbacUseCase) enforcerFor(ctx context.Context, tenantID string) (*casbin.SyncedEnforcer, error) {
|
||
uc.enforcerMu.RLock()
|
||
if e, ok := uc.enforcers[tenantID]; ok {
|
||
uc.enforcerMu.RUnlock()
|
||
return e, nil
|
||
}
|
||
uc.enforcerMu.RUnlock()
|
||
|
||
uc.enforcerMu.Lock()
|
||
defer uc.enforcerMu.Unlock()
|
||
if e, ok := uc.enforcers[tenantID]; ok {
|
||
return e, nil
|
||
}
|
||
mdl, err := uc.cloneModel()
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
enforcer, err := casbin.NewSyncedEnforcer(mdl)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("permission: new enforcer: %w", err)
|
||
}
|
||
enforcer.EnableAutoSave(false)
|
||
uc.enforcers[tenantID] = enforcer
|
||
|
||
rules, err := uc.buildRules(ctx, tenantID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(rules) > 0 {
|
||
if _, err := enforcer.AddPolicies(rules); err != nil {
|
||
return nil, fmt.Errorf("permission: seed policies: %w", err)
|
||
}
|
||
}
|
||
return enforcer, nil
|
||
}
|
||
|
||
func (uc *rbacUseCase) cloneModel() (casbinmodel.Model, error) {
|
||
uc.modelMu.Lock()
|
||
defer uc.modelMu.Unlock()
|
||
if uc.modelTxt != "" {
|
||
return casbinmodel.NewModelFromString(uc.modelTxt)
|
||
}
|
||
if uc.model == nil {
|
||
return nil, errors.New("permission: casbin model not loaded")
|
||
}
|
||
// casbin/model is not safe for concurrent enforcers in some versions;
|
||
// dump+parse keeps each enforcer isolated.
|
||
return casbinmodel.NewModelFromString(uc.model.ToText())
|
||
}
|
||
|
||
func (uc *rbacUseCase) buildRules(ctx context.Context, tenantID string) ([][]string, error) {
|
||
roles, err := uc.roles.ListByTenant(ctx, tenantID)
|
||
if err != nil {
|
||
return nil, wrapRepoErr(err)
|
||
}
|
||
if len(roles) == 0 {
|
||
return nil, nil
|
||
}
|
||
roleByID := make(map[string]*entity.Role, len(roles))
|
||
roleIDs := make([]string, 0, len(roles))
|
||
for _, role := range roles {
|
||
if role.Status != enum.StatusOpen {
|
||
continue
|
||
}
|
||
roleByID[role.ID.Hex()] = role
|
||
roleIDs = append(roleIDs, role.ID.Hex())
|
||
}
|
||
rps, err := uc.rolePerms.ListByRoles(ctx, tenantID, roleIDs)
|
||
if err != nil {
|
||
return nil, wrapRepoErr(err)
|
||
}
|
||
if len(rps) == 0 {
|
||
return nil, nil
|
||
}
|
||
permIDSet := make(map[string]struct{}, len(rps))
|
||
for _, rp := range rps {
|
||
permIDSet[rp.PermissionID] = struct{}{}
|
||
}
|
||
ids := make([]string, 0, len(permIDSet))
|
||
for id := range permIDSet {
|
||
ids = append(ids, id)
|
||
}
|
||
perms, err := uc.perms.GetByIDs(ctx, ids)
|
||
if err != nil {
|
||
return nil, wrapRepoErr(err)
|
||
}
|
||
permByID := make(map[string]*entity.Permission, len(perms))
|
||
for _, perm := range perms {
|
||
permByID[perm.ID.Hex()] = perm
|
||
}
|
||
rules := make([][]string, 0, len(rps))
|
||
for _, rp := range rps {
|
||
role, ok := roleByID[rp.RoleID]
|
||
if !ok {
|
||
continue
|
||
}
|
||
perm, ok := permByID[rp.PermissionID]
|
||
if !ok || !perm.IsLeaf() || perm.Status != enum.StatusOpen {
|
||
continue
|
||
}
|
||
rules = append(rules, []string{
|
||
tenantID,
|
||
role.Key,
|
||
perm.HTTPPath,
|
||
perm.HTTPMethods,
|
||
perm.Name,
|
||
})
|
||
}
|
||
return rules, nil
|
||
}
|
||
|
||
func (uc *rbacUseCase) allTenantsWithRoles(ctx context.Context) ([]string, error) {
|
||
// Casbin reload is best-effort across pods; we use the Redis cluster
|
||
// to remember which tenant keys exist. Empty set ⇒ nothing to do.
|
||
if uc.redis == nil || uc.redis.Zero() == nil {
|
||
return nil, nil
|
||
}
|
||
keys, err := uc.redis.Zero().KeysCtx(ctx, permission.CasbinRulesRedisKey.String()+":*")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
prefix := permission.CasbinRulesRedisKey.String() + ":"
|
||
tenantIDs := make([]string, 0, len(keys))
|
||
for _, key := range keys {
|
||
tenantIDs = append(tenantIDs, strings.TrimPrefix(key, prefix))
|
||
}
|
||
return tenantIDs, nil
|
||
}
|
||
|
||
func (uc *rbacUseCase) saveAdapter(ctx context.Context, tenantID string, rules [][]string) error {
|
||
adapter, err := newRedisAdapterFromClient(uc.redis)
|
||
if err != nil || adapter == nil {
|
||
return err
|
||
}
|
||
return adapter.SaveAll(ctx, tenantID, rules)
|
||
}
|
||
|
||
// newRedisAdapterFromClient is implemented in casbin_adapter_bridge.go to
|
||
// keep the import surface narrow (avoid pulling repository into usecase).
|
||
func newRedisAdapterFromClient(client *redislib.Client) (domrepo.CasbinPolicyAdapter, error) {
|
||
return RedisAdapterFactory(client)
|
||
}
|
||
|
||
// RedisAdapterFactory is plugged in by module.go (DI seam). Tests can
|
||
// override by assigning a stub.
|
||
var RedisAdapterFactory = func(_ *redislib.Client) (domrepo.CasbinPolicyAdapter, error) {
|
||
return nil, nil
|
||
}
|
||
|
||
func (uc *rbacUseCase) roleKeysOf(ctx context.Context, tenantID, uid string) ([]string, error) {
|
||
urs, err := uc.userRoles.ListByUser(ctx, tenantID, uid)
|
||
if err != nil {
|
||
return nil, wrapRepoErr(err)
|
||
}
|
||
if len(urs) == 0 {
|
||
return nil, nil
|
||
}
|
||
roleIDs := make([]string, 0, len(urs))
|
||
for _, ur := range urs {
|
||
roleIDs = append(roleIDs, ur.RoleID)
|
||
}
|
||
roles, err := uc.roles.ListByTenantAndIDs(ctx, tenantID, roleIDs)
|
||
if err != nil {
|
||
return nil, wrapRepoErr(err)
|
||
}
|
||
out := make([]string, 0, len(roles))
|
||
for _, role := range roles {
|
||
if role.Status != enum.StatusOpen {
|
||
continue
|
||
}
|
||
out = append(out, role.Key)
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
var _ dom.RBACUseCase = (*rbacUseCase)(nil)
|