From d82e9e1a54585a425ae52f83d760ee61e61ab5cd Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Tue, 6 Aug 2024 13:59:24 +0800 Subject: [PATCH 01/10] feat: create new token --- etc/permission_example.yaml | 15 + generate/protobuf/permission.proto | 8 +- go.mod | 13 +- internal/config/config.go | 12 +- internal/domain/const.go | 18 + internal/domain/redis.go | 43 + internal/domain/repository/token.go | 10 + internal/entity/claims.go | 8 + internal/entity/token.go | 38 + internal/lib/error/code/define.go | 98 ++ internal/lib/error/code/messsage.go | 13 + internal/lib/error/easy_func.go | 442 +++++++ internal/lib/error/easy_func_test.go | 1031 +++++++++++++++++ internal/lib/error/errors.go | 197 ++++ internal/lib/error/errors_test.go | 297 +++++ internal/lib/middleware/with_context.go | 28 + internal/lib/required/validate.go | 51 + internal/lib/required/validate_option.go | 29 + ....go => cancel_token_by_device_id_logic.go} | 8 +- ..._logic.go => cancel_token_by_uid_logic.go} | 8 +- internal/logic/claims.go | 55 + ... => get_user_tokens_by_device_id_logic.go} | 8 +- ...gic.go => get_user_tokens_by_uid_logic.go} | 8 +- internal/logic/new_token_logic.go | 114 +- internal/repository/token.go | 136 +++ internal/server/token_service_server.go | 24 +- internal/svc/service_context.go | 24 +- permission.go | 3 + tokenservice/token_service.go | 24 +- 29 files changed, 2710 insertions(+), 53 deletions(-) create mode 100644 etc/permission_example.yaml create mode 100644 internal/domain/const.go create mode 100644 internal/domain/redis.go create mode 100644 internal/domain/repository/token.go create mode 100644 internal/entity/claims.go create mode 100644 internal/entity/token.go create mode 100644 internal/lib/error/code/define.go create mode 100644 internal/lib/error/code/messsage.go create mode 100644 internal/lib/error/easy_func.go create mode 100644 internal/lib/error/easy_func_test.go create mode 100644 internal/lib/error/errors.go create mode 100644 internal/lib/error/errors_test.go create mode 100644 internal/lib/middleware/with_context.go create mode 100644 internal/lib/required/validate.go create mode 100644 internal/lib/required/validate_option.go rename internal/logic/{cancel_token_by_device_i_d_logic.go => cancel_token_by_device_id_logic.go} (65%) rename internal/logic/{cancel_token_by_u_i_d_logic.go => cancel_token_by_uid_logic.go} (70%) create mode 100644 internal/logic/claims.go rename internal/logic/{get_user_tokens_by_device_i_d_logic.go => get_user_tokens_by_device_id_logic.go} (66%) rename internal/logic/{get_user_tokens_by_u_i_d_logic.go => get_user_tokens_by_uid_logic.go} (67%) create mode 100644 internal/repository/token.go diff --git a/etc/permission_example.yaml b/etc/permission_example.yaml new file mode 100644 index 0000000..b1f1d42 --- /dev/null +++ b/etc/permission_example.yaml @@ -0,0 +1,15 @@ +Name: permission.rpc +ListenOn: 0.0.0.0:8080 +Etcd: + Hosts: + - 127.0.0.1:2379 + Key: permission.rpc + +RedisCluster: + Host: 127.0.0.1:7001 + Type: cluster + +Token: + Expired: 300 + RefreshExpires: 86500 + Secret: gg88g88 \ No newline at end of file diff --git a/generate/protobuf/permission.proto b/generate/protobuf/permission.proto index 83db18c..717f753 100644 --- a/generate/protobuf/permission.proto +++ b/generate/protobuf/permission.proto @@ -126,15 +126,15 @@ service TokenService { // CancelToken 取消 Token,也包含他裡面的 One Time Toke rpc CancelToken(CancelTokenReq) returns(OKResp); // CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke - rpc CancelTokenByUID(DoTokenByUIDReq) returns(OKResp); + rpc CancelTokenByUid(DoTokenByUIDReq) returns(OKResp); // CancelTokenByDeviceID 取消 Token - rpc CancelTokenByDeviceID(DoTokenByDeviceIDReq) returns(OKResp); + rpc CancelTokenByDeviceId(DoTokenByDeviceIDReq) returns(OKResp); // ValidationToken 驗證這個 Token 有沒有效 rpc ValidationToken(ValidationTokenReq) returns(ValidationTokenResp); // GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens - rpc GetUserTokensByDeviceID(DoTokenByDeviceIDReq) returns(Tokens); + rpc GetUserTokensByDeviceId(DoTokenByDeviceIDReq) returns(Tokens); // GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens - rpc GetUserTokensByUID(DoTokenByUIDReq) returns(Tokens); + rpc GetUserTokensByUid(DoTokenByUIDReq) returns(Tokens); // NewOneTimeToken 建立一次性使用,例如:RefreshToken rpc NewOneTimeToken(CreateOneTimeTokenReq) returns(CreateOneTimeTokenResp); // CancelOneTimeToken 取消一次性使用 diff --git a/go.mod b/go.mod index 357d382..12ccb99 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,11 @@ module ark-permission go 1.22.3 require ( + github.com/go-playground/validator/v10 v10.22.0 + github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/golang/mock v1.6.0 + github.com/google/uuid v1.6.0 + github.com/stretchr/testify v1.9.0 github.com/zeromicro/go-zero v1.7.0 google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.34.2 @@ -18,21 +23,23 @@ require ( github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/fatih/color v1.17.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/jsonpointer v0.19.6 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.22.4 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -41,6 +48,7 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/openzipkin/zipkin-go v0.4.3 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.19.1 // indirect github.com/prometheus/client_model v0.5.0 // indirect github.com/prometheus/common v0.48.0 // indirect @@ -65,6 +73,7 @@ require ( go.uber.org/automaxprocs v1.5.3 // indirect go.uber.org/multierr v1.9.0 // indirect go.uber.org/zap v1.24.0 // indirect + golang.org/x/crypto v0.25.0 // indirect golang.org/x/net v0.27.0 // indirect golang.org/x/oauth2 v0.20.0 // indirect golang.org/x/sys v0.22.0 // indirect diff --git a/internal/config/config.go b/internal/config/config.go index c1f85b9..6cd58e1 100755 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,17 @@ package config -import "github.com/zeromicro/go-zero/zrpc" +import ( + "github.com/zeromicro/go-zero/core/stores/redis" + "github.com/zeromicro/go-zero/zrpc" + "time" +) type Config struct { zrpc.RpcServerConf + RedisCluster redis.RedisConf + Token struct { + RefreshExpires time.Duration + Expired time.Duration + Secret string + } } diff --git a/internal/domain/const.go b/internal/domain/const.go new file mode 100644 index 0000000..82ac1b8 --- /dev/null +++ b/internal/domain/const.go @@ -0,0 +1,18 @@ +package domain + +type GrantType string + +const ( + PasswordCredentials GrantType = "password" + ClientCredentials GrantType = "client_credentials" + Refreshing GrantType = "refresh_token" +) + +const ( + // DefaultRole 預設role + DefaultRole = "user" +) + +const ( + TokenTypeBearer = "Bearer" +) diff --git a/internal/domain/redis.go b/internal/domain/redis.go new file mode 100644 index 0000000..2fb3b69 --- /dev/null +++ b/internal/domain/redis.go @@ -0,0 +1,43 @@ +package domain + +import "strings" + +const ( + TicketKeyPrefix = "tic/" +) + +const ( + ClientDataKey = "permission:clients" +) + +type RedisKey string + +const ( + AccessTokenRedisKey RedisKey = "access_token" + RefreshTokenRedisKey RedisKey = "refresh_token" + DeviceTokenRedisKey RedisKey = "device_token" + UIDTokenRedisKey RedisKey = "uid_token" + TicketRedisKey RedisKey = "ticket" +) + +func (key RedisKey) ToString() string { + return "permission:" + string(key) +} + +func (key RedisKey) With(s ...string) RedisKey { + parts := append([]string{string(key)}, s...) + + return RedisKey(strings.Join(parts, ":")) +} + +func GetAccessTokenRedisKey(id string) string { + return AccessTokenRedisKey.With(id).ToString() +} + +func GetUIDTokenRedisKey(uid string) string { + return UIDTokenRedisKey.With(uid).ToString() +} + +func GetTicketRedisKey(ticket string) string { + return TicketRedisKey.With(ticket).ToString() +} diff --git a/internal/domain/repository/token.go b/internal/domain/repository/token.go new file mode 100644 index 0000000..b5f512a --- /dev/null +++ b/internal/domain/repository/token.go @@ -0,0 +1,10 @@ +package repository + +import ( + "ark-permission/internal/entity" + "context" +) + +type TokenRepository interface { + Create(ctx context.Context, token entity.Token) error +} diff --git a/internal/entity/claims.go b/internal/entity/claims.go new file mode 100644 index 0000000..9271795 --- /dev/null +++ b/internal/entity/claims.go @@ -0,0 +1,8 @@ +package entity + +import "github.com/golang-jwt/jwt/v4" + +type Claims struct { + jwt.RegisteredClaims + Data interface{} `json:"data"` +} diff --git a/internal/entity/token.go b/internal/entity/token.go new file mode 100644 index 0000000..eeadd26 --- /dev/null +++ b/internal/entity/token.go @@ -0,0 +1,38 @@ +package entity + +import "time" + +type Token struct { + ID string `json:"id"` + UID string `json:"uid"` + DeviceID string `json:"device_id"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + AccessCreateAt time.Time `json:"access_create_at"` + RefreshToken string `json:"refresh_token"` + RefreshExpiresIn int `json:"refresh_expires_in"` + RefreshCreateAt time.Time `json:"refresh_create_at"` +} + +func (t *Token) AccessTokenExpires() time.Duration { + return time.Duration(t.ExpiresIn) * time.Second +} + +func (t *Token) RefreshTokenExpires() time.Duration { + return time.Duration(t.RefreshExpiresIn) * time.Second +} + +func (t *Token) RefreshTokenExpiresUnix() int64 { + return time.Now().Add(t.RefreshTokenExpires()).Unix() +} + +func (t *Token) IsExpires() bool { + return t.AccessCreateAt.Add(t.AccessTokenExpires()).Before(time.Now()) +} + +type UIDToken map[string]int64 + +type Ticket struct { + Data interface{} `json:"data"` + Token Token `json:"token"` +} diff --git a/internal/lib/error/code/define.go b/internal/lib/error/code/define.go new file mode 100644 index 0000000..49715a2 --- /dev/null +++ b/internal/lib/error/code/define.go @@ -0,0 +1,98 @@ +package code + +const ( + OK uint32 = 0 +) + +// Scope +const ( + Unset uint32 = iota + CloudEPPortalGW + CloudEPMember +) + +// Category for general operations: 100 - 4900 +const ( + _ = iota + CatInput uint32 = iota * 100 + CatDB + CatResource + CatGRPC + CatAuth + CatSystem + CatPubSub +) + +// CatArk Category for specific app/service: 5000 - 9900 +const ( + CatArk uint32 = (iota + 50) * 100 +) + +// Detail - Input 1xx +const ( + _ = iota + CatInput + InvalidFormat + NotValidImplementation + InvalidRange +) + +// Detail - Database 2xx +const ( + _ = iota + CatDB + DBError // general error + DBDataConvert + DBDuplicate +) + +// Detail - Resource 3xx +const ( + _ = iota + CatResource + ResourceNotFound + InvalidResourceFormat + ResourceAlreadyExist + ResourceInsufficient + InsufficientPermission + InvalidMeasurementID + ResourceExpired + ResourceMigrated + InvalidResourceState + InsufficientQuota + ResourceHasMultiOwner +) + +/* Detail - GRPC */ +// The GRPC detail code uses Go GRPC's built-in codes. +// Refer to "google.golang.org/grpc/codes" for more detail. + +// Detail - Auth 5xx +const ( + _ = iota + CatAuth + Unauthorized + AuthExpired + InvalidPosixTime + SigAndPayloadNotMatched + Forbidden +) + +// Detail - System 6xx +const ( + _ = iota + CatSystem + SystemInternalError + SystemMaintainError + SystemTimeoutError +) + +// Detail - PubSub 7xx +const ( + _ = iota + CatPubSub + Publish + Consume + MsgSizeTooLarge +) + +// Detail - Ark 5xxx +const ( + _ = iota + CatArk + ArkInternal + ArkHttp400 +) diff --git a/internal/lib/error/code/messsage.go b/internal/lib/error/code/messsage.go new file mode 100644 index 0000000..18a4d4f --- /dev/null +++ b/internal/lib/error/code/messsage.go @@ -0,0 +1,13 @@ +package code + +// CatToStr collects general error messages for each Category +// It is used to send back to API caller +var CatToStr = map[uint32]string{ + CatInput: "Invalid Input Data", + CatDB: "Database Error", + CatResource: "Resource Error", + CatGRPC: "Internal Service Communication Error", + CatAuth: "Authentication Error", + CatArk: "Internal Service Communication Error", + CatSystem: "System Error", +} diff --git a/internal/lib/error/easy_func.go b/internal/lib/error/easy_func.go new file mode 100644 index 0000000..2e13bd8 --- /dev/null +++ b/internal/lib/error/easy_func.go @@ -0,0 +1,442 @@ +package error + +import ( + "ark-permission/internal/lib/error/code" + "errors" + "fmt" + "strings" + + "github.com/zeromicro/go-zero/core/logx" + _ "github.com/zeromicro/go-zero/core/logx" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func newErr(scope, detail uint32, msg string) *Err { + cat := detail / 100 * 100 + return &Err{ + category: cat, + code: detail, + scope: scope, + msg: msg, + } +} + +func newBuiltinGRPCErr(scope, detail uint32, msg string) *Err { + return &Err{ + category: code.CatGRPC, + code: detail, + scope: scope, + msg: msg, + } +} + +// FromError tries to let error as Err +// it supports to unwrap error that has Err +// return nil if failed to transfer +func FromError(err error) *Err { + if err == nil { + return nil + } + + var e *Err + if errors.As(err, &e) { + return e + } + + return nil +} + +// FromCode parses code as following +// Decimal: 120314 +// 12 represents Scope +// 03 represents Category +// 14 represents Detail error code +func FromCode(code uint32) *Err { + scope := code / 10000 + detail := code % 10000 + return &Err{ + category: detail / 100 * 100, + code: detail, + scope: scope, + msg: "", + } +} + +// FromGRPCError transfer error to Err +// useful for gRPC client +func FromGRPCError(err error) *Err { + s, _ := status.FromError(err) + e := FromCode(uint32(s.Code())) + e.msg = s.Message() + + // For GRPC built-in code + if e.Scope() == code.Unset && e.Category() == 0 && e.Code() != code.OK { + e = newBuiltinGRPCErr(Scope, e.Code(), s.Message()) + } + + return e +} + +// Deprecated: check GRPCStatus() in Errs struct +// ToGRPCError returns the status.Status +// Useful to return error in gRPC server +func ToGRPCError(e *Err) error { + return status.New(codes.Code(e.FullCode()), e.Error()).Err() +} + +/*** System ***/ + +// SystemTimeoutError returns Err +func SystemTimeoutError(s ...string) *Err { + return newErr(Scope, code.SystemTimeoutError, fmt.Sprintf("system timeout: %s", strings.Join(s, " "))) +} + +// SystemTimeoutErrorL logs error message and returns Err +func SystemTimeoutErrorL(l logx.Logger, s ...string) *Err { + e := SystemTimeoutError(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// SystemInternalError returns Err struct +func SystemInternalError(s ...string) *Err { + return newErr(Scope, code.SystemInternalError, fmt.Sprintf("internal error: %s", strings.Join(s, " "))) +} + +// SystemInternalErrorL logs error message and returns Err +func SystemInternalErrorL(l logx.Logger, s ...string) *Err { + e := SystemInternalError(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// SystemMaintainErrorL logs error message and returns Err +func SystemMaintainErrorL(l logx.Logger, s ...string) *Err { + e := SystemMaintainError(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// SystemMaintainError returns Err struct +func SystemMaintainError(s ...string) *Err { + return newErr(Scope, code.SystemMaintainError, fmt.Sprintf("service under maintenance: %s", strings.Join(s, " "))) +} + +/*** CatInput ***/ + +// InvalidFormat returns Err struct +func InvalidFormat(s ...string) *Err { + return newErr(Scope, code.InvalidFormat, fmt.Sprintf("invalid format: %s", strings.Join(s, " "))) +} + +// InvalidFormatL logs error message and returns Err +func InvalidFormatL(l logx.Logger, s ...string) *Err { + e := InvalidFormat(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InvalidRange returns Err struct +func InvalidRange(s ...string) *Err { + return newErr(Scope, code.InvalidRange, fmt.Sprintf("invalid range: %s", strings.Join(s, " "))) +} + +// InvalidRangeL logs error message and returns Err +func InvalidRangeL(l logx.Logger, s ...string) *Err { + e := InvalidRange(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// NotValidImplementation returns Err struct +func NotValidImplementation(s ...string) *Err { + return newErr(Scope, code.NotValidImplementation, fmt.Sprintf("not valid implementation: %s", strings.Join(s, " "))) +} + +// NotValidImplementationL logs error message and returns Err +func NotValidImplementationL(l logx.Logger, s ...string) *Err { + e := NotValidImplementation(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +/*** CatDB ***/ + +// DBError returns Err +func DBError(s ...string) *Err { + return newErr(Scope, code.DBError, fmt.Sprintf("db error: %s", strings.Join(s, " "))) +} + +// DBErrorL logs error message and returns Err +func DBErrorL(l logx.Logger, s ...string) *Err { + e := DBError(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// DBDataConvert returns Err +func DBDataConvert(s ...string) *Err { + return newErr(Scope, code.DBDataConvert, fmt.Sprintf("data from db convert error: %s", strings.Join(s, " "))) +} + +// DBDataConvertL logs error message and returns Err +func DBDataConvertL(l logx.Logger, s ...string) *Err { + e := DBDataConvert(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// DBDuplicate returns Err +func DBDuplicate(s ...string) *Err { + return newErr(Scope, code.DBDuplicate, fmt.Sprintf("data Duplicate key error: %s", strings.Join(s, " "))) +} + +// DBDuplicateL logs error message and returns Err +func DBDuplicateL(l logx.Logger, s ...string) *Err { + e := DBDuplicate(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +/*** CatResource ***/ + +// ResourceNotFound returns Err and logging +func ResourceNotFound(s ...string) *Err { + return newErr(Scope, code.ResourceNotFound, fmt.Sprintf("resource not found: %s", strings.Join(s, " "))) +} + +// ResourceNotFoundL logs error message and returns Err +func ResourceNotFoundL(l logx.Logger, s ...string) *Err { + e := ResourceNotFound(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InvalidResourceFormat returns Err +func InvalidResourceFormat(s ...string) *Err { + return newErr(Scope, code.InvalidResourceFormat, fmt.Sprintf("invalid resource format: %s", strings.Join(s, " "))) +} + +// InvalidResourceFormatL logs error message and returns Err +func InvalidResourceFormatL(l logx.Logger, s ...string) *Err { + e := InvalidResourceFormat(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InvalidResourceState returns status not correct. +// for example: company should be destroy, agent should be no-sensor/fail-install ... +func InvalidResourceState(s ...string) *Err { + return newErr(Scope, code.InvalidResourceState, fmt.Sprintf("invalid resource state: %s", strings.Join(s, " "))) +} + +// InvalidResourceStateL logs error message and returns status not correct. +func InvalidResourceStateL(l logx.Logger, s ...string) *Err { + e := InvalidResourceState(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +func ResourceInsufficient(s ...string) *Err { + return newErr(Scope, code.ResourceInsufficient, + fmt.Sprintf("insufficient resource: %s", strings.Join(s, " "))) +} + +func ResourceInsufficientL(l logx.Logger, s ...string) *Err { + e := ResourceInsufficient(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InsufficientPermission returns Err +func InsufficientPermission(s ...string) *Err { + return newErr(Scope, code.InsufficientPermission, + fmt.Sprintf("insufficient permission: %s", strings.Join(s, " "))) +} + +// InsufficientPermissionL returns Err and log +func InsufficientPermissionL(l logx.Logger, s ...string) *Err { + e := InsufficientPermission(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// ResourceAlreadyExist returns Err +func ResourceAlreadyExist(s ...string) *Err { + return newErr(Scope, code.ResourceAlreadyExist, fmt.Sprintf("resource already exist: %s", strings.Join(s, " "))) +} + +// ResourceAlreadyExistL logs error message and returns Err +func ResourceAlreadyExistL(l logx.Logger, s ...string) *Err { + e := ResourceAlreadyExist(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InvalidMeasurementID returns Err +func InvalidMeasurementID(s ...string) *Err { + return newErr(Scope, code.InvalidMeasurementID, fmt.Sprintf("missing measurement id: %s", strings.Join(s, " "))) +} + +// InvalidMeasurementIDL logs error message and returns Err +func InvalidMeasurementIDL(l logx.Logger, s ...string) *Err { + e := InvalidMeasurementID(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// ResourceExpired returns Err +func ResourceExpired(s ...string) *Err { + return newErr(Scope, code.ResourceExpired, fmt.Sprintf("resource expired: %s", strings.Join(s, " "))) +} + +// ResourceExpiredL logs error message and returns Err +func ResourceExpiredL(l logx.Logger, s ...string) *Err { + e := ResourceExpired(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// ResourceMigrated returns Err +func ResourceMigrated(s ...string) *Err { + return newErr(Scope, code.ResourceMigrated, fmt.Sprintf("resource migrated: %s", strings.Join(s, " "))) +} + +// ResourceMigratedL logs error message and returns Err +func ResourceMigratedL(l logx.Logger, s ...string) *Err { + e := ResourceMigrated(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InsufficientQuota returns Err +func InsufficientQuota(s ...string) *Err { + return newErr(Scope, code.InsufficientQuota, fmt.Sprintf("insufficient quota: %s", strings.Join(s, " "))) +} + +// InsufficientQuotaL logs error message and returns Err +func InsufficientQuotaL(l logx.Logger, s ...string) *Err { + e := InsufficientQuota(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +/*** CatAuth ***/ + +// Unauthorized returns Err +func Unauthorized(s ...string) *Err { + return newErr(Scope, code.Unauthorized, fmt.Sprintf("unauthorized: %s", strings.Join(s, " "))) +} + +// UnauthorizedL logs error message and returns Err +func UnauthorizedL(l logx.Logger, s ...string) *Err { + e := Unauthorized(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// AuthExpired returns Err +func AuthExpired(s ...string) *Err { + return newErr(Scope, code.AuthExpired, fmt.Sprintf("expired: %s", strings.Join(s, " "))) +} + +// AuthExpiredL logs error message and returns Err +func AuthExpiredL(l logx.Logger, s ...string) *Err { + e := AuthExpired(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InvalidPosixTime returns Err +func InvalidPosixTime(s ...string) *Err { + return newErr(Scope, code.InvalidPosixTime, fmt.Sprintf("invalid posix time: %s", strings.Join(s, " "))) +} + +// InvalidPosixTimeL logs error message and returns Err +func InvalidPosixTimeL(l logx.Logger, s ...string) *Err { + e := InvalidPosixTime(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// SigAndPayloadNotMatched returns Err +func SigAndPayloadNotMatched(s ...string) *Err { + return newErr(Scope, code.SigAndPayloadNotMatched, fmt.Sprintf("signature and the payload are not match: %s", strings.Join(s, " "))) +} + +// SigAndPayloadNotMatchedL logs error message and returns Err +func SigAndPayloadNotMatchedL(l logx.Logger, s ...string) *Err { + e := SigAndPayloadNotMatched(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// Forbidden returns Err +func Forbidden(s ...string) *Err { + return newErr(Scope, code.Forbidden, fmt.Sprintf("forbidden: %s", strings.Join(s, " "))) +} + +// ForbiddenL logs error message and returns Err +func ForbiddenL(l logx.Logger, s ...string) *Err { + e := Forbidden(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// IsAuthUnauthorizedError check the err is unauthorized error +func IsAuthUnauthorizedError(err *Err) bool { + switch err.Code() { + case code.Unauthorized, code.AuthExpired, code.InvalidPosixTime, + code.SigAndPayloadNotMatched, code.Forbidden, + code.InvalidFormat, code.ResourceNotFound: + return true + default: + return false + } +} + +/*** CatXBC ***/ + +// ArkInternal returns Err +func ArkInternal(s ...string) *Err { + return newErr(Scope, code.ArkInternal, fmt.Sprintf("ark internal error: %s", strings.Join(s, " "))) +} + +// ArkInternalL logs error message and returns Err +func ArkInternalL(l logx.Logger, s ...string) *Err { + e := ArkInternal(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +/*** CatPubSub ***/ + +// Publish returns Err +func Publish(s ...string) *Err { + return newErr(Scope, code.Publish, fmt.Sprintf("publish: %s", strings.Join(s, " "))) +} + +// PublishL logs error message and returns Err +func PublishL(l logx.Logger, s ...string) *Err { + e := Publish(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// Consume returns Err +func Consume(s ...string) *Err { + return newErr(Scope, code.Consume, fmt.Sprintf("consume: %s", strings.Join(s, " "))) +} + +// MsgSizeTooLarge returns Err +func MsgSizeTooLarge(s ...string) *Err { + return newErr(Scope, code.MsgSizeTooLarge, fmt.Sprintf("kafka error: %s", strings.Join(s, " "))) +} + +// MsgSizeTooLargeL logs error message and returns Err +func MsgSizeTooLargeL(l logx.Logger, s ...string) *Err { + e := MsgSizeTooLarge(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} diff --git a/internal/lib/error/easy_func_test.go b/internal/lib/error/easy_func_test.go new file mode 100644 index 0000000..5ff951c --- /dev/null +++ b/internal/lib/error/easy_func_test.go @@ -0,0 +1,1031 @@ +package error + +import ( + "context" + "errors" + "fmt" + "member/internal/lib/error/code" + "reflect" + "strconv" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/logx" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestFromGRPCError_GivenStatusWithCodeAndMessage_ShouldReturnErr(t *testing.T) { + // setup + s := status.Error(codes.Code(102399), "FAKE ERROR") + + // act + e := FromGRPCError(s) + + // assert + assert.Equal(t, uint32(10), e.Scope()) + assert.Equal(t, uint32(2300), e.Category()) + assert.Equal(t, uint32(2399), e.Code()) + assert.Equal(t, "FAKE ERROR", e.Error()) +} + +func TestFromGRPCError_GivenNilError_ShouldReturnErr_Scope0_Cat0_Detail0(t *testing.T) { + // setup + var nilError error = nil + + // act + e := FromGRPCError(nilError) + + // assert + assert.Equal(t, uint32(0), e.Scope()) + assert.Equal(t, uint32(0), e.Category()) + assert.Equal(t, uint32(0), e.Code()) + assert.Equal(t, "", e.Error()) +} + +func TestFromGRPCError_GivenGRPCNativeError_ShouldReturnErr_Scope0_CatGRPC_DetailGRPCUnavailable(t *testing.T) { + // setup + msg := "GRPC Unavailable ERROR" + s := status.Error(codes.Code(codes.Unavailable), msg) + + // act + e := FromGRPCError(s) + + // assert + assert.Equal(t, code.Unset, e.Scope()) + assert.Equal(t, code.CatGRPC, e.Category()) + assert.Equal(t, uint32(codes.Unavailable), e.Code()) + assert.Equal(t, msg, e.Error()) +} + +func TestFromGRPCError_GivenGeneralError_ShouldReturnErr_Scope0_CatGRPC_DetailGRPCUnknown(t *testing.T) { + // setup + generalErr := errors.New("general error") + + // act + e := FromGRPCError(generalErr) + + // assert + assert.Equal(t, code.Unset, e.Scope()) + assert.Equal(t, code.CatGRPC, e.Category()) + assert.Equal(t, uint32(codes.Unknown), e.Code()) +} + +func TestToGRPCError_GivenErr_StatusShouldHave_Code112233(t *testing.T) { + // setup + e := Err{scope: 11, code: 2233, msg: "FAKE MSG"} + + // act + err := ToGRPCError(&e) + s, _ := status.FromError(err) + + // assert + assert.Equal(t, 112233, int(s.Code())) + assert.Equal(t, "FAKE MSG", s.Message()) +} + +func TestInvalidFormat_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InvalidFormat("field A", "Error description") + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.InvalidFormat, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Equal(t, e.Error(), "invalid format: field A Error description") +} + +func TestInvalidFormatL_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + // act + e := InvalidFormatL(logx.WithContext(ctx), "field A", "Error description") + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.InvalidFormat, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInvalidRange_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InvalidRange("field A", "Error description") + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.InvalidRange, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Equal(t, e.Error(), "invalid range: field A Error description") +} + +func TestInvalidRangeL_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + // act + e := InvalidRangeL(logx.WithContext(ctx), "field A", "Error description") + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.InvalidRange, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestNotValidImplementation_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := NotValidImplementation("field A", "Error description") + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.NotValidImplementation, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Equal(t, e.Error(), "not valid implementation: field A Error description") +} + +func TestNotValidImplementationL_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + // act + e := NotValidImplementationL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.NotValidImplementation, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestDBError_WithStrings_ShouldHasCatDBAndDetailCodeDBError(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := DBError("field A", "Error description") + + // assert + assert.Equal(t, code.CatDB, e.Category()) + assert.Equal(t, code.DBError, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestDBDataConvert_WithStrings_ShouldHasCatDBAndDetailCodeDBDataConvert(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := DBDataConvert("field A", "Error description") + + // assert + assert.Equal(t, code.CatDB, e.Category()) + assert.Equal(t, code.DBDataConvert, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceNotFound_WithStrings_ShouldHasCatResource_DetailCodeResourceNotFound(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := ResourceNotFound("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceNotFound, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInvalidResourceFormat_WithStrings_ShouldHasCatResource_DetailCodeInvalidResourceFormat(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InvalidResourceFormat("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InvalidResourceFormat, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInvalidResourceState_OK(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InvalidResourceState("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InvalidResourceState, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.EqualError(t, e, "invalid resource state: field A Error description") +} + +func TestInvalidResourceStateL_LogError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := InvalidResourceStateL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InvalidResourceState, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.EqualError(t, e, "invalid resource state: field A Error description") +} + +func TestAuthExpired_OK(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := AuthExpired("field A", "Error description") + + // assert + assert.Equal(t, code.CatAuth, e.Category()) + assert.Equal(t, code.AuthExpired, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestUnauthorized_WithStrings_ShouldHasCatAuth_DetailCodeUnauthorized(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := Unauthorized("field A", "Error description") + + // assert + assert.Equal(t, code.CatAuth, e.Category()) + assert.Equal(t, code.Unauthorized, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInvalidPosixTime_WithStrings_ShouldHasCatAuth_DetailCodeInvalidPosixTime(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InvalidPosixTime("field A", "Error description") + + // assert + assert.Equal(t, code.CatAuth, e.Category()) + assert.Equal(t, code.InvalidPosixTime, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestSigAndPayloadNotMatched_WithStrings_ShouldHasCatAuth_DetailCodeSigAndPayloadNotMatched(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := SigAndPayloadNotMatched("field A", "Error description") + + // assert + assert.Equal(t, code.CatAuth, e.Category()) + assert.Equal(t, code.SigAndPayloadNotMatched, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestForbidden_WithStrings_ShouldHasCatAuth_DetailCodeForbidden(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := Forbidden("field A", "Error description") + + // assert + assert.Equal(t, code.CatAuth, e.Category()) + assert.Equal(t, code.Forbidden, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestXBCInternal_WithStrings_ShouldHasCatResource_DetailCodeXBCInternal(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := ArkInternal("field A", "Error description") + + // assert + assert.Equal(t, code.CatArk, e.Category()) + assert.Equal(t, code.ArkInternal, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestGeneralInternalError_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := SystemInternalError("field A", "Error description") + + // assert + assert.Equal(t, code.CatSystem, e.Category()) + assert.Equal(t, code.SystemInternalError, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestGeneralInternalErrorL_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := SystemInternalErrorL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatSystem, e.Category()) + assert.Equal(t, code.SystemInternalError, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestSystemMaintainError_WithStrings_DetailSystemMaintainError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := SystemMaintainErrorL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatSystem, e.Category()) + assert.Equal(t, code.SystemMaintainError, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceAlreadyExist_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := ResourceAlreadyExist("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceAlreadyExist, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceAlreadyExistL_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := ResourceAlreadyExistL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceAlreadyExist, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceInsufficient_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := ResourceInsufficient("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceInsufficient, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceInsufficientL_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := ResourceInsufficientL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceInsufficient, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInsufficientPermission_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InsufficientPermission("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InsufficientPermission, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInsufficientPermissionL_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := InsufficientPermissionL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InsufficientPermission, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInvalidMeasurementID_WithErrorStrings_ShouldReturnCorrectCodeAndErrorString(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InvalidMeasurementID("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InvalidMeasurementID, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInvalidMeasurementIDL_WithErrorStrings_ShouldReturnCorrectCodeAndErrorStringAndCallLogger(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := InvalidMeasurementIDL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InvalidMeasurementID, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceExpired_OK(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := ResourceExpired("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceExpired, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceExpiredL_LogError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := ResourceExpiredL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceExpired, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceMigrated_OK(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := ResourceMigrated("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceMigrated, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceMigratedL_LogError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := ResourceMigratedL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceMigrated, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInsufficientQuota_OK(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InsufficientQuota("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InsufficientQuota, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInsufficientQuotaL_LogError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := InsufficientQuotaL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InsufficientQuota, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestPublish_WithErrorStrings_ShouldReturnCorrectCodeAndErrorString(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := Publish("field A", "Error description") + + // assert + assert.Equal(t, code.CatPubSub, e.Category()) + assert.Equal(t, code.Publish, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestPublishL_WithErrorStrings_ShouldReturnCorrectCodeAndErrorStringAndCallLogger(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := PublishL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatPubSub, e.Category()) + assert.Equal(t, code.Publish, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestMsgSizeTooLarge_WithErrorStrings_ShouldReturnCorrectCodeAndErrorString(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := MsgSizeTooLarge("Error description") + + // assert + assert.Equal(t, code.CatPubSub, e.Category()) + assert.Equal(t, code.MsgSizeTooLarge, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "kafka error: Error description") +} + +func TestMsgSizeTooLargeL_WithErrorStrings_ShouldReturnCorrectCodeAndErrorStringAndCallLogger(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := MsgSizeTooLargeL(l, "Error description") + + // assert + assert.Equal(t, code.CatPubSub, e.Category()) + assert.Equal(t, code.MsgSizeTooLarge, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "kafka error: Error description") +} + +func TestStructErr_WithInternalErr_ShouldIsFuncReportCorrectly(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + // arrange 2 layers err + layer1Err := fmt.Errorf("layer 1 error") + layer2Err := fmt.Errorf("layer 2: %w", layer1Err) + + // act with error chain: InvalidFormat -> layer 2 err -> layer 1 err + e := InvalidFormat("field A", "Error description") + err := e.Wrap(layer2Err) + if err != nil { + t.Fatalf("Failed to wrap error: %v", err) + } + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.InvalidFormat, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") + + // errors.Is should report correctly + assert.True(t, errors.Is(e, layer1Err)) + assert.True(t, errors.Is(e, layer2Err)) +} + +func TestStructErr_WithInternalErr_ShouldErrorOutputChainErrMessage(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + + // arrange 2 layers err + layer1Err := fmt.Errorf("layer 1 error") + // act with error chain: InvalidFormat -> layer 1 err + e := InvalidFormat("field A", "Error description") + err := e.Wrap(layer1Err) + if err != nil { + t.Fatalf("Failed to wrap error: %v", err) + } + + // assert + assert.Equal(t, "invalid format: field A Error description: layer 1 error", e.Error()) +} + +// arrange a specific err type just for UT +type testErr struct { + code int +} + +func (e *testErr) Error() string { + return strconv.Itoa(e.code) +} + +func TestStructErr_WithInternalErr_ShouldAsFuncReportCorrectly(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + + testE := &testErr{code: 123} + layer2Err := fmt.Errorf("layer 2: %w", testE) + + // act with error chain: InvalidFormat -> layer 2 err -> testErr + e := InvalidFormat("field A", "Error description") + err := e.Wrap(layer2Err) + if err != nil { + t.Fatalf("Failed to wrap error: %v", err) + } + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.InvalidFormat, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") + + // errors.As should report correctly + var internalErr *testErr + assert.True(t, errors.As(e, &internalErr)) + assert.Equal(t, testE, internalErr) +} + +/* +benchmark run for 1 second: +Benchmark_ErrorsIs_OneLayerError-4 148281332 8.68 ns/op 0 B/op 0 allocs/op +Benchmark_ErrorsIs_TwoLayerError-4 35048202 32.4 ns/op 0 B/op 0 allocs/op +Benchmark_ErrorsIs_FourLayerError-4 15309349 81.7 ns/op 0 B/op 0 allocs/op + +Benchmark_ErrorsAs_OneLayerError-4 16893205 70.4 ns/op 0 B/op 0 allocs/op +Benchmark_ErrorsAs_TwoLayerError-4 10568083 112 ns/op 0 B/op 0 allocs/op +Benchmark_ErrorsAs_FourLayerError-4 6307729 188 ns/op 0 B/op 0 allocs/op +*/ +func Benchmark_ErrorsIs_OneLayerError(b *testing.B) { + layer1Err := &testErr{code: 123} + var err error = layer1Err + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errors.Is(err, layer1Err) + } +} + +func Benchmark_ErrorsIs_TwoLayerError(b *testing.B) { + layer1Err := &testErr{code: 123} + + // act with error chain: InvalidFormat(layer 2) -> testErr(layer 1) + layer2Err := InvalidFormat("field A", "Error description") + err := layer2Err.Wrap(layer1Err) + if err != nil { + b.Fatalf("Failed to wrap error: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errors.Is(layer2Err, layer1Err) + } +} + +func Benchmark_ErrorsIs_FourLayerError(b *testing.B) { + layer1Err := &testErr{code: 123} + layer2Err := fmt.Errorf("layer 2: %w", layer1Err) + layer3Err := fmt.Errorf("layer 3: %w", layer2Err) + // act with error chain: InvalidFormat(layer 4) -> Error(layer 3) -> Error(layer 2) -> testErr(layer 1) + layer4Err := InvalidFormat("field A", "Error description") + err := layer4Err.Wrap(layer3Err) + if err != nil { + b.Fatalf("Failed to wrap error: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errors.Is(layer4Err, layer1Err) + } +} + +func Benchmark_ErrorsAs_OneLayerError(b *testing.B) { + layer1Err := &testErr{code: 123} + var err error = layer1Err + + b.ReportAllocs() + b.ResetTimer() + var internalErr *testErr + for i := 0; i < b.N; i++ { + errors.As(err, &internalErr) + } +} + +func Benchmark_ErrorsAs_TwoLayerError(b *testing.B) { + layer1Err := &testErr{code: 123} + + // act with error chain: InvalidFormat(layer 2) -> testErr(layer 1) + layer2Err := InvalidFormat("field A", "Error description") + err := layer2Err.Wrap(layer1Err) + if err != nil { + b.Fatalf("Failed to wrap error: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + var internalErr *testErr + for i := 0; i < b.N; i++ { + errors.As(layer2Err, &internalErr) + } +} + +func Benchmark_ErrorsAs_FourLayerError(b *testing.B) { + layer1Err := &testErr{code: 123} + layer2Err := fmt.Errorf("layer 2: %w", layer1Err) + layer3Err := fmt.Errorf("layer 3: %w", layer2Err) + // act with error chain: InvalidFormat(layer 4) -> Error(layer 3) -> Error(layer 2) -> testErr(layer 1) + layer4Err := InvalidFormat("field A", "Error description") + err := layer4Err.Wrap(layer3Err) + if err != nil { + b.Fatalf("Failed to wrap error: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + var internalErr *testErr + for i := 0; i < b.N; i++ { + errors.As(layer4Err, &internalErr) + } +} + +func TestFromError(t *testing.T) { + tests := []struct { + name string + givenError error + want *Err + }{ + { + "given nil error should return nil", + nil, + nil, + }, + { + "given normal error should return nil", + errors.New("normal error"), + nil, + }, + { + "given Err should return Err", + ResourceNotFound("fake error"), + ResourceNotFound("fake error"), + }, + { + "given error wraps Err should return Err", + fmt.Errorf("outter error wraps %w", ResourceNotFound("fake error")), + ResourceNotFound("fake error"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := FromError(tt.givenError); !reflect.DeepEqual(got, tt.want) { + t.Errorf("FromError() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/lib/error/errors.go b/internal/lib/error/errors.go new file mode 100644 index 0000000..fb16d5c --- /dev/null +++ b/internal/lib/error/errors.go @@ -0,0 +1,197 @@ +package error + +import ( + "ark-permission/internal/lib/error/code" + "errors" + "fmt" + "net/http" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// TODO Error要移到common 包 + +// Scope global variable should be set by service or module +var Scope = code.Unset + +type Err struct { + category uint32 + code uint32 + scope uint32 + msg string + internalErr error +} + +// Error is the interface of error +// Getter function of private property "msg" +func (e *Err) Error() string { + if e == nil { + return "" + } + + // chain the error string if the internal err exists + var internalErrStr string + if e.internalErr != nil { + internalErrStr = e.internalErr.Error() + } + + if e.msg != "" { + if internalErrStr != "" { + return fmt.Sprintf("%s: %s", e.msg, internalErrStr) + } + return e.msg + } + + generalErrStr := e.GeneralError() + if internalErrStr != "" { + return fmt.Sprintf("%s: %s", generalErrStr, internalErrStr) + } + return generalErrStr +} + +// Category getter function of private property "category" +func (e *Err) Category() uint32 { + if e == nil { + return 0 + } + return e.category +} + +// Scope getter function of private property "scope" +func (e *Err) Scope() uint32 { + if e == nil { + return code.Unset + } + + return e.scope +} + +// CodeStr returns the string of error code with zero padding +func (e *Err) CodeStr() string { + if e == nil { + return "00000" + } + + if e.Category() == code.CatGRPC { + return fmt.Sprintf("%d%04d", e.Scope(), e.Category()+e.Code()) + } + + return fmt.Sprintf("%d%04d", e.Scope(), e.Code()) +} + +// Code getter function of private property "code" +func (e *Err) Code() uint32 { + if e == nil { + return code.OK + } + + return e.code +} + +func (e *Err) FullCode() uint32 { + if e == nil { + return 0 + } + + if e.Category() == code.CatGRPC { + return e.Scope()*10000 + e.Category() + e.Code() + } + + return e.Scope()*10000 + e.Code() +} + +// HTTPStatus returns corresponding HTTP status code +func (e *Err) HTTPStatus() int { + if e == nil || e.Code() == code.OK { + return http.StatusOK + } + // determine status code by code + switch e.Code() { + case code.ResourceInsufficient: + // 400 + return http.StatusBadRequest + case code.Unauthorized, code.InsufficientPermission: + // 401 + return http.StatusUnauthorized + case code.InsufficientQuota: + // 402 + return http.StatusPaymentRequired + case code.InvalidPosixTime, code.Forbidden: + // 403 + return http.StatusForbidden + case code.ResourceNotFound: + // 404 + return http.StatusNotFound + case code.ResourceAlreadyExist, code.InvalidResourceState: + // 409 + return http.StatusConflict + case code.NotValidImplementation: + // 501 + return http.StatusNotImplemented + default: + } + + // determine status code by category + switch e.Category() { + case code.CatInput: + return http.StatusBadRequest + default: + // return status code 500 if none of the condition is met + return http.StatusInternalServerError + } +} + +// GeneralError transform category level error message +// It's the general error message for customer/API caller +func (e *Err) GeneralError() string { + if e == nil { + return "" + } + + errStr, ok := code.CatToStr[e.Category()] + if !ok { + return "" + } + + return errStr +} + +// Is called when performing errors.Is(). +// DO NOT USE THIS FUNCTION DIRECTLY unless you are very certain about what you're doing. +// Use errors.Is instead. +// This function compares if two error variables are both *Err, and have the same code (without checking the wrapped internal error) +func (e *Err) Is(f error) bool { + var err *Err + ok := errors.As(f, &err) + if !ok { + return false + } + return e.Code() == err.Code() +} + +// Unwrap returns the underlying error +// The result of unwrapping an error may itself have an Unwrap method; +// we call the sequence of errors produced by repeated unwrapping the error chain. +func (e *Err) Unwrap() error { + if e == nil { + return nil + } + return e.internalErr +} + +// Wrap sets the internal error to Err struct +func (e *Err) Wrap(internalErr error) *Err { + if e != nil { + e.internalErr = internalErr + } + return e +} + +func (e *Err) GRPCStatus() *status.Status { + if e == nil { + return status.New(codes.OK, "") + } + + return status.New(codes.Code(e.FullCode()), e.Error()) +} diff --git a/internal/lib/error/errors_test.go b/internal/lib/error/errors_test.go new file mode 100644 index 0000000..a0f5325 --- /dev/null +++ b/internal/lib/error/errors_test.go @@ -0,0 +1,297 @@ +package error + +import ( + "errors" + "fmt" + "member/internal/lib/error/code" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestCode_GivenNilReceiver_CodeReturnOK_CodeStrReturns00000(t *testing.T) { + // setup + var e *Err = nil + + // act & assert + assert.Equal(t, code.OK, e.Code()) + assert.Equal(t, "00000", e.CodeStr()) + assert.Equal(t, "", e.Error()) +} + +func TestCode_GivenScope99DetailCode6687_ShouldReturn996687(t *testing.T) { + // setup + e := Err{scope: 99, code: 6687} + + // act & assert + assert.Equal(t, uint32(6687), e.Code()) + assert.Equal(t, "996687", e.CodeStr()) +} + +func TestCode_GivenScope0DetailCode87_ShouldReturn87(t *testing.T) { + // setup + e := Err{scope: 0, code: 87} + + // act & assert + assert.Equal(t, uint32(87), e.Code()) + assert.Equal(t, "00087", e.CodeStr()) +} + +func TestFromCode_Given870005_ShouldHasScope87_Cat0_Detail5(t *testing.T) { + // setup + e := FromCode(870005) + + // assert + assert.Equal(t, uint32(87), e.Scope()) + assert.Equal(t, uint32(0), e.Category()) + assert.Equal(t, uint32(5), e.Code()) + assert.Equal(t, "", e.Error()) +} + +func TestFromCode_Given0_ShouldHasScope0_Cat0_Detail0(t *testing.T) { + // setup + e := FromCode(0) + + // assert + assert.Equal(t, uint32(0), e.Scope()) + assert.Equal(t, uint32(0), e.Category()) + assert.Equal(t, uint32(0), e.Code()) + assert.Equal(t, "", e.Error()) +} + +func TestFromCode_Given9105_ShouldHasScope0_Cat9100_Detail9105(t *testing.T) { + // setup + e := FromCode(9105) + + // assert + assert.Equal(t, uint32(0), e.Scope()) + assert.Equal(t, uint32(9100), e.Category()) + assert.Equal(t, uint32(9105), e.Code()) + assert.Equal(t, "", e.Error()) +} + +func TestErr_ShouldImplementErrorFunction(t *testing.T) { + // setup a func return error + f := func() error { return InvalidFormat("fake field") } + + // act + err := f() + + // assert + assert.NotNil(t, err) + assert.Contains(t, fmt.Sprint(err), "fake field") // can be printed +} + +func TestGeneralError_GivenNilErr_ShouldReturnEmptyString(t *testing.T) { + // setup + var e *Err = nil + + // act & assert + assert.Equal(t, "", e.GeneralError()) +} + +func TestGeneralError_GivenNotExistCat_ShouldReturnEmptyString(t *testing.T) { + // setup + e := Err{category: 123456} + + // act & assert + assert.Equal(t, "", e.GeneralError()) +} + +func TestGeneralError_GivenCatDB_ShouldReturnDBError(t *testing.T) { + // setup + e := Err{category: code.CatDB} + catErrStr := code.CatToStr[code.CatDB] + + // act & assert + assert.Equal(t, catErrStr, e.GeneralError()) +} + +func TestError_GivenEmptyMsg_ShouldReturnCatGeneralErrorMessage(t *testing.T) { + // setup + e := Err{category: code.CatDB, msg: ""} + + // act + errMsg := e.Error() + + // assert + assert.Equal(t, code.CatToStr[code.CatDB], errMsg) +} + +func TestError_GivenMsg_ShouldReturnGiveMsg(t *testing.T) { + // setup + e := Err{msg: "FAKE"} + + // act + errMsg := e.Error() + + // assert + assert.Equal(t, "FAKE", errMsg) +} + +func TestIs_GivenNilErr_ShouldReturnFalse(t *testing.T) { + var nilErrs *Err + // act + result := errors.Is(nilErrs, DBError()) + result2 := errors.Is(DBError(), nilErrs) + + // assert + assert.False(t, result) + assert.False(t, result2) +} + +func TestIs_GivenNil_ShouldReturnFalse(t *testing.T) { + // act + result := errors.Is(nil, DBError()) + result2 := errors.Is(DBError(), nil) + + // assert + assert.False(t, result) + assert.False(t, result2) +} + +func TestIs_GivenNilReceiver_ShouldReturnCorrectResult(t *testing.T) { + var nilErr *Err = nil + + // test 1: nilErr != DBError + var dbErr error = DBError("fake db error") + assert.False(t, nilErr.Is(dbErr)) + + // test 2: nilErr != nil error + var nilError error + assert.False(t, nilErr.Is(nilError)) + + // test 3: nilErr == another nilErr + var nilErr2 *Err = nil + assert.True(t, nilErr.Is(nilErr2)) +} + +func TestIs_GivenDBError_ShouldReturnTrue(t *testing.T) { + // setup + dbErr := DBError("fake db error") + + // act + result := errors.Is(dbErr, DBError("not care")) + result2 := errors.Is(DBError(), dbErr) + + // assert + assert.True(t, result) + assert.True(t, result2) +} + +func TestIs_GivenDBErrorAssignToErrorType_ShouldReturnTrue(t *testing.T) { + // setup + var dbErr error = DBError("fake db error") + + // act + result := errors.Is(dbErr, DBError("not care")) + result2 := errors.Is(DBError(), dbErr) + + // assert + assert.True(t, result) + assert.True(t, result2) +} + +func TestWrap_GivenNilErr_ShouldNoPanic(t *testing.T) { + // act & assert + assert.NotPanics(t, func() { + var e *Err = nil + _ = e.Wrap(fmt.Errorf("test")) + }) +} + +func TestWrap_GivenErrorToWrap_ShouldReturnErrorWithWrappedError(t *testing.T) { + // act & assert + wrappedErr := fmt.Errorf("test") + wrappingErr := SystemInternalError("WrappingError").Wrap(wrappedErr) + unWrappedErr := wrappingErr.Unwrap() + + assert.Equal(t, wrappedErr, unWrappedErr) +} + +func TestUnwrap_GivenNilErr_ShouldReturnNil(t *testing.T) { + var e *Err = nil + internalErr := e.Unwrap() + assert.Nil(t, internalErr) +} + +func TestErrorsIs_GivenNilErr_ShouldReturnFalse(t *testing.T) { + var e *Err = nil + assert.False(t, errors.Is(e, fmt.Errorf("test"))) +} + +func TestErrorsAs_GivenNilErr_ShouldReturnFalse(t *testing.T) { + var internalErr *testErr + var e *Err = nil + assert.False(t, errors.As(e, &internalErr)) +} + +func TestGRPCStatus(t *testing.T) { + // setup table driven tests + tests := []struct { + name string + given *Err + expect *status.Status + expectConvert error + }{ + { + "nil errs.Err", + nil, + status.New(codes.OK, ""), + nil, + }, + { + "InvalidFormat Err", + InvalidFormat("fake"), + status.New(codes.Code(101), "invalid format: fake"), + status.New(codes.Code(101), "invalid format: fake").Err(), + }, + } + + // act & assert + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := test.given.GRPCStatus() + assert.Equal(t, test.expect.Code(), s.Code()) + assert.Equal(t, test.expect.Message(), s.Message()) + assert.Equal(t, test.expectConvert, status.Convert(test.given).Err()) + }) + } +} + +func TestErr_HTTPStatus(t *testing.T) { + tests := []struct { + name string + err *Err + want int + }{ + {name: "nil error", err: nil, want: http.StatusOK}, + {name: "invalid measurement id", err: &Err{category: code.CatResource, code: code.InvalidMeasurementID}, want: http.StatusInternalServerError}, + {name: "resource already exists", err: &Err{category: code.CatResource, code: code.ResourceAlreadyExist}, want: http.StatusConflict}, + {name: "invalid resource state", err: &Err{category: code.CatResource, code: code.InvalidResourceState}, want: http.StatusConflict}, + {name: "invalid posix time", err: &Err{category: code.CatAuth, code: code.InvalidPosixTime}, want: http.StatusForbidden}, + {name: "unauthorized", err: &Err{category: code.CatAuth, code: code.Unauthorized}, want: http.StatusUnauthorized}, + {name: "db error", err: &Err{category: code.CatDB, code: code.DBError}, want: http.StatusInternalServerError}, + {name: "insufficient permission", err: &Err{category: code.CatResource, code: code.InsufficientPermission}, want: http.StatusUnauthorized}, + {name: "resource insufficient", err: &Err{category: code.CatResource, code: code.ResourceInsufficient}, want: http.StatusBadRequest}, + {name: "invalid format", err: &Err{category: code.CatInput, code: code.InvalidFormat}, want: http.StatusBadRequest}, + {name: "resource not found", err: &Err{code: code.ResourceNotFound}, want: http.StatusNotFound}, + {name: "ok", err: &Err{code: code.OK}, want: http.StatusOK}, + {name: "not valid implementation", err: &Err{category: code.CatInput, code: code.NotValidImplementation}, want: http.StatusNotImplemented}, + {name: "forbidden", err: &Err{category: code.CatAuth, code: code.Forbidden}, want: http.StatusForbidden}, + {name: "insufficient quota", err: &Err{category: code.CatResource, code: code.InsufficientQuota}, want: http.StatusPaymentRequired}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // act + got := tt.err.HTTPStatus() + + // assert + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/lib/middleware/with_context.go b/internal/lib/middleware/with_context.go new file mode 100644 index 0000000..df06757 --- /dev/null +++ b/internal/lib/middleware/with_context.go @@ -0,0 +1,28 @@ +package middleware + +import ( + ers "ark-permission/internal/lib/error" + "context" + "errors" + "time" + + "github.com/zeromicro/go-zero/core/logx" + "google.golang.org/grpc" +) + +const defaultTimeout = 30 * time.Second + +func TimeoutMiddleware(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + + newCtx, cancelCtx := context.WithTimeout(ctx, defaultTimeout) + defer func() { + cancelCtx() + + if errors.Is(newCtx.Err(), context.DeadlineExceeded) { + err = ers.SystemTimeoutError(info.FullMethod) + logx.Errorf("Method: %s, request %v, timeout: %d", info.FullMethod, req, defaultTimeout) + } + }() + + return handler(ctx, req) +} diff --git a/internal/lib/required/validate.go b/internal/lib/required/validate.go new file mode 100644 index 0000000..6fe9a11 --- /dev/null +++ b/internal/lib/required/validate.go @@ -0,0 +1,51 @@ +package required + +import ( + "fmt" + + "github.com/zeromicro/go-zero/core/logx" + + "github.com/go-playground/validator/v10" +) + +type Validate interface { + ValidateAll(obj any) error + BindToValidator(opts ...Option) error +} + +type Validator struct { + V *validator.Validate +} + +// ValidateAll TODO 要移到common 包 +func (v *Validator) ValidateAll(obj any) error { + err := v.V.Struct(obj) + if err != nil { + return err + } + + return nil +} + +func (v *Validator) BindToValidator(opts ...Option) error { + for _, item := range opts { + err := v.V.RegisterValidation(item.ValidatorName, item.ValidatorFunc) + if err != nil { + return fmt.Errorf("failed to register validator : %w", err) + } + } + + return nil +} + +func MustValidator(option ...Option) Validate { + v := &Validator{ + V: validator.New(), + } + + if err := v.BindToValidator(option...); err != nil { + logx.Error("failed to bind validator") + } + + return v +} diff --git a/internal/lib/required/validate_option.go b/internal/lib/required/validate_option.go new file mode 100644 index 0000000..8a896cb --- /dev/null +++ b/internal/lib/required/validate_option.go @@ -0,0 +1,29 @@ +package required + +import ( + "regexp" + + "github.com/go-playground/validator/v10" +) + +type Option struct { + ValidatorName string + ValidatorFunc func(fl validator.FieldLevel) bool +} + +// WithAccount 創建一個新的 Option 結構,包含自定義的驗證函數,用於驗證 email 和台灣的手機號碼格式 +func WithAccount(tagName string) Option { + return Option{ + ValidatorName: tagName, + ValidatorFunc: func(fl validator.FieldLevel) bool { + value := fl.Field().String() + emailRegex := `^[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}$` + phoneRegex := `^(\+886|0)?9\d{8}$` + + emailMatch, _ := regexp.MatchString(emailRegex, value) + phoneMatch, _ := regexp.MatchString(phoneRegex, value) + + return emailMatch || phoneMatch + }, + } +} diff --git a/internal/logic/cancel_token_by_device_i_d_logic.go b/internal/logic/cancel_token_by_device_id_logic.go similarity index 65% rename from internal/logic/cancel_token_by_device_i_d_logic.go rename to internal/logic/cancel_token_by_device_id_logic.go index 222a7bd..862a8d5 100644 --- a/internal/logic/cancel_token_by_device_i_d_logic.go +++ b/internal/logic/cancel_token_by_device_id_logic.go @@ -9,14 +9,14 @@ import ( "github.com/zeromicro/go-zero/core/logx" ) -type CancelTokenByDeviceIDLogic struct { +type CancelTokenByDeviceIdLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger } -func NewCancelTokenByDeviceIDLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokenByDeviceIDLogic { - return &CancelTokenByDeviceIDLogic{ +func NewCancelTokenByDeviceIdLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokenByDeviceIdLogic { + return &CancelTokenByDeviceIdLogic{ ctx: ctx, svcCtx: svcCtx, Logger: logx.WithContext(ctx), @@ -24,7 +24,7 @@ func NewCancelTokenByDeviceIDLogic(ctx context.Context, svcCtx *svc.ServiceConte } // CancelTokenByDeviceID 取消 Token -func (l *CancelTokenByDeviceIDLogic) CancelTokenByDeviceID(in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { +func (l *CancelTokenByDeviceIdLogic) CancelTokenByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { // todo: add your logic here and delete this line return &permission.OKResp{}, nil diff --git a/internal/logic/cancel_token_by_u_i_d_logic.go b/internal/logic/cancel_token_by_uid_logic.go similarity index 70% rename from internal/logic/cancel_token_by_u_i_d_logic.go rename to internal/logic/cancel_token_by_uid_logic.go index 8d6a001..05b0566 100644 --- a/internal/logic/cancel_token_by_u_i_d_logic.go +++ b/internal/logic/cancel_token_by_uid_logic.go @@ -9,14 +9,14 @@ import ( "github.com/zeromicro/go-zero/core/logx" ) -type CancelTokenByUIDLogic struct { +type CancelTokenByUidLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger } -func NewCancelTokenByUIDLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokenByUIDLogic { - return &CancelTokenByUIDLogic{ +func NewCancelTokenByUidLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokenByUidLogic { + return &CancelTokenByUidLogic{ ctx: ctx, svcCtx: svcCtx, Logger: logx.WithContext(ctx), @@ -24,7 +24,7 @@ func NewCancelTokenByUIDLogic(ctx context.Context, svcCtx *svc.ServiceContext) * } // CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (l *CancelTokenByUIDLogic) CancelTokenByUID(in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { +func (l *CancelTokenByUidLogic) CancelTokenByUid(in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { // todo: add your logic here and delete this line return &permission.OKResp{}, nil diff --git a/internal/logic/claims.go b/internal/logic/claims.go new file mode 100644 index 0000000..937953d --- /dev/null +++ b/internal/logic/claims.go @@ -0,0 +1,55 @@ +package logic + +type claims map[string]string + +func (c claims) SetID(id string) { + c["id"] = id +} + +func (c claims) SetRole(role string) { + c["role"] = role +} + +func (c claims) SetDeviceID(deviceID string) { + c["device_id"] = deviceID +} + +func (c claims) SetScope(scope string) { + c["scope"] = scope +} + +func (c claims) Role() string { + role, ok := c["role"] + if !ok { + return "" + } + + return role +} + +func (c claims) ID() string { + id, ok := c["id"] + if !ok { + return "" + } + + return id +} + +func (c claims) DeviceID() string { + deviceID, ok := c["device_id"] + if !ok { + return "" + } + + return deviceID +} + +func (c claims) UID() string { + uid, ok := c["uid"] + if !ok { + return "" + } + + return uid +} diff --git a/internal/logic/get_user_tokens_by_device_i_d_logic.go b/internal/logic/get_user_tokens_by_device_id_logic.go similarity index 66% rename from internal/logic/get_user_tokens_by_device_i_d_logic.go rename to internal/logic/get_user_tokens_by_device_id_logic.go index ea6cd3d..d633995 100644 --- a/internal/logic/get_user_tokens_by_device_i_d_logic.go +++ b/internal/logic/get_user_tokens_by_device_id_logic.go @@ -9,14 +9,14 @@ import ( "github.com/zeromicro/go-zero/core/logx" ) -type GetUserTokensByDeviceIDLogic struct { +type GetUserTokensByDeviceIdLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger } -func NewGetUserTokensByDeviceIDLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserTokensByDeviceIDLogic { - return &GetUserTokensByDeviceIDLogic{ +func NewGetUserTokensByDeviceIdLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserTokensByDeviceIdLogic { + return &GetUserTokensByDeviceIdLogic{ ctx: ctx, svcCtx: svcCtx, Logger: logx.WithContext(ctx), @@ -24,7 +24,7 @@ func NewGetUserTokensByDeviceIDLogic(ctx context.Context, svcCtx *svc.ServiceCon } // GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens -func (l *GetUserTokensByDeviceIDLogic) GetUserTokensByDeviceID(in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { +func (l *GetUserTokensByDeviceIdLogic) GetUserTokensByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { // todo: add your logic here and delete this line return &permission.Tokens{}, nil diff --git a/internal/logic/get_user_tokens_by_u_i_d_logic.go b/internal/logic/get_user_tokens_by_uid_logic.go similarity index 67% rename from internal/logic/get_user_tokens_by_u_i_d_logic.go rename to internal/logic/get_user_tokens_by_uid_logic.go index 339eeb7..ee70ab1 100644 --- a/internal/logic/get_user_tokens_by_u_i_d_logic.go +++ b/internal/logic/get_user_tokens_by_uid_logic.go @@ -9,14 +9,14 @@ import ( "github.com/zeromicro/go-zero/core/logx" ) -type GetUserTokensByUIDLogic struct { +type GetUserTokensByUidLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger } -func NewGetUserTokensByUIDLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserTokensByUIDLogic { - return &GetUserTokensByUIDLogic{ +func NewGetUserTokensByUidLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserTokensByUidLogic { + return &GetUserTokensByUidLogic{ ctx: ctx, svcCtx: svcCtx, Logger: logx.WithContext(ctx), @@ -24,7 +24,7 @@ func NewGetUserTokensByUIDLogic(ctx context.Context, svcCtx *svc.ServiceContext) } // GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens -func (l *GetUserTokensByUIDLogic) GetUserTokensByUID(in *permission.DoTokenByUIDReq) (*permission.Tokens, error) { +func (l *GetUserTokensByUidLogic) GetUserTokensByUid(in *permission.DoTokenByUIDReq) (*permission.Tokens, error) { // todo: add your logic here and delete this line return &permission.Tokens{}, nil diff --git a/internal/logic/new_token_logic.go b/internal/logic/new_token_logic.go index 180eca1..07fb080 100644 --- a/internal/logic/new_token_logic.go +++ b/internal/logic/new_token_logic.go @@ -1,10 +1,19 @@ package logic import ( - "context" - "ark-permission/gen_result/pb/permission" + "ark-permission/internal/domain" + "ark-permission/internal/entity" + ers "ark-permission/internal/lib/error" "ark-permission/internal/svc" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "time" "github.com/zeromicro/go-zero/core/logx" ) @@ -23,9 +32,106 @@ func NewNewTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *NewToken } } +// https://datatracker.ietf.org/doc/html/rfc6749#section-3.3 +type authorizationReq struct { + GrantType domain.GrantType `json:"grant_type" validate:"required,oneof=password client_credentials refresh_token"` + DeviceID string `json:"device_id"` + Scope string `json:"scope" validate:"required"` + Data map[string]any `json:"data"` + Expires int `json:"expires"` + IsRefreshToken bool `json:"is_refresh_token"` +} + // NewToken 建立一個新的 Token,例如:AccessToken func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.TokenResp, error) { - // todo: add your logic here and delete this line + // 驗證所需 + if err := l.svcCtx.Validate.ValidateAll(&authorizationReq{ + GrantType: domain.GrantType(in.GetGrantType()), + Scope: in.GetScope(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } - return &permission.TokenResp{}, nil + // 準備建立 Token 所需 + now := time.Now().UTC() + expires := int(in.GetExpires()) + refreshExpires := int(in.GetExpires()) + if expires <= 0 { + expires = int(l.svcCtx.Config.Token.Expired.Seconds()) + refreshExpires = expires + } + + // 如果這是一個 Refresh Token 過期時間要比普通的Token 長 + if in.GetIsRefreshToken() { + refreshExpires = int(l.svcCtx.Config.Token.RefreshExpires.Seconds()) + } + + token := entity.Token{ + ID: uuid.Must(uuid.NewRandom()).String(), + DeviceID: in.GetDeviceId(), + ExpiresIn: expires, + RefreshExpiresIn: refreshExpires, + AccessCreateAt: now, + RefreshCreateAt: now, + } + + claims := claims(in.GetData()) + claims.SetRole(domain.DefaultRole) + claims.SetID(token.ID) + claims.SetScope(in.GetScope()) + + token.UID = claims.UID() + + if in.GetDeviceId() != "" { + claims.SetDeviceID(in.GetDeviceId()) + } + + var err error + token.AccessToken, err = generateAccessToken(token, claims, l.svcCtx.Config.Token.Secret) + if err != nil { + return nil, ers.ArkInternal(fmt.Errorf("accessGenerate token error: %w", err).Error()) + } + + if in.GetIsRefreshToken() { + token.RefreshToken = generateRefreshToken(token.AccessToken) + } + + err = l.svcCtx.TokenRedisRepo.Create(l.ctx, token) + if err != nil { + return nil, ers.ArkInternal(fmt.Errorf("tokenRepository.Create error: %w", err).Error()) + } + + return &permission.TokenResp{ + AccessToken: token.AccessToken, + TokenType: domain.TokenTypeBearer, + ExpiresIn: int32(token.ExpiresIn), + RefreshToken: token.RefreshToken, + }, nil +} + +func generateAccessToken(token entity.Token, data any, sign string) (string, error) { + claim := entity.Claims{ + Data: data, + RegisteredClaims: jwt.RegisteredClaims{ + ID: token.ID, + ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)), + Issuer: "permission", + }, + } + + accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claim). + SignedString([]byte(sign)) + if err != nil { + return "", err + } + + return accessToken, nil +} + +func generateRefreshToken(accessToken string) string { + buf := bytes.NewBufferString(accessToken) + h := sha256.New() + _, _ = h.Write(buf.Bytes()) + + return hex.EncodeToString(h.Sum(nil)) } diff --git a/internal/repository/token.go b/internal/repository/token.go new file mode 100644 index 0000000..2f00c71 --- /dev/null +++ b/internal/repository/token.go @@ -0,0 +1,136 @@ +package repository + +import ( + "ark-permission/internal/domain" + "ark-permission/internal/domain/repository" + "ark-permission/internal/entity" + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/zeromicro/go-zero/core/stores/redis" +) + +type TokenRepositoryParam struct { + Store *redis.Redis `name:"redis"` +} + +type tokenRepository struct { + store *redis.Redis +} + +func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository { + return &tokenRepository{ + store: param.Store, + } +} + +func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error { + body, err := json.Marshal(token) + if err != nil { + return wrapError("json.Marshal token error", err) + } + + err = t.store.Pipelined(func(tx redis.Pipeliner) error { + rTTL := token.RefreshTokenExpires() + + if err := t.setToken(ctx, tx, token, body, rTTL); err != nil { + return err + } + + if err := t.setRefreshToken(ctx, tx, token, rTTL); err != nil { + return err + } + + if err := t.setDeviceToken(ctx, tx, token, rTTL); err != nil { + return err + } + + return nil + }) + if err != nil { + return wrapError("store.Pipelined error", err) + } + + if err := t.SetUIDToken(token); err != nil { + return wrapError("SetUIDToken error", err) + } + + return nil +} + +func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, rTTL time.Duration) error { + err := tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, rTTL).Err() + if err != nil { + return wrapError("tx.Set GetAccessTokenRedisKey error", err) + } + return nil +} + +func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { + if token.RefreshToken != "" { + err := tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, rTTL).Err() + if err != nil { + return wrapError("tx.Set RefreshToken error", err) + } + } + return nil +} + +func (t *tokenRepository) setDeviceToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { + if token.DeviceID != "" { + key := domain.DeviceTokenRedisKey.With(token.UID).ToString() + value := fmt.Sprintf("%s-%d", token.ID, token.AccessCreateAt.Add(rTTL).Unix()) + err := tx.HSet(ctx, key, token.DeviceID, value).Err() + if err != nil { + return wrapError("tx.HSet Device Token error", err) + } + err = tx.Expire(ctx, key, rTTL).Err() + if err != nil { + return wrapError("tx.Expire Device Token error", err) + } + } + return nil +} + +// SetUIDToken 將 token 資料放進 uid key中 +func (t *tokenRepository) SetUIDToken(token entity.Token) error { + uidTokens := make(entity.UIDToken) + b, err := t.store.Get(domain.GetUIDTokenRedisKey(token.UID)) + if err != nil && !errors.Is(err, redis.Nil) { + return wrapError("t.store.Get GetUIDTokenRedisKey error", err) + } + + if b != "" { + err = json.Unmarshal([]byte(b), &uidTokens) + if err != nil { + return wrapError("json.Unmarshal GetUIDTokenRedisKey error", err) + } + } + + now := time.Now().Unix() + for k, t := range uidTokens { + if t < now { + delete(uidTokens, k) + } + } + + uidTokens[token.ID] = token.RefreshTokenExpiresUnix() + s, err := json.Marshal(uidTokens) + if err != nil { + return wrapError("json.Marshal UIDToken error", err) + } + + err = t.store.Setex(domain.GetUIDTokenRedisKey(token.UID), string(s), 86400*30) + if err != nil { + return wrapError("t.store.Setex GetUIDTokenRedisKey error", err) + } + + return nil +} + +func wrapError(message string, err error) error { + return fmt.Errorf("%s: %w", message, err) +} diff --git a/internal/server/token_service_server.go b/internal/server/token_service_server.go index 81e5554..5e11a23 100644 --- a/internal/server/token_service_server.go +++ b/internal/server/token_service_server.go @@ -41,15 +41,15 @@ func (s *TokenServiceServer) CancelToken(ctx context.Context, in *permission.Can } // CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (s *TokenServiceServer) CancelTokenByUID(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { - l := logic.NewCancelTokenByUIDLogic(ctx, s.svcCtx) - return l.CancelTokenByUID(in) +func (s *TokenServiceServer) CancelTokenByUid(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { + l := logic.NewCancelTokenByUidLogic(ctx, s.svcCtx) + return l.CancelTokenByUid(in) } // CancelTokenByDeviceID 取消 Token -func (s *TokenServiceServer) CancelTokenByDeviceID(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { - l := logic.NewCancelTokenByDeviceIDLogic(ctx, s.svcCtx) - return l.CancelTokenByDeviceID(in) +func (s *TokenServiceServer) CancelTokenByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { + l := logic.NewCancelTokenByDeviceIdLogic(ctx, s.svcCtx) + return l.CancelTokenByDeviceId(in) } // ValidationToken 驗證這個 Token 有沒有效 @@ -59,15 +59,15 @@ func (s *TokenServiceServer) ValidationToken(ctx context.Context, in *permission } // GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens -func (s *TokenServiceServer) GetUserTokensByDeviceID(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { - l := logic.NewGetUserTokensByDeviceIDLogic(ctx, s.svcCtx) - return l.GetUserTokensByDeviceID(in) +func (s *TokenServiceServer) GetUserTokensByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { + l := logic.NewGetUserTokensByDeviceIdLogic(ctx, s.svcCtx) + return l.GetUserTokensByDeviceId(in) } // GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens -func (s *TokenServiceServer) GetUserTokensByUID(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.Tokens, error) { - l := logic.NewGetUserTokensByUIDLogic(ctx, s.svcCtx) - return l.GetUserTokensByUID(in) +func (s *TokenServiceServer) GetUserTokensByUid(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.Tokens, error) { + l := logic.NewGetUserTokensByUidLogic(ctx, s.svcCtx) + return l.GetUserTokensByUid(in) } // NewOneTimeToken 建立一次性使用,例如:RefreshToken diff --git a/internal/svc/service_context.go b/internal/svc/service_context.go index 9541b6b..5a21250 100644 --- a/internal/svc/service_context.go +++ b/internal/svc/service_context.go @@ -1,13 +1,33 @@ package svc -import "ark-permission/internal/config" +import ( + "ark-permission/internal/config" + "ark-permission/internal/domain/repository" + "ark-permission/internal/lib/required" + repo "ark-permission/internal/repository" + "github.com/zeromicro/go-zero/core/stores/redis" +) type ServiceContext struct { Config config.Config + + Validate required.Validate + Redis redis.Redis + TokenRedisRepo repository.TokenRepository } func NewServiceContext(c config.Config) *ServiceContext { + newRedis, err := redis.NewRedis(c.RedisCluster, redis.Cluster()) + if err != nil { + panic(err) + } + return &ServiceContext{ - Config: c, + Config: c, + Validate: required.MustValidator(), + Redis: *newRedis, + TokenRedisRepo: repo.NewTokenRepository(repo.TokenRepositoryParam{ + Store: newRedis, + }), } } diff --git a/permission.go b/permission.go index 7309554..d7f2864 100644 --- a/permission.go +++ b/permission.go @@ -34,6 +34,9 @@ func main() { }) defer s.Stop() + // // 加入中間件 + // s.AddUnaryInterceptors(middleware.TimeoutMiddleware) + fmt.Printf("Starting rpc server at %s...\n", c.ListenOn) s.Start() } diff --git a/tokenservice/token_service.go b/tokenservice/token_service.go index 074dcad..28be4ee 100644 --- a/tokenservice/token_service.go +++ b/tokenservice/token_service.go @@ -36,15 +36,15 @@ type ( // CancelToken 取消 Token,也包含他裡面的 One Time Toke CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) // CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke - CancelTokenByUID(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) + CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) // CancelTokenByDeviceID 取消 Token - CancelTokenByDeviceID(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) + CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) // ValidationToken 驗證這個 Token 有沒有效 ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) // GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens - GetUserTokensByDeviceID(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) + GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) // GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens - GetUserTokensByUID(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) + GetUserTokensByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) // NewOneTimeToken 建立一次性使用,例如:RefreshToken NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) // CancelOneTimeToken 取消一次性使用 @@ -81,15 +81,15 @@ func (m *defaultTokenService) CancelToken(ctx context.Context, in *CancelTokenRe } // CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (m *defaultTokenService) CancelTokenByUID(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { +func (m *defaultTokenService) CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelTokenByUID(ctx, in, opts...) + return client.CancelTokenByUid(ctx, in, opts...) } // CancelTokenByDeviceID 取消 Token -func (m *defaultTokenService) CancelTokenByDeviceID(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { +func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelTokenByDeviceID(ctx, in, opts...) + return client.CancelTokenByDeviceId(ctx, in, opts...) } // ValidationToken 驗證這個 Token 有沒有效 @@ -99,15 +99,15 @@ func (m *defaultTokenService) ValidationToken(ctx context.Context, in *Validatio } // GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens -func (m *defaultTokenService) GetUserTokensByDeviceID(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { +func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.GetUserTokensByDeviceID(ctx, in, opts...) + return client.GetUserTokensByDeviceId(ctx, in, opts...) } // GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens -func (m *defaultTokenService) GetUserTokensByUID(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) { +func (m *defaultTokenService) GetUserTokensByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.GetUserTokensByUID(ctx, in, opts...) + return client.GetUserTokensByUid(ctx, in, opts...) } // NewOneTimeToken 建立一次性使用,例如:RefreshToken -- 2.40.1 From bb2459f751d75ae5c26f2e462ebc11d3fffd5305 Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Tue, 6 Aug 2024 15:52:42 +0800 Subject: [PATCH 02/10] feat: create new token ut --- go.mod | 1 + internal/logic/new_token_logic.go | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index 12ccb99..040f160 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/google/uuid v1.6.0 github.com/stretchr/testify v1.9.0 github.com/zeromicro/go-zero v1.7.0 + go.uber.org/mock v0.4.0 google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.34.2 ) diff --git a/internal/logic/new_token_logic.go b/internal/logic/new_token_logic.go index 07fb080..2dae534 100644 --- a/internal/logic/new_token_logic.go +++ b/internal/logic/new_token_logic.go @@ -42,6 +42,9 @@ type authorizationReq struct { IsRefreshToken bool `json:"is_refresh_token"` } +var generateAccessTokenFunc = generateAccessToken +var generateRefreshTokenFunc = generateRefreshToken + // NewToken 建立一個新的 Token,例如:AccessToken func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.TokenResp, error) { // 驗證所需 @@ -87,13 +90,13 @@ func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.T } var err error - token.AccessToken, err = generateAccessToken(token, claims, l.svcCtx.Config.Token.Secret) + token.AccessToken, err = generateAccessTokenFunc(token, claims, l.svcCtx.Config.Token.Secret) if err != nil { return nil, ers.ArkInternal(fmt.Errorf("accessGenerate token error: %w", err).Error()) } if in.GetIsRefreshToken() { - token.RefreshToken = generateRefreshToken(token.AccessToken) + token.RefreshToken = generateRefreshTokenFunc(token.AccessToken) } err = l.svcCtx.TokenRedisRepo.Create(l.ctx, token) -- 2.40.1 From 45c3486f5f01e15faa77d71d4d72813a57aa6864 Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Tue, 6 Aug 2024 15:53:34 +0800 Subject: [PATCH 03/10] feat: create new token ut --- internal/logic/new_token_logic_test.go | 209 +++++++++++++++++++++++++ internal/mock/lib/validate.go | 72 +++++++++ internal/mock/repository/token.go | 55 +++++++ 3 files changed, 336 insertions(+) create mode 100644 internal/logic/new_token_logic_test.go create mode 100644 internal/mock/lib/validate.go create mode 100644 internal/mock/repository/token.go diff --git a/internal/logic/new_token_logic_test.go b/internal/logic/new_token_logic_test.go new file mode 100644 index 0000000..35ab56a --- /dev/null +++ b/internal/logic/new_token_logic_test.go @@ -0,0 +1,209 @@ +package logic + +import ( + "ark-permission/gen_result/pb/permission" + "ark-permission/internal/domain" + "ark-permission/internal/entity" + libMock "ark-permission/internal/mock/lib" + repoMock "ark-permission/internal/mock/repository" + "ark-permission/internal/svc" + "errors" + "github.com/stretchr/testify/assert" + + "context" + "github.com/golang-jwt/jwt/v4" + "go.uber.org/mock/gomock" + "testing" + "time" +) + +func TestNewTokenLogic_NewToken(t *testing.T) { + // mock + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + tokenMockRepo := repoMock.NewMockTokenRepository(ctrl) + mockValidate := libMock.NewMockValidate(ctrl) + + sc := svc.ServiceContext{ + TokenRedisRepo: tokenMockRepo, + Validate: mockValidate, + } + + l := NewNewTokenLogic(context.Background(), &sc) + + tests := []struct { + name string + input *permission.AuthorizationReq + setupMocks func() + expectError bool + expected *permission.TokenResp + }{ + { + name: "Valid token request", + input: &permission.AuthorizationReq{ + GrantType: "authorization_code", + DeviceId: "device123", + Scope: "read", + Expires: 3600, + IsRefreshToken: false, + Data: map[string]string{ + "uid": "user123", + }, + }, + setupMocks: func() { + mockValidate.EXPECT().ValidateAll(gomock.Any()).Return(nil) + tokenMockRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil).Do(func(ctx context.Context, token entity.Token) { + token.AccessToken = "access_token" + }) + generateAccessTokenFunc = func(token entity.Token, data any, sign string) (string, error) { + return "access_token", nil + } + generateRefreshTokenFunc = func(accessToken string) string { + return "refresh_token" + } + }, + expectError: false, + expected: &permission.TokenResp{ + AccessToken: "access_token", + TokenType: domain.TokenTypeBearer, + ExpiresIn: 3600, + RefreshToken: "", + }, + }, + { + name: "Validation error", + input: &permission.AuthorizationReq{ + GrantType: "invalid_grant", + DeviceId: "device123", + Scope: "read", + Expires: 3600, + IsRefreshToken: false, + Data: map[string]string{ + "uid": "user123", + }, + }, + setupMocks: func() { + mockValidate.EXPECT().ValidateAll(gomock.Any()).Return(errors.New("invalid grant type")) + }, + expectError: true, + expected: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setupMocks() + + resp, err := l.NewToken(tt.input) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, resp) + } + }) + } +} + +// 測試 generateAccessToken 函數 +func TestGenerateAccessToken(t *testing.T) { + // 定義測試用例 + tests := []struct { + name string + token entity.Token + data any + sign string + shouldFail bool + shouldVerify bool + }{ + { + name: "Valid token with admin role", + token: entity.Token{ + ID: "123", + ExpiresIn: int(time.Now().Add(time.Hour * 24).Unix()), + }, + data: map[string]string{"role": "admin"}, + sign: "secret", + shouldFail: false, + shouldVerify: true, + }, + { + name: "Expired token", + token: entity.Token{ + ID: "456", + ExpiresIn: int(time.Now().Add(-time.Hour * 24).Unix()), // 過期時間 + }, + data: map[string]string{"role": "user"}, + sign: "secret", + shouldFail: false, // 這個測試不會失敗,因為過期檢查通常在驗證時進行 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tokenString, err := generateAccessToken(tt.token, tt.data, tt.sign) + if (err != nil) != tt.shouldFail { + t.Errorf("generateAccessToken() error = %v, shouldFail %v", err, tt.shouldFail) + return + } + + if tt.shouldVerify { + // 驗證生成的 token + parsedToken, err := jwt.ParseWithClaims(tokenString, &entity.Claims{}, func(token *jwt.Token) (interface{}, error) { + return []byte(tt.sign), nil + }) + if err != nil { + t.Errorf("Error parsing token: %v", err) + return + } + + if claims, ok := parsedToken.Claims.(*entity.Claims); ok && parsedToken.Valid { + if claims.ID != tt.token.ID { + t.Errorf("Expected ID %v, got %v", tt.token.ID, claims.ID) + } + if claims.Issuer != "permission" { + t.Errorf("Expected Issuer 'permission', got %v", claims.Issuer) + } + for k, v := range tt.data.(map[string]string) { + if claims.Data.(map[string]any)[k] != v { + t.Errorf("Expected data %v, got %v", v, claims.Data.(map[string]string)[k]) + } + } + } else { + t.Errorf("Invalid token claims") + } + } + }) + } +} + +// 測試 generateRefreshToken 函數 +func TestGenerateRefreshToken(t *testing.T) { + // 定義測試用例 + tests := []struct { + accessToken string + expected string + }{ + { + accessToken: "test_access_token", + expected: "4993552f2cc6c4e57fa5738f9b161a1a4051c8370cddb32514c8f6f4c797801f", + }, + { + accessToken: "another_test_access_token", + expected: "8361833e9a11f829f2be9a00f1939b5a72408ff829451169f3b223c41768cfa2", + }, + { + accessToken: "", + expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + }, + } + + for _, tt := range tests { + t.Run(tt.accessToken, func(t *testing.T) { + got := generateRefreshToken(tt.accessToken) + if got != tt.expected { + t.Errorf("generateRefreshToken(%s) = %s; want %s", tt.accessToken, got, tt.expected) + } + }) + } +} diff --git a/internal/mock/lib/validate.go b/internal/mock/lib/validate.go new file mode 100644 index 0000000..123fe29 --- /dev/null +++ b/internal/mock/lib/validate.go @@ -0,0 +1,72 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./validate.go +// +// Generated by this command: +// +// mockgen -source=./validate.go -destination=../../mock/lib/validate.go -package=lib +// + +// Package lib is a generated GoMock package. +package lib + +import ( + required "ark-permission/internal/lib/required" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockValidate is a mock of Validate interface. +type MockValidate struct { + ctrl *gomock.Controller + recorder *MockValidateMockRecorder +} + +// MockValidateMockRecorder is the mock recorder for MockValidate. +type MockValidateMockRecorder struct { + mock *MockValidate +} + +// NewMockValidate creates a new mock instance. +func NewMockValidate(ctrl *gomock.Controller) *MockValidate { + mock := &MockValidate{ctrl: ctrl} + mock.recorder = &MockValidateMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockValidate) EXPECT() *MockValidateMockRecorder { + return m.recorder +} + +// BindToValidator mocks base method. +func (m *MockValidate) BindToValidator(opts ...required.Option) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BindToValidator", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// BindToValidator indicates an expected call of BindToValidator. +func (mr *MockValidateMockRecorder) BindToValidator(opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BindToValidator", reflect.TypeOf((*MockValidate)(nil).BindToValidator), opts...) +} + +// ValidateAll mocks base method. +func (m *MockValidate) ValidateAll(obj any) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ValidateAll", obj) + ret0, _ := ret[0].(error) + return ret0 +} + +// ValidateAll indicates an expected call of ValidateAll. +func (mr *MockValidateMockRecorder) ValidateAll(obj any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateAll", reflect.TypeOf((*MockValidate)(nil).ValidateAll), obj) +} diff --git a/internal/mock/repository/token.go b/internal/mock/repository/token.go new file mode 100644 index 0000000..6ef99eb --- /dev/null +++ b/internal/mock/repository/token.go @@ -0,0 +1,55 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: ./token.go +// +// Generated by this command: +// +// mockgen -source=./token.go -destination=../../mock/repository/token.go -package=repository +// + +// Package repository is a generated GoMock package. +package repository + +import ( + entity "ark-permission/internal/entity" + context "context" + reflect "reflect" + + gomock "go.uber.org/mock/gomock" +) + +// MockTokenRepository is a mock of TokenRepository interface. +type MockTokenRepository struct { + ctrl *gomock.Controller + recorder *MockTokenRepositoryMockRecorder +} + +// MockTokenRepositoryMockRecorder is the mock recorder for MockTokenRepository. +type MockTokenRepositoryMockRecorder struct { + mock *MockTokenRepository +} + +// NewMockTokenRepository creates a new mock instance. +func NewMockTokenRepository(ctrl *gomock.Controller) *MockTokenRepository { + mock := &MockTokenRepository{ctrl: ctrl} + mock.recorder = &MockTokenRepositoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTokenRepository) EXPECT() *MockTokenRepositoryMockRecorder { + return m.recorder +} + +// Create mocks base method. +func (m *MockTokenRepository) Create(ctx context.Context, token entity.Token) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Create", ctx, token) + ret0, _ := ret[0].(error) + return ret0 +} + +// Create indicates an expected call of Create. +func (mr *MockTokenRepositoryMockRecorder) Create(ctx, token any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockTokenRepository)(nil).Create), ctx, token) +} -- 2.40.1 From df0f4e426afeb9e53a63e537c1331ebb93725a8d Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Thu, 8 Aug 2024 11:02:13 +0800 Subject: [PATCH 04/10] update errors from lib to command --- go.mod | 5 +- internal/domain/errors.go | 20 + internal/domain/repository/token.go | 2 + internal/lib/error/code/define.go | 98 -- internal/lib/error/code/messsage.go | 13 - internal/lib/error/easy_func.go | 442 ------- internal/lib/error/easy_func_test.go | 1031 ----------------- internal/lib/error/errors.go | 197 ---- internal/lib/error/errors_test.go | 297 ----- internal/logic/cancel_token_logic.go | 37 +- internal/logic/new_token_logic.go | 33 +- internal/logic/new_token_logic_test.go | 184 ++- internal/logic/refresh_token_logic.go | 5 + internal/logic/{claims.go => utils_claims.go} | 0 internal/logic/utils_jwt.go | 76 ++ internal/repository/token.go | 55 + internal/svc/service_context.go | 3 + permission.go | 2 +- 18 files changed, 285 insertions(+), 2215 deletions(-) create mode 100644 internal/domain/errors.go delete mode 100644 internal/lib/error/code/define.go delete mode 100644 internal/lib/error/code/messsage.go delete mode 100644 internal/lib/error/easy_func.go delete mode 100644 internal/lib/error/easy_func_test.go delete mode 100644 internal/lib/error/errors.go delete mode 100644 internal/lib/error/errors_test.go rename internal/logic/{claims.go => utils_claims.go} (100%) create mode 100644 internal/logic/utils_jwt.go diff --git a/go.mod b/go.mod index 040f160..25a35bc 100644 --- a/go.mod +++ b/go.mod @@ -3,11 +3,10 @@ module ark-permission go 1.22.3 require ( + code.30cm.net/wanderland/library-go/errors v1.0.1 github.com/go-playground/validator/v10 v10.22.0 github.com/golang-jwt/jwt/v4 v4.5.0 - github.com/golang/mock v1.6.0 github.com/google/uuid v1.6.0 - github.com/stretchr/testify v1.9.0 github.com/zeromicro/go-zero v1.7.0 go.uber.org/mock v0.4.0 google.golang.org/grpc v1.65.0 @@ -33,6 +32,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect github.com/google/go-cmp v0.6.0 // indirect @@ -49,7 +49,6 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/openzipkin/zipkin-go v0.4.3 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.19.1 // indirect github.com/prometheus/client_model v0.5.0 // indirect github.com/prometheus/common v0.48.0 // indirect diff --git a/internal/domain/errors.go b/internal/domain/errors.go new file mode 100644 index 0000000..a0b6a03 --- /dev/null +++ b/internal/domain/errors.go @@ -0,0 +1,20 @@ +package domain + +import ( + ers "code.30cm.net/wanderland/library-go/errors" + "code.30cm.net/wanderland/library-go/errors/code" +) + +// Decimal: 120314 +// 12 represents Scope +// 03 represents Category +// 14 represents Detail error code + +const ( + TokenUnexpectedSigning = 1 +) + +// TokenUnexpectedSigningErr 031011 +func TokenUnexpectedSigningErr(msg string) *ers.Err { + return ers.NewErr(code.CloudEPPermission, code.CatInput, code.InvalidFormat, msg) +} diff --git a/internal/domain/repository/token.go b/internal/domain/repository/token.go index b5f512a..8d7de3e 100644 --- a/internal/domain/repository/token.go +++ b/internal/domain/repository/token.go @@ -7,4 +7,6 @@ import ( type TokenRepository interface { Create(ctx context.Context, token entity.Token) error + GetByAccess(ctx context.Context, id string) (entity.Token, error) + Delete(ctx context.Context, token entity.Token) error } diff --git a/internal/lib/error/code/define.go b/internal/lib/error/code/define.go deleted file mode 100644 index 49715a2..0000000 --- a/internal/lib/error/code/define.go +++ /dev/null @@ -1,98 +0,0 @@ -package code - -const ( - OK uint32 = 0 -) - -// Scope -const ( - Unset uint32 = iota - CloudEPPortalGW - CloudEPMember -) - -// Category for general operations: 100 - 4900 -const ( - _ = iota - CatInput uint32 = iota * 100 - CatDB - CatResource - CatGRPC - CatAuth - CatSystem - CatPubSub -) - -// CatArk Category for specific app/service: 5000 - 9900 -const ( - CatArk uint32 = (iota + 50) * 100 -) - -// Detail - Input 1xx -const ( - _ = iota + CatInput - InvalidFormat - NotValidImplementation - InvalidRange -) - -// Detail - Database 2xx -const ( - _ = iota + CatDB - DBError // general error - DBDataConvert - DBDuplicate -) - -// Detail - Resource 3xx -const ( - _ = iota + CatResource - ResourceNotFound - InvalidResourceFormat - ResourceAlreadyExist - ResourceInsufficient - InsufficientPermission - InvalidMeasurementID - ResourceExpired - ResourceMigrated - InvalidResourceState - InsufficientQuota - ResourceHasMultiOwner -) - -/* Detail - GRPC */ -// The GRPC detail code uses Go GRPC's built-in codes. -// Refer to "google.golang.org/grpc/codes" for more detail. - -// Detail - Auth 5xx -const ( - _ = iota + CatAuth - Unauthorized - AuthExpired - InvalidPosixTime - SigAndPayloadNotMatched - Forbidden -) - -// Detail - System 6xx -const ( - _ = iota + CatSystem - SystemInternalError - SystemMaintainError - SystemTimeoutError -) - -// Detail - PubSub 7xx -const ( - _ = iota + CatPubSub - Publish - Consume - MsgSizeTooLarge -) - -// Detail - Ark 5xxx -const ( - _ = iota + CatArk - ArkInternal - ArkHttp400 -) diff --git a/internal/lib/error/code/messsage.go b/internal/lib/error/code/messsage.go deleted file mode 100644 index 18a4d4f..0000000 --- a/internal/lib/error/code/messsage.go +++ /dev/null @@ -1,13 +0,0 @@ -package code - -// CatToStr collects general error messages for each Category -// It is used to send back to API caller -var CatToStr = map[uint32]string{ - CatInput: "Invalid Input Data", - CatDB: "Database Error", - CatResource: "Resource Error", - CatGRPC: "Internal Service Communication Error", - CatAuth: "Authentication Error", - CatArk: "Internal Service Communication Error", - CatSystem: "System Error", -} diff --git a/internal/lib/error/easy_func.go b/internal/lib/error/easy_func.go deleted file mode 100644 index 2e13bd8..0000000 --- a/internal/lib/error/easy_func.go +++ /dev/null @@ -1,442 +0,0 @@ -package error - -import ( - "ark-permission/internal/lib/error/code" - "errors" - "fmt" - "strings" - - "github.com/zeromicro/go-zero/core/logx" - _ "github.com/zeromicro/go-zero/core/logx" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func newErr(scope, detail uint32, msg string) *Err { - cat := detail / 100 * 100 - return &Err{ - category: cat, - code: detail, - scope: scope, - msg: msg, - } -} - -func newBuiltinGRPCErr(scope, detail uint32, msg string) *Err { - return &Err{ - category: code.CatGRPC, - code: detail, - scope: scope, - msg: msg, - } -} - -// FromError tries to let error as Err -// it supports to unwrap error that has Err -// return nil if failed to transfer -func FromError(err error) *Err { - if err == nil { - return nil - } - - var e *Err - if errors.As(err, &e) { - return e - } - - return nil -} - -// FromCode parses code as following -// Decimal: 120314 -// 12 represents Scope -// 03 represents Category -// 14 represents Detail error code -func FromCode(code uint32) *Err { - scope := code / 10000 - detail := code % 10000 - return &Err{ - category: detail / 100 * 100, - code: detail, - scope: scope, - msg: "", - } -} - -// FromGRPCError transfer error to Err -// useful for gRPC client -func FromGRPCError(err error) *Err { - s, _ := status.FromError(err) - e := FromCode(uint32(s.Code())) - e.msg = s.Message() - - // For GRPC built-in code - if e.Scope() == code.Unset && e.Category() == 0 && e.Code() != code.OK { - e = newBuiltinGRPCErr(Scope, e.Code(), s.Message()) - } - - return e -} - -// Deprecated: check GRPCStatus() in Errs struct -// ToGRPCError returns the status.Status -// Useful to return error in gRPC server -func ToGRPCError(e *Err) error { - return status.New(codes.Code(e.FullCode()), e.Error()).Err() -} - -/*** System ***/ - -// SystemTimeoutError returns Err -func SystemTimeoutError(s ...string) *Err { - return newErr(Scope, code.SystemTimeoutError, fmt.Sprintf("system timeout: %s", strings.Join(s, " "))) -} - -// SystemTimeoutErrorL logs error message and returns Err -func SystemTimeoutErrorL(l logx.Logger, s ...string) *Err { - e := SystemTimeoutError(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// SystemInternalError returns Err struct -func SystemInternalError(s ...string) *Err { - return newErr(Scope, code.SystemInternalError, fmt.Sprintf("internal error: %s", strings.Join(s, " "))) -} - -// SystemInternalErrorL logs error message and returns Err -func SystemInternalErrorL(l logx.Logger, s ...string) *Err { - e := SystemInternalError(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// SystemMaintainErrorL logs error message and returns Err -func SystemMaintainErrorL(l logx.Logger, s ...string) *Err { - e := SystemMaintainError(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// SystemMaintainError returns Err struct -func SystemMaintainError(s ...string) *Err { - return newErr(Scope, code.SystemMaintainError, fmt.Sprintf("service under maintenance: %s", strings.Join(s, " "))) -} - -/*** CatInput ***/ - -// InvalidFormat returns Err struct -func InvalidFormat(s ...string) *Err { - return newErr(Scope, code.InvalidFormat, fmt.Sprintf("invalid format: %s", strings.Join(s, " "))) -} - -// InvalidFormatL logs error message and returns Err -func InvalidFormatL(l logx.Logger, s ...string) *Err { - e := InvalidFormat(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// InvalidRange returns Err struct -func InvalidRange(s ...string) *Err { - return newErr(Scope, code.InvalidRange, fmt.Sprintf("invalid range: %s", strings.Join(s, " "))) -} - -// InvalidRangeL logs error message and returns Err -func InvalidRangeL(l logx.Logger, s ...string) *Err { - e := InvalidRange(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// NotValidImplementation returns Err struct -func NotValidImplementation(s ...string) *Err { - return newErr(Scope, code.NotValidImplementation, fmt.Sprintf("not valid implementation: %s", strings.Join(s, " "))) -} - -// NotValidImplementationL logs error message and returns Err -func NotValidImplementationL(l logx.Logger, s ...string) *Err { - e := NotValidImplementation(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -/*** CatDB ***/ - -// DBError returns Err -func DBError(s ...string) *Err { - return newErr(Scope, code.DBError, fmt.Sprintf("db error: %s", strings.Join(s, " "))) -} - -// DBErrorL logs error message and returns Err -func DBErrorL(l logx.Logger, s ...string) *Err { - e := DBError(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// DBDataConvert returns Err -func DBDataConvert(s ...string) *Err { - return newErr(Scope, code.DBDataConvert, fmt.Sprintf("data from db convert error: %s", strings.Join(s, " "))) -} - -// DBDataConvertL logs error message and returns Err -func DBDataConvertL(l logx.Logger, s ...string) *Err { - e := DBDataConvert(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// DBDuplicate returns Err -func DBDuplicate(s ...string) *Err { - return newErr(Scope, code.DBDuplicate, fmt.Sprintf("data Duplicate key error: %s", strings.Join(s, " "))) -} - -// DBDuplicateL logs error message and returns Err -func DBDuplicateL(l logx.Logger, s ...string) *Err { - e := DBDuplicate(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -/*** CatResource ***/ - -// ResourceNotFound returns Err and logging -func ResourceNotFound(s ...string) *Err { - return newErr(Scope, code.ResourceNotFound, fmt.Sprintf("resource not found: %s", strings.Join(s, " "))) -} - -// ResourceNotFoundL logs error message and returns Err -func ResourceNotFoundL(l logx.Logger, s ...string) *Err { - e := ResourceNotFound(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// InvalidResourceFormat returns Err -func InvalidResourceFormat(s ...string) *Err { - return newErr(Scope, code.InvalidResourceFormat, fmt.Sprintf("invalid resource format: %s", strings.Join(s, " "))) -} - -// InvalidResourceFormatL logs error message and returns Err -func InvalidResourceFormatL(l logx.Logger, s ...string) *Err { - e := InvalidResourceFormat(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// InvalidResourceState returns status not correct. -// for example: company should be destroy, agent should be no-sensor/fail-install ... -func InvalidResourceState(s ...string) *Err { - return newErr(Scope, code.InvalidResourceState, fmt.Sprintf("invalid resource state: %s", strings.Join(s, " "))) -} - -// InvalidResourceStateL logs error message and returns status not correct. -func InvalidResourceStateL(l logx.Logger, s ...string) *Err { - e := InvalidResourceState(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -func ResourceInsufficient(s ...string) *Err { - return newErr(Scope, code.ResourceInsufficient, - fmt.Sprintf("insufficient resource: %s", strings.Join(s, " "))) -} - -func ResourceInsufficientL(l logx.Logger, s ...string) *Err { - e := ResourceInsufficient(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// InsufficientPermission returns Err -func InsufficientPermission(s ...string) *Err { - return newErr(Scope, code.InsufficientPermission, - fmt.Sprintf("insufficient permission: %s", strings.Join(s, " "))) -} - -// InsufficientPermissionL returns Err and log -func InsufficientPermissionL(l logx.Logger, s ...string) *Err { - e := InsufficientPermission(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// ResourceAlreadyExist returns Err -func ResourceAlreadyExist(s ...string) *Err { - return newErr(Scope, code.ResourceAlreadyExist, fmt.Sprintf("resource already exist: %s", strings.Join(s, " "))) -} - -// ResourceAlreadyExistL logs error message and returns Err -func ResourceAlreadyExistL(l logx.Logger, s ...string) *Err { - e := ResourceAlreadyExist(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// InvalidMeasurementID returns Err -func InvalidMeasurementID(s ...string) *Err { - return newErr(Scope, code.InvalidMeasurementID, fmt.Sprintf("missing measurement id: %s", strings.Join(s, " "))) -} - -// InvalidMeasurementIDL logs error message and returns Err -func InvalidMeasurementIDL(l logx.Logger, s ...string) *Err { - e := InvalidMeasurementID(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// ResourceExpired returns Err -func ResourceExpired(s ...string) *Err { - return newErr(Scope, code.ResourceExpired, fmt.Sprintf("resource expired: %s", strings.Join(s, " "))) -} - -// ResourceExpiredL logs error message and returns Err -func ResourceExpiredL(l logx.Logger, s ...string) *Err { - e := ResourceExpired(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// ResourceMigrated returns Err -func ResourceMigrated(s ...string) *Err { - return newErr(Scope, code.ResourceMigrated, fmt.Sprintf("resource migrated: %s", strings.Join(s, " "))) -} - -// ResourceMigratedL logs error message and returns Err -func ResourceMigratedL(l logx.Logger, s ...string) *Err { - e := ResourceMigrated(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// InsufficientQuota returns Err -func InsufficientQuota(s ...string) *Err { - return newErr(Scope, code.InsufficientQuota, fmt.Sprintf("insufficient quota: %s", strings.Join(s, " "))) -} - -// InsufficientQuotaL logs error message and returns Err -func InsufficientQuotaL(l logx.Logger, s ...string) *Err { - e := InsufficientQuota(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -/*** CatAuth ***/ - -// Unauthorized returns Err -func Unauthorized(s ...string) *Err { - return newErr(Scope, code.Unauthorized, fmt.Sprintf("unauthorized: %s", strings.Join(s, " "))) -} - -// UnauthorizedL logs error message and returns Err -func UnauthorizedL(l logx.Logger, s ...string) *Err { - e := Unauthorized(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// AuthExpired returns Err -func AuthExpired(s ...string) *Err { - return newErr(Scope, code.AuthExpired, fmt.Sprintf("expired: %s", strings.Join(s, " "))) -} - -// AuthExpiredL logs error message and returns Err -func AuthExpiredL(l logx.Logger, s ...string) *Err { - e := AuthExpired(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// InvalidPosixTime returns Err -func InvalidPosixTime(s ...string) *Err { - return newErr(Scope, code.InvalidPosixTime, fmt.Sprintf("invalid posix time: %s", strings.Join(s, " "))) -} - -// InvalidPosixTimeL logs error message and returns Err -func InvalidPosixTimeL(l logx.Logger, s ...string) *Err { - e := InvalidPosixTime(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// SigAndPayloadNotMatched returns Err -func SigAndPayloadNotMatched(s ...string) *Err { - return newErr(Scope, code.SigAndPayloadNotMatched, fmt.Sprintf("signature and the payload are not match: %s", strings.Join(s, " "))) -} - -// SigAndPayloadNotMatchedL logs error message and returns Err -func SigAndPayloadNotMatchedL(l logx.Logger, s ...string) *Err { - e := SigAndPayloadNotMatched(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// Forbidden returns Err -func Forbidden(s ...string) *Err { - return newErr(Scope, code.Forbidden, fmt.Sprintf("forbidden: %s", strings.Join(s, " "))) -} - -// ForbiddenL logs error message and returns Err -func ForbiddenL(l logx.Logger, s ...string) *Err { - e := Forbidden(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// IsAuthUnauthorizedError check the err is unauthorized error -func IsAuthUnauthorizedError(err *Err) bool { - switch err.Code() { - case code.Unauthorized, code.AuthExpired, code.InvalidPosixTime, - code.SigAndPayloadNotMatched, code.Forbidden, - code.InvalidFormat, code.ResourceNotFound: - return true - default: - return false - } -} - -/*** CatXBC ***/ - -// ArkInternal returns Err -func ArkInternal(s ...string) *Err { - return newErr(Scope, code.ArkInternal, fmt.Sprintf("ark internal error: %s", strings.Join(s, " "))) -} - -// ArkInternalL logs error message and returns Err -func ArkInternalL(l logx.Logger, s ...string) *Err { - e := ArkInternal(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -/*** CatPubSub ***/ - -// Publish returns Err -func Publish(s ...string) *Err { - return newErr(Scope, code.Publish, fmt.Sprintf("publish: %s", strings.Join(s, " "))) -} - -// PublishL logs error message and returns Err -func PublishL(l logx.Logger, s ...string) *Err { - e := Publish(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} - -// Consume returns Err -func Consume(s ...string) *Err { - return newErr(Scope, code.Consume, fmt.Sprintf("consume: %s", strings.Join(s, " "))) -} - -// MsgSizeTooLarge returns Err -func MsgSizeTooLarge(s ...string) *Err { - return newErr(Scope, code.MsgSizeTooLarge, fmt.Sprintf("kafka error: %s", strings.Join(s, " "))) -} - -// MsgSizeTooLargeL logs error message and returns Err -func MsgSizeTooLargeL(l logx.Logger, s ...string) *Err { - e := MsgSizeTooLarge(s...) - l.WithCallerSkip(1).Error(e.Error()) - return e -} diff --git a/internal/lib/error/easy_func_test.go b/internal/lib/error/easy_func_test.go deleted file mode 100644 index 5ff951c..0000000 --- a/internal/lib/error/easy_func_test.go +++ /dev/null @@ -1,1031 +0,0 @@ -package error - -import ( - "context" - "errors" - "fmt" - "member/internal/lib/error/code" - "reflect" - "strconv" - "testing" - - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/logx" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func TestFromGRPCError_GivenStatusWithCodeAndMessage_ShouldReturnErr(t *testing.T) { - // setup - s := status.Error(codes.Code(102399), "FAKE ERROR") - - // act - e := FromGRPCError(s) - - // assert - assert.Equal(t, uint32(10), e.Scope()) - assert.Equal(t, uint32(2300), e.Category()) - assert.Equal(t, uint32(2399), e.Code()) - assert.Equal(t, "FAKE ERROR", e.Error()) -} - -func TestFromGRPCError_GivenNilError_ShouldReturnErr_Scope0_Cat0_Detail0(t *testing.T) { - // setup - var nilError error = nil - - // act - e := FromGRPCError(nilError) - - // assert - assert.Equal(t, uint32(0), e.Scope()) - assert.Equal(t, uint32(0), e.Category()) - assert.Equal(t, uint32(0), e.Code()) - assert.Equal(t, "", e.Error()) -} - -func TestFromGRPCError_GivenGRPCNativeError_ShouldReturnErr_Scope0_CatGRPC_DetailGRPCUnavailable(t *testing.T) { - // setup - msg := "GRPC Unavailable ERROR" - s := status.Error(codes.Code(codes.Unavailable), msg) - - // act - e := FromGRPCError(s) - - // assert - assert.Equal(t, code.Unset, e.Scope()) - assert.Equal(t, code.CatGRPC, e.Category()) - assert.Equal(t, uint32(codes.Unavailable), e.Code()) - assert.Equal(t, msg, e.Error()) -} - -func TestFromGRPCError_GivenGeneralError_ShouldReturnErr_Scope0_CatGRPC_DetailGRPCUnknown(t *testing.T) { - // setup - generalErr := errors.New("general error") - - // act - e := FromGRPCError(generalErr) - - // assert - assert.Equal(t, code.Unset, e.Scope()) - assert.Equal(t, code.CatGRPC, e.Category()) - assert.Equal(t, uint32(codes.Unknown), e.Code()) -} - -func TestToGRPCError_GivenErr_StatusShouldHave_Code112233(t *testing.T) { - // setup - e := Err{scope: 11, code: 2233, msg: "FAKE MSG"} - - // act - err := ToGRPCError(&e) - s, _ := status.FromError(err) - - // assert - assert.Equal(t, 112233, int(s.Code())) - assert.Equal(t, "FAKE MSG", s.Message()) -} - -func TestInvalidFormat_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := InvalidFormat("field A", "Error description") - - // assert - assert.Equal(t, code.CatInput, e.Category()) - assert.Equal(t, code.InvalidFormat, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Equal(t, e.Error(), "invalid format: field A Error description") -} - -func TestInvalidFormatL_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - ctx := context.Background() - // act - e := InvalidFormatL(logx.WithContext(ctx), "field A", "Error description") - - // assert - assert.Equal(t, code.CatInput, e.Category()) - assert.Equal(t, code.InvalidFormat, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestInvalidRange_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := InvalidRange("field A", "Error description") - - // assert - assert.Equal(t, code.CatInput, e.Category()) - assert.Equal(t, code.InvalidRange, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Equal(t, e.Error(), "invalid range: field A Error description") -} - -func TestInvalidRangeL_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - ctx := context.Background() - // act - e := InvalidRangeL(logx.WithContext(ctx), "field A", "Error description") - - // assert - assert.Equal(t, code.CatInput, e.Category()) - assert.Equal(t, code.InvalidRange, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestNotValidImplementation_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := NotValidImplementation("field A", "Error description") - - // assert - assert.Equal(t, code.CatInput, e.Category()) - assert.Equal(t, code.NotValidImplementation, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Equal(t, e.Error(), "not valid implementation: field A Error description") -} - -func TestNotValidImplementationL_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - // act - e := NotValidImplementationL(l, "field A", "Error description") - - // assert - assert.Equal(t, code.CatInput, e.Category()) - assert.Equal(t, code.NotValidImplementation, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestDBError_WithStrings_ShouldHasCatDBAndDetailCodeDBError(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := DBError("field A", "Error description") - - // assert - assert.Equal(t, code.CatDB, e.Category()) - assert.Equal(t, code.DBError, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestDBDataConvert_WithStrings_ShouldHasCatDBAndDetailCodeDBDataConvert(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := DBDataConvert("field A", "Error description") - - // assert - assert.Equal(t, code.CatDB, e.Category()) - assert.Equal(t, code.DBDataConvert, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestResourceNotFound_WithStrings_ShouldHasCatResource_DetailCodeResourceNotFound(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := ResourceNotFound("field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.ResourceNotFound, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestInvalidResourceFormat_WithStrings_ShouldHasCatResource_DetailCodeInvalidResourceFormat(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := InvalidResourceFormat("field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.InvalidResourceFormat, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestInvalidResourceState_OK(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := InvalidResourceState("field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.InvalidResourceState, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.EqualError(t, e, "invalid resource state: field A Error description") -} - -func TestInvalidResourceStateL_LogError(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - - // act - e := InvalidResourceStateL(l, "field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.InvalidResourceState, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.EqualError(t, e, "invalid resource state: field A Error description") -} - -func TestAuthExpired_OK(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := AuthExpired("field A", "Error description") - - // assert - assert.Equal(t, code.CatAuth, e.Category()) - assert.Equal(t, code.AuthExpired, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestUnauthorized_WithStrings_ShouldHasCatAuth_DetailCodeUnauthorized(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := Unauthorized("field A", "Error description") - - // assert - assert.Equal(t, code.CatAuth, e.Category()) - assert.Equal(t, code.Unauthorized, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestInvalidPosixTime_WithStrings_ShouldHasCatAuth_DetailCodeInvalidPosixTime(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := InvalidPosixTime("field A", "Error description") - - // assert - assert.Equal(t, code.CatAuth, e.Category()) - assert.Equal(t, code.InvalidPosixTime, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestSigAndPayloadNotMatched_WithStrings_ShouldHasCatAuth_DetailCodeSigAndPayloadNotMatched(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := SigAndPayloadNotMatched("field A", "Error description") - - // assert - assert.Equal(t, code.CatAuth, e.Category()) - assert.Equal(t, code.SigAndPayloadNotMatched, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestForbidden_WithStrings_ShouldHasCatAuth_DetailCodeForbidden(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := Forbidden("field A", "Error description") - - // assert - assert.Equal(t, code.CatAuth, e.Category()) - assert.Equal(t, code.Forbidden, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestXBCInternal_WithStrings_ShouldHasCatResource_DetailCodeXBCInternal(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := ArkInternal("field A", "Error description") - - // assert - assert.Equal(t, code.CatArk, e.Category()) - assert.Equal(t, code.ArkInternal, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestGeneralInternalError_WithStrings_DetailInternalError(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := SystemInternalError("field A", "Error description") - - // assert - assert.Equal(t, code.CatSystem, e.Category()) - assert.Equal(t, code.SystemInternalError, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestGeneralInternalErrorL_WithStrings_DetailInternalError(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - - // act - e := SystemInternalErrorL(l, "field A", "Error description") - - // assert - assert.Equal(t, code.CatSystem, e.Category()) - assert.Equal(t, code.SystemInternalError, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestSystemMaintainError_WithStrings_DetailSystemMaintainError(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - - // act - e := SystemMaintainErrorL(l, "field A", "Error description") - - // assert - assert.Equal(t, code.CatSystem, e.Category()) - assert.Equal(t, code.SystemMaintainError, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestResourceAlreadyExist_WithStrings_DetailInternalError(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := ResourceAlreadyExist("field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.ResourceAlreadyExist, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestResourceAlreadyExistL_WithStrings_DetailInternalError(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - - // act - e := ResourceAlreadyExistL(l, "field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.ResourceAlreadyExist, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestResourceInsufficient_WithStrings_DetailInternalError(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := ResourceInsufficient("field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.ResourceInsufficient, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestResourceInsufficientL_WithStrings_DetailInternalError(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - - // act - e := ResourceInsufficientL(l, "field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.ResourceInsufficient, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestInsufficientPermission_WithStrings_DetailInternalError(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := InsufficientPermission("field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.InsufficientPermission, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestInsufficientPermissionL_WithStrings_DetailInternalError(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - - // act - e := InsufficientPermissionL(l, "field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.InsufficientPermission, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestInvalidMeasurementID_WithErrorStrings_ShouldReturnCorrectCodeAndErrorString(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := InvalidMeasurementID("field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.InvalidMeasurementID, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestInvalidMeasurementIDL_WithErrorStrings_ShouldReturnCorrectCodeAndErrorStringAndCallLogger(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - - // act - e := InvalidMeasurementIDL(l, "field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.InvalidMeasurementID, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestResourceExpired_OK(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := ResourceExpired("field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.ResourceExpired, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestResourceExpiredL_LogError(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - - // act - e := ResourceExpiredL(l, "field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.ResourceExpired, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestResourceMigrated_OK(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := ResourceMigrated("field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.ResourceMigrated, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestResourceMigratedL_LogError(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - - // act - e := ResourceMigratedL(l, "field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.ResourceMigrated, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestInsufficientQuota_OK(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := InsufficientQuota("field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.InsufficientQuota, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestInsufficientQuotaL_LogError(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - - // act - e := InsufficientQuotaL(l, "field A", "Error description") - - // assert - assert.Equal(t, code.CatResource, e.Category()) - assert.Equal(t, code.InsufficientQuota, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestPublish_WithErrorStrings_ShouldReturnCorrectCodeAndErrorString(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := Publish("field A", "Error description") - - // assert - assert.Equal(t, code.CatPubSub, e.Category()) - assert.Equal(t, code.Publish, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestPublishL_WithErrorStrings_ShouldReturnCorrectCodeAndErrorStringAndCallLogger(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - - // act - e := PublishL(l, "field A", "Error description") - - // assert - assert.Equal(t, code.CatPubSub, e.Category()) - assert.Equal(t, code.Publish, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") -} - -func TestMsgSizeTooLarge_WithErrorStrings_ShouldReturnCorrectCodeAndErrorString(t *testing.T) { - // setup - Scope = 99 - defer func() { - Scope = code.Unset - }() - - // act - e := MsgSizeTooLarge("Error description") - - // assert - assert.Equal(t, code.CatPubSub, e.Category()) - assert.Equal(t, code.MsgSizeTooLarge, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "kafka error: Error description") -} - -func TestMsgSizeTooLargeL_WithErrorStrings_ShouldReturnCorrectCodeAndErrorStringAndCallLogger(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - l := logx.WithContext(context.Background()) - - // act - e := MsgSizeTooLargeL(l, "Error description") - - // assert - assert.Equal(t, code.CatPubSub, e.Category()) - assert.Equal(t, code.MsgSizeTooLarge, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "kafka error: Error description") -} - -func TestStructErr_WithInternalErr_ShouldIsFuncReportCorrectly(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - // arrange 2 layers err - layer1Err := fmt.Errorf("layer 1 error") - layer2Err := fmt.Errorf("layer 2: %w", layer1Err) - - // act with error chain: InvalidFormat -> layer 2 err -> layer 1 err - e := InvalidFormat("field A", "Error description") - err := e.Wrap(layer2Err) - if err != nil { - t.Fatalf("Failed to wrap error: %v", err) - } - - // assert - assert.Equal(t, code.CatInput, e.Category()) - assert.Equal(t, code.InvalidFormat, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") - - // errors.Is should report correctly - assert.True(t, errors.Is(e, layer1Err)) - assert.True(t, errors.Is(e, layer2Err)) -} - -func TestStructErr_WithInternalErr_ShouldErrorOutputChainErrMessage(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - - // arrange 2 layers err - layer1Err := fmt.Errorf("layer 1 error") - // act with error chain: InvalidFormat -> layer 1 err - e := InvalidFormat("field A", "Error description") - err := e.Wrap(layer1Err) - if err != nil { - t.Fatalf("Failed to wrap error: %v", err) - } - - // assert - assert.Equal(t, "invalid format: field A Error description: layer 1 error", e.Error()) -} - -// arrange a specific err type just for UT -type testErr struct { - code int -} - -func (e *testErr) Error() string { - return strconv.Itoa(e.code) -} - -func TestStructErr_WithInternalErr_ShouldAsFuncReportCorrectly(t *testing.T) { - // setup - Scope = 99 - defer func() { Scope = code.Unset }() - - testE := &testErr{code: 123} - layer2Err := fmt.Errorf("layer 2: %w", testE) - - // act with error chain: InvalidFormat -> layer 2 err -> testErr - e := InvalidFormat("field A", "Error description") - err := e.Wrap(layer2Err) - if err != nil { - t.Fatalf("Failed to wrap error: %v", err) - } - - // assert - assert.Equal(t, code.CatInput, e.Category()) - assert.Equal(t, code.InvalidFormat, e.Code()) - assert.Equal(t, uint32(99), e.Scope()) - assert.Contains(t, e.Error(), "field A") - assert.Contains(t, e.Error(), "Error description") - - // errors.As should report correctly - var internalErr *testErr - assert.True(t, errors.As(e, &internalErr)) - assert.Equal(t, testE, internalErr) -} - -/* -benchmark run for 1 second: -Benchmark_ErrorsIs_OneLayerError-4 148281332 8.68 ns/op 0 B/op 0 allocs/op -Benchmark_ErrorsIs_TwoLayerError-4 35048202 32.4 ns/op 0 B/op 0 allocs/op -Benchmark_ErrorsIs_FourLayerError-4 15309349 81.7 ns/op 0 B/op 0 allocs/op - -Benchmark_ErrorsAs_OneLayerError-4 16893205 70.4 ns/op 0 B/op 0 allocs/op -Benchmark_ErrorsAs_TwoLayerError-4 10568083 112 ns/op 0 B/op 0 allocs/op -Benchmark_ErrorsAs_FourLayerError-4 6307729 188 ns/op 0 B/op 0 allocs/op -*/ -func Benchmark_ErrorsIs_OneLayerError(b *testing.B) { - layer1Err := &testErr{code: 123} - var err error = layer1Err - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - errors.Is(err, layer1Err) - } -} - -func Benchmark_ErrorsIs_TwoLayerError(b *testing.B) { - layer1Err := &testErr{code: 123} - - // act with error chain: InvalidFormat(layer 2) -> testErr(layer 1) - layer2Err := InvalidFormat("field A", "Error description") - err := layer2Err.Wrap(layer1Err) - if err != nil { - b.Fatalf("Failed to wrap error: %v", err) - } - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - errors.Is(layer2Err, layer1Err) - } -} - -func Benchmark_ErrorsIs_FourLayerError(b *testing.B) { - layer1Err := &testErr{code: 123} - layer2Err := fmt.Errorf("layer 2: %w", layer1Err) - layer3Err := fmt.Errorf("layer 3: %w", layer2Err) - // act with error chain: InvalidFormat(layer 4) -> Error(layer 3) -> Error(layer 2) -> testErr(layer 1) - layer4Err := InvalidFormat("field A", "Error description") - err := layer4Err.Wrap(layer3Err) - if err != nil { - b.Fatalf("Failed to wrap error: %v", err) - } - - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - errors.Is(layer4Err, layer1Err) - } -} - -func Benchmark_ErrorsAs_OneLayerError(b *testing.B) { - layer1Err := &testErr{code: 123} - var err error = layer1Err - - b.ReportAllocs() - b.ResetTimer() - var internalErr *testErr - for i := 0; i < b.N; i++ { - errors.As(err, &internalErr) - } -} - -func Benchmark_ErrorsAs_TwoLayerError(b *testing.B) { - layer1Err := &testErr{code: 123} - - // act with error chain: InvalidFormat(layer 2) -> testErr(layer 1) - layer2Err := InvalidFormat("field A", "Error description") - err := layer2Err.Wrap(layer1Err) - if err != nil { - b.Fatalf("Failed to wrap error: %v", err) - } - - b.ReportAllocs() - b.ResetTimer() - var internalErr *testErr - for i := 0; i < b.N; i++ { - errors.As(layer2Err, &internalErr) - } -} - -func Benchmark_ErrorsAs_FourLayerError(b *testing.B) { - layer1Err := &testErr{code: 123} - layer2Err := fmt.Errorf("layer 2: %w", layer1Err) - layer3Err := fmt.Errorf("layer 3: %w", layer2Err) - // act with error chain: InvalidFormat(layer 4) -> Error(layer 3) -> Error(layer 2) -> testErr(layer 1) - layer4Err := InvalidFormat("field A", "Error description") - err := layer4Err.Wrap(layer3Err) - if err != nil { - b.Fatalf("Failed to wrap error: %v", err) - } - - b.ReportAllocs() - b.ResetTimer() - var internalErr *testErr - for i := 0; i < b.N; i++ { - errors.As(layer4Err, &internalErr) - } -} - -func TestFromError(t *testing.T) { - tests := []struct { - name string - givenError error - want *Err - }{ - { - "given nil error should return nil", - nil, - nil, - }, - { - "given normal error should return nil", - errors.New("normal error"), - nil, - }, - { - "given Err should return Err", - ResourceNotFound("fake error"), - ResourceNotFound("fake error"), - }, - { - "given error wraps Err should return Err", - fmt.Errorf("outter error wraps %w", ResourceNotFound("fake error")), - ResourceNotFound("fake error"), - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := FromError(tt.givenError); !reflect.DeepEqual(got, tt.want) { - t.Errorf("FromError() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/internal/lib/error/errors.go b/internal/lib/error/errors.go deleted file mode 100644 index fb16d5c..0000000 --- a/internal/lib/error/errors.go +++ /dev/null @@ -1,197 +0,0 @@ -package error - -import ( - "ark-permission/internal/lib/error/code" - "errors" - "fmt" - "net/http" - - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -// TODO Error要移到common 包 - -// Scope global variable should be set by service or module -var Scope = code.Unset - -type Err struct { - category uint32 - code uint32 - scope uint32 - msg string - internalErr error -} - -// Error is the interface of error -// Getter function of private property "msg" -func (e *Err) Error() string { - if e == nil { - return "" - } - - // chain the error string if the internal err exists - var internalErrStr string - if e.internalErr != nil { - internalErrStr = e.internalErr.Error() - } - - if e.msg != "" { - if internalErrStr != "" { - return fmt.Sprintf("%s: %s", e.msg, internalErrStr) - } - return e.msg - } - - generalErrStr := e.GeneralError() - if internalErrStr != "" { - return fmt.Sprintf("%s: %s", generalErrStr, internalErrStr) - } - return generalErrStr -} - -// Category getter function of private property "category" -func (e *Err) Category() uint32 { - if e == nil { - return 0 - } - return e.category -} - -// Scope getter function of private property "scope" -func (e *Err) Scope() uint32 { - if e == nil { - return code.Unset - } - - return e.scope -} - -// CodeStr returns the string of error code with zero padding -func (e *Err) CodeStr() string { - if e == nil { - return "00000" - } - - if e.Category() == code.CatGRPC { - return fmt.Sprintf("%d%04d", e.Scope(), e.Category()+e.Code()) - } - - return fmt.Sprintf("%d%04d", e.Scope(), e.Code()) -} - -// Code getter function of private property "code" -func (e *Err) Code() uint32 { - if e == nil { - return code.OK - } - - return e.code -} - -func (e *Err) FullCode() uint32 { - if e == nil { - return 0 - } - - if e.Category() == code.CatGRPC { - return e.Scope()*10000 + e.Category() + e.Code() - } - - return e.Scope()*10000 + e.Code() -} - -// HTTPStatus returns corresponding HTTP status code -func (e *Err) HTTPStatus() int { - if e == nil || e.Code() == code.OK { - return http.StatusOK - } - // determine status code by code - switch e.Code() { - case code.ResourceInsufficient: - // 400 - return http.StatusBadRequest - case code.Unauthorized, code.InsufficientPermission: - // 401 - return http.StatusUnauthorized - case code.InsufficientQuota: - // 402 - return http.StatusPaymentRequired - case code.InvalidPosixTime, code.Forbidden: - // 403 - return http.StatusForbidden - case code.ResourceNotFound: - // 404 - return http.StatusNotFound - case code.ResourceAlreadyExist, code.InvalidResourceState: - // 409 - return http.StatusConflict - case code.NotValidImplementation: - // 501 - return http.StatusNotImplemented - default: - } - - // determine status code by category - switch e.Category() { - case code.CatInput: - return http.StatusBadRequest - default: - // return status code 500 if none of the condition is met - return http.StatusInternalServerError - } -} - -// GeneralError transform category level error message -// It's the general error message for customer/API caller -func (e *Err) GeneralError() string { - if e == nil { - return "" - } - - errStr, ok := code.CatToStr[e.Category()] - if !ok { - return "" - } - - return errStr -} - -// Is called when performing errors.Is(). -// DO NOT USE THIS FUNCTION DIRECTLY unless you are very certain about what you're doing. -// Use errors.Is instead. -// This function compares if two error variables are both *Err, and have the same code (without checking the wrapped internal error) -func (e *Err) Is(f error) bool { - var err *Err - ok := errors.As(f, &err) - if !ok { - return false - } - return e.Code() == err.Code() -} - -// Unwrap returns the underlying error -// The result of unwrapping an error may itself have an Unwrap method; -// we call the sequence of errors produced by repeated unwrapping the error chain. -func (e *Err) Unwrap() error { - if e == nil { - return nil - } - return e.internalErr -} - -// Wrap sets the internal error to Err struct -func (e *Err) Wrap(internalErr error) *Err { - if e != nil { - e.internalErr = internalErr - } - return e -} - -func (e *Err) GRPCStatus() *status.Status { - if e == nil { - return status.New(codes.OK, "") - } - - return status.New(codes.Code(e.FullCode()), e.Error()) -} diff --git a/internal/lib/error/errors_test.go b/internal/lib/error/errors_test.go deleted file mode 100644 index a0f5325..0000000 --- a/internal/lib/error/errors_test.go +++ /dev/null @@ -1,297 +0,0 @@ -package error - -import ( - "errors" - "fmt" - "member/internal/lib/error/code" - "net/http" - "testing" - - "github.com/stretchr/testify/assert" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func TestCode_GivenNilReceiver_CodeReturnOK_CodeStrReturns00000(t *testing.T) { - // setup - var e *Err = nil - - // act & assert - assert.Equal(t, code.OK, e.Code()) - assert.Equal(t, "00000", e.CodeStr()) - assert.Equal(t, "", e.Error()) -} - -func TestCode_GivenScope99DetailCode6687_ShouldReturn996687(t *testing.T) { - // setup - e := Err{scope: 99, code: 6687} - - // act & assert - assert.Equal(t, uint32(6687), e.Code()) - assert.Equal(t, "996687", e.CodeStr()) -} - -func TestCode_GivenScope0DetailCode87_ShouldReturn87(t *testing.T) { - // setup - e := Err{scope: 0, code: 87} - - // act & assert - assert.Equal(t, uint32(87), e.Code()) - assert.Equal(t, "00087", e.CodeStr()) -} - -func TestFromCode_Given870005_ShouldHasScope87_Cat0_Detail5(t *testing.T) { - // setup - e := FromCode(870005) - - // assert - assert.Equal(t, uint32(87), e.Scope()) - assert.Equal(t, uint32(0), e.Category()) - assert.Equal(t, uint32(5), e.Code()) - assert.Equal(t, "", e.Error()) -} - -func TestFromCode_Given0_ShouldHasScope0_Cat0_Detail0(t *testing.T) { - // setup - e := FromCode(0) - - // assert - assert.Equal(t, uint32(0), e.Scope()) - assert.Equal(t, uint32(0), e.Category()) - assert.Equal(t, uint32(0), e.Code()) - assert.Equal(t, "", e.Error()) -} - -func TestFromCode_Given9105_ShouldHasScope0_Cat9100_Detail9105(t *testing.T) { - // setup - e := FromCode(9105) - - // assert - assert.Equal(t, uint32(0), e.Scope()) - assert.Equal(t, uint32(9100), e.Category()) - assert.Equal(t, uint32(9105), e.Code()) - assert.Equal(t, "", e.Error()) -} - -func TestErr_ShouldImplementErrorFunction(t *testing.T) { - // setup a func return error - f := func() error { return InvalidFormat("fake field") } - - // act - err := f() - - // assert - assert.NotNil(t, err) - assert.Contains(t, fmt.Sprint(err), "fake field") // can be printed -} - -func TestGeneralError_GivenNilErr_ShouldReturnEmptyString(t *testing.T) { - // setup - var e *Err = nil - - // act & assert - assert.Equal(t, "", e.GeneralError()) -} - -func TestGeneralError_GivenNotExistCat_ShouldReturnEmptyString(t *testing.T) { - // setup - e := Err{category: 123456} - - // act & assert - assert.Equal(t, "", e.GeneralError()) -} - -func TestGeneralError_GivenCatDB_ShouldReturnDBError(t *testing.T) { - // setup - e := Err{category: code.CatDB} - catErrStr := code.CatToStr[code.CatDB] - - // act & assert - assert.Equal(t, catErrStr, e.GeneralError()) -} - -func TestError_GivenEmptyMsg_ShouldReturnCatGeneralErrorMessage(t *testing.T) { - // setup - e := Err{category: code.CatDB, msg: ""} - - // act - errMsg := e.Error() - - // assert - assert.Equal(t, code.CatToStr[code.CatDB], errMsg) -} - -func TestError_GivenMsg_ShouldReturnGiveMsg(t *testing.T) { - // setup - e := Err{msg: "FAKE"} - - // act - errMsg := e.Error() - - // assert - assert.Equal(t, "FAKE", errMsg) -} - -func TestIs_GivenNilErr_ShouldReturnFalse(t *testing.T) { - var nilErrs *Err - // act - result := errors.Is(nilErrs, DBError()) - result2 := errors.Is(DBError(), nilErrs) - - // assert - assert.False(t, result) - assert.False(t, result2) -} - -func TestIs_GivenNil_ShouldReturnFalse(t *testing.T) { - // act - result := errors.Is(nil, DBError()) - result2 := errors.Is(DBError(), nil) - - // assert - assert.False(t, result) - assert.False(t, result2) -} - -func TestIs_GivenNilReceiver_ShouldReturnCorrectResult(t *testing.T) { - var nilErr *Err = nil - - // test 1: nilErr != DBError - var dbErr error = DBError("fake db error") - assert.False(t, nilErr.Is(dbErr)) - - // test 2: nilErr != nil error - var nilError error - assert.False(t, nilErr.Is(nilError)) - - // test 3: nilErr == another nilErr - var nilErr2 *Err = nil - assert.True(t, nilErr.Is(nilErr2)) -} - -func TestIs_GivenDBError_ShouldReturnTrue(t *testing.T) { - // setup - dbErr := DBError("fake db error") - - // act - result := errors.Is(dbErr, DBError("not care")) - result2 := errors.Is(DBError(), dbErr) - - // assert - assert.True(t, result) - assert.True(t, result2) -} - -func TestIs_GivenDBErrorAssignToErrorType_ShouldReturnTrue(t *testing.T) { - // setup - var dbErr error = DBError("fake db error") - - // act - result := errors.Is(dbErr, DBError("not care")) - result2 := errors.Is(DBError(), dbErr) - - // assert - assert.True(t, result) - assert.True(t, result2) -} - -func TestWrap_GivenNilErr_ShouldNoPanic(t *testing.T) { - // act & assert - assert.NotPanics(t, func() { - var e *Err = nil - _ = e.Wrap(fmt.Errorf("test")) - }) -} - -func TestWrap_GivenErrorToWrap_ShouldReturnErrorWithWrappedError(t *testing.T) { - // act & assert - wrappedErr := fmt.Errorf("test") - wrappingErr := SystemInternalError("WrappingError").Wrap(wrappedErr) - unWrappedErr := wrappingErr.Unwrap() - - assert.Equal(t, wrappedErr, unWrappedErr) -} - -func TestUnwrap_GivenNilErr_ShouldReturnNil(t *testing.T) { - var e *Err = nil - internalErr := e.Unwrap() - assert.Nil(t, internalErr) -} - -func TestErrorsIs_GivenNilErr_ShouldReturnFalse(t *testing.T) { - var e *Err = nil - assert.False(t, errors.Is(e, fmt.Errorf("test"))) -} - -func TestErrorsAs_GivenNilErr_ShouldReturnFalse(t *testing.T) { - var internalErr *testErr - var e *Err = nil - assert.False(t, errors.As(e, &internalErr)) -} - -func TestGRPCStatus(t *testing.T) { - // setup table driven tests - tests := []struct { - name string - given *Err - expect *status.Status - expectConvert error - }{ - { - "nil errs.Err", - nil, - status.New(codes.OK, ""), - nil, - }, - { - "InvalidFormat Err", - InvalidFormat("fake"), - status.New(codes.Code(101), "invalid format: fake"), - status.New(codes.Code(101), "invalid format: fake").Err(), - }, - } - - // act & assert - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - s := test.given.GRPCStatus() - assert.Equal(t, test.expect.Code(), s.Code()) - assert.Equal(t, test.expect.Message(), s.Message()) - assert.Equal(t, test.expectConvert, status.Convert(test.given).Err()) - }) - } -} - -func TestErr_HTTPStatus(t *testing.T) { - tests := []struct { - name string - err *Err - want int - }{ - {name: "nil error", err: nil, want: http.StatusOK}, - {name: "invalid measurement id", err: &Err{category: code.CatResource, code: code.InvalidMeasurementID}, want: http.StatusInternalServerError}, - {name: "resource already exists", err: &Err{category: code.CatResource, code: code.ResourceAlreadyExist}, want: http.StatusConflict}, - {name: "invalid resource state", err: &Err{category: code.CatResource, code: code.InvalidResourceState}, want: http.StatusConflict}, - {name: "invalid posix time", err: &Err{category: code.CatAuth, code: code.InvalidPosixTime}, want: http.StatusForbidden}, - {name: "unauthorized", err: &Err{category: code.CatAuth, code: code.Unauthorized}, want: http.StatusUnauthorized}, - {name: "db error", err: &Err{category: code.CatDB, code: code.DBError}, want: http.StatusInternalServerError}, - {name: "insufficient permission", err: &Err{category: code.CatResource, code: code.InsufficientPermission}, want: http.StatusUnauthorized}, - {name: "resource insufficient", err: &Err{category: code.CatResource, code: code.ResourceInsufficient}, want: http.StatusBadRequest}, - {name: "invalid format", err: &Err{category: code.CatInput, code: code.InvalidFormat}, want: http.StatusBadRequest}, - {name: "resource not found", err: &Err{code: code.ResourceNotFound}, want: http.StatusNotFound}, - {name: "ok", err: &Err{code: code.OK}, want: http.StatusOK}, - {name: "not valid implementation", err: &Err{category: code.CatInput, code: code.NotValidImplementation}, want: http.StatusNotImplemented}, - {name: "forbidden", err: &Err{category: code.CatAuth, code: code.Forbidden}, want: http.StatusForbidden}, - {name: "insufficient quota", err: &Err{category: code.CatResource, code: code.InsufficientQuota}, want: http.StatusPaymentRequired}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - // act - got := tt.err.HTTPStatus() - - // assert - assert.Equal(t, tt.want, got) - }) - } -} diff --git a/internal/logic/cancel_token_logic.go b/internal/logic/cancel_token_logic.go index 0601e5a..8a32d58 100644 --- a/internal/logic/cancel_token_logic.go +++ b/internal/logic/cancel_token_logic.go @@ -1,11 +1,9 @@ package logic import ( - "context" - "ark-permission/gen_result/pb/permission" "ark-permission/internal/svc" - + "context" "github.com/zeromicro/go-zero/core/logx" ) @@ -23,9 +21,40 @@ func NewCancelTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Cance } } +type cancelTokenReq struct { + Token string `json:"token" validate:"required"` +} + // CancelToken 取消 Token,也包含他裡面的 One Time Toke func (l *CancelTokenLogic) CancelToken(in *permission.CancelTokenReq) (*permission.OKResp, error) { - // todo: add your logic here and delete this line + // // 驗證所需 + // if err := l.svcCtx.Validate.ValidateAll(&cancelTokenReq{ + // Token: in.GetToken(), + // }); err != nil { + // return nil, ers.InvalidFormat(err.Error()) + // } + + // claims, err := uc.parseClaims(accessToken) + // if err != nil { + // return err + // } + // + // token, err := uc.TokenRepository.GetByAccess(ctx, claims.ID()) + // if err != nil { + // if errors.Is(err, repository.ErrRecordNotFound) { + // return usecase.TokenError{Msg: "token not found"} + // } + // + // return usecase.InternalError{Err: fmt.Errorf("tokenRepository.GetByAccess error: %w", err)} + // } + // + // if err := uc.TokenRepository.Delete(ctx, token); err != nil { + // if errors.Is(err, repository.ErrRecordNotFound) { + // return nil, usecase.TokenError{Msg: "token not found"} + // } + // + // return nil, err + // } return &permission.OKResp{}, nil } diff --git a/internal/logic/new_token_logic.go b/internal/logic/new_token_logic.go index 2dae534..c9b961d 100644 --- a/internal/logic/new_token_logic.go +++ b/internal/logic/new_token_logic.go @@ -4,14 +4,10 @@ import ( "ark-permission/gen_result/pb/permission" "ark-permission/internal/domain" "ark-permission/internal/entity" - ers "ark-permission/internal/lib/error" "ark-permission/internal/svc" - "bytes" + ers "code.30cm.net/wanderland/library-go/errors" "context" - "crypto/sha256" - "encoding/hex" "fmt" - "github.com/golang-jwt/jwt/v4" "github.com/google/uuid" "time" @@ -111,30 +107,3 @@ func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.T RefreshToken: token.RefreshToken, }, nil } - -func generateAccessToken(token entity.Token, data any, sign string) (string, error) { - claim := entity.Claims{ - Data: data, - RegisteredClaims: jwt.RegisteredClaims{ - ID: token.ID, - ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)), - Issuer: "permission", - }, - } - - accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claim). - SignedString([]byte(sign)) - if err != nil { - return "", err - } - - return accessToken, nil -} - -func generateRefreshToken(accessToken string) string { - buf := bytes.NewBufferString(accessToken) - h := sha256.New() - _, _ = h.Write(buf.Bytes()) - - return hex.EncodeToString(h.Sum(nil)) -} diff --git a/internal/logic/new_token_logic_test.go b/internal/logic/new_token_logic_test.go index 35ab56a..e00c816 100644 --- a/internal/logic/new_token_logic_test.go +++ b/internal/logic/new_token_logic_test.go @@ -1,109 +1,99 @@ package logic import ( - "ark-permission/gen_result/pb/permission" - "ark-permission/internal/domain" "ark-permission/internal/entity" - libMock "ark-permission/internal/mock/lib" - repoMock "ark-permission/internal/mock/repository" - "ark-permission/internal/svc" - "errors" - "github.com/stretchr/testify/assert" - - "context" "github.com/golang-jwt/jwt/v4" - "go.uber.org/mock/gomock" "testing" "time" ) -func TestNewTokenLogic_NewToken(t *testing.T) { - // mock - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - tokenMockRepo := repoMock.NewMockTokenRepository(ctrl) - mockValidate := libMock.NewMockValidate(ctrl) - - sc := svc.ServiceContext{ - TokenRedisRepo: tokenMockRepo, - Validate: mockValidate, - } - - l := NewNewTokenLogic(context.Background(), &sc) - - tests := []struct { - name string - input *permission.AuthorizationReq - setupMocks func() - expectError bool - expected *permission.TokenResp - }{ - { - name: "Valid token request", - input: &permission.AuthorizationReq{ - GrantType: "authorization_code", - DeviceId: "device123", - Scope: "read", - Expires: 3600, - IsRefreshToken: false, - Data: map[string]string{ - "uid": "user123", - }, - }, - setupMocks: func() { - mockValidate.EXPECT().ValidateAll(gomock.Any()).Return(nil) - tokenMockRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil).Do(func(ctx context.Context, token entity.Token) { - token.AccessToken = "access_token" - }) - generateAccessTokenFunc = func(token entity.Token, data any, sign string) (string, error) { - return "access_token", nil - } - generateRefreshTokenFunc = func(accessToken string) string { - return "refresh_token" - } - }, - expectError: false, - expected: &permission.TokenResp{ - AccessToken: "access_token", - TokenType: domain.TokenTypeBearer, - ExpiresIn: 3600, - RefreshToken: "", - }, - }, - { - name: "Validation error", - input: &permission.AuthorizationReq{ - GrantType: "invalid_grant", - DeviceId: "device123", - Scope: "read", - Expires: 3600, - IsRefreshToken: false, - Data: map[string]string{ - "uid": "user123", - }, - }, - setupMocks: func() { - mockValidate.EXPECT().ValidateAll(gomock.Any()).Return(errors.New("invalid grant type")) - }, - expectError: true, - expected: nil, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tt.setupMocks() - - resp, err := l.NewToken(tt.input) - if tt.expectError { - assert.Error(t, err) - } else { - assert.NoError(t, err) - assert.Equal(t, tt.expected, resp) - } - }) - } -} +// func TestNewTokenLogic_NewToken(t *testing.T) { +// // mock +// ctrl := gomock.NewController(t) +// defer ctrl.Finish() +// +// tokenMockRepo := repoMock.NewMockTokenRepository(ctrl) +// mockValidate := libMock.NewMockValidate(ctrl) +// +// sc := svc.ServiceContext{ +// TokenRedisRepo: tokenMockRepo, +// Validate: mockValidate, +// } +// +// l := NewNewTokenLogic(context.Background(), &sc) +// +// tests := []struct { +// name string +// input *permission.AuthorizationReq +// setupMocks func() +// expectError bool +// expected *permission.TokenResp +// }{ +// { +// name: "Valid token request", +// input: &permission.AuthorizationReq{ +// GrantType: "authorization_code", +// DeviceId: "device123", +// Scope: "read", +// Expires: 3600, +// IsRefreshToken: false, +// Data: map[string]string{ +// "uid": "user123", +// }, +// }, +// setupMocks: func() { +// mockValidate.EXPECT().ValidateAll(gomock.Any()).Return(nil) +// tokenMockRepo.EXPECT().Create(gomock.Any(), gomock.Any()).Return(nil).Do(func(ctx context.Context, token entity.Token) { +// token.AccessToken = "access_token" +// }) +// generateAccessTokenFunc = func(token entity.Token, data any, sign string) (string, error) { +// return "access_token", nil +// } +// generateRefreshTokenFunc = func(accessToken string) string { +// return "refresh_token" +// } +// }, +// expectError: false, +// expected: &permission.TokenResp{ +// AccessToken: "access_token", +// TokenType: domain.TokenTypeBearer, +// ExpiresIn: 3600, +// RefreshToken: "", +// }, +// }, +// { +// name: "Validation error", +// input: &permission.AuthorizationReq{ +// GrantType: "invalid_grant", +// DeviceId: "device123", +// Scope: "read", +// Expires: 3600, +// IsRefreshToken: false, +// Data: map[string]string{ +// "uid": "user123", +// }, +// }, +// setupMocks: func() { +// mockValidate.EXPECT().ValidateAll(gomock.Any()).Return(errors.New("invalid grant type")) +// }, +// expectError: true, +// expected: nil, +// }, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// tt.setupMocks() +// +// resp, err := l.NewToken(tt.input) +// if tt.expectError { +// assert.Error(t, err) +// } else { +// assert.NoError(t, err) +// assert.Equal(t, tt.expected, resp) +// } +// }) +// } +// } // 測試 generateAccessToken 函數 func TestGenerateAccessToken(t *testing.T) { diff --git a/internal/logic/refresh_token_logic.go b/internal/logic/refresh_token_logic.go index 4caef16..44295c9 100644 --- a/internal/logic/refresh_token_logic.go +++ b/internal/logic/refresh_token_logic.go @@ -1,7 +1,10 @@ package logic import ( + "ark-permission/internal/domain" "context" + "fmt" + "strconv" "ark-permission/gen_result/pb/permission" "ark-permission/internal/svc" @@ -26,6 +29,8 @@ func NewRefreshTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Refr // RefreshToken 更新目前的token 以及裡面包含的一次性 Token func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permission.RefreshTokenResp, error) { // todo: add your logic here and delete this line + e := domain.TokenUnexpectedSigningErr("gg88g88") + fmt.Printf(strconv.Itoa(int(e.Code())), e.Category(), e.Scope(), e.FullCode(), e.Error()) return &permission.RefreshTokenResp{}, nil } diff --git a/internal/logic/claims.go b/internal/logic/utils_claims.go similarity index 100% rename from internal/logic/claims.go rename to internal/logic/utils_claims.go diff --git a/internal/logic/utils_jwt.go b/internal/logic/utils_jwt.go new file mode 100644 index 0000000..4b5712d --- /dev/null +++ b/internal/logic/utils_jwt.go @@ -0,0 +1,76 @@ +package logic + +import ( + "ark-permission/internal/entity" + "bytes" + "crypto/sha256" + "encoding/hex" + "fmt" + "github.com/golang-jwt/jwt/v4" + "time" +) + +func generateAccessToken(token entity.Token, data any, sign string) (string, error) { + claim := entity.Claims{ + Data: data, + RegisteredClaims: jwt.RegisteredClaims{ + ID: token.ID, + ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)), + Issuer: "permission", + }, + } + + accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claim). + SignedString([]byte(sign)) + if err != nil { + return "", err + } + + return accessToken, nil +} + +func generateRefreshToken(accessToken string) string { + buf := bytes.NewBufferString(accessToken) + h := sha256.New() + _, _ = h.Write(buf.Bytes()) + + return hex.EncodeToString(h.Sum(nil)) +} + +func parseClaims(accessToken string) (claims, error) { + claimMap, err := parseToken(accessToken) + if err != nil { + return claims{}, err + } + + claims, ok := claimMap["data"].(map[string]string) + if ok { + return claims, nil + } + + return nil, fmt.Errorf("get data from claim map error") +} + +func parseToken(accessToken string) (jwt.MapClaims, error) { + // token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) { + // if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + // return nil, domain.TokenUnexpectedSigningErr(fmt.Sprintf("token unexpected signing method: %v", token.Header["alg"])) + // } + // + // return []byte(uc.Config.CustomConfig.Token.Secret), nil + // }) + // + // if err != nil { + // ers.FromCode() + // return jwt.MapClaims{}, usecase.TokenError{Msg: fmt.Sprintf("parse token error: %s token: %s", err.Error(), accessToken)} + // } + // + // claims, ok := token.Claims.(jwt.MapClaims) + // + // if !(ok && token.Valid) { + // return jwt.MapClaims{}, usecase.TokenError{Msg: "token valid error"} + // } + // + // return claims, nil + return nil, nil +} diff --git a/internal/repository/token.go b/internal/repository/token.go index 2f00c71..c228665 100644 --- a/internal/repository/token.go +++ b/internal/repository/token.go @@ -4,6 +4,8 @@ import ( "ark-permission/internal/domain" "ark-permission/internal/domain/repository" "ark-permission/internal/entity" + ers "code.30cm.net/wanderland/library-go/errors" + "context" "encoding/json" "errors" @@ -61,6 +63,59 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error return nil } +func (t *tokenRepository) GetByAccess(_ context.Context, id string) (entity.Token, error) { + return t.get(domain.GetAccessTokenRedisKey(id)) +} + +func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { + err := t.store.Pipelined(func(tx redis.Pipeliner) error { + keys := []string{ + domain.GetAccessTokenRedisKey(token.ID), + domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), + } + + for _, key := range keys { + if err := tx.Del(ctx, key).Err(); err != nil { + return fmt.Errorf("store.Del key error: %w", err) + } + } + + if token.DeviceID != "" { + key := domain.DeviceTokenRedisKey.With(token.UID).ToString() + _, err := t.store.Hdel(key, token.DeviceID) + if err != nil { + return fmt.Errorf("store.HDel deviceKey error: %w", err) + } + } + + return nil + }) + + if err != nil { + return fmt.Errorf("store.Pipelined error: %w", err) + } + + return nil +} + +func (t *tokenRepository) get(key string) (entity.Token, error) { + body, err := t.store.Get(key) + if errors.Is(err, redis.Nil) { + return entity.Token{}, ers.ResourceNotFound("token key not found in redis", key) + } + + if err != nil { + return entity.Token{}, fmt.Errorf("store.Get tokenTag error: %w", err) + } + + var token entity.Token + if err := json.Unmarshal([]byte(body), &token); err != nil { + return entity.Token{}, fmt.Errorf("json.Unmarshal token error: %w", err) + } + + return token, nil +} + func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, rTTL time.Duration) error { err := tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, rTTL).Err() if err != nil { diff --git a/internal/svc/service_context.go b/internal/svc/service_context.go index 5a21250..2784204 100644 --- a/internal/svc/service_context.go +++ b/internal/svc/service_context.go @@ -5,6 +5,8 @@ import ( "ark-permission/internal/domain/repository" "ark-permission/internal/lib/required" repo "ark-permission/internal/repository" + ers "code.30cm.net/wanderland/library-go/errors" + "code.30cm.net/wanderland/library-go/errors/code" "github.com/zeromicro/go-zero/core/stores/redis" ) @@ -21,6 +23,7 @@ func NewServiceContext(c config.Config) *ServiceContext { if err != nil { panic(err) } + ers.Scope = code.CloudEPPermission return &ServiceContext{ Config: c, diff --git a/permission.go b/permission.go index d7f2864..89037e5 100644 --- a/permission.go +++ b/permission.go @@ -34,7 +34,7 @@ func main() { }) defer s.Stop() - // // 加入中間件 + // 加入中間件 // s.AddUnaryInterceptors(middleware.TimeoutMiddleware) fmt.Printf("Starting rpc server at %s...\n", c.ListenOn) -- 2.40.1 From b111cc12106c62d2b03fedeab2f80dd362e03be2 Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Thu, 8 Aug 2024 16:10:38 +0800 Subject: [PATCH 05/10] update errors from lib to command --- internal/domain/errors.go | 49 +++++++++++++++--- internal/lib/metric/app.go | 30 +++++++++++ internal/lib/metric/db.go | 1 + internal/logic/cancel_token_logic.go | 59 ++++++++++++---------- internal/logic/new_token_logic.go | 13 +++-- internal/logic/refresh_token_logic.go | 8 +-- internal/logic/utils_jwt.go | 72 ++++++++++++++++----------- internal/repository/token.go | 19 ++++--- 8 files changed, 170 insertions(+), 81 deletions(-) create mode 100644 internal/lib/metric/app.go create mode 100644 internal/lib/metric/db.go diff --git a/internal/domain/errors.go b/internal/domain/errors.go index a0b6a03..0ccbdfd 100644 --- a/internal/domain/errors.go +++ b/internal/domain/errors.go @@ -1,20 +1,57 @@ package domain import ( + mts "ark-permission/internal/lib/metric" + ers "code.30cm.net/wanderland/library-go/errors" "code.30cm.net/wanderland/library-go/errors/code" ) -// Decimal: 120314 // 12 represents Scope -// 03 represents Category -// 14 represents Detail error code +// 100 represents Category +// 9 represents Detail error code +// full code 12009 只會有 系統以及錯誤碼,category 是給系統判定用的 +// 目前 Scope 以及分類要系統共用,係向的錯誤各自服務實作就好 const ( - TokenUnexpectedSigning = 1 + TokenUnexpectedSigningErrorCode = iota + 1 + TokenValidateErrorCode + TokenClaimErrorCode ) -// TokenUnexpectedSigningErr 031011 +const ( + RedisDelErrorCode = iota + 20 + RedisPipLineErrorCode +) + +// TokenUnexpectedSigningErr 30001 Token 簽名錯誤 func TokenUnexpectedSigningErr(msg string) *ers.Err { - return ers.NewErr(code.CloudEPPermission, code.CatInput, code.InvalidFormat, msg) + mts.AppErrorMetrics.AddFailure("token", "token_unexpected_sign") + return ers.NewErr(code.CloudEPPermission, code.CatInput, TokenUnexpectedSigningErrorCode, msg) +} + +// TokenTokenValidateErr 30002 Token 驗證錯誤 +func TokenTokenValidateErr(msg string) *ers.Err { + mts.AppErrorMetrics.AddFailure("token", "token_validate_ilegal") + return ers.NewErr(code.CloudEPPermission, code.CatInput, TokenValidateErrorCode, msg) +} + +// TokenClaimError 30003 Token 驗證錯誤 +func TokenClaimError(msg string) *ers.Err { + mts.AppErrorMetrics.AddFailure("token", "token_claim_error") + return ers.NewErr(code.CloudEPPermission, code.CatInput, TokenClaimErrorCode, msg) +} + +// RedisDelError 30020 Redis 刪除錯誤 +func RedisDelError(msg string) *ers.Err { + // 看需要建立哪些 Metrics + mts.AppErrorMetrics.AddFailure("redis", "del_error") + return ers.NewErr(code.CloudEPPermission, code.CatDB, RedisDelErrorCode, msg) +} + +// RedisPipLineError 30021 Redis PipLine 錯誤 +func RedisPipLineError(msg string) *ers.Err { + // 看需要建立哪些 Metrics + mts.AppErrorMetrics.AddFailure("redis", "pip_line_error") + return ers.NewErr(code.CloudEPPermission, code.CatInput, TokenClaimErrorCode, msg) } diff --git a/internal/lib/metric/app.go b/internal/lib/metric/app.go new file mode 100644 index 0000000..59da7ef --- /dev/null +++ b/internal/lib/metric/app.go @@ -0,0 +1,30 @@ +package metric + +import ( + "github.com/zeromicro/go-zero/core/metric" +) + +var AppErrorMetrics = NewAppErrMetrics() + +type appErrMetrics struct { + metric.CounterVec +} + +type Metrics interface { + AddFailure(source, reason string) +} + +// NewAppErrMetrics initiate metrics and register to prometheus +func NewAppErrMetrics() Metrics { + return &appErrMetrics{metric.NewCounterVec(&metric.CounterVecOpts{ + Namespace: "ark", + Subsystem: "permission", + Name: "permission_app_error_total", + Help: "App defined failure total.", + Labels: []string{"source", "reason"}, + })} +} + +func (m *appErrMetrics) AddFailure(source, reason string) { + m.Inc(source, reason) +} diff --git a/internal/lib/metric/db.go b/internal/lib/metric/db.go new file mode 100644 index 0000000..0bad30a --- /dev/null +++ b/internal/lib/metric/db.go @@ -0,0 +1 @@ +package metric diff --git a/internal/logic/cancel_token_logic.go b/internal/logic/cancel_token_logic.go index 8a32d58..d7f94a2 100644 --- a/internal/logic/cancel_token_logic.go +++ b/internal/logic/cancel_token_logic.go @@ -3,6 +3,7 @@ package logic import ( "ark-permission/gen_result/pb/permission" "ark-permission/internal/svc" + ers "code.30cm.net/wanderland/library-go/errors" "context" "github.com/zeromicro/go-zero/core/logx" ) @@ -27,34 +28,38 @@ type cancelTokenReq struct { // CancelToken 取消 Token,也包含他裡面的 One Time Toke func (l *CancelTokenLogic) CancelToken(in *permission.CancelTokenReq) (*permission.OKResp, error) { - // // 驗證所需 - // if err := l.svcCtx.Validate.ValidateAll(&cancelTokenReq{ - // Token: in.GetToken(), - // }); err != nil { - // return nil, ers.InvalidFormat(err.Error()) - // } + // 驗證所需 + if err := l.svcCtx.Validate.ValidateAll(&cancelTokenReq{ + Token: in.GetToken(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } - // claims, err := uc.parseClaims(accessToken) - // if err != nil { - // return err - // } - // - // token, err := uc.TokenRepository.GetByAccess(ctx, claims.ID()) - // if err != nil { - // if errors.Is(err, repository.ErrRecordNotFound) { - // return usecase.TokenError{Msg: "token not found"} - // } - // - // return usecase.InternalError{Err: fmt.Errorf("tokenRepository.GetByAccess error: %w", err)} - // } - // - // if err := uc.TokenRepository.Delete(ctx, token); err != nil { - // if errors.Is(err, repository.ErrRecordNotFound) { - // return nil, usecase.TokenError{Msg: "token not found"} - // } - // - // return nil, err - // } + claims, err := parseClaims(l.ctx, in.GetToken(), l.svcCtx.Config.Token.Secret) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "parseClaims"), + ).Error(err.Error()) + return nil, err + } + + token, err := l.svcCtx.TokenRedisRepo.GetByAccess(l.ctx, claims.ID()) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.GetByAccess"), + logx.Field("claims", claims), + ).Error(err.Error()) + return nil, err + } + + err = l.svcCtx.TokenRedisRepo.Delete(l.ctx, token) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.Delete"), + logx.Field("req", token), + ).Error(err.Error()) + return nil, err + } return &permission.OKResp{}, nil } diff --git a/internal/logic/new_token_logic.go b/internal/logic/new_token_logic.go index c9b961d..078bb33 100644 --- a/internal/logic/new_token_logic.go +++ b/internal/logic/new_token_logic.go @@ -7,7 +7,6 @@ import ( "ark-permission/internal/svc" ers "code.30cm.net/wanderland/library-go/errors" "context" - "fmt" "github.com/google/uuid" "time" @@ -88,7 +87,11 @@ func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.T var err error token.AccessToken, err = generateAccessTokenFunc(token, claims, l.svcCtx.Config.Token.Secret) if err != nil { - return nil, ers.ArkInternal(fmt.Errorf("accessGenerate token error: %w", err).Error()) + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "generateAccessTokenFunc"), + logx.Field("claims", claims), + ).Error(err.Error()) + return nil, err } if in.GetIsRefreshToken() { @@ -97,7 +100,11 @@ func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.T err = l.svcCtx.TokenRedisRepo.Create(l.ctx, token) if err != nil { - return nil, ers.ArkInternal(fmt.Errorf("tokenRepository.Create error: %w", err).Error()) + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.Create"), + logx.Field("token", token), + ).Error(err.Error()) + return nil, err } return &permission.TokenResp{ diff --git a/internal/logic/refresh_token_logic.go b/internal/logic/refresh_token_logic.go index 44295c9..c09cee8 100644 --- a/internal/logic/refresh_token_logic.go +++ b/internal/logic/refresh_token_logic.go @@ -1,13 +1,9 @@ package logic import ( - "ark-permission/internal/domain" - "context" - "fmt" - "strconv" - "ark-permission/gen_result/pb/permission" "ark-permission/internal/svc" + "context" "github.com/zeromicro/go-zero/core/logx" ) @@ -29,8 +25,6 @@ func NewRefreshTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Refr // RefreshToken 更新目前的token 以及裡面包含的一次性 Token func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permission.RefreshTokenResp, error) { // todo: add your logic here and delete this line - e := domain.TokenUnexpectedSigningErr("gg88g88") - fmt.Printf(strconv.Itoa(int(e.Code())), e.Category(), e.Scope(), e.FullCode(), e.Error()) return &permission.RefreshTokenResp{}, nil } diff --git a/internal/logic/utils_jwt.go b/internal/logic/utils_jwt.go index 4b5712d..f34d9ab 100644 --- a/internal/logic/utils_jwt.go +++ b/internal/logic/utils_jwt.go @@ -1,8 +1,10 @@ package logic import ( + "ark-permission/internal/domain" "ark-permission/internal/entity" "bytes" + "context" "crypto/sha256" "encoding/hex" "fmt" @@ -23,7 +25,7 @@ func generateAccessToken(token entity.Token, data any, sign string) (string, err accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claim). SignedString([]byte(sign)) if err != nil { - return "", err + return "", domain.TokenClaimError(err.Error()) } return accessToken, nil @@ -37,40 +39,54 @@ func generateRefreshToken(accessToken string) string { return hex.EncodeToString(h.Sum(nil)) } -func parseClaims(accessToken string) (claims, error) { - claimMap, err := parseToken(accessToken) +func parseClaims(ctx context.Context, accessToken string, secret string) (claims, error) { + claimMap, err := parseToken(ctx, accessToken, secret) if err != nil { return claims{}, err } - claims, ok := claimMap["data"].(map[string]string) + claims, ok := claimMap["data"].(map[string]any) if ok { - return claims, nil + + return convertMap(claims), nil } - return nil, fmt.Errorf("get data from claim map error") + return nil, domain.TokenClaimError("get data from claim map error") } -func parseToken(accessToken string) (jwt.MapClaims, error) { - // token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) { - // if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - // return nil, domain.TokenUnexpectedSigningErr(fmt.Sprintf("token unexpected signing method: %v", token.Header["alg"])) - // } - // - // return []byte(uc.Config.CustomConfig.Token.Secret), nil - // }) - // - // if err != nil { - // ers.FromCode() - // return jwt.MapClaims{}, usecase.TokenError{Msg: fmt.Sprintf("parse token error: %s token: %s", err.Error(), accessToken)} - // } - // - // claims, ok := token.Claims.(jwt.MapClaims) - // - // if !(ok && token.Valid) { - // return jwt.MapClaims{}, usecase.TokenError{Msg: "token valid error"} - // } - // - // return claims, nil - return nil, nil +func parseToken(ctx context.Context, accessToken string, secret string) (jwt.MapClaims, error) { + token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, domain.TokenUnexpectedSigningErr(fmt.Sprintf("token unexpected signing method: %v", token.Header["alg"])) + } + + return []byte(secret), nil + }) + + if err != nil { + return jwt.MapClaims{}, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + + if !(ok && token.Valid) { + return jwt.MapClaims{}, domain.TokenTokenValidateErr("token valid error") + } + + return claims, nil +} + +func convertMap(input map[string]interface{}) map[string]string { + output := make(map[string]string) + for key, value := range input { + switch v := value.(type) { + case string: + output[key] = v + case fmt.Stringer: + output[key] = v.String() + default: + output[key] = fmt.Sprintf("%v", value) + } + } + return output } diff --git a/internal/repository/token.go b/internal/repository/token.go index c228665..105d794 100644 --- a/internal/repository/token.go +++ b/internal/repository/token.go @@ -5,7 +5,6 @@ import ( "ark-permission/internal/domain/repository" "ark-permission/internal/entity" ers "code.30cm.net/wanderland/library-go/errors" - "context" "encoding/json" "errors" @@ -32,7 +31,7 @@ func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository { func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error { body, err := json.Marshal(token) if err != nil { - return wrapError("json.Marshal token error", err) + return ers.ArkInternal("json.Marshal token error", err.Error()) } err = t.store.Pipelined(func(tx redis.Pipeliner) error { @@ -53,11 +52,11 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error return nil }) if err != nil { - return wrapError("store.Pipelined error", err) + return domain.RedisPipLineError(err.Error()) } if err := t.SetUIDToken(token); err != nil { - return wrapError("SetUIDToken error", err) + return ers.ArkInternal("SetUIDToken error", err.Error()) } return nil @@ -76,7 +75,7 @@ func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error for _, key := range keys { if err := tx.Del(ctx, key).Err(); err != nil { - return fmt.Errorf("store.Del key error: %w", err) + return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) } } @@ -84,7 +83,7 @@ func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error key := domain.DeviceTokenRedisKey.With(token.UID).ToString() _, err := t.store.Hdel(key, token.DeviceID) if err != nil { - return fmt.Errorf("store.HDel deviceKey error: %w", err) + return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err)) } } @@ -92,7 +91,7 @@ func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error }) if err != nil { - return fmt.Errorf("store.Pipelined error: %w", err) + return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) } return nil @@ -100,17 +99,17 @@ func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error func (t *tokenRepository) get(key string) (entity.Token, error) { body, err := t.store.Get(key) - if errors.Is(err, redis.Nil) { + if errors.Is(err, redis.Nil) || body == "" { return entity.Token{}, ers.ResourceNotFound("token key not found in redis", key) } if err != nil { - return entity.Token{}, fmt.Errorf("store.Get tokenTag error: %w", err) + return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.Get tokenTag error: %v", err)) } var token entity.Token if err := json.Unmarshal([]byte(body), &token); err != nil { - return entity.Token{}, fmt.Errorf("json.Unmarshal token error: %w", err) + return entity.Token{}, ers.ArkInternal(fmt.Sprintf("json.Unmarshal token error: %w", err)) } return token, nil -- 2.40.1 From 0d13b0c5f0cc894f0aa68621b2e181186bc292c6 Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Sat, 10 Aug 2024 09:52:23 +0800 Subject: [PATCH 06/10] fix: update repostitory --- .../permissionservice/permission_service.go | 45 ++++ client/roleservice/role_service.go | 51 +++++ client/tokenservice/token_service.go | 125 +++++++++++ generate/protobuf/permission.proto | 40 +++- internal/domain/errors.go | 10 +- internal/domain/repository/token.go | 22 +- internal/entity/token.go | 4 +- .../logic/get_user_tokens_by_uid_logic.go | 31 --- internal/logic/new_one_time_token_logic.go | 31 --- internal/logic/refresh_token_logic.go | 30 --- internal/logic/roleservice/ping_logic.go | 30 +++ .../cancel_one_time_token_logic.go | 27 ++- .../cancel_token_by_device_id_logic.go | 4 +- .../cancel_token_by_uid_logic.go | 23 +- .../{ => tokenservice}/cancel_token_logic.go | 8 +- .../get_user_tokens_by_device_id_logic.go | 24 +- .../get_user_tokens_by_uid_logic.go | 57 +++++ .../tokenservice/new_one_time_token_logic.go | 69 ++++++ .../{ => tokenservice}/new_token_logic.go | 24 +- .../new_token_logic_test.go | 2 +- .../logic/tokenservice/refresh_token_logic.go | 133 +++++++++++ .../logic/{ => tokenservice}/utils_claims.go | 2 +- .../logic/{ => tokenservice}/utils_jwt.go | 9 +- .../tokenservice/validation_token_logic.go | 71 ++++++ internal/logic/validation_token_logic.go | 31 --- internal/repository/token.go | 212 +++++++++++++++++- .../permission_service_server.go | 20 ++ .../server/roleservice/role_service_server.go | 28 +++ .../token_service_server.go | 34 +-- permission.go | 8 +- tokenservice/token_service.go | 21 +- 31 files changed, 1024 insertions(+), 202 deletions(-) create mode 100644 client/permissionservice/permission_service.go create mode 100644 client/roleservice/role_service.go create mode 100644 client/tokenservice/token_service.go delete mode 100644 internal/logic/get_user_tokens_by_uid_logic.go delete mode 100644 internal/logic/new_one_time_token_logic.go delete mode 100644 internal/logic/refresh_token_logic.go create mode 100644 internal/logic/roleservice/ping_logic.go rename internal/logic/{ => tokenservice}/cancel_one_time_token_logic.go (51%) rename internal/logic/{ => tokenservice}/cancel_token_by_device_id_logic.go (91%) rename internal/logic/{ => tokenservice}/cancel_token_by_uid_logic.go (54%) rename internal/logic/{ => tokenservice}/cancel_token_logic.go (98%) rename internal/logic/{ => tokenservice}/get_user_tokens_by_device_id_logic.go (57%) create mode 100644 internal/logic/tokenservice/get_user_tokens_by_uid_logic.go create mode 100644 internal/logic/tokenservice/new_one_time_token_logic.go rename internal/logic/{ => tokenservice}/new_token_logic.go (86%) rename internal/logic/{ => tokenservice}/new_token_logic_test.go (99%) create mode 100644 internal/logic/tokenservice/refresh_token_logic.go rename internal/logic/{ => tokenservice}/utils_claims.go (96%) rename internal/logic/{ => tokenservice}/utils_jwt.go (88%) create mode 100644 internal/logic/tokenservice/validation_token_logic.go delete mode 100644 internal/logic/validation_token_logic.go create mode 100644 internal/server/permissionservice/permission_service_server.go create mode 100644 internal/server/roleservice/role_service_server.go rename internal/server/{ => tokenservice}/token_service_server.go (69%) diff --git a/client/permissionservice/permission_service.go b/client/permissionservice/permission_service.go new file mode 100644 index 0000000..fdae9cb --- /dev/null +++ b/client/permissionservice/permission_service.go @@ -0,0 +1,45 @@ +// Code generated by goctl. DO NOT EDIT. +// Source: permission.proto + +package permissionservice + +import ( + "context" + + "ark-permission/gen_result/pb/permission" + + "github.com/zeromicro/go-zero/zrpc" + "google.golang.org/grpc" +) + +type ( + AuthorizationReq = permission.AuthorizationReq + CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq + CancelTokenReq = permission.CancelTokenReq + CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq + CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp + DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq + DoTokenByUIDReq = permission.DoTokenByUIDReq + OKResp = permission.OKResp + QueryTokenByUIDReq = permission.QueryTokenByUIDReq + RefreshTokenReq = permission.RefreshTokenReq + RefreshTokenResp = permission.RefreshTokenResp + Token = permission.Token + TokenResp = permission.TokenResp + Tokens = permission.Tokens + ValidationTokenReq = permission.ValidationTokenReq + ValidationTokenResp = permission.ValidationTokenResp + + PermissionService interface { + } + + defaultPermissionService struct { + cli zrpc.Client + } +) + +func NewPermissionService(cli zrpc.Client) PermissionService { + return &defaultPermissionService{ + cli: cli, + } +} diff --git a/client/roleservice/role_service.go b/client/roleservice/role_service.go new file mode 100644 index 0000000..8d7e2e2 --- /dev/null +++ b/client/roleservice/role_service.go @@ -0,0 +1,51 @@ +// Code generated by goctl. DO NOT EDIT. +// Source: permission.proto + +package roleservice + +import ( + "context" + + "ark-permission/gen_result/pb/permission" + + "github.com/zeromicro/go-zero/zrpc" + "google.golang.org/grpc" +) + +type ( + AuthorizationReq = permission.AuthorizationReq + CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq + CancelTokenReq = permission.CancelTokenReq + CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq + CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp + DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq + DoTokenByUIDReq = permission.DoTokenByUIDReq + OKResp = permission.OKResp + QueryTokenByUIDReq = permission.QueryTokenByUIDReq + RefreshTokenReq = permission.RefreshTokenReq + RefreshTokenResp = permission.RefreshTokenResp + Token = permission.Token + TokenResp = permission.TokenResp + Tokens = permission.Tokens + ValidationTokenReq = permission.ValidationTokenReq + ValidationTokenResp = permission.ValidationTokenResp + + RoleService interface { + Ping(ctx context.Context, in *OKResp, opts ...grpc.CallOption) (*OKResp, error) + } + + defaultRoleService struct { + cli zrpc.Client + } +) + +func NewRoleService(cli zrpc.Client) RoleService { + return &defaultRoleService{ + cli: cli, + } +} + +func (m *defaultRoleService) Ping(ctx context.Context, in *OKResp, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewRoleServiceClient(m.cli.Conn()) + return client.Ping(ctx, in, opts...) +} diff --git a/client/tokenservice/token_service.go b/client/tokenservice/token_service.go new file mode 100644 index 0000000..2bc1af2 --- /dev/null +++ b/client/tokenservice/token_service.go @@ -0,0 +1,125 @@ +// Code generated by goctl. DO NOT EDIT. +// Source: permission.proto + +package tokenservice + +import ( + "context" + + "ark-permission/gen_result/pb/permission" + + "github.com/zeromicro/go-zero/zrpc" + "google.golang.org/grpc" +) + +type ( + AuthorizationReq = permission.AuthorizationReq + CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq + CancelTokenReq = permission.CancelTokenReq + CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq + CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp + DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq + DoTokenByUIDReq = permission.DoTokenByUIDReq + OKResp = permission.OKResp + QueryTokenByUIDReq = permission.QueryTokenByUIDReq + RefreshTokenReq = permission.RefreshTokenReq + RefreshTokenResp = permission.RefreshTokenResp + Token = permission.Token + TokenResp = permission.TokenResp + Tokens = permission.Tokens + ValidationTokenReq = permission.ValidationTokenReq + ValidationTokenResp = permission.ValidationTokenResp + + TokenService interface { + // NewToken 建立一個新的 Token,例如:AccessToken + NewToken(ctx context.Context, in *AuthorizationReq, opts ...grpc.CallOption) (*TokenResp, error) + // RefreshToken 更新目前的token 以及裡面包含的一次性 Token + RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) + // CancelToken 取消 Token,也包含他裡面的 One Time Toke + CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) + // CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke + CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) + // CancelTokenByDeviceId 取消 Token + CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) + // ValidationToken 驗證這個 Token 有沒有效 + ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) + // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens + GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) + // GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens + GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) + // NewOneTimeToken 建立一次性使用,例如:RefreshToken + NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) + // CancelOneTimeToken 取消一次性使用 + CancelOneTimeToken(ctx context.Context, in *CancelOneTimeTokenReq, opts ...grpc.CallOption) (*OKResp, error) + } + + defaultTokenService struct { + cli zrpc.Client + } +) + +func NewTokenService(cli zrpc.Client) TokenService { + return &defaultTokenService{ + cli: cli, + } +} + +// NewToken 建立一個新的 Token,例如:AccessToken +func (m *defaultTokenService) NewToken(ctx context.Context, in *AuthorizationReq, opts ...grpc.CallOption) (*TokenResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.NewToken(ctx, in, opts...) +} + +// RefreshToken 更新目前的token 以及裡面包含的一次性 Token +func (m *defaultTokenService) RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.RefreshToken(ctx, in, opts...) +} + +// CancelToken 取消 Token,也包含他裡面的 One Time Toke +func (m *defaultTokenService) CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelToken(ctx, in, opts...) +} + +// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke +func (m *defaultTokenService) CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelTokenByUid(ctx, in, opts...) +} + +// CancelTokenByDeviceId 取消 Token +func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelTokenByDeviceId(ctx, in, opts...) +} + +// ValidationToken 驗證這個 Token 有沒有效 +func (m *defaultTokenService) ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.ValidationToken(ctx, in, opts...) +} + +// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens +func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.GetUserTokensByDeviceId(ctx, in, opts...) +} + +// GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens +func (m *defaultTokenService) GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.GetUserTokensByUid(ctx, in, opts...) +} + +// NewOneTimeToken 建立一次性使用,例如:RefreshToken +func (m *defaultTokenService) NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.NewOneTimeToken(ctx, in, opts...) +} + +// CancelOneTimeToken 取消一次性使用 +func (m *defaultTokenService) CancelOneTimeToken(ctx context.Context, in *CancelOneTimeTokenReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelOneTimeToken(ctx, in, opts...) +} diff --git a/generate/protobuf/permission.proto b/generate/protobuf/permission.proto index 717f753..38bfdc0 100644 --- a/generate/protobuf/permission.proto +++ b/generate/protobuf/permission.proto @@ -68,7 +68,13 @@ message CancelTokenReq { // CancelTokenReq 註銷這個 Token message DoTokenByUIDReq { - repeated string uid = 1; + repeated string ids = 1; + string uid = 2; +} + +// QueryTokenByUIDReq 拿這個UID 找 Token +message QueryTokenByUIDReq { + string uid = 1; } // ValidationTokenReq 驗證這個 Token @@ -112,6 +118,10 @@ message DoTokenByDeviceIDReq { } message Tokens{ + repeated TokenResp token = 1; +} + +message CancelOneTimeTokenReq { repeated string token = 1; } @@ -125,23 +135,29 @@ service TokenService { rpc RefreshToken(RefreshTokenReq) returns(RefreshTokenResp); // CancelToken 取消 Token,也包含他裡面的 One Time Toke rpc CancelToken(CancelTokenReq) returns(OKResp); - // CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke - rpc CancelTokenByUid(DoTokenByUIDReq) returns(OKResp); - // CancelTokenByDeviceID 取消 Token - rpc CancelTokenByDeviceId(DoTokenByDeviceIDReq) returns(OKResp); // ValidationToken 驗證這個 Token 有沒有效 rpc ValidationToken(ValidationTokenReq) returns(ValidationTokenResp); - // GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens + // CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 + rpc CancelTokens(DoTokenByUIDReq) returns(OKResp); + + // CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token + rpc CancelTokenByDeviceId(DoTokenByDeviceIDReq) returns(OKResp); + // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens rpc GetUserTokensByDeviceId(DoTokenByDeviceIDReq) returns(Tokens); - // GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens - rpc GetUserTokensByUid(DoTokenByUIDReq) returns(Tokens); + + + // GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens + rpc GetUserTokensByUid(QueryTokenByUIDReq) returns(Tokens); + // NewOneTimeToken 建立一次性使用,例如:RefreshToken rpc NewOneTimeToken(CreateOneTimeTokenReq) returns(CreateOneTimeTokenResp); // CancelOneTimeToken 取消一次性使用 - rpc CancelOneTimeToken(CreateOneTimeTokenReq) returns(CreateOneTimeTokenResp); + rpc CancelOneTimeToken(CancelOneTimeTokenReq) returns(OKResp); } -//service Role_Service {} -// -//service Permission_Service {} \ No newline at end of file +service RoleService { + rpc Ping(OKResp) returns(OKResp); +} + +service PermissionService {} \ No newline at end of file diff --git a/internal/domain/errors.go b/internal/domain/errors.go index 0ccbdfd..3182133 100644 --- a/internal/domain/errors.go +++ b/internal/domain/errors.go @@ -22,6 +22,7 @@ const ( const ( RedisDelErrorCode = iota + 20 RedisPipLineErrorCode + RedisErrorCode ) // TokenUnexpectedSigningErr 30001 Token 簽名錯誤 @@ -53,5 +54,12 @@ func RedisDelError(msg string) *ers.Err { func RedisPipLineError(msg string) *ers.Err { // 看需要建立哪些 Metrics mts.AppErrorMetrics.AddFailure("redis", "pip_line_error") - return ers.NewErr(code.CloudEPPermission, code.CatInput, TokenClaimErrorCode, msg) + return ers.NewErr(code.CloudEPPermission, code.CatInput, RedisPipLineErrorCode, msg) +} + +// RedisError 30022 Redis 錯誤 +func RedisError(msg string) *ers.Err { + // 看需要建立哪些 Metrics + mts.AppErrorMetrics.AddFailure("redis", "error") + return ers.NewErr(code.CloudEPPermission, code.CatInput, RedisErrorCode, msg) } diff --git a/internal/domain/repository/token.go b/internal/domain/repository/token.go index 8d7de3e..57bd4f6 100644 --- a/internal/domain/repository/token.go +++ b/internal/domain/repository/token.go @@ -3,10 +3,30 @@ package repository import ( "ark-permission/internal/entity" "context" + "time" ) type TokenRepository interface { Create(ctx context.Context, token entity.Token) error - GetByAccess(ctx context.Context, id string) (entity.Token, error) + DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error + CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error + GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) + + GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) + GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) + GetAccessTokenCountByUID(uid string) (int, error) + GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) + GetAccessTokenCountByDeviceID(deviceID string) (int, error) + Delete(ctx context.Context, token entity.Token) error + DeleteAccessTokenByID(ctx context.Context, id string) error + DeleteAccessTokensByUID(ctx context.Context, uid string) error + DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error + DeleteAccessTokenByDeviceIDAndUID(ctx context.Context, deviceID, uid string) error + DeleteUIDToken(ctx context.Context, uid string, ids []string) error +} + +type DeviceToken struct { + DeviceID string + TokenID string } diff --git a/internal/entity/token.go b/internal/entity/token.go index eeadd26..c2cd7fc 100644 --- a/internal/entity/token.go +++ b/internal/entity/token.go @@ -33,6 +33,6 @@ func (t *Token) IsExpires() bool { type UIDToken map[string]int64 type Ticket struct { - Data interface{} `json:"data"` - Token Token `json:"token"` + Data any `json:"data"` + Token Token `json:"token"` } diff --git a/internal/logic/get_user_tokens_by_uid_logic.go b/internal/logic/get_user_tokens_by_uid_logic.go deleted file mode 100644 index ee70ab1..0000000 --- a/internal/logic/get_user_tokens_by_uid_logic.go +++ /dev/null @@ -1,31 +0,0 @@ -package logic - -import ( - "context" - - "ark-permission/gen_result/pb/permission" - "ark-permission/internal/svc" - - "github.com/zeromicro/go-zero/core/logx" -) - -type GetUserTokensByUidLogic struct { - ctx context.Context - svcCtx *svc.ServiceContext - logx.Logger -} - -func NewGetUserTokensByUidLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserTokensByUidLogic { - return &GetUserTokensByUidLogic{ - ctx: ctx, - svcCtx: svcCtx, - Logger: logx.WithContext(ctx), - } -} - -// GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens -func (l *GetUserTokensByUidLogic) GetUserTokensByUid(in *permission.DoTokenByUIDReq) (*permission.Tokens, error) { - // todo: add your logic here and delete this line - - return &permission.Tokens{}, nil -} diff --git a/internal/logic/new_one_time_token_logic.go b/internal/logic/new_one_time_token_logic.go deleted file mode 100644 index 7183ba4..0000000 --- a/internal/logic/new_one_time_token_logic.go +++ /dev/null @@ -1,31 +0,0 @@ -package logic - -import ( - "context" - - "ark-permission/gen_result/pb/permission" - "ark-permission/internal/svc" - - "github.com/zeromicro/go-zero/core/logx" -) - -type NewOneTimeTokenLogic struct { - ctx context.Context - svcCtx *svc.ServiceContext - logx.Logger -} - -func NewNewOneTimeTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *NewOneTimeTokenLogic { - return &NewOneTimeTokenLogic{ - ctx: ctx, - svcCtx: svcCtx, - Logger: logx.WithContext(ctx), - } -} - -// NewOneTimeToken 建立一次性使用,例如:RefreshToken -func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeTokenReq) (*permission.CreateOneTimeTokenResp, error) { - // todo: add your logic here and delete this line - - return &permission.CreateOneTimeTokenResp{}, nil -} diff --git a/internal/logic/refresh_token_logic.go b/internal/logic/refresh_token_logic.go deleted file mode 100644 index c09cee8..0000000 --- a/internal/logic/refresh_token_logic.go +++ /dev/null @@ -1,30 +0,0 @@ -package logic - -import ( - "ark-permission/gen_result/pb/permission" - "ark-permission/internal/svc" - "context" - - "github.com/zeromicro/go-zero/core/logx" -) - -type RefreshTokenLogic struct { - ctx context.Context - svcCtx *svc.ServiceContext - logx.Logger -} - -func NewRefreshTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RefreshTokenLogic { - return &RefreshTokenLogic{ - ctx: ctx, - svcCtx: svcCtx, - Logger: logx.WithContext(ctx), - } -} - -// RefreshToken 更新目前的token 以及裡面包含的一次性 Token -func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permission.RefreshTokenResp, error) { - // todo: add your logic here and delete this line - - return &permission.RefreshTokenResp{}, nil -} diff --git a/internal/logic/roleservice/ping_logic.go b/internal/logic/roleservice/ping_logic.go new file mode 100644 index 0000000..bbe1ccb --- /dev/null +++ b/internal/logic/roleservice/ping_logic.go @@ -0,0 +1,30 @@ +package roleservicelogic + +import ( + "context" + + "ark-permission/gen_result/pb/permission" + "ark-permission/internal/svc" + + "github.com/zeromicro/go-zero/core/logx" +) + +type PingLogic struct { + ctx context.Context + svcCtx *svc.ServiceContext + logx.Logger +} + +func NewPingLogic(ctx context.Context, svcCtx *svc.ServiceContext) *PingLogic { + return &PingLogic{ + ctx: ctx, + svcCtx: svcCtx, + Logger: logx.WithContext(ctx), + } +} + +func (l *PingLogic) Ping(in *permission.OKResp) (*permission.OKResp, error) { + // todo: add your logic here and delete this line + + return &permission.OKResp{}, nil +} diff --git a/internal/logic/cancel_one_time_token_logic.go b/internal/logic/tokenservice/cancel_one_time_token_logic.go similarity index 51% rename from internal/logic/cancel_one_time_token_logic.go rename to internal/logic/tokenservice/cancel_one_time_token_logic.go index 86e6d5e..86092eb 100644 --- a/internal/logic/cancel_one_time_token_logic.go +++ b/internal/logic/tokenservice/cancel_one_time_token_logic.go @@ -1,6 +1,7 @@ -package logic +package tokenservicelogic import ( + ers "code.30cm.net/wanderland/library-go/errors" "context" "ark-permission/gen_result/pb/permission" @@ -23,9 +24,23 @@ func NewCancelOneTimeTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) } } -// CancelOneTimeToken 取消一次性使用 -func (l *CancelOneTimeTokenLogic) CancelOneTimeToken(in *permission.CreateOneTimeTokenReq) (*permission.CreateOneTimeTokenResp, error) { - // todo: add your logic here and delete this line - - return &permission.CreateOneTimeTokenResp{}, nil +type cancelOneTimeTokenReq struct { + Token []string `json:"token" validate:"required"` +} + +// CancelOneTimeToken 取消一次性使用 +func (l *CancelOneTimeTokenLogic) CancelOneTimeToken(in *permission.CancelOneTimeTokenReq) (*permission.OKResp, error) { + // 驗證所需 + if err := l.svcCtx.Validate.ValidateAll(&cancelOneTimeTokenReq{ + Token: in.GetToken(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } + + err := l.svcCtx.TokenRedisRepo.DeleteOneTimeToken(l.ctx, in.GetToken(), nil) + if err != nil { + return nil, err + } + + return &permission.OKResp{}, nil } diff --git a/internal/logic/cancel_token_by_device_id_logic.go b/internal/logic/tokenservice/cancel_token_by_device_id_logic.go similarity index 91% rename from internal/logic/cancel_token_by_device_id_logic.go rename to internal/logic/tokenservice/cancel_token_by_device_id_logic.go index 862a8d5..ed59824 100644 --- a/internal/logic/cancel_token_by_device_id_logic.go +++ b/internal/logic/tokenservice/cancel_token_by_device_id_logic.go @@ -1,4 +1,4 @@ -package logic +package tokenservicelogic import ( "context" @@ -23,7 +23,7 @@ func NewCancelTokenByDeviceIdLogic(ctx context.Context, svcCtx *svc.ServiceConte } } -// CancelTokenByDeviceID 取消 Token +// CancelTokenByDeviceId 取消 Token func (l *CancelTokenByDeviceIdLogic) CancelTokenByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { // todo: add your logic here and delete this line diff --git a/internal/logic/cancel_token_by_uid_logic.go b/internal/logic/tokenservice/cancel_token_by_uid_logic.go similarity index 54% rename from internal/logic/cancel_token_by_uid_logic.go rename to internal/logic/tokenservice/cancel_token_by_uid_logic.go index 05b0566..f266185 100644 --- a/internal/logic/cancel_token_by_uid_logic.go +++ b/internal/logic/tokenservice/cancel_token_by_uid_logic.go @@ -1,6 +1,7 @@ -package logic +package tokenservicelogic import ( + ers "code.30cm.net/wanderland/library-go/errors" "context" "ark-permission/gen_result/pb/permission" @@ -23,9 +24,25 @@ func NewCancelTokenByUidLogic(ctx context.Context, svcCtx *svc.ServiceContext) * } } -// CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke +type deleteByTokenIDs struct { + UID string `json:"uid" binding:"required"` + IDs []string `json:"ids" binding:"required"` +} + +// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke func (l *CancelTokenByUidLogic) CancelTokenByUid(in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { - // todo: add your logic here and delete this line + // 驗證所需 + if err := l.svcCtx.Validate.ValidateAll(&deleteByTokenIDs{ + UID: in.GetUid(), + IDs: in.GetIds(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } + + err := l.svcCtx.TokenRedisRepo.DeleteUIDToken(l.ctx, in.GetUid(), in.GetIds()) + if err != nil { + return nil, err + } return &permission.OKResp{}, nil } diff --git a/internal/logic/cancel_token_logic.go b/internal/logic/tokenservice/cancel_token_logic.go similarity index 98% rename from internal/logic/cancel_token_logic.go rename to internal/logic/tokenservice/cancel_token_logic.go index d7f94a2..4f615bd 100644 --- a/internal/logic/cancel_token_logic.go +++ b/internal/logic/tokenservice/cancel_token_logic.go @@ -1,10 +1,12 @@ -package logic +package tokenservicelogic import ( - "ark-permission/gen_result/pb/permission" - "ark-permission/internal/svc" ers "code.30cm.net/wanderland/library-go/errors" "context" + + "ark-permission/gen_result/pb/permission" + "ark-permission/internal/svc" + "github.com/zeromicro/go-zero/core/logx" ) diff --git a/internal/logic/get_user_tokens_by_device_id_logic.go b/internal/logic/tokenservice/get_user_tokens_by_device_id_logic.go similarity index 57% rename from internal/logic/get_user_tokens_by_device_id_logic.go rename to internal/logic/tokenservice/get_user_tokens_by_device_id_logic.go index d633995..93836ea 100644 --- a/internal/logic/get_user_tokens_by_device_id_logic.go +++ b/internal/logic/tokenservice/get_user_tokens_by_device_id_logic.go @@ -1,11 +1,9 @@ -package logic +package tokenservicelogic import ( - "context" - "ark-permission/gen_result/pb/permission" "ark-permission/internal/svc" - + "context" "github.com/zeromicro/go-zero/core/logx" ) @@ -23,9 +21,23 @@ func NewGetUserTokensByDeviceIdLogic(ctx context.Context, svcCtx *svc.ServiceCon } } -// GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens +// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens func (l *GetUserTokensByDeviceIdLogic) GetUserTokensByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { - // todo: add your logic here and delete this line + + // ids, err := l.svcCtx.TokenRedisRepo.GetAccessTokensByDeviceID(l.ctx, "") + // if err != nil { + // return nil, error + // } + + // tokenIDs := make([]usecase.DeviceToken, 0, len(ids)) + // for _, v := range ids { + // tokenIDs = append(tokenIDs, usecase.DeviceToken{ + // DeviceID: v.DeviceID, + // TokenID: v.TokenID, + // }) + // } + // + // return tokenIDs, nil return &permission.Tokens{}, nil } diff --git a/internal/logic/tokenservice/get_user_tokens_by_uid_logic.go b/internal/logic/tokenservice/get_user_tokens_by_uid_logic.go new file mode 100644 index 0000000..eb3a1f0 --- /dev/null +++ b/internal/logic/tokenservice/get_user_tokens_by_uid_logic.go @@ -0,0 +1,57 @@ +package tokenservicelogic + +import ( + "ark-permission/gen_result/pb/permission" + "ark-permission/internal/domain" + "ark-permission/internal/svc" + ers "code.30cm.net/wanderland/library-go/errors" + "context" + + "github.com/zeromicro/go-zero/core/logx" +) + +type GetUserTokensByUidLogic struct { + ctx context.Context + svcCtx *svc.ServiceContext + logx.Logger +} + +func NewGetUserTokensByUidLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserTokensByUidLogic { + return &GetUserTokensByUidLogic{ + ctx: ctx, + svcCtx: svcCtx, + Logger: logx.WithContext(ctx), + } +} + +type getUserTokensByUidReq struct { + UID string `json:"uid" validate:"required"` +} + +// GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens +func (l *GetUserTokensByUidLogic) GetUserTokensByUid(in *permission.QueryTokenByUIDReq) (*permission.Tokens, error) { + if err := l.svcCtx.Validate.ValidateAll(&getUserTokensByUidReq{ + UID: in.GetUid(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } + + uidTokens, err := l.svcCtx.TokenRedisRepo.GetAccessTokensByUID(l.ctx, in.GetUid()) + if err != nil { + return nil, err + } + + tokens := make([]*permission.TokenResp, 0, len(uidTokens)) + for _, v := range uidTokens { + tokens = append(tokens, &permission.TokenResp{ + AccessToken: v.AccessToken, + TokenType: domain.TokenTypeBearer, + ExpiresIn: int32(v.ExpiresIn), + RefreshToken: v.RefreshToken, + }) + } + + return &permission.Tokens{ + Token: tokens, + }, nil +} diff --git a/internal/logic/tokenservice/new_one_time_token_logic.go b/internal/logic/tokenservice/new_one_time_token_logic.go new file mode 100644 index 0000000..bcb02d5 --- /dev/null +++ b/internal/logic/tokenservice/new_one_time_token_logic.go @@ -0,0 +1,69 @@ +package tokenservicelogic + +import ( + "ark-permission/internal/domain" + "ark-permission/internal/entity" + ers "code.30cm.net/wanderland/library-go/errors" + "context" + "time" + + "ark-permission/gen_result/pb/permission" + "ark-permission/internal/svc" + + "github.com/zeromicro/go-zero/core/logx" +) + +type NewOneTimeTokenLogic struct { + ctx context.Context + svcCtx *svc.ServiceContext + logx.Logger +} + +func NewNewOneTimeTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *NewOneTimeTokenLogic { + return &NewOneTimeTokenLogic{ + ctx: ctx, + svcCtx: svcCtx, + Logger: logx.WithContext(ctx), + } +} + +// NewOneTimeToken 建立一次性使用,例如:RefreshToken +func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeTokenReq) (*permission.CreateOneTimeTokenResp, error) { + // 驗證所需 + if err := l.svcCtx.Validate.ValidateAll(&refreshTokenReq{ + Token: in.GetToken(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } + + // 驗證Token + claims, err := parseClaims(l.ctx, in.GetToken(), l.svcCtx.Config.Token.Secret) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "parseClaims"), + ).Error(err.Error()) + return nil, err + } + + token, err := l.svcCtx.TokenRedisRepo.GetByAccess(l.ctx, claims.ID()) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.GetByAccess"), + logx.Field("claims", claims), + ).Error(err.Error()) + return nil, err + } + + oneTimeToken := generateRefreshToken(in.GetToken()) + key := domain.TicketKeyPrefix + oneTimeToken + if err = l.svcCtx.TokenRedisRepo.CreateOneTimeToken(l.ctx, key, entity.Ticket{ + Data: claims, + Token: token, + }, time.Minute); err != nil { + return &permission.CreateOneTimeTokenResp{}, err + } + + return &permission.CreateOneTimeTokenResp{ + OneTimeToken: oneTimeToken, + }, nil +} diff --git a/internal/logic/new_token_logic.go b/internal/logic/tokenservice/new_token_logic.go similarity index 86% rename from internal/logic/new_token_logic.go rename to internal/logic/tokenservice/new_token_logic.go index 078bb33..487e3f9 100644 --- a/internal/logic/new_token_logic.go +++ b/internal/logic/tokenservice/new_token_logic.go @@ -1,15 +1,16 @@ -package logic +package tokenservicelogic import ( - "ark-permission/gen_result/pb/permission" "ark-permission/internal/domain" "ark-permission/internal/entity" - "ark-permission/internal/svc" ers "code.30cm.net/wanderland/library-go/errors" "context" "github.com/google/uuid" "time" + "ark-permission/gen_result/pb/permission" + "ark-permission/internal/svc" + "github.com/zeromicro/go-zero/core/logx" ) @@ -37,9 +38,6 @@ type authorizationReq struct { IsRefreshToken bool `json:"is_refresh_token"` } -var generateAccessTokenFunc = generateAccessToken -var generateRefreshTokenFunc = generateRefreshToken - // NewToken 建立一個新的 Token,例如:AccessToken func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.TokenResp, error) { // 驗證所需 @@ -55,13 +53,23 @@ func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.T expires := int(in.GetExpires()) refreshExpires := int(in.GetExpires()) if expires <= 0 { - expires = int(l.svcCtx.Config.Token.Expired.Seconds()) + // 將時間加上 300 秒 + sec := time.Duration(l.svcCtx.Config.Token.Expired.Seconds()) * time.Second + newTime := now.Add(sec) + // 獲取 Unix 時間戳 + timestamp := newTime.Unix() + expires = int(timestamp) refreshExpires = expires } // 如果這是一個 Refresh Token 過期時間要比普通的Token 長 if in.GetIsRefreshToken() { - refreshExpires = int(l.svcCtx.Config.Token.RefreshExpires.Seconds()) + // 將時間加上 300 秒 + sec := time.Duration(l.svcCtx.Config.Token.RefreshExpires.Seconds()) * time.Second + newTime := now.Add(sec) + // 獲取 Unix 時間戳 + timestamp := newTime.Unix() + refreshExpires = int(timestamp) } token := entity.Token{ diff --git a/internal/logic/new_token_logic_test.go b/internal/logic/tokenservice/new_token_logic_test.go similarity index 99% rename from internal/logic/new_token_logic_test.go rename to internal/logic/tokenservice/new_token_logic_test.go index e00c816..76263c9 100644 --- a/internal/logic/new_token_logic_test.go +++ b/internal/logic/tokenservice/new_token_logic_test.go @@ -1,4 +1,4 @@ -package logic +package tokenservicelogic import ( "ark-permission/internal/entity" diff --git a/internal/logic/tokenservice/refresh_token_logic.go b/internal/logic/tokenservice/refresh_token_logic.go new file mode 100644 index 0000000..dee0484 --- /dev/null +++ b/internal/logic/tokenservice/refresh_token_logic.go @@ -0,0 +1,133 @@ +package tokenservicelogic + +import ( + "ark-permission/internal/domain" + "ark-permission/internal/entity" + ers "code.30cm.net/wanderland/library-go/errors" + "context" + "time" + + "ark-permission/gen_result/pb/permission" + "ark-permission/internal/svc" + + "github.com/zeromicro/go-zero/core/logx" +) + +type RefreshTokenLogic struct { + ctx context.Context + svcCtx *svc.ServiceContext + logx.Logger +} + +func NewRefreshTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RefreshTokenLogic { + return &RefreshTokenLogic{ + ctx: ctx, + svcCtx: svcCtx, + Logger: logx.WithContext(ctx), + } +} + +type refreshReq struct { + RefreshToken string `json:"grant_type" validate:"required"` + DeviceID string `json:"device_id" validate:"required"` + Scope string `json:"scope" validate:"required"` + Expires int64 `json:"expires" validate:"required"` +} + +// RefreshToken 更新目前的token 以及裡面包含的一次性 Token +func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permission.RefreshTokenResp, error) { + // 驗證所需 + if err := l.svcCtx.Validate.ValidateAll(&refreshReq{ + RefreshToken: in.GetToken(), + Scope: in.GetScope(), + DeviceID: in.GetDeviceId(), + Expires: in.GetExpires(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } + // step 1 拿看看有沒有這個 refresh token + token, err := l.svcCtx.TokenRedisRepo.GetByRefresh(l.ctx, in.Token) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.GetByRefresh"), + logx.Field("req", in), + ).Error(err.Error()) + return nil, err + } + // 拿到之後替換掉時間以及 refresh token + // refreshToken 建立 + now := time.Now().UTC() + sec := time.Duration(l.svcCtx.Config.Token.RefreshExpires.Seconds()) * time.Second + newTime := now.Add(sec) + // 獲取 Unix 時間戳 + timestamp := newTime.Unix() + refreshExpires := int(timestamp) + expires := int(in.GetExpires()) + if expires <= 0 { + // 將時間加上 300 秒 + sec := time.Duration(l.svcCtx.Config.Token.Expired.Seconds()) * time.Second + newTime := now.Add(sec) + // 獲取 Unix 時間戳 + timestamp := newTime.Unix() + expires = int(timestamp) + } + + newToken := entity.Token{ + ID: token.ID, + UID: token.UID, + DeviceID: in.GetDeviceId(), + ExpiresIn: expires, + RefreshExpiresIn: refreshExpires, + AccessCreateAt: now, + RefreshCreateAt: now, + } + + claims := claims(map[string]string{ + "uid": token.UID, + }) + claims.SetRole(domain.DefaultRole) + claims.SetID(token.ID) + claims.SetScope(in.GetScope()) + claims.UID() + + if in.GetDeviceId() != "" { + claims.SetDeviceID(in.GetDeviceId()) + } + + newToken.AccessToken, err = generateAccessTokenFunc(newToken, claims, l.svcCtx.Config.Token.Secret) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "generateAccessTokenFunc"), + logx.Field("claims", claims), + ).Error(err.Error()) + return nil, err + } + + newToken.RefreshToken = generateRefreshTokenFunc(newToken.AccessToken) + + // 刪除掉舊的 token + err = l.svcCtx.TokenRedisRepo.Delete(l.ctx, token) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.Delete"), + logx.Field("req", token), + ).Error(err.Error()) + return nil, err + } + + err = l.svcCtx.TokenRedisRepo.Create(l.ctx, newToken) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.Create"), + logx.Field("token", token), + ).Error(err.Error()) + return nil, err + } + + return &permission.RefreshTokenResp{ + Token: newToken.AccessToken, + OneTimeToken: newToken.RefreshToken, + ExpiresIn: int64(expires), + TokenType: domain.TokenTypeBearer, + }, nil +} diff --git a/internal/logic/utils_claims.go b/internal/logic/tokenservice/utils_claims.go similarity index 96% rename from internal/logic/utils_claims.go rename to internal/logic/tokenservice/utils_claims.go index 937953d..2d59b66 100644 --- a/internal/logic/utils_claims.go +++ b/internal/logic/tokenservice/utils_claims.go @@ -1,4 +1,4 @@ -package logic +package tokenservicelogic type claims map[string]string diff --git a/internal/logic/utils_jwt.go b/internal/logic/tokenservice/utils_jwt.go similarity index 88% rename from internal/logic/utils_jwt.go rename to internal/logic/tokenservice/utils_jwt.go index f34d9ab..c6550b9 100644 --- a/internal/logic/utils_jwt.go +++ b/internal/logic/tokenservice/utils_jwt.go @@ -1,4 +1,4 @@ -package logic +package tokenservicelogic import ( "ark-permission/internal/domain" @@ -12,6 +12,9 @@ import ( "time" ) +var generateAccessTokenFunc = generateAccessToken +var generateRefreshTokenFunc = generateRefreshToken + func generateAccessToken(token entity.Token, data any, sign string) (string, error) { claim := entity.Claims{ Data: data, @@ -40,7 +43,7 @@ func generateRefreshToken(accessToken string) string { } func parseClaims(ctx context.Context, accessToken string, secret string) (claims, error) { - claimMap, err := parseToken(ctx, accessToken, secret) + claimMap, err := parseToken(accessToken, secret) if err != nil { return claims{}, err } @@ -54,7 +57,7 @@ func parseClaims(ctx context.Context, accessToken string, secret string) (claims return nil, domain.TokenClaimError("get data from claim map error") } -func parseToken(ctx context.Context, accessToken string, secret string) (jwt.MapClaims, error) { +func parseToken(accessToken string, secret string) (jwt.MapClaims, error) { token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, domain.TokenUnexpectedSigningErr(fmt.Sprintf("token unexpected signing method: %v", token.Header["alg"])) diff --git a/internal/logic/tokenservice/validation_token_logic.go b/internal/logic/tokenservice/validation_token_logic.go new file mode 100644 index 0000000..fd00492 --- /dev/null +++ b/internal/logic/tokenservice/validation_token_logic.go @@ -0,0 +1,71 @@ +package tokenservicelogic + +import ( + ers "code.30cm.net/wanderland/library-go/errors" + "context" + + "ark-permission/gen_result/pb/permission" + "ark-permission/internal/svc" + + "github.com/zeromicro/go-zero/core/logx" +) + +type ValidationTokenLogic struct { + ctx context.Context + svcCtx *svc.ServiceContext + logx.Logger +} + +func NewValidationTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ValidationTokenLogic { + return &ValidationTokenLogic{ + ctx: ctx, + svcCtx: svcCtx, + Logger: logx.WithContext(ctx), + } +} + +type refreshTokenReq struct { + Token string `json:"token" validate:"required"` +} + +// ValidationToken 驗證這個 Token 有沒有效 +func (l *ValidationTokenLogic) ValidationToken(in *permission.ValidationTokenReq) (*permission.ValidationTokenResp, error) { + // 驗證所需 + if err := l.svcCtx.Validate.ValidateAll(&refreshTokenReq{ + Token: in.GetToken(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } + + claims, err := parseClaims(l.ctx, in.GetToken(), l.svcCtx.Config.Token.Secret) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "parseClaims"), + ).Error(err.Error()) + return nil, err + } + + token, err := l.svcCtx.TokenRedisRepo.GetByAccess(l.ctx, claims.ID()) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.GetByAccess"), + logx.Field("claims", claims), + ).Error(err.Error()) + return nil, err + } + + return &permission.ValidationTokenResp{ + Token: &permission.Token{ + Id: token.ID, + Uid: token.UID, + DeviceId: token.DeviceID, + AccessCreateAt: token.AccessCreateAt.Unix(), + AccessToken: token.AccessToken, + ExpiresIn: int32(token.ExpiresIn), + RefreshToken: token.RefreshToken, + RefreshExpiresIn: int32(token.RefreshExpiresIn), + RefreshCreateAt: token.RefreshCreateAt.Unix(), + }, + Data: claims, + }, nil +} diff --git a/internal/logic/validation_token_logic.go b/internal/logic/validation_token_logic.go deleted file mode 100644 index 40588d8..0000000 --- a/internal/logic/validation_token_logic.go +++ /dev/null @@ -1,31 +0,0 @@ -package logic - -import ( - "context" - - "ark-permission/gen_result/pb/permission" - "ark-permission/internal/svc" - - "github.com/zeromicro/go-zero/core/logx" -) - -type ValidationTokenLogic struct { - ctx context.Context - svcCtx *svc.ServiceContext - logx.Logger -} - -func NewValidationTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ValidationTokenLogic { - return &ValidationTokenLogic{ - ctx: ctx, - svcCtx: svcCtx, - Logger: logx.WithContext(ctx), - } -} - -// ValidationToken 驗證這個 Token 有沒有效 -func (l *ValidationTokenLogic) ValidationToken(in *permission.ValidationTokenReq) (*permission.ValidationTokenResp, error) { - // todo: add your logic here and delete this line - - return &permission.ValidationTokenResp{}, nil -} diff --git a/internal/repository/token.go b/internal/repository/token.go index 105d794..e434c3b 100644 --- a/internal/repository/token.go +++ b/internal/repository/token.go @@ -22,6 +22,41 @@ type tokenRepository struct { store *redis.Redis } +func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) { + // TODO implement me + panic("implement me") +} + +func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { + // TODO implement me + panic("implement me") +} + +func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) { + // TODO implement me + panic("implement me") +} + +func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, id string) error { + // TODO implement me + panic("implement me") +} + +func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error { + // TODO implement me + panic("implement me") +} + +func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { + // TODO implement me + panic("implement me") +} + +func (t *tokenRepository) DeleteAccessTokenByDeviceIDAndUID(ctx context.Context, deviceID, uid string) error { + // TODO implement me + panic("implement me") +} + func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository { return &tokenRepository{ store: param.Store, @@ -62,10 +97,185 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error return nil } -func (t *tokenRepository) GetByAccess(_ context.Context, id string) (entity.Token, error) { +// // GetAccessTokensByDeviceID 透過 Device ID 得到目前未過期的token +// func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, uid string) ([]repository.DeviceToken, error) { +// data, err := t.store.Hgetall(domain.DeviceTokenRedisKey.With(uid).ToString()) +// if err != nil { +// if errors.Is(err, redis.Nil) { +// return nil, nil +// } +// +// return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.HGetAll Device Token error: %v", err.Error())) +// } +// +// ids := make([]repository.DeviceToken, 0, len(data)) +// for deviceID, id := range data { +// ids = append(ids, repository.DeviceToken{ +// DeviceID: deviceID, +// +// // e0a4f824-41db-4eb2-8e5a-d96966ea1d56-1698083859 +// // -11是因為id組成最後11位數是-跟時間戳記 +// TokenID: id[:len(id)-11], +// }) +// } +// return ids, nil +// } + +// GetAccessTokensByUID 透過 uid 得到目前未過期的 token +func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) { + utKeys, err := t.store.Get(domain.GetUIDTokenRedisKey(uid)) + if err != nil { + // 沒有就視為回空 + if errors.Is(err, redis.Nil) { + return nil, nil + } + + return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByUID store.Get GetUIDTokenRedisKey error: %v", err.Error())) + } + + uidTokens := make(entity.UIDToken) + err = json.Unmarshal([]byte(utKeys), &uidTokens) + if err != nil { + return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) + } + + now := time.Now().Unix() + var tokens []entity.Token + var deleteToken []string + for id, token := range uidTokens { + if token < now { + deleteToken = append(deleteToken, id) + + continue + } + + tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id)) + if err == nil { + item := entity.Token{} + err = json.Unmarshal([]byte(tk), &item) + if err != nil { + return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) + } + tokens = append(tokens, item) + } + + if errors.Is(err, redis.Nil) { + deleteToken = append(deleteToken, id) + } + } + + if len(deleteToken) > 0 { + // 如果失敗也沒關係,其他get method撈取時會在判斷是否過期或存在 + _ = t.DeleteUIDToken(ctx, uid, deleteToken) + } + + return tokens, nil +} + +func (t *tokenRepository) DeleteUIDToken(ctx context.Context, uid string, ids []string) error { + uidTokens := make(entity.UIDToken) + tokenKeys, err := t.store.Get(domain.GetUIDTokenRedisKey(uid)) + if err != nil { + if !errors.Is(err, redis.Nil) { + return fmt.Errorf("tx.get GetDeviceTokenRedisKey error: %w", err) + } + } + + if tokenKeys != "" { + err = json.Unmarshal([]byte(tokenKeys), &uidTokens) + if err != nil { + return fmt.Errorf("json.Unmarshal GetDeviceTokenRedisKey error: %w", err) + } + } + + now := time.Now().Unix() + for k, t := range uidTokens { + // 到期就刪除 + if t < now { + delete(uidTokens, k) + } + } + + for _, id := range ids { + delete(uidTokens, id) + } + + b, err := json.Marshal(uidTokens) + if err != nil { + return fmt.Errorf("json.Marshal UIDToken error: %w", err) + } + + _, err = t.store.SetnxEx(domain.GetUIDTokenRedisKey(uid), string(b), 86400*30) + if err != nil { + return fmt.Errorf("tx.set GetUIDTokenRedisKey error: %w", err) + } + + return nil +} + +func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) { return t.get(domain.GetAccessTokenRedisKey(id)) } +func (t *tokenRepository) GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) { + id, err := t.store.Get(domain.RefreshTokenRedisKey.With(refreshToken).ToString()) + if err != nil { + return entity.Token{}, err + } + + if errors.Is(err, redis.Nil) || id == "" { + return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(refreshToken).ToString()) + } + + if err != nil { + return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.GetByRefresh refresh token error: %v", err)) + } + + return t.GetAccessTokenByID(ctx, id) +} + +func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error { + err := t.store.Pipelined(func(tx redis.Pipeliner) error { + keys := make([]string, 0, len(ids)+len(tokens)) + + for _, id := range ids { + keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString()) + } + + for _, token := range tokens { + keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()) + } + + for _, key := range keys { + if err := tx.Del(ctx, key).Err(); err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) + } + } + + return nil + }) + + if err != nil { + return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) + } + + return nil +} + +func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, expires time.Duration) error { + body, err := json.Marshal(ticket) + if err != nil { + return ers.InvalidFormat("CreateOneTimeToken json.Marshal error:", err.Error()) + } + + _, err = t.store.SetnxEx(domain.GetTicketRedisKey(key), string(body), int(expires.Seconds())) + if err != nil { + return ers.DBError("CreateOneTimeToken store.set error:", err.Error()) + } + + return nil +} + func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { err := t.store.Pipelined(func(tx redis.Pipeliner) error { keys := []string{ diff --git a/internal/server/permissionservice/permission_service_server.go b/internal/server/permissionservice/permission_service_server.go new file mode 100644 index 0000000..d645530 --- /dev/null +++ b/internal/server/permissionservice/permission_service_server.go @@ -0,0 +1,20 @@ +// Code generated by goctl. DO NOT EDIT. +// Source: permission.proto + +package server + +import ( + "ark-permission/gen_result/pb/permission" + "ark-permission/internal/svc" +) + +type PermissionServiceServer struct { + svcCtx *svc.ServiceContext + permission.UnimplementedPermissionServiceServer +} + +func NewPermissionServiceServer(svcCtx *svc.ServiceContext) *PermissionServiceServer { + return &PermissionServiceServer{ + svcCtx: svcCtx, + } +} diff --git a/internal/server/roleservice/role_service_server.go b/internal/server/roleservice/role_service_server.go new file mode 100644 index 0000000..0175dc7 --- /dev/null +++ b/internal/server/roleservice/role_service_server.go @@ -0,0 +1,28 @@ +// Code generated by goctl. DO NOT EDIT. +// Source: permission.proto + +package server + +import ( + "context" + + "ark-permission/gen_result/pb/permission" + "ark-permission/internal/logic/roleservice" + "ark-permission/internal/svc" +) + +type RoleServiceServer struct { + svcCtx *svc.ServiceContext + permission.UnimplementedRoleServiceServer +} + +func NewRoleServiceServer(svcCtx *svc.ServiceContext) *RoleServiceServer { + return &RoleServiceServer{ + svcCtx: svcCtx, + } +} + +func (s *RoleServiceServer) Ping(ctx context.Context, in *permission.OKResp) (*permission.OKResp, error) { + l := roleservicelogic.NewPingLogic(ctx, s.svcCtx) + return l.Ping(in) +} diff --git a/internal/server/token_service_server.go b/internal/server/tokenservice/token_service_server.go similarity index 69% rename from internal/server/token_service_server.go rename to internal/server/tokenservice/token_service_server.go index 5e11a23..0187058 100644 --- a/internal/server/token_service_server.go +++ b/internal/server/tokenservice/token_service_server.go @@ -7,7 +7,7 @@ import ( "context" "ark-permission/gen_result/pb/permission" - "ark-permission/internal/logic" + "ark-permission/internal/logic/tokenservice" "ark-permission/internal/svc" ) @@ -24,60 +24,60 @@ func NewTokenServiceServer(svcCtx *svc.ServiceContext) *TokenServiceServer { // NewToken 建立一個新的 Token,例如:AccessToken func (s *TokenServiceServer) NewToken(ctx context.Context, in *permission.AuthorizationReq) (*permission.TokenResp, error) { - l := logic.NewNewTokenLogic(ctx, s.svcCtx) + l := tokenservicelogic.NewNewTokenLogic(ctx, s.svcCtx) return l.NewToken(in) } // RefreshToken 更新目前的token 以及裡面包含的一次性 Token func (s *TokenServiceServer) RefreshToken(ctx context.Context, in *permission.RefreshTokenReq) (*permission.RefreshTokenResp, error) { - l := logic.NewRefreshTokenLogic(ctx, s.svcCtx) + l := tokenservicelogic.NewRefreshTokenLogic(ctx, s.svcCtx) return l.RefreshToken(in) } // CancelToken 取消 Token,也包含他裡面的 One Time Toke func (s *TokenServiceServer) CancelToken(ctx context.Context, in *permission.CancelTokenReq) (*permission.OKResp, error) { - l := logic.NewCancelTokenLogic(ctx, s.svcCtx) + l := tokenservicelogic.NewCancelTokenLogic(ctx, s.svcCtx) return l.CancelToken(in) } -// CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke +// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke func (s *TokenServiceServer) CancelTokenByUid(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { - l := logic.NewCancelTokenByUidLogic(ctx, s.svcCtx) + l := tokenservicelogic.NewCancelTokenByUidLogic(ctx, s.svcCtx) return l.CancelTokenByUid(in) } -// CancelTokenByDeviceID 取消 Token +// CancelTokenByDeviceId 取消 Token func (s *TokenServiceServer) CancelTokenByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { - l := logic.NewCancelTokenByDeviceIdLogic(ctx, s.svcCtx) + l := tokenservicelogic.NewCancelTokenByDeviceIdLogic(ctx, s.svcCtx) return l.CancelTokenByDeviceId(in) } // ValidationToken 驗證這個 Token 有沒有效 func (s *TokenServiceServer) ValidationToken(ctx context.Context, in *permission.ValidationTokenReq) (*permission.ValidationTokenResp, error) { - l := logic.NewValidationTokenLogic(ctx, s.svcCtx) + l := tokenservicelogic.NewValidationTokenLogic(ctx, s.svcCtx) return l.ValidationToken(in) } -// GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens +// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens func (s *TokenServiceServer) GetUserTokensByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { - l := logic.NewGetUserTokensByDeviceIdLogic(ctx, s.svcCtx) + l := tokenservicelogic.NewGetUserTokensByDeviceIdLogic(ctx, s.svcCtx) return l.GetUserTokensByDeviceId(in) } -// GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens -func (s *TokenServiceServer) GetUserTokensByUid(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.Tokens, error) { - l := logic.NewGetUserTokensByUidLogic(ctx, s.svcCtx) +// GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens +func (s *TokenServiceServer) GetUserTokensByUid(ctx context.Context, in *permission.QueryTokenByUIDReq) (*permission.Tokens, error) { + l := tokenservicelogic.NewGetUserTokensByUidLogic(ctx, s.svcCtx) return l.GetUserTokensByUid(in) } // NewOneTimeToken 建立一次性使用,例如:RefreshToken func (s *TokenServiceServer) NewOneTimeToken(ctx context.Context, in *permission.CreateOneTimeTokenReq) (*permission.CreateOneTimeTokenResp, error) { - l := logic.NewNewOneTimeTokenLogic(ctx, s.svcCtx) + l := tokenservicelogic.NewNewOneTimeTokenLogic(ctx, s.svcCtx) return l.NewOneTimeToken(in) } // CancelOneTimeToken 取消一次性使用 -func (s *TokenServiceServer) CancelOneTimeToken(ctx context.Context, in *permission.CreateOneTimeTokenReq) (*permission.CreateOneTimeTokenResp, error) { - l := logic.NewCancelOneTimeTokenLogic(ctx, s.svcCtx) +func (s *TokenServiceServer) CancelOneTimeToken(ctx context.Context, in *permission.CancelOneTimeTokenReq) (*permission.OKResp, error) { + l := tokenservicelogic.NewCancelOneTimeTokenLogic(ctx, s.svcCtx) return l.CancelOneTimeToken(in) } diff --git a/permission.go b/permission.go index 89037e5..9fe85f9 100644 --- a/permission.go +++ b/permission.go @@ -1,12 +1,14 @@ package main import ( + permissionservice "ark-permission/internal/server/permissionservice" + roleservice "ark-permission/internal/server/roleservice" + tokenservice "ark-permission/internal/server/tokenservice" "flag" "fmt" "ark-permission/gen_result/pb/permission" "ark-permission/internal/config" - "ark-permission/internal/server" "ark-permission/internal/svc" "github.com/zeromicro/go-zero/core/conf" @@ -26,7 +28,9 @@ func main() { ctx := svc.NewServiceContext(c) s := zrpc.MustNewServer(c.RpcServerConf, func(grpcServer *grpc.Server) { - permission.RegisterTokenServiceServer(grpcServer, server.NewTokenServiceServer(ctx)) + permission.RegisterTokenServiceServer(grpcServer, tokenservice.NewTokenServiceServer(ctx)) + permission.RegisterRoleServiceServer(grpcServer, roleservice.NewRoleServiceServer(ctx)) + permission.RegisterPermissionServiceServer(grpcServer, permissionservice.NewPermissionServiceServer(ctx)) if c.Mode == service.DevMode || c.Mode == service.TestMode { reflection.Register(grpcServer) diff --git a/tokenservice/token_service.go b/tokenservice/token_service.go index 28be4ee..d8ee889 100644 --- a/tokenservice/token_service.go +++ b/tokenservice/token_service.go @@ -20,6 +20,7 @@ type ( DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq DoTokenByUIDReq = permission.DoTokenByUIDReq OKResp = permission.OKResp + QueryTokenByUIDReq = permission.QueryTokenByUIDReq RefreshTokenReq = permission.RefreshTokenReq RefreshTokenResp = permission.RefreshTokenResp Token = permission.Token @@ -35,16 +36,16 @@ type ( RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) // CancelToken 取消 Token,也包含他裡面的 One Time Toke CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) - // CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke + // CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) - // CancelTokenByDeviceID 取消 Token + // CancelTokenByDeviceId 取消 Token CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) // ValidationToken 驗證這個 Token 有沒有效 ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) - // GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens + // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) - // GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens - GetUserTokensByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) + // GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens + GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) // NewOneTimeToken 建立一次性使用,例如:RefreshToken NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) // CancelOneTimeToken 取消一次性使用 @@ -80,13 +81,13 @@ func (m *defaultTokenService) CancelToken(ctx context.Context, in *CancelTokenRe return client.CancelToken(ctx, in, opts...) } -// CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke +// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke func (m *defaultTokenService) CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) return client.CancelTokenByUid(ctx, in, opts...) } -// CancelTokenByDeviceID 取消 Token +// CancelTokenByDeviceId 取消 Token func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) return client.CancelTokenByDeviceId(ctx, in, opts...) @@ -98,14 +99,14 @@ func (m *defaultTokenService) ValidationToken(ctx context.Context, in *Validatio return client.ValidationToken(ctx, in, opts...) } -// GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens +// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) return client.GetUserTokensByDeviceId(ctx, in, opts...) } -// GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens -func (m *defaultTokenService) GetUserTokensByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) { +// GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens +func (m *defaultTokenService) GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) return client.GetUserTokensByUid(ctx, in, opts...) } -- 2.40.1 From 2d485cf0951414e4c51ec8719c759c87106cba08 Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Sun, 11 Aug 2024 20:21:42 +0800 Subject: [PATCH 07/10] fix: complete token service --- client/tokenservice/token_service.go | 32 +- generate/protobuf/permission.proto | 2 +- internal/domain/redis.go | 1 + internal/domain/repository/token.go | 4 +- internal/entity/token.go | 12 + .../cancel_token_by_device_id_logic.go | 15 +- .../tokenservice/cancel_token_by_uid_logic.go | 48 --- .../logic/tokenservice/cancel_token_logic.go | 4 +- .../logic/tokenservice/cancel_tokens_logic.go | 52 +++ .../get_user_tokens_by_device_id_logic.go | 41 +- .../get_user_tokens_by_uid_logic.go | 1 - .../tokenservice/new_one_time_token_logic.go | 7 +- .../logic/tokenservice/new_token_logic.go | 146 +++---- .../logic/tokenservice/refresh_token_logic.go | 81 ++-- internal/logic/tokenservice/utils_jwt.go | 63 +-- .../tokenservice/validation_token_logic.go | 4 +- internal/repository/token.go | 370 ++++++++++-------- .../tokenservice/token_service_server.go | 24 +- 18 files changed, 500 insertions(+), 407 deletions(-) delete mode 100644 internal/logic/tokenservice/cancel_token_by_uid_logic.go create mode 100644 internal/logic/tokenservice/cancel_tokens_logic.go diff --git a/client/tokenservice/token_service.go b/client/tokenservice/token_service.go index 2bc1af2..617dc38 100644 --- a/client/tokenservice/token_service.go +++ b/client/tokenservice/token_service.go @@ -37,12 +37,12 @@ type ( RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) // CancelToken 取消 Token,也包含他裡面的 One Time Toke CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) - // CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke - CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) - // CancelTokenByDeviceId 取消 Token - CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) // ValidationToken 驗證這個 Token 有沒有效 ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) + // CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 + CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) + // CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token + CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) // GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens @@ -82,24 +82,24 @@ func (m *defaultTokenService) CancelToken(ctx context.Context, in *CancelTokenRe return client.CancelToken(ctx, in, opts...) } -// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (m *defaultTokenService) CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelTokenByUid(ctx, in, opts...) -} - -// CancelTokenByDeviceId 取消 Token -func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelTokenByDeviceId(ctx, in, opts...) -} - // ValidationToken 驗證這個 Token 有沒有效 func (m *defaultTokenService) ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) return client.ValidationToken(ctx, in, opts...) } +// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 +func (m *defaultTokenService) CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelTokens(ctx, in, opts...) +} + +// CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token +func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelTokenByDeviceId(ctx, in, opts...) +} + // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) diff --git a/generate/protobuf/permission.proto b/generate/protobuf/permission.proto index 38bfdc0..e6ed508 100644 --- a/generate/protobuf/permission.proto +++ b/generate/protobuf/permission.proto @@ -114,7 +114,7 @@ message Token { // DoTokenByDeviceIDReq 用DeviceID 來做事的 message DoTokenByDeviceIDReq { - repeated string device_id = 1; + string device_id = 1; } message Tokens{ diff --git a/internal/domain/redis.go b/internal/domain/redis.go index 2fb3b69..2e31a91 100644 --- a/internal/domain/redis.go +++ b/internal/domain/redis.go @@ -18,6 +18,7 @@ const ( DeviceTokenRedisKey RedisKey = "device_token" UIDTokenRedisKey RedisKey = "uid_token" TicketRedisKey RedisKey = "ticket" + DeviceUIDRedisKey RedisKey = "device_uid" ) func (key RedisKey) ToString() string { diff --git a/internal/domain/repository/token.go b/internal/domain/repository/token.go index 57bd4f6..be4c207 100644 --- a/internal/domain/repository/token.go +++ b/internal/domain/repository/token.go @@ -19,11 +19,9 @@ type TokenRepository interface { GetAccessTokenCountByDeviceID(deviceID string) (int, error) Delete(ctx context.Context, token entity.Token) error - DeleteAccessTokenByID(ctx context.Context, id string) error + DeleteAccessTokenByID(ctx context.Context, ids []string) error DeleteAccessTokensByUID(ctx context.Context, uid string) error DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error - DeleteAccessTokenByDeviceIDAndUID(ctx context.Context, deviceID, uid string) error - DeleteUIDToken(ctx context.Context, uid string, ids []string) error } type DeviceToken struct { diff --git a/internal/entity/token.go b/internal/entity/token.go index c2cd7fc..791395a 100644 --- a/internal/entity/token.go +++ b/internal/entity/token.go @@ -30,6 +30,18 @@ func (t *Token) IsExpires() bool { return t.AccessCreateAt.Add(t.AccessTokenExpires()).Before(time.Now()) } +func (t *Token) RedisExpiredSec() int64 { + sec := time.Unix(int64(t.ExpiresIn), 0).Sub(time.Now().UTC()) + + return int64(sec.Seconds()) +} + +func (t *Token) RedisRefreshExpiredSec() int64 { + sec := time.Unix(int64(t.RefreshExpiresIn), 0).Sub(time.Now().UTC()) + + return int64(sec.Seconds()) +} + type UIDToken map[string]int64 type Ticket struct { diff --git a/internal/logic/tokenservice/cancel_token_by_device_id_logic.go b/internal/logic/tokenservice/cancel_token_by_device_id_logic.go index ed59824..e726a26 100644 --- a/internal/logic/tokenservice/cancel_token_by_device_id_logic.go +++ b/internal/logic/tokenservice/cancel_token_by_device_id_logic.go @@ -1,6 +1,7 @@ package tokenservicelogic import ( + ers "code.30cm.net/wanderland/library-go/errors" "context" "ark-permission/gen_result/pb/permission" @@ -25,7 +26,19 @@ func NewCancelTokenByDeviceIdLogic(ctx context.Context, svcCtx *svc.ServiceConte // CancelTokenByDeviceId 取消 Token func (l *CancelTokenByDeviceIdLogic) CancelTokenByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { - // todo: add your logic here and delete this line + if err := l.svcCtx.Validate.ValidateAll(&getUserTokensByDeviceIdReq{ + DeviceID: in.GetDeviceId(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } + err := l.svcCtx.TokenRedisRepo.DeleteAccessTokensByDeviceID(l.ctx, in.GetDeviceId()) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.DeleteAccessTokensByDeviceID"), + logx.Field("DeviceID", in.GetDeviceId()), + ).Error(err.Error()) + return nil, err + } return &permission.OKResp{}, nil } diff --git a/internal/logic/tokenservice/cancel_token_by_uid_logic.go b/internal/logic/tokenservice/cancel_token_by_uid_logic.go deleted file mode 100644 index f266185..0000000 --- a/internal/logic/tokenservice/cancel_token_by_uid_logic.go +++ /dev/null @@ -1,48 +0,0 @@ -package tokenservicelogic - -import ( - ers "code.30cm.net/wanderland/library-go/errors" - "context" - - "ark-permission/gen_result/pb/permission" - "ark-permission/internal/svc" - - "github.com/zeromicro/go-zero/core/logx" -) - -type CancelTokenByUidLogic struct { - ctx context.Context - svcCtx *svc.ServiceContext - logx.Logger -} - -func NewCancelTokenByUidLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokenByUidLogic { - return &CancelTokenByUidLogic{ - ctx: ctx, - svcCtx: svcCtx, - Logger: logx.WithContext(ctx), - } -} - -type deleteByTokenIDs struct { - UID string `json:"uid" binding:"required"` - IDs []string `json:"ids" binding:"required"` -} - -// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (l *CancelTokenByUidLogic) CancelTokenByUid(in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { - // 驗證所需 - if err := l.svcCtx.Validate.ValidateAll(&deleteByTokenIDs{ - UID: in.GetUid(), - IDs: in.GetIds(), - }); err != nil { - return nil, ers.InvalidFormat(err.Error()) - } - - err := l.svcCtx.TokenRedisRepo.DeleteUIDToken(l.ctx, in.GetUid(), in.GetIds()) - if err != nil { - return nil, err - } - - return &permission.OKResp{}, nil -} diff --git a/internal/logic/tokenservice/cancel_token_logic.go b/internal/logic/tokenservice/cancel_token_logic.go index 4f615bd..bf89e59 100644 --- a/internal/logic/tokenservice/cancel_token_logic.go +++ b/internal/logic/tokenservice/cancel_token_logic.go @@ -37,7 +37,7 @@ func (l *CancelTokenLogic) CancelToken(in *permission.CancelTokenReq) (*permissi return nil, ers.InvalidFormat(err.Error()) } - claims, err := parseClaims(l.ctx, in.GetToken(), l.svcCtx.Config.Token.Secret) + claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, false) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "parseClaims"), @@ -45,7 +45,7 @@ func (l *CancelTokenLogic) CancelToken(in *permission.CancelTokenReq) (*permissi return nil, err } - token, err := l.svcCtx.TokenRedisRepo.GetByAccess(l.ctx, claims.ID()) + token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID()) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "TokenRedisRepo.GetByAccess"), diff --git a/internal/logic/tokenservice/cancel_tokens_logic.go b/internal/logic/tokenservice/cancel_tokens_logic.go new file mode 100644 index 0000000..01fac22 --- /dev/null +++ b/internal/logic/tokenservice/cancel_tokens_logic.go @@ -0,0 +1,52 @@ +package tokenservicelogic + +import ( + ers "code.30cm.net/wanderland/library-go/errors" + "context" + + "ark-permission/gen_result/pb/permission" + "ark-permission/internal/svc" + + "github.com/zeromicro/go-zero/core/logx" +) + +type CancelTokensLogic struct { + ctx context.Context + svcCtx *svc.ServiceContext + logx.Logger +} + +func NewCancelTokensLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokensLogic { + return &CancelTokensLogic{ + ctx: ctx, + svcCtx: svcCtx, + Logger: logx.WithContext(ctx), + } +} + +// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 +func (l *CancelTokensLogic) CancelTokens(in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { + if in.GetUid() != "" { + err := l.svcCtx.TokenRedisRepo.DeleteAccessTokensByUID(l.ctx, in.GetUid()) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.DeleteAccessTokensByUID"), + logx.Field("uid", in.GetUid()), + ).Error(err.Error()) + return nil, ers.ResourceInsufficient(err.Error()) + } + } + + if len(in.GetIds()) > 0 { + err := l.svcCtx.TokenRedisRepo.DeleteAccessTokenByID(l.ctx, in.GetIds()) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "TokenRedisRepo.DeleteAccessTokenByID"), + logx.Field("ids", in.GetIds()), + ).Error(err.Error()) + return nil, ers.ResourceInsufficient(err.Error()) + } + } + + return &permission.OKResp{}, nil +} diff --git a/internal/logic/tokenservice/get_user_tokens_by_device_id_logic.go b/internal/logic/tokenservice/get_user_tokens_by_device_id_logic.go index 93836ea..fbc773d 100644 --- a/internal/logic/tokenservice/get_user_tokens_by_device_id_logic.go +++ b/internal/logic/tokenservice/get_user_tokens_by_device_id_logic.go @@ -2,7 +2,9 @@ package tokenservicelogic import ( "ark-permission/gen_result/pb/permission" + "ark-permission/internal/domain" "ark-permission/internal/svc" + ers "code.30cm.net/wanderland/library-go/errors" "context" "github.com/zeromicro/go-zero/core/logx" ) @@ -21,23 +23,34 @@ func NewGetUserTokensByDeviceIdLogic(ctx context.Context, svcCtx *svc.ServiceCon } } +type getUserTokensByDeviceIdReq struct { + DeviceID string `json:"device_id" validate:"required"` +} + // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens func (l *GetUserTokensByDeviceIdLogic) GetUserTokensByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { + if err := l.svcCtx.Validate.ValidateAll(&getUserTokensByDeviceIdReq{ + DeviceID: in.GetDeviceId(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } - // ids, err := l.svcCtx.TokenRedisRepo.GetAccessTokensByDeviceID(l.ctx, "") - // if err != nil { - // return nil, error - // } + uidTokens, err := l.svcCtx.TokenRedisRepo.GetAccessTokensByDeviceID(l.ctx, in.GetDeviceId()) + if err != nil { + return nil, err + } - // tokenIDs := make([]usecase.DeviceToken, 0, len(ids)) - // for _, v := range ids { - // tokenIDs = append(tokenIDs, usecase.DeviceToken{ - // DeviceID: v.DeviceID, - // TokenID: v.TokenID, - // }) - // } - // - // return tokenIDs, nil + tokens := make([]*permission.TokenResp, 0, len(uidTokens)) + for _, v := range uidTokens { + tokens = append(tokens, &permission.TokenResp{ + AccessToken: v.AccessToken, + TokenType: domain.TokenTypeBearer, + ExpiresIn: int32(v.ExpiresIn), + RefreshToken: v.RefreshToken, + }) + } - return &permission.Tokens{}, nil + return &permission.Tokens{ + Token: tokens, + }, nil } diff --git a/internal/logic/tokenservice/get_user_tokens_by_uid_logic.go b/internal/logic/tokenservice/get_user_tokens_by_uid_logic.go index eb3a1f0..9d8616e 100644 --- a/internal/logic/tokenservice/get_user_tokens_by_uid_logic.go +++ b/internal/logic/tokenservice/get_user_tokens_by_uid_logic.go @@ -6,7 +6,6 @@ import ( "ark-permission/internal/svc" ers "code.30cm.net/wanderland/library-go/errors" "context" - "github.com/zeromicro/go-zero/core/logx" ) diff --git a/internal/logic/tokenservice/new_one_time_token_logic.go b/internal/logic/tokenservice/new_one_time_token_logic.go index bcb02d5..79db981 100644 --- a/internal/logic/tokenservice/new_one_time_token_logic.go +++ b/internal/logic/tokenservice/new_one_time_token_logic.go @@ -5,6 +5,7 @@ import ( "ark-permission/internal/entity" ers "code.30cm.net/wanderland/library-go/errors" "context" + "github.com/google/uuid" "time" "ark-permission/gen_result/pb/permission" @@ -37,7 +38,7 @@ func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeToken } // 驗證Token - claims, err := parseClaims(l.ctx, in.GetToken(), l.svcCtx.Config.Token.Secret) + claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, false) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "parseClaims"), @@ -45,7 +46,7 @@ func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeToken return nil, err } - token, err := l.svcCtx.TokenRedisRepo.GetByAccess(l.ctx, claims.ID()) + token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID()) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "TokenRedisRepo.GetByAccess"), @@ -54,7 +55,7 @@ func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeToken return nil, err } - oneTimeToken := generateRefreshToken(in.GetToken()) + oneTimeToken := generateRefreshToken(uuid.Must(uuid.NewRandom()).String()) key := domain.TicketKeyPrefix + oneTimeToken if err = l.svcCtx.TokenRedisRepo.CreateOneTimeToken(l.ctx, key, entity.Ticket{ Data: claims, diff --git a/internal/logic/tokenservice/new_token_logic.go b/internal/logic/tokenservice/new_token_logic.go index 487e3f9..fdccf8f 100644 --- a/internal/logic/tokenservice/new_token_logic.go +++ b/internal/logic/tokenservice/new_token_logic.go @@ -1,6 +1,7 @@ package tokenservicelogic import ( + "ark-permission/internal/config" "ark-permission/internal/domain" "ark-permission/internal/entity" ers "code.30cm.net/wanderland/library-go/errors" @@ -30,83 +31,34 @@ func NewNewTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *NewToken // https://datatracker.ietf.org/doc/html/rfc6749#section-3.3 type authorizationReq struct { - GrantType domain.GrantType `json:"grant_type" validate:"required,oneof=password client_credentials refresh_token"` - DeviceID string `json:"device_id"` - Scope string `json:"scope" validate:"required"` - Data map[string]any `json:"data"` - Expires int `json:"expires"` - IsRefreshToken bool `json:"is_refresh_token"` + GrantType domain.GrantType `json:"grant_type" validate:"required,oneof=password client_credentials refresh_token"` + DeviceID string `json:"device_id"` + Scope string `json:"scope" validate:"required"` + Data map[string]string `json:"data"` + Expires int `json:"expires"` + IsRefreshToken bool `json:"is_refresh_token"` } // NewToken 建立一個新的 Token,例如:AccessToken func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.TokenResp, error) { + data := authorizationReq{ + GrantType: domain.GrantType(in.GetGrantType()), + Scope: in.GetScope(), + DeviceID: in.GetDeviceId(), + Data: in.GetData(), + Expires: int(in.GetExpires()), + IsRefreshToken: in.GetIsRefreshToken(), + } // 驗證所需 - if err := l.svcCtx.Validate.ValidateAll(&authorizationReq{ - GrantType: domain.GrantType(in.GetGrantType()), - Scope: in.GetScope(), - }); err != nil { + if err := l.svcCtx.Validate.ValidateAll(&data); err != nil { return nil, ers.InvalidFormat(err.Error()) } - - // 準備建立 Token 所需 - now := time.Now().UTC() - expires := int(in.GetExpires()) - refreshExpires := int(in.GetExpires()) - if expires <= 0 { - // 將時間加上 300 秒 - sec := time.Duration(l.svcCtx.Config.Token.Expired.Seconds()) * time.Second - newTime := now.Add(sec) - // 獲取 Unix 時間戳 - timestamp := newTime.Unix() - expires = int(timestamp) - refreshExpires = expires - } - - // 如果這是一個 Refresh Token 過期時間要比普通的Token 長 - if in.GetIsRefreshToken() { - // 將時間加上 300 秒 - sec := time.Duration(l.svcCtx.Config.Token.RefreshExpires.Seconds()) * time.Second - newTime := now.Add(sec) - // 獲取 Unix 時間戳 - timestamp := newTime.Unix() - refreshExpires = int(timestamp) - } - - token := entity.Token{ - ID: uuid.Must(uuid.NewRandom()).String(), - DeviceID: in.GetDeviceId(), - ExpiresIn: expires, - RefreshExpiresIn: refreshExpires, - AccessCreateAt: now, - RefreshCreateAt: now, - } - - claims := claims(in.GetData()) - claims.SetRole(domain.DefaultRole) - claims.SetID(token.ID) - claims.SetScope(in.GetScope()) - - token.UID = claims.UID() - - if in.GetDeviceId() != "" { - claims.SetDeviceID(in.GetDeviceId()) - } - - var err error - token.AccessToken, err = generateAccessTokenFunc(token, claims, l.svcCtx.Config.Token.Secret) + token, err := newToken(data, l.svcCtx.Config) if err != nil { - logx.WithCallerSkip(1).WithFields( - logx.Field("func", "generateAccessTokenFunc"), - logx.Field("claims", claims), - ).Error(err.Error()) return nil, err } - if in.GetIsRefreshToken() { - token.RefreshToken = generateRefreshTokenFunc(token.AccessToken) - } - - err = l.svcCtx.TokenRedisRepo.Create(l.ctx, token) + err = l.svcCtx.TokenRedisRepo.Create(l.ctx, *token) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "TokenRedisRepo.Create"), @@ -122,3 +74,65 @@ func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.T RefreshToken: token.RefreshToken, }, nil } + +func newToken(authReq authorizationReq, cfg config.Config) (*entity.Token, error) { + // 準備建立 Token 所需 + now := time.Now().UTC() + expires := authReq.Expires + refreshExpires := authReq.Expires + if expires <= 0 { + // 將時間加上 300 秒 + sec := time.Duration(cfg.Token.Expired.Seconds()) * time.Second + newTime := now.Add(sec) + // 獲取 Unix 時間戳 + timestamp := newTime.Unix() + expires = int(timestamp) + refreshExpires = expires + } + + // 如果這是一個 Refresh Token 過期時間要比普通的Token 長 + if authReq.IsRefreshToken { + // 將時間加上 300 秒 + sec := time.Duration(cfg.Token.RefreshExpires.Seconds()) * time.Second + newTime := now.Add(sec) + // 獲取 Unix 時間戳 + timestamp := newTime.Unix() + refreshExpires = int(timestamp) + } + + token := entity.Token{ + ID: uuid.Must(uuid.NewRandom()).String(), + DeviceID: authReq.DeviceID, + ExpiresIn: expires, + RefreshExpiresIn: refreshExpires, + AccessCreateAt: now, + RefreshCreateAt: now, + } + + claims := claims(authReq.Data) + claims.SetRole(domain.DefaultRole) + claims.SetID(token.ID) + claims.SetScope(authReq.Scope) + + token.UID = claims.UID() + + if authReq.DeviceID != "" { + claims.SetDeviceID(authReq.DeviceID) + } + + var err error + token.AccessToken, err = generateAccessTokenFunc(token, claims, cfg.Token.Secret) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "generateAccessTokenFunc"), + logx.Field("claims", claims), + ).Error(err.Error()) + return nil, err + } + + if authReq.IsRefreshToken { + token.RefreshToken = generateRefreshTokenFunc(token.AccessToken) + } + + return &token, nil +} diff --git a/internal/logic/tokenservice/refresh_token_logic.go b/internal/logic/tokenservice/refresh_token_logic.go index dee0484..607fd28 100644 --- a/internal/logic/tokenservice/refresh_token_logic.go +++ b/internal/logic/tokenservice/refresh_token_logic.go @@ -1,14 +1,11 @@ package tokenservicelogic import ( + "ark-permission/gen_result/pb/permission" "ark-permission/internal/domain" - "ark-permission/internal/entity" + "ark-permission/internal/svc" ers "code.30cm.net/wanderland/library-go/errors" "context" - "time" - - "ark-permission/gen_result/pb/permission" - "ark-permission/internal/svc" "github.com/zeromicro/go-zero/core/logx" ) @@ -31,7 +28,6 @@ type refreshReq struct { RefreshToken string `json:"grant_type" validate:"required"` DeviceID string `json:"device_id" validate:"required"` Scope string `json:"scope" validate:"required"` - Expires int64 `json:"expires" validate:"required"` } // RefreshToken 更新目前的token 以及裡面包含的一次性 Token @@ -41,10 +37,10 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi RefreshToken: in.GetToken(), Scope: in.GetScope(), DeviceID: in.GetDeviceId(), - Expires: in.GetExpires(), }); err != nil { return nil, ers.InvalidFormat(err.Error()) } + // step 1 拿看看有沒有這個 refresh token token, err := l.svcCtx.TokenRedisRepo.GetByRefresh(l.ctx, in.Token) if err != nil { @@ -54,56 +50,33 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi ).Error(err.Error()) return nil, err } - // 拿到之後替換掉時間以及 refresh token - // refreshToken 建立 - now := time.Now().UTC() - sec := time.Duration(l.svcCtx.Config.Token.RefreshExpires.Seconds()) * time.Second - newTime := now.Add(sec) - // 獲取 Unix 時間戳 - timestamp := newTime.Unix() - refreshExpires := int(timestamp) - expires := int(in.GetExpires()) - if expires <= 0 { - // 將時間加上 300 秒 - sec := time.Duration(l.svcCtx.Config.Token.Expired.Seconds()) * time.Second - newTime := now.Add(sec) - // 獲取 Unix 時間戳 - timestamp := newTime.Unix() - expires = int(timestamp) - } - newToken := entity.Token{ - ID: token.ID, - UID: token.UID, - DeviceID: in.GetDeviceId(), - ExpiresIn: expires, - RefreshExpiresIn: refreshExpires, - AccessCreateAt: now, - RefreshCreateAt: now, - } - - claims := claims(map[string]string{ - "uid": token.UID, - }) - claims.SetRole(domain.DefaultRole) - claims.SetID(token.ID) - claims.SetScope(in.GetScope()) - claims.UID() - - if in.GetDeviceId() != "" { - claims.SetDeviceID(in.GetDeviceId()) - } - - newToken.AccessToken, err = generateAccessTokenFunc(newToken, claims, l.svcCtx.Config.Token.Secret) + // 取得 Data + c, err := parseClaims(token.AccessToken, l.svcCtx.Config.Token.Secret, false) if err != nil { logx.WithCallerSkip(1).WithFields( - logx.Field("func", "generateAccessTokenFunc"), - logx.Field("claims", claims), + logx.Field("func", "parseClaims"), + logx.Field("token", token), ).Error(err.Error()) return nil, err } - newToken.RefreshToken = generateRefreshTokenFunc(newToken.AccessToken) + // step 2 建立新 token + nt, err := newToken(authorizationReq{ + GrantType: domain.ClientCredentials, + Scope: in.GetScope(), + DeviceID: in.GetDeviceId(), + Data: c, + Expires: int(in.GetExpires()), + IsRefreshToken: true, + }, l.svcCtx.Config) + if err != nil { + logx.WithCallerSkip(1).WithFields( + logx.Field("func", "newToken"), + logx.Field("req", in), + ).Error(err.Error()) + return nil, err + } // 刪除掉舊的 token err = l.svcCtx.TokenRedisRepo.Delete(l.ctx, token) @@ -115,7 +88,7 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi return nil, err } - err = l.svcCtx.TokenRedisRepo.Create(l.ctx, newToken) + err = l.svcCtx.TokenRedisRepo.Create(l.ctx, *nt) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "TokenRedisRepo.Create"), @@ -125,9 +98,9 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi } return &permission.RefreshTokenResp{ - Token: newToken.AccessToken, - OneTimeToken: newToken.RefreshToken, - ExpiresIn: int64(expires), + Token: nt.AccessToken, + OneTimeToken: nt.RefreshToken, + ExpiresIn: int64(nt.ExpiresIn), TokenType: domain.TokenTypeBearer, }, nil } diff --git a/internal/logic/tokenservice/utils_jwt.go b/internal/logic/tokenservice/utils_jwt.go index c6550b9..3ef34fd 100644 --- a/internal/logic/tokenservice/utils_jwt.go +++ b/internal/logic/tokenservice/utils_jwt.go @@ -4,7 +4,6 @@ import ( "ark-permission/internal/domain" "ark-permission/internal/entity" "bytes" - "context" "crypto/sha256" "encoding/hex" "fmt" @@ -42,43 +41,53 @@ func generateRefreshToken(accessToken string) string { return hex.EncodeToString(h.Sum(nil)) } -func parseClaims(ctx context.Context, accessToken string, secret string) (claims, error) { - claimMap, err := parseToken(accessToken, secret) - if err != nil { - return claims{}, err - } +func parseToken(accessToken string, secret string, validate bool) (jwt.MapClaims, error) { + // 跳過驗證的解析 + var token *jwt.Token + var err error - claims, ok := claimMap["data"].(map[string]any) - if ok { - - return convertMap(claims), nil - } - - return nil, domain.TokenClaimError("get data from claim map error") -} - -func parseToken(accessToken string, secret string) (jwt.MapClaims, error) { - token, err := jwt.Parse(accessToken, func(token *jwt.Token) (any, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, domain.TokenUnexpectedSigningErr(fmt.Sprintf("token unexpected signing method: %v", token.Header["alg"])) + if validate { + token, err = jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, domain.TokenUnexpectedSigningErr(fmt.Sprintf("token unexpected signing method: %v", token.Header["alg"])) + } + return []byte(secret), nil + }) + if err != nil { + return jwt.MapClaims{}, err + } + } else { + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) + token, err = parser.Parse(accessToken, func(token *jwt.Token) (interface{}, error) { + return []byte(secret), nil + }) + if err != nil { + return jwt.MapClaims{}, err } - - return []byte(secret), nil - }) - - if err != nil { - return jwt.MapClaims{}, err } claims, ok := token.Claims.(jwt.MapClaims) - - if !(ok && token.Valid) { + if !ok && token.Valid { return jwt.MapClaims{}, domain.TokenTokenValidateErr("token valid error") } return claims, nil } +func parseClaims(accessToken string, secret string, validate bool) (claims, error) { + claimMap, err := parseToken(accessToken, secret, validate) + if err != nil { + return claims{}, err + } + + claimsData, ok := claimMap["data"].(map[string]any) + if ok { + return convertMap(claimsData), nil + } + + return claims{}, domain.TokenClaimError("get data from claim map error") +} + func convertMap(input map[string]interface{}) map[string]string { output := make(map[string]string) for key, value := range input { diff --git a/internal/logic/tokenservice/validation_token_logic.go b/internal/logic/tokenservice/validation_token_logic.go index fd00492..f3baf9e 100644 --- a/internal/logic/tokenservice/validation_token_logic.go +++ b/internal/logic/tokenservice/validation_token_logic.go @@ -37,7 +37,7 @@ func (l *ValidationTokenLogic) ValidationToken(in *permission.ValidationTokenReq return nil, ers.InvalidFormat(err.Error()) } - claims, err := parseClaims(l.ctx, in.GetToken(), l.svcCtx.Config.Token.Secret) + claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, true) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "parseClaims"), @@ -45,7 +45,7 @@ func (l *ValidationTokenLogic) ValidationToken(in *permission.ValidationTokenReq return nil, err } - token, err := l.svcCtx.TokenRedisRepo.GetByAccess(l.ctx, claims.ID()) + token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID()) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "TokenRedisRepo.GetByAccess"), diff --git a/internal/repository/token.go b/internal/repository/token.go index e434c3b..17a325f 100644 --- a/internal/repository/token.go +++ b/internal/repository/token.go @@ -22,41 +22,6 @@ type tokenRepository struct { store *redis.Redis } -func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) { - // TODO implement me - panic("implement me") -} - -func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { - // TODO implement me - panic("implement me") -} - -func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) { - // TODO implement me - panic("implement me") -} - -func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, id string) error { - // TODO implement me - panic("implement me") -} - -func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error { - // TODO implement me - panic("implement me") -} - -func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { - // TODO implement me - panic("implement me") -} - -func (t *tokenRepository) DeleteAccessTokenByDeviceIDAndUID(ctx context.Context, deviceID, uid string) error { - // TODO implement me - panic("implement me") -} - func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository { return &tokenRepository{ store: param.Store, @@ -70,17 +35,19 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error } err = t.store.Pipelined(func(tx redis.Pipeliner) error { - rTTL := token.RefreshTokenExpires() + // rTTL := token.RedisExpiredSec() + refreshTTL := token.RedisRefreshExpiredSec() - if err := t.setToken(ctx, tx, token, body, rTTL); err != nil { + if err := t.setToken(ctx, tx, token, body, time.Duration(refreshTTL)*time.Second); err != nil { return err } - if err := t.setRefreshToken(ctx, tx, token, rTTL); err != nil { + if err := t.setRefreshToken(ctx, tx, token, time.Duration(refreshTTL)*time.Second); err != nil { return err } - if err := t.setDeviceToken(ctx, tx, token, rTTL); err != nil { + err := t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, time.Duration(refreshTTL)*time.Second) + if err != nil { return err } @@ -90,40 +57,103 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error return domain.RedisPipLineError(err.Error()) } - if err := t.SetUIDToken(token); err != nil { - return ers.ArkInternal("SetUIDToken error", err.Error()) + return nil +} + +func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { + err := t.store.Pipelined(func(tx redis.Pipeliner) error { + keys := []string{ + domain.GetAccessTokenRedisKey(token.ID), + domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), + domain.UIDTokenRedisKey.With(token.UID).ToString(), + } + + for _, key := range keys { + if err := tx.Del(ctx, key).Err(); err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) + } + } + + if token.DeviceID != "" { + key := domain.DeviceTokenRedisKey.With(token.DeviceID).ToString() + _, err := t.store.Del(key) + if err != nil { + return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err)) + } + } + + return nil + }) + + if err != nil { + return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) } return nil } -// // GetAccessTokensByDeviceID 透過 Device ID 得到目前未過期的token -// func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, uid string) ([]repository.DeviceToken, error) { -// data, err := t.store.Hgetall(domain.DeviceTokenRedisKey.With(uid).ToString()) -// if err != nil { -// if errors.Is(err, redis.Nil) { -// return nil, nil -// } -// -// return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.HGetAll Device Token error: %v", err.Error())) -// } -// -// ids := make([]repository.DeviceToken, 0, len(data)) -// for deviceID, id := range data { -// ids = append(ids, repository.DeviceToken{ -// DeviceID: deviceID, -// -// // e0a4f824-41db-4eb2-8e5a-d96966ea1d56-1698083859 -// // -11是因為id組成最後11位數是-跟時間戳記 -// TokenID: id[:len(id)-11], -// }) -// } -// return ids, nil -// } +func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) { + return t.get(domain.GetAccessTokenRedisKey(id)) +} + +func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error { + tokens, err := t.GetAccessTokensByUID(ctx, uid) + if err != nil { + return err + } + for _, item := range tokens { + err := t.Delete(ctx, item) + if err != nil { + return err + } + } + + return nil +} + +// DeleteAccessTokenByID TODO 要做錯誤處理 +func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error { + for _, tokenID := range ids { + token, err := t.GetAccessTokenByID(ctx, tokenID) + if err != nil { + continue + } + + err = t.store.Pipelined(func(tx redis.Pipeliner) error { + keys := []string{ + domain.GetAccessTokenRedisKey(token.ID), + domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), + } + + for _, key := range keys { + if err := tx.Del(ctx, key).Err(); err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) + } + } + + _, err = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID) + if err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err)) + } + + _, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) + if err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err)) + } + + return nil + }) + if err != nil { + continue + } + } + + return nil +} // GetAccessTokensByUID 透過 uid 得到目前未過期的 token func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) { - utKeys, err := t.store.Get(domain.GetUIDTokenRedisKey(uid)) + utKeys, err := t.store.Smembers(domain.GetUIDTokenRedisKey(uid)) if err != nil { // 沒有就視為回空 if errors.Is(err, redis.Nil) { @@ -133,90 +163,39 @@ func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByUID store.Get GetUIDTokenRedisKey error: %v", err.Error())) } - uidTokens := make(entity.UIDToken) - err = json.Unmarshal([]byte(utKeys), &uidTokens) - if err != nil { - return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) - } - - now := time.Now().Unix() + now := time.Now().UTC() var tokens []entity.Token var deleteToken []string - for id, token := range uidTokens { - if token < now { - deleteToken = append(deleteToken, id) - - continue - } - + for _, id := range utKeys { + item := &entity.Token{} tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id)) if err == nil { - item := entity.Token{} - err = json.Unmarshal([]byte(tk), &item) + err = json.Unmarshal([]byte(tk), item) if err != nil { return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) } - tokens = append(tokens, item) + tokens = append(tokens, *item) } if errors.Is(err, redis.Nil) { deleteToken = append(deleteToken, id) } - } + if int64(item.ExpiresIn) < now.Unix() { + deleteToken = append(deleteToken, id) + + continue + } + + } if len(deleteToken) > 0 { // 如果失敗也沒關係,其他get method撈取時會在判斷是否過期或存在 - _ = t.DeleteUIDToken(ctx, uid, deleteToken) + _ = t.DeleteAccessTokenByID(ctx, deleteToken) } return tokens, nil } -func (t *tokenRepository) DeleteUIDToken(ctx context.Context, uid string, ids []string) error { - uidTokens := make(entity.UIDToken) - tokenKeys, err := t.store.Get(domain.GetUIDTokenRedisKey(uid)) - if err != nil { - if !errors.Is(err, redis.Nil) { - return fmt.Errorf("tx.get GetDeviceTokenRedisKey error: %w", err) - } - } - - if tokenKeys != "" { - err = json.Unmarshal([]byte(tokenKeys), &uidTokens) - if err != nil { - return fmt.Errorf("json.Unmarshal GetDeviceTokenRedisKey error: %w", err) - } - } - - now := time.Now().Unix() - for k, t := range uidTokens { - // 到期就刪除 - if t < now { - delete(uidTokens, k) - } - } - - for _, id := range ids { - delete(uidTokens, id) - } - - b, err := json.Marshal(uidTokens) - if err != nil { - return fmt.Errorf("json.Marshal UIDToken error: %w", err) - } - - _, err = t.store.SetnxEx(domain.GetUIDTokenRedisKey(uid), string(b), 86400*30) - if err != nil { - return fmt.Errorf("tx.set GetUIDTokenRedisKey error: %w", err) - } - - return nil -} - -func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) { - return t.get(domain.GetAccessTokenRedisKey(id)) -} - func (t *tokenRepository) GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) { id, err := t.store.Get(domain.RefreshTokenRedisKey.With(refreshToken).ToString()) if err != nil { @@ -262,13 +241,13 @@ func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, return nil } -func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, expires time.Duration) error { +func (t *tokenRepository) CreateOneTimeToken(_ context.Context, key string, ticket entity.Ticket, expires time.Duration) error { body, err := json.Marshal(ticket) if err != nil { return ers.InvalidFormat("CreateOneTimeToken json.Marshal error:", err.Error()) } - _, err = t.store.SetnxEx(domain.GetTicketRedisKey(key), string(body), int(expires.Seconds())) + _, err = t.store.SetnxEx(domain.RefreshTokenRedisKey.With(key).ToString(), string(body), int(expires.Seconds())) if err != nil { return ers.DBError("CreateOneTimeToken store.set error:", err.Error()) } @@ -276,37 +255,106 @@ func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ti return nil } -func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { - err := t.store.Pipelined(func(tx redis.Pipeliner) error { - keys := []string{ - domain.GetAccessTokenRedisKey(token.ID), - domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), +func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { + utKeys, err := t.store.Smembers(domain.DeviceTokenRedisKey.With(deviceID).ToString()) + if err != nil { + // 沒有就視為回空 + if errors.Is(err, redis.Nil) { + return nil, nil } - for _, key := range keys { - if err := tx.Del(ctx, key).Err(); err != nil { + return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.Get DeviceTokenRedisKey error: %v", err.Error())) + } + + now := time.Now().UTC() + var tokens []entity.Token + var deleteToken []string + for _, id := range utKeys { + item := &entity.Token{} + tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id)) + if err == nil { + err = json.Unmarshal([]byte(tk), item) + if err != nil { + return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) + } + tokens = append(tokens, *item) + } + + if errors.Is(err, redis.Nil) { + deleteToken = append(deleteToken, id) + } + + if int64(item.ExpiresIn) < now.Unix() { + deleteToken = append(deleteToken, id) + + continue + } + + } + if len(deleteToken) > 0 { + // 如果失敗也沒關係,其他get method撈取時會在判斷是否過期或存在 + _ = t.DeleteAccessTokenByID(ctx, deleteToken) + } + + return tokens, nil +} + +func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { + tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID) + if err != nil { + return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err)) + } + + err = t.store.Pipelined(func(tx redis.Pipeliner) error { + for _, token := range tokens { + if err := tx.Del(ctx, domain.GetAccessTokenRedisKey(token.ID)).Err(); err != nil { return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) } + + if err := tx.Del(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()).Err(); err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) + } + _, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) + if err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err)) + } } - if token.DeviceID != "" { - key := domain.DeviceTokenRedisKey.With(token.UID).ToString() - _, err := t.store.Hdel(key, token.DeviceID) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err)) - } + _, err := t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString()) + if err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err)) } return nil }) if err != nil { - return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) + return err } return nil } +func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) { + count, err := t.store.Scard(domain.DeviceTokenRedisKey.With(deviceID).ToString()) + if err != nil { + return 0, err + } + + return int(count), nil +} + +func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) { + count, err := t.store.Scard(domain.UIDTokenRedisKey.With(uid).ToString()) + if err != nil { + return 0, err + } + + return int(count), nil +} + +// -------------------- Private area -------------------- + func (t *tokenRepository) get(key string) (entity.Token, error) { body, err := t.store.Get(key) if errors.Is(err, redis.Nil) || body == "" { @@ -343,19 +391,27 @@ func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeline return nil } -func (t *tokenRepository) setDeviceToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { - if token.DeviceID != "" { - key := domain.DeviceTokenRedisKey.With(token.UID).ToString() - value := fmt.Sprintf("%s-%d", token.ID, token.AccessCreateAt.Add(rTTL).Unix()) - err := tx.HSet(ctx, key, token.DeviceID, value).Err() - if err != nil { - return wrapError("tx.HSet Device Token error", err) - } - err = tx.Expire(ctx, key, rTTL).Err() - if err != nil { - return wrapError("tx.Expire Device Token error", err) - } +func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, rttl time.Duration) error { + uidKey := domain.UIDTokenRedisKey.With(uid).ToString() + err := tx.SAdd(ctx, uidKey, tokenID).Err() + if err != nil { + return err } + err = tx.Expire(ctx, uidKey, rttl).Err() + if err != nil { + return err + } + + deviceKey := domain.DeviceTokenRedisKey.With(deviceID).ToString() + err = tx.SAdd(ctx, deviceKey, tokenID).Err() + if err != nil { + return err + } + err = tx.Expire(ctx, deviceKey, rttl).Err() + if err != nil { + return err + } + return nil } diff --git a/internal/server/tokenservice/token_service_server.go b/internal/server/tokenservice/token_service_server.go index 0187058..c5adb92 100644 --- a/internal/server/tokenservice/token_service_server.go +++ b/internal/server/tokenservice/token_service_server.go @@ -40,24 +40,24 @@ func (s *TokenServiceServer) CancelToken(ctx context.Context, in *permission.Can return l.CancelToken(in) } -// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (s *TokenServiceServer) CancelTokenByUid(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { - l := tokenservicelogic.NewCancelTokenByUidLogic(ctx, s.svcCtx) - return l.CancelTokenByUid(in) -} - -// CancelTokenByDeviceId 取消 Token -func (s *TokenServiceServer) CancelTokenByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { - l := tokenservicelogic.NewCancelTokenByDeviceIdLogic(ctx, s.svcCtx) - return l.CancelTokenByDeviceId(in) -} - // ValidationToken 驗證這個 Token 有沒有效 func (s *TokenServiceServer) ValidationToken(ctx context.Context, in *permission.ValidationTokenReq) (*permission.ValidationTokenResp, error) { l := tokenservicelogic.NewValidationTokenLogic(ctx, s.svcCtx) return l.ValidationToken(in) } +// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 +func (s *TokenServiceServer) CancelTokens(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { + l := tokenservicelogic.NewCancelTokensLogic(ctx, s.svcCtx) + return l.CancelTokens(in) +} + +// CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token +func (s *TokenServiceServer) CancelTokenByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { + l := tokenservicelogic.NewCancelTokenByDeviceIdLogic(ctx, s.svcCtx) + return l.CancelTokenByDeviceId(in) +} + // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens func (s *TokenServiceServer) GetUserTokensByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { l := tokenservicelogic.NewGetUserTokensByDeviceIdLogic(ctx, s.svcCtx) -- 2.40.1 From 2f3d169d10b3c60ed8d464b25886973287a46942 Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Sun, 11 Aug 2024 20:23:45 +0800 Subject: [PATCH 08/10] fix: delete --- tokenservice/token_service.go | 124 ---------------------------------- 1 file changed, 124 deletions(-) delete mode 100644 tokenservice/token_service.go diff --git a/tokenservice/token_service.go b/tokenservice/token_service.go deleted file mode 100644 index d8ee889..0000000 --- a/tokenservice/token_service.go +++ /dev/null @@ -1,124 +0,0 @@ -// Code generated by goctl. DO NOT EDIT. -// Source: permission.proto - -package tokenservice - -import ( - "context" - - "ark-permission/gen_result/pb/permission" - - "github.com/zeromicro/go-zero/zrpc" - "google.golang.org/grpc" -) - -type ( - AuthorizationReq = permission.AuthorizationReq - CancelTokenReq = permission.CancelTokenReq - CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq - CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp - DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq - DoTokenByUIDReq = permission.DoTokenByUIDReq - OKResp = permission.OKResp - QueryTokenByUIDReq = permission.QueryTokenByUIDReq - RefreshTokenReq = permission.RefreshTokenReq - RefreshTokenResp = permission.RefreshTokenResp - Token = permission.Token - TokenResp = permission.TokenResp - Tokens = permission.Tokens - ValidationTokenReq = permission.ValidationTokenReq - ValidationTokenResp = permission.ValidationTokenResp - - TokenService interface { - // NewToken 建立一個新的 Token,例如:AccessToken - NewToken(ctx context.Context, in *AuthorizationReq, opts ...grpc.CallOption) (*TokenResp, error) - // RefreshToken 更新目前的token 以及裡面包含的一次性 Token - RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) - // CancelToken 取消 Token,也包含他裡面的 One Time Toke - CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) - // CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke - CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) - // CancelTokenByDeviceId 取消 Token - CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) - // ValidationToken 驗證這個 Token 有沒有效 - ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) - // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens - GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) - // GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens - GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) - // NewOneTimeToken 建立一次性使用,例如:RefreshToken - NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) - // CancelOneTimeToken 取消一次性使用 - CancelOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) - } - - defaultTokenService struct { - cli zrpc.Client - } -) - -func NewTokenService(cli zrpc.Client) TokenService { - return &defaultTokenService{ - cli: cli, - } -} - -// NewToken 建立一個新的 Token,例如:AccessToken -func (m *defaultTokenService) NewToken(ctx context.Context, in *AuthorizationReq, opts ...grpc.CallOption) (*TokenResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.NewToken(ctx, in, opts...) -} - -// RefreshToken 更新目前的token 以及裡面包含的一次性 Token -func (m *defaultTokenService) RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.RefreshToken(ctx, in, opts...) -} - -// CancelToken 取消 Token,也包含他裡面的 One Time Toke -func (m *defaultTokenService) CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelToken(ctx, in, opts...) -} - -// CancelTokenByUid 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (m *defaultTokenService) CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelTokenByUid(ctx, in, opts...) -} - -// CancelTokenByDeviceId 取消 Token -func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelTokenByDeviceId(ctx, in, opts...) -} - -// ValidationToken 驗證這個 Token 有沒有效 -func (m *defaultTokenService) ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.ValidationToken(ctx, in, opts...) -} - -// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens -func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.GetUserTokensByDeviceId(ctx, in, opts...) -} - -// GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens -func (m *defaultTokenService) GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.GetUserTokensByUid(ctx, in, opts...) -} - -// NewOneTimeToken 建立一次性使用,例如:RefreshToken -func (m *defaultTokenService) NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.NewOneTimeToken(ctx, in, opts...) -} - -// CancelOneTimeToken 取消一次性使用 -func (m *defaultTokenService) CancelOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelOneTimeToken(ctx, in, opts...) -} -- 2.40.1 From 31fd7ad226ba579b3fec8b1a5f55deac84f29b9e Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Sun, 11 Aug 2024 20:24:51 +0800 Subject: [PATCH 09/10] fix: delete --- .gitignore | 3 +- .../permissionservice/permission_service.go | 45 ------- client/roleservice/role_service.go | 51 ------- client/tokenservice/token_service.go | 125 ------------------ 4 files changed, 2 insertions(+), 222 deletions(-) delete mode 100644 client/permissionservice/permission_service.go delete mode 100644 client/roleservice/role_service.go delete mode 100644 client/tokenservice/token_service.go diff --git a/.gitignore b/.gitignore index d7d485f..59ba36c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ go.sum account/ gen_result/ -etc/permission.yaml \ No newline at end of file +etc/permission.yaml +./client \ No newline at end of file diff --git a/client/permissionservice/permission_service.go b/client/permissionservice/permission_service.go deleted file mode 100644 index fdae9cb..0000000 --- a/client/permissionservice/permission_service.go +++ /dev/null @@ -1,45 +0,0 @@ -// Code generated by goctl. DO NOT EDIT. -// Source: permission.proto - -package permissionservice - -import ( - "context" - - "ark-permission/gen_result/pb/permission" - - "github.com/zeromicro/go-zero/zrpc" - "google.golang.org/grpc" -) - -type ( - AuthorizationReq = permission.AuthorizationReq - CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq - CancelTokenReq = permission.CancelTokenReq - CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq - CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp - DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq - DoTokenByUIDReq = permission.DoTokenByUIDReq - OKResp = permission.OKResp - QueryTokenByUIDReq = permission.QueryTokenByUIDReq - RefreshTokenReq = permission.RefreshTokenReq - RefreshTokenResp = permission.RefreshTokenResp - Token = permission.Token - TokenResp = permission.TokenResp - Tokens = permission.Tokens - ValidationTokenReq = permission.ValidationTokenReq - ValidationTokenResp = permission.ValidationTokenResp - - PermissionService interface { - } - - defaultPermissionService struct { - cli zrpc.Client - } -) - -func NewPermissionService(cli zrpc.Client) PermissionService { - return &defaultPermissionService{ - cli: cli, - } -} diff --git a/client/roleservice/role_service.go b/client/roleservice/role_service.go deleted file mode 100644 index 8d7e2e2..0000000 --- a/client/roleservice/role_service.go +++ /dev/null @@ -1,51 +0,0 @@ -// Code generated by goctl. DO NOT EDIT. -// Source: permission.proto - -package roleservice - -import ( - "context" - - "ark-permission/gen_result/pb/permission" - - "github.com/zeromicro/go-zero/zrpc" - "google.golang.org/grpc" -) - -type ( - AuthorizationReq = permission.AuthorizationReq - CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq - CancelTokenReq = permission.CancelTokenReq - CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq - CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp - DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq - DoTokenByUIDReq = permission.DoTokenByUIDReq - OKResp = permission.OKResp - QueryTokenByUIDReq = permission.QueryTokenByUIDReq - RefreshTokenReq = permission.RefreshTokenReq - RefreshTokenResp = permission.RefreshTokenResp - Token = permission.Token - TokenResp = permission.TokenResp - Tokens = permission.Tokens - ValidationTokenReq = permission.ValidationTokenReq - ValidationTokenResp = permission.ValidationTokenResp - - RoleService interface { - Ping(ctx context.Context, in *OKResp, opts ...grpc.CallOption) (*OKResp, error) - } - - defaultRoleService struct { - cli zrpc.Client - } -) - -func NewRoleService(cli zrpc.Client) RoleService { - return &defaultRoleService{ - cli: cli, - } -} - -func (m *defaultRoleService) Ping(ctx context.Context, in *OKResp, opts ...grpc.CallOption) (*OKResp, error) { - client := permission.NewRoleServiceClient(m.cli.Conn()) - return client.Ping(ctx, in, opts...) -} diff --git a/client/tokenservice/token_service.go b/client/tokenservice/token_service.go deleted file mode 100644 index 617dc38..0000000 --- a/client/tokenservice/token_service.go +++ /dev/null @@ -1,125 +0,0 @@ -// Code generated by goctl. DO NOT EDIT. -// Source: permission.proto - -package tokenservice - -import ( - "context" - - "ark-permission/gen_result/pb/permission" - - "github.com/zeromicro/go-zero/zrpc" - "google.golang.org/grpc" -) - -type ( - AuthorizationReq = permission.AuthorizationReq - CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq - CancelTokenReq = permission.CancelTokenReq - CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq - CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp - DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq - DoTokenByUIDReq = permission.DoTokenByUIDReq - OKResp = permission.OKResp - QueryTokenByUIDReq = permission.QueryTokenByUIDReq - RefreshTokenReq = permission.RefreshTokenReq - RefreshTokenResp = permission.RefreshTokenResp - Token = permission.Token - TokenResp = permission.TokenResp - Tokens = permission.Tokens - ValidationTokenReq = permission.ValidationTokenReq - ValidationTokenResp = permission.ValidationTokenResp - - TokenService interface { - // NewToken 建立一個新的 Token,例如:AccessToken - NewToken(ctx context.Context, in *AuthorizationReq, opts ...grpc.CallOption) (*TokenResp, error) - // RefreshToken 更新目前的token 以及裡面包含的一次性 Token - RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) - // CancelToken 取消 Token,也包含他裡面的 One Time Toke - CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) - // ValidationToken 驗證這個 Token 有沒有效 - ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) - // CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 - CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) - // CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token - CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) - // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens - GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) - // GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens - GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) - // NewOneTimeToken 建立一次性使用,例如:RefreshToken - NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) - // CancelOneTimeToken 取消一次性使用 - CancelOneTimeToken(ctx context.Context, in *CancelOneTimeTokenReq, opts ...grpc.CallOption) (*OKResp, error) - } - - defaultTokenService struct { - cli zrpc.Client - } -) - -func NewTokenService(cli zrpc.Client) TokenService { - return &defaultTokenService{ - cli: cli, - } -} - -// NewToken 建立一個新的 Token,例如:AccessToken -func (m *defaultTokenService) NewToken(ctx context.Context, in *AuthorizationReq, opts ...grpc.CallOption) (*TokenResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.NewToken(ctx, in, opts...) -} - -// RefreshToken 更新目前的token 以及裡面包含的一次性 Token -func (m *defaultTokenService) RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.RefreshToken(ctx, in, opts...) -} - -// CancelToken 取消 Token,也包含他裡面的 One Time Toke -func (m *defaultTokenService) CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelToken(ctx, in, opts...) -} - -// ValidationToken 驗證這個 Token 有沒有效 -func (m *defaultTokenService) ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.ValidationToken(ctx, in, opts...) -} - -// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 -func (m *defaultTokenService) CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelTokens(ctx, in, opts...) -} - -// CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token -func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelTokenByDeviceId(ctx, in, opts...) -} - -// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens -func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.GetUserTokensByDeviceId(ctx, in, opts...) -} - -// GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens -func (m *defaultTokenService) GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.GetUserTokensByUid(ctx, in, opts...) -} - -// NewOneTimeToken 建立一次性使用,例如:RefreshToken -func (m *defaultTokenService) NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.NewOneTimeToken(ctx, in, opts...) -} - -// CancelOneTimeToken 取消一次性使用 -func (m *defaultTokenService) CancelOneTimeToken(ctx context.Context, in *CancelOneTimeTokenReq, opts ...grpc.CallOption) (*OKResp, error) { - client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelOneTimeToken(ctx, in, opts...) -} -- 2.40.1 From db0d90351870b3cd4e330835503867636770a798 Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Mon, 12 Aug 2024 22:19:34 +0800 Subject: [PATCH 10/10] update pre and online gateway version --- .../permissionservice/permission_service.go | 45 ++ client/roleservice/role_service.go | 51 ++ client/tokenservice/token_service.go | 125 +++++ .../20230529020000_create_schema.down.sql | 2 +- .../20230529020000_create_schema.up.sql | 2 +- internal/domain/repository/token.go | 7 +- .../logic/tokenservice/cancel_token_logic.go | 6 +- .../tokenservice/new_one_time_token_logic.go | 2 +- .../logic/tokenservice/refresh_token_logic.go | 2 +- .../tokenservice/validation_token_logic.go | 9 +- internal/repository/token.go | 447 ++++++------------ 11 files changed, 381 insertions(+), 317 deletions(-) create mode 100644 client/permissionservice/permission_service.go create mode 100644 client/roleservice/role_service.go create mode 100644 client/tokenservice/token_service.go diff --git a/client/permissionservice/permission_service.go b/client/permissionservice/permission_service.go new file mode 100644 index 0000000..fdae9cb --- /dev/null +++ b/client/permissionservice/permission_service.go @@ -0,0 +1,45 @@ +// Code generated by goctl. DO NOT EDIT. +// Source: permission.proto + +package permissionservice + +import ( + "context" + + "ark-permission/gen_result/pb/permission" + + "github.com/zeromicro/go-zero/zrpc" + "google.golang.org/grpc" +) + +type ( + AuthorizationReq = permission.AuthorizationReq + CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq + CancelTokenReq = permission.CancelTokenReq + CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq + CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp + DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq + DoTokenByUIDReq = permission.DoTokenByUIDReq + OKResp = permission.OKResp + QueryTokenByUIDReq = permission.QueryTokenByUIDReq + RefreshTokenReq = permission.RefreshTokenReq + RefreshTokenResp = permission.RefreshTokenResp + Token = permission.Token + TokenResp = permission.TokenResp + Tokens = permission.Tokens + ValidationTokenReq = permission.ValidationTokenReq + ValidationTokenResp = permission.ValidationTokenResp + + PermissionService interface { + } + + defaultPermissionService struct { + cli zrpc.Client + } +) + +func NewPermissionService(cli zrpc.Client) PermissionService { + return &defaultPermissionService{ + cli: cli, + } +} diff --git a/client/roleservice/role_service.go b/client/roleservice/role_service.go new file mode 100644 index 0000000..8d7e2e2 --- /dev/null +++ b/client/roleservice/role_service.go @@ -0,0 +1,51 @@ +// Code generated by goctl. DO NOT EDIT. +// Source: permission.proto + +package roleservice + +import ( + "context" + + "ark-permission/gen_result/pb/permission" + + "github.com/zeromicro/go-zero/zrpc" + "google.golang.org/grpc" +) + +type ( + AuthorizationReq = permission.AuthorizationReq + CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq + CancelTokenReq = permission.CancelTokenReq + CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq + CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp + DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq + DoTokenByUIDReq = permission.DoTokenByUIDReq + OKResp = permission.OKResp + QueryTokenByUIDReq = permission.QueryTokenByUIDReq + RefreshTokenReq = permission.RefreshTokenReq + RefreshTokenResp = permission.RefreshTokenResp + Token = permission.Token + TokenResp = permission.TokenResp + Tokens = permission.Tokens + ValidationTokenReq = permission.ValidationTokenReq + ValidationTokenResp = permission.ValidationTokenResp + + RoleService interface { + Ping(ctx context.Context, in *OKResp, opts ...grpc.CallOption) (*OKResp, error) + } + + defaultRoleService struct { + cli zrpc.Client + } +) + +func NewRoleService(cli zrpc.Client) RoleService { + return &defaultRoleService{ + cli: cli, + } +} + +func (m *defaultRoleService) Ping(ctx context.Context, in *OKResp, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewRoleServiceClient(m.cli.Conn()) + return client.Ping(ctx, in, opts...) +} diff --git a/client/tokenservice/token_service.go b/client/tokenservice/token_service.go new file mode 100644 index 0000000..617dc38 --- /dev/null +++ b/client/tokenservice/token_service.go @@ -0,0 +1,125 @@ +// Code generated by goctl. DO NOT EDIT. +// Source: permission.proto + +package tokenservice + +import ( + "context" + + "ark-permission/gen_result/pb/permission" + + "github.com/zeromicro/go-zero/zrpc" + "google.golang.org/grpc" +) + +type ( + AuthorizationReq = permission.AuthorizationReq + CancelOneTimeTokenReq = permission.CancelOneTimeTokenReq + CancelTokenReq = permission.CancelTokenReq + CreateOneTimeTokenReq = permission.CreateOneTimeTokenReq + CreateOneTimeTokenResp = permission.CreateOneTimeTokenResp + DoTokenByDeviceIDReq = permission.DoTokenByDeviceIDReq + DoTokenByUIDReq = permission.DoTokenByUIDReq + OKResp = permission.OKResp + QueryTokenByUIDReq = permission.QueryTokenByUIDReq + RefreshTokenReq = permission.RefreshTokenReq + RefreshTokenResp = permission.RefreshTokenResp + Token = permission.Token + TokenResp = permission.TokenResp + Tokens = permission.Tokens + ValidationTokenReq = permission.ValidationTokenReq + ValidationTokenResp = permission.ValidationTokenResp + + TokenService interface { + // NewToken 建立一個新的 Token,例如:AccessToken + NewToken(ctx context.Context, in *AuthorizationReq, opts ...grpc.CallOption) (*TokenResp, error) + // RefreshToken 更新目前的token 以及裡面包含的一次性 Token + RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) + // CancelToken 取消 Token,也包含他裡面的 One Time Toke + CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) + // ValidationToken 驗證這個 Token 有沒有效 + ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) + // CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 + CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) + // CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token + CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) + // GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens + GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) + // GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens + GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) + // NewOneTimeToken 建立一次性使用,例如:RefreshToken + NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) + // CancelOneTimeToken 取消一次性使用 + CancelOneTimeToken(ctx context.Context, in *CancelOneTimeTokenReq, opts ...grpc.CallOption) (*OKResp, error) + } + + defaultTokenService struct { + cli zrpc.Client + } +) + +func NewTokenService(cli zrpc.Client) TokenService { + return &defaultTokenService{ + cli: cli, + } +} + +// NewToken 建立一個新的 Token,例如:AccessToken +func (m *defaultTokenService) NewToken(ctx context.Context, in *AuthorizationReq, opts ...grpc.CallOption) (*TokenResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.NewToken(ctx, in, opts...) +} + +// RefreshToken 更新目前的token 以及裡面包含的一次性 Token +func (m *defaultTokenService) RefreshToken(ctx context.Context, in *RefreshTokenReq, opts ...grpc.CallOption) (*RefreshTokenResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.RefreshToken(ctx, in, opts...) +} + +// CancelToken 取消 Token,也包含他裡面的 One Time Toke +func (m *defaultTokenService) CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelToken(ctx, in, opts...) +} + +// ValidationToken 驗證這個 Token 有沒有效 +func (m *defaultTokenService) ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.ValidationToken(ctx, in, opts...) +} + +// CancelTokens 取消 Token 從UID 視角,以及 token id 視角出發, UID 登出,底下所有 Device ID 也要登出, Token ID 登出, 所有 UID + Device 都要登出 +func (m *defaultTokenService) CancelTokens(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelTokens(ctx, in, opts...) +} + +// CancelTokenByDeviceId 取消 Token, 從 Device 視角出發,可以選,登出這個Device 下所有 token ,登出這個Device 下指定token +func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelTokenByDeviceId(ctx, in, opts...) +} + +// GetUserTokensByDeviceId 取得目前所對應的 DeviceID 所存在的 Tokens +func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.GetUserTokensByDeviceId(ctx, in, opts...) +} + +// GetUserTokensByUid 取得目前所對應的 UID 所存在的 Tokens +func (m *defaultTokenService) GetUserTokensByUid(ctx context.Context, in *QueryTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.GetUserTokensByUid(ctx, in, opts...) +} + +// NewOneTimeToken 建立一次性使用,例如:RefreshToken +func (m *defaultTokenService) NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.NewOneTimeToken(ctx, in, opts...) +} + +// CancelOneTimeToken 取消一次性使用 +func (m *defaultTokenService) CancelOneTimeToken(ctx context.Context, in *CancelOneTimeTokenReq, opts ...grpc.CallOption) (*OKResp, error) { + client := permission.NewTokenServiceClient(m.cli.Conn()) + return client.CancelOneTimeToken(ctx, in, opts...) +} diff --git a/generate/database/mysql/create/20230529020000_create_schema.down.sql b/generate/database/mysql/create/20230529020000_create_schema.down.sql index e7727a5..dc0bfe4 100644 --- a/generate/database/mysql/create/20230529020000_create_schema.down.sql +++ b/generate/database/mysql/create/20230529020000_create_schema.down.sql @@ -1 +1 @@ -DROP DATABASE IF EXISTS `ark_member`; \ No newline at end of file +DROP DATABASE IF EXISTS `ark_permission`; \ No newline at end of file diff --git a/generate/database/mysql/create/20230529020000_create_schema.up.sql b/generate/database/mysql/create/20230529020000_create_schema.up.sql index d997e04..686ffdf 100644 --- a/generate/database/mysql/create/20230529020000_create_schema.up.sql +++ b/generate/database/mysql/create/20230529020000_create_schema.up.sql @@ -1 +1 @@ -CREATE DATABASE IF NOT EXISTS `ark_member`; \ No newline at end of file +CREATE DATABASE IF NOT EXISTS `ark_permission`; \ No newline at end of file diff --git a/internal/domain/repository/token.go b/internal/domain/repository/token.go index be4c207..5f02b98 100644 --- a/internal/domain/repository/token.go +++ b/internal/domain/repository/token.go @@ -6,12 +6,14 @@ import ( "time" ) +// TokenRepository token 的 redis 操作 type TokenRepository interface { + // Create 建立Token Create(ctx context.Context, token entity.Token) error - DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error + // CreateOneTimeToken 建立臨時 Token CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error - GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) + GetAccessTokenByByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) GetAccessTokenCountByUID(uid string) (int, error) @@ -19,6 +21,7 @@ type TokenRepository interface { GetAccessTokenCountByDeviceID(deviceID string) (int, error) Delete(ctx context.Context, token entity.Token) error + DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error DeleteAccessTokenByID(ctx context.Context, ids []string) error DeleteAccessTokensByUID(ctx context.Context, uid string) error DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error diff --git a/internal/logic/tokenservice/cancel_token_logic.go b/internal/logic/tokenservice/cancel_token_logic.go index bf89e59..368f297 100644 --- a/internal/logic/tokenservice/cancel_token_logic.go +++ b/internal/logic/tokenservice/cancel_token_logic.go @@ -1,12 +1,10 @@ package tokenservicelogic import ( - ers "code.30cm.net/wanderland/library-go/errors" - "context" - "ark-permission/gen_result/pb/permission" "ark-permission/internal/svc" - + ers "code.30cm.net/wanderland/library-go/errors" + "context" "github.com/zeromicro/go-zero/core/logx" ) diff --git a/internal/logic/tokenservice/new_one_time_token_logic.go b/internal/logic/tokenservice/new_one_time_token_logic.go index 79db981..389f23a 100644 --- a/internal/logic/tokenservice/new_one_time_token_logic.go +++ b/internal/logic/tokenservice/new_one_time_token_logic.go @@ -28,7 +28,7 @@ func NewNewOneTimeTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *N } } -// NewOneTimeToken 建立一次性使用,例如:RefreshToken +// NewOneTimeToken 建立一次性使用,例如:RefreshToken TODO 目前並無後續操作 func (l *NewOneTimeTokenLogic) NewOneTimeToken(in *permission.CreateOneTimeTokenReq) (*permission.CreateOneTimeTokenResp, error) { // 驗證所需 if err := l.svcCtx.Validate.ValidateAll(&refreshTokenReq{ diff --git a/internal/logic/tokenservice/refresh_token_logic.go b/internal/logic/tokenservice/refresh_token_logic.go index 607fd28..e059b7a 100644 --- a/internal/logic/tokenservice/refresh_token_logic.go +++ b/internal/logic/tokenservice/refresh_token_logic.go @@ -42,7 +42,7 @@ func (l *RefreshTokenLogic) RefreshToken(in *permission.RefreshTokenReq) (*permi } // step 1 拿看看有沒有這個 refresh token - token, err := l.svcCtx.TokenRedisRepo.GetByRefresh(l.ctx, in.Token) + token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByByOneTimeToken(l.ctx, in.Token) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "TokenRedisRepo.GetByRefresh"), diff --git a/internal/logic/tokenservice/validation_token_logic.go b/internal/logic/tokenservice/validation_token_logic.go index f3baf9e..be6ab1a 100644 --- a/internal/logic/tokenservice/validation_token_logic.go +++ b/internal/logic/tokenservice/validation_token_logic.go @@ -1,11 +1,10 @@ package tokenservicelogic import ( - ers "code.30cm.net/wanderland/library-go/errors" - "context" - "ark-permission/gen_result/pb/permission" "ark-permission/internal/svc" + ers "code.30cm.net/wanderland/library-go/errors" + "context" "github.com/zeromicro/go-zero/core/logx" ) @@ -36,15 +35,13 @@ func (l *ValidationTokenLogic) ValidationToken(in *permission.ValidationTokenReq }); err != nil { return nil, ers.InvalidFormat(err.Error()) } - claims, err := parseClaims(in.GetToken(), l.svcCtx.Config.Token.Secret, true) if err != nil { logx.WithCallerSkip(1).WithFields( logx.Field("func", "parseClaims"), - ).Error(err.Error()) + ).Info(err.Error()) return nil, err } - token, err := l.svcCtx.TokenRedisRepo.GetAccessTokenByID(l.ctx, claims.ID()) if err != nil { logx.WithCallerSkip(1).WithFields( diff --git a/internal/repository/token.go b/internal/repository/token.go index 17a325f..d91f73c 100644 --- a/internal/repository/token.go +++ b/internal/repository/token.go @@ -33,27 +33,19 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error if err != nil { return ers.ArkInternal("json.Marshal token error", err.Error()) } + if err := t.store.Pipelined(func(tx redis.Pipeliner) error { + refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second - err = t.store.Pipelined(func(tx redis.Pipeliner) error { - // rTTL := token.RedisExpiredSec() - refreshTTL := token.RedisRefreshExpiredSec() - - if err := t.setToken(ctx, tx, token, body, time.Duration(refreshTTL)*time.Second); err != nil { + if err := t.setToken(ctx, tx, token, body, refreshTTL); err != nil { return err } - if err := t.setRefreshToken(ctx, tx, token, time.Duration(refreshTTL)*time.Second); err != nil { + if err := t.setRefreshToken(ctx, tx, token, refreshTTL); err != nil { return err } - err := t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, time.Duration(refreshTTL)*time.Second) - if err != nil { - return err - } - - return nil - }) - if err != nil { + return t.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL) + }); err != nil { return domain.RedisPipLineError(err.Error()) } @@ -61,39 +53,28 @@ func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error } func (t *tokenRepository) Delete(ctx context.Context, token entity.Token) error { - err := t.store.Pipelined(func(tx redis.Pipeliner) error { - keys := []string{ - domain.GetAccessTokenRedisKey(token.ID), - domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), - domain.UIDTokenRedisKey.With(token.UID).ToString(), - } - - for _, key := range keys { - if err := tx.Del(ctx, key).Err(); err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) - } - } - - if token.DeviceID != "" { - key := domain.DeviceTokenRedisKey.With(token.DeviceID).ToString() - _, err := t.store.Del(key) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("store.HDel deviceKey error: %v", err)) - } - } - - return nil - }) - - if err != nil { - return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) + keys := []string{ + domain.GetAccessTokenRedisKey(token.ID), + domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), } + if err := t.deleteKeys(ctx, keys...); err != nil { + return domain.RedisPipLineError(err.Error()) + } + + _, _ = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID) + _, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) + return nil } -func (t *tokenRepository) GetAccessTokenByID(_ context.Context, id string) (entity.Token, error) { - return t.get(domain.GetAccessTokenRedisKey(id)) +func (t *tokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) { + token, err := t.get(ctx, domain.GetAccessTokenRedisKey(id)) + if err != nil { + return entity.Token{}, err + } + + return token, nil } func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error { @@ -101,9 +82,9 @@ func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid strin if err != nil { return err } - for _, item := range tokens { - err := t.Delete(ctx, item) - if err != nil { + + for _, token := range tokens { + if err := t.Delete(ctx, token); err != nil { return err } } @@ -111,7 +92,6 @@ func (t *tokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid strin return nil } -// DeleteAccessTokenByID TODO 要做錯誤處理 func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error { for _, tokenID := range ids { token, err := t.GetAccessTokenByID(ctx, tokenID) @@ -119,338 +99,203 @@ func (t *tokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []strin continue } - err = t.store.Pipelined(func(tx redis.Pipeliner) error { - keys := []string{ - domain.GetAccessTokenRedisKey(token.ID), - domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), - } + keys := []string{ + domain.GetAccessTokenRedisKey(token.ID), + domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), + } - for _, key := range keys { - if err := tx.Del(ctx, key).Err(); err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) - } - } - - _, err = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err)) - } - - _, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err)) - } - - return nil - }) - if err != nil { + if err := t.deleteKeys(ctx, keys...); err != nil { continue } + + _, _ = t.store.Srem(domain.DeviceTokenRedisKey.With(token.DeviceID).ToString(), token.ID) + _, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) } return nil } -// GetAccessTokensByUID 透過 uid 得到目前未過期的 token func (t *tokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) { - utKeys, err := t.store.Smembers(domain.GetUIDTokenRedisKey(uid)) - if err != nil { - // 沒有就視為回空 - if errors.Is(err, redis.Nil) { - return nil, nil - } - - return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByUID store.Get GetUIDTokenRedisKey error: %v", err.Error())) - } - - now := time.Now().UTC() - var tokens []entity.Token - var deleteToken []string - for _, id := range utKeys { - item := &entity.Token{} - tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id)) - if err == nil { - err = json.Unmarshal([]byte(tk), item) - if err != nil { - return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) - } - tokens = append(tokens, *item) - } - - if errors.Is(err, redis.Nil) { - deleteToken = append(deleteToken, id) - } - - if int64(item.ExpiresIn) < now.Unix() { - deleteToken = append(deleteToken, id) - - continue - } - - } - if len(deleteToken) > 0 { - // 如果失敗也沒關係,其他get method撈取時會在判斷是否過期或存在 - _ = t.DeleteAccessTokenByID(ctx, deleteToken) - } - - return tokens, nil + return t.getTokensBySet(ctx, domain.GetUIDTokenRedisKey(uid)) } -func (t *tokenRepository) GetByRefresh(ctx context.Context, refreshToken string) (entity.Token, error) { - id, err := t.store.Get(domain.RefreshTokenRedisKey.With(refreshToken).ToString()) +func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { + return t.getTokensBySet(ctx, domain.DeviceTokenRedisKey.With(deviceID).ToString()) +} + +func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { + + tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID) if err != nil { - return entity.Token{}, err + return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err)) } - if errors.Is(err, redis.Nil) || id == "" { - return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(refreshToken).ToString()) + var keys []string + for _, token := range tokens { + keys = append(keys, domain.GetAccessTokenRedisKey(token.ID)) + keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()) + } + err = t.store.Pipelined(func(tx redis.Pipeliner) error { + for _, token := range tokens { + _, _ = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) + } + return nil + }) if err != nil { - return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.GetByRefresh refresh token error: %v", err)) + return err + } + + if err := t.deleteKeys(ctx, keys...); err != nil { + return err + } + + _, err = t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString()) + return err +} + +func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) { + return t.getCountBySet(domain.DeviceTokenRedisKey.With(deviceID).ToString()) +} + +func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) { + return t.getCountBySet(domain.UIDTokenRedisKey.With(uid).ToString()) +} + +func (t *tokenRepository) GetAccessTokenByByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) { + id, err := t.store.Get(domain.RefreshTokenRedisKey.With(oneTimeToken).ToString()) + if err != nil { + return entity.Token{}, domain.RedisError(fmt.Sprintf("GetAccessTokenByByOneTimeToken store.Get error: %s", err.Error())) + } + + if id == "" { + return entity.Token{}, ers.ResourceNotFound("token key not found in redis", domain.RefreshTokenRedisKey.With(oneTimeToken).ToString()) } return t.GetAccessTokenByID(ctx, id) } func (t *tokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error { - err := t.store.Pipelined(func(tx redis.Pipeliner) error { - keys := make([]string, 0, len(ids)+len(tokens)) + var keys []string - for _, id := range ids { - keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString()) - } - - for _, token := range tokens { - keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()) - } - - for _, key := range keys { - if err := tx.Del(ctx, key).Err(); err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) - } - } - - return nil - }) - - if err != nil { - return domain.RedisPipLineError(fmt.Sprintf("store.Pipelined error: %v", err)) + for _, id := range ids { + keys = append(keys, domain.RefreshTokenRedisKey.With(id).ToString()) } - return nil + for _, token := range tokens { + keys = append(keys, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()) + } + + return t.deleteKeys(ctx, keys...) } -func (t *tokenRepository) CreateOneTimeToken(_ context.Context, key string, ticket entity.Ticket, expires time.Duration) error { +func (t *tokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, expires time.Duration) error { body, err := json.Marshal(ticket) if err != nil { - return ers.InvalidFormat("CreateOneTimeToken json.Marshal error:", err.Error()) + return ers.InvalidFormat("CreateOneTimeToken json.Marshal error", err.Error()) } _, err = t.store.SetnxEx(domain.RefreshTokenRedisKey.With(key).ToString(), string(body), int(expires.Seconds())) if err != nil { - return ers.DBError("CreateOneTimeToken store.set error:", err.Error()) + return domain.RedisError(fmt.Sprintf("CreateOneTimeToken store.SetnxEx error: %s", err.Error())) } return nil } -func (t *tokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) { - utKeys, err := t.store.Smembers(domain.DeviceTokenRedisKey.With(deviceID).ToString()) - if err != nil { - // 沒有就視為回空 - if errors.Is(err, redis.Nil) { - return nil, nil - } - - return nil, domain.RedisError(fmt.Sprintf("tokenRepository.GetAccessTokensByDeviceID store.Get DeviceTokenRedisKey error: %v", err.Error())) - } - - now := time.Now().UTC() - var tokens []entity.Token - var deleteToken []string - for _, id := range utKeys { - item := &entity.Token{} - tk, err := t.store.Get(domain.GetAccessTokenRedisKey(id)) - if err == nil { - err = json.Unmarshal([]byte(tk), item) - if err != nil { - return nil, ers.ArkInternal(fmt.Sprintf("tokenRepository.GetAccessTokensByUID json.Unmarshal GetUIDTokenRedisKey error: %v", err)) - } - tokens = append(tokens, *item) - } - - if errors.Is(err, redis.Nil) { - deleteToken = append(deleteToken, id) - } - - if int64(item.ExpiresIn) < now.Unix() { - deleteToken = append(deleteToken, id) - - continue - } - - } - if len(deleteToken) > 0 { - // 如果失敗也沒關係,其他get method撈取時會在判斷是否過期或存在 - _ = t.DeleteAccessTokenByID(ctx, deleteToken) - } - - return tokens, nil -} - -func (t *tokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error { - tokens, err := t.GetAccessTokensByDeviceID(ctx, deviceID) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("GetAccessTokensByDeviceID error: %v", err)) - } - - err = t.store.Pipelined(func(tx redis.Pipeliner) error { - for _, token := range tokens { - if err := tx.Del(ctx, domain.GetAccessTokenRedisKey(token.ID)).Err(); err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) - } - - if err := tx.Del(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString()).Err(); err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) - } - _, err = t.store.Srem(domain.UIDTokenRedisKey.With(token.UID).ToString(), token.ID) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Srem UIDTokenRedisKey error: %v", err)) - } - } - - _, err := t.store.Del(domain.DeviceTokenRedisKey.With(deviceID).ToString()) - if err != nil { - return domain.RedisDelError(fmt.Sprintf("store.Srem DeviceTokenRedisKey error: %v", err)) - } - - return nil - }) - - if err != nil { - return err - } - - return nil -} - -func (t *tokenRepository) GetAccessTokenCountByDeviceID(deviceID string) (int, error) { - count, err := t.store.Scard(domain.DeviceTokenRedisKey.With(deviceID).ToString()) - if err != nil { - return 0, err - } - - return int(count), nil -} - -func (t *tokenRepository) GetAccessTokenCountByUID(uid string) (int, error) { - count, err := t.store.Scard(domain.UIDTokenRedisKey.With(uid).ToString()) - if err != nil { - return 0, err - } - - return int(count), nil -} - // -------------------- Private area -------------------- -func (t *tokenRepository) get(key string) (entity.Token, error) { - body, err := t.store.Get(key) - if errors.Is(err, redis.Nil) || body == "" { - return entity.Token{}, ers.ResourceNotFound("token key not found in redis", key) +func (t *tokenRepository) get(ctx context.Context, key string) (entity.Token, error) { + body, err := t.store.GetCtx(ctx, key) + if err != nil { + return entity.Token{}, domain.RedisError(fmt.Sprintf("token %s not found in redis: %s", key, err.Error())) } - if err != nil { - return entity.Token{}, ers.ArkInternal(fmt.Sprintf("store.Get tokenTag error: %v", err)) + if body == "" { + return entity.Token{}, ers.ResourceNotFound("this token not found") } var token entity.Token if err := json.Unmarshal([]byte(body), &token); err != nil { - return entity.Token{}, ers.ArkInternal(fmt.Sprintf("json.Unmarshal token error: %w", err)) + return entity.Token{}, ers.ArkInternal("json.Unmarshal token error", err.Error()) } return token, nil } -func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, rTTL time.Duration) error { - err := tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, rTTL).Err() - if err != nil { - return wrapError("tx.Set GetAccessTokenRedisKey error", err) - } - return nil +func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, ttl time.Duration) error { + return tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, ttl).Err() } -func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { +func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, ttl time.Duration) error { if token.RefreshToken != "" { - err := tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, rTTL).Err() - if err != nil { - return wrapError("tx.Set RefreshToken error", err) - } + return tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, ttl).Err() } return nil } -func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, rttl time.Duration) error { - uidKey := domain.UIDTokenRedisKey.With(uid).ToString() - err := tx.SAdd(ctx, uidKey, tokenID).Err() - if err != nil { - return err - } - err = tx.Expire(ctx, uidKey, rttl).Err() - if err != nil { +func (t *tokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error { + if err := tx.SAdd(ctx, domain.UIDTokenRedisKey.With(uid).ToString(), tokenID).Err(); err != nil { return err } - deviceKey := domain.DeviceTokenRedisKey.With(deviceID).ToString() - err = tx.SAdd(ctx, deviceKey, tokenID).Err() - if err != nil { - return err - } - err = tx.Expire(ctx, deviceKey, rttl).Err() - if err != nil { + if err := tx.SAdd(ctx, domain.DeviceTokenRedisKey.With(deviceID).ToString(), tokenID).Err(); err != nil { return err } return nil } -// SetUIDToken 將 token 資料放進 uid key中 -func (t *tokenRepository) SetUIDToken(token entity.Token) error { - uidTokens := make(entity.UIDToken) - b, err := t.store.Get(domain.GetUIDTokenRedisKey(token.UID)) - if err != nil && !errors.Is(err, redis.Nil) { - return wrapError("t.store.Get GetUIDTokenRedisKey error", err) - } - - if b != "" { - err = json.Unmarshal([]byte(b), &uidTokens) - if err != nil { - return wrapError("json.Unmarshal GetUIDTokenRedisKey error", err) +func (t *tokenRepository) deleteKeys(ctx context.Context, keys ...string) error { + return t.store.Pipelined(func(tx redis.Pipeliner) error { + for _, key := range keys { + if err := tx.Del(ctx, key).Err(); err != nil { + return domain.RedisDelError(fmt.Sprintf("store.Del key error: %v", err)) + } } + return nil + }) +} + +func (t *tokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) { + ids, err := t.store.Smembers(setKey) + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, nil + } + return nil, domain.RedisError(fmt.Sprintf("getTokensBySet store.Get %s error: %v", setKey, err.Error())) } + var tokens []entity.Token + var deleteTokens []string now := time.Now().Unix() - for k, t := range uidTokens { - if t < now { - delete(uidTokens, k) + for _, id := range ids { + token, err := t.get(ctx, domain.GetAccessTokenRedisKey(id)) + if err != nil { + deleteTokens = append(deleteTokens, id) + continue } + + if int64(token.ExpiresIn) < now { + deleteTokens = append(deleteTokens, id) + continue + } + + tokens = append(tokens, token) } - uidTokens[token.ID] = token.RefreshTokenExpiresUnix() - s, err := json.Marshal(uidTokens) - if err != nil { - return wrapError("json.Marshal UIDToken error", err) + if len(deleteTokens) > 0 { + _ = t.DeleteAccessTokenByID(ctx, deleteTokens) } - err = t.store.Setex(domain.GetUIDTokenRedisKey(token.UID), string(s), 86400*30) - if err != nil { - return wrapError("t.store.Setex GetUIDTokenRedisKey error", err) - } - - return nil + return tokens, nil } -func wrapError(message string, err error) error { - return fmt.Errorf("%s: %w", message, err) +func (t *tokenRepository) getCountBySet(setKey string) (int, error) { + count, err := t.store.Scard(setKey) + if err != nil { + return 0, err + } + return int(count), nil } -- 2.40.1