package logic import ( "ark-permission/gen_result/pb/permission" "ark-permission/internal/domain" "ark-permission/internal/entity" ers "ark-permission/internal/lib/error" "ark-permission/internal/svc" "bytes" "context" "crypto/sha256" "encoding/hex" "fmt" "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "time" "github.com/zeromicro/go-zero/core/logx" ) type NewTokenLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger } func NewNewTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *NewTokenLogic { return &NewTokenLogic{ ctx: ctx, svcCtx: svcCtx, Logger: logx.WithContext(ctx), } } // https://datatracker.ietf.org/doc/html/rfc6749#section-3.3 type authorizationReq struct { GrantType domain.GrantType `json:"grant_type" validate:"required,oneof=password client_credentials refresh_token"` DeviceID string `json:"device_id"` Scope string `json:"scope" validate:"required"` Data map[string]any `json:"data"` Expires int `json:"expires"` IsRefreshToken bool `json:"is_refresh_token"` } // NewToken 建立一個新的 Token,例如:AccessToken func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.TokenResp, error) { // 驗證所需 if err := l.svcCtx.Validate.ValidateAll(&authorizationReq{ GrantType: domain.GrantType(in.GetGrantType()), Scope: in.GetScope(), }); err != nil { return nil, ers.InvalidFormat(err.Error()) } // 準備建立 Token 所需 now := time.Now().UTC() expires := int(in.GetExpires()) refreshExpires := int(in.GetExpires()) if expires <= 0 { expires = int(l.svcCtx.Config.Token.Expired.Seconds()) refreshExpires = expires } // 如果這是一個 Refresh Token 過期時間要比普通的Token 長 if in.GetIsRefreshToken() { refreshExpires = int(l.svcCtx.Config.Token.RefreshExpires.Seconds()) } token := entity.Token{ ID: uuid.Must(uuid.NewRandom()).String(), DeviceID: in.GetDeviceId(), ExpiresIn: expires, RefreshExpiresIn: refreshExpires, AccessCreateAt: now, RefreshCreateAt: now, } claims := claims(in.GetData()) claims.SetRole(domain.DefaultRole) claims.SetID(token.ID) claims.SetScope(in.GetScope()) token.UID = claims.UID() if in.GetDeviceId() != "" { claims.SetDeviceID(in.GetDeviceId()) } var err error token.AccessToken, err = generateAccessToken(token, claims, l.svcCtx.Config.Token.Secret) if err != nil { return nil, ers.ArkInternal(fmt.Errorf("accessGenerate token error: %w", err).Error()) } if in.GetIsRefreshToken() { token.RefreshToken = generateRefreshToken(token.AccessToken) } err = l.svcCtx.TokenRedisRepo.Create(l.ctx, token) if err != nil { return nil, ers.ArkInternal(fmt.Errorf("tokenRepository.Create error: %w", err).Error()) } return &permission.TokenResp{ AccessToken: token.AccessToken, TokenType: domain.TokenTypeBearer, ExpiresIn: int32(token.ExpiresIn), RefreshToken: token.RefreshToken, }, nil } func generateAccessToken(token entity.Token, data any, sign string) (string, error) { claim := entity.Claims{ Data: data, RegisteredClaims: jwt.RegisteredClaims{ ID: token.ID, ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)), Issuer: "permission", }, } accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claim). SignedString([]byte(sign)) if err != nil { return "", err } return accessToken, nil } func generateRefreshToken(accessToken string) string { buf := bytes.NewBufferString(accessToken) h := sha256.New() _, _ = h.Write(buf.Bytes()) return hex.EncodeToString(h.Sum(nil)) }