138 lines
3.6 KiB
Go
138 lines
3.6 KiB
Go
package logic
|
||
|
||
import (
|
||
"ark-permission/gen_result/pb/permission"
|
||
"ark-permission/internal/domain"
|
||
"ark-permission/internal/entity"
|
||
ers "ark-permission/internal/lib/error"
|
||
"ark-permission/internal/svc"
|
||
"bytes"
|
||
"context"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"fmt"
|
||
"github.com/golang-jwt/jwt/v4"
|
||
"github.com/google/uuid"
|
||
"time"
|
||
|
||
"github.com/zeromicro/go-zero/core/logx"
|
||
)
|
||
|
||
type NewTokenLogic struct {
|
||
ctx context.Context
|
||
svcCtx *svc.ServiceContext
|
||
logx.Logger
|
||
}
|
||
|
||
func NewNewTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *NewTokenLogic {
|
||
return &NewTokenLogic{
|
||
ctx: ctx,
|
||
svcCtx: svcCtx,
|
||
Logger: logx.WithContext(ctx),
|
||
}
|
||
}
|
||
|
||
// https://datatracker.ietf.org/doc/html/rfc6749#section-3.3
|
||
type authorizationReq struct {
|
||
GrantType domain.GrantType `json:"grant_type" validate:"required,oneof=password client_credentials refresh_token"`
|
||
DeviceID string `json:"device_id"`
|
||
Scope string `json:"scope" validate:"required"`
|
||
Data map[string]any `json:"data"`
|
||
Expires int `json:"expires"`
|
||
IsRefreshToken bool `json:"is_refresh_token"`
|
||
}
|
||
|
||
// NewToken 建立一個新的 Token,例如:AccessToken
|
||
func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.TokenResp, error) {
|
||
// 驗證所需
|
||
if err := l.svcCtx.Validate.ValidateAll(&authorizationReq{
|
||
GrantType: domain.GrantType(in.GetGrantType()),
|
||
Scope: in.GetScope(),
|
||
}); err != nil {
|
||
return nil, ers.InvalidFormat(err.Error())
|
||
}
|
||
|
||
// 準備建立 Token 所需
|
||
now := time.Now().UTC()
|
||
expires := int(in.GetExpires())
|
||
refreshExpires := int(in.GetExpires())
|
||
if expires <= 0 {
|
||
expires = int(l.svcCtx.Config.Token.Expired.Seconds())
|
||
refreshExpires = expires
|
||
}
|
||
|
||
// 如果這是一個 Refresh Token 過期時間要比普通的Token 長
|
||
if in.GetIsRefreshToken() {
|
||
refreshExpires = int(l.svcCtx.Config.Token.RefreshExpires.Seconds())
|
||
}
|
||
|
||
token := entity.Token{
|
||
ID: uuid.Must(uuid.NewRandom()).String(),
|
||
DeviceID: in.GetDeviceId(),
|
||
ExpiresIn: expires,
|
||
RefreshExpiresIn: refreshExpires,
|
||
AccessCreateAt: now,
|
||
RefreshCreateAt: now,
|
||
}
|
||
|
||
claims := claims(in.GetData())
|
||
claims.SetRole(domain.DefaultRole)
|
||
claims.SetID(token.ID)
|
||
claims.SetScope(in.GetScope())
|
||
|
||
token.UID = claims.UID()
|
||
|
||
if in.GetDeviceId() != "" {
|
||
claims.SetDeviceID(in.GetDeviceId())
|
||
}
|
||
|
||
var err error
|
||
token.AccessToken, err = generateAccessToken(token, claims, l.svcCtx.Config.Token.Secret)
|
||
if err != nil {
|
||
return nil, ers.ArkInternal(fmt.Errorf("accessGenerate token error: %w", err).Error())
|
||
}
|
||
|
||
if in.GetIsRefreshToken() {
|
||
token.RefreshToken = generateRefreshToken(token.AccessToken)
|
||
}
|
||
|
||
err = l.svcCtx.TokenRedisRepo.Create(l.ctx, token)
|
||
if err != nil {
|
||
return nil, ers.ArkInternal(fmt.Errorf("tokenRepository.Create error: %w", err).Error())
|
||
}
|
||
|
||
return &permission.TokenResp{
|
||
AccessToken: token.AccessToken,
|
||
TokenType: domain.TokenTypeBearer,
|
||
ExpiresIn: int32(token.ExpiresIn),
|
||
RefreshToken: token.RefreshToken,
|
||
}, nil
|
||
}
|
||
|
||
func generateAccessToken(token entity.Token, data any, sign string) (string, error) {
|
||
claim := entity.Claims{
|
||
Data: data,
|
||
RegisteredClaims: jwt.RegisteredClaims{
|
||
ID: token.ID,
|
||
ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)),
|
||
Issuer: "permission",
|
||
},
|
||
}
|
||
|
||
accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claim).
|
||
SignedString([]byte(sign))
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
return accessToken, nil
|
||
}
|
||
|
||
func generateRefreshToken(accessToken string) string {
|
||
buf := bytes.NewBufferString(accessToken)
|
||
h := sha256.New()
|
||
_, _ = h.Write(buf.Bytes())
|
||
|
||
return hex.EncodeToString(h.Sum(nil))
|
||
}
|