feat/create_new_token #2

Merged
daniel.w merged 10 commits from feat/create_new_token into main 2024-08-12 14:20:15 +00:00
11 changed files with 381 additions and 317 deletions
Showing only changes of commit db0d903518 - Show all commits

View File

@ -0,0 +1,45 @@
// Code generated by goctl. DO NOT EDIT.
// Source: permission.proto
package permissionservice
import (
"context"
"ark-permission/gen_result/pb/permission"
"github.com/zeromicro/go-zero/zrpc"
"google.golang.org/grpc"
)
type (
AuthorizationReq = permission.AuthorizationReq
CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq
CancelTokenReq = permission.CancelTokenReq
CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq
CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp
DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq
DoTokenByUIDReq = permission.DoTokenByUIDReq
OKResp = permission.OKResp
QueryTokenByUIDReq = permission.QueryTokenByUIDReq
RefreshTokenReq = permission.RefreshTokenReq
RefreshTokenResp = permission.RefreshTokenResp
Token = permission.Token
TokenResp = permission.TokenResp
Tokens = permission.Tokens
ValidationTokenReq = permission.ValidationTokenReq
ValidationTokenResp = permission.ValidationTokenResp
PermissionService interface {
}
defaultPermissionService struct {
cli zrpc.Client
}
)
func NewPermissionService(cli zrpc.Client) PermissionService {
return &defaultPermissionService{
cli: cli,
}
}

View File

@ -0,0 +1,51 @@
// Code generated by goctl. DO NOT EDIT.
// Source: permission.proto
package roleservice
import (
"context"
"ark-permission/gen_result/pb/permission"
"github.com/zeromicro/go-zero/zrpc"
"google.golang.org/grpc"
)
type (
AuthorizationReq = permission.AuthorizationReq
CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq
CancelTokenReq = permission.CancelTokenReq
CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq
CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp
DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq
DoTokenByUIDReq = permission.DoTokenByUIDReq
OKResp = permission.OKResp
QueryTokenByUIDReq = permission.QueryTokenByUIDReq
RefreshTokenReq = permission.RefreshTokenReq
RefreshTokenResp = permission.RefreshTokenResp
Token = permission.Token
TokenResp = permission.TokenResp
Tokens = permission.Tokens
ValidationTokenReq = permission.ValidationTokenReq
ValidationTokenResp = permission.ValidationTokenResp
RoleService interface {
Ping(ctx context.Context, in *OKResp, opts ...grpc.CallOption) (*OKResp, error)
}
defaultRoleService struct {
cli zrpc.Client
}
)
func NewRoleService(cli zrpc.Client) RoleService {
return &defaultRoleService{
cli: cli,
}
}
func (m *defaultRoleService) Ping(ctx context.Context, in *OKResp, opts ...grpc.CallOption) (*OKResp, error) {
client := permission.NewRoleServiceClient(m.cli.Conn())
return client.Ping(ctx, in, opts...)
}

View File

