From 2d485cf0951414e4c51ec8719c759c87106cba08 Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Sun, 11 Aug 2024 20:21:42 +0800 Subject: [PATCH] fix: complete token service --- client/tokenservice/token_service.go | 32 +- generate/protobuf/permission.proto | 2 +- internal/domain/redis.go | 1 + internal/domain/repository/token.go | 4 +- internal/entity/token.go | 12 + .../cancel_token_by_device_id_logic.go | 15 +- .../tokenservice/cancel_token_by_uid_logic.go | 48 --- .../logic/tokenservice/cancel_token_logic.go | 4 +- .../logic/tokenservice/cancel_tokens_logic.go | 52 +++ .../get_user_tokens_by_device_id_logic.go | 41 +- .../get_user_tokens_by_uid_logic.go | 1 - .../tokenservice/new_one_time_token_logic.go | 7 +- .../logic/tokenservice/new_token_logic.go | 146 +++---- .../logic/tokenservice/refresh_token_logic.go | 81 ++-- internal/logic/tokenservice/utils_jwt.go | 63 +-- .../tokenservice/validation_token_logic.go | 4 +- internal/repository/token.go | 370 ++++++++++-------- .../tokenservice/token_service_server.go | 24 +- 18 files changed, 500 insertions(+), 407 deletions(-) delete mode 100644 internal/logic/tokenservice/cancel_token_by_uid_logic.go create mode 100644 internal/logic/tokenservice/cancel_tokens_logic.go diff --git a/client/tokenservice/token_service.go b/client/tokenservice/token_service.go index 2bc1af2..617dc38 100644 --- a/client/tokenservice/token_service.go +++ b/client/tokenservice/token_service.go @@ -37,12 +37,12 @@ type ( RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) // CancelToken 取消 Token,也包含他裡面的 One Time Toke CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) - // CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke - CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) - // CancelTokenByDeviceId 取消 Token - CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) // ValidationToken 驗證這個 Token 有沒有效 ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) + // CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 + CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) + // CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token + CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) // GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens @@ -82,24 +82,24 @@ func (m *defaultTokenService) CancelToken(ctx context.Context, in *CancelTokenRe return client.CancelToken(ctx, in, opts...) } -// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (m *defaultTokenService) CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelTokenByUid(ctx, in, opts...) -} - -// CancelTokenByDeviceId 取消 Token -func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelTokenByDeviceId(ctx, in, opts...) -} - // ValidationToken 驗證這個 Token 有沒有效 func (m *defaultTokenService) ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) return client.ValidationToken(ctx, in, opts...) } +// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 +func (m *defaultTokenService) CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelTokens(ctx, in, opts...) +} + +// CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token +func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelTokenByDeviceId(ctx, in, opts...) +} + // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) diff --git a/generate/protobuf/permission.proto b/generate/protobuf/permission.proto index 38bfdc0..e6ed508 100644 --- a/generate/protobuf/permission.proto +++ b/generate/protobuf/permission.proto @@ -114,7 +114,7 @@ message Token { // DoTokenByDeviceIDReq 用DeviceID 來做事的 message DoTokenByDeviceIDReq { - repeated string device_id = 1; + string device_id = 1; } message Tokens{ diff --git a/internal/domain/redis.go b/internal/domain/redis.go index 2fb3b69..2e31a91 100644 --- a/internal/domain/redis.go +++ b/internal/domain/redis.go @@ -18,6 +18,7 @@ const ( DeviceTokenRedisKey RedisKey = "device_token" UIDTokenRedisKey RedisKey = "uid_token" TicketRedisKey RedisKey = "ticket" + DeviceUIDRedisKey RedisKey = "device_uid" ) func (key RedisKey) ToString() string { diff --git a/internal/domain/repository/token.go b/internal/domain/repository/token.go index 57bd4f6..be4c207 100644 --- a/internal/domain/repository/token.go +++ b/internal/domain/repository/token.go @@ -19,11 +19,9 @@ type TokenRepository interface { GetAccessTokenCountByDeviceID(deviceID string) (int, error) Delete(ctx context.Context, token entity.Token) error - DeleteAccessTokenByID(ctx context.Context, id string) error + DeleteAccessTokenByID(ctx context.Context, ids []string) error DeleteAccessTokensByUID(ctx context.Context, uid string) error DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error - DeleteAccessTokenByDeviceIDAndUID(ctx context.Context, deviceID, uid string) error - DeleteUIDToken(ctx context.Context, uid string, ids []string) error } type DeviceToken struct { diff --git a/internal/entity/token.go b/internal/entity/token.go index c2cd7fc..791395a 100644 --- a/internal/entity/token.go +++ b/internal/entity/token.go @@ -30,6 +30,18 @@ func (t *Token) IsExpires() bool { return t.AccessCreateAt.Add(t.AccessTokenExpires()).Before(time.Now()) } +func (t *Token) RedisExpiredSec() int64 { + sec := time.Unix(int64(t.ExpiresIn), 0).Sub(time.Now().UTC()) + + return int64(sec.Seconds()) +} + +func (t *Token) RedisRefreshExpiredSec() int64 { + sec := time.Unix(int64(t.RefreshExpiresIn), 0).Sub(time.Now().UTC()) + + return int64(sec.Seconds()) +} + type UIDToken map[string]int64 type Ticket struct { diff --git a/internal/logic/tokenservice/cancel_token_by_device_id_logic.go b/internal/logic/tokenservice/cancel_token_by_device_id_logic.go index ed59824..e726a26 100644 --- a/internal/logic/tokenservice/cancel_token_by_device_id_logic.go +++ b/internal/logic/tokenservice/cancel_token_by_device_id_logic.go @@ -1,6 +1,7 @@ package tokenservicelogic import ( + ers "code.30cm.net/wanderland/library-go/errors" "context" "ark-permission/gen_result/pb/permission" @@ -25,7 +26,19 @@ func NewCancelTokenByDeviceIdLogic(ctx context.Context, svcCtx *svc.ServiceConte // CancelTokenByDeviceId 取消 Token func (l *CancelTokenByDeviceIdLogic) CancelTokenByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { - // todo: add your logic here and delete this line + if err := l.svcCtx.Validate.ValidateAll(&getUserTokensByDeviceIdReq{ + DeviceID: in.GetDeviceId(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } + err := l.svcCtx.TokenRedisRepo.DeleteAccessTokensByDeviceID(l.ctx, in.GetDeviceId()) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.DeleteAccessTokensByDeviceID"), + logx.Field("DeviceID", in.GetDeviceId()), + ).Error(err.Error()) + return nil, err + } return &permission.OKResp{}, nil } diff --git a/internal/logic/tokenservice/cancel_token_by_uid_logic.go b/internal/logic/tokenservice/cancel_token_by_uid_logic.go deleted file mode 100644 index f266185..0000000 --- a/internal/logic/tokenservice/cancel_token_by_uid_logic.go +++ /dev/null @@ -1,48 +0,0 @@ -package tokenservicelogic - -import ( - ers "code.30cm.net/wanderland/library-go/errors" - "context" - - "ark-permission/gen_result/pb/permission" - "ark-permission/internal/svc" - - "github.com/zeromicro/go-zero/core/logx" -) - -type CancelTokenByUidLogic struct { - ctx context.Context - svcCtx *svc.ServiceContext - logx.Logger -} - -func NewCancelTokenByUidLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokenByUidLogic { - return &CancelTokenByUidLogic{ - ctx: ctx, - svcCtx: svcCtx, - Logger: logx.WithContext(ctx), - } -} - -type deleteByTokenIDs struct { - UID string `json:"uid" binding:"required"` - IDs []string `json:"ids" binding:"required"` -} - -// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (l *CancelTokenByUidLogic) CancelTokenByUid(in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { - // 驗證所需 - if err := l.svcCtx.Validate.ValidateAll(&deleteByTokenIDs{ - UID: in.GetUid(), - IDs: in.GetIds(), - }); err != nil { - return nil, ers.InvalidFormat(err.Error()) - } - - err := l.svcCtx.TokenRedisRepo.DeleteUIDToken(l.ctx, in.GetUid(), in.GetIds()) - if err != nil { - return nil, err - } - - return &permission.OKResp{}, nil -} diff --git a/internal/logic/tokenservice/cancel_token_logic.go b/internal/logic/tokenservice/cancel_token_logic.go index 4f615bd..bf89e59 100644 --- a/internal/logic/tokenservice/cancel_token_logic.go +++ b/internal/logic/tokenservice/cancel_token_logic.go @@ -37,7 +37,7 @@ func (l *CancelTokenLogic) CancelToken(in *permission.CancelTokenReq) (*permissi return nil, ers.InvalidFormat(err.Error()) } - claims, err := parseClaims(l.ctx, in.GetToken(), l.svcCtx.Config.Token.Secret) + claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, false) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "parseClaims"), @@ -45,7 +45,7 @@ func (l *CancelTokenLogic) CancelToken(in *permission.CancelTokenReq) (*permissi return nil, err } - token, err := l.svcCtx.TokenRedisRepo.GetByAccess(l.ctx, claims.ID()) + token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID()) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "TokenRedisRepo.GetByAccess"), diff --git a/internal/logic/tokenservice/cancel_tokens_logic.go b/internal/logic/tokenservice/cancel_tokens_logic.go new file mode 100644 index 0000000..01fac22 --- /dev/null +++ b/internal/logic/tokenservice/cancel_tokens_logic.go @@ -0,0 +1,52 @@ +package tokenservicelogic + +import ( + ers "code.30cm.net/wanderland/library-go/errors" + "context" + + "ark-permission/gen_result/pb/permission" + "ark-permission/internal/svc" + + "github.com/zeromicro/go-zero/core/logx" +) + +type CancelTokensLogic struct { + ctx context.Context + svcCtx *svc.ServiceContext + logx.Logger +} + +func NewCancelTokensLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokensLogic { + return &CancelTokensLogic{ + ctx: ctx, + svcCtx: svcCtx, + Logger: logx.WithContext(ctx), + } +} + +// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 +func (l *CancelTokensLogic) CancelTokens(in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { + if in.GetUid() != "" { + err := l.svcCtx.TokenRedisRepo.DeleteAccessTokensByUID(l.ctx, in.GetUid()) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.DeleteAccessTokensByUID"), + logx.Field("uid", in.GetUid()), + ).Error(err.Error()) + return nil, ers.ResourceInsufficient(err.Error()) + } + } + + if len(in.GetIds()) > 0 { + err := l.svcCtx.TokenRedisRepo.DeleteAccessTokenByID(l.ctx, in.GetIds()) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.DeleteAccessTokenByID"), + logx.Field("ids", in.GetIds()), + ).Error(err.Error()) + return nil, ers.ResourceInsufficient(err.Error()) + } + } + + return &permission.OKResp{}, nil +} diff --git a/internal/logic/tokenservice/get_user_tokens_by_device_id_logic.go b/internal/logic/tokenservice/get_user_tokens_by_device_id_logic.go index 93836ea..fbc773d 100644 --- a/internal/logic/tokenservice/get_user_tokens_by_device_id_logic.go +++ b/internal/logic/tokenservice/get_user_tokens_by_device_id_logic.go @@ -2,7 +2,9 @@ package tokenservicelogic import ( "ark-permission/gen_result/pb/permission" + "ark-permission/internal/domain" "ark-permission/internal/svc" + ers "code.30cm.net/wanderland/library-go/errors" "context" "github.com/zeromicro/go-zero/core/logx" ) @@ -21,23 +23,34 @@ func NewGetUserTokensByDeviceIdLogic(ctx context.Context, svcCtx *svc.ServiceCon } } +type getUserTokensByDeviceIdReq struct { + DeviceID string `json:"device_id" validate:"required"` +} + // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens func (l *GetUserTokensByDeviceIdLogic) GetUserTokensByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { + if err := l.svcCtx.Validate.ValidateAll(&getUserTokensByDeviceIdReq{ + DeviceID: in.GetDeviceId(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } - // ids, err := l.svcCtx.TokenRedisRepo.GetAccessTokensByDeviceID(l.ctx, "") - // if err != nil { - // return nil, error - // } + uidTokens, err := l.svcCtx.TokenRedisRepo.GetAccessTokensByDeviceID(l.ctx, in.GetDeviceId()) + if err != nil { + return nil, err + } - // tokenIDs := make([]usecase.DeviceToken, 0, len(ids)) - // for _, v := range ids { - // tokenIDs = append(tokenIDs, usecase.DeviceToken{ - // DeviceID: v.DeviceID, - // TokenID: v.TokenID, - // }) - // } - // - // return tokenIDs, nil + tokens := make([]*permission.TokenResp, 0, len(uidTokens)) + for _, v := range uidTokens { + tokens = append(tokens, &permission.TokenResp{ + AccessToken: v.AccessToken, + TokenType: domain.TokenTypeBearer, + ExpiresIn: int32(v.ExpiresIn), + RefreshToken: v.RefreshToken, + }) + } - return &permission.Tokens{}, nil + return &permission.Tokens{ + Token: tokens, + }, nil } diff --git a/internal/logic/tokenservice/get_user_tokens_by_uid_logic.go b/internal/logic/tokenservice/get_user_tokens_by_uid_logic.go index eb3a1f0..9d8616e 100644 --- a/internal/logic/tokenservice/get_user_tokens_by_uid_logic.go +++ b/internal/logic/tokenservice/get_user_tokens_by_uid_logic.go @@ -6,7 +6,6 @@ import ( "ark-permission/internal/svc" ers "code.30cm.net/wanderland/library-go/errors" "context" - "github.com/zeromicro/go-zero/core/logx" ) diff --git a/internal/logic/tokenservice/new_one_time_token_logic.go b/internal/logic/tokenservice/new_one_time_token_logic.go index bcb02d5..79db981 100644 --- a/internal/logic/tokenservice/new_one_time_token_logic.go +++ b/internal/logic/tokenservice/new_one_time_token_logic.go @@ -5,6 +5,7 @@ import ( "ark-permission/internal/entity" ers "code.30cm.net/wanderland/library-go/errors" "context" + "github.com/google/uuid" "time" "ark-permission/gen_result/pb/permission" @@ -37,7 +38,7 @@ func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeToken } // 驗證Token - claims, err := parseClaims(l.ctx, in.GetToken(), l.svcCtx.Config.Token.Secret) + claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, false) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "parseClaims"), @@ -45,7 +46,7 @@ func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeToken return nil, err } - token, err := l.svcCtx.TokenRedisRepo.GetByAccess(l.ctx, claims.ID()) + token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID()) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "TokenRedisRepo.GetByAccess"), @@ -54,7 +55,7 @@ func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeToken return nil, err } - oneTimeToken := generateRefreshToken(in.GetToken()) + oneTimeToken := generateRefreshToken(uuid.Must(uuid.NewRandom()).String()) key := domain.TicketKeyPrefix + oneTimeToken if err = l.svcCtx.TokenRedisRepo.CreateOneTimeToken(l.ctx, key, entity.Ticket{ Data: claims, diff --git a/internal/logic/tokenservice/new_token_logic.go b/internal/logic/tokenservice/new_token_logic.go index 487e3f9..fdccf8f 100644 --- a/internal/logic/tokenservice/new_token_logic.go +++ b/internal/logic/tokenservice/new_token_logic.go @@ -1,6 +1,7 @@ package tokenservicelogic import ( + "ark-permission/internal/config" "ark-permission/internal/domain" "ark-permission/internal/entity" ers "code.30cm.net/wanderland/library-go/errors" @@ -30,83 +31,34 @@ func NewNewTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *NewToken // https://datatracker.ietf.org/doc/html/rfc6749#section-3.3 type authorizationReq struct { - GrantType domain.GrantType `json:"grant_type" validate:"required,oneof=password client_credentials refresh_token"` - DeviceID string `json:"device_id"` - Scope string `json:"scope" validate:"required"` - Data map[string]any `json:"data"` - Expires int `json:"expires"` - IsRefreshToken bool `json:"is_refresh_token"` + GrantType domain.GrantType `json:"grant_type" validate:"required,oneof=password client_credentials refresh_token"` + DeviceID string `json:"device_id"` + Scope string `json:"scope" validate:"required"` + Data map[string]string `json:"data"` + Expires int `json:"expires"` + IsRefreshToken bool `json:"is_refresh_token"` } // NewToken 建立一個新的 Token,例如:AccessToken func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.TokenResp, error) { + data := authorizationReq{ + GrantType: domain.GrantType(in.GetGrantType()), + Scope: in.GetScope(), + DeviceID: in.GetDeviceId(), + Data: in.GetData(), + Expires: int(in.GetExpires()), + IsRefreshToken: in.GetIsRefreshToken(), + } // 驗證所需 - if err := l.svcCtx.Validate.ValidateAll(&authorizationReq{ - GrantType: domain.GrantType(in.GetGrantType()), - Scope: in.GetScope(), - }); err != nil { + if err := l.svcCtx.Validate.ValidateAll(&data); err != nil { return nil, ers.InvalidFormat(err.Error()) } - - // 準備建立 Token 所需 - now := time.Now().UTC() - expires := int(in.GetExpires()) - refreshExpires := int(in.GetExpires()) - if expires <= 0 { - // 將時間加上 300 秒 - sec := time.Duration(l.svcCtx.Config.Token.Expired.Seconds()) * time.Second - newTime := now.Add(sec) - // 獲取 Unix 時間戳 - timestamp := newTime.Unix() - expires = int(timestamp) - refreshExpires = expires - } - - // 如果這是一個 Refresh Token 過期時間要比普通的Token 長 - if in.GetIsRefreshToken() { - // 將時間加上 300 秒 - sec := time.Duration(l.svcCtx.Config.Token.RefreshExpires.Seconds()) * time.Second - newTime := now.Add(sec) - // 獲取 Unix 時間戳 - timestamp := newTime.Unix() - refreshExpires = int(timestamp) - } - - token := entity.Token{ - ID: uuid.Must(uuid.NewRandom()).String(), - DeviceID: in.GetDeviceId(), - ExpiresIn: expires, - RefreshExpiresIn: refreshExpires, - AccessCreateAt: now, - RefreshCreateAt: now, - } - - claims := claims(in.GetData()) - claims.SetRole(domain.DefaultRole) - claims.SetID(token.ID) - claims.SetScope(in.GetScope()) - - token.UID = claims.UID() - - if in.GetDeviceId() != "" { - claims.SetDeviceID(in.GetDeviceId()) - } - - var err error - token.AccessToken, err = generateAccessTokenFunc(token, claims, l.svcCtx.Config.Token.Secret) + token, err := newToken(data, l.svcCtx.Config) if err != nil { - logx.WithCallerSkip(1).WithFields( - logx.Field("func", "generateAccessTokenFunc"), - logx.Field("claims", claims), - ).Error(err.Error()) return nil, err } - if in.GetIsRefreshToken() { - token.RefreshToken = generateRefreshTokenFunc(token.AccessToken) - } - - err = l.svcCtx.TokenRedisRepo.Create(l.ctx, token) + err = l.svcCtx.TokenRedisRepo.Create(l.ctx, *token) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "TokenRedisRepo.Create"), @@ -122,3 +74,65 @@ func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.T RefreshToken: token.RefreshToken, }, nil } + +func newToken(authReq authorizationReq, cfg config.Config) (*entity.Token, error) { + // 準備建立 Token 所需 + now := time.Now().UTC() + expires := authReq.Expires + refreshExpires := authReq.Expires + if expires <= 0 { + // 將時間加上 300 秒 + sec := time.Duration(cfg.Token.Expired.Seconds()) * time.Second + newTime := now.Add(sec) + // 獲取 Unix 時間戳 + timestamp := newTime.Unix() + expires = int(timestamp) + refreshExpires = expires + } + + // 如果這是一個 Refresh Token 過期時間要比普通的Token 長 + if authReq.IsRefreshToken { + // 將時間加上 300 秒 + sec := time.Duration(cfg.Token.RefreshExpires.Seconds()) * time.Second + newTime := now.Add(sec) + // 獲取 Unix 時間戳 + timestamp := newTime.Unix() + refreshExpires = int(timestamp) + } + + token := entity.Token{ + ID: uuid.Must(uuid.NewRandom()).String(), + DeviceID: authReq.DeviceID, + ExpiresIn: expires, + RefreshExpiresIn: refreshExpires, + AccessCreateAt: now, + RefreshCreateAt: now, + } + + claims := claims(authReq.Data) + claims.SetRole(domain.DefaultRole) + claims.SetID(token.ID) + claims.SetScope(authReq.Scope) + + token.UID = claims.UID() + + if authReq.DeviceID != "" { + claims.SetDeviceID(authReq.DeviceID) + } + + var err error + token.AccessToken, err = generateAccessTokenFunc(token, claims, cfg.Token.Secret) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "generateAccessTokenFunc"), + logx.Field("claims", claims), + ).Error(err.Error()) + return nil, err + } + + if authReq.IsRefreshToken { + token.RefreshToken = generateRefreshTokenFunc(token.AccessToken) + } + + return &token, nil +} diff --git a/internal/logic/tokenservice/refresh_token_logic.go b/internal/logic/tokenservice/refresh_token_logic.go index dee0484..607fd28 100644 --- a/internal/logic/tokenservice/refresh_token_logic.go +++ b/internal/logic/tokenservice/refresh_token_logic.go @@ -1,14 +1,11 @@ package tokenservicelogic import ( + "ark-permission/gen_result/pb/permission" "ark-permission/internal/domain" - "ark-permission/internal/entity" + "ark-permission/internal/svc" ers "code.30cm.net/wanderland/library-go/errors" "context" - "time" - - "ark-permission/gen_result/pb/permission" - "ark-permission/internal/svc" "github.com/zeromicro/go-zero/core/logx" ) @@ -31,7 +28,6 @@ type refreshReq struct { RefreshToken string `json:"grant_type" validate:"required"` DeviceID string `json:"device_id" validate:"required"` Scope string `json:"scope" validate:"required"` - Expires int64 `json:"expires" validate:"required"` } // RefreshToken 更新目前的token 以及裡面包含的一次性 Token @@ -41,10 +37,10 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi RefreshToken: in.GetToken(), Scope: in.GetScope(), DeviceID: in.GetDeviceId(), - Expires: in.GetExpires(), }); err != nil { return nil, ers.InvalidFormat(err.Error()) } + // step 1 拿看看有沒有這個 refresh token token, err := l.svcCtx.TokenRedisRepo.GetByRefresh(l.ctx, in.Token) if err != nil { @@ -54,56 +50,33 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi ).Error(err.Error()) return nil, err } - // 拿到之後替換掉時間以及 refresh token - // refreshToken 建立 - now := time.Now().UTC() - sec := time.Duration(l.svcCtx.Config.Token.RefreshExpires.Seconds()) * time.Second - newTime := now.Add(sec) - // 獲取 Unix 時間戳 - timestamp := newTime.Unix() - refreshExpires := int(timestamp) - expires := int(in.GetExpires()) - if expires <= 0 { - // 將時間加上 300 秒 - sec := time.Duration(l.svcCtx.Config.Token.Expired.Seconds()) * time.Second - newTime := now.Add(sec) - // 獲取 Unix 時間戳 - timestamp := newTime.Unix() - expires = int(timestamp) - } - newToken := entity.Token{ - ID: token.ID, - UID: token.UID, - DeviceID: in.GetDeviceId(), - ExpiresIn: expires, - RefreshExpiresIn: refreshExpires, - AccessCreateAt: now, - RefreshCreateAt: now, - } - - claims := claims(map[string]string{ - "uid": token.UID, - }) - claims.SetRole(domain.DefaultRole) - claims.SetID(token.ID) - claims.SetScope(in.GetScope()) - claims.UID() - - if in.GetDeviceId() != "" { - claims.SetDeviceID(in.GetDeviceId()) - } - - newToken.AccessToken, err = generateAccessTokenFunc(newToken, claims, l.svcCtx.Config.Token.Secret) + // 取得 Data + c, err := parseClaims(token.AccessToken, l.svcCtx.Config.Token.Secret, false) if err != nil { logx.WithCallerSkip(1).WithFields( - logx.Field("func", "generateAccessTokenFunc"), - logx.Field("claims", claims), + logx.Field("func", "parseClaims"), + logx.Field("token", token), ).Error(err.Error()) return nil, err } - newToken.RefreshToken = generateRefreshTokenFunc(newToken.AccessToken) + // step 2 建立新 token + nt, err := newToken(authorizationReq{ + GrantType: domain.ClientCredentials, + Scope: in.GetScope(), + DeviceID: in.GetDeviceId(), + Data: c, + Expires: int(in.GetExpires()), + IsRefreshToken: true, + }, l.svcCtx.Config) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "newToken"), + logx.Field("req", in), + ).Error(err.Error()) + return nil, err + } // 刪除掉舊的 token err = l.svcCtx.TokenRedisRepo.Delete(l.ctx, token) @@ -115,7 +88,7 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi return nil, err } - err = l.svcCtx.TokenRedisRepo.Create(l.ctx, newToken) + err = l.svcCtx.TokenRedisRepo.Create(l.ctx, *nt) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "TokenRedisRepo.Create"), @@ -125,9 +98,9 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi } return &permission.RefreshTokenResp{ - Token: newToken.AccessToken, - OneTimeToken: newToken.RefreshToken, - ExpiresIn: int64(expires), + Token: nt.AccessToken, + OneTimeToken: nt.RefreshToken, + ExpiresIn: int64(nt.ExpiresIn), TokenType: domain.TokenTypeBearer, }, nil } diff --git a/internal/logic/tokenservice/utils_jwt.go b/internal/logic/tokenservice/utils_jwt.go index c6550b9..3ef34fd 100644 --- a/internal/logic/tokenservice/utils_jwt.go +++ b/internal/logic/tokenservice/utils_jwt.go @@ -4,7 +4,6 @@ import ( "ark-permission/internal/domain" "ark-permission/internal/entity" "bytes" - "context" "crypto/sha256" "encoding/hex" "fmt" @@ -42,43 +41,53 @@ func generateRefreshToken(accessToken string) string { return hex.EncodeToString(h.Sum(nil)) } -func parseClaims(ctx context.Context, accessToken string, secret string) (claims, error) { - claimMap, err := parseToken(accessToken, secret) - if err != nil { - return claims{}, err - } +func parseToken(accessToken string, secret string, validate bool) (jwt.MapClaims, error) { + // 跳過驗證的解析 + var token *jwt.Token + var err error - claims, ok := claimMap["data"].(map[string]any) - if ok { - - return convertMap(claims), nil - } - - return nil, domain.TokenClaimError("get data from claim map error") -} - -func parseToken(accessToken string, secret string) (jwt.MapClaims, error) { - token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, domain.TokenUnexpectedSigningErr(fmt.Sprintf("token unexpected signing method: %v", token.Header["alg"])) + if validate { + token, err = jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, domain.TokenUnexpectedSigningErr(fmt.Sprintf("token unexpected signing method: %v", token.Header["alg"])) + } + return []byte(secret), nil + }) + if err != nil { + return jwt.MapClaims{}, err + } + } else { + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + token, err = parser.Parse(accessToken, func(token *jwt.Token) (interface{}, error) { + return []byte(secret), nil + }) + if err != nil { + return jwt.MapClaims{}, err } - - return []byte(secret), nil - }) - - if err != nil { - return jwt.MapClaims{}, err } claims, ok := token.Claims.(jwt.MapClaims) - - if !(ok && token.Valid) { + if !ok && token.Valid { return jwt.MapClaims{}, domain.TokenTokenValidateErr("token valid error") } return claims, nil } +func parseClaims(accessToken string, secret string, validate bool) (claims, error) { + claimMap, err := parseToken(accessToken, secret, validate) + if err != nil { + return claims{}, err + } + + claimsData, ok := claimMap["data"].(map[string]any) + if ok { + return convertMap(claimsData), nil + } + + return claims{}, domain.TokenClaimError("get data from claim map error") +} + func convertMap(input map[string]interface{}) map[string]string { output := make(map[string]string) for key, value := range input { diff --git a/internal/logic/tokenservice/validation_token_logic.go b/internal/logic/tokenservice/validation_token_logic.go index fd00492..f3baf9e 100644 --- a/internal/logic/tokenservice/validation_token_logic.go +++ b/internal/logic/tokenservice/validation_token_logic.go @@ -37,7 +37,7 @@ func (l *ValidationTokenLogic) ValidationToken(in *permission.ValidationTokenReq return nil, ers.InvalidFormat(err.Error()) } - claims, err := parseClaims(l.ctx, in.GetToken(), l.svcCtx.Config.Token.Secret) + claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, true) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "parseClaims"), @@ -45,7 +45,7 @@ func (l *ValidationTokenLogic) ValidationToken(in *permission.ValidationTokenReq return nil, err } - token, err := l.svcCtx.TokenRedisRepo.GetByAccess(l.ctx, claims.ID()) + token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID()) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "TokenRedisRepo.GetByAccess"), diff --git a/internal/repository/token.go b/internal/repository/token.go index e434c3b..17a325f 100644 --- a/internal/repository/token.go +++ b/internal/repository/token.go @@ -22,41 +22,6 @@ type tokenRepository struct { store *redis.Redis } -func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) { - // TODO implement me - panic("implement me") -} - -func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { - // TODO implement me - panic("implement me") -} - -func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) { - // TODO implement me - panic("implement me") -} - -func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, id string) error { - // TODO implement me - panic("implement me") -} - -func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error { - // TODO implement me - panic("implement me") -} - -func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { - // TODO implement me - panic("implement me") -} - -func (t *tokenRepository) DeleteAccessTokenByDeviceIDAndUID(ctx context.Context, deviceID, uid string) error { - // TODO implement me - panic("implement me") -} - func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository { return &tokenRepository{ store: param.Store, @@ -70,17 +35,19 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error } err = t.store.Pipelined(func(tx redis.Pipeliner) error { - rTTL := token.RefreshTokenExpires() + // rTTL := token.RedisExpiredSec() + refreshTTL := token.RedisRefreshExpiredSec() - if err := t.setToken(ctx, tx, token, body, rTTL); err != nil { + if err := t.setToken(ctx, tx, token, body, time.Duration(refreshTTL)*time.Second); err != nil { return err } - if err := t.setRefreshToken(ctx, tx, token, rTTL); err != nil { + if err := t.setRefreshToken(ctx, tx, token, time.Duration(refreshTTL)*time.Second); err != nil { return err } - if err := t.setDeviceToken(ctx, tx, token, rTTL); err != nil { + err := t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, time.Duration(refreshTTL)*time.Second) + if err != nil { return err } @@ -90,40 +57,103 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error return domain.RedisPipLineError(err.Error()) } - if err := t.SetUIDToken(token); err != nil { - return ers.ArkInternal("SetUIDToken error", err.Error()) + return nil +} + +func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { + err := t.store.Pipelined(func(tx redis.Pipeliner) error { + keys := []string{ + domain.GetAccessTokenRedisKey(token.ID), + domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), + domain.UIDTokenRedisKey.With(token.UID).ToString(), + } + + for _, key := range keys { + if err := tx.Del(ctx, key).Err(); err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) + } + } + + if token.DeviceID != "" { + key := domain.DeviceTokenRedisKey.With(token.DeviceID).ToString() + _, err := t.store.Del(key) + if err != nil { + return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err)) + } + } + + return nil + }) + + if err != nil { + return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) } return nil } -// // GetAccessTokensByDeviceID 透過 Device ID 得到目前未過期的token -// func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, uid string) ([]repository.DeviceToken, error) { -// data, err := t.store.Hgetall(domain.DeviceTokenRedisKey.With(uid).ToString()) -// if err != nil { -// if errors.Is(err, redis.Nil) { -// return nil, nil -// } -// -// return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.HGetAll Device Token error: %v", err.Error())) -// } -// -// ids := make([]repository.DeviceToken, 0, len(data)) -// for deviceID, id := range data { -// ids = append(ids, repository.DeviceToken{ -// DeviceID: deviceID, -// -// // e0a4f824-41db-4eb2-8e5a-d96966ea1d56-1698083859 -// // -11是因為id組成最後11位數是-跟時間戳記 -// TokenID: id[:len(id)-11], -// }) -// } -// return ids, nil -// } +func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) { + return t.get(domain.GetAccessTokenRedisKey(id)) +} + +func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error { + tokens, err := t.GetAccessTokensByUID(ctx, uid) + if err != nil { + return err + } + for _, item := range tokens { + err := t.Delete(ctx, item) + if err != nil { + return err + } + } + + return nil +} + +// DeleteAccessTokenByID TODO 要做錯誤處理 +func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error { + for _, tokenID := range ids { + token, err := t.GetAccessTokenByID(ctx, tokenID) + if err != nil { + continue + } + + err = t.store.Pipelined(func(tx redis.Pipeliner) error { + keys := []string{ + domain.GetAccessTokenRedisKey(token.ID), + domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), + } + + for _, key := range keys { + if err := tx.Del(ctx, key).Err(); err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) + } + } + + _, err = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID) + if err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err)) + } + + _, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) + if err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err)) + } + + return nil + }) + if err != nil { + continue + } + } + + return nil +} // GetAccessTokensByUID 透過 uid 得到目前未過期的 token func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) { - utKeys, err := t.store.Get(domain.GetUIDTokenRedisKey(uid)) + utKeys, err := t.store.Smembers(domain.GetUIDTokenRedisKey(uid)) if err != nil { // 沒有就視為回空 if errors.Is(err, redis.Nil) { @@ -133,90 +163,39 @@ func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByUID store.Get GetUIDTokenRedisKey error: %v", err.Error())) } - uidTokens := make(entity.UIDToken) - err = json.Unmarshal([]byte(utKeys), &uidTokens) - if err != nil { - return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) - } - - now := time.Now().Unix() + now := time.Now().UTC() var tokens []entity.Token var deleteToken []string - for id, token := range uidTokens { - if token < now { - deleteToken = append(deleteToken, id) - - continue - } - + for _, id := range utKeys { + item := &entity.Token{} tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id)) if err == nil { - item := entity.Token{} - err = json.Unmarshal([]byte(tk), &item) + err = json.Unmarshal([]byte(tk), item) if err != nil { return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) } - tokens = append(tokens, item) + tokens = append(tokens, *item) } if errors.Is(err, redis.Nil) { deleteToken = append(deleteToken, id) } - } + if int64(item.ExpiresIn) < now.Unix() { + deleteToken = append(deleteToken, id) + + continue + } + + } if len(deleteToken) > 0 { // 如果失敗也沒關係,其他get method撈取時會在判斷是否過期或存在 - _ = t.DeleteUIDToken(ctx, uid, deleteToken) + _ = t.DeleteAccessTokenByID(ctx, deleteToken) } return tokens, nil } -func (t *tokenRepository) DeleteUIDToken(ctx context.Context, uid string, ids []string) error { - uidTokens := make(entity.UIDToken) - tokenKeys, err := t.store.Get(domain.GetUIDTokenRedisKey(uid)) - if err != nil { - if !errors.Is(err, redis.Nil) { - return fmt.Errorf("tx.get GetDeviceTokenRedisKey error: %w", err) - } - } - - if tokenKeys != "" { - err = json.Unmarshal([]byte(tokenKeys), &uidTokens) - if err != nil { - return fmt.Errorf("json.Unmarshal GetDeviceTokenRedisKey error: %w", err) - } - } - - now := time.Now().Unix() - for k, t := range uidTokens { - // 到期就刪除 - if t < now { - delete(uidTokens, k) - } - } - - for _, id := range ids { - delete(uidTokens, id) - } - - b, err := json.Marshal(uidTokens) - if err != nil { - return fmt.Errorf("json.Marshal UIDToken error: %w", err) - } - - _, err = t.store.SetnxEx(domain.GetUIDTokenRedisKey(uid), string(b), 86400*30) - if err != nil { - return fmt.Errorf("tx.set GetUIDTokenRedisKey error: %w", err) - } - - return nil -} - -func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) { - return t.get(domain.GetAccessTokenRedisKey(id)) -} - func (t *tokenRepository) GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) { id, err := t.store.Get(domain.RefreshTokenRedisKey.With(refreshToken).ToString()) if err != nil { @@ -262,13 +241,13 @@ func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, return nil } -func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, expires time.Duration) error { +func (t *tokenRepository) CreateOneTimeToken(_ context.Context, key string, ticket entity.Ticket, expires time.Duration) error { body, err := json.Marshal(ticket) if err != nil { return ers.InvalidFormat("CreateOneTimeToken json.Marshal error:", err.Error()) } - _, err = t.store.SetnxEx(domain.GetTicketRedisKey(key), string(body), int(expires.Seconds())) + _, err = t.store.SetnxEx(domain.RefreshTokenRedisKey.With(key).ToString(), string(body), int(expires.Seconds())) if err != nil { return ers.DBError("CreateOneTimeToken store.set error:", err.Error()) } @@ -276,37 +255,106 @@ func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ti return nil } -func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { - err := t.store.Pipelined(func(tx redis.Pipeliner) error { - keys := []string{ - domain.GetAccessTokenRedisKey(token.ID), - domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), +func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { + utKeys, err := t.store.Smembers(domain.DeviceTokenRedisKey.With(deviceID).ToString()) + if err != nil { + // 沒有就視為回空 + if errors.Is(err, redis.Nil) { + return nil, nil } - for _, key := range keys { - if err := tx.Del(ctx, key).Err(); err != nil { + return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.Get DeviceTokenRedisKey error: %v", err.Error())) + } + + now := time.Now().UTC() + var tokens []entity.Token + var deleteToken []string + for _, id := range utKeys { + item := &entity.Token{} + tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id)) + if err == nil { + err = json.Unmarshal([]byte(tk), item) + if err != nil { + return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) + } + tokens = append(tokens, *item) + } + + if errors.Is(err, redis.Nil) { + deleteToken = append(deleteToken, id) + } + + if int64(item.ExpiresIn) < now.Unix() { + deleteToken = append(deleteToken, id) + + continue + } + + } + if len(deleteToken) > 0 { + // 如果失敗也沒關係,其他get method撈取時會在判斷是否過期或存在 + _ = t.DeleteAccessTokenByID(ctx, deleteToken) + } + + return tokens, nil +} + +func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { + tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID) + if err != nil { + return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err)) + } + + err = t.store.Pipelined(func(tx redis.Pipeliner) error { + for _, token := range tokens { + if err := tx.Del(ctx, domain.GetAccessTokenRedisKey(token.ID)).Err(); err != nil { return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) } + + if err := tx.Del(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()).Err(); err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) + } + _, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) + if err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err)) + } } - if token.DeviceID != "" { - key := domain.DeviceTokenRedisKey.With(token.UID).ToString() - _, err := t.store.Hdel(key, token.DeviceID) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err)) - } + _, err := t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString()) + if err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err)) } return nil }) if err != nil { - return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) + return err } return nil } +func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) { + count, err := t.store.Scard(domain.DeviceTokenRedisKey.With(deviceID).ToString()) + if err != nil { + return 0, err + } + + return int(count), nil +} + +func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) { + count, err := t.store.Scard(domain.UIDTokenRedisKey.With(uid).ToString()) + if err != nil { + return 0, err + } + + return int(count), nil +} + +// -------------------- Private area -------------------- + func (t *tokenRepository) get(key string) (entity.Token, error) { body, err := t.store.Get(key) if errors.Is(err, redis.Nil) || body == "" { @@ -343,19 +391,27 @@ func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeline return nil } -func (t *tokenRepository) setDeviceToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { - if token.DeviceID != "" { - key := domain.DeviceTokenRedisKey.With(token.UID).ToString() - value := fmt.Sprintf("%s-%d", token.ID, token.AccessCreateAt.Add(rTTL).Unix()) - err := tx.HSet(ctx, key, token.DeviceID, value).Err() - if err != nil { - return wrapError("tx.HSet Device Token error", err) - } - err = tx.Expire(ctx, key, rTTL).Err() - if err != nil { - return wrapError("tx.Expire Device Token error", err) - } +func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, rttl time.Duration) error { + uidKey := domain.UIDTokenRedisKey.With(uid).ToString() + err := tx.SAdd(ctx, uidKey, tokenID).Err() + if err != nil { + return err } + err = tx.Expire(ctx, uidKey, rttl).Err() + if err != nil { + return err + } + + deviceKey := domain.DeviceTokenRedisKey.With(deviceID).ToString() + err = tx.SAdd(ctx, deviceKey, tokenID).Err() + if err != nil { + return err + } + err = tx.Expire(ctx, deviceKey, rttl).Err() + if err != nil { + return err + } + return nil } diff --git a/internal/server/tokenservice/token_service_server.go b/internal/server/tokenservice/token_service_server.go index 0187058..c5adb92 100644 --- a/internal/server/tokenservice/token_service_server.go +++ b/internal/server/tokenservice/token_service_server.go @@ -40,24 +40,24 @@ func (s *TokenServiceServer) CancelToken(ctx context.Context, in *permission.Can return l.CancelToken(in) } -// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (s *TokenServiceServer) CancelTokenByUid(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { - l := tokenservicelogic.NewCancelTokenByUidLogic(ctx, s.svcCtx) - return l.CancelTokenByUid(in) -} - -// CancelTokenByDeviceId 取消 Token -func (s *TokenServiceServer) CancelTokenByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { - l := tokenservicelogic.NewCancelTokenByDeviceIdLogic(ctx, s.svcCtx) - return l.CancelTokenByDeviceId(in) -} - // ValidationToken 驗證這個 Token 有沒有效 func (s *TokenServiceServer) ValidationToken(ctx context.Context, in *permission.ValidationTokenReq) (*permission.ValidationTokenResp, error) { l := tokenservicelogic.NewValidationTokenLogic(ctx, s.svcCtx) return l.ValidationToken(in) } +// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 +func (s *TokenServiceServer) CancelTokens(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { + l := tokenservicelogic.NewCancelTokensLogic(ctx, s.svcCtx) + return l.CancelTokens(in) +} + +// CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token +func (s *TokenServiceServer) CancelTokenByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { + l := tokenservicelogic.NewCancelTokenByDeviceIdLogic(ctx, s.svcCtx) + return l.CancelTokenByDeviceId(in) +} + // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens func (s *TokenServiceServer) GetUserTokensByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { l := tokenservicelogic.NewGetUserTokensByDeviceIdLogic(ctx, s.svcCtx)