// Package totp implements RFC 6238 Time-based One-Time Password generation // and verification helpers used by the member step-up MFA flow. // // Algorithm constraints (compatible with Google Authenticator / Authy / // 1Password / Microsoft Authenticator): // - HMAC-SHA1 // - 30-second period // - 6 digits // // Callers may verify across a small window (typically plus or minus one step) // to tolerate device clock drift. package totp import ( "crypto/hmac" "crypto/rand" "crypto/sha1" //nolint:gosec // RFC 6238 requires HMAC-SHA1 "crypto/subtle" "encoding/base32" "encoding/binary" "encoding/hex" "fmt" "net/url" "strings" "time" ) // Defaults match the configuration documented in identity-member-design.md // section 5.8 / etc/gateway.yaml TOTP block. const ( DefaultDigits = 6 DefaultPeriod = 30 * time.Second DefaultWindow = 1 SecretBytes = 20 BackupCodeSize = 12 BackupCodeNum = 10 ) // Sentinel errors so callers can distinguish failure modes. var ( ErrInvalidSecret = fmt.Errorf("totp: invalid secret") ErrInvalidCode = fmt.Errorf("totp: invalid code") ErrInvalidDigits = fmt.Errorf("totp: invalid digit length") ErrInvalidIssuer = fmt.Errorf("totp: issuer must be non-empty") ErrInvalidAccount = fmt.Errorf("totp: account must be non-empty") ) // GenerateSecret returns a random 160-bit secret suitable for RFC 4226 / 6238. func GenerateSecret() ([]byte, error) { buf := make([]byte, SecretBytes) if _, err := rand.Read(buf); err != nil { return nil, fmt.Errorf("totp: read random: %w", err) } return buf, nil } // EncodeSecret returns the base32 (no padding) representation typically // embedded into otpauth URLs. func EncodeSecret(secret []byte) string { return strings.TrimRight(base32.StdEncoding.EncodeToString(secret), "=") } // DecodeSecret parses the base32 representation back to raw bytes. Padding is // optional so the function accepts the canonical form used in otpauth URLs. func DecodeSecret(encoded string) ([]byte, error) { if encoded == "" { return nil, ErrInvalidSecret } clean := strings.ToUpper(strings.TrimSpace(encoded)) clean = strings.TrimRight(clean, "=") if mod := len(clean) % 8; mod != 0 { clean += strings.Repeat("=", 8-mod) } out, err := base32.StdEncoding.DecodeString(clean) if err != nil { return nil, fmt.Errorf("%w: %w", ErrInvalidSecret, err) } return out, nil } // Generate returns the TOTP code for the given timestamp. func Generate(secret []byte, ts time.Time, period time.Duration, digits int) (string, error) { if len(secret) == 0 { return "", ErrInvalidSecret } if digits <= 0 || digits > 10 { return "", ErrInvalidDigits } if period <= 0 { period = DefaultPeriod } return computeHOTP(secret, unixCounter(ts, period), digits), nil } // Verify checks the supplied code against the secret allowing for plus or // minus window time steps. When the code is valid it returns the timestep // counter that matched so callers can persist it for replay protection. func Verify(secret []byte, code string, ts time.Time, period time.Duration, digits, window int) (uint64, bool) { if len(secret) == 0 || code == "" || digits <= 0 { return 0, false } if len(code) != digits { return 0, false } if period <= 0 { period = DefaultPeriod } if window < 0 { window = 0 } base := unixCounter(ts, period) want := []byte(code) for i := -window; i <= window; i++ { counter, ok := shiftCounter(base, i) if !ok { continue } got := []byte(computeHOTP(secret, counter, digits)) if subtle.ConstantTimeCompare(want, got) == 1 { return counter, true } } return 0, false } // shiftCounter applies a signed offset to base without crossing 0. // Returning (0, false) signals that the requested offset is out of range. func shiftCounter(base uint64, offset int) (uint64, bool) { if offset >= 0 { return base + uint64(offset), true } neg := uint64(-offset) //nolint:gosec // offset is bounded by ±Window; magnitude fits uint64. if base < neg { return 0, false } return base - neg, true } func computeHOTP(secret []byte, counter uint64, digits int) string { var buf [8]byte binary.BigEndian.PutUint64(buf[:], counter) mac := hmac.New(sha1.New, secret) mac.Write(buf[:]) sum := mac.Sum(nil) offset := sum[len(sum)-1] & 0x0F bin := (uint32(sum[offset]&0x7F) << 24) | (uint32(sum[offset+1]) << 16) | (uint32(sum[offset+2]) << 8) | uint32(sum[offset+3]) mod := uint32(1) for range digits { mod *= 10 } num := bin % mod return fmt.Sprintf("%0*d", digits, num) } // OtpauthURLInput controls the otpauth URL fields shown to users. type OtpauthURLInput struct { Issuer string Account string Secret []byte Algorithm string Digits int Period time.Duration } // BuildOtpauthURL renders the canonical otpauth://totp/... URL. func BuildOtpauthURL(in OtpauthURLInput) (string, error) { if in.Issuer == "" { return "", ErrInvalidIssuer } if in.Account == "" { return "", ErrInvalidAccount } if len(in.Secret) == 0 { return "", ErrInvalidSecret } algo := in.Algorithm if algo == "" { algo = "SHA1" } digits := in.Digits if digits <= 0 { digits = DefaultDigits } period := in.Period if period <= 0 { period = DefaultPeriod } q := url.Values{} q.Set("secret", EncodeSecret(in.Secret)) q.Set("issuer", in.Issuer) q.Set("algorithm", algo) q.Set("digits", fmt.Sprintf("%d", digits)) q.Set("period", fmt.Sprintf("%d", int(period.Seconds()))) label := url.PathEscape(in.Issuer) + ":" + url.PathEscape(in.Account) return "otpauth://totp/" + label + "?" + q.Encode(), nil } // GenerateBackupCodes returns n random hex-encoded codes of the given length. // They are returned in plaintext; callers must hash before storage. func GenerateBackupCodes(n, length int) ([]string, error) { if n <= 0 { n = BackupCodeNum } if length <= 0 { length = BackupCodeSize } if length%2 != 0 { length++ } out := make([]string, 0, n) buf := make([]byte, length/2) for range n { if _, err := rand.Read(buf); err != nil { return nil, fmt.Errorf("totp: backup code rand: %w", err) } out = append(out, hex.EncodeToString(buf)) } return out, nil } // TimeStep returns the integer time-step counter used by Verify, useful for // replay-protection keys. func TimeStep(ts time.Time, period time.Duration) uint64 { if period <= 0 { period = DefaultPeriod } return unixCounter(ts, period) } // unixCounter converts (ts, period) to the RFC 6238 counter, clamping // negative timestamps to zero so the int64→uint64 conversion is bounded. func unixCounter(ts time.Time, period time.Duration) uint64 { sec := ts.Unix() if sec < 0 { sec = 0 } per := int64(period.Seconds()) if per <= 0 { per = int64(DefaultPeriod.Seconds()) } return uint64(sec / per) //nolint:gosec // sec is non-negative and per is positive; quotient fits uint64. }