428 lines
12 KiB
Go
428 lines
12 KiB
Go
|
|
package usecase
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"encoding/json"
|
|||
|
|
"errors"
|
|||
|
|
"fmt"
|
|||
|
|
"strings"
|
|||
|
|
"sync"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
redislib "gateway/internal/library/redis"
|
|||
|
|
permission "gateway/internal/model/permission/domain"
|
|||
|
|
"gateway/internal/model/permission/domain/entity"
|
|||
|
|
"gateway/internal/model/permission/domain/enum"
|
|||
|
|
domrepo "gateway/internal/model/permission/domain/repository"
|
|||
|
|
dom "gateway/internal/model/permission/domain/usecase"
|
|||
|
|
|
|||
|
|
"github.com/casbin/casbin/v2"
|
|||
|
|
casbinmodel "github.com/casbin/casbin/v2/model"
|
|||
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// RBACUseCaseParam injects all repos + Redis Pub/Sub client. ModelPath
|
|||
|
|
// must point at etc/rbac.conf; CasbinModelText overrides ModelPath when
|
|||
|
|
// non-empty (used by tests / embedded resources).
|
|||
|
|
type RBACUseCaseParam struct {
|
|||
|
|
Roles domrepo.RoleRepository
|
|||
|
|
Permissions domrepo.PermissionRepository
|
|||
|
|
RolePermissions domrepo.RolePermissionRepository
|
|||
|
|
UserRoles domrepo.UserRoleRepository
|
|||
|
|
Redis *redislib.Client
|
|||
|
|
ModelPath string
|
|||
|
|
CasbinModelText string
|
|||
|
|
ReloadChannel string
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// reloadEvent is the JSON payload published on the reload channel.
|
|||
|
|
type reloadEvent struct {
|
|||
|
|
TenantID string `json:"tenant_id"`
|
|||
|
|
TS int64 `json:"ts"`
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
type rbacUseCase struct {
|
|||
|
|
roles domrepo.RoleRepository
|
|||
|
|
perms domrepo.PermissionRepository
|
|||
|
|
rolePerms domrepo.RolePermissionRepository
|
|||
|
|
userRoles domrepo.UserRoleRepository
|
|||
|
|
redis *redislib.Client
|
|||
|
|
|
|||
|
|
enforcerMu sync.RWMutex
|
|||
|
|
enforcers map[string]*casbin.SyncedEnforcer
|
|||
|
|
|
|||
|
|
model casbinmodel.Model
|
|||
|
|
modelMu sync.Mutex
|
|||
|
|
modelTxt string
|
|||
|
|
|
|||
|
|
reloadChannel string
|
|||
|
|
stopSubscribe context.CancelFunc
|
|||
|
|
stopMu sync.Mutex
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewRBACUseCase wires the Casbin enforcer with the persistence layer.
|
|||
|
|
// Returns ErrCasbinNotConfigured when Redis is missing — Casbin's Redis
|
|||
|
|
// adapter and Pub/Sub require Redis to function.
|
|||
|
|
func NewRBACUseCase(param RBACUseCaseParam) (dom.RBACUseCase, error) {
|
|||
|
|
if param.Redis == nil || param.Redis.Zero() == nil {
|
|||
|
|
return nil, permission.ErrCasbinNotConfigured
|
|||
|
|
}
|
|||
|
|
channel := strings.TrimSpace(param.ReloadChannel)
|
|||
|
|
if channel == "" {
|
|||
|
|
channel = permission.PolicyReloadChannel
|
|||
|
|
}
|
|||
|
|
uc := &rbacUseCase{
|
|||
|
|
roles: param.Roles,
|
|||
|
|
perms: param.Permissions,
|
|||
|
|
rolePerms: param.RolePermissions,
|
|||
|
|
userRoles: param.UserRoles,
|
|||
|
|
redis: param.Redis,
|
|||
|
|
enforcers: make(map[string]*casbin.SyncedEnforcer),
|
|||
|
|
modelTxt: strings.TrimSpace(param.CasbinModelText),
|
|||
|
|
reloadChannel: channel,
|
|||
|
|
}
|
|||
|
|
if uc.modelTxt == "" && param.ModelPath != "" {
|
|||
|
|
mdl, err := casbinmodel.NewModelFromFile(param.ModelPath)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("permission: load casbin model: %w", err)
|
|||
|
|
}
|
|||
|
|
uc.model = mdl
|
|||
|
|
}
|
|||
|
|
return uc, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Check enforces (tenant, uid → role keys) ∩ policy. Multiple roles use
|
|||
|
|
// any-allow semantics: the first matching role short-circuits with
|
|||
|
|
// allow=true. The `r.role == p.role` matcher means we must call EnforceEx
|
|||
|
|
// once per role; that is acceptable because a member typically has 1–3
|
|||
|
|
// roles and the call is in-memory.
|
|||
|
|
func (uc *rbacUseCase) Check(ctx context.Context, req *dom.CheckRequest) (*dom.CheckResult, error) {
|
|||
|
|
if req == nil || req.TenantID == "" || req.UID == "" || req.Path == "" || req.Method == "" {
|
|||
|
|
return nil, permission.ErrInvalidCheckRequest
|
|||
|
|
}
|
|||
|
|
enforcer, err := uc.enforcerFor(ctx, req.TenantID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
roleKeys, err := uc.roleKeysOf(ctx, req.TenantID, req.UID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
if len(roleKeys) == 0 {
|
|||
|
|
return &dom.CheckResult{Allow: false}, nil
|
|||
|
|
}
|
|||
|
|
for _, key := range roleKeys {
|
|||
|
|
ok, matched, err := enforcer.EnforceEx(req.TenantID, key, req.Path, req.Method)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("permission: enforce: %w", err)
|
|||
|
|
}
|
|||
|
|
if ok {
|
|||
|
|
return &dom.CheckResult{
|
|||
|
|
Allow: true,
|
|||
|
|
MatchedRoleKey: key,
|
|||
|
|
MatchedPolicyRow: append([]string{permission.CasbinPolicyType}, matched...),
|
|||
|
|
}, nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return &dom.CheckResult{Allow: false}, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// LoadPolicy materialises role_permissions for a single tenant into
|
|||
|
|
// Casbin policy rules and atomically saves them via the Redis adapter.
|
|||
|
|
func (uc *rbacUseCase) LoadPolicy(ctx context.Context, tenantID string) error {
|
|||
|
|
rules, err := uc.buildRules(ctx, tenantID)
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
enforcer, err := uc.enforcerFor(ctx, tenantID)
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
enforcer.ClearPolicy()
|
|||
|
|
if len(rules) > 0 {
|
|||
|
|
if _, err := enforcer.AddPolicies(rules); err != nil {
|
|||
|
|
return fmt.Errorf("permission: add policies: %w", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
if err := uc.saveAdapter(ctx, tenantID, rules); err != nil {
|
|||
|
|
logx.WithContext(ctx).Errorf("permission: save adapter tenant=%s: %v", tenantID, err)
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// LoadAllPolicies refreshes policies for every tenant. Used by the
|
|||
|
|
// 5-minute cron fallback (see plan §6.11).
|
|||
|
|
func (uc *rbacUseCase) LoadAllPolicies(ctx context.Context) error {
|
|||
|
|
// Tenant list comes from the member module via Casbin keys; here we
|
|||
|
|
// scan the role collection's distinct tenant_id. For simplicity we
|
|||
|
|
// reload only tenants that have at least one role.
|
|||
|
|
roles, err := uc.allTenantsWithRoles(ctx)
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
for _, tenantID := range roles {
|
|||
|
|
if err := uc.LoadPolicy(ctx, tenantID); err != nil {
|
|||
|
|
logx.WithContext(ctx).Errorf("permission: reload tenant=%s: %v", tenantID, err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// BroadcastReload publishes a tenant-scoped reload event over Redis
|
|||
|
|
// Pub/Sub. Other pods (and this pod itself) consume it to re-LoadPolicy.
|
|||
|
|
func (uc *rbacUseCase) BroadcastReload(ctx context.Context, tenantID string) error {
|
|||
|
|
if uc.redis == nil || uc.redis.Zero() == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
if tenantID == "" {
|
|||
|
|
tenantID = permission.PolicyReloadAllToken
|
|||
|
|
}
|
|||
|
|
payload, err := json.Marshal(reloadEvent{TenantID: tenantID, TS: time.Now().UnixMilli()})
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
_, err = uc.redis.Zero().PublishCtx(ctx, uc.reloadChannel, string(payload))
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// StartReloadSubscriber spins a goroutine that reads from the Redis
|
|||
|
|
// Pub/Sub channel and calls LoadPolicy for each event. Idempotent: a
|
|||
|
|
// second call replaces the prior subscription.
|
|||
|
|
func (uc *rbacUseCase) StartReloadSubscriber(ctx context.Context) error {
|
|||
|
|
uc.StopReloadSubscriber()
|
|||
|
|
pubsub := uc.redis.PubSubClient()
|
|||
|
|
if pubsub == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
subCtx, cancel := context.WithCancel(ctx)
|
|||
|
|
uc.stopMu.Lock()
|
|||
|
|
uc.stopSubscribe = cancel
|
|||
|
|
uc.stopMu.Unlock()
|
|||
|
|
|
|||
|
|
sub := pubsub.Subscribe(subCtx, uc.reloadChannel)
|
|||
|
|
if _, err := sub.Receive(subCtx); err != nil {
|
|||
|
|
cancel()
|
|||
|
|
return fmt.Errorf("permission: subscribe reload channel: %w", err)
|
|||
|
|
}
|
|||
|
|
ch := sub.Channel()
|
|||
|
|
go func() {
|
|||
|
|
defer func() { _ = sub.Close() }()
|
|||
|
|
for {
|
|||
|
|
select {
|
|||
|
|
case <-subCtx.Done():
|
|||
|
|
return
|
|||
|
|
case msg, ok := <-ch:
|
|||
|
|
if !ok {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
uc.handleReload(subCtx, msg.Payload)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// StopReloadSubscriber cancels the subscriber goroutine (best-effort).
|
|||
|
|
func (uc *rbacUseCase) StopReloadSubscriber() {
|
|||
|
|
uc.stopMu.Lock()
|
|||
|
|
defer uc.stopMu.Unlock()
|
|||
|
|
if uc.stopSubscribe != nil {
|
|||
|
|
uc.stopSubscribe()
|
|||
|
|
uc.stopSubscribe = nil
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (uc *rbacUseCase) handleReload(ctx context.Context, payload string) {
|
|||
|
|
var ev reloadEvent
|
|||
|
|
if err := json.Unmarshal([]byte(payload), &ev); err != nil {
|
|||
|
|
logx.WithContext(ctx).Errorf("permission: invalid reload payload: %s", payload)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if ev.TenantID == permission.PolicyReloadAllToken || ev.TenantID == "" {
|
|||
|
|
if err := uc.LoadAllPolicies(ctx); err != nil {
|
|||
|
|
logx.WithContext(ctx).Errorf("permission: reload all: %v", err)
|
|||
|
|
}
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
if err := uc.LoadPolicy(ctx, ev.TenantID); err != nil {
|
|||
|
|
logx.WithContext(ctx).Errorf("permission: reload tenant=%s: %v", ev.TenantID, err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (uc *rbacUseCase) enforcerFor(ctx context.Context, tenantID string) (*casbin.SyncedEnforcer, error) {
|
|||
|
|
uc.enforcerMu.RLock()
|
|||
|
|
if e, ok := uc.enforcers[tenantID]; ok {
|
|||
|
|
uc.enforcerMu.RUnlock()
|
|||
|
|
return e, nil
|
|||
|
|
}
|
|||
|
|
uc.enforcerMu.RUnlock()
|
|||
|
|
|
|||
|
|
uc.enforcerMu.Lock()
|
|||
|
|
defer uc.enforcerMu.Unlock()
|
|||
|
|
if e, ok := uc.enforcers[tenantID]; ok {
|
|||
|
|
return e, nil
|
|||
|
|
}
|
|||
|
|
mdl, err := uc.cloneModel()
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
enforcer, err := casbin.NewSyncedEnforcer(mdl)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, fmt.Errorf("permission: new enforcer: %w", err)
|
|||
|
|
}
|
|||
|
|
enforcer.EnableAutoSave(false)
|
|||
|
|
uc.enforcers[tenantID] = enforcer
|
|||
|
|
|
|||
|
|
rules, err := uc.buildRules(ctx, tenantID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
if len(rules) > 0 {
|
|||
|
|
if _, err := enforcer.AddPolicies(rules); err != nil {
|
|||
|
|
return nil, fmt.Errorf("permission: seed policies: %w", err)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return enforcer, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (uc *rbacUseCase) cloneModel() (casbinmodel.Model, error) {
|
|||
|
|
uc.modelMu.Lock()
|
|||
|
|
defer uc.modelMu.Unlock()
|
|||
|
|
if uc.modelTxt != "" {
|
|||
|
|
return casbinmodel.NewModelFromString(uc.modelTxt)
|
|||
|
|
}
|
|||
|
|
if uc.model == nil {
|
|||
|
|
return nil, errors.New("permission: casbin model not loaded")
|
|||
|
|
}
|
|||
|
|
// casbin/model is not safe for concurrent enforcers in some versions;
|
|||
|
|
// dump+parse keeps each enforcer isolated.
|
|||
|
|
return casbinmodel.NewModelFromString(uc.model.ToText())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (uc *rbacUseCase) buildRules(ctx context.Context, tenantID string) ([][]string, error) {
|
|||
|
|
roles, err := uc.roles.ListByTenant(ctx, tenantID)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, wrapRepoErr(err)
|
|||
|
|
}
|
|||
|
|
if len(roles) == 0 {
|
|||
|
|
return nil, nil
|
|||
|
|
}
|
|||
|
|
roleByID := make(map[string]*entity.Role, len(roles))
|
|||
|
|
roleIDs := make([]string, 0, len(roles))
|
|||
|
|
for _, role := range roles {
|
|||
|
|
if role.Status != enum.StatusOpen {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
roleByID[role.ID.Hex()] = role
|
|||
|
|
roleIDs = append(roleIDs, role.ID.Hex())
|
|||
|
|
}
|
|||
|
|
rps, err := uc.rolePerms.ListByRoles(ctx, tenantID, roleIDs)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, wrapRepoErr(err)
|
|||
|
|
}
|
|||
|
|
if len(rps) == 0 {
|
|||
|
|
return nil, nil
|
|||
|
|
}
|
|||
|
|
permIDSet := make(map[string]struct{}, len(rps))
|
|||
|
|
for _, rp := range rps {
|
|||
|
|
permIDSet[rp.PermissionID] = struct{}{}
|
|||
|
|
}
|
|||
|
|
ids := make([]string, 0, len(permIDSet))
|
|||
|
|
for id := range permIDSet {
|
|||
|
|
ids = append(ids, id)
|
|||
|
|
}
|
|||
|
|
perms, err := uc.perms.GetByIDs(ctx, ids)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, wrapRepoErr(err)
|
|||
|
|
}
|
|||
|
|
permByID := make(map[string]*entity.Permission, len(perms))
|
|||
|
|
for _, perm := range perms {
|
|||
|
|
permByID[perm.ID.Hex()] = perm
|
|||
|
|
}
|
|||
|
|
rules := make([][]string, 0, len(rps))
|
|||
|
|
for _, rp := range rps {
|
|||
|
|
role, ok := roleByID[rp.RoleID]
|
|||
|
|
if !ok {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
perm, ok := permByID[rp.PermissionID]
|
|||
|
|
if !ok || !perm.IsLeaf() || perm.Status != enum.StatusOpen {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
rules = append(rules, []string{
|
|||
|
|
tenantID,
|
|||
|
|
role.Key,
|
|||
|
|
perm.HTTPPath,
|
|||
|
|
perm.HTTPMethods,
|
|||
|
|
perm.Name,
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
return rules, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (uc *rbacUseCase) allTenantsWithRoles(ctx context.Context) ([]string, error) {
|
|||
|
|
// Casbin reload is best-effort across pods; we use the Redis cluster
|
|||
|
|
// to remember which tenant keys exist. Empty set ⇒ nothing to do.
|
|||
|
|
if uc.redis == nil || uc.redis.Zero() == nil {
|
|||
|
|
return nil, nil
|
|||
|
|
}
|
|||
|
|
keys, err := uc.redis.Zero().KeysCtx(ctx, permission.CasbinRulesRedisKey.String()+":*")
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
prefix := permission.CasbinRulesRedisKey.String() + ":"
|
|||
|
|
tenantIDs := make([]string, 0, len(keys))
|
|||
|
|
for _, key := range keys {
|
|||
|
|
tenantIDs = append(tenantIDs, strings.TrimPrefix(key, prefix))
|
|||
|
|
}
|
|||
|
|
return tenantIDs, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (uc *rbacUseCase) saveAdapter(ctx context.Context, tenantID string, rules [][]string) error {
|
|||
|
|
adapter, err := newRedisAdapterFromClient(uc.redis)
|
|||
|
|
if err != nil || adapter == nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
return adapter.SaveAll(ctx, tenantID, rules)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// newRedisAdapterFromClient is implemented in casbin_adapter_bridge.go to
|
|||
|
|
// keep the import surface narrow (avoid pulling repository into usecase).
|
|||
|
|
func newRedisAdapterFromClient(client *redislib.Client) (domrepo.CasbinPolicyAdapter, error) {
|
|||
|
|
return RedisAdapterFactory(client)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// RedisAdapterFactory is plugged in by module.go (DI seam). Tests can
|
|||
|
|
// override by assigning a stub.
|
|||
|
|
var RedisAdapterFactory = func(_ *redislib.Client) (domrepo.CasbinPolicyAdapter, error) {
|
|||
|
|
return nil, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func (uc *rbacUseCase) roleKeysOf(ctx context.Context, tenantID, uid string) ([]string, error) {
|
|||
|
|
urs, err := uc.userRoles.ListByUser(ctx, tenantID, uid)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, wrapRepoErr(err)
|
|||
|
|
}
|
|||
|
|
if len(urs) == 0 {
|
|||
|
|
return nil, nil
|
|||
|
|
}
|
|||
|
|
roleIDs := make([]string, 0, len(urs))
|
|||
|
|
for _, ur := range urs {
|
|||
|
|
roleIDs = append(roleIDs, ur.RoleID)
|
|||
|
|
}
|
|||
|
|
roles, err := uc.roles.ListByTenantAndIDs(ctx, tenantID, roleIDs)
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, wrapRepoErr(err)
|
|||
|
|
}
|
|||
|
|
out := make([]string, 0, len(roles))
|
|||
|
|
for _, role := range roles {
|
|||
|
|
if role.Status != enum.StatusOpen {
|
|||
|
|
continue
|
|||
|
|
}
|
|||
|
|
out = append(out, role.Key)
|
|||
|
|
}
|
|||
|
|
return out, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var _ dom.RBACUseCase = (*rbacUseCase)(nil)
|