@ -0,0 +1,125 @@
// Code generated by goctl. DO NOT EDIT.
// Source: permission.proto
package tokenservice
import (
"context"
"ark-permission/gen_result/pb/permission"
"github.com/zeromicro/go-zero/zrpc"
"google.golang.org/grpc"
)
type (
AuthorizationReq = permission.AuthorizationReq
CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq
CancelTokenReq = permission.CancelTokenReq
CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq
CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp
DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq
DoTokenByUIDReq = permission.DoTokenByUIDReq
OKResp = permission.OKResp
QueryTokenByUIDReq = permission.QueryTokenByUIDReq
RefreshTokenReq = permission.RefreshTokenReq
RefreshTokenResp = permission.RefreshTokenResp
Token = permission.Token
TokenResp = permission.TokenResp
Tokens = permission.Tokens
ValidationTokenReq = permission.ValidationTokenReq
ValidationTokenResp = permission.ValidationTokenResp
TokenService interface {
// NewToken 建立一個新的 Token例如AccessToken
NewToken(ctx context.Context, in *AuthorizationReq, opts ...grpc.CallOption) (*TokenResp, error)
// RefreshToken 更新目前的token 以及裡面包含的一次性 Token
RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error)
// CancelToken 取消 Token也包含他裡面的 One Time Toke
CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error)
// ValidationToken 驗證這個 Token 有沒有效
ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error)
// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出
CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error)
// CancelTokenByDeviceId 取消 Token 從 Device 視角出發可以選登出這個Device 下所有 token 登出這個Device 下指定token
CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error)
// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens
GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error)
// GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens
GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error)
// NewOneTimeToken 建立一次性使用例如RefreshToken
NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error)
// CancelOneTimeToken 取消一次性使用
CancelOneTimeToken(ctx context.Context, in *CancelOneTimeTokenReq, opts ...grpc.CallOption) (*OKResp, error)
}
defaultTokenService struct {
cli zrpc.Client
}
)
func NewTokenService(cli zrpc.Client) TokenService {
return &defaultTokenService{
cli: cli,
}
}
// NewToken 建立一個新的 Token例如AccessToken
func (m *defaultTokenService) NewToken(ctx context.Context, in *AuthorizationReq, opts ...grpc.CallOption) (*TokenResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.NewToken(ctx, in, opts...)
}
// RefreshToken 更新目前的token 以及裡面包含的一次性 Token
func (m *defaultTokenService) RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.RefreshToken(ctx, in, opts...)
}
// CancelToken 取消 Token也包含他裡面的 One Time Toke
func (m *defaultTokenService) CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.CancelToken(ctx, in, opts...)
}
// ValidationToken 驗證這個 Token 有沒有效
func (m *defaultTokenService) ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.ValidationToken(ctx, in, opts...)
}
// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出
func (m *defaultTokenService) CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.CancelTokens(ctx, in, opts...)
}
// CancelTokenByDeviceId 取消 Token 從 Device 視角出發可以選登出這個Device 下所有 token 登出這個Device 下指定token
func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.CancelTokenByDeviceId(ctx, in, opts...)
}
// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens
func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.GetUserTokensByDeviceId(ctx, in, opts...)
}
// GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens
func (m *defaultTokenService) GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.GetUserTokensByUid(ctx, in, opts...)
}
// NewOneTimeToken 建立一次性使用例如RefreshToken
func (m *defaultTokenService) NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.NewOneTimeToken(ctx, in, opts...)
}
// CancelOneTimeToken 取消一次性使用
func (m *defaultTokenService) CancelOneTimeToken(ctx context.Context, in *CancelOneTimeTokenReq, opts ...grpc.CallOption) (*OKResp, error) {
client := permission.NewTokenServiceClient(m.cli.Conn())
return client.CancelOneTimeToken(ctx, in, opts...)
}

View File

@ -1 +1 @@
DROP DATABASE IF EXISTS `ark_member`; DROP DATABASE IF EXISTS `ark_permission`;

View File

@ -1 +1 @@
CREATE DATABASE IF NOT EXISTS `ark_member`; CREATE DATABASE IF NOT EXISTS `ark_permission`;

View File

