fix: complete token service

This commit is contained in:
daniel.w 2024-08-11 20:21:42 +08:00
parent 0d13b0c5f0
commit 2d485cf095
18 changed files with 500 additions and 407 deletions

View File

@ -37,12 +37,12 @@ type (
RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error)
// CancelToken 取消 Token也包含他裡面的 One Time Toke // CancelToken 取消 Token也包含他裡面的 One Time Toke
CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error)
// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke
CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error)
// CancelTokenByDeviceId 取消 Token
CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error)
// ValidationToken 驗證這個 Token 有沒有效 // ValidationToken 驗證這個 Token 有沒有效
ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error)
// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出
CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error)
// CancelTokenByDeviceId 取消 Token 從 Device 視角出發可以選登出這個Device 下所有 token 登出這個Device 下指定token
CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error)
// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens
GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error)
// GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens // GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens
@ -82,24 +82,24 @@ func (m *defaultTokenService) CancelToken(ctx context.Context, in *CancelTokenRe
return client.CancelToken(ctx, in, opts...) return client.CancelToken(ctx, in, opts...)
} }
// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke
func (m *defaultTokenService) CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.CancelTokenByUid(ctx, in, opts...)
}
// CancelTokenByDeviceId 取消 Token
func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.CancelTokenByDeviceId(ctx, in, opts...)
}
// ValidationToken 驗證這個 Token 有沒有效 // ValidationToken 驗證這個 Token 有沒有效
func (m *defaultTokenService) ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) { func (m *defaultTokenService) ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn()) client := permission.NewTokenServiceClient(m.cli.Conn())
return client.ValidationToken(ctx, in, opts...) return client.ValidationToken(ctx, in, opts...)
} }
// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出
func (m *defaultTokenService) CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.CancelTokens(ctx, in, opts...)
}
// CancelTokenByDeviceId 取消 Token 從 Device 視角出發可以選登出這個Device 下所有 token 登出這個Device 下指定token
func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.CancelTokenByDeviceId(ctx, in, opts...)
}
// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens
func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) {
client := permission.NewTokenServiceClient(m.cli.Conn()) client := permission.NewTokenServiceClient(m.cli.Conn())

View File

@ -114,7 +114,7 @@ message Token {
// DoTokenByDeviceIDReq DeviceID // DoTokenByDeviceIDReq DeviceID
message DoTokenByDeviceIDReq { message DoTokenByDeviceIDReq {
repeated string device_id = 1; string device_id = 1;
} }
message Tokens{ message Tokens{

View File

@ -18,6 +18,7 @@ const (
DeviceTokenRedisKey RedisKey = "device_token" DeviceTokenRedisKey RedisKey = "device_token"
UIDTokenRedisKey RedisKey = "uid_token" UIDTokenRedisKey RedisKey = "uid_token"
TicketRedisKey RedisKey = "ticket" TicketRedisKey RedisKey = "ticket"
DeviceUIDRedisKey RedisKey = "device_uid"
) )
func (key RedisKey) ToString() string { func (key RedisKey) ToString() string {

View File

@ -19,11 +19,9 @@ type TokenRepository interface {
GetAccessTokenCountByDeviceID(deviceID string) (int, error) GetAccessTokenCountByDeviceID(deviceID string) (int, error)
Delete(ctx context.Context, token entity.Token) error Delete(ctx context.Context, token entity.Token) error
DeleteAccessTokenByID(ctx context.Context, id string) error DeleteAccessTokenByID(ctx context.Context, ids []string) error
DeleteAccessTokensByUID(ctx context.Context, uid string) error DeleteAccessTokensByUID(ctx context.Context, uid string) error
DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error
DeleteAccessTokenByDeviceIDAndUID(ctx context.Context, deviceID, uid string) error
DeleteUIDToken(ctx context.Context, uid string, ids []string) error
} }
type DeviceToken struct { type DeviceToken struct {

View File

@ -30,6 +30,18 @@ func (t *Token) IsExpires() bool {
return t.AccessCreateAt.Add(t.AccessTokenExpires()).Before(time.Now()) return t.AccessCreateAt.Add(t.AccessTokenExpires()).Before(time.Now())
} }
func (t *Token) RedisExpiredSec() int64 {
sec := time.Unix(int64(t.ExpiresIn), 0).Sub(time.Now().UTC())
return int64(sec.Seconds())
}
func (t *Token) RedisRefreshExpiredSec() int64 {
sec := time.Unix(int64(t.RefreshExpiresIn), 0).Sub(time.Now().UTC())
return int64(sec.Seconds())
}
type UIDToken map[string]int64 type UIDToken map[string]int64
type Ticket struct { type Ticket struct {

View File

@ -1,6 +1,7 @@
package tokenservicelogic package tokenservicelogic
import ( import (
ers "code.30cm.net/wanderland/library-go/errors"
"context" "context"
"ark-permission/gen_result/pb/permission" "ark-permission/gen_result/pb/permission"
@ -25,7 +26,19 @@ func NewCancelTokenByDeviceIdLogic(ctx context.Context, svcCtx *svc.ServiceConte
// CancelTokenByDeviceId 取消 Token // CancelTokenByDeviceId 取消 Token
func (l *CancelTokenByDeviceIdLogic) CancelTokenByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { func (l *CancelTokenByDeviceIdLogic) CancelTokenByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) {
// todo: add your logic here and delete this line if err := l.svcCtx.Validate.ValidateAll(&getUserTokensByDeviceIdReq{
DeviceID: in.GetDeviceId(),
}); err != nil {
return nil, ers.InvalidFormat(err.Error())
}
err := l.svcCtx.TokenRedisRepo.DeleteAccessTokensByDeviceID(l.ctx, in.GetDeviceId())
if err != nil {
logx.WithCallerSkip(1).WithFields(
logx.Field("func", "TokenRedisRepo.DeleteAccessTokensByDeviceID"),
logx.Field("DeviceID", in.GetDeviceId()),
).Error(err.Error())
return nil, err
}
return &permission.OKResp{}, nil return &permission.OKResp{}, nil
} }

View File

@ -1,48 +0,0 @@
package tokenservicelogic
import (
ers "code.30cm.net/wanderland/library-go/errors"
"context"
"ark-permission/gen_result/pb/permission"
"ark-permission/internal/svc"
"github.com/zeromicro/go-zero/core/logx"
)
type CancelTokenByUidLogic struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func NewCancelTokenByUidLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokenByUidLogic {
return &CancelTokenByUidLogic{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
type deleteByTokenIDs struct {
UID string `json:"uid" binding:"required"`
IDs []string `json:"ids" binding:"required"`
}
// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke
func (l *CancelTokenByUidLogic) CancelTokenByUid(in *permission.DoTokenByUIDReq) (*permission.OKResp, error) {
// 驗證所需
if err := l.svcCtx.Validate.ValidateAll(&deleteByTokenIDs{
UID: in.GetUid(),
IDs: in.GetIds(),
}); err != nil {
return nil, ers.InvalidFormat(err.Error())
}
err := l.svcCtx.TokenRedisRepo.DeleteUIDToken(l.ctx, in.GetUid(), in.GetIds())
if err != nil {
return nil, err
}
return &permission.OKResp{}, nil
}

View File

@ -37,7 +37,7 @@ func (l *CancelTokenLogic) CancelToken(in *permission.CancelTokenReq) (*permissi
return nil, ers.InvalidFormat(err.Error()) return nil, ers.InvalidFormat(err.Error())
} }
claims, err := parseClaims(l.ctx, in.GetToken(), l.svcCtx.Config.Token.Secret) claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, false)
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields( logx.WithCallerSkip(1).WithFields(
logx.Field("func", "parseClaims"), logx.Field("func", "parseClaims"),
@ -45,7 +45,7 @@ func (l *CancelTokenLogic) CancelToken(in *permission.CancelTokenReq) (*permissi
return nil, err return nil, err
} }
token, err := l.svcCtx.TokenRedisRepo.GetByAccess(l.ctx, claims.ID()) token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID())
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields( logx.WithCallerSkip(1).WithFields(
logx.Field("func", "TokenRedisRepo.GetByAccess"), logx.Field("func", "TokenRedisRepo.GetByAccess"),

View File

@ -0,0 +1,52 @@
package tokenservicelogic
import (
ers "code.30cm.net/wanderland/library-go/errors"
"context"
"ark-permission/gen_result/pb/permission"
"ark-permission/internal/svc"
"github.com/zeromicro/go-zero/core/logx"
)
type CancelTokensLogic struct {
ctx context.Context
svcCtx *svc.ServiceContext
logx.Logger
}
func NewCancelTokensLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokensLogic {
return &CancelTokensLogic{
ctx: ctx,
svcCtx: svcCtx,
Logger: logx.WithContext(ctx),
}
}
// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出
func (l *CancelTokensLogic) CancelTokens(in *permission.DoTokenByUIDReq) (*permission.OKResp, error) {
if in.GetUid() != "" {
err := l.svcCtx.TokenRedisRepo.DeleteAccessTokensByUID(l.ctx, in.GetUid())
if err != nil {
logx.WithCallerSkip(1).WithFields(
logx.Field("func", "TokenRedisRepo.DeleteAccessTokensByUID"),
logx.Field("uid", in.GetUid()),
).Error(err.Error())
return nil, ers.ResourceInsufficient(err.Error())
}
}
if len(in.GetIds()) > 0 {
err := l.svcCtx.TokenRedisRepo.DeleteAccessTokenByID(l.ctx, in.GetIds())
if err != nil {
logx.WithCallerSkip(1).WithFields(
logx.Field("func", "TokenRedisRepo.DeleteAccessTokenByID"),
logx.Field("ids", in.GetIds()),
).Error(err.Error())
return nil, ers.ResourceInsufficient(err.Error())
}
}
return &permission.OKResp{}, nil
}

View File

@ -2,7 +2,9 @@ package tokenservicelogic
import ( import (
"ark-permission/gen_result/pb/permission" "ark-permission/gen_result/pb/permission"
"ark-permission/internal/domain"
"ark-permission/internal/svc" "ark-permission/internal/svc"
ers "code.30cm.net/wanderland/library-go/errors"
"context" "context"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )
@ -21,23 +23,34 @@ func NewGetUserTokensByDeviceIdLogic(ctx context.Context, svcCtx *svc.ServiceCon
} }
} }
type getUserTokensByDeviceIdReq struct {
DeviceID string `json:"device_id" validate:"required"`
}
// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens
func (l *GetUserTokensByDeviceIdLogic) GetUserTokensByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { func (l *GetUserTokensByDeviceIdLogic) GetUserTokensByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) {
if err := l.svcCtx.Validate.ValidateAll(&getUserTokensByDeviceIdReq{
// ids, err := l.svcCtx.TokenRedisRepo.GetAccessTokensByDeviceID(l.ctx, "") DeviceID: in.GetDeviceId(),
// if err != nil { }); err != nil {
// return nil, error return nil, ers.InvalidFormat(err.Error())
// } }
// tokenIDs := make([]usecase.DeviceToken, 0, len(ids)) uidTokens, err := l.svcCtx.TokenRedisRepo.GetAccessTokensByDeviceID(l.ctx, in.GetDeviceId())
// for _, v := range ids { if err != nil {
// tokenIDs = append(tokenIDs, usecase.DeviceToken{ return nil, err
// DeviceID: v.DeviceID, }
// TokenID: v.TokenID,
// }) tokens := make([]*permission.TokenResp, 0, len(uidTokens))
// } for _, v := range uidTokens {
// tokens = append(tokens, &permission.TokenResp{
// return tokenIDs, nil AccessToken: v.AccessToken,
TokenType: domain.TokenTypeBearer,
return &permission.Tokens{}, nil ExpiresIn: int32(v.ExpiresIn),
RefreshToken: v.RefreshToken,
})
}
return &permission.Tokens{
Token: tokens,
}, nil
} }

View File

@ -6,7 +6,6 @@ import (
"ark-permission/internal/svc" "ark-permission/internal/svc"
ers "code.30cm.net/wanderland/library-go/errors" ers "code.30cm.net/wanderland/library-go/errors"
"context" "context"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )

View File

@ -5,6 +5,7 @@ import (
"ark-permission/internal/entity" "ark-permission/internal/entity"
ers "code.30cm.net/wanderland/library-go/errors" ers "code.30cm.net/wanderland/library-go/errors"
"context" "context"
"github.com/google/uuid"
"time" "time"
"ark-permission/gen_result/pb/permission" "ark-permission/gen_result/pb/permission"
@ -37,7 +38,7 @@ func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeToken
} }
// 驗證Token // 驗證Token
claims, err := parseClaims(l.ctx, in.GetToken(), l.svcCtx.Config.Token.Secret) claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, false)
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields( logx.WithCallerSkip(1).WithFields(
logx.Field("func", "parseClaims"), logx.Field("func", "parseClaims"),
@ -45,7 +46,7 @@ func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeToken
return nil, err return nil, err
} }
token, err := l.svcCtx.TokenRedisRepo.GetByAccess(l.ctx, claims.ID()) token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID())
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields( logx.WithCallerSkip(1).WithFields(
logx.Field("func", "TokenRedisRepo.GetByAccess"), logx.Field("func", "TokenRedisRepo.GetByAccess"),
@ -54,7 +55,7 @@ func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeToken
return nil, err return nil, err
} }
oneTimeToken := generateRefreshToken(in.GetToken()) oneTimeToken := generateRefreshToken(uuid.Must(uuid.NewRandom()).String())
key := domain.TicketKeyPrefix + oneTimeToken key := domain.TicketKeyPrefix + oneTimeToken
if err = l.svcCtx.TokenRedisRepo.CreateOneTimeToken(l.ctx, key, entity.Ticket{ if err = l.svcCtx.TokenRedisRepo.CreateOneTimeToken(l.ctx, key, entity.Ticket{
Data: claims, Data: claims,

View File

@ -1,6 +1,7 @@
package tokenservicelogic package tokenservicelogic
import ( import (
"ark-permission/internal/config"
"ark-permission/internal/domain" "ark-permission/internal/domain"
"ark-permission/internal/entity" "ark-permission/internal/entity"
ers "code.30cm.net/wanderland/library-go/errors" ers "code.30cm.net/wanderland/library-go/errors"
@ -33,80 +34,31 @@ type authorizationReq struct {
GrantType domain.GrantType `json:"grant_type" validate:"required,oneof=password client_credentials refresh_token"` GrantType domain.GrantType `json:"grant_type" validate:"required,oneof=password client_credentials refresh_token"`
DeviceID string `json:"device_id"` DeviceID string `json:"device_id"`
Scope string `json:"scope" validate:"required"` Scope string `json:"scope" validate:"required"`
Data map[string]any `json:"data"` Data map[string]string `json:"data"`
Expires int `json:"expires"` Expires int `json:"expires"`
IsRefreshToken bool `json:"is_refresh_token"` IsRefreshToken bool `json:"is_refresh_token"`
} }
// NewToken 建立一個新的 Token例如AccessToken // NewToken 建立一個新的 Token例如AccessToken
func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.TokenResp, error) { func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.TokenResp, error) {
// 驗證所需 data := authorizationReq{
if err := l.svcCtx.Validate.ValidateAll(&authorizationReq{
GrantType: domain.GrantType(in.GetGrantType()), GrantType: domain.GrantType(in.GetGrantType()),
Scope: in.GetScope(), Scope: in.GetScope(),
}); err != nil { DeviceID: in.GetDeviceId(),
Data: in.GetData(),
Expires: int(in.GetExpires()),
IsRefreshToken: in.GetIsRefreshToken(),
}
// 驗證所需
if err := l.svcCtx.Validate.ValidateAll(&data); err != nil {
return nil, ers.InvalidFormat(err.Error()) return nil, ers.InvalidFormat(err.Error())
} }
token, err := newToken(data, l.svcCtx.Config)
// 準備建立 Token 所需
now := time.Now().UTC()
expires := int(in.GetExpires())
refreshExpires := int(in.GetExpires())
if expires <= 0 {
// 將時間加上 300 秒
sec := time.Duration(l.svcCtx.Config.Token.Expired.Seconds()) * time.Second
newTime := now.Add(sec)
// 獲取 Unix 時間戳
timestamp := newTime.Unix()
expires = int(timestamp)
refreshExpires = expires
}
// 如果這是一個 Refresh Token 過期時間要比普通的Token 長
if in.GetIsRefreshToken() {
// 將時間加上 300 秒
sec := time.Duration(l.svcCtx.Config.Token.RefreshExpires.Seconds()) * time.Second
newTime := now.Add(sec)
// 獲取 Unix 時間戳
timestamp := newTime.Unix()
refreshExpires = int(timestamp)
}
token := entity.Token{
ID: uuid.Must(uuid.NewRandom()).String(),
DeviceID: in.GetDeviceId(),
ExpiresIn: expires,
RefreshExpiresIn: refreshExpires,
AccessCreateAt: now,
RefreshCreateAt: now,
}
claims := claims(in.GetData())
claims.SetRole(domain.DefaultRole)
claims.SetID(token.ID)
claims.SetScope(in.GetScope())
token.UID = claims.UID()
if in.GetDeviceId() != "" {
claims.SetDeviceID(in.GetDeviceId())
}
var err error
token.AccessToken, err = generateAccessTokenFunc(token, claims, l.svcCtx.Config.Token.Secret)
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields(
logx.Field("func", "generateAccessTokenFunc"),
logx.Field("claims", claims),
).Error(err.Error())
return nil, err return nil, err
} }
if in.GetIsRefreshToken() { err = l.svcCtx.TokenRedisRepo.Create(l.ctx, *token)
token.RefreshToken = generateRefreshTokenFunc(token.AccessToken)
}
err = l.svcCtx.TokenRedisRepo.Create(l.ctx, token)
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields( logx.WithCallerSkip(1).WithFields(
logx.Field("func", "TokenRedisRepo.Create"), logx.Field("func", "TokenRedisRepo.Create"),
@ -122,3 +74,65 @@ func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.T
RefreshToken: token.RefreshToken, RefreshToken: token.RefreshToken,
}, nil }, nil
} }
func newToken(authReq authorizationReq, cfg config.Config) (*entity.Token, error) {
// 準備建立 Token 所需
now := time.Now().UTC()
expires := authReq.Expires
refreshExpires := authReq.Expires
if expires <= 0 {
// 將時間加上 300 秒
sec := time.Duration(cfg.Token.Expired.Seconds()) * time.Second
newTime := now.Add(sec)
// 獲取 Unix 時間戳
timestamp := newTime.Unix()
expires = int(timestamp)
refreshExpires = expires
}
// 如果這是一個 Refresh Token 過期時間要比普通的Token 長
if authReq.IsRefreshToken {
// 將時間加上 300 秒
sec := time.Duration(cfg.Token.RefreshExpires.Seconds()) * time.Second
newTime := now.Add(sec)
// 獲取 Unix 時間戳
timestamp := newTime.Unix()
refreshExpires = int(timestamp)
}
token := entity.Token{
ID: uuid.Must(uuid.NewRandom()).String(),
DeviceID: authReq.DeviceID,
ExpiresIn: expires,
RefreshExpiresIn: refreshExpires,
AccessCreateAt: now,
RefreshCreateAt: now,
}
claims := claims(authReq.Data)
claims.SetRole(domain.DefaultRole)
claims.SetID(token.ID)
claims.SetScope(authReq.Scope)
token.UID = claims.UID()
if authReq.DeviceID != "" {
claims.SetDeviceID(authReq.DeviceID)
}
var err error
token.AccessToken, err = generateAccessTokenFunc(token, claims, cfg.Token.Secret)
if err != nil {
logx.WithCallerSkip(1).WithFields(
logx.Field("func", "generateAccessTokenFunc"),
logx.Field("claims", claims),
).Error(err.Error())
return nil, err
}
if authReq.IsRefreshToken {
token.RefreshToken = generateRefreshTokenFunc(token.AccessToken)
}
return &token, nil
}

View File

@ -1,14 +1,11 @@
package tokenservicelogic package tokenservicelogic
import ( import (
"ark-permission/gen_result/pb/permission"
"ark-permission/internal/domain" "ark-permission/internal/domain"
"ark-permission/internal/entity" "ark-permission/internal/svc"
ers "code.30cm.net/wanderland/library-go/errors" ers "code.30cm.net/wanderland/library-go/errors"
"context" "context"
"time"
"ark-permission/gen_result/pb/permission"
"ark-permission/internal/svc"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )
@ -31,7 +28,6 @@ type refreshReq struct {
RefreshToken string `json:"grant_type" validate:"required"` RefreshToken string `json:"grant_type" validate:"required"`
DeviceID string `json:"device_id" validate:"required"` DeviceID string `json:"device_id" validate:"required"`
Scope string `json:"scope" validate:"required"` Scope string `json:"scope" validate:"required"`
Expires int64 `json:"expires" validate:"required"`
} }
// RefreshToken 更新目前的token 以及裡面包含的一次性 Token // RefreshToken 更新目前的token 以及裡面包含的一次性 Token
@ -41,10 +37,10 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi
RefreshToken: in.GetToken(), RefreshToken: in.GetToken(),
Scope: in.GetScope(), Scope: in.GetScope(),
DeviceID: in.GetDeviceId(), DeviceID: in.GetDeviceId(),
Expires: in.GetExpires(),
}); err != nil { }); err != nil {
return nil, ers.InvalidFormat(err.Error()) return nil, ers.InvalidFormat(err.Error())
} }
// step 1 拿看看有沒有這個 refresh token // step 1 拿看看有沒有這個 refresh token
token, err := l.svcCtx.TokenRedisRepo.GetByRefresh(l.ctx, in.Token) token, err := l.svcCtx.TokenRedisRepo.GetByRefresh(l.ctx, in.Token)
if err != nil { if err != nil {
@ -54,56 +50,33 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi
).Error(err.Error()) ).Error(err.Error())
return nil, err return nil, err
} }
// 拿到之後替換掉時間以及 refresh token
// refreshToken 建立
now := time.Now().UTC()
sec := time.Duration(l.svcCtx.Config.Token.RefreshExpires.Seconds()) * time.Second
newTime := now.Add(sec)
// 獲取 Unix 時間戳
timestamp := newTime.Unix()
refreshExpires := int(timestamp)
expires := int(in.GetExpires())
if expires <= 0 {
// 將時間加上 300 秒
sec := time.Duration(l.svcCtx.Config.Token.Expired.Seconds()) * time.Second
newTime := now.Add(sec)
// 獲取 Unix 時間戳
timestamp := newTime.Unix()
expires = int(timestamp)
}
newToken := entity.Token{ // 取得 Data
ID: token.ID, c, err := parseClaims(token.AccessToken, l.svcCtx.Config.Token.Secret, false)
UID: token.UID,
DeviceID: in.GetDeviceId(),
ExpiresIn: expires,
RefreshExpiresIn: refreshExpires,
AccessCreateAt: now,
RefreshCreateAt: now,
}
claims := claims(map[string]string{
"uid": token.UID,
})
claims.SetRole(domain.DefaultRole)
claims.SetID(token.ID)
claims.SetScope(in.GetScope())
claims.UID()
if in.GetDeviceId() != "" {
claims.SetDeviceID(in.GetDeviceId())
}
newToken.AccessToken, err = generateAccessTokenFunc(newToken, claims, l.svcCtx.Config.Token.Secret)
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields( logx.WithCallerSkip(1).WithFields(
logx.Field("func", "generateAccessTokenFunc"), logx.Field("func", "parseClaims"),
logx.Field("claims", claims), logx.Field("token", token),
).Error(err.Error()) ).Error(err.Error())
return nil, err return nil, err
} }
newToken.RefreshToken = generateRefreshTokenFunc(newToken.AccessToken) // step 2 建立新 token
nt, err := newToken(authorizationReq{
GrantType: domain.ClientCredentials,
Scope: in.GetScope(),
DeviceID: in.GetDeviceId(),
Data: c,
Expires: int(in.GetExpires()),
IsRefreshToken: true,
}, l.svcCtx.Config)
if err != nil {
logx.WithCallerSkip(1).WithFields(
logx.Field("func", "newToken"),
logx.Field("req", in),
).Error(err.Error())
return nil, err
}
// 刪除掉舊的 token // 刪除掉舊的 token
err = l.svcCtx.TokenRedisRepo.Delete(l.ctx, token) err = l.svcCtx.TokenRedisRepo.Delete(l.ctx, token)
@ -115,7 +88,7 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi
return nil, err return nil, err
} }
err = l.svcCtx.TokenRedisRepo.Create(l.ctx, newToken) err = l.svcCtx.TokenRedisRepo.Create(l.ctx, *nt)
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields( logx.WithCallerSkip(1).WithFields(
logx.Field("func", "TokenRedisRepo.Create"), logx.Field("func", "TokenRedisRepo.Create"),
@ -125,9 +98,9 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi
} }
return &permission.RefreshTokenResp{ return &permission.RefreshTokenResp{
Token: newToken.AccessToken, Token: nt.AccessToken,
OneTimeToken: newToken.RefreshToken, OneTimeToken: nt.RefreshToken,
ExpiresIn: int64(expires), ExpiresIn: int64(nt.ExpiresIn),
TokenType: domain.TokenTypeBearer, TokenType: domain.TokenTypeBearer,
}, nil }, nil
} }

View File

@ -4,7 +4,6 @@ import (
"ark-permission/internal/domain" "ark-permission/internal/domain"
"ark-permission/internal/entity" "ark-permission/internal/entity"
"bytes" "bytes"
"context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
@ -42,43 +41,53 @@ func generateRefreshToken(accessToken string) string {
return hex.EncodeToString(h.Sum(nil)) return hex.EncodeToString(h.Sum(nil))
} }
func parseClaims(ctx context.Context, accessToken string, secret string) (claims, error) { func parseToken(accessToken string, secret string, validate bool) (jwt.MapClaims, error) {
claimMap, err := parseToken(accessToken, secret) // 跳過驗證的解析
if err != nil { var token *jwt.Token
return claims{}, err var err error
}
claims, ok := claimMap["data"].(map[string]any) if validate {
if ok { token, err = jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
return convertMap(claims), nil
}
return nil, domain.TokenClaimError("get data from claim map error")
}
func parseToken(accessToken string, secret string) (jwt.MapClaims, error) {
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, domain.TokenUnexpectedSigningErr(fmt.Sprintf("token unexpected signing method: %v", token.Header["alg"])) return nil, domain.TokenUnexpectedSigningErr(fmt.Sprintf("token unexpected signing method: %v", token.Header["alg"]))
} }
return []byte(secret), nil return []byte(secret), nil
}) })
if err != nil { if err != nil {
return jwt.MapClaims{}, err return jwt.MapClaims{}, err
} }
} else {
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
token, err = parser.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
return []byte(secret), nil
})
if err != nil {
return jwt.MapClaims{}, err
}
}
claims, ok := token.Claims.(jwt.MapClaims) claims, ok := token.Claims.(jwt.MapClaims)
if !ok && token.Valid {
if !(ok && token.Valid) {
return jwt.MapClaims{}, domain.TokenTokenValidateErr("token valid error") return jwt.MapClaims{}, domain.TokenTokenValidateErr("token valid error")
} }
return claims, nil return claims, nil
} }
func parseClaims(accessToken string, secret string, validate bool) (claims, error) {
claimMap, err := parseToken(accessToken, secret, validate)
if err != nil {
return claims{}, err
}
claimsData, ok := claimMap["data"].(map[string]any)
if ok {
return convertMap(claimsData), nil
}
return claims{}, domain.TokenClaimError("get data from claim map error")
}
func convertMap(input map[string]interface{}) map[string]string { func convertMap(input map[string]interface{}) map[string]string {
output := make(map[string]string) output := make(map[string]string)
for key, value := range input { for key, value := range input {

View File

@ -37,7 +37,7 @@ func (l *ValidationTokenLogic) ValidationToken(in *permission.ValidationTokenReq
return nil, ers.InvalidFormat(err.Error()) return nil, ers.InvalidFormat(err.Error())
} }
claims, err := parseClaims(l.ctx, in.GetToken(), l.svcCtx.Config.Token.Secret) claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, true)
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields( logx.WithCallerSkip(1).WithFields(
logx.Field("func", "parseClaims"), logx.Field("func", "parseClaims"),
@ -45,7 +45,7 @@ func (l *ValidationTokenLogic) ValidationToken(in *permission.ValidationTokenReq
return nil, err return nil, err
} }
token, err := l.svcCtx.TokenRedisRepo.GetByAccess(l.ctx, claims.ID()) token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID())
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields( logx.WithCallerSkip(1).WithFields(
logx.Field("func", "TokenRedisRepo.GetByAccess"), logx.Field("func", "TokenRedisRepo.GetByAccess"),

View File

@ -22,41 +22,6 @@ type tokenRepository struct {
store *redis.Redis store *redis.Redis
} }
func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) {
// TODO implement me
panic("implement me")
}
func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
// TODO implement me
panic("implement me")
}
func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) {
// TODO implement me
panic("implement me")
}
func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, id string) error {
// TODO implement me
panic("implement me")
}
func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
// TODO implement me
panic("implement me")
}
func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
// TODO implement me
panic("implement me")
}
func (t *tokenRepository) DeleteAccessTokenByDeviceIDAndUID(ctx context.Context, deviceID, uid string) error {
// TODO implement me
panic("implement me")
}
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository { func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
return &tokenRepository{ return &tokenRepository{
store: param.Store, store: param.Store,
@ -70,17 +35,19 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error
} }
err = t.store.Pipelined(func(tx redis.Pipeliner) error { err = t.store.Pipelined(func(tx redis.Pipeliner) error {
rTTL := token.RefreshTokenExpires() // rTTL := token.RedisExpiredSec()
refreshTTL := token.RedisRefreshExpiredSec()
if err := t.setToken(ctx, tx, token, body, rTTL); err != nil { if err := t.setToken(ctx, tx, token, body, time.Duration(refreshTTL)*time.Second); err != nil {
return err return err
} }
if err := t.setRefreshToken(ctx, tx, token, rTTL); err != nil { if err := t.setRefreshToken(ctx, tx, token, time.Duration(refreshTTL)*time.Second); err != nil {
return err return err
} }
if err := t.setDeviceToken(ctx, tx, token, rTTL); err != nil { err := t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, time.Duration(refreshTTL)*time.Second)
if err != nil {
return err return err
} }
@ -90,40 +57,103 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error
return domain.RedisPipLineError(err.Error()) return domain.RedisPipLineError(err.Error())
} }
if err := t.SetUIDToken(token); err != nil { return nil
return ers.ArkInternal("SetUIDToken error", err.Error()) }
func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error {
err := t.store.Pipelined(func(tx redis.Pipeliner) error {
keys := []string{
domain.GetAccessTokenRedisKey(token.ID),
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
domain.UIDTokenRedisKey.With(token.UID).ToString(),
}
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
if token.DeviceID != "" {
key := domain.DeviceTokenRedisKey.With(token.DeviceID).ToString()
_, err := t.store.Del(key)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err))
}
}
return nil
})
if err != nil {
return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err))
} }
return nil return nil
} }
// // GetAccessTokensByDeviceID 透過 Device ID 得到目前未過期的token func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) {
// func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, uid string) ([]repository.DeviceToken, error) { return t.get(domain.GetAccessTokenRedisKey(id))
// data, err := t.store.Hgetall(domain.DeviceTokenRedisKey.With(uid).ToString()) }
// if err != nil {
// if errors.Is(err, redis.Nil) { func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
// return nil, nil tokens, err := t.GetAccessTokensByUID(ctx, uid)
// } if err != nil {
// return err
// return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.HGetAll Device Token error: %v", err.Error())) }
// } for _, item := range tokens {
// err := t.Delete(ctx, item)
// ids := make([]repository.DeviceToken, 0, len(data)) if err != nil {
// for deviceID, id := range data { return err
// ids = append(ids, repository.DeviceToken{ }
// DeviceID: deviceID, }
//
// // e0a4f824-41db-4eb2-8e5a-d96966ea1d56-1698083859 return nil
// // -11是因為id組成最後11位數是-跟時間戳記 }
// TokenID: id[:len(id)-11],
// }) // DeleteAccessTokenByID TODO 要做錯誤處理
// } func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
// return ids, nil for _, tokenID := range ids {
// } token, err := t.GetAccessTokenByID(ctx, tokenID)
if err != nil {
continue
}
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
keys := []string{
domain.GetAccessTokenRedisKey(token.ID),
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
}
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
_, err = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err))
}
_, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err))
}
return nil
})
if err != nil {
continue
}
}
return nil
}
// GetAccessTokensByUID 透過 uid 得到目前未過期的 token // GetAccessTokensByUID 透過 uid 得到目前未過期的 token
func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) { func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
utKeys, err := t.store.Get(domain.GetUIDTokenRedisKey(uid)) utKeys, err := t.store.Smembers(domain.GetUIDTokenRedisKey(uid))
if err != nil { if err != nil {
// 沒有就視為回空 // 沒有就視為回空
if errors.Is(err, redis.Nil) { if errors.Is(err, redis.Nil) {
@ -133,90 +163,39 @@ func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string)
return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByUID store.Get GetUIDTokenRedisKey error: %v", err.Error())) return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByUID store.Get GetUIDTokenRedisKey error: %v", err.Error()))
} }
uidTokens := make(entity.UIDToken) now := time.Now().UTC()
err = json.Unmarshal([]byte(utKeys), &uidTokens)
if err != nil {
return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err))
}
now := time.Now().Unix()
var tokens []entity.Token var tokens []entity.Token
var deleteToken []string var deleteToken []string
for id, token := range uidTokens { for _, id := range utKeys {
if token < now { item := &entity.Token{}
deleteToken = append(deleteToken, id)
continue
}
tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id)) tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id))
if err == nil { if err == nil {
item := entity.Token{} err = json.Unmarshal([]byte(tk), item)
err = json.Unmarshal([]byte(tk), &item)
if err != nil { if err != nil {
return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err))
} }
tokens = append(tokens, item) tokens = append(tokens, *item)
} }
if errors.Is(err, redis.Nil) { if errors.Is(err, redis.Nil) {
deleteToken = append(deleteToken, id) deleteToken = append(deleteToken, id)
} }
if int64(item.ExpiresIn) < now.Unix() {
deleteToken = append(deleteToken, id)
continue
} }
}
if len(deleteToken) > 0 { if len(deleteToken) > 0 {
// 如果失敗也沒關係其他get method撈取時會在判斷是否過期或存在 // 如果失敗也沒關係其他get method撈取時會在判斷是否過期或存在
_ = t.DeleteUIDToken(ctx, uid, deleteToken) _ = t.DeleteAccessTokenByID(ctx, deleteToken)
} }
return tokens, nil return tokens, nil
} }
func (t *tokenRepository) DeleteUIDToken(ctx context.Context, uid string, ids []string) error {
uidTokens := make(entity.UIDToken)
tokenKeys, err := t.store.Get(domain.GetUIDTokenRedisKey(uid))
if err != nil {
if !errors.Is(err, redis.Nil) {
return fmt.Errorf("tx.get GetDeviceTokenRedisKey error: %w", err)
}
}
if tokenKeys != "" {
err = json.Unmarshal([]byte(tokenKeys), &uidTokens)
if err != nil {
return fmt.Errorf("json.Unmarshal GetDeviceTokenRedisKey error: %w", err)
}
}
now := time.Now().Unix()
for k, t := range uidTokens {
// 到期就刪除
if t < now {
delete(uidTokens, k)
}
}
for _, id := range ids {
delete(uidTokens, id)
}
b, err := json.Marshal(uidTokens)
if err != nil {
return fmt.Errorf("json.Marshal UIDToken error: %w", err)
}
_, err = t.store.SetnxEx(domain.GetUIDTokenRedisKey(uid), string(b), 86400*30)
if err != nil {
return fmt.Errorf("tx.set GetUIDTokenRedisKey error: %w", err)
}
return nil
}
func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) {
return t.get(domain.GetAccessTokenRedisKey(id))
}
func (t *tokenRepository) GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) { func (t *tokenRepository) GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) {
id, err := t.store.Get(domain.RefreshTokenRedisKey.With(refreshToken).ToString()) id, err := t.store.Get(domain.RefreshTokenRedisKey.With(refreshToken).ToString())
if err != nil { if err != nil {
@ -262,13 +241,13 @@ func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string,
return nil return nil
} }
func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, expires time.Duration) error { func (t *tokenRepository) CreateOneTimeToken(_ context.Context, key string, ticket entity.Ticket, expires time.Duration) error {
body, err := json.Marshal(ticket) body, err := json.Marshal(ticket)
if err != nil { if err != nil {
return ers.InvalidFormat("CreateOneTimeToken json.Marshal error:", err.Error()) return ers.InvalidFormat("CreateOneTimeToken json.Marshal error:", err.Error())
} }
_, err = t.store.SetnxEx(domain.GetTicketRedisKey(key), string(body), int(expires.Seconds())) _, err = t.store.SetnxEx(domain.RefreshTokenRedisKey.With(key).ToString(), string(body), int(expires.Seconds()))
if err != nil { if err != nil {
return ers.DBError("CreateOneTimeToken store.set error:", err.Error()) return ers.DBError("CreateOneTimeToken store.set error:", err.Error())
} }
@ -276,37 +255,106 @@ func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ti
return nil return nil
} }
func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
err := t.store.Pipelined(func(tx redis.Pipeliner) error { utKeys, err := t.store.Smembers(domain.DeviceTokenRedisKey.With(deviceID).ToString())
keys := []string{ if err != nil {
domain.GetAccessTokenRedisKey(token.ID), // 沒有就視為回空
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), if errors.Is(err, redis.Nil) {
return nil, nil
} }
for _, key := range keys { return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.Get DeviceTokenRedisKey error: %v", err.Error()))
if err := tx.Del(ctx, key).Err(); err != nil { }
now := time.Now().UTC()
var tokens []entity.Token
var deleteToken []string
for _, id := range utKeys {
item := &entity.Token{}
tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id))
if err == nil {
err = json.Unmarshal([]byte(tk), item)
if err != nil {
return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err))
}
tokens = append(tokens, *item)
}
if errors.Is(err, redis.Nil) {
deleteToken = append(deleteToken, id)
}
if int64(item.ExpiresIn) < now.Unix() {
deleteToken = append(deleteToken, id)
continue
}
}
if len(deleteToken) > 0 {
// 如果失敗也沒關係其他get method撈取時會在判斷是否過期或存在
_ = t.DeleteAccessTokenByID(ctx, deleteToken)
}
return tokens, nil
}
func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err))
}
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
for _, token := range tokens {
if err := tx.Del(ctx, domain.GetAccessTokenRedisKey(token.ID)).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
} }
if err := tx.Del(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
_, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err))
}
} }
if token.DeviceID != "" { _, err := t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString())
key := domain.DeviceTokenRedisKey.With(token.UID).ToString()
_, err := t.store.Hdel(key, token.DeviceID)
if err != nil { if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err)) return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err))
}
} }
return nil return nil
}) })
if err != nil { if err != nil {
return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) return err
} }
return nil return nil
} }
func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) {
count, err := t.store.Scard(domain.DeviceTokenRedisKey.With(deviceID).ToString())
if err != nil {
return 0, err
}
return int(count), nil
}
func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) {
count, err := t.store.Scard(domain.UIDTokenRedisKey.With(uid).ToString())
if err != nil {
return 0, err
}
return int(count), nil
}
// -------------------- Private area --------------------
func (t *tokenRepository) get(key string) (entity.Token, error) { func (t *tokenRepository) get(key string) (entity.Token, error) {
body, err := t.store.Get(key) body, err := t.store.Get(key)
if errors.Is(err, redis.Nil) || body == "" { if errors.Is(err, redis.Nil) || body == "" {
@ -343,19 +391,27 @@ func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeline
return nil return nil
} }
func (t *tokenRepository) setDeviceToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, rttl time.Duration) error {
if token.DeviceID != "" { uidKey := domain.UIDTokenRedisKey.With(uid).ToString()
key := domain.DeviceTokenRedisKey.With(token.UID).ToString() err := tx.SAdd(ctx, uidKey, tokenID).Err()
value := fmt.Sprintf("%s-%d", token.ID, token.AccessCreateAt.Add(rTTL).Unix())
err := tx.HSet(ctx, key, token.DeviceID, value).Err()
if err != nil { if err != nil {
return wrapError("tx.HSet Device Token error", err) return err
} }
err = tx.Expire(ctx, key, rTTL).Err() err = tx.Expire(ctx, uidKey, rttl).Err()
if err != nil { if err != nil {
return wrapError("tx.Expire Device Token error", err) return err
} }
deviceKey := domain.DeviceTokenRedisKey.With(deviceID).ToString()
err = tx.SAdd(ctx, deviceKey, tokenID).Err()
if err != nil {
return err
} }
err = tx.Expire(ctx, deviceKey, rttl).Err()
if err != nil {
return err
}
return nil return nil
} }

View File

@ -40,24 +40,24 @@ func (s *TokenServiceServer) CancelToken(ctx context.Context, in *permission.Can
return l.CancelToken(in) return l.CancelToken(in)
} }
// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke
func (s *TokenServiceServer) CancelTokenByUid(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.OKResp, error) {
l := tokenservicelogic.NewCancelTokenByUidLogic(ctx, s.svcCtx)
return l.CancelTokenByUid(in)
}
// CancelTokenByDeviceId 取消 Token
func (s *TokenServiceServer) CancelTokenByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) {
l := tokenservicelogic.NewCancelTokenByDeviceIdLogic(ctx, s.svcCtx)
return l.CancelTokenByDeviceId(in)
}
// ValidationToken 驗證這個 Token 有沒有效 // ValidationToken 驗證這個 Token 有沒有效
func (s *TokenServiceServer) ValidationToken(ctx context.Context, in *permission.ValidationTokenReq) (*permission.ValidationTokenResp, error) { func (s *TokenServiceServer) ValidationToken(ctx context.Context, in *permission.ValidationTokenReq) (*permission.ValidationTokenResp, error) {
l := tokenservicelogic.NewValidationTokenLogic(ctx, s.svcCtx) l := tokenservicelogic.NewValidationTokenLogic(ctx, s.svcCtx)
return l.ValidationToken(in) return l.ValidationToken(in)
} }
// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出
func (s *TokenServiceServer) CancelTokens(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.OKResp, error) {
l := tokenservicelogic.NewCancelTokensLogic(ctx, s.svcCtx)
return l.CancelTokens(in)
}
// CancelTokenByDeviceId 取消 Token 從 Device 視角出發可以選登出這個Device 下所有 token 登出這個Device 下指定token
func (s *TokenServiceServer) CancelTokenByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) {
l := tokenservicelogic.NewCancelTokenByDeviceIdLogic(ctx, s.svcCtx)
return l.CancelTokenByDeviceId(in)
}
// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens
func (s *TokenServiceServer) GetUserTokensByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { func (s *TokenServiceServer) GetUserTokensByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) {
l := tokenservicelogic.NewGetUserTokensByDeviceIdLogic(ctx, s.svcCtx) l := tokenservicelogic.NewGetUserTokensByDeviceIdLogic(ctx, s.svcCtx)