248 lines
6.7 KiB
Go
248 lines
6.7 KiB
Go
|
|
// Package totp implements RFC 6238 Time-based One-Time Password generation
|
||
|
|
// and verification helpers used by the member step-up MFA flow.
|
||
|
|
//
|
||
|
|
// Algorithm constraints (compatible with Google Authenticator / Authy /
|
||
|
|
// 1Password / Microsoft Authenticator):
|
||
|
|
// - HMAC-SHA1
|
||
|
|
// - 30-second period
|
||
|
|
// - 6 digits
|
||
|
|
//
|
||
|
|
// Callers may verify across a small window (typically plus or minus one step)
|
||
|
|
// to tolerate device clock drift.
|
||
|
|
package totp
|
||
|
|
|
||
|
|
import (
|
||
|
|
"crypto/hmac"
|
||
|
|
"crypto/rand"
|
||
|
|
"crypto/sha1" //nolint:gosec // RFC 6238 requires HMAC-SHA1
|
||
|
|
"crypto/subtle"
|
||
|
|
"encoding/base32"
|
||
|
|
"encoding/binary"
|
||
|
|
"encoding/hex"
|
||
|
|
"fmt"
|
||
|
|
"net/url"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
// Defaults match the configuration documented in identity-member-design.md
|
||
|
|
// section 5.8 / etc/gateway.yaml TOTP block.
|
||
|
|
const (
|
||
|
|
DefaultDigits = 6
|
||
|
|
DefaultPeriod = 30 * time.Second
|
||
|
|
DefaultWindow = 1
|
||
|
|
SecretBytes = 20
|
||
|
|
BackupCodeSize = 12
|
||
|
|
BackupCodeNum = 10
|
||
|
|
)
|
||
|
|
|
||
|
|
// Sentinel errors so callers can distinguish failure modes.
|
||
|
|
var (
|
||
|
|
ErrInvalidSecret = fmt.Errorf("totp: invalid secret")
|
||
|
|
ErrInvalidCode = fmt.Errorf("totp: invalid code")
|
||
|
|
ErrInvalidDigits = fmt.Errorf("totp: invalid digit length")
|
||
|
|
ErrInvalidIssuer = fmt.Errorf("totp: issuer must be non-empty")
|
||
|
|
ErrInvalidAccount = fmt.Errorf("totp: account must be non-empty")
|
||
|
|
)
|
||
|
|
|
||
|
|
// GenerateSecret returns a random 160-bit secret suitable for RFC 4226 / 6238.
|
||
|
|
func GenerateSecret() ([]byte, error) {
|
||
|
|
buf := make([]byte, SecretBytes)
|
||
|
|
if _, err := rand.Read(buf); err != nil {
|
||
|
|
return nil, fmt.Errorf("totp: read random: %w", err)
|
||
|
|
}
|
||
|
|
return buf, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// EncodeSecret returns the base32 (no padding) representation typically
|
||
|
|
// embedded into otpauth URLs.
|
||
|
|
func EncodeSecret(secret []byte) string {
|
||
|
|
return strings.TrimRight(base32.StdEncoding.EncodeToString(secret), "=")
|
||
|
|
}
|
||
|
|
|
||
|
|
// DecodeSecret parses the base32 representation back to raw bytes. Padding is
|
||
|
|
// optional so the function accepts the canonical form used in otpauth URLs.
|
||
|
|
func DecodeSecret(encoded string) ([]byte, error) {
|
||
|
|
if encoded == "" {
|
||
|
|
return nil, ErrInvalidSecret
|
||
|
|
}
|
||
|
|
clean := strings.ToUpper(strings.TrimSpace(encoded))
|
||
|
|
clean = strings.TrimRight(clean, "=")
|
||
|
|
if mod := len(clean) % 8; mod != 0 {
|
||
|
|
clean += strings.Repeat("=", 8-mod)
|
||
|
|
}
|
||
|
|
out, err := base32.StdEncoding.DecodeString(clean)
|
||
|
|
if err != nil {
|
||
|
|
return nil, fmt.Errorf("%w: %w", ErrInvalidSecret, err)
|
||
|
|
}
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Generate returns the TOTP code for the given timestamp.
|
||
|
|
func Generate(secret []byte, ts time.Time, period time.Duration, digits int) (string, error) {
|
||
|
|
if len(secret) == 0 {
|
||
|
|
return "", ErrInvalidSecret
|
||
|
|
}
|
||
|
|
if digits <= 0 || digits > 10 {
|
||
|
|
return "", ErrInvalidDigits
|
||
|
|
}
|
||
|
|
if period <= 0 {
|
||
|
|
period = DefaultPeriod
|
||
|
|
}
|
||
|
|
return computeHOTP(secret, unixCounter(ts, period), digits), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// Verify checks the supplied code against the secret allowing for plus or
|
||
|
|
// minus window time steps. When the code is valid it returns the timestep
|
||
|
|
// counter that matched so callers can persist it for replay protection.
|
||
|
|
func Verify(secret []byte, code string, ts time.Time, period time.Duration, digits, window int) (uint64, bool) {
|
||
|
|
if len(secret) == 0 || code == "" || digits <= 0 {
|
||
|
|
return 0, false
|
||
|
|
}
|
||
|
|
if len(code) != digits {
|
||
|
|
return 0, false
|
||
|
|
}
|
||
|
|
if period <= 0 {
|
||
|
|
period = DefaultPeriod
|
||
|
|
}
|
||
|
|
if window < 0 {
|
||
|
|
window = 0
|
||
|
|
}
|
||
|
|
base := unixCounter(ts, period)
|
||
|
|
want := []byte(code)
|
||
|
|
for i := -window; i <= window; i++ {
|
||
|
|
counter, ok := shiftCounter(base, i)
|
||
|
|
if !ok {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
got := []byte(computeHOTP(secret, counter, digits))
|
||
|
|
if subtle.ConstantTimeCompare(want, got) == 1 {
|
||
|
|
return counter, true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return 0, false
|
||
|
|
}
|
||
|
|
|
||
|
|
// shiftCounter applies a signed offset to base without crossing 0.
|
||
|
|
// Returning (0, false) signals that the requested offset is out of range.
|
||
|
|
func shiftCounter(base uint64, offset int) (uint64, bool) {
|
||
|
|
if offset >= 0 {
|
||
|
|
return base + uint64(offset), true
|
||
|
|
}
|
||
|
|
neg := uint64(-offset) //nolint:gosec // offset is bounded by ±Window; magnitude fits uint64.
|
||
|
|
if base < neg {
|
||
|
|
return 0, false
|
||
|
|
}
|
||
|
|
return base - neg, true
|
||
|
|
}
|
||
|
|
|
||
|
|
func computeHOTP(secret []byte, counter uint64, digits int) string {
|
||
|
|
var buf [8]byte
|
||
|
|
binary.BigEndian.PutUint64(buf[:], counter)
|
||
|
|
mac := hmac.New(sha1.New, secret)
|
||
|
|
mac.Write(buf[:])
|
||
|
|
sum := mac.Sum(nil)
|
||
|
|
offset := sum[len(sum)-1] & 0x0F
|
||
|
|
bin := (uint32(sum[offset]&0x7F) << 24) |
|
||
|
|
(uint32(sum[offset+1]) << 16) |
|
||
|
|
(uint32(sum[offset+2]) << 8) |
|
||
|
|
uint32(sum[offset+3])
|
||
|
|
mod := uint32(1)
|
||
|
|
for range digits {
|
||
|
|
mod *= 10
|
||
|
|
}
|
||
|
|
num := bin % mod
|
||
|
|
return fmt.Sprintf("%0*d", digits, num)
|
||
|
|
}
|
||
|
|
|
||
|
|
// OtpauthURLInput controls the otpauth URL fields shown to users.
|
||
|
|
type OtpauthURLInput struct {
|
||
|
|
Issuer string
|
||
|
|
Account string
|
||
|
|
Secret []byte
|
||
|
|
Algorithm string
|
||
|
|
Digits int
|
||
|
|
Period time.Duration
|
||
|
|
}
|
||
|
|
|
||
|
|
// BuildOtpauthURL renders the canonical otpauth://totp/... URL.
|
||
|
|
func BuildOtpauthURL(in OtpauthURLInput) (string, error) {
|
||
|
|
if in.Issuer == "" {
|
||
|
|
return "", ErrInvalidIssuer
|
||
|
|
}
|
||
|
|
if in.Account == "" {
|
||
|
|
return "", ErrInvalidAccount
|
||
|
|
}
|
||
|
|
if len(in.Secret) == 0 {
|
||
|
|
return "", ErrInvalidSecret
|
||
|
|
}
|
||
|
|
algo := in.Algorithm
|
||
|
|
if algo == "" {
|
||
|
|
algo = "SHA1"
|
||
|
|
}
|
||
|
|
digits := in.Digits
|
||
|
|
if digits <= 0 {
|
||
|
|
digits = DefaultDigits
|
||
|
|
}
|
||
|
|
period := in.Period
|
||
|
|
if period <= 0 {
|
||
|
|
period = DefaultPeriod
|
||
|
|
}
|
||
|
|
|
||
|
|
q := url.Values{}
|
||
|
|
q.Set("secret", EncodeSecret(in.Secret))
|
||
|
|
q.Set("issuer", in.Issuer)
|
||
|
|
q.Set("algorithm", algo)
|
||
|
|
q.Set("digits", fmt.Sprintf("%d", digits))
|
||
|
|
q.Set("period", fmt.Sprintf("%d", int(period.Seconds())))
|
||
|
|
|
||
|
|
label := url.PathEscape(in.Issuer) + ":" + url.PathEscape(in.Account)
|
||
|
|
return "otpauth://totp/" + label + "?" + q.Encode(), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// GenerateBackupCodes returns n random hex-encoded codes of the given length.
|
||
|
|
// They are returned in plaintext; callers must hash before storage.
|
||
|
|
func GenerateBackupCodes(n, length int) ([]string, error) {
|
||
|
|
if n <= 0 {
|
||
|
|
n = BackupCodeNum
|
||
|
|
}
|
||
|
|
if length <= 0 {
|
||
|
|
length = BackupCodeSize
|
||
|
|
}
|
||
|
|
if length%2 != 0 {
|
||
|
|
length++
|
||
|
|
}
|
||
|
|
out := make([]string, 0, n)
|
||
|
|
buf := make([]byte, length/2)
|
||
|
|
for range n {
|
||
|
|
if _, err := rand.Read(buf); err != nil {
|
||
|
|
return nil, fmt.Errorf("totp: backup code rand: %w", err)
|
||
|
|
}
|
||
|
|
out = append(out, hex.EncodeToString(buf))
|
||
|
|
}
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// TimeStep returns the integer time-step counter used by Verify, useful for
|
||
|
|
// replay-protection keys.
|
||
|
|
func TimeStep(ts time.Time, period time.Duration) uint64 {
|
||
|
|
if period <= 0 {
|
||
|
|
period = DefaultPeriod
|
||
|
|
}
|
||
|
|
return unixCounter(ts, period)
|
||
|
|
}
|
||
|
|
|
||
|
|
// unixCounter converts (ts, period) to the RFC 6238 counter, clamping
|
||
|
|
// negative timestamps to zero so the int64→uint64 conversion is bounded.
|
||
|
|
func unixCounter(ts time.Time, period time.Duration) uint64 {
|
||
|
|
sec := ts.Unix()
|
||
|
|
if sec < 0 {
|
||
|
|
sec = 0
|
||
|
|
}
|
||
|
|
per := int64(period.Seconds())
|
||
|
|
if per <= 0 {
|
||
|
|
per = int64(DefaultPeriod.Seconds())
|
||
|
|
}
|
||
|
|
return uint64(sec / per) //nolint:gosec // sec is non-negative and per is positive; quotient fits uint64.
|
||
|
|
}
|