@ -6,12 +6,14 @@ import (
"time" "time"
) )
// TokenRepository token 的 redis 操作
type TokenRepository interface { type TokenRepository interface {
// Create 建立Token
Create(ctx context.Context, token entity.Token) error Create(ctx context.Context, token entity.Token) error
DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error // CreateOneTimeToken 建立臨時 Token
CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error
GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error)
GetAccessTokenByByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error)
GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error)
GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error)
GetAccessTokenCountByUID(uid string) (int, error) GetAccessTokenCountByUID(uid string) (int, error)
@ -19,6 +21,7 @@ type TokenRepository interface {
GetAccessTokenCountByDeviceID(deviceID string) (int, error) GetAccessTokenCountByDeviceID(deviceID string) (int, error)
Delete(ctx context.Context, token entity.Token) error Delete(ctx context.Context, token entity.Token) error
DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error
DeleteAccessTokenByID(ctx context.Context, ids []string) error DeleteAccessTokenByID(ctx context.Context, ids []string) error
DeleteAccessTokensByUID(ctx context.Context, uid string) error DeleteAccessTokensByUID(ctx context.Context, uid string) error
DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error

View File

@ -1,12 +1,10 @@
package tokenservicelogic package tokenservicelogic
import ( import (
ers "code.30cm.net/wanderland/library-go/errors"
"context"
"ark-permission/gen_result/pb/permission" "ark-permission/gen_result/pb/permission"
"ark-permission/internal/svc" "ark-permission/internal/svc"
ers "code.30cm.net/wanderland/library-go/errors"
"context"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )

View File

@ -28,7 +28,7 @@ func NewNewOneTimeTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *N
} }
} }
// NewOneTimeToken 建立一次性使用例如RefreshToken // NewOneTimeToken 建立一次性使用例如RefreshToken TODO 目前並無後續操作
func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeTokenReq) (*permission.CreateOneTimeTokenResp, error) { func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeTokenReq) (*permission.CreateOneTimeTokenResp, error) {
// 驗證所需 // 驗證所需
if err := l.svcCtx.Validate.ValidateAll(&refreshTokenReq{ if err := l.svcCtx.Validate.ValidateAll(&refreshTokenReq{

View File

@ -42,7 +42,7 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi
} }
// step 1 拿看看有沒有這個 refresh token // step 1 拿看看有沒有這個 refresh token
token, err := l.svcCtx.TokenRedisRepo.GetByRefresh(l.ctx, in.Token) token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByByOneTimeToken(l.ctx, in.Token)
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields( logx.WithCallerSkip(1).WithFields(
logx.Field("func", "TokenRedisRepo.GetByRefresh"), logx.Field("func", "TokenRedisRepo.GetByRefresh"),

View File

@ -1,11 +1,10 @@
package tokenservicelogic package tokenservicelogic
import ( import (
ers "code.30cm.net/wanderland/library-go/errors"
"context"
"ark-permission/gen_result/pb/permission" "ark-permission/gen_result/pb/permission"
"ark-permission/internal/svc" "ark-permission/internal/svc"
ers "code.30cm.net/wanderland/library-go/errors"
"context"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
) )
@ -36,15 +35,13 @@ func (l *ValidationTokenLogic) ValidationToken(in *permission.ValidationTokenReq
}); err != nil { }); err != nil {
return nil, ers.InvalidFormat(err.Error()) return nil, ers.InvalidFormat(err.Error())
} }
claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, true) claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, true)
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields( logx.WithCallerSkip(1).WithFields(
logx.Field("func", "parseClaims"), logx.Field("func", "parseClaims"),
).Error(err.Error()) ).Info(err.Error())
return nil, err return nil, err
} }
token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID()) token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID())
if err != nil { if err != nil {
logx.WithCallerSkip(1).WithFields( logx.WithCallerSkip(1).WithFields(

View File

@ -33,27 +33,19 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error
if err != nil { if err != nil {
return ers.ArkInternal("json.Marshal token error", err.Error()) return ers.ArkInternal("json.Marshal token error", err.Error())
} }
if err := t.store.Pipelined(func(tx redis.Pipeliner) error {
refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second
err = t.store.Pipelined(func(tx redis.Pipeliner) error { if err := t.setToken(ctx, tx, token, body, refreshTTL); err != nil {
// rTTL := token.RedisExpiredSec()
refreshTTL := token.RedisRefreshExpiredSec()
if err := t.setToken(ctx, tx, token, body, time.Duration(refreshTTL)*time.Second); err != nil {
return err return err
} }
if err := t.setRefreshToken(ctx, tx, token, time.Duration(refreshTTL)*time.Second); err != nil { if err := t.setRefreshToken(ctx, tx, token, refreshTTL); err != nil {
return err return err
} }
err := t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, time.Duration(refreshTTL)*time.Second) return t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL)
if err != nil { }); err != nil {
return err
}
return nil
})
if err != nil {
return domain.RedisPipLineError(err.Error()) return domain.RedisPipLineError(err.Error())
} }
@ -61,39 +53,28 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error
} }
func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error {
err := t.store.Pipelined(func(tx redis.Pipeliner) error { keys := []string{
keys := []string{ domain.GetAccessTokenRedisKey(token.ID),
domain.GetAccessTokenRedisKey(token.ID), domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
domain.UIDTokenRedisKey.With(token.UID).ToString(),
}
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
if token.DeviceID != "" {
key := domain.DeviceTokenRedisKey.With(token.DeviceID).ToString()
_, err := t.store.Del(key)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err))
}
}
return nil
})
if err != nil {
return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err))
} }
if err := t.deleteKeys(ctx, keys...); err != nil {
return domain.RedisPipLineError(err.Error())
}
_, _ = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID)
_, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
return nil return nil
} }
func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) { func (t *tokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) {
return t.get(domain.GetAccessTokenRedisKey(id)) token, err := t.get(ctx, domain.GetAccessTokenRedisKey(id))
if err != nil {
return entity.Token{}, err
}
return token, nil
} }
func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error { func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
@ -101,9 +82,9 @@ func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid strin
if err != nil { if err != nil {
return err return err
} }
for _, item := range tokens {
err := t.Delete(ctx, item) for _, token := range tokens {
if err != nil { if err := t.Delete(ctx, token); err != nil {
return err return err
} }
} }
@ -111,7 +92,6 @@ func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid strin
return nil return nil
} }
// DeleteAccessTokenByID TODO 要做錯誤處理
func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error { func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
for _, tokenID := range ids { for _, tokenID := range ids {
token, err := t.GetAccessTokenByID(ctx, tokenID) token, err := t.GetAccessTokenByID(ctx, tokenID)
@ -119,338 +99,203 @@ func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []strin
continue continue
} }
err = t.store.Pipelined(func(tx redis.Pipeliner) error { keys := []string{
keys := []string{ domain.GetAccessTokenRedisKey(token.ID),
domain.GetAccessTokenRedisKey(token.ID), domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(),
domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), }
}
for _, key := range keys { if err := t.deleteKeys(ctx, keys...); err != nil {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
_, err = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err))
}
_, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err))
}
return nil
})
if err != nil {
continue continue
} }
_, _ = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID)
_, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
} }
return nil return nil
} }
// GetAccessTokensByUID 透過 uid 得到目前未過期的 token
func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) { func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
utKeys, err := t.store.Smembers(domain.GetUIDTokenRedisKey(uid)) return t.getTokensBySet(ctx, domain.GetUIDTokenRedisKey(uid))
if err != nil {
// 沒有就視為回空
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByUID store.Get GetUIDTokenRedisKey error: %v", err.Error()))
}
now := time.Now().UTC()
var tokens []entity.Token
var deleteToken []string
for _, id := range utKeys {
item := &entity.Token{}
tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id))
if err == nil {
err = json.Unmarshal([]byte(tk), item)
if err != nil {
return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err))
}
tokens = append(tokens, *item)
}
if errors.Is(err, redis.Nil) {
deleteToken = append(deleteToken, id)
}
if int64(item.ExpiresIn) < now.Unix() {
deleteToken = append(deleteToken, id)
continue
}
}
if len(deleteToken) > 0 {
// 如果失敗也沒關係其他get method撈取時會在判斷是否過期或存在
_ = t.DeleteAccessTokenByID(ctx, deleteToken)
}
return tokens, nil
} }
func (t *tokenRepository) GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) { func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
id, err := t.store.Get(domain.RefreshTokenRedisKey.With(refreshToken).ToString()) return t.getTokensBySet(ctx, domain.DeviceTokenRedisKey.With(deviceID).ToString())
}
func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID)
if err != nil { if err != nil {
return entity.Token{}, err return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err))
} }
if errors.Is(err, redis.Nil) || id == "" { var keys []string
return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(refreshToken).ToString()) for _, token := range tokens {
keys = append(keys, domain.GetAccessTokenRedisKey(token.ID))
keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString())
} }
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
for _, token := range tokens {
_, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
}
return nil
})
if err != nil { if err != nil {
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.GetByRefresh refresh token error: %v", err)) return err
}
if err := t.deleteKeys(ctx, keys...); err != nil {
return err
}
_, err = t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString())
return err
}
func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) {
return t.getCountBySet(domain.DeviceTokenRedisKey.With(deviceID).ToString())
}
func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) {
return t.getCountBySet(domain.UIDTokenRedisKey.With(uid).ToString())
}
func (t *tokenRepository) GetAccessTokenByByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) {
id, err := t.store.Get(domain.RefreshTokenRedisKey.With(oneTimeToken).ToString())
if err != nil {
return entity.Token{}, domain.RedisError(fmt.Sprintf("GetAccessTokenByByOneTimeToken store.Get error: %s", err.Error()))
}
if id == "" {
return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(oneTimeToken).ToString())
} }
return t.GetAccessTokenByID(ctx, id) return t.GetAccessTokenByID(ctx, id)
} }
func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error { func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
err := t.store.Pipelined(func(tx redis.Pipeliner) error { var keys []string
keys := make([]string, 0, len(ids)+len(tokens))
for _, id := range ids { for _, id := range ids {
keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString()) keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString())
}
for _, token := range tokens {
keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString())
}
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
}
return nil
})
if err != nil {
return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err))
} }
return nil for _, token := range tokens {
keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString())
}
return t.deleteKeys(ctx, keys...)
} }
func (t *tokenRepository) CreateOneTimeToken(_ context.Context, key string, ticket entity.Ticket, expires time.Duration) error { func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, expires time.Duration) error {
body, err := json.Marshal(ticket) body, err := json.Marshal(ticket)
if err != nil { if err != nil {
return ers.InvalidFormat("CreateOneTimeToken json.Marshal error:", err.Error()) return ers.InvalidFormat("CreateOneTimeToken json.Marshal error", err.Error())
} }
_, err = t.store.SetnxEx(domain.RefreshTokenRedisKey.With(key).ToString(), string(body), int(expires.Seconds())) _, err = t.store.SetnxEx(domain.RefreshTokenRedisKey.With(key).ToString(), string(body), int(expires.Seconds()))
if err != nil { if err != nil {
return ers.DBError("CreateOneTimeToken store.set error:", err.Error()) return domain.RedisError(fmt.Sprintf("CreateOneTimeToken store.SetnxEx error: %s", err.Error()))
} }
return nil return nil
} }
func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
utKeys, err := t.store.Smembers(domain.DeviceTokenRedisKey.With(deviceID).ToString())
if err != nil {
// 沒有就視為回空
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.Get DeviceTokenRedisKey error: %v", err.Error()))
}
now := time.Now().UTC()
var tokens []entity.Token
var deleteToken []string
for _, id := range utKeys {
item := &entity.Token{}
tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id))
if err == nil {
err = json.Unmarshal([]byte(tk), item)
if err != nil {
return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err))
}
tokens = append(tokens, *item)
}
if errors.Is(err, redis.Nil) {
deleteToken = append(deleteToken, id)
}
if int64(item.ExpiresIn) < now.Unix() {
deleteToken = append(deleteToken, id)
continue
}
}
if len(deleteToken) > 0 {
// 如果失敗也沒關係其他get method撈取時會在判斷是否過期或存在
_ = t.DeleteAccessTokenByID(ctx, deleteToken)
}
return tokens, nil
}
func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err))
}
err = t.store.Pipelined(func(tx redis.Pipeliner) error {
for _, token := range tokens {
if err := tx.Del(ctx, domain.GetAccessTokenRedisKey(token.ID)).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
if err := tx.Del(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()).Err(); err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
}
_, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID)
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err))
}
}
_, err := t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString())
if err != nil {
return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err))
}
return nil
})
if err != nil {
return err
}
return nil
}
func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) {
count, err := t.store.Scard(domain.DeviceTokenRedisKey.With(deviceID).ToString())
if err != nil {
return 0, err
}
return int(count), nil
}
func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) {
count, err := t.store.Scard(domain.UIDTokenRedisKey.With(uid).ToString())
if err != nil {
return 0, err
}
return int(count), nil
}
// -------------------- Private area -------------------- // -------------------- Private area --------------------
func (t *tokenRepository) get(key string) (entity.Token, error) { func (t *tokenRepository) get(ctx context.Context, key string) (entity.Token, error) {
body, err := t.store.Get(key) body, err := t.store.GetCtx(ctx, key)
if errors.Is(err, redis.Nil) || body == "" { if err != nil {
return entity.Token{}, ers.ResourceNotFound("token key not found in redis", key) return entity.Token{}, domain.RedisError(fmt.Sprintf("token %s not found in redis: %s", key, err.Error()))
} }
if err != nil { if body == "" {
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.Get tokenTag error: %v", err)) return entity.Token{}, ers.ResourceNotFound("this token not found")
} }
var token entity.Token var token entity.Token
if err := json.Unmarshal([]byte(body), &token); err != nil { if err := json.Unmarshal([]byte(body), &token); err != nil {
return entity.Token{}, ers.ArkInternal(fmt.Sprintf("json.Unmarshal token error: %w", err)) return entity.Token{}, ers.ArkInternal("json.Unmarshal token error", err.Error())
} }
return token, nil return token, nil
} }
func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, rTTL time.Duration) error { func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, ttl time.Duration) error {
err := tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, rTTL).Err() return tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, ttl).Err()
if err != nil {
return wrapError("tx.Set GetAccessTokenRedisKey error", err)
}
return nil
} }
func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, ttl time.Duration) error {
if token.RefreshToken != "" { if token.RefreshToken != "" {
err := tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, rTTL).Err() return tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, ttl).Err()
if err != nil {
return wrapError("tx.Set RefreshToken error", err)
}
} }
return nil return nil
} }
func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, rttl time.Duration) error { func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error {
uidKey := domain.UIDTokenRedisKey.With(uid).ToString() if err := tx.SAdd(ctx, domain.UIDTokenRedisKey.With(uid).ToString(), tokenID).Err(); err != nil {
err := tx.SAdd(ctx, uidKey, tokenID).Err()
if err != nil {
return err
}
err = tx.Expire(ctx, uidKey, rttl).Err()
if err != nil {
return err return err
} }
deviceKey := domain.DeviceTokenRedisKey.With(deviceID).ToString() if err := tx.SAdd(ctx, domain.DeviceTokenRedisKey.With(deviceID).ToString(), tokenID).Err(); err != nil {
err = tx.SAdd(ctx, deviceKey, tokenID).Err()
if err != nil {
return err
}
err = tx.Expire(ctx, deviceKey, rttl).Err()
if err != nil {
return err return err
} }
return nil return nil
} }
// SetUIDToken 將 token 資料放進 uid key中 func (t *tokenRepository) deleteKeys(ctx context.Context, keys ...string) error {
func (t *tokenRepository) SetUIDToken(token entity.Token) error { return t.store.Pipelined(func(tx redis.Pipeliner) error {
uidTokens := make(entity.UIDToken) for _, key := range keys {
b, err := t.store.Get(domain.GetUIDTokenRedisKey(token.UID)) if err := tx.Del(ctx, key).Err(); err != nil {
if err != nil && !errors.Is(err, redis.Nil) { return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err))
return wrapError("t.store.Get GetUIDTokenRedisKey error", err) }
}
if b != "" {
err = json.Unmarshal([]byte(b), &uidTokens)
if err != nil {
return wrapError("json.Unmarshal GetUIDTokenRedisKey error", err)
} }
return nil
})
}
func (t *tokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) {
ids, err := t.store.Smembers(setKey)
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, domain.RedisError(fmt.Sprintf("getTokensBySet store.Get %s error: %v", setKey, err.Error()))
} }
var tokens []entity.Token
var deleteTokens []string
now := time.Now().Unix() now := time.Now().Unix()
for k, t := range uidTokens { for _, id := range ids {
if t < now { token, err := t.get(ctx, domain.GetAccessTokenRedisKey(id))
delete(uidTokens, k) if err != nil {
deleteTokens = append(deleteTokens, id)
continue
} }
if int64(token.ExpiresIn) < now {
deleteTokens = append(deleteTokens, id)
continue
}
tokens = append(tokens, token)
} }
uidTokens[token.ID] = token.RefreshTokenExpiresUnix() if len(deleteTokens) > 0 {
s, err := json.Marshal(uidTokens) _ = t.DeleteAccessTokenByID(ctx, deleteTokens)
if err != nil {
return wrapError("json.Marshal UIDToken error", err)
} }
err = t.store.Setex(domain.GetUIDTokenRedisKey(token.UID), string(s), 86400*30) return tokens, nil
if err != nil {
return wrapError("t.store.Setex GetUIDTokenRedisKey error", err)
}
return nil
} }
func wrapError(message string, err error) error { func (t *tokenRepository) getCountBySet(setKey string) (int, error) {
return fmt.Errorf("%s: %w", message, err) count, err := t.store.Scard(setKey)
if err != nil {
return 0, err
}
return int(count), nil
} }