78 lines
2.5 KiB
Go
78 lines
2.5 KiB
Go
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
|
||
|
|
"gateway/internal/library/actor"
|
||
|
|
errs "gateway/internal/library/errors"
|
||
|
|
"gateway/internal/library/errors/code"
|
||
|
|
domauth "gateway/internal/model/auth/domain/usecase"
|
||
|
|
"gateway/internal/response"
|
||
|
|
)
|
||
|
|
|
||
|
|
// AuthJWTMiddleware enforces Bearer access tokens on protected routes.
|
||
|
|
//
|
||
|
|
// Mounted via @server(middleware: AuthJWT) in the .api file. Missing or
|
||
|
|
// invalid tokens return 28501000 (Auth scope, unauthorized) — public
|
||
|
|
// routes (register / login / token refresh / health) must NOT mount
|
||
|
|
// this middleware.
|
||
|
|
//
|
||
|
|
// On success the parsed (tenant, uid) is injected via library/actor
|
||
|
|
// so downstream logic can read it through actor.ActorFromContext (or
|
||
|
|
// the package-local member.ActorFromContext / permission.ActorFromContext
|
||
|
|
// aliases).
|
||
|
|
//
|
||
|
|
// File name follows goctl's stringx convention (`authjwt_middleware.go`)
|
||
|
|
// so `make gen-api` sees it as already-generated and never overwrites
|
||
|
|
// the implementation.
|
||
|
|
type AuthJWTMiddleware struct {
|
||
|
|
tokens domauth.TokenUseCase
|
||
|
|
}
|
||
|
|
|
||
|
|
// NewAuthJWTMiddleware wires the middleware with the auth module's
|
||
|
|
// TokenUseCase (set up in ServiceContext.NewServiceContext).
|
||
|
|
func NewAuthJWTMiddleware(tokens domauth.TokenUseCase) *AuthJWTMiddleware {
|
||
|
|
return &AuthJWTMiddleware{tokens: tokens}
|
||
|
|
}
|
||
|
|
|
||
|
|
// Handle implements the go-zero rest.Middleware signature.
|
||
|
|
func (m *AuthJWTMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||
|
|
bld := errs.For(code.Auth)
|
||
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
||
|
|
if m.tokens == nil {
|
||
|
|
response.Write(r.Context(), w, nil,
|
||
|
|
bld.SysNotImplemented("auth middleware: token usecase not configured"))
|
||
|
|
return
|
||
|
|
}
|
||
|
|
raw := bearerToken(r.Header.Get("Authorization"))
|
||
|
|
if raw == "" {
|
||
|
|
response.Write(r.Context(), w, nil,
|
||
|
|
bld.AuthUnauthorized("missing bearer token"))
|
||
|
|
return
|
||
|
|
}
|
||
|
|
claims, err := m.tokens.ParseAccessToken(r.Context(), raw)
|
||
|
|
if err != nil {
|
||
|
|
// surface already-typed Auth errors as-is so the biz code
|
||
|
|
// (e.g. expired vs invalid) is preserved.
|
||
|
|
if e := errs.FromError(err); e != nil && e.Category() == code.AuthUnauthorized {
|
||
|
|
response.Write(r.Context(), w, nil, err)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
response.Write(r.Context(), w, nil,
|
||
|
|
bld.AuthUnauthorized("invalid bearer token").WithCause(err))
|
||
|
|
return
|
||
|
|
}
|
||
|
|
ctx := actor.WithActor(r.Context(), claims.TenantID, claims.UID)
|
||
|
|
next(w, r.WithContext(ctx))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func bearerToken(header string) string {
|
||
|
|
const prefix = "Bearer "
|
||
|
|
if !strings.HasPrefix(header, prefix) {
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
return strings.TrimSpace(strings.TrimPrefix(header, prefix))
|
||
|
|
}
|