401 lines
14 KiB
Go
401 lines
14 KiB
Go
|
|
// k6-seed-admin bootstraps an admin user for the k6 rbac journey.
|
||
|
|
//
|
||
|
|
// Workflow (no external deps beyond gateway + MailHog + Mongo):
|
||
|
|
// 1. POST /api/v1/auth/register against the local gateway with a fixed
|
||
|
|
// admin email/password.
|
||
|
|
// 2. Poll MailHog HTTP API for the 6-digit OTP.
|
||
|
|
// 3. POST /api/v1/auth/register/confirm to receive a JWT (we don't keep it).
|
||
|
|
// 4. Connect to Mongo, seed the permission catalog + default system roles for
|
||
|
|
// the tenant via internal/model/permission/seed.Apply.
|
||
|
|
// 5. Insert a UserRole linking the new admin UID to the tenant_admin role.
|
||
|
|
// 6. Print ADMIN_EMAIL / ADMIN_PASSWORD / ADMIN_UID env exports to stdout so
|
||
|
|
// callers can `eval "$(make k6-seed-admin ...)"` or redirect into a file.
|
||
|
|
//
|
||
|
|
// Re-running is safe: register is idempotent at the OTP-confirm step (the
|
||
|
|
// challenge is fresh per call), and seed.Apply / UserRole insert are
|
||
|
|
// idempotent-by-key.
|
||
|
|
package main
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"flag"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"net/http"
|
||
|
|
"os"
|
||
|
|
"regexp"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
libmongo "gateway/internal/library/mongo"
|
||
|
|
memberrepo "gateway/internal/model/member/repository"
|
||
|
|
permdomain "gateway/internal/model/permission/domain"
|
||
|
|
permentity "gateway/internal/model/permission/domain/entity"
|
||
|
|
permrepo "gateway/internal/model/permission/repository"
|
||
|
|
permseed "gateway/internal/model/permission/seed"
|
||
|
|
|
||
|
|
"github.com/redis/go-redis/v9"
|
||
|
|
"github.com/zeromicro/go-zero/core/logx"
|
||
|
|
"go.mongodb.org/mongo-driver/v2/bson"
|
||
|
|
)
|
||
|
|
|
||
|
|
var (
|
||
|
|
flagBase = flag.String("base", envOr("BASE_URL", "http://localhost:8888"), "Gateway base URL")
|
||
|
|
flagMailhog = flag.String("mailhog", envOr("MAILHOG_URL", "http://localhost:8025"), "MailHog HTTP API URL")
|
||
|
|
flagTenant = flag.String("tenant", envOr("TENANT_SLUG", "k6-tenant"), "Tenant slug")
|
||
|
|
flagInvite = flag.String("invite", envOr("INVITE_CODE", "K6INVITE"), "Invite code")
|
||
|
|
// Default email is rotated per-invocation. Re-running seed-admin against
|
||
|
|
// a stable email would collide with the existing ZITADEL user (28303000
|
||
|
|
// email already registered) since ZITADEL state lives outside docker
|
||
|
|
// volumes that `make k6-down` clears. Override with -email or ADMIN_EMAIL.
|
||
|
|
flagEmail = flag.String("email", envOr("ADMIN_EMAIL", fmt.Sprintf("k6-admin-%d@k6.local", time.Now().Unix())), "Admin email")
|
||
|
|
flagPassword = flag.String("password", envOr("ADMIN_PASSWORD", "K6-Admin-Pass-1!"), "Admin password")
|
||
|
|
flagMongoHost = flag.String("mongo-host", envOr("K6_MONGO_HOST", "127.0.0.1"), "Mongo host")
|
||
|
|
flagMongoPort = flag.Int("mongo-port", envOrInt("K6_MONGO_PORT", 27017), "Mongo port")
|
||
|
|
flagMongoDB = flag.String("mongo-db", envOr("K6_MONGO_DB", "gateway_k6"), "Mongo database")
|
||
|
|
flagTenantID = flag.String("tenant-id", envOr("ADMIN_TENANT_ID", ""), "Override resolved tenant_id (skip lookup)")
|
||
|
|
flagPollSecs = flag.Int("otp-timeout", 10, "MailHog OTP poll timeout (seconds)")
|
||
|
|
flagDryRun = flag.Bool("dry-run", false, "Skip Mongo writes; only test register flow")
|
||
|
|
flagRedisAddr = flag.String("redis-addr", envOr("REDIS_ADDR", "localhost:6379"), "Redis addr (host:port) for casbin reload broadcast")
|
||
|
|
flagReloadChannel = flag.String("reload-channel", envOr("CASBIN_RELOAD_CHANNEL", "casbin:reload:k6"), "Casbin reload Pub/Sub channel (must match gateway Permission.Reload.Channel)")
|
||
|
|
)
|
||
|
|
|
||
|
|
func main() {
|
||
|
|
flag.Parse()
|
||
|
|
// go-zero's mongo helper logs every query via logx; in a CLI that pipes
|
||
|
|
// stdout to k6.env that pollutes the env file with JSON log lines.
|
||
|
|
// Disable logx entirely — we keep our own [k6-seed-admin] stderr logs.
|
||
|
|
logx.Disable()
|
||
|
|
|
||
|
|
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
logf("registering admin %s @ %s", *flagEmail, *flagBase)
|
||
|
|
regResp, err := register(ctx)
|
||
|
|
if err != nil {
|
||
|
|
exitf("register: %v", err)
|
||
|
|
}
|
||
|
|
logf("challenge_id=%s uid=%s", regResp.ChallengeID, regResp.UID)
|
||
|
|
|
||
|
|
code, err := pollOTP(ctx, *flagEmail, time.Duration(*flagPollSecs)*time.Second)
|
||
|
|
if err != nil {
|
||
|
|
exitf("poll OTP: %v", err)
|
||
|
|
}
|
||
|
|
logf("OTP=%s", code)
|
||
|
|
|
||
|
|
tokens, err := confirm(ctx, regResp.ChallengeID, code)
|
||
|
|
if err != nil {
|
||
|
|
exitf("register/confirm: %v", err)
|
||
|
|
}
|
||
|
|
logf("registration confirmed; admin uid=%s access_token=%d chars", regResp.UID, len(tokens.AccessToken))
|
||
|
|
|
||
|
|
if *flagDryRun {
|
||
|
|
writeOutput(*flagEmail, *flagPassword, regResp.UID, "", tokens)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
mongoConf := &libmongo.Conf{
|
||
|
|
Schema: "mongodb",
|
||
|
|
Host: *flagMongoHost,
|
||
|
|
Port: *flagMongoPort,
|
||
|
|
Database: *flagMongoDB,
|
||
|
|
}
|
||
|
|
|
||
|
|
tenantID := *flagTenantID
|
||
|
|
if tenantID == "" {
|
||
|
|
t, err := resolveTenantID(ctx, mongoConf, *flagTenant)
|
||
|
|
if err != nil {
|
||
|
|
exitf("resolve tenant_id for slug=%s: %v", *flagTenant, err)
|
||
|
|
}
|
||
|
|
tenantID = t
|
||
|
|
}
|
||
|
|
logf("tenant_id=%s", tenantID)
|
||
|
|
|
||
|
|
if err := seedRoles(ctx, mongoConf, tenantID); err != nil {
|
||
|
|
exitf("seed roles: %v", err)
|
||
|
|
}
|
||
|
|
roleID, err := assignAdmin(ctx, mongoConf, tenantID, regResp.UID)
|
||
|
|
if err != nil {
|
||
|
|
exitf("assign tenant_admin: %v", err)
|
||
|
|
}
|
||
|
|
logf("tenant_admin role_id=%s assigned", roleID)
|
||
|
|
|
||
|
|
// Casbin lives in process memory inside the gateway and only reloads
|
||
|
|
// from Mongo when it boots or when something publishes on the reload
|
||
|
|
// channel. seed-admin runs AFTER the gateway started, so without this
|
||
|
|
// broadcast the admin's tenant_admin assignment is invisible until a
|
||
|
|
// restart and the rbac journey 403s on the very first /roles call.
|
||
|
|
if err := broadcastReload(ctx, *flagRedisAddr, *flagReloadChannel, tenantID); err != nil {
|
||
|
|
logf("warn: casbin reload broadcast failed (rbac journey may 403 until gateway restart): %v", err)
|
||
|
|
} else {
|
||
|
|
logf("casbin policy reload broadcast on %s channel=%s", *flagRedisAddr, *flagReloadChannel)
|
||
|
|
}
|
||
|
|
// Pub/Sub is best-effort; give the subscriber a beat to LoadPolicy
|
||
|
|
// before callers (e.g. make k6-journey) hit /roles.
|
||
|
|
time.Sleep(500 * time.Millisecond)
|
||
|
|
|
||
|
|
writeOutput(*flagEmail, *flagPassword, regResp.UID, tenantID, tokens)
|
||
|
|
}
|
||
|
|
|
||
|
|
// broadcastReload publishes a casbin reload event on the same Redis channel
|
||
|
|
// the gateway subscribes to (see internal/model/permission/usecase
|
||
|
|
// /rbac_usecase.go::BroadcastReload). Payload shape mirrors that function.
|
||
|
|
func broadcastReload(ctx context.Context, addr, channel, tenantID string) error {
|
||
|
|
if addr == "" {
|
||
|
|
return fmt.Errorf("redis addr empty")
|
||
|
|
}
|
||
|
|
if channel == "" {
|
||
|
|
channel = permdomain.PolicyReloadChannel
|
||
|
|
}
|
||
|
|
if tenantID == "" {
|
||
|
|
tenantID = permdomain.PolicyReloadAllToken
|
||
|
|
}
|
||
|
|
rdb := redis.NewClient(&redis.Options{Addr: addr})
|
||
|
|
defer func() { _ = rdb.Close() }()
|
||
|
|
payload, _ := json.Marshal(map[string]any{
|
||
|
|
"tenant_id": tenantID,
|
||
|
|
"ts": time.Now().UnixMilli(),
|
||
|
|
})
|
||
|
|
pubCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||
|
|
defer cancel()
|
||
|
|
return rdb.Publish(pubCtx, channel, payload).Err()
|
||
|
|
}
|
||
|
|
|
||
|
|
// ---------- HTTP / API helpers ----------
|
||
|
|
|
||
|
|
type registerResp struct {
|
||
|
|
ChallengeID string `json:"challenge_id"`
|
||
|
|
ExpiresIn int `json:"expires_in"`
|
||
|
|
UID string `json:"uid"`
|
||
|
|
}
|
||
|
|
|
||
|
|
type confirmResp struct {
|
||
|
|
AccessToken string `json:"access_token"`
|
||
|
|
RefreshToken string `json:"refresh_token"`
|
||
|
|
ExpiresIn int `json:"expires_in"`
|
||
|
|
UID string `json:"uid"`
|
||
|
|
TokenType string `json:"token_type"`
|
||
|
|
}
|
||
|
|
|
||
|
|
type envelope struct {
|
||
|
|
Code int `json:"code"`
|
||
|
|
Message string `json:"message"`
|
||
|
|
Data json.RawMessage `json:"data"`
|
||
|
|
}
|
||
|
|
|
||
|
|
func register(ctx context.Context) (*registerResp, error) {
|
||
|
|
body, _ := json.Marshal(map[string]any{
|
||
|
|
"tenant_slug": *flagTenant,
|
||
|
|
"invite_code": *flagInvite,
|
||
|
|
"email": *flagEmail,
|
||
|
|
"password": *flagPassword,
|
||
|
|
"display_name": "k6 admin",
|
||
|
|
"language": "zh-TW",
|
||
|
|
"accept_terms_version": "2025-01-01",
|
||
|
|
"marketing_opt_in": false,
|
||
|
|
})
|
||
|
|
env, err := doJSON(ctx, "POST", *flagBase+"/api/v1/auth/register", body)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
var r registerResp
|
||
|
|
if err := json.Unmarshal(env.Data, &r); err != nil {
|
||
|
|
return nil, fmt.Errorf("decode register data: %w", err)
|
||
|
|
}
|
||
|
|
return &r, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func confirm(ctx context.Context, challengeID, code string) (*confirmResp, error) {
|
||
|
|
body, _ := json.Marshal(map[string]any{
|
||
|
|
"tenant_slug": *flagTenant,
|
||
|
|
"challenge_id": challengeID,
|
||
|
|
"code": code,
|
||
|
|
})
|
||
|
|
env, err := doJSON(ctx, "POST", *flagBase+"/api/v1/auth/register/confirm", body)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
var r confirmResp
|
||
|
|
if err := json.Unmarshal(env.Data, &r); err != nil {
|
||
|
|
return nil, fmt.Errorf("decode confirm data: %w", err)
|
||
|
|
}
|
||
|
|
return &r, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func doJSON(ctx context.Context, method, url string, body []byte) (*envelope, error) {
|
||
|
|
req, err := http.NewRequestWithContext(ctx, method, url, strings.NewReader(string(body)))
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
req.Header.Set("Content-Type", "application/json")
|
||
|
|
resp, err := http.DefaultClient.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
defer resp.Body.Close()
|
||
|
|
raw, _ := io.ReadAll(resp.Body)
|
||
|
|
if resp.StatusCode >= 400 {
|
||
|
|
return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(raw)))
|
||
|
|
}
|
||
|
|
var env envelope
|
||
|
|
if err := json.Unmarshal(raw, &env); err != nil {
|
||
|
|
return nil, fmt.Errorf("decode envelope: %w (body=%s)", err, raw)
|
||
|
|
}
|
||
|
|
if env.Code != 102000 {
|
||
|
|
return nil, fmt.Errorf("non-success code=%d message=%s", env.Code, env.Message)
|
||
|
|
}
|
||
|
|
return &env, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
var (
|
||
|
|
otpRegex = regexp.MustCompile(`\b(\d{6})\b`)
|
||
|
|
cssHexRe = regexp.MustCompile(`#[0-9a-fA-F]{6}\b`)
|
||
|
|
qpSoftLine = regexp.MustCompile(`=\r?\n`)
|
||
|
|
)
|
||
|
|
|
||
|
|
// extractOTP returns the LAST 6-digit number in the body after stripping
|
||
|
|
// CSS hex colors (e.g. #059669) and quoted-printable soft line breaks.
|
||
|
|
// Email bodies render the OTP in a styled span near the bottom; the naive
|
||
|
|
// "first 6-digit" approach picks up brand colors.
|
||
|
|
func extractOTP(body string) string {
|
||
|
|
cleaned := cssHexRe.ReplaceAllString(qpSoftLine.ReplaceAllString(body, ""), "")
|
||
|
|
matches := otpRegex.FindAllStringSubmatch(cleaned, -1)
|
||
|
|
if len(matches) == 0 {
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
return matches[len(matches)-1][1]
|
||
|
|
}
|
||
|
|
|
||
|
|
type mailhogItem struct {
|
||
|
|
Created string `json:"Created"`
|
||
|
|
Content struct {
|
||
|
|
Body string `json:"Body"`
|
||
|
|
} `json:"Content"`
|
||
|
|
}
|
||
|
|
|
||
|
|
type mailhogList struct {
|
||
|
|
Items []mailhogItem `json:"items"`
|
||
|
|
}
|
||
|
|
|
||
|
|
func pollOTP(ctx context.Context, email string, timeout time.Duration) (string, error) {
|
||
|
|
deadline := time.Now().Add(timeout)
|
||
|
|
url := fmt.Sprintf("%s/api/v2/search?kind=to&query=%s&start=0&limit=5", *flagMailhog, email)
|
||
|
|
for time.Now().Before(deadline) {
|
||
|
|
req, _ := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||
|
|
resp, err := http.DefaultClient.Do(req)
|
||
|
|
if err == nil && resp.StatusCode == 200 {
|
||
|
|
raw, _ := io.ReadAll(resp.Body)
|
||
|
|
_ = resp.Body.Close()
|
||
|
|
var list mailhogList
|
||
|
|
if json.Unmarshal(raw, &list) == nil {
|
||
|
|
for _, it := range list.Items {
|
||
|
|
if code := extractOTP(it.Content.Body); code != "" {
|
||
|
|
return code, nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
} else if resp != nil {
|
||
|
|
_ = resp.Body.Close()
|
||
|
|
}
|
||
|
|
time.Sleep(300 * time.Millisecond)
|
||
|
|
}
|
||
|
|
return "", fmt.Errorf("OTP not seen within %s", timeout)
|
||
|
|
}
|
||
|
|
|
||
|
|
// ---------- Mongo helpers ----------
|
||
|
|
|
||
|
|
func resolveTenantID(ctx context.Context, conf *libmongo.Conf, slug string) (string, error) {
|
||
|
|
repo := memberrepo.NewTenantRepository(memberrepo.TenantRepositoryParam{Conf: conf})
|
||
|
|
deadline := time.Now().Add(5 * time.Second)
|
||
|
|
for time.Now().Before(deadline) {
|
||
|
|
t, err := repo.GetBySlug(ctx, slug)
|
||
|
|
if err == nil && t != nil && t.TenantID != "" {
|
||
|
|
return t.TenantID, nil
|
||
|
|
}
|
||
|
|
time.Sleep(200 * time.Millisecond)
|
||
|
|
}
|
||
|
|
// Fallback: treat the slug itself as the tenant id (works for gateways
|
||
|
|
// that use slug == tenant_id, e.g. dev seed).
|
||
|
|
return slug, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func seedRoles(ctx context.Context, conf *libmongo.Conf, tenantID string) error {
|
||
|
|
perms := permrepo.NewPermissionRepository(permrepo.PermissionRepositoryParam{Conf: conf})
|
||
|
|
roles := permrepo.NewRoleRepository(permrepo.RoleRepositoryParam{Conf: conf})
|
||
|
|
rolePerms := permrepo.NewRolePermissionRepository(permrepo.RolePermissionRepositoryParam{Conf: conf})
|
||
|
|
rpt, err := permseed.Apply(ctx, perms, roles, rolePerms, permseed.ApplyOptions{
|
||
|
|
TenantIDs: []string{tenantID},
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
logf("seed report: catalog=%d roles=%d role_perms=%d", rpt.CatalogUpserted, rpt.RolesUpserted, rpt.RolePermissionSet)
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func assignAdmin(ctx context.Context, conf *libmongo.Conf, tenantID, uid string) (string, error) {
|
||
|
|
roles := permrepo.NewRoleRepository(permrepo.RoleRepositoryParam{Conf: conf})
|
||
|
|
role, err := roles.GetByKey(ctx, tenantID, "tenant_admin")
|
||
|
|
if err != nil || role == nil {
|
||
|
|
return "", fmt.Errorf("tenant_admin role not found for tenant=%s: %v", tenantID, err)
|
||
|
|
}
|
||
|
|
urRepo := permrepo.NewUserRoleRepository(permrepo.UserRoleRepositoryParam{Conf: conf})
|
||
|
|
if err := urRepo.Insert(ctx, &permentity.UserRole{
|
||
|
|
ID: bson.NewObjectID(),
|
||
|
|
TenantID: tenantID,
|
||
|
|
UID: uid,
|
||
|
|
RoleID: role.ID.Hex(),
|
||
|
|
}); err != nil {
|
||
|
|
// duplicate-key is OK (idempotent re-run)
|
||
|
|
if !strings.Contains(err.Error(), "duplicate") {
|
||
|
|
return "", err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return role.ID.Hex(), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// ---------- util ----------
|
||
|
|
|
||
|
|
func writeOutput(email, password, uid, tenantID string, tokens *confirmResp) {
|
||
|
|
fmt.Printf("export ADMIN_EMAIL=%s\n", email)
|
||
|
|
fmt.Printf("export ADMIN_PASSWORD=%s\n", password)
|
||
|
|
fmt.Printf("export ADMIN_UID=%s\n", uid)
|
||
|
|
if tenantID != "" {
|
||
|
|
fmt.Printf("export ADMIN_TENANT_ID=%s\n", tenantID)
|
||
|
|
}
|
||
|
|
if tokens != nil && tokens.AccessToken != "" {
|
||
|
|
// k6 journeys (rbac_admin.js) prefer these over POST /auth/login,
|
||
|
|
// since ZITADEL v2 disables the OAuth password grant by default and
|
||
|
|
// the gateway's /auth/login → VerifyPassword path then 502s.
|
||
|
|
fmt.Printf("export ADMIN_ACCESS_TOKEN=%s\n", tokens.AccessToken)
|
||
|
|
fmt.Printf("export ADMIN_REFRESH_TOKEN=%s\n", tokens.RefreshToken)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func envOr(k, def string) string {
|
||
|
|
if v := os.Getenv(k); v != "" {
|
||
|
|
return v
|
||
|
|
}
|
||
|
|
return def
|
||
|
|
}
|
||
|
|
|
||
|
|
func envOrInt(k string, def int) int {
|
||
|
|
if v := os.Getenv(k); v != "" {
|
||
|
|
var n int
|
||
|
|
if _, err := fmt.Sscanf(v, "%d", &n); err == nil {
|
||
|
|
return n
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return def
|
||
|
|
}
|
||
|
|
|
||
|
|
func logf(format string, a ...any) {
|
||
|
|
fmt.Fprintf(os.Stderr, "[k6-seed-admin] "+format+"\n", a...)
|
||
|
|
}
|
||
|
|
|
||
|
|
func exitf(format string, a ...any) {
|
||
|
|
logf(format, a...)
|
||
|
|
os.Exit(1)
|
||
|
|
}
|