guard/internal/logic/tokenservice/refresh_token_logic.go

134 lines
3.6 KiB
Go
Raw Normal View History

2024-08-10 01:52:23 +00:00
package tokenservicelogic
import (
"ark-permission/internal/domain"
"ark-permission/internal/entity"
ers "code.30cm.net/wanderland/library-go/errors"
"context"
"time"
"ark-permission/gen_result/pb/permission"
"ark-permission/internal/svc"
"github.com/zeromicro/go-zero/core/logx"
)
type RefreshTokenLogic struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func NewRefreshTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RefreshTokenLogic {
return &RefreshTokenLogic{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
type refreshReq struct {
RefreshToken string `json:"grant_type" validate:"required"`
DeviceID string `json:"device_id" validate:"required"`
Scope string `json:"scope" validate:"required"`
Expires int64 `json:"expires" validate:"required"`
}
// RefreshToken 更新目前的token 以及裡面包含的一次性 Token
func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permission.RefreshTokenResp, error) {
// 驗證所需
if err := l.svcCtx.Validate.ValidateAll(&refreshReq{
RefreshToken: in.GetToken(),
Scope: in.GetScope(),
DeviceID: in.GetDeviceId(),
Expires: in.GetExpires(),
}); err != nil {
return nil, ers.InvalidFormat(err.Error())
}
// step 1 拿看看有沒有這個 refresh token
token, err := l.svcCtx.TokenRedisRepo.GetByRefresh(l.ctx, in.Token)
if err != nil {
logx.WithCallerSkip(1).WithFields(
logx.Field("func", "TokenRedisRepo.GetByRefresh"),
logx.Field("req", in),
).Error(err.Error())
return nil, err
}
// 拿到之後替換掉時間以及 refresh token
// refreshToken 建立
now := time.Now().UTC()
sec := time.Duration(l.svcCtx.Config.Token.RefreshExpires.Seconds()) * time.Second
newTime := now.Add(sec)
// 獲取 Unix 時間戳
timestamp := newTime.Unix()
refreshExpires := int(timestamp)
expires := int(in.GetExpires())
if expires <= 0 {
// 將時間加上 300 秒
sec := time.Duration(l.svcCtx.Config.Token.Expired.Seconds()) * time.Second
newTime := now.Add(sec)
// 獲取 Unix 時間戳
timestamp := newTime.Unix()
expires = int(timestamp)
}
newToken := entity.Token{
ID: token.ID,
UID: token.UID,
DeviceID: in.GetDeviceId(),
ExpiresIn: expires,
RefreshExpiresIn: refreshExpires,
AccessCreateAt: now,
RefreshCreateAt: now,
}
claims := claims(map[string]string{
"uid": token.UID,
})
claims.SetRole(domain.DefaultRole)
claims.SetID(token.ID)
claims.SetScope(in.GetScope())
claims.UID()
if in.GetDeviceId() != "" {
claims.SetDeviceID(in.GetDeviceId())
}
newToken.AccessToken, err = generateAccessTokenFunc(newToken, claims, l.svcCtx.Config.Token.Secret)
if err != nil {
logx.WithCallerSkip(1).WithFields(
logx.Field("func", "generateAccessTokenFunc"),
logx.Field("claims", claims),
).Error(err.Error())
return nil, err
}
newToken.RefreshToken = generateRefreshTokenFunc(newToken.AccessToken)
// 刪除掉舊的 token
err = l.svcCtx.TokenRedisRepo.Delete(l.ctx, token)
if err != nil {
logx.WithCallerSkip(1).WithFields(
logx.Field("func", "TokenRedisRepo.Delete"),
logx.Field("req", token),
).Error(err.Error())
return nil, err
}
err = l.svcCtx.TokenRedisRepo.Create(l.ctx, newToken)
if err != nil {
logx.WithCallerSkip(1).WithFields(
logx.Field("func", "TokenRedisRepo.Create"),
logx.Field("token", token),
).Error(err.Error())
return nil, err
}
return &permission.RefreshTokenResp{
Token: newToken.AccessToken,
OneTimeToken: newToken.RefreshToken,
ExpiresIn: int64(expires),
TokenType: domain.TokenTypeBearer,
}, nil
}