diff --git a/internal/usecase/opa.go b/internal/usecase/opa.go new file mode 100644 index 0000000..ecd4260 --- /dev/null +++ b/internal/usecase/opa.go @@ -0,0 +1,163 @@ +package usecase + +import ( + "ark-permission/internal/domain" + "ark-permission/internal/domain/usecase" + ers "code.30cm.net/wanderland/library-go/errors" + "context" + _ "embed" + "fmt" + "github.com/open-policy-agent/opa/rego" + "github.com/zeromicro/go-zero/core/logx" +) + +//go:embed "rule.rego" +var policy []byte + +type OpaUseCaseParam struct{} + +type opaUseCase struct { + // 查詢這個角色是否可用 + allowQuery rego.PreparedEvalQuery + + policies []map[string]any +} + +func (o *opaUseCase) GetPolicy(ctx context.Context) []map[string]any { + return o.policies +} + +func (o *opaUseCase) CheckRBACPermission(ctx context.Context, req usecase.CheckReq) (usecase.CheckOPAResp, error) { + results, err := o.allowQuery.Eval(ctx, rego.EvalInput(map[string]any{ + "roles": req.Roles, + "path": req.Path, + "method": req.Method, + "policies": o.policies, + })) + + if err != nil { + return usecase.CheckOPAResp{}, domain.PermissionGetDataError(fmt.Sprintf("failed to evaluate policy: %v", err)) + } + + if len(results) == 0 { + logx.WithCallerSkip(1).WithFields( + logx.Field("roles", req.Roles), + logx.Field("path", req.Path), + logx.Field("method", req.Method), + logx.Field("policies", o.policies), + ).Error("empty RBAC policy result, possibly due to an incorrect query string or policy") + return usecase.CheckOPAResp{}, domain.PermissionGetDataError("no results returned from policy evaluation") + } + + data, ok := results[0].Expressions[0].Value.(map[string]any) + if !ok { + return usecase.CheckOPAResp{}, domain.PermissionGetDataError("unexpected data format in policy evaluation result") + } + resp, err := convertToCheckOPAResp(data) + if !ok { + return usecase.CheckOPAResp{}, domain.PermissionGetDataError(err.Error()) + } + + return resp, nil +} + +// LoadPolicy 逐一處理 Policy 並且處理超時 +func (o *opaUseCase) LoadPolicy(ctx context.Context, input []usecase.Policy) error { + mapped := make([]map[string]any, 0, len(input)) + + for i, policy := range input { + select { + case <-ctx.Done(): // 監控是否超時或取消 + logx.WithCallerSkip(1).WithFields( + logx.Field("input", input), + ).Error("LoadPolicy context time out") + // TODO 部分完成後處理,記錄日誌並返回成功的部分,或應該要重新 Loading.... + o.policies = append(o.policies, mapped...) + return ers.SystemTimeoutError(fmt.Sprintf("operation timed out after processing %d policies: %v", i, ctx.Err())) + default: + // 繼續處理 + mapped = append(mapped, policiesToMap(policy)) + } + } + + // 完成所有更新後紀錄,整個取代 policies + o.policies = mapped + return nil +} + +func NewOpaUseCase(param OpaUseCaseParam) (usecase.OpaUseCase, error) { + module := rego.Module("policy", string(policy)) + ctx := context.Background() + var allowQueryErr error + uc := &opaUseCase{} + uc.allowQuery, allowQueryErr = rego.New( + rego.Query("data.rbac"), // 要尋找的話 data 必帶, rbac = rego package , allow 是要query 啥 + module, + ).PrepareForEval(ctx) + if allowQueryErr != nil { + return &opaUseCase{}, domain.PermissionGetDataError(allowQueryErr.Error()) + } + + return uc, nil +} + +// 內部使用 +func policiesToMap(policy usecase.Policy) map[string]any { + return map[string]any{ + "methods": policy.Methods, + "name": policy.Name, + "path": policy.Path, + "role": policy.Role, + } +} + +func convertToCheckOPAResp(data map[string]any) (usecase.CheckOPAResp, error) { + var response usecase.CheckOPAResp + + if allow, ok := data["allow"].(bool); ok { + response.Allow = allow + } else { + return usecase.CheckOPAResp{}, fmt.Errorf("missing or invalid 'allow' field") + } + + requestData, ok := data["request"].(map[string]any) + if !ok { + return usecase.CheckOPAResp{}, fmt.Errorf("missing or invalid 'request' field") + } + + response.Request.Method = requestData["method"].(string) + response.Request.Path = requestData["path"].(string) + + policiesData, ok := requestData["policies"].([]any) + if !ok { + return usecase.CheckOPAResp{}, fmt.Errorf("missing or invalid 'policies' field") + } + response.Request.Policies = make([]usecase.Policy, 0, len(policiesData)) + for _, policyData := range policiesData { + p := policyData.(map[string]any) + methodsData := p["methods"].([]any) + methods := make([]string, len(methodsData)) + for i, m := range methodsData { + methods[i] = m.(string) + } + + policy := usecase.Policy{ + Methods: methods, + Name: p["name"].(string), + Path: p["path"].(string), + Role: p["role"].(string), + } + response.Request.Policies = append(response.Request.Policies, policy) + } + + rolesData, ok := requestData["roles"].([]any) + if !ok { + return usecase.CheckOPAResp{}, fmt.Errorf("missing or invalid 'roles' field") + } + response.Request.Roles = make([]string, len(rolesData)) + for i, r := range rolesData { + response.Request.Roles[i] = r.(string) + } + + return response, nil +} diff --git a/internal/usecase/opa_test.go b/internal/usecase/opa_test.go new file mode 100644 index 0000000..e26f4e3 --- /dev/null +++ b/internal/usecase/opa_test.go @@ -0,0 +1,337 @@ +package usecase + +import ( + "ark-permission/internal/domain/usecase" + "context" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zeromicro/go-zero/core/logx" + "testing" + "time" +) + +func TestMustOpaUseCase(t *testing.T) { + // 初始化 OPA UseCase + got, err := NewOpaUseCase(OpaUseCaseParam{}) + assert.NoError(t, err) + + ctx := context.Background() + + // 加载 Policy + err = got.LoadPolicy(ctx, []usecase.Policy{ + { + Role: "admin", + Path: "/admin/.*", + Methods: []string{"GET", "POST"}, + Name: "Admin access", + }, + { + Role: "user", + Path: "/user/.*", + Methods: []string{"GET"}, + Name: "User read access", + }, + { + Role: "editor", + Path: "/editor/.*", + Methods: []string{"PUT", "POST"}, + Name: "Editor access", + }, + }) + assert.NoError(t, err) + + // 定义测试用例表 + tests := []struct { + name string + req usecase.CheckReq + expect bool + expectError bool + }{ + { + name: "單一角色,應允許通過", + req: usecase.CheckReq{ + Roles: []string{"user"}, + Path: "/user/profile", + Method: "GET", + }, + expect: true, + }, + { + name: "多角色其中一個有配到,應允許通過", + req: usecase.CheckReq{ + Roles: []string{"user", "admin"}, + Path: "/user/profile", + Method: "GET", + }, + expect: true, + }, + { + name: "角色不匹配,應拒絕通過", + req: usecase.CheckReq{ + Roles: []string{"editor"}, + Path: "/user/profile", + Method: "GET", + }, + expect: false, + }, + { + name: "路徑不匹配,應拒絕通過", + req: usecase.CheckReq{ + Roles: []string{"user"}, + Path: "/editor/dashboard", + Method: "GET", + }, + expect: false, + }, + { + name: "方法不匹配,應拒絕通過", + req: usecase.CheckReq{ + Roles: []string{"user"}, + Path: "/user/profile", + Method: "POST", + }, + expect: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + check, err := got.CheckRBACPermission(ctx, tt.req) + if tt.expectError { + assert.Error(t, err, "expected an error but got none") + } else { + assert.NoError(t, err, "did not expect an error but got one") + assert.Equal(t, tt.expect, check.Allow) + } + }) + } +} + +func TestLoadPolicy(t *testing.T) { + // 初始化 OPA UseCase + got, err := NewOpaUseCase(OpaUseCaseParam{}) + require.NoError(t, err) + + tests := []struct { + name string + input []usecase.Policy + ctxTimeout time.Duration + expectErr bool + }{ + { + name: "正常加載多個Policy", + input: []usecase.Policy{ + { + Role: "admin", + Path: "/admin/.*", + Methods: []string{"GET", "POST"}, + Name: "Admin access", + }, + { + Role: "user", + Path: "/user/.*", + Methods: []string{"GET"}, + Name: "User read access", + }, + }, + ctxTimeout: 3 * time.Second, // 足夠的時間來執行 + expectErr: false, + }, + { + name: "加載策略超時", + input: []usecase.Policy{ + { + Role: "admin", + Path: "/admin/.*", + Methods: []string{"GET", "POST"}, + Name: "Admin access", + }, + { + Role: "user", + Path: "/user/.*", + Methods: []string{"GET"}, + Name: "User read access", + }, + }, + ctxTimeout: 1 * time.Nanosecond, // 超時 + expectErr: true, + }, + { + name: "空策略加載", + input: []usecase.Policy{}, + ctxTimeout: 3 * time.Second, // 足夠的時間 + expectErr: false, + }, + } + + // 遍歷所有測試用例 + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 設置具有超時的 Context + ctx, cancel := context.WithTimeout(context.Background(), tt.ctxTimeout) + defer cancel() + + // 調用 LoadPolicy + err := got.LoadPolicy(ctx, tt.input) + + // 檢查是否符合預期錯誤 + if tt.expectErr { + assert.Error(t, err, "預期發生錯誤,但沒有發生") + } else { + assert.NoError(t, err, "不預期發生錯誤,但卻發生了") + } + + // 如果沒有錯誤,檢查 policies 是否被正確加載 + if !tt.expectErr { + assert.Equal(t, len(tt.input), len(got.GetPolicy(ctx)), "policies 加載的數量與輸入數量不一致") + } + }) + } +} + +func BenchmarkLoadPolicy(b *testing.B) { + // 初始化 OPA UseCase + got, _ := NewOpaUseCase(OpaUseCaseParam{}) + logx.Disable() + + // 定义不同数量的 Policy 用于基准测试 + policiesSmall := []usecase.Policy{ + { + Role: "admin", + Path: "/admin/.*", + Methods: []string{"GET", "POST"}, + Name: "Admin access", + }, + { + Role: "user", + Path: "/user/.*", + Methods: []string{"GET"}, + Name: "User read access", + }, + } + + policiesLarge := make([]usecase.Policy, 1000) + for i := 0; i < 1000; i++ { + policiesLarge[i] = usecase.Policy{ + Role: "admin", + Path: "/admin/.*", + Methods: []string{"GET", "POST"}, + Name: "Admin access", + } + } + + // 定义基准测试的不同场景 + benchmarks := []struct { + name string + policies []usecase.Policy + ctxTimeout time.Duration + }{ + { + name: "SmallPolicy_NoTimeout", + policies: policiesSmall, + ctxTimeout: 1 * time.Second, // 没有超时 + }, + { + name: "LargePolicy_NoTimeout", + policies: policiesLarge, + ctxTimeout: 1 * time.Second, // 没有超时 + }, + { + name: "SmallPolicy_WithTimeout", + policies: policiesSmall, + ctxTimeout: 1 * time.Millisecond, // 很短的超时 + }, + { + name: "LargePolicy_WithTimeout", + policies: policiesLarge, + ctxTimeout: 1 * time.Millisecond, // 很短的超时 + }, + } + + // 遍历基准测试场景 + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + // 每次运行基准测试时,设置一个带有超时的 Context + ctx, cancel := context.WithTimeout(context.Background(), bm.ctxTimeout) + defer cancel() + + // 调用 LoadPolicy 函数进行基准测试 + _ = got.LoadPolicy(ctx, bm.policies) + } + }) + } +} + +func BenchmarkCheckRBACPermission(b *testing.B) { + got, _ := NewOpaUseCase(OpaUseCaseParam{}) + logx.Disable() + // 定義測試用 Policy + policies := []usecase.Policy{ + { + Role: "admin", + Path: "/admin/.*", + Methods: []string{"GET", "POST"}, + Name: "Admin access", + }, + { + Role: "user", + Path: "/user/.*", + Methods: []string{"GET"}, + Name: "User read access", + }, + { + Role: "editor", + Path: "/editor/.*", + Methods: []string{"PUT", "POST"}, + Name: "Editor access", + }, + } + + // 加載 Policy + _ = got.LoadPolicy(context.Background(), policies) + + // 定義不同測試基準場景 + benchmarks := []struct { + name string + req usecase.CheckReq + }{ + { + name: "SingleRole_SimplePath", + req: usecase.CheckReq{ + Roles: []string{"user"}, + Path: "/user/profile", + Method: "GET", + }, + }, + { + name: "MultipleRoles_ComplexPath", + req: usecase.CheckReq{ + Roles: []string{"admin", "user", "editor"}, + Path: "/editor/dashboard", + Method: "PUT", + }, + }, + { + name: "NoRoles_InvalidPath", + req: usecase.CheckReq{ + Roles: []string{}, + Path: "/invalid/path", + Method: "POST", + }, + }, + } + + // 走訪所有場景 + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + // 設置一個超時的 ctx + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + + _, _ = got.CheckRBACPermission(ctx, bm.req) + } + }) + } +} diff --git a/internal/usecase/rule.rego b/internal/usecase/rule.rego new file mode 100644 index 0000000..3aa4aab --- /dev/null +++ b/internal/usecase/rule.rego @@ -0,0 +1,34 @@ +package rbac + +import rego.v1 + +request = { + "roles": input.roles, + "path": input.path, + "method": input.method, + "policies": input.policies, +} + +default allow = false + +key_match(request_path, policy_path) if { + regex.match(policy_path, request_path) +} + +# 方法函數的驗證 +method_match(request_method, policy_methods) if { + policy_methods[_] == request_method +} + +# 檢驗是不是匹配或繼承 +valid_role(user_role, policy_role) if { + user_role[_] == policy_role +} + +# 定義一個策略 +allow if { + policy := input.policies[_] + key_match(input.path, policy.path) + valid_role(input.roles, policy.role) + method_match(input.method, policy.methods) +}