From d82e9e1a54585a425ae52f83d760ee61e61ab5cd Mon Sep 17 00:00:00 2001 From: "daniel.w" Date: Tue, 6 Aug 2024 13:59:24 +0800 Subject: [PATCH] feat: create new token --- etc/permission_example.yaml | 15 + generate/protobuf/permission.proto | 8 +- go.mod | 13 +- internal/config/config.go | 12 +- internal/domain/const.go | 18 + internal/domain/redis.go | 43 + internal/domain/repository/token.go | 10 + internal/entity/claims.go | 8 + internal/entity/token.go | 38 + internal/lib/error/code/define.go | 98 ++ internal/lib/error/code/messsage.go | 13 + internal/lib/error/easy_func.go | 442 +++++++ internal/lib/error/easy_func_test.go | 1031 +++++++++++++++++ internal/lib/error/errors.go | 197 ++++ internal/lib/error/errors_test.go | 297 +++++ internal/lib/middleware/with_context.go | 28 + internal/lib/required/validate.go | 51 + internal/lib/required/validate_option.go | 29 + ....go => cancel_token_by_device_id_logic.go} | 8 +- ..._logic.go => cancel_token_by_uid_logic.go} | 8 +- internal/logic/claims.go | 55 + ... => get_user_tokens_by_device_id_logic.go} | 8 +- ...gic.go => get_user_tokens_by_uid_logic.go} | 8 +- internal/logic/new_token_logic.go | 114 +- internal/repository/token.go | 136 +++ internal/server/token_service_server.go | 24 +- internal/svc/service_context.go | 24 +- permission.go | 3 + tokenservice/token_service.go | 24 +- 29 files changed, 2710 insertions(+), 53 deletions(-) create mode 100644 etc/permission_example.yaml create mode 100644 internal/domain/const.go create mode 100644 internal/domain/redis.go create mode 100644 internal/domain/repository/token.go create mode 100644 internal/entity/claims.go create mode 100644 internal/entity/token.go create mode 100644 internal/lib/error/code/define.go create mode 100644 internal/lib/error/code/messsage.go create mode 100644 internal/lib/error/easy_func.go create mode 100644 internal/lib/error/easy_func_test.go create mode 100644 internal/lib/error/errors.go create mode 100644 internal/lib/error/errors_test.go create mode 100644 internal/lib/middleware/with_context.go create mode 100644 internal/lib/required/validate.go create mode 100644 internal/lib/required/validate_option.go rename internal/logic/{cancel_token_by_device_i_d_logic.go => cancel_token_by_device_id_logic.go} (65%) rename internal/logic/{cancel_token_by_u_i_d_logic.go => cancel_token_by_uid_logic.go} (70%) create mode 100644 internal/logic/claims.go rename internal/logic/{get_user_tokens_by_device_i_d_logic.go => get_user_tokens_by_device_id_logic.go} (66%) rename internal/logic/{get_user_tokens_by_u_i_d_logic.go => get_user_tokens_by_uid_logic.go} (67%) create mode 100644 internal/repository/token.go diff --git a/etc/permission_example.yaml b/etc/permission_example.yaml new file mode 100644 index 0000000..b1f1d42 --- /dev/null +++ b/etc/permission_example.yaml @@ -0,0 +1,15 @@ +Name: permission.rpc +ListenOn: 0.0.0.0:8080 +Etcd: + Hosts: + - 127.0.0.1:2379 + Key: permission.rpc + +RedisCluster: + Host: 127.0.0.1:7001 + Type: cluster + +Token: + Expired: 300 + RefreshExpires: 86500 + Secret: gg88g88 \ No newline at end of file diff --git a/generate/protobuf/permission.proto b/generate/protobuf/permission.proto index 83db18c..717f753 100644 --- a/generate/protobuf/permission.proto +++ b/generate/protobuf/permission.proto @@ -126,15 +126,15 @@ service TokenService { // CancelToken 取消 Token,也包含他裡面的 One Time Toke rpc CancelToken(CancelTokenReq) returns(OKResp); // CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke - rpc CancelTokenByUID(DoTokenByUIDReq) returns(OKResp); + rpc CancelTokenByUid(DoTokenByUIDReq) returns(OKResp); // CancelTokenByDeviceID 取消 Token - rpc CancelTokenByDeviceID(DoTokenByDeviceIDReq) returns(OKResp); + rpc CancelTokenByDeviceId(DoTokenByDeviceIDReq) returns(OKResp); // ValidationToken 驗證這個 Token 有沒有效 rpc ValidationToken(ValidationTokenReq) returns(ValidationTokenResp); // GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens - rpc GetUserTokensByDeviceID(DoTokenByDeviceIDReq) returns(Tokens); + rpc GetUserTokensByDeviceId(DoTokenByDeviceIDReq) returns(Tokens); // GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens - rpc GetUserTokensByUID(DoTokenByUIDReq) returns(Tokens); + rpc GetUserTokensByUid(DoTokenByUIDReq) returns(Tokens); // NewOneTimeToken 建立一次性使用,例如:RefreshToken rpc NewOneTimeToken(CreateOneTimeTokenReq) returns(CreateOneTimeTokenResp); // CancelOneTimeToken 取消一次性使用 diff --git a/go.mod b/go.mod index 357d382..12ccb99 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,11 @@ module ark-permission go 1.22.3 require ( + github.com/go-playground/validator/v10 v10.22.0 + github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/golang/mock v1.6.0 + github.com/google/uuid v1.6.0 + github.com/stretchr/testify v1.9.0 github.com/zeromicro/go-zero v1.7.0 google.golang.org/grpc v1.65.0 google.golang.org/protobuf v1.34.2 @@ -18,21 +23,23 @@ require ( github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/emicklei/go-restful/v3 v3.11.0 // indirect github.com/fatih/color v1.17.0 // indirect + github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-openapi/jsonpointer v0.19.6 // indirect github.com/go-openapi/jsonreference v0.20.2 // indirect github.com/go-openapi/swag v0.22.4 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang/mock v1.6.0 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/gnostic-models v0.6.8 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/gofuzz v1.2.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect + github.com/leodido/go-urn v1.4.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect @@ -41,6 +48,7 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/openzipkin/zipkin-go v0.4.3 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.19.1 // indirect github.com/prometheus/client_model v0.5.0 // indirect github.com/prometheus/common v0.48.0 // indirect @@ -65,6 +73,7 @@ require ( go.uber.org/automaxprocs v1.5.3 // indirect go.uber.org/multierr v1.9.0 // indirect go.uber.org/zap v1.24.0 // indirect + golang.org/x/crypto v0.25.0 // indirect golang.org/x/net v0.27.0 // indirect golang.org/x/oauth2 v0.20.0 // indirect golang.org/x/sys v0.22.0 // indirect diff --git a/internal/config/config.go b/internal/config/config.go index c1f85b9..6cd58e1 100755 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,7 +1,17 @@ package config -import "github.com/zeromicro/go-zero/zrpc" +import ( + "github.com/zeromicro/go-zero/core/stores/redis" + "github.com/zeromicro/go-zero/zrpc" + "time" +) type Config struct { zrpc.RpcServerConf + RedisCluster redis.RedisConf + Token struct { + RefreshExpires time.Duration + Expired time.Duration + Secret string + } } diff --git a/internal/domain/const.go b/internal/domain/const.go new file mode 100644 index 0000000..82ac1b8 --- /dev/null +++ b/internal/domain/const.go @@ -0,0 +1,18 @@ +package domain + +type GrantType string + +const ( + PasswordCredentials GrantType = "password" + ClientCredentials GrantType = "client_credentials" + Refreshing GrantType = "refresh_token" +) + +const ( + // DefaultRole 預設role + DefaultRole = "user" +) + +const ( + TokenTypeBearer = "Bearer" +) diff --git a/internal/domain/redis.go b/internal/domain/redis.go new file mode 100644 index 0000000..2fb3b69 --- /dev/null +++ b/internal/domain/redis.go @@ -0,0 +1,43 @@ +package domain + +import "strings" + +const ( + TicketKeyPrefix = "tic/" +) + +const ( + ClientDataKey = "permission:clients" +) + +type RedisKey string + +const ( + AccessTokenRedisKey RedisKey = "access_token" + RefreshTokenRedisKey RedisKey = "refresh_token" + DeviceTokenRedisKey RedisKey = "device_token" + UIDTokenRedisKey RedisKey = "uid_token" + TicketRedisKey RedisKey = "ticket" +) + +func (key RedisKey) ToString() string { + return "permission:" + string(key) +} + +func (key RedisKey) With(s ...string) RedisKey { + parts := append([]string{string(key)}, s...) + + return RedisKey(strings.Join(parts, ":")) +} + +func GetAccessTokenRedisKey(id string) string { + return AccessTokenRedisKey.With(id).ToString() +} + +func GetUIDTokenRedisKey(uid string) string { + return UIDTokenRedisKey.With(uid).ToString() +} + +func GetTicketRedisKey(ticket string) string { + return TicketRedisKey.With(ticket).ToString() +} diff --git a/internal/domain/repository/token.go b/internal/domain/repository/token.go new file mode 100644 index 0000000..b5f512a --- /dev/null +++ b/internal/domain/repository/token.go @@ -0,0 +1,10 @@ +package repository + +import ( + "ark-permission/internal/entity" + "context" +) + +type TokenRepository interface { + Create(ctx context.Context, token entity.Token) error +} diff --git a/internal/entity/claims.go b/internal/entity/claims.go new file mode 100644 index 0000000..9271795 --- /dev/null +++ b/internal/entity/claims.go @@ -0,0 +1,8 @@ +package entity + +import "github.com/golang-jwt/jwt/v4" + +type Claims struct { + jwt.RegisteredClaims + Data interface{} `json:"data"` +} diff --git a/internal/entity/token.go b/internal/entity/token.go new file mode 100644 index 0000000..eeadd26 --- /dev/null +++ b/internal/entity/token.go @@ -0,0 +1,38 @@ +package entity + +import "time" + +type Token struct { + ID string `json:"id"` + UID string `json:"uid"` + DeviceID string `json:"device_id"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + AccessCreateAt time.Time `json:"access_create_at"` + RefreshToken string `json:"refresh_token"` + RefreshExpiresIn int `json:"refresh_expires_in"` + RefreshCreateAt time.Time `json:"refresh_create_at"` +} + +func (t *Token) AccessTokenExpires() time.Duration { + return time.Duration(t.ExpiresIn) * time.Second +} + +func (t *Token) RefreshTokenExpires() time.Duration { + return time.Duration(t.RefreshExpiresIn) * time.Second +} + +func (t *Token) RefreshTokenExpiresUnix() int64 { + return time.Now().Add(t.RefreshTokenExpires()).Unix() +} + +func (t *Token) IsExpires() bool { + return t.AccessCreateAt.Add(t.AccessTokenExpires()).Before(time.Now()) +} + +type UIDToken map[string]int64 + +type Ticket struct { + Data interface{} `json:"data"` + Token Token `json:"token"` +} diff --git a/internal/lib/error/code/define.go b/internal/lib/error/code/define.go new file mode 100644 index 0000000..49715a2 --- /dev/null +++ b/internal/lib/error/code/define.go @@ -0,0 +1,98 @@ +package code + +const ( + OK uint32 = 0 +) + +// Scope +const ( + Unset uint32 = iota + CloudEPPortalGW + CloudEPMember +) + +// Category for general operations: 100 - 4900 +const ( + _ = iota + CatInput uint32 = iota * 100 + CatDB + CatResource + CatGRPC + CatAuth + CatSystem + CatPubSub +) + +// CatArk Category for specific app/service: 5000 - 9900 +const ( + CatArk uint32 = (iota + 50) * 100 +) + +// Detail - Input 1xx +const ( + _ = iota + CatInput + InvalidFormat + NotValidImplementation + InvalidRange +) + +// Detail - Database 2xx +const ( + _ = iota + CatDB + DBError // general error + DBDataConvert + DBDuplicate +) + +// Detail - Resource 3xx +const ( + _ = iota + CatResource + ResourceNotFound + InvalidResourceFormat + ResourceAlreadyExist + ResourceInsufficient + InsufficientPermission + InvalidMeasurementID + ResourceExpired + ResourceMigrated + InvalidResourceState + InsufficientQuota + ResourceHasMultiOwner +) + +/* Detail - GRPC */ +// The GRPC detail code uses Go GRPC's built-in codes. +// Refer to "google.golang.org/grpc/codes" for more detail. + +// Detail - Auth 5xx +const ( + _ = iota + CatAuth + Unauthorized + AuthExpired + InvalidPosixTime + SigAndPayloadNotMatched + Forbidden +) + +// Detail - System 6xx +const ( + _ = iota + CatSystem + SystemInternalError + SystemMaintainError + SystemTimeoutError +) + +// Detail - PubSub 7xx +const ( + _ = iota + CatPubSub + Publish + Consume + MsgSizeTooLarge +) + +// Detail - Ark 5xxx +const ( + _ = iota + CatArk + ArkInternal + ArkHttp400 +) diff --git a/internal/lib/error/code/messsage.go b/internal/lib/error/code/messsage.go new file mode 100644 index 0000000..18a4d4f --- /dev/null +++ b/internal/lib/error/code/messsage.go @@ -0,0 +1,13 @@ +package code + +// CatToStr collects general error messages for each Category +// It is used to send back to API caller +var CatToStr = map[uint32]string{ + CatInput: "Invalid Input Data", + CatDB: "Database Error", + CatResource: "Resource Error", + CatGRPC: "Internal Service Communication Error", + CatAuth: "Authentication Error", + CatArk: "Internal Service Communication Error", + CatSystem: "System Error", +} diff --git a/internal/lib/error/easy_func.go b/internal/lib/error/easy_func.go new file mode 100644 index 0000000..2e13bd8 --- /dev/null +++ b/internal/lib/error/easy_func.go @@ -0,0 +1,442 @@ +package error + +import ( + "ark-permission/internal/lib/error/code" + "errors" + "fmt" + "strings" + + "github.com/zeromicro/go-zero/core/logx" + _ "github.com/zeromicro/go-zero/core/logx" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func newErr(scope, detail uint32, msg string) *Err { + cat := detail / 100 * 100 + return &Err{ + category: cat, + code: detail, + scope: scope, + msg: msg, + } +} + +func newBuiltinGRPCErr(scope, detail uint32, msg string) *Err { + return &Err{ + category: code.CatGRPC, + code: detail, + scope: scope, + msg: msg, + } +} + +// FromError tries to let error as Err +// it supports to unwrap error that has Err +// return nil if failed to transfer +func FromError(err error) *Err { + if err == nil { + return nil + } + + var e *Err + if errors.As(err, &e) { + return e + } + + return nil +} + +// FromCode parses code as following +// Decimal: 120314 +// 12 represents Scope +// 03 represents Category +// 14 represents Detail error code +func FromCode(code uint32) *Err { + scope := code / 10000 + detail := code % 10000 + return &Err{ + category: detail / 100 * 100, + code: detail, + scope: scope, + msg: "", + } +} + +// FromGRPCError transfer error to Err +// useful for gRPC client +func FromGRPCError(err error) *Err { + s, _ := status.FromError(err) + e := FromCode(uint32(s.Code())) + e.msg = s.Message() + + // For GRPC built-in code + if e.Scope() == code.Unset && e.Category() == 0 && e.Code() != code.OK { + e = newBuiltinGRPCErr(Scope, e.Code(), s.Message()) + } + + return e +} + +// Deprecated: check GRPCStatus() in Errs struct +// ToGRPCError returns the status.Status +// Useful to return error in gRPC server +func ToGRPCError(e *Err) error { + return status.New(codes.Code(e.FullCode()), e.Error()).Err() +} + +/*** System ***/ + +// SystemTimeoutError returns Err +func SystemTimeoutError(s ...string) *Err { + return newErr(Scope, code.SystemTimeoutError, fmt.Sprintf("system timeout: %s", strings.Join(s, " "))) +} + +// SystemTimeoutErrorL logs error message and returns Err +func SystemTimeoutErrorL(l logx.Logger, s ...string) *Err { + e := SystemTimeoutError(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// SystemInternalError returns Err struct +func SystemInternalError(s ...string) *Err { + return newErr(Scope, code.SystemInternalError, fmt.Sprintf("internal error: %s", strings.Join(s, " "))) +} + +// SystemInternalErrorL logs error message and returns Err +func SystemInternalErrorL(l logx.Logger, s ...string) *Err { + e := SystemInternalError(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// SystemMaintainErrorL logs error message and returns Err +func SystemMaintainErrorL(l logx.Logger, s ...string) *Err { + e := SystemMaintainError(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// SystemMaintainError returns Err struct +func SystemMaintainError(s ...string) *Err { + return newErr(Scope, code.SystemMaintainError, fmt.Sprintf("service under maintenance: %s", strings.Join(s, " "))) +} + +/*** CatInput ***/ + +// InvalidFormat returns Err struct +func InvalidFormat(s ...string) *Err { + return newErr(Scope, code.InvalidFormat, fmt.Sprintf("invalid format: %s", strings.Join(s, " "))) +} + +// InvalidFormatL logs error message and returns Err +func InvalidFormatL(l logx.Logger, s ...string) *Err { + e := InvalidFormat(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InvalidRange returns Err struct +func InvalidRange(s ...string) *Err { + return newErr(Scope, code.InvalidRange, fmt.Sprintf("invalid range: %s", strings.Join(s, " "))) +} + +// InvalidRangeL logs error message and returns Err +func InvalidRangeL(l logx.Logger, s ...string) *Err { + e := InvalidRange(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// NotValidImplementation returns Err struct +func NotValidImplementation(s ...string) *Err { + return newErr(Scope, code.NotValidImplementation, fmt.Sprintf("not valid implementation: %s", strings.Join(s, " "))) +} + +// NotValidImplementationL logs error message and returns Err +func NotValidImplementationL(l logx.Logger, s ...string) *Err { + e := NotValidImplementation(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +/*** CatDB ***/ + +// DBError returns Err +func DBError(s ...string) *Err { + return newErr(Scope, code.DBError, fmt.Sprintf("db error: %s", strings.Join(s, " "))) +} + +// DBErrorL logs error message and returns Err +func DBErrorL(l logx.Logger, s ...string) *Err { + e := DBError(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// DBDataConvert returns Err +func DBDataConvert(s ...string) *Err { + return newErr(Scope, code.DBDataConvert, fmt.Sprintf("data from db convert error: %s", strings.Join(s, " "))) +} + +// DBDataConvertL logs error message and returns Err +func DBDataConvertL(l logx.Logger, s ...string) *Err { + e := DBDataConvert(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// DBDuplicate returns Err +func DBDuplicate(s ...string) *Err { + return newErr(Scope, code.DBDuplicate, fmt.Sprintf("data Duplicate key error: %s", strings.Join(s, " "))) +} + +// DBDuplicateL logs error message and returns Err +func DBDuplicateL(l logx.Logger, s ...string) *Err { + e := DBDuplicate(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +/*** CatResource ***/ + +// ResourceNotFound returns Err and logging +func ResourceNotFound(s ...string) *Err { + return newErr(Scope, code.ResourceNotFound, fmt.Sprintf("resource not found: %s", strings.Join(s, " "))) +} + +// ResourceNotFoundL logs error message and returns Err +func ResourceNotFoundL(l logx.Logger, s ...string) *Err { + e := ResourceNotFound(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InvalidResourceFormat returns Err +func InvalidResourceFormat(s ...string) *Err { + return newErr(Scope, code.InvalidResourceFormat, fmt.Sprintf("invalid resource format: %s", strings.Join(s, " "))) +} + +// InvalidResourceFormatL logs error message and returns Err +func InvalidResourceFormatL(l logx.Logger, s ...string) *Err { + e := InvalidResourceFormat(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InvalidResourceState returns status not correct. +// for example: company should be destroy, agent should be no-sensor/fail-install ... +func InvalidResourceState(s ...string) *Err { + return newErr(Scope, code.InvalidResourceState, fmt.Sprintf("invalid resource state: %s", strings.Join(s, " "))) +} + +// InvalidResourceStateL logs error message and returns status not correct. +func InvalidResourceStateL(l logx.Logger, s ...string) *Err { + e := InvalidResourceState(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +func ResourceInsufficient(s ...string) *Err { + return newErr(Scope, code.ResourceInsufficient, + fmt.Sprintf("insufficient resource: %s", strings.Join(s, " "))) +} + +func ResourceInsufficientL(l logx.Logger, s ...string) *Err { + e := ResourceInsufficient(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InsufficientPermission returns Err +func InsufficientPermission(s ...string) *Err { + return newErr(Scope, code.InsufficientPermission, + fmt.Sprintf("insufficient permission: %s", strings.Join(s, " "))) +} + +// InsufficientPermissionL returns Err and log +func InsufficientPermissionL(l logx.Logger, s ...string) *Err { + e := InsufficientPermission(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// ResourceAlreadyExist returns Err +func ResourceAlreadyExist(s ...string) *Err { + return newErr(Scope, code.ResourceAlreadyExist, fmt.Sprintf("resource already exist: %s", strings.Join(s, " "))) +} + +// ResourceAlreadyExistL logs error message and returns Err +func ResourceAlreadyExistL(l logx.Logger, s ...string) *Err { + e := ResourceAlreadyExist(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InvalidMeasurementID returns Err +func InvalidMeasurementID(s ...string) *Err { + return newErr(Scope, code.InvalidMeasurementID, fmt.Sprintf("missing measurement id: %s", strings.Join(s, " "))) +} + +// InvalidMeasurementIDL logs error message and returns Err +func InvalidMeasurementIDL(l logx.Logger, s ...string) *Err { + e := InvalidMeasurementID(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// ResourceExpired returns Err +func ResourceExpired(s ...string) *Err { + return newErr(Scope, code.ResourceExpired, fmt.Sprintf("resource expired: %s", strings.Join(s, " "))) +} + +// ResourceExpiredL logs error message and returns Err +func ResourceExpiredL(l logx.Logger, s ...string) *Err { + e := ResourceExpired(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// ResourceMigrated returns Err +func ResourceMigrated(s ...string) *Err { + return newErr(Scope, code.ResourceMigrated, fmt.Sprintf("resource migrated: %s", strings.Join(s, " "))) +} + +// ResourceMigratedL logs error message and returns Err +func ResourceMigratedL(l logx.Logger, s ...string) *Err { + e := ResourceMigrated(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InsufficientQuota returns Err +func InsufficientQuota(s ...string) *Err { + return newErr(Scope, code.InsufficientQuota, fmt.Sprintf("insufficient quota: %s", strings.Join(s, " "))) +} + +// InsufficientQuotaL logs error message and returns Err +func InsufficientQuotaL(l logx.Logger, s ...string) *Err { + e := InsufficientQuota(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +/*** CatAuth ***/ + +// Unauthorized returns Err +func Unauthorized(s ...string) *Err { + return newErr(Scope, code.Unauthorized, fmt.Sprintf("unauthorized: %s", strings.Join(s, " "))) +} + +// UnauthorizedL logs error message and returns Err +func UnauthorizedL(l logx.Logger, s ...string) *Err { + e := Unauthorized(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// AuthExpired returns Err +func AuthExpired(s ...string) *Err { + return newErr(Scope, code.AuthExpired, fmt.Sprintf("expired: %s", strings.Join(s, " "))) +} + +// AuthExpiredL logs error message and returns Err +func AuthExpiredL(l logx.Logger, s ...string) *Err { + e := AuthExpired(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// InvalidPosixTime returns Err +func InvalidPosixTime(s ...string) *Err { + return newErr(Scope, code.InvalidPosixTime, fmt.Sprintf("invalid posix time: %s", strings.Join(s, " "))) +} + +// InvalidPosixTimeL logs error message and returns Err +func InvalidPosixTimeL(l logx.Logger, s ...string) *Err { + e := InvalidPosixTime(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// SigAndPayloadNotMatched returns Err +func SigAndPayloadNotMatched(s ...string) *Err { + return newErr(Scope, code.SigAndPayloadNotMatched, fmt.Sprintf("signature and the payload are not match: %s", strings.Join(s, " "))) +} + +// SigAndPayloadNotMatchedL logs error message and returns Err +func SigAndPayloadNotMatchedL(l logx.Logger, s ...string) *Err { + e := SigAndPayloadNotMatched(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// Forbidden returns Err +func Forbidden(s ...string) *Err { + return newErr(Scope, code.Forbidden, fmt.Sprintf("forbidden: %s", strings.Join(s, " "))) +} + +// ForbiddenL logs error message and returns Err +func ForbiddenL(l logx.Logger, s ...string) *Err { + e := Forbidden(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// IsAuthUnauthorizedError check the err is unauthorized error +func IsAuthUnauthorizedError(err *Err) bool { + switch err.Code() { + case code.Unauthorized, code.AuthExpired, code.InvalidPosixTime, + code.SigAndPayloadNotMatched, code.Forbidden, + code.InvalidFormat, code.ResourceNotFound: + return true + default: + return false + } +} + +/*** CatXBC ***/ + +// ArkInternal returns Err +func ArkInternal(s ...string) *Err { + return newErr(Scope, code.ArkInternal, fmt.Sprintf("ark internal error: %s", strings.Join(s, " "))) +} + +// ArkInternalL logs error message and returns Err +func ArkInternalL(l logx.Logger, s ...string) *Err { + e := ArkInternal(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +/*** CatPubSub ***/ + +// Publish returns Err +func Publish(s ...string) *Err { + return newErr(Scope, code.Publish, fmt.Sprintf("publish: %s", strings.Join(s, " "))) +} + +// PublishL logs error message and returns Err +func PublishL(l logx.Logger, s ...string) *Err { + e := Publish(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} + +// Consume returns Err +func Consume(s ...string) *Err { + return newErr(Scope, code.Consume, fmt.Sprintf("consume: %s", strings.Join(s, " "))) +} + +// MsgSizeTooLarge returns Err +func MsgSizeTooLarge(s ...string) *Err { + return newErr(Scope, code.MsgSizeTooLarge, fmt.Sprintf("kafka error: %s", strings.Join(s, " "))) +} + +// MsgSizeTooLargeL logs error message and returns Err +func MsgSizeTooLargeL(l logx.Logger, s ...string) *Err { + e := MsgSizeTooLarge(s...) + l.WithCallerSkip(1).Error(e.Error()) + return e +} diff --git a/internal/lib/error/easy_func_test.go b/internal/lib/error/easy_func_test.go new file mode 100644 index 0000000..5ff951c --- /dev/null +++ b/internal/lib/error/easy_func_test.go @@ -0,0 +1,1031 @@ +package error + +import ( + "context" + "errors" + "fmt" + "member/internal/lib/error/code" + "reflect" + "strconv" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/logx" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestFromGRPCError_GivenStatusWithCodeAndMessage_ShouldReturnErr(t *testing.T) { + // setup + s := status.Error(codes.Code(102399), "FAKE ERROR") + + // act + e := FromGRPCError(s) + + // assert + assert.Equal(t, uint32(10), e.Scope()) + assert.Equal(t, uint32(2300), e.Category()) + assert.Equal(t, uint32(2399), e.Code()) + assert.Equal(t, "FAKE ERROR", e.Error()) +} + +func TestFromGRPCError_GivenNilError_ShouldReturnErr_Scope0_Cat0_Detail0(t *testing.T) { + // setup + var nilError error = nil + + // act + e := FromGRPCError(nilError) + + // assert + assert.Equal(t, uint32(0), e.Scope()) + assert.Equal(t, uint32(0), e.Category()) + assert.Equal(t, uint32(0), e.Code()) + assert.Equal(t, "", e.Error()) +} + +func TestFromGRPCError_GivenGRPCNativeError_ShouldReturnErr_Scope0_CatGRPC_DetailGRPCUnavailable(t *testing.T) { + // setup + msg := "GRPC Unavailable ERROR" + s := status.Error(codes.Code(codes.Unavailable), msg) + + // act + e := FromGRPCError(s) + + // assert + assert.Equal(t, code.Unset, e.Scope()) + assert.Equal(t, code.CatGRPC, e.Category()) + assert.Equal(t, uint32(codes.Unavailable), e.Code()) + assert.Equal(t, msg, e.Error()) +} + +func TestFromGRPCError_GivenGeneralError_ShouldReturnErr_Scope0_CatGRPC_DetailGRPCUnknown(t *testing.T) { + // setup + generalErr := errors.New("general error") + + // act + e := FromGRPCError(generalErr) + + // assert + assert.Equal(t, code.Unset, e.Scope()) + assert.Equal(t, code.CatGRPC, e.Category()) + assert.Equal(t, uint32(codes.Unknown), e.Code()) +} + +func TestToGRPCError_GivenErr_StatusShouldHave_Code112233(t *testing.T) { + // setup + e := Err{scope: 11, code: 2233, msg: "FAKE MSG"} + + // act + err := ToGRPCError(&e) + s, _ := status.FromError(err) + + // assert + assert.Equal(t, 112233, int(s.Code())) + assert.Equal(t, "FAKE MSG", s.Message()) +} + +func TestInvalidFormat_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InvalidFormat("field A", "Error description") + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.InvalidFormat, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Equal(t, e.Error(), "invalid format: field A Error description") +} + +func TestInvalidFormatL_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + // act + e := InvalidFormatL(logx.WithContext(ctx), "field A", "Error description") + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.InvalidFormat, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInvalidRange_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InvalidRange("field A", "Error description") + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.InvalidRange, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Equal(t, e.Error(), "invalid range: field A Error description") +} + +func TestInvalidRangeL_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + // act + e := InvalidRangeL(logx.WithContext(ctx), "field A", "Error description") + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.InvalidRange, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestNotValidImplementation_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := NotValidImplementation("field A", "Error description") + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.NotValidImplementation, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Equal(t, e.Error(), "not valid implementation: field A Error description") +} + +func TestNotValidImplementationL_WithStrings_ShouldHasCatInputAndDetailCode(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + // act + e := NotValidImplementationL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.NotValidImplementation, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestDBError_WithStrings_ShouldHasCatDBAndDetailCodeDBError(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := DBError("field A", "Error description") + + // assert + assert.Equal(t, code.CatDB, e.Category()) + assert.Equal(t, code.DBError, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestDBDataConvert_WithStrings_ShouldHasCatDBAndDetailCodeDBDataConvert(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := DBDataConvert("field A", "Error description") + + // assert + assert.Equal(t, code.CatDB, e.Category()) + assert.Equal(t, code.DBDataConvert, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceNotFound_WithStrings_ShouldHasCatResource_DetailCodeResourceNotFound(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := ResourceNotFound("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceNotFound, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInvalidResourceFormat_WithStrings_ShouldHasCatResource_DetailCodeInvalidResourceFormat(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InvalidResourceFormat("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InvalidResourceFormat, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInvalidResourceState_OK(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InvalidResourceState("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InvalidResourceState, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.EqualError(t, e, "invalid resource state: field A Error description") +} + +func TestInvalidResourceStateL_LogError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := InvalidResourceStateL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InvalidResourceState, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.EqualError(t, e, "invalid resource state: field A Error description") +} + +func TestAuthExpired_OK(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := AuthExpired("field A", "Error description") + + // assert + assert.Equal(t, code.CatAuth, e.Category()) + assert.Equal(t, code.AuthExpired, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestUnauthorized_WithStrings_ShouldHasCatAuth_DetailCodeUnauthorized(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := Unauthorized("field A", "Error description") + + // assert + assert.Equal(t, code.CatAuth, e.Category()) + assert.Equal(t, code.Unauthorized, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInvalidPosixTime_WithStrings_ShouldHasCatAuth_DetailCodeInvalidPosixTime(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InvalidPosixTime("field A", "Error description") + + // assert + assert.Equal(t, code.CatAuth, e.Category()) + assert.Equal(t, code.InvalidPosixTime, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestSigAndPayloadNotMatched_WithStrings_ShouldHasCatAuth_DetailCodeSigAndPayloadNotMatched(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := SigAndPayloadNotMatched("field A", "Error description") + + // assert + assert.Equal(t, code.CatAuth, e.Category()) + assert.Equal(t, code.SigAndPayloadNotMatched, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestForbidden_WithStrings_ShouldHasCatAuth_DetailCodeForbidden(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := Forbidden("field A", "Error description") + + // assert + assert.Equal(t, code.CatAuth, e.Category()) + assert.Equal(t, code.Forbidden, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestXBCInternal_WithStrings_ShouldHasCatResource_DetailCodeXBCInternal(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := ArkInternal("field A", "Error description") + + // assert + assert.Equal(t, code.CatArk, e.Category()) + assert.Equal(t, code.ArkInternal, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestGeneralInternalError_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := SystemInternalError("field A", "Error description") + + // assert + assert.Equal(t, code.CatSystem, e.Category()) + assert.Equal(t, code.SystemInternalError, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestGeneralInternalErrorL_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := SystemInternalErrorL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatSystem, e.Category()) + assert.Equal(t, code.SystemInternalError, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestSystemMaintainError_WithStrings_DetailSystemMaintainError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := SystemMaintainErrorL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatSystem, e.Category()) + assert.Equal(t, code.SystemMaintainError, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceAlreadyExist_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := ResourceAlreadyExist("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceAlreadyExist, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceAlreadyExistL_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := ResourceAlreadyExistL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceAlreadyExist, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceInsufficient_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := ResourceInsufficient("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceInsufficient, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceInsufficientL_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := ResourceInsufficientL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceInsufficient, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInsufficientPermission_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InsufficientPermission("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InsufficientPermission, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInsufficientPermissionL_WithStrings_DetailInternalError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := InsufficientPermissionL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InsufficientPermission, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInvalidMeasurementID_WithErrorStrings_ShouldReturnCorrectCodeAndErrorString(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InvalidMeasurementID("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InvalidMeasurementID, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInvalidMeasurementIDL_WithErrorStrings_ShouldReturnCorrectCodeAndErrorStringAndCallLogger(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := InvalidMeasurementIDL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InvalidMeasurementID, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceExpired_OK(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := ResourceExpired("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceExpired, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceExpiredL_LogError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := ResourceExpiredL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceExpired, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceMigrated_OK(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := ResourceMigrated("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceMigrated, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestResourceMigratedL_LogError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := ResourceMigratedL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.ResourceMigrated, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInsufficientQuota_OK(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := InsufficientQuota("field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InsufficientQuota, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestInsufficientQuotaL_LogError(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := InsufficientQuotaL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatResource, e.Category()) + assert.Equal(t, code.InsufficientQuota, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestPublish_WithErrorStrings_ShouldReturnCorrectCodeAndErrorString(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := Publish("field A", "Error description") + + // assert + assert.Equal(t, code.CatPubSub, e.Category()) + assert.Equal(t, code.Publish, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestPublishL_WithErrorStrings_ShouldReturnCorrectCodeAndErrorStringAndCallLogger(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := PublishL(l, "field A", "Error description") + + // assert + assert.Equal(t, code.CatPubSub, e.Category()) + assert.Equal(t, code.Publish, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") +} + +func TestMsgSizeTooLarge_WithErrorStrings_ShouldReturnCorrectCodeAndErrorString(t *testing.T) { + // setup + Scope = 99 + defer func() { + Scope = code.Unset + }() + + // act + e := MsgSizeTooLarge("Error description") + + // assert + assert.Equal(t, code.CatPubSub, e.Category()) + assert.Equal(t, code.MsgSizeTooLarge, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "kafka error: Error description") +} + +func TestMsgSizeTooLargeL_WithErrorStrings_ShouldReturnCorrectCodeAndErrorStringAndCallLogger(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + l := logx.WithContext(context.Background()) + + // act + e := MsgSizeTooLargeL(l, "Error description") + + // assert + assert.Equal(t, code.CatPubSub, e.Category()) + assert.Equal(t, code.MsgSizeTooLarge, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "kafka error: Error description") +} + +func TestStructErr_WithInternalErr_ShouldIsFuncReportCorrectly(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + // arrange 2 layers err + layer1Err := fmt.Errorf("layer 1 error") + layer2Err := fmt.Errorf("layer 2: %w", layer1Err) + + // act with error chain: InvalidFormat -> layer 2 err -> layer 1 err + e := InvalidFormat("field A", "Error description") + err := e.Wrap(layer2Err) + if err != nil { + t.Fatalf("Failed to wrap error: %v", err) + } + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.InvalidFormat, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") + + // errors.Is should report correctly + assert.True(t, errors.Is(e, layer1Err)) + assert.True(t, errors.Is(e, layer2Err)) +} + +func TestStructErr_WithInternalErr_ShouldErrorOutputChainErrMessage(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + + // arrange 2 layers err + layer1Err := fmt.Errorf("layer 1 error") + // act with error chain: InvalidFormat -> layer 1 err + e := InvalidFormat("field A", "Error description") + err := e.Wrap(layer1Err) + if err != nil { + t.Fatalf("Failed to wrap error: %v", err) + } + + // assert + assert.Equal(t, "invalid format: field A Error description: layer 1 error", e.Error()) +} + +// arrange a specific err type just for UT +type testErr struct { + code int +} + +func (e *testErr) Error() string { + return strconv.Itoa(e.code) +} + +func TestStructErr_WithInternalErr_ShouldAsFuncReportCorrectly(t *testing.T) { + // setup + Scope = 99 + defer func() { Scope = code.Unset }() + + testE := &testErr{code: 123} + layer2Err := fmt.Errorf("layer 2: %w", testE) + + // act with error chain: InvalidFormat -> layer 2 err -> testErr + e := InvalidFormat("field A", "Error description") + err := e.Wrap(layer2Err) + if err != nil { + t.Fatalf("Failed to wrap error: %v", err) + } + + // assert + assert.Equal(t, code.CatInput, e.Category()) + assert.Equal(t, code.InvalidFormat, e.Code()) + assert.Equal(t, uint32(99), e.Scope()) + assert.Contains(t, e.Error(), "field A") + assert.Contains(t, e.Error(), "Error description") + + // errors.As should report correctly + var internalErr *testErr + assert.True(t, errors.As(e, &internalErr)) + assert.Equal(t, testE, internalErr) +} + +/* +benchmark run for 1 second: +Benchmark_ErrorsIs_OneLayerError-4 148281332 8.68 ns/op 0 B/op 0 allocs/op +Benchmark_ErrorsIs_TwoLayerError-4 35048202 32.4 ns/op 0 B/op 0 allocs/op +Benchmark_ErrorsIs_FourLayerError-4 15309349 81.7 ns/op 0 B/op 0 allocs/op + +Benchmark_ErrorsAs_OneLayerError-4 16893205 70.4 ns/op 0 B/op 0 allocs/op +Benchmark_ErrorsAs_TwoLayerError-4 10568083 112 ns/op 0 B/op 0 allocs/op +Benchmark_ErrorsAs_FourLayerError-4 6307729 188 ns/op 0 B/op 0 allocs/op +*/ +func Benchmark_ErrorsIs_OneLayerError(b *testing.B) { + layer1Err := &testErr{code: 123} + var err error = layer1Err + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errors.Is(err, layer1Err) + } +} + +func Benchmark_ErrorsIs_TwoLayerError(b *testing.B) { + layer1Err := &testErr{code: 123} + + // act with error chain: InvalidFormat(layer 2) -> testErr(layer 1) + layer2Err := InvalidFormat("field A", "Error description") + err := layer2Err.Wrap(layer1Err) + if err != nil { + b.Fatalf("Failed to wrap error: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errors.Is(layer2Err, layer1Err) + } +} + +func Benchmark_ErrorsIs_FourLayerError(b *testing.B) { + layer1Err := &testErr{code: 123} + layer2Err := fmt.Errorf("layer 2: %w", layer1Err) + layer3Err := fmt.Errorf("layer 3: %w", layer2Err) + // act with error chain: InvalidFormat(layer 4) -> Error(layer 3) -> Error(layer 2) -> testErr(layer 1) + layer4Err := InvalidFormat("field A", "Error description") + err := layer4Err.Wrap(layer3Err) + if err != nil { + b.Fatalf("Failed to wrap error: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errors.Is(layer4Err, layer1Err) + } +} + +func Benchmark_ErrorsAs_OneLayerError(b *testing.B) { + layer1Err := &testErr{code: 123} + var err error = layer1Err + + b.ReportAllocs() + b.ResetTimer() + var internalErr *testErr + for i := 0; i < b.N; i++ { + errors.As(err, &internalErr) + } +} + +func Benchmark_ErrorsAs_TwoLayerError(b *testing.B) { + layer1Err := &testErr{code: 123} + + // act with error chain: InvalidFormat(layer 2) -> testErr(layer 1) + layer2Err := InvalidFormat("field A", "Error description") + err := layer2Err.Wrap(layer1Err) + if err != nil { + b.Fatalf("Failed to wrap error: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + var internalErr *testErr + for i := 0; i < b.N; i++ { + errors.As(layer2Err, &internalErr) + } +} + +func Benchmark_ErrorsAs_FourLayerError(b *testing.B) { + layer1Err := &testErr{code: 123} + layer2Err := fmt.Errorf("layer 2: %w", layer1Err) + layer3Err := fmt.Errorf("layer 3: %w", layer2Err) + // act with error chain: InvalidFormat(layer 4) -> Error(layer 3) -> Error(layer 2) -> testErr(layer 1) + layer4Err := InvalidFormat("field A", "Error description") + err := layer4Err.Wrap(layer3Err) + if err != nil { + b.Fatalf("Failed to wrap error: %v", err) + } + + b.ReportAllocs() + b.ResetTimer() + var internalErr *testErr + for i := 0; i < b.N; i++ { + errors.As(layer4Err, &internalErr) + } +} + +func TestFromError(t *testing.T) { + tests := []struct { + name string + givenError error + want *Err + }{ + { + "given nil error should return nil", + nil, + nil, + }, + { + "given normal error should return nil", + errors.New("normal error"), + nil, + }, + { + "given Err should return Err", + ResourceNotFound("fake error"), + ResourceNotFound("fake error"), + }, + { + "given error wraps Err should return Err", + fmt.Errorf("outter error wraps %w", ResourceNotFound("fake error")), + ResourceNotFound("fake error"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := FromError(tt.givenError); !reflect.DeepEqual(got, tt.want) { + t.Errorf("FromError() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/lib/error/errors.go b/internal/lib/error/errors.go new file mode 100644 index 0000000..fb16d5c --- /dev/null +++ b/internal/lib/error/errors.go @@ -0,0 +1,197 @@ +package error + +import ( + "ark-permission/internal/lib/error/code" + "errors" + "fmt" + "net/http" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// TODO Error要移到common 包 + +// Scope global variable should be set by service or module +var Scope = code.Unset + +type Err struct { + category uint32 + code uint32 + scope uint32 + msg string + internalErr error +} + +// Error is the interface of error +// Getter function of private property "msg" +func (e *Err) Error() string { + if e == nil { + return "" + } + + // chain the error string if the internal err exists + var internalErrStr string + if e.internalErr != nil { + internalErrStr = e.internalErr.Error() + } + + if e.msg != "" { + if internalErrStr != "" { + return fmt.Sprintf("%s: %s", e.msg, internalErrStr) + } + return e.msg + } + + generalErrStr := e.GeneralError() + if internalErrStr != "" { + return fmt.Sprintf("%s: %s", generalErrStr, internalErrStr) + } + return generalErrStr +} + +// Category getter function of private property "category" +func (e *Err) Category() uint32 { + if e == nil { + return 0 + } + return e.category +} + +// Scope getter function of private property "scope" +func (e *Err) Scope() uint32 { + if e == nil { + return code.Unset + } + + return e.scope +} + +// CodeStr returns the string of error code with zero padding +func (e *Err) CodeStr() string { + if e == nil { + return "00000" + } + + if e.Category() == code.CatGRPC { + return fmt.Sprintf("%d%04d", e.Scope(), e.Category()+e.Code()) + } + + return fmt.Sprintf("%d%04d", e.Scope(), e.Code()) +} + +// Code getter function of private property "code" +func (e *Err) Code() uint32 { + if e == nil { + return code.OK + } + + return e.code +} + +func (e *Err) FullCode() uint32 { + if e == nil { + return 0 + } + + if e.Category() == code.CatGRPC { + return e.Scope()*10000 + e.Category() + e.Code() + } + + return e.Scope()*10000 + e.Code() +} + +// HTTPStatus returns corresponding HTTP status code +func (e *Err) HTTPStatus() int { + if e == nil || e.Code() == code.OK { + return http.StatusOK + } + // determine status code by code + switch e.Code() { + case code.ResourceInsufficient: + // 400 + return http.StatusBadRequest + case code.Unauthorized, code.InsufficientPermission: + // 401 + return http.StatusUnauthorized + case code.InsufficientQuota: + // 402 + return http.StatusPaymentRequired + case code.InvalidPosixTime, code.Forbidden: + // 403 + return http.StatusForbidden + case code.ResourceNotFound: + // 404 + return http.StatusNotFound + case code.ResourceAlreadyExist, code.InvalidResourceState: + // 409 + return http.StatusConflict + case code.NotValidImplementation: + // 501 + return http.StatusNotImplemented + default: + } + + // determine status code by category + switch e.Category() { + case code.CatInput: + return http.StatusBadRequest + default: + // return status code 500 if none of the condition is met + return http.StatusInternalServerError + } +} + +// GeneralError transform category level error message +// It's the general error message for customer/API caller +func (e *Err) GeneralError() string { + if e == nil { + return "" + } + + errStr, ok := code.CatToStr[e.Category()] + if !ok { + return "" + } + + return errStr +} + +// Is called when performing errors.Is(). +// DO NOT USE THIS FUNCTION DIRECTLY unless you are very certain about what you're doing. +// Use errors.Is instead. +// This function compares if two error variables are both *Err, and have the same code (without checking the wrapped internal error) +func (e *Err) Is(f error) bool { + var err *Err + ok := errors.As(f, &err) + if !ok { + return false + } + return e.Code() == err.Code() +} + +// Unwrap returns the underlying error +// The result of unwrapping an error may itself have an Unwrap method; +// we call the sequence of errors produced by repeated unwrapping the error chain. +func (e *Err) Unwrap() error { + if e == nil { + return nil + } + return e.internalErr +} + +// Wrap sets the internal error to Err struct +func (e *Err) Wrap(internalErr error) *Err { + if e != nil { + e.internalErr = internalErr + } + return e +} + +func (e *Err) GRPCStatus() *status.Status { + if e == nil { + return status.New(codes.OK, "") + } + + return status.New(codes.Code(e.FullCode()), e.Error()) +} diff --git a/internal/lib/error/errors_test.go b/internal/lib/error/errors_test.go new file mode 100644 index 0000000..a0f5325 --- /dev/null +++ b/internal/lib/error/errors_test.go @@ -0,0 +1,297 @@ +package error + +import ( + "errors" + "fmt" + "member/internal/lib/error/code" + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestCode_GivenNilReceiver_CodeReturnOK_CodeStrReturns00000(t *testing.T) { + // setup + var e *Err = nil + + // act & assert + assert.Equal(t, code.OK, e.Code()) + assert.Equal(t, "00000", e.CodeStr()) + assert.Equal(t, "", e.Error()) +} + +func TestCode_GivenScope99DetailCode6687_ShouldReturn996687(t *testing.T) { + // setup + e := Err{scope: 99, code: 6687} + + // act & assert + assert.Equal(t, uint32(6687), e.Code()) + assert.Equal(t, "996687", e.CodeStr()) +} + +func TestCode_GivenScope0DetailCode87_ShouldReturn87(t *testing.T) { + // setup + e := Err{scope: 0, code: 87} + + // act & assert + assert.Equal(t, uint32(87), e.Code()) + assert.Equal(t, "00087", e.CodeStr()) +} + +func TestFromCode_Given870005_ShouldHasScope87_Cat0_Detail5(t *testing.T) { + // setup + e := FromCode(870005) + + // assert + assert.Equal(t, uint32(87), e.Scope()) + assert.Equal(t, uint32(0), e.Category()) + assert.Equal(t, uint32(5), e.Code()) + assert.Equal(t, "", e.Error()) +} + +func TestFromCode_Given0_ShouldHasScope0_Cat0_Detail0(t *testing.T) { + // setup + e := FromCode(0) + + // assert + assert.Equal(t, uint32(0), e.Scope()) + assert.Equal(t, uint32(0), e.Category()) + assert.Equal(t, uint32(0), e.Code()) + assert.Equal(t, "", e.Error()) +} + +func TestFromCode_Given9105_ShouldHasScope0_Cat9100_Detail9105(t *testing.T) { + // setup + e := FromCode(9105) + + // assert + assert.Equal(t, uint32(0), e.Scope()) + assert.Equal(t, uint32(9100), e.Category()) + assert.Equal(t, uint32(9105), e.Code()) + assert.Equal(t, "", e.Error()) +} + +func TestErr_ShouldImplementErrorFunction(t *testing.T) { + // setup a func return error + f := func() error { return InvalidFormat("fake field") } + + // act + err := f() + + // assert + assert.NotNil(t, err) + assert.Contains(t, fmt.Sprint(err), "fake field") // can be printed +} + +func TestGeneralError_GivenNilErr_ShouldReturnEmptyString(t *testing.T) { + // setup + var e *Err = nil + + // act & assert + assert.Equal(t, "", e.GeneralError()) +} + +func TestGeneralError_GivenNotExistCat_ShouldReturnEmptyString(t *testing.T) { + // setup + e := Err{category: 123456} + + // act & assert + assert.Equal(t, "", e.GeneralError()) +} + +func TestGeneralError_GivenCatDB_ShouldReturnDBError(t *testing.T) { + // setup + e := Err{category: code.CatDB} + catErrStr := code.CatToStr[code.CatDB] + + // act & assert + assert.Equal(t, catErrStr, e.GeneralError()) +} + +func TestError_GivenEmptyMsg_ShouldReturnCatGeneralErrorMessage(t *testing.T) { + // setup + e := Err{category: code.CatDB, msg: ""} + + // act + errMsg := e.Error() + + // assert + assert.Equal(t, code.CatToStr[code.CatDB], errMsg) +} + +func TestError_GivenMsg_ShouldReturnGiveMsg(t *testing.T) { + // setup + e := Err{msg: "FAKE"} + + // act + errMsg := e.Error() + + // assert + assert.Equal(t, "FAKE", errMsg) +} + +func TestIs_GivenNilErr_ShouldReturnFalse(t *testing.T) { + var nilErrs *Err + // act + result := errors.Is(nilErrs, DBError()) + result2 := errors.Is(DBError(), nilErrs) + + // assert + assert.False(t, result) + assert.False(t, result2) +} + +func TestIs_GivenNil_ShouldReturnFalse(t *testing.T) { + // act + result := errors.Is(nil, DBError()) + result2 := errors.Is(DBError(), nil) + + // assert + assert.False(t, result) + assert.False(t, result2) +} + +func TestIs_GivenNilReceiver_ShouldReturnCorrectResult(t *testing.T) { + var nilErr *Err = nil + + // test 1: nilErr != DBError + var dbErr error = DBError("fake db error") + assert.False(t, nilErr.Is(dbErr)) + + // test 2: nilErr != nil error + var nilError error + assert.False(t, nilErr.Is(nilError)) + + // test 3: nilErr == another nilErr + var nilErr2 *Err = nil + assert.True(t, nilErr.Is(nilErr2)) +} + +func TestIs_GivenDBError_ShouldReturnTrue(t *testing.T) { + // setup + dbErr := DBError("fake db error") + + // act + result := errors.Is(dbErr, DBError("not care")) + result2 := errors.Is(DBError(), dbErr) + + // assert + assert.True(t, result) + assert.True(t, result2) +} + +func TestIs_GivenDBErrorAssignToErrorType_ShouldReturnTrue(t *testing.T) { + // setup + var dbErr error = DBError("fake db error") + + // act + result := errors.Is(dbErr, DBError("not care")) + result2 := errors.Is(DBError(), dbErr) + + // assert + assert.True(t, result) + assert.True(t, result2) +} + +func TestWrap_GivenNilErr_ShouldNoPanic(t *testing.T) { + // act & assert + assert.NotPanics(t, func() { + var e *Err = nil + _ = e.Wrap(fmt.Errorf("test")) + }) +} + +func TestWrap_GivenErrorToWrap_ShouldReturnErrorWithWrappedError(t *testing.T) { + // act & assert + wrappedErr := fmt.Errorf("test") + wrappingErr := SystemInternalError("WrappingError").Wrap(wrappedErr) + unWrappedErr := wrappingErr.Unwrap() + + assert.Equal(t, wrappedErr, unWrappedErr) +} + +func TestUnwrap_GivenNilErr_ShouldReturnNil(t *testing.T) { + var e *Err = nil + internalErr := e.Unwrap() + assert.Nil(t, internalErr) +} + +func TestErrorsIs_GivenNilErr_ShouldReturnFalse(t *testing.T) { + var e *Err = nil + assert.False(t, errors.Is(e, fmt.Errorf("test"))) +} + +func TestErrorsAs_GivenNilErr_ShouldReturnFalse(t *testing.T) { + var internalErr *testErr + var e *Err = nil + assert.False(t, errors.As(e, &internalErr)) +} + +func TestGRPCStatus(t *testing.T) { + // setup table driven tests + tests := []struct { + name string + given *Err + expect *status.Status + expectConvert error + }{ + { + "nil errs.Err", + nil, + status.New(codes.OK, ""), + nil, + }, + { + "InvalidFormat Err", + InvalidFormat("fake"), + status.New(codes.Code(101), "invalid format: fake"), + status.New(codes.Code(101), "invalid format: fake").Err(), + }, + } + + // act & assert + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := test.given.GRPCStatus() + assert.Equal(t, test.expect.Code(), s.Code()) + assert.Equal(t, test.expect.Message(), s.Message()) + assert.Equal(t, test.expectConvert, status.Convert(test.given).Err()) + }) + } +} + +func TestErr_HTTPStatus(t *testing.T) { + tests := []struct { + name string + err *Err + want int + }{ + {name: "nil error", err: nil, want: http.StatusOK}, + {name: "invalid measurement id", err: &Err{category: code.CatResource, code: code.InvalidMeasurementID}, want: http.StatusInternalServerError}, + {name: "resource already exists", err: &Err{category: code.CatResource, code: code.ResourceAlreadyExist}, want: http.StatusConflict}, + {name: "invalid resource state", err: &Err{category: code.CatResource, code: code.InvalidResourceState}, want: http.StatusConflict}, + {name: "invalid posix time", err: &Err{category: code.CatAuth, code: code.InvalidPosixTime}, want: http.StatusForbidden}, + {name: "unauthorized", err: &Err{category: code.CatAuth, code: code.Unauthorized}, want: http.StatusUnauthorized}, + {name: "db error", err: &Err{category: code.CatDB, code: code.DBError}, want: http.StatusInternalServerError}, + {name: "insufficient permission", err: &Err{category: code.CatResource, code: code.InsufficientPermission}, want: http.StatusUnauthorized}, + {name: "resource insufficient", err: &Err{category: code.CatResource, code: code.ResourceInsufficient}, want: http.StatusBadRequest}, + {name: "invalid format", err: &Err{category: code.CatInput, code: code.InvalidFormat}, want: http.StatusBadRequest}, + {name: "resource not found", err: &Err{code: code.ResourceNotFound}, want: http.StatusNotFound}, + {name: "ok", err: &Err{code: code.OK}, want: http.StatusOK}, + {name: "not valid implementation", err: &Err{category: code.CatInput, code: code.NotValidImplementation}, want: http.StatusNotImplemented}, + {name: "forbidden", err: &Err{category: code.CatAuth, code: code.Forbidden}, want: http.StatusForbidden}, + {name: "insufficient quota", err: &Err{category: code.CatResource, code: code.InsufficientQuota}, want: http.StatusPaymentRequired}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + // act + got := tt.err.HTTPStatus() + + // assert + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/lib/middleware/with_context.go b/internal/lib/middleware/with_context.go new file mode 100644 index 0000000..df06757 --- /dev/null +++ b/internal/lib/middleware/with_context.go @@ -0,0 +1,28 @@ +package middleware + +import ( + ers "ark-permission/internal/lib/error" + "context" + "errors" + "time" + + "github.com/zeromicro/go-zero/core/logx" + "google.golang.org/grpc" +) + +const defaultTimeout = 30 * time.Second + +func TimeoutMiddleware(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { + + newCtx, cancelCtx := context.WithTimeout(ctx, defaultTimeout) + defer func() { + cancelCtx() + + if errors.Is(newCtx.Err(), context.DeadlineExceeded) { + err = ers.SystemTimeoutError(info.FullMethod) + logx.Errorf("Method: %s, request %v, timeout: %d", info.FullMethod, req, defaultTimeout) + } + }() + + return handler(ctx, req) +} diff --git a/internal/lib/required/validate.go b/internal/lib/required/validate.go new file mode 100644 index 0000000..6fe9a11 --- /dev/null +++ b/internal/lib/required/validate.go @@ -0,0 +1,51 @@ +package required + +import ( + "fmt" + + "github.com/zeromicro/go-zero/core/logx" + + "github.com/go-playground/validator/v10" +) + +type Validate interface { + ValidateAll(obj any) error + BindToValidator(opts ...Option) error +} + +type Validator struct { + V *validator.Validate +} + +// ValidateAll TODO 要移到common 包 +func (v *Validator) ValidateAll(obj any) error { + err := v.V.Struct(obj) + if err != nil { + return err + } + + return nil +} + +func (v *Validator) BindToValidator(opts ...Option) error { + for _, item := range opts { + err := v.V.RegisterValidation(item.ValidatorName, item.ValidatorFunc) + if err != nil { + return fmt.Errorf("failed to register validator : %w", err) + } + } + + return nil +} + +func MustValidator(option ...Option) Validate { + v := &Validator{ + V: validator.New(), + } + + if err := v.BindToValidator(option...); err != nil { + logx.Error("failed to bind validator") + } + + return v +} diff --git a/internal/lib/required/validate_option.go b/internal/lib/required/validate_option.go new file mode 100644 index 0000000..8a896cb --- /dev/null +++ b/internal/lib/required/validate_option.go @@ -0,0 +1,29 @@ +package required + +import ( + "regexp" + + "github.com/go-playground/validator/v10" +) + +type Option struct { + ValidatorName string + ValidatorFunc func(fl validator.FieldLevel) bool +} + +// WithAccount 創建一個新的 Option 結構,包含自定義的驗證函數,用於驗證 email 和台灣的手機號碼格式 +func WithAccount(tagName string) Option { + return Option{ + ValidatorName: tagName, + ValidatorFunc: func(fl validator.FieldLevel) bool { + value := fl.Field().String() + emailRegex := `^[a-z0-9._%+\-]+@[a-z0-9.\-]+\.[a-z]{2,}$` + phoneRegex := `^(\+886|0)?9\d{8}$` + + emailMatch, _ := regexp.MatchString(emailRegex, value) + phoneMatch, _ := regexp.MatchString(phoneRegex, value) + + return emailMatch || phoneMatch + }, + } +} diff --git a/internal/logic/cancel_token_by_device_i_d_logic.go b/internal/logic/cancel_token_by_device_id_logic.go similarity index 65% rename from internal/logic/cancel_token_by_device_i_d_logic.go rename to internal/logic/cancel_token_by_device_id_logic.go index 222a7bd..862a8d5 100644 --- a/internal/logic/cancel_token_by_device_i_d_logic.go +++ b/internal/logic/cancel_token_by_device_id_logic.go @@ -9,14 +9,14 @@ import ( "github.com/zeromicro/go-zero/core/logx" ) -type CancelTokenByDeviceIDLogic struct { +type CancelTokenByDeviceIdLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger } -func NewCancelTokenByDeviceIDLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokenByDeviceIDLogic { - return &CancelTokenByDeviceIDLogic{ +func NewCancelTokenByDeviceIdLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokenByDeviceIdLogic { + return &CancelTokenByDeviceIdLogic{ ctx: ctx, svcCtx: svcCtx, Logger: logx.WithContext(ctx), @@ -24,7 +24,7 @@ func NewCancelTokenByDeviceIDLogic(ctx context.Context, svcCtx *svc.ServiceConte } // CancelTokenByDeviceID 取消 Token -func (l *CancelTokenByDeviceIDLogic) CancelTokenByDeviceID(in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { +func (l *CancelTokenByDeviceIdLogic) CancelTokenByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { // todo: add your logic here and delete this line return &permission.OKResp{}, nil diff --git a/internal/logic/cancel_token_by_u_i_d_logic.go b/internal/logic/cancel_token_by_uid_logic.go similarity index 70% rename from internal/logic/cancel_token_by_u_i_d_logic.go rename to internal/logic/cancel_token_by_uid_logic.go index 8d6a001..05b0566 100644 --- a/internal/logic/cancel_token_by_u_i_d_logic.go +++ b/internal/logic/cancel_token_by_uid_logic.go @@ -9,14 +9,14 @@ import ( "github.com/zeromicro/go-zero/core/logx" ) -type CancelTokenByUIDLogic struct { +type CancelTokenByUidLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger } -func NewCancelTokenByUIDLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokenByUIDLogic { - return &CancelTokenByUIDLogic{ +func NewCancelTokenByUidLogic(ctx context.Context, svcCtx *svc.ServiceContext) *CancelTokenByUidLogic { + return &CancelTokenByUidLogic{ ctx: ctx, svcCtx: svcCtx, Logger: logx.WithContext(ctx), @@ -24,7 +24,7 @@ func NewCancelTokenByUIDLogic(ctx context.Context, svcCtx *svc.ServiceContext) * } // CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (l *CancelTokenByUIDLogic) CancelTokenByUID(in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { +func (l *CancelTokenByUidLogic) CancelTokenByUid(in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { // todo: add your logic here and delete this line return &permission.OKResp{}, nil diff --git a/internal/logic/claims.go b/internal/logic/claims.go new file mode 100644 index 0000000..937953d --- /dev/null +++ b/internal/logic/claims.go @@ -0,0 +1,55 @@ +package logic + +type claims map[string]string + +func (c claims) SetID(id string) { + c["id"] = id +} + +func (c claims) SetRole(role string) { + c["role"] = role +} + +func (c claims) SetDeviceID(deviceID string) { + c["device_id"] = deviceID +} + +func (c claims) SetScope(scope string) { + c["scope"] = scope +} + +func (c claims) Role() string { + role, ok := c["role"] + if !ok { + return "" + } + + return role +} + +func (c claims) ID() string { + id, ok := c["id"] + if !ok { + return "" + } + + return id +} + +func (c claims) DeviceID() string { + deviceID, ok := c["device_id"] + if !ok { + return "" + } + + return deviceID +} + +func (c claims) UID() string { + uid, ok := c["uid"] + if !ok { + return "" + } + + return uid +} diff --git a/internal/logic/get_user_tokens_by_device_i_d_logic.go b/internal/logic/get_user_tokens_by_device_id_logic.go similarity index 66% rename from internal/logic/get_user_tokens_by_device_i_d_logic.go rename to internal/logic/get_user_tokens_by_device_id_logic.go index ea6cd3d..d633995 100644 --- a/internal/logic/get_user_tokens_by_device_i_d_logic.go +++ b/internal/logic/get_user_tokens_by_device_id_logic.go @@ -9,14 +9,14 @@ import ( "github.com/zeromicro/go-zero/core/logx" ) -type GetUserTokensByDeviceIDLogic struct { +type GetUserTokensByDeviceIdLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger } -func NewGetUserTokensByDeviceIDLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserTokensByDeviceIDLogic { - return &GetUserTokensByDeviceIDLogic{ +func NewGetUserTokensByDeviceIdLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserTokensByDeviceIdLogic { + return &GetUserTokensByDeviceIdLogic{ ctx: ctx, svcCtx: svcCtx, Logger: logx.WithContext(ctx), @@ -24,7 +24,7 @@ func NewGetUserTokensByDeviceIDLogic(ctx context.Context, svcCtx *svc.ServiceCon } // GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens -func (l *GetUserTokensByDeviceIDLogic) GetUserTokensByDeviceID(in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { +func (l *GetUserTokensByDeviceIdLogic) GetUserTokensByDeviceId(in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { // todo: add your logic here and delete this line return &permission.Tokens{}, nil diff --git a/internal/logic/get_user_tokens_by_u_i_d_logic.go b/internal/logic/get_user_tokens_by_uid_logic.go similarity index 67% rename from internal/logic/get_user_tokens_by_u_i_d_logic.go rename to internal/logic/get_user_tokens_by_uid_logic.go index 339eeb7..ee70ab1 100644 --- a/internal/logic/get_user_tokens_by_u_i_d_logic.go +++ b/internal/logic/get_user_tokens_by_uid_logic.go @@ -9,14 +9,14 @@ import ( "github.com/zeromicro/go-zero/core/logx" ) -type GetUserTokensByUIDLogic struct { +type GetUserTokensByUidLogic struct { ctx context.Context svcCtx *svc.ServiceContext logx.Logger } -func NewGetUserTokensByUIDLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserTokensByUIDLogic { - return &GetUserTokensByUIDLogic{ +func NewGetUserTokensByUidLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserTokensByUidLogic { + return &GetUserTokensByUidLogic{ ctx: ctx, svcCtx: svcCtx, Logger: logx.WithContext(ctx), @@ -24,7 +24,7 @@ func NewGetUserTokensByUIDLogic(ctx context.Context, svcCtx *svc.ServiceContext) } // GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens -func (l *GetUserTokensByUIDLogic) GetUserTokensByUID(in *permission.DoTokenByUIDReq) (*permission.Tokens, error) { +func (l *GetUserTokensByUidLogic) GetUserTokensByUid(in *permission.DoTokenByUIDReq) (*permission.Tokens, error) { // todo: add your logic here and delete this line return &permission.Tokens{}, nil diff --git a/internal/logic/new_token_logic.go b/internal/logic/new_token_logic.go index 180eca1..07fb080 100644 --- a/internal/logic/new_token_logic.go +++ b/internal/logic/new_token_logic.go @@ -1,10 +1,19 @@ package logic import ( - "context" - "ark-permission/gen_result/pb/permission" + "ark-permission/internal/domain" + "ark-permission/internal/entity" + ers "ark-permission/internal/lib/error" "ark-permission/internal/svc" + "bytes" + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "github.com/golang-jwt/jwt/v4" + "github.com/google/uuid" + "time" "github.com/zeromicro/go-zero/core/logx" ) @@ -23,9 +32,106 @@ func NewNewTokenLogic(ctx context.Context, svcCtx *svc.ServiceContext) *NewToken } } +// https://datatracker.ietf.org/doc/html/rfc6749#section-3.3 +type authorizationReq struct { + GrantType domain.GrantType `json:"grant_type" validate:"required,oneof=password client_credentials refresh_token"` + DeviceID string `json:"device_id"` + Scope string `json:"scope" validate:"required"` + Data map[string]any `json:"data"` + Expires int `json:"expires"` + IsRefreshToken bool `json:"is_refresh_token"` +} + // NewToken 建立一個新的 Token,例如:AccessToken func (l *NewTokenLogic) NewToken(in *permission.AuthorizationReq) (*permission.TokenResp, error) { - // todo: add your logic here and delete this line + // 驗證所需 + if err := l.svcCtx.Validate.ValidateAll(&authorizationReq{ + GrantType: domain.GrantType(in.GetGrantType()), + Scope: in.GetScope(), + }); err != nil { + return nil, ers.InvalidFormat(err.Error()) + } - return &permission.TokenResp{}, nil + // 準備建立 Token 所需 + now := time.Now().UTC() + expires := int(in.GetExpires()) + refreshExpires := int(in.GetExpires()) + if expires <= 0 { + expires = int(l.svcCtx.Config.Token.Expired.Seconds()) + refreshExpires = expires + } + + // 如果這是一個 Refresh Token 過期時間要比普通的Token 長 + if in.GetIsRefreshToken() { + refreshExpires = int(l.svcCtx.Config.Token.RefreshExpires.Seconds()) + } + + token := entity.Token{ + ID: uuid.Must(uuid.NewRandom()).String(), + DeviceID: in.GetDeviceId(), + ExpiresIn: expires, + RefreshExpiresIn: refreshExpires, + AccessCreateAt: now, + RefreshCreateAt: now, + } + + claims := claims(in.GetData()) + claims.SetRole(domain.DefaultRole) + claims.SetID(token.ID) + claims.SetScope(in.GetScope()) + + token.UID = claims.UID() + + if in.GetDeviceId() != "" { + claims.SetDeviceID(in.GetDeviceId()) + } + + var err error + token.AccessToken, err = generateAccessToken(token, claims, l.svcCtx.Config.Token.Secret) + if err != nil { + return nil, ers.ArkInternal(fmt.Errorf("accessGenerate token error: %w", err).Error()) + } + + if in.GetIsRefreshToken() { + token.RefreshToken = generateRefreshToken(token.AccessToken) + } + + err = l.svcCtx.TokenRedisRepo.Create(l.ctx, token) + if err != nil { + return nil, ers.ArkInternal(fmt.Errorf("tokenRepository.Create error: %w", err).Error()) + } + + return &permission.TokenResp{ + AccessToken: token.AccessToken, + TokenType: domain.TokenTypeBearer, + ExpiresIn: int32(token.ExpiresIn), + RefreshToken: token.RefreshToken, + }, nil +} + +func generateAccessToken(token entity.Token, data any, sign string) (string, error) { + claim := entity.Claims{ + Data: data, + RegisteredClaims: jwt.RegisteredClaims{ + ID: token.ID, + ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)), + Issuer: "permission", + }, + } + + accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claim). + SignedString([]byte(sign)) + if err != nil { + return "", err + } + + return accessToken, nil +} + +func generateRefreshToken(accessToken string) string { + buf := bytes.NewBufferString(accessToken) + h := sha256.New() + _, _ = h.Write(buf.Bytes()) + + return hex.EncodeToString(h.Sum(nil)) } diff --git a/internal/repository/token.go b/internal/repository/token.go new file mode 100644 index 0000000..2f00c71 --- /dev/null +++ b/internal/repository/token.go @@ -0,0 +1,136 @@ +package repository + +import ( + "ark-permission/internal/domain" + "ark-permission/internal/domain/repository" + "ark-permission/internal/entity" + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/zeromicro/go-zero/core/stores/redis" +) + +type TokenRepositoryParam struct { + Store *redis.Redis `name:"redis"` +} + +type tokenRepository struct { + store *redis.Redis +} + +func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository { + return &tokenRepository{ + store: param.Store, + } +} + +func (t *tokenRepository) Create(ctx context.Context, token entity.Token) error { + body, err := json.Marshal(token) + if err != nil { + return wrapError("json.Marshal token error", err) + } + + err = t.store.Pipelined(func(tx redis.Pipeliner) error { + rTTL := token.RefreshTokenExpires() + + if err := t.setToken(ctx, tx, token, body, rTTL); err != nil { + return err + } + + if err := t.setRefreshToken(ctx, tx, token, rTTL); err != nil { + return err + } + + if err := t.setDeviceToken(ctx, tx, token, rTTL); err != nil { + return err + } + + return nil + }) + if err != nil { + return wrapError("store.Pipelined error", err) + } + + if err := t.SetUIDToken(token); err != nil { + return wrapError("SetUIDToken error", err) + } + + return nil +} + +func (t *tokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, body []byte, rTTL time.Duration) error { + err := tx.Set(ctx, domain.GetAccessTokenRedisKey(token.ID), body, rTTL).Err() + if err != nil { + return wrapError("tx.Set GetAccessTokenRedisKey error", err) + } + return nil +} + +func (t *tokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { + if token.RefreshToken != "" { + err := tx.Set(ctx, domain.RefreshTokenRedisKey.With(token.RefreshToken).ToString(), token.ID, rTTL).Err() + if err != nil { + return wrapError("tx.Set RefreshToken error", err) + } + } + return nil +} + +func (t *tokenRepository) setDeviceToken(ctx context.Context, tx redis.Pipeliner, token entity.Token, rTTL time.Duration) error { + if token.DeviceID != "" { + key := domain.DeviceTokenRedisKey.With(token.UID).ToString() + value := fmt.Sprintf("%s-%d", token.ID, token.AccessCreateAt.Add(rTTL).Unix()) + err := tx.HSet(ctx, key, token.DeviceID, value).Err() + if err != nil { + return wrapError("tx.HSet Device Token error", err) + } + err = tx.Expire(ctx, key, rTTL).Err() + if err != nil { + return wrapError("tx.Expire Device Token error", err) + } + } + return nil +} + +// SetUIDToken 將 token 資料放進 uid key中 +func (t *tokenRepository) SetUIDToken(token entity.Token) error { + uidTokens := make(entity.UIDToken) + b, err := t.store.Get(domain.GetUIDTokenRedisKey(token.UID)) + if err != nil && !errors.Is(err, redis.Nil) { + return wrapError("t.store.Get GetUIDTokenRedisKey error", err) + } + + if b != "" { + err = json.Unmarshal([]byte(b), &uidTokens) + if err != nil { + return wrapError("json.Unmarshal GetUIDTokenRedisKey error", err) + } + } + + now := time.Now().Unix() + for k, t := range uidTokens { + if t < now { + delete(uidTokens, k) + } + } + + uidTokens[token.ID] = token.RefreshTokenExpiresUnix() + s, err := json.Marshal(uidTokens) + if err != nil { + return wrapError("json.Marshal UIDToken error", err) + } + + err = t.store.Setex(domain.GetUIDTokenRedisKey(token.UID), string(s), 86400*30) + if err != nil { + return wrapError("t.store.Setex GetUIDTokenRedisKey error", err) + } + + return nil +} + +func wrapError(message string, err error) error { + return fmt.Errorf("%s: %w", message, err) +} diff --git a/internal/server/token_service_server.go b/internal/server/token_service_server.go index 81e5554..5e11a23 100644 --- a/internal/server/token_service_server.go +++ b/internal/server/token_service_server.go @@ -41,15 +41,15 @@ func (s *TokenServiceServer) CancelToken(ctx context.Context, in *permission.Can } // CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (s *TokenServiceServer) CancelTokenByUID(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { - l := logic.NewCancelTokenByUIDLogic(ctx, s.svcCtx) - return l.CancelTokenByUID(in) +func (s *TokenServiceServer) CancelTokenByUid(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.OKResp, error) { + l := logic.NewCancelTokenByUidLogic(ctx, s.svcCtx) + return l.CancelTokenByUid(in) } // CancelTokenByDeviceID 取消 Token -func (s *TokenServiceServer) CancelTokenByDeviceID(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { - l := logic.NewCancelTokenByDeviceIDLogic(ctx, s.svcCtx) - return l.CancelTokenByDeviceID(in) +func (s *TokenServiceServer) CancelTokenByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.OKResp, error) { + l := logic.NewCancelTokenByDeviceIdLogic(ctx, s.svcCtx) + return l.CancelTokenByDeviceId(in) } // ValidationToken 驗證這個 Token 有沒有效 @@ -59,15 +59,15 @@ func (s *TokenServiceServer) ValidationToken(ctx context.Context, in *permission } // GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens -func (s *TokenServiceServer) GetUserTokensByDeviceID(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { - l := logic.NewGetUserTokensByDeviceIDLogic(ctx, s.svcCtx) - return l.GetUserTokensByDeviceID(in) +func (s *TokenServiceServer) GetUserTokensByDeviceId(ctx context.Context, in *permission.DoTokenByDeviceIDReq) (*permission.Tokens, error) { + l := logic.NewGetUserTokensByDeviceIdLogic(ctx, s.svcCtx) + return l.GetUserTokensByDeviceId(in) } // GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens -func (s *TokenServiceServer) GetUserTokensByUID(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.Tokens, error) { - l := logic.NewGetUserTokensByUIDLogic(ctx, s.svcCtx) - return l.GetUserTokensByUID(in) +func (s *TokenServiceServer) GetUserTokensByUid(ctx context.Context, in *permission.DoTokenByUIDReq) (*permission.Tokens, error) { + l := logic.NewGetUserTokensByUidLogic(ctx, s.svcCtx) + return l.GetUserTokensByUid(in) } // NewOneTimeToken 建立一次性使用,例如:RefreshToken diff --git a/internal/svc/service_context.go b/internal/svc/service_context.go index 9541b6b..5a21250 100644 --- a/internal/svc/service_context.go +++ b/internal/svc/service_context.go @@ -1,13 +1,33 @@ package svc -import "ark-permission/internal/config" +import ( + "ark-permission/internal/config" + "ark-permission/internal/domain/repository" + "ark-permission/internal/lib/required" + repo "ark-permission/internal/repository" + "github.com/zeromicro/go-zero/core/stores/redis" +) type ServiceContext struct { Config config.Config + + Validate required.Validate + Redis redis.Redis + TokenRedisRepo repository.TokenRepository } func NewServiceContext(c config.Config) *ServiceContext { + newRedis, err := redis.NewRedis(c.RedisCluster, redis.Cluster()) + if err != nil { + panic(err) + } + return &ServiceContext{ - Config: c, + Config: c, + Validate: required.MustValidator(), + Redis: *newRedis, + TokenRedisRepo: repo.NewTokenRepository(repo.TokenRepositoryParam{ + Store: newRedis, + }), } } diff --git a/permission.go b/permission.go index 7309554..d7f2864 100644 --- a/permission.go +++ b/permission.go @@ -34,6 +34,9 @@ func main() { }) defer s.Stop() + // // 加入中間件 + // s.AddUnaryInterceptors(middleware.TimeoutMiddleware) + fmt.Printf("Starting rpc server at %s...\n", c.ListenOn) s.Start() } diff --git a/tokenservice/token_service.go b/tokenservice/token_service.go index 074dcad..28be4ee 100644 --- a/tokenservice/token_service.go +++ b/tokenservice/token_service.go @@ -36,15 +36,15 @@ type ( // CancelToken 取消 Token,也包含他裡面的 One Time Toke CancelToken(ctx context.Context, in *CancelTokenReq, opts ...grpc.CallOption) (*OKResp, error) // CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke - CancelTokenByUID(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) + CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) // CancelTokenByDeviceID 取消 Token - CancelTokenByDeviceID(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) + CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) // ValidationToken 驗證這個 Token 有沒有效 ValidationToken(ctx context.Context, in *ValidationTokenReq, opts ...grpc.CallOption) (*ValidationTokenResp, error) // GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens - GetUserTokensByDeviceID(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) + GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) // GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens - GetUserTokensByUID(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) + GetUserTokensByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) // NewOneTimeToken 建立一次性使用,例如:RefreshToken NewOneTimeToken(ctx context.Context, in *CreateOneTimeTokenReq, opts ...grpc.CallOption) (*CreateOneTimeTokenResp, error) // CancelOneTimeToken 取消一次性使用 @@ -81,15 +81,15 @@ func (m *defaultTokenService) CancelToken(ctx context.Context, in *CancelTokenRe } // CancelTokenByUID 取消 Token (取消這個用戶從不同 Device 登入的所有 Token),也包含他裡面的 One Time Toke -func (m *defaultTokenService) CancelTokenByUID(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { +func (m *defaultTokenService) CancelTokenByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*OKResp, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelTokenByUID(ctx, in, opts...) + return client.CancelTokenByUid(ctx, in, opts...) } // CancelTokenByDeviceID 取消 Token -func (m *defaultTokenService) CancelTokenByDeviceID(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { +func (m *defaultTokenService) CancelTokenByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*OKResp, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.CancelTokenByDeviceID(ctx, in, opts...) + return client.CancelTokenByDeviceId(ctx, in, opts...) } // ValidationToken 驗證這個 Token 有沒有效 @@ -99,15 +99,15 @@ func (m *defaultTokenService) ValidationToken(ctx context.Context, in *Validatio } // GetUserTokensByDeviceIDs 取得目前所對應的 DeviceID 所存在的 Tokens -func (m *defaultTokenService) GetUserTokensByDeviceID(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { +func (m *defaultTokenService) GetUserTokensByDeviceId(ctx context.Context, in *DoTokenByDeviceIDReq, opts ...grpc.CallOption) (*Tokens, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.GetUserTokensByDeviceID(ctx, in, opts...) + return client.GetUserTokensByDeviceId(ctx, in, opts...) } // GetUserTokensByUID 取得目前所對應的 UID 所存在的 Tokens -func (m *defaultTokenService) GetUserTokensByUID(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) { +func (m *defaultTokenService) GetUserTokensByUid(ctx context.Context, in *DoTokenByUIDReq, opts ...grpc.CallOption) (*Tokens, error) { client := permission.NewTokenServiceClient(m.cli.Conn()) - return client.GetUserTokensByUID(ctx, in, opts...) + return client.GetUserTokensByUid(ctx, in, opts...) } // NewOneTimeToken 建立一次性使用,例如:RefreshToken