diff --git a/client/permissionservice/permission_service.go b/client/permissionservice/permission_service.go new file mode 100644 index 0000000..fdae9cb --- /dev/null +++ b/client/permissionservice/permission_service.go @@ -0,0 +1,45 @@ +// Code generated by goctl. DO NOT EDIT. +// Source: permission.proto + +package permissionservice + +import ( + "context" + + "ark-permission/gen_result/pb/permission" + + "github.com/zeromicro/go-zero/zrpc" + "google.golang.org/grpc" +) + +type ( + AuthorizationReq = permission.AuthorizationReq + CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq + CancelTokenReq = permission.CancelTokenReq + CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq + CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp + DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq + DoTokenByUIDReq = permission.DoTokenByUIDReq + OKResp = permission.OKResp + QueryTokenByUIDReq = permission.QueryTokenByUIDReq + RefreshTokenReq = permission.RefreshTokenReq + RefreshTokenResp = permission.RefreshTokenResp + Token = permission.Token + TokenResp = permission.TokenResp + Tokens = permission.Tokens + ValidationTokenReq = permission.ValidationTokenReq + ValidationTokenResp = permission.ValidationTokenResp + + PermissionService interface { + } + + defaultPermissionService struct { + cli zrpc.Client + } +) + +func NewPermissionService(cli zrpc.Client) PermissionService { + return &defaultPermissionService{ + cli: cli, + } +} diff --git a/client/roleservice/role_service.go b/client/roleservice/role_service.go new file mode 100644 index 0000000..8d7e2e2 --- /dev/null +++ b/client/roleservice/role_service.go @@ -0,0 +1,51 @@ +// Code generated by goctl. DO NOT EDIT. +// Source: permission.proto + +package roleservice + +import ( + "context" + + "ark-permission/gen_result/pb/permission" + + "github.com/zeromicro/go-zero/zrpc" + "google.golang.org/grpc" +) + +type ( + AuthorizationReq = permission.AuthorizationReq + CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq + CancelTokenReq = permission.CancelTokenReq + CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq + CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp + DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq + DoTokenByUIDReq = permission.DoTokenByUIDReq + OKResp = permission.OKResp + QueryTokenByUIDReq = permission.QueryTokenByUIDReq + RefreshTokenReq = permission.RefreshTokenReq + RefreshTokenResp = permission.RefreshTokenResp + Token = permission.Token + TokenResp = permission.TokenResp + Tokens = permission.Tokens + ValidationTokenReq = permission.ValidationTokenReq + ValidationTokenResp = permission.ValidationTokenResp + + RoleService interface { + Ping(ctx context.Context, in *OKResp, opts ...grpc.CallOption) (*OKResp, error) + } + + defaultRoleService struct { + cli zrpc.Client + } +) + +func NewRoleService(cli zrpc.Client) RoleService { + return &defaultRoleService{ + cli: cli, + } +} + +func (m *defaultRoleService) Ping(ctx context.Context, in *OKResp, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewRoleServiceClient(m.cli.Conn()) + return client.Ping(ctx, in, opts...) +} diff --git a/client/tokenservice/token_service.go b/client/tokenservice/token_service.go new file mode 100644 index 0000000..617dc38 --- /dev/null +++ b/client/tokenservice/token_service.go @@ -0,0 +1,125 @@ +// Code generated by goctl. DO NOT EDIT. +// Source: permission.proto + +package tokenservice + +import ( + "context" + + "ark-permission/gen_result/pb/permission" + + "github.com/zeromicro/go-zero/zrpc" + "google.golang.org/grpc" +) + +type ( + AuthorizationReq = permission.AuthorizationReq + CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq + CancelTokenReq = permission.CancelTokenReq + CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq + CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp + DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq + DoTokenByUIDReq = permission.DoTokenByUIDReq + OKResp = permission.OKResp + QueryTokenByUIDReq = permission.QueryTokenByUIDReq + RefreshTokenReq = permission.RefreshTokenReq + RefreshTokenResp = permission.RefreshTokenResp + Token = permission.Token + TokenResp = permission.TokenResp + Tokens = permission.Tokens + ValidationTokenReq = permission.ValidationTokenReq + ValidationTokenResp = permission.ValidationTokenResp + + TokenService interface { + // NewToken 建立一個新的 Token,例如:AccessToken + NewToken(ctx context.Context, in *AuthorizationReq, opts ...grpc.CallOption) (*TokenResp, error) + // RefreshToken 更新目前的token 以及裡面包含的一次性 Token + RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) + // CancelToken 取消 Token,也包含他裡面的 One Time Toke + CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) + // ValidationToken 驗證這個 Token 有沒有效 + ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) + // CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 + CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) + // CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token + CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) + // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens + GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) + // GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens + GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) + // NewOneTimeToken 建立一次性使用,例如:RefreshToken + NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) + // CancelOneTimeToken 取消一次性使用 + CancelOneTimeToken(ctx context.Context, in *CancelOneTimeTokenReq, opts ...grpc.CallOption) (*OKResp, error) + } + + defaultTokenService struct { + cli zrpc.Client + } +) + +func NewTokenService(cli zrpc.Client) TokenService { + return &defaultTokenService{ + cli: cli, + } +} + +// NewToken 建立一個新的 Token,例如:AccessToken +func (m *defaultTokenService) NewToken(ctx context.Context, in *AuthorizationReq, opts ...grpc.CallOption) (*TokenResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.NewToken(ctx, in, opts...) +} + +// RefreshToken 更新目前的token 以及裡面包含的一次性 Token +func (m *defaultTokenService) RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.RefreshToken(ctx, in, opts...) +} + +// CancelToken 取消 Token,也包含他裡面的 One Time Toke +func (m *defaultTokenService) CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelToken(ctx, in, opts...) +} + +// ValidationToken 驗證這個 Token 有沒有效 +func (m *defaultTokenService) ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.ValidationToken(ctx, in, opts...) +} + +// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 +func (m *defaultTokenService) CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelTokens(ctx, in, opts...) +} + +// CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token +func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelTokenByDeviceId(ctx, in, opts...) +} + +// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens +func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.GetUserTokensByDeviceId(ctx, in, opts...) +} + +// GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens +func (m *defaultTokenService) GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.GetUserTokensByUid(ctx, in, opts...) +} + +// NewOneTimeToken 建立一次性使用,例如:RefreshToken +func (m *defaultTokenService) NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.NewOneTimeToken(ctx, in, opts...) +} + +// CancelOneTimeToken 取消一次性使用 +func (m *defaultTokenService) CancelOneTimeToken(ctx context.Context, in *CancelOneTimeTokenReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelOneTimeToken(ctx, in, opts...) +} diff --git a/generate/database/mysql/create/20230529020000_create_schema.down.sql b/generate/database/mysql/create/20230529020000_create_schema.down.sql index e7727a5..dc0bfe4 100644 --- a/generate/database/mysql/create/20230529020000_create_schema.down.sql +++ b/generate/database/mysql/create/20230529020000_create_schema.down.sql @@ -1 +1 @@ -DROP DATABASE IF EXISTS `ark_member`; \ No newline at end of file +DROP DATABASE IF EXISTS `ark_permission`; \ No newline at end of file diff --git a/generate/database/mysql/create/20230529020000_create_schema.up.sql b/generate/database/mysql/create/20230529020000_create_schema.up.sql index d997e04..686ffdf 100644 --- a/generate/database/mysql/create/20230529020000_create_schema.up.sql +++ b/generate/database/mysql/create/20230529020000_create_schema.up.sql @@ -1 +1 @@ -CREATE DATABASE IF NOT EXISTS `ark_member`; \ No newline at end of file +CREATE DATABASE IF NOT EXISTS `ark_permission`; \ No newline at end of file diff --git a/internal/domain/repository/token.go b/internal/domain/repository/token.go index be4c207..5f02b98 100644 --- a/internal/domain/repository/token.go +++ b/internal/domain/repository/token.go @@ -6,12 +6,14 @@ import ( "time" ) +// TokenRepository token 的 redis 操作 type TokenRepository interface { + // Create 建立Token Create(ctx context.Context, token entity.Token) error - DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error + // CreateOneTimeToken 建立臨時 Token CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error - GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) + GetAccessTokenByByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) GetAccessTokenCountByUID(uid string) (int, error) @@ -19,6 +21,7 @@ type TokenRepository interface { GetAccessTokenCountByDeviceID(deviceID string) (int, error) Delete(ctx context.Context, token entity.Token) error + DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error DeleteAccessTokenByID(ctx context.Context, ids []string) error DeleteAccessTokensByUID(ctx context.Context, uid string) error DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error diff --git a/internal/logic/tokenservice/cancel_token_logic.go b/internal/logic/tokenservice/cancel_token_logic.go index bf89e59..368f297 100644 --- a/internal/logic/tokenservice/cancel_token_logic.go +++ b/internal/logic/tokenservice/cancel_token_logic.go @@ -1,12 +1,10 @@ package tokenservicelogic import ( - ers "code.30cm.net/wanderland/library-go/errors" - "context" - "ark-permission/gen_result/pb/permission" "ark-permission/internal/svc" - + ers "code.30cm.net/wanderland/library-go/errors" + "context" "github.com/zeromicro/go-zero/core/logx" ) diff --git a/internal/logic/tokenservice/new_one_time_token_logic.go b/internal/logic/tokenservice/new_one_time_token_logic.go index 79db981..389f23a 100644 --- a/internal/logic/tokenservice/new_one_time_token_logic.go +++ b/internal/logic/tokenservice/new_one_time_token_logic.go @@ -28,7 +28,7 @@ func NewNewOneTimeTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *N } } -// NewOneTimeToken 建立一次性使用,例如:RefreshToken +// NewOneTimeToken 建立一次性使用,例如:RefreshToken TODO 目前並無後續操作 func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeTokenReq) (*permission.CreateOneTimeTokenResp, error) { // 驗證所需 if err := l.svcCtx.Validate.ValidateAll(&refreshTokenReq{ diff --git a/internal/logic/tokenservice/refresh_token_logic.go b/internal/logic/tokenservice/refresh_token_logic.go index 607fd28..e059b7a 100644 --- a/internal/logic/tokenservice/refresh_token_logic.go +++ b/internal/logic/tokenservice/refresh_token_logic.go @@ -42,7 +42,7 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi } // step 1 拿看看有沒有這個 refresh token - token, err := l.svcCtx.TokenRedisRepo.GetByRefresh(l.ctx, in.Token) + token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByByOneTimeToken(l.ctx, in.Token) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "TokenRedisRepo.GetByRefresh"), diff --git a/internal/logic/tokenservice/validation_token_logic.go b/internal/logic/tokenservice/validation_token_logic.go index f3baf9e..be6ab1a 100644 --- a/internal/logic/tokenservice/validation_token_logic.go +++ b/internal/logic/tokenservice/validation_token_logic.go @@ -1,11 +1,10 @@ package tokenservicelogic import ( - ers "code.30cm.net/wanderland/library-go/errors" - "context" - "ark-permission/gen_result/pb/permission" "ark-permission/internal/svc" + ers "code.30cm.net/wanderland/library-go/errors" + "context" "github.com/zeromicro/go-zero/core/logx" ) @@ -36,15 +35,13 @@ func (l *ValidationTokenLogic) ValidationToken(in *permission.ValidationTokenReq }); err != nil { return nil, ers.InvalidFormat(err.Error()) } - claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, true) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "parseClaims"), - ).Error(err.Error()) + ).Info(err.Error()) return nil, err } - token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID()) if err != nil { logx.WithCallerSkip(1).WithFields( diff --git a/internal/repository/token.go b/internal/repository/token.go index 17a325f..d91f73c 100644 --- a/internal/repository/token.go +++ b/internal/repository/token.go @@ -33,27 +33,19 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error if err != nil { return ers.ArkInternal("json.Marshal token error", err.Error()) } + if err := t.store.Pipelined(func(tx redis.Pipeliner) error { + refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second - err = t.store.Pipelined(func(tx redis.Pipeliner) error { - // rTTL := token.RedisExpiredSec() - refreshTTL := token.RedisRefreshExpiredSec() - - if err := t.setToken(ctx, tx, token, body, time.Duration(refreshTTL)*time.Second); err != nil { + if err := t.setToken(ctx, tx, token, body, refreshTTL); err != nil { return err } - if err := t.setRefreshToken(ctx, tx, token, time.Duration(refreshTTL)*time.Second); err != nil { + if err := t.setRefreshToken(ctx, tx, token, refreshTTL); err != nil { return err } - err := t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, time.Duration(refreshTTL)*time.Second) - if err != nil { - return err - } - - return nil - }) - if err != nil { + return t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL) + }); err != nil { return domain.RedisPipLineError(err.Error()) } @@ -61,39 +53,28 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error } func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { - err := t.store.Pipelined(func(tx redis.Pipeliner) error { - keys := []string{ - domain.GetAccessTokenRedisKey(token.ID), - domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), - domain.UIDTokenRedisKey.With(token.UID).ToString(), - } - - for _, key := range keys { - if err := tx.Del(ctx, key).Err(); err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) - } - } - - if token.DeviceID != "" { - key := domain.DeviceTokenRedisKey.With(token.DeviceID).ToString() - _, err := t.store.Del(key) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err)) - } - } - - return nil - }) - - if err != nil { - return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) + keys := []string{ + domain.GetAccessTokenRedisKey(token.ID), + domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), } + if err := t.deleteKeys(ctx, keys...); err != nil { + return domain.RedisPipLineError(err.Error()) + } + + _, _ = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID) + _, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) + return nil } -func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) { - return t.get(domain.GetAccessTokenRedisKey(id)) +func (t *tokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) { + token, err := t.get(ctx, domain.GetAccessTokenRedisKey(id)) + if err != nil { + return entity.Token{}, err + } + + return token, nil } func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error { @@ -101,9 +82,9 @@ func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid strin if err != nil { return err } - for _, item := range tokens { - err := t.Delete(ctx, item) - if err != nil { + + for _, token := range tokens { + if err := t.Delete(ctx, token); err != nil { return err } } @@ -111,7 +92,6 @@ func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid strin return nil } -// DeleteAccessTokenByID TODO 要做錯誤處理 func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error { for _, tokenID := range ids { token, err := t.GetAccessTokenByID(ctx, tokenID) @@ -119,338 +99,203 @@ func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []strin continue } - err = t.store.Pipelined(func(tx redis.Pipeliner) error { - keys := []string{ - domain.GetAccessTokenRedisKey(token.ID), - domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), - } + keys := []string{ + domain.GetAccessTokenRedisKey(token.ID), + domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), + } - for _, key := range keys { - if err := tx.Del(ctx, key).Err(); err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) - } - } - - _, err = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err)) - } - - _, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err)) - } - - return nil - }) - if err != nil { + if err := t.deleteKeys(ctx, keys...); err != nil { continue } + + _, _ = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID) + _, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) } return nil } -// GetAccessTokensByUID 透過 uid 得到目前未過期的 token func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) { - utKeys, err := t.store.Smembers(domain.GetUIDTokenRedisKey(uid)) - if err != nil { - // 沒有就視為回空 - if errors.Is(err, redis.Nil) { - return nil, nil - } - - return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByUID store.Get GetUIDTokenRedisKey error: %v", err.Error())) - } - - now := time.Now().UTC() - var tokens []entity.Token - var deleteToken []string - for _, id := range utKeys { - item := &entity.Token{} - tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id)) - if err == nil { - err = json.Unmarshal([]byte(tk), item) - if err != nil { - return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) - } - tokens = append(tokens, *item) - } - - if errors.Is(err, redis.Nil) { - deleteToken = append(deleteToken, id) - } - - if int64(item.ExpiresIn) < now.Unix() { - deleteToken = append(deleteToken, id) - - continue - } - - } - if len(deleteToken) > 0 { - // 如果失敗也沒關係,其他get method撈取時會在判斷是否過期或存在 - _ = t.DeleteAccessTokenByID(ctx, deleteToken) - } - - return tokens, nil + return t.getTokensBySet(ctx, domain.GetUIDTokenRedisKey(uid)) } -func (t *tokenRepository) GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) { - id, err := t.store.Get(domain.RefreshTokenRedisKey.With(refreshToken).ToString()) +func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { + return t.getTokensBySet(ctx, domain.DeviceTokenRedisKey.With(deviceID).ToString()) +} + +func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { + + tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID) if err != nil { - return entity.Token{}, err + return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err)) } - if errors.Is(err, redis.Nil) || id == "" { - return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(refreshToken).ToString()) + var keys []string + for _, token := range tokens { + keys = append(keys, domain.GetAccessTokenRedisKey(token.ID)) + keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()) + } + err = t.store.Pipelined(func(tx redis.Pipeliner) error { + for _, token := range tokens { + _, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) + } + return nil + }) if err != nil { - return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.GetByRefresh refresh token error: %v", err)) + return err + } + + if err := t.deleteKeys(ctx, keys...); err != nil { + return err + } + + _, err = t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString()) + return err +} + +func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) { + return t.getCountBySet(domain.DeviceTokenRedisKey.With(deviceID).ToString()) +} + +func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) { + return t.getCountBySet(domain.UIDTokenRedisKey.With(uid).ToString()) +} + +func (t *tokenRepository) GetAccessTokenByByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) { + id, err := t.store.Get(domain.RefreshTokenRedisKey.With(oneTimeToken).ToString()) + if err != nil { + return entity.Token{}, domain.RedisError(fmt.Sprintf("GetAccessTokenByByOneTimeToken store.Get error: %s", err.Error())) + } + + if id == "" { + return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(oneTimeToken).ToString()) } return t.GetAccessTokenByID(ctx, id) } func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error { - err := t.store.Pipelined(func(tx redis.Pipeliner) error { - keys := make([]string, 0, len(ids)+len(tokens)) + var keys []string - for _, id := range ids { - keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString()) - } - - for _, token := range tokens { - keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()) - } - - for _, key := range keys { - if err := tx.Del(ctx, key).Err(); err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) - } - } - - return nil - }) - - if err != nil { - return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) + for _, id := range ids { + keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString()) } - return nil + for _, token := range tokens { + keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()) + } + + return t.deleteKeys(ctx, keys...) } -func (t *tokenRepository) CreateOneTimeToken(_ context.Context, key string, ticket entity.Ticket, expires time.Duration) error { +func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, expires time.Duration) error { body, err := json.Marshal(ticket) if err != nil { - return ers.InvalidFormat("CreateOneTimeToken json.Marshal error:", err.Error()) + return ers.InvalidFormat("CreateOneTimeToken json.Marshal error", err.Error()) } _, err = t.store.SetnxEx(domain.RefreshTokenRedisKey.With(key).ToString(), string(body), int(expires.Seconds())) if err != nil { - return ers.DBError("CreateOneTimeToken store.set error:", err.Error()) + return domain.RedisError(fmt.Sprintf("CreateOneTimeToken store.SetnxEx error: %s", err.Error())) } return nil } -func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { - utKeys, err := t.store.Smembers(domain.DeviceTokenRedisKey.With(deviceID).ToString()) - if err != nil { - // 沒有就視為回空 - if errors.Is(err, redis.Nil) { - return nil, nil - } - - return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.Get DeviceTokenRedisKey error: %v", err.Error())) - } - - now := time.Now().UTC() - var tokens []entity.Token - var deleteToken []string - for _, id := range utKeys { - item := &entity.Token{} - tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id)) - if err == nil { - err = json.Unmarshal([]byte(tk), item) - if err != nil { - return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) - } - tokens = append(tokens, *item) - } - - if errors.Is(err, redis.Nil) { - deleteToken = append(deleteToken, id) - } - - if int64(item.ExpiresIn) < now.Unix() { - deleteToken = append(deleteToken, id) - - continue - } - - } - if len(deleteToken) > 0 { - // 如果失敗也沒關係,其他get method撈取時會在判斷是否過期或存在 - _ = t.DeleteAccessTokenByID(ctx, deleteToken) - } - - return tokens, nil -} - -func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { - tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err)) - } - - err = t.store.Pipelined(func(tx redis.Pipeliner) error { - for _, token := range tokens { - if err := tx.Del(ctx, domain.GetAccessTokenRedisKey(token.ID)).Err(); err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) - } - - if err := tx.Del(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()).Err(); err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) - } - _, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err)) - } - } - - _, err := t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString()) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err)) - } - - return nil - }) - - if err != nil { - return err - } - - return nil -} - -func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) { - count, err := t.store.Scard(domain.DeviceTokenRedisKey.With(deviceID).ToString()) - if err != nil { - return 0, err - } - - return int(count), nil -} - -func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) { - count, err := t.store.Scard(domain.UIDTokenRedisKey.With(uid).ToString()) - if err != nil { - return 0, err - } - - return int(count), nil -} - // -------------------- Private area -------------------- -func (t *tokenRepository) get(key string) (entity.Token, error) { - body, err := t.store.Get(key) - if errors.Is(err, redis.Nil) || body == "" { - return entity.Token{}, ers.ResourceNotFound("token key not found in redis", key) +func (t *tokenRepository) get(ctx context.Context, key string) (entity.Token, error) { + body, err := t.store.GetCtx(ctx, key) + if err != nil { + return entity.Token{}, domain.RedisError(fmt.Sprintf("token %s not found in redis: %s", key, err.Error())) } - if err != nil { - return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.Get tokenTag error: %v", err)) + if body == "" { + return entity.Token{}, ers.ResourceNotFound("this token not found") } var token entity.Token if err := json.Unmarshal([]byte(body), &token); err != nil { - return entity.Token{}, ers.ArkInternal(fmt.Sprintf("json.Unmarshal token error: %w", err)) + return entity.Token{}, ers.ArkInternal("json.Unmarshal token error", err.Error()) } return token, nil } -func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, rTTL time.Duration) error { - err := tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, rTTL).Err() - if err != nil { - return wrapError("tx.Set GetAccessTokenRedisKey error", err) - } - return nil +func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, ttl time.Duration) error { + return tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, ttl).Err() } -func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { +func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, ttl time.Duration) error { if token.RefreshToken != "" { - err := tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, rTTL).Err() - if err != nil { - return wrapError("tx.Set RefreshToken error", err) - } + return tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, ttl).Err() } return nil } -func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, rttl time.Duration) error { - uidKey := domain.UIDTokenRedisKey.With(uid).ToString() - err := tx.SAdd(ctx, uidKey, tokenID).Err() - if err != nil { - return err - } - err = tx.Expire(ctx, uidKey, rttl).Err() - if err != nil { +func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error { + if err := tx.SAdd(ctx, domain.UIDTokenRedisKey.With(uid).ToString(), tokenID).Err(); err != nil { return err } - deviceKey := domain.DeviceTokenRedisKey.With(deviceID).ToString() - err = tx.SAdd(ctx, deviceKey, tokenID).Err() - if err != nil { - return err - } - err = tx.Expire(ctx, deviceKey, rttl).Err() - if err != nil { + if err := tx.SAdd(ctx, domain.DeviceTokenRedisKey.With(deviceID).ToString(), tokenID).Err(); err != nil { return err } return nil } -// SetUIDToken 將 token 資料放進 uid key中 -func (t *tokenRepository) SetUIDToken(token entity.Token) error { - uidTokens := make(entity.UIDToken) - b, err := t.store.Get(domain.GetUIDTokenRedisKey(token.UID)) - if err != nil && !errors.Is(err, redis.Nil) { - return wrapError("t.store.Get GetUIDTokenRedisKey error", err) - } - - if b != "" { - err = json.Unmarshal([]byte(b), &uidTokens) - if err != nil { - return wrapError("json.Unmarshal GetUIDTokenRedisKey error", err) +func (t *tokenRepository) deleteKeys(ctx context.Context, keys ...string) error { + return t.store.Pipelined(func(tx redis.Pipeliner) error { + for _, key := range keys { + if err := tx.Del(ctx, key).Err(); err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) + } } + return nil + }) +} + +func (t *tokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) { + ids, err := t.store.Smembers(setKey) + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, nil + } + return nil, domain.RedisError(fmt.Sprintf("getTokensBySet store.Get %s error: %v", setKey, err.Error())) } + var tokens []entity.Token + var deleteTokens []string now := time.Now().Unix() - for k, t := range uidTokens { - if t < now { - delete(uidTokens, k) + for _, id := range ids { + token, err := t.get(ctx, domain.GetAccessTokenRedisKey(id)) + if err != nil { + deleteTokens = append(deleteTokens, id) + continue } + + if int64(token.ExpiresIn) < now { + deleteTokens = append(deleteTokens, id) + continue + } + + tokens = append(tokens, token) } - uidTokens[token.ID] = token.RefreshTokenExpiresUnix() - s, err := json.Marshal(uidTokens) - if err != nil { - return wrapError("json.Marshal UIDToken error", err) + if len(deleteTokens) > 0 { + _ = t.DeleteAccessTokenByID(ctx, deleteTokens) } - err = t.store.Setex(domain.GetUIDTokenRedisKey(token.UID), string(s), 86400*30) - if err != nil { - return wrapError("t.store.Setex GetUIDTokenRedisKey error", err) - } - - return nil + return tokens, nil } -func wrapError(message string, err error) error { - return fmt.Errorf("%s: %w", message, err) +func (t *tokenRepository) getCountBySet(setKey string) (int, error) { + count, err := t.store.Scard(setKey) + if err != nil { + return 0, err + } + return int(count), nil }