template-monorepo/internal/model/member/totp/totp.go

248 lines
6.7 KiB
Go
Raw Normal View History

2026-05-20 13:03:59 +00:00
// Package totp implements RFC 6238 Time-based One-Time Password generation
// and verification helpers used by the member step-up MFA flow.
//
// Algorithm constraints (compatible with Google Authenticator / Authy /
// 1Password / Microsoft Authenticator):
// - HMAC-SHA1
// - 30-second period
// - 6 digits
//
// Callers may verify across a small window (typically plus or minus one step)
// to tolerate device clock drift.
package totp
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha1" //nolint:gosec // RFC 6238 requires HMAC-SHA1
"crypto/subtle"
"encoding/base32"
"encoding/binary"
"encoding/hex"
"fmt"
"net/url"
"strings"
"time"
)
// Defaults match the configuration documented in internal/model/member/SDD.md
// §3.5 / etc/gateway.yaml TOTP block.
2026-05-20 13:03:59 +00:00
const (
DefaultDigits = 6
DefaultPeriod = 30 * time.Second
DefaultWindow = 1
SecretBytes = 20
BackupCodeSize = 12
BackupCodeNum = 10
)
// Sentinel errors so callers can distinguish failure modes.
var (
ErrInvalidSecret = fmt.Errorf("totp: invalid secret")
ErrInvalidCode = fmt.Errorf("totp: invalid code")
ErrInvalidDigits = fmt.Errorf("totp: invalid digit length")
ErrInvalidIssuer = fmt.Errorf("totp: issuer must be non-empty")
ErrInvalidAccount = fmt.Errorf("totp: account must be non-empty")
)
// GenerateSecret returns a random 160-bit secret suitable for RFC 4226 / 6238.
func GenerateSecret() ([]byte, error) {
buf := make([]byte, SecretBytes)
if _, err := rand.Read(buf); err != nil {
return nil, fmt.Errorf("totp: read random: %w", err)
}
return buf, nil
}
// EncodeSecret returns the base32 (no padding) representation typically
// embedded into otpauth URLs.
func EncodeSecret(secret []byte) string {
return strings.TrimRight(base32.StdEncoding.EncodeToString(secret), "=")
}
// DecodeSecret parses the base32 representation back to raw bytes. Padding is
// optional so the function accepts the canonical form used in otpauth URLs.
func DecodeSecret(encoded string) ([]byte, error) {
if encoded == "" {
return nil, ErrInvalidSecret
}
clean := strings.ToUpper(strings.TrimSpace(encoded))
clean = strings.TrimRight(clean, "=")
if mod := len(clean) % 8; mod != 0 {
clean += strings.Repeat("=", 8-mod)
}
out, err := base32.StdEncoding.DecodeString(clean)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrInvalidSecret, err)
}
return out, nil
}
// Generate returns the TOTP code for the given timestamp.
func Generate(secret []byte, ts time.Time, period time.Duration, digits int) (string, error) {
if len(secret) == 0 {
return "", ErrInvalidSecret
}
if digits <= 0 || digits > 10 {
return "", ErrInvalidDigits
}
if period <= 0 {
period = DefaultPeriod
}
return computeHOTP(secret, unixCounter(ts, period), digits), nil
}
// Verify checks the supplied code against the secret allowing for plus or
// minus window time steps. When the code is valid it returns the timestep
// counter that matched so callers can persist it for replay protection.
func Verify(secret []byte, code string, ts time.Time, period time.Duration, digits, window int) (uint64, bool) {
if len(secret) == 0 || code == "" || digits <= 0 {
return 0, false
}
if len(code) != digits {
return 0, false
}
if period <= 0 {
period = DefaultPeriod
}
if window < 0 {
window = 0
}
base := unixCounter(ts, period)
want := []byte(code)
for i := -window; i <= window; i++ {
counter, ok := shiftCounter(base, i)
if !ok {
continue
}
got := []byte(computeHOTP(secret, counter, digits))
if subtle.ConstantTimeCompare(want, got) == 1 {
return counter, true
}
}
return 0, false
}
// shiftCounter applies a signed offset to base without crossing 0.
// Returning (0, false) signals that the requested offset is out of range.
func shiftCounter(base uint64, offset int) (uint64, bool) {
if offset >= 0 {
return base + uint64(offset), true
}
neg := uint64(-offset) //nolint:gosec // offset is bounded by ±Window; magnitude fits uint64.
if base < neg {
return 0, false
}
return base - neg, true
}
func computeHOTP(secret []byte, counter uint64, digits int) string {
var buf [8]byte
binary.BigEndian.PutUint64(buf[:], counter)
mac := hmac.New(sha1.New, secret)
mac.Write(buf[:])
sum := mac.Sum(nil)
offset := sum[len(sum)-1] & 0x0F
bin := (uint32(sum[offset]&0x7F) << 24) |
(uint32(sum[offset+1]) << 16) |
(uint32(sum[offset+2]) << 8) |
uint32(sum[offset+3])
mod := uint32(1)
for range digits {
mod *= 10
}
num := bin % mod
return fmt.Sprintf("%0*d", digits, num)
}
// OtpauthURLInput controls the otpauth URL fields shown to users.
type OtpauthURLInput struct {
Issuer string
Account string
Secret []byte
Algorithm string
Digits int
Period time.Duration
}
// BuildOtpauthURL renders the canonical otpauth://totp/... URL.
func BuildOtpauthURL(in OtpauthURLInput) (string, error) {
if in.Issuer == "" {
return "", ErrInvalidIssuer
}
if in.Account == "" {
return "", ErrInvalidAccount
}
if len(in.Secret) == 0 {
return "", ErrInvalidSecret
}
algo := in.Algorithm
if algo == "" {
algo = "SHA1"
}
digits := in.Digits
if digits <= 0 {
digits = DefaultDigits
}
period := in.Period
if period <= 0 {
period = DefaultPeriod
}
q := url.Values{}
q.Set("secret", EncodeSecret(in.Secret))
q.Set("issuer", in.Issuer)
q.Set("algorithm", algo)
q.Set("digits", fmt.Sprintf("%d", digits))
q.Set("period", fmt.Sprintf("%d", int(period.Seconds())))
label := url.PathEscape(in.Issuer) + ":" + url.PathEscape(in.Account)
return "otpauth://totp/" + label + "?" + q.Encode(), nil
}
// GenerateBackupCodes returns n random hex-encoded codes of the given length.
// They are returned in plaintext; callers must hash before storage.
func GenerateBackupCodes(n, length int) ([]string, error) {
if n <= 0 {
n = BackupCodeNum
}
if length <= 0 {
length = BackupCodeSize
}
if length%2 != 0 {
length++
}
out := make([]string, 0, n)
buf := make([]byte, length/2)
for range n {
if _, err := rand.Read(buf); err != nil {
return nil, fmt.Errorf("totp: backup code rand: %w", err)
}
out = append(out, hex.EncodeToString(buf))
}
return out, nil
}
// TimeStep returns the integer time-step counter used by Verify, useful for
// replay-protection keys.
func TimeStep(ts time.Time, period time.Duration) uint64 {
if period <= 0 {
period = DefaultPeriod
}
return unixCounter(ts, period)
}
// unixCounter converts (ts, period) to the RFC 6238 counter, clamping
// negative timestamps to zero so the int64→uint64 conversion is bounded.
func unixCounter(ts time.Time, period time.Duration) uint64 {
sec := ts.Unix()
if sec < 0 {
sec = 0
}
per := int64(period.Seconds())
if per <= 0 {
per = int64(DefaultPeriod.Seconds())
}
return uint64(sec / per) //nolint:gosec // sec is non-negative and per is positive; quotient fits uint64.
}