From d71ffea7501157ef72c981a3e520909e0d153bda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=80=A7=E9=A9=8A?= Date: Wed, 22 Oct 2025 21:40:31 +0800 Subject: [PATCH] test push --- etc/gateway.yaml | 8 + generate/api/common.api | 11 +- generate/api/member.api | 29 +- .../mongo/2025101200000001_roles.down.txt | 4 + .../mongo/2025101200000001_roles.up.txt | 32 ++ internal/config/config.go | 32 +- internal/domain/redis.go | 19 + internal/handler/auth/login_handler.go | 36 +- .../handler/auth/refresh_token_handler.go | 32 +- internal/handler/auth/register_handler.go | 21 +- .../auth/request_password_reset_handler.go | 32 +- .../handler/auth/reset_password_handler.go | 32 +- .../verify_password_reset_code_handler.go | 32 +- .../handler/user/get_user_info_handler.go | 33 +- internal/logic/auth/generate_token.go | 5 +- internal/logic/auth/login_logic.go | 7 +- internal/logic/auth/register_logic.go | 12 +- .../auth/request_password_reset_logic.go | 113 +++++- internal/logic/auth/reset_password_logic.go | 66 ++- .../auth/verify_password_reset_code_logic.go | 16 +- internal/logic/user/get_user_info_logic.go | 93 ++++- internal/middleware/auth_middleware.go | 80 +++- internal/svc/service_context.go | 22 +- internal/svc/token.go | 102 +++++ internal/types/types.go | 35 +- internal/utils/format.go | 36 ++ pkg/library/errs/code/code.go | 19 +- pkg/library/errs/easy_func.go | 5 + pkg/library/errs/errors.go | 10 +- pkg/library/errs/errors_test.go | 12 +- pkg/member/domain/member/default.go | 32 ++ pkg/member/domain/repository/account_uid.go | 1 + pkg/member/domain/usecase/account.go | 2 + pkg/member/repository/account_uid.go | 14 + pkg/member/repository/user.go | 2 +- pkg/member/usecase/account.go | 13 + pkg/notification/config/config.go | 162 +++++++- pkg/notification/config/config_test.go | 314 +++++++++++++++ pkg/notification/domain/error.go | 1 + pkg/notification/domain/usecase/template.go | 14 + pkg/notification/repository/aws_ses_mailer.go | 91 ++--- .../repository/mitake_sms_sender.go | 42 +- pkg/notification/repository/smtp_mailer.go | 53 ++- pkg/notification/usecase/delivery.go | 247 ++++++------ pkg/notification/usecase/delivery_test.go | 377 ++++++++++++++++++ pkg/notification/usecase/template_test.go | 255 ++++++++++++ pkg/permission/domain/config/permission.go | 2 +- pkg/permission/domain/permission/role.go | 5 + pkg/permission/domain/token/context.go | 50 +++ pkg/permission/repository/permission.go | 2 +- pkg/permission/repository/permission_test.go | 2 +- pkg/permission/usecase/token.go | 24 +- pkg/permission/usecase/token_claims.go | 28 +- pkg/permission/usecase/token_claims_test.go | 27 +- pkg/permission/usecase/token_jwt.go | 7 +- pkg/permission/usecase/token_jwt_test.go | 18 +- 56 files changed, 2390 insertions(+), 381 deletions(-) create mode 100644 generate/database/mongo/2025101200000001_roles.down.txt create mode 100644 generate/database/mongo/2025101200000001_roles.up.txt create mode 100644 internal/domain/redis.go create mode 100644 internal/utils/format.go create mode 100644 pkg/notification/config/config_test.go create mode 100644 pkg/notification/usecase/delivery_test.go create mode 100644 pkg/notification/usecase/template_test.go create mode 100644 pkg/permission/domain/permission/role.go create mode 100644 pkg/permission/domain/token/context.go diff --git a/etc/gateway.yaml b/etc/gateway.yaml index 6b5e961..f09c57d 100644 --- a/etc/gateway.yaml +++ b/etc/gateway.yaml @@ -51,3 +51,11 @@ Token: OneTimeTokenExpiry : 600s MaxTokensPerUser : 2 MaxTokensPerDevice : 2 + + +RoleConfig: + UIDPrefix: "AM" + UIDLength: 6 + AdminRoleUID: "AM000000" + AdminUserUID: "B000000" + DefaultRoleName: "USER" \ No newline at end of file diff --git a/generate/api/common.api b/generate/api/common.api index fdfa841..0194bc0 100755 --- a/generate/api/common.api +++ b/generate/api/common.api @@ -3,11 +3,7 @@ syntax = "v1" // ================ 通用響應 ================ type ( // 成功響應 - RespOK { - Code int `json:"code"` - Msg string `json:"msg"` - Data interface{} `json:"data,omitempty"` - } + RespOK {} // 分頁響應 PagerResp { @@ -29,4 +25,9 @@ type ( Authorization { Authorization string `header:"Authorization" validate:"required"` } + Status { + Code int64 `json:"code"` // 狀態碼 + Message string `json:"message"` // 訊息 + Data interface{} `json:"data,omitempty"` // 可選的資料,當有返回時才出現 + } ) diff --git a/generate/api/member.api b/generate/api/member.api index 34b5742..1ce53bb 100644 --- a/generate/api/member.api +++ b/generate/api/member.api @@ -38,7 +38,7 @@ type ( // RequestPasswordResetReq 請求發送「忘記密碼」的驗證碼 RequestPasswordResetReq { - Identifier string `json:"identifier" validate:"required,email|phone"` // 使用者帳號 (信箱或手機) + Identifier string `json:"identifier" validate:"required"` // 使用者帳號 (信箱或手機) AccountType string `json:"account_type" validate:"required,oneof=email phone"` } @@ -141,6 +141,31 @@ type ( VerifyCode string `json:"verify_code" validate:"required,len=6"` Authorization } + + // MyInfo 用於獲取會員資訊的標準響應結構 + MyInfo { + Platform string `json:"platform"` // 註冊平台 + UID string `json:"uid"` // 用戶 UID + AvatarURL *string `json:"avatar_url,omitempty"` // 頭像 URL + FullName *string `json:"full_name,omitempty"` // 用戶全名 + Nickname *string `json:"nickname,omitempty"` // 暱稱 + GenderCode *string `json:"gender_code,omitempty"` // 性別代碼 + Birthdate *string `json:"birthdate,omitempty"` // 生日 (格式: 1993-04-17) + PhoneNumber *string `json:"phone_number,omitempty"` // 電話 + IsPhoneVerified *bool `json:"is_phone_verified,omitempty"` // 手機是否已驗證 + Email *string `json:"email,omitempty"` // 信箱 + IsEmailVerified *bool `json:"is_email_verified,omitempty"` // 信箱是否已驗證 + Address *string `json:"address,omitempty"` // 地址 + UserStatus string `json:"user_status,omitempty"` // 用戶狀態 + PreferredLanguage string `json:"preferred_language,omitempty"` // 偏好語言 + Currency string `json:"currency,omitempty"` // 偏好幣種 + AlarmCategory string `json:"alarm_category,omitempty"` // 告警狀態 + PostCode *string `json:"post_code,omitempty"` // 郵遞區號 + Carrier *string `json:"carrier,omitempty"` // 載具 + Role string `json:"role"` // 角色 + UpdateAt string `json:"update_at"` + CreateAt string `json:"create_at"` + } ) // ================================================================= @@ -251,7 +276,7 @@ service gateway { @respdoc-500 (ErrorResp) // 伺服器內部錯誤 */ @handler getUserInfo - get /me (Authorization) returns (UserInfoResp) + get /me (Authorization) returns (MyInfo) @doc( summary: "更新當前登入的會員資訊" diff --git a/generate/database/mongo/2025101200000001_roles.down.txt b/generate/database/mongo/2025101200000001_roles.down.txt new file mode 100644 index 0000000..fcd9846 --- /dev/null +++ b/generate/database/mongo/2025101200000001_roles.down.txt @@ -0,0 +1,4 @@ +db.role.deleteMany({ + "uid": { "$in": ["ADMIN", "OPERATOR", "USER"] } +}); + diff --git a/generate/database/mongo/2025101200000001_roles.up.txt b/generate/database/mongo/2025101200000001_roles.up.txt new file mode 100644 index 0000000..2e15e49 --- /dev/null +++ b/generate/database/mongo/2025101200000001_roles.up.txt @@ -0,0 +1,32 @@ +db.role.insertMany([ + { + "client_id": 1, + "uid": "ADMIN", + "name": "管理員", + "status": 1, + "create_time": NumberLong(1728745200), + "update_time": NumberLong(1728745200) + }, + { + "client_id": 1, + "uid": "OPERATOR", + "name": "操作員", + "status": 1, + "create_time": NumberLong(1728745200), + "update_time": NumberLong(1728745200) + }, + { + "client_id": 1, + "uid": "USER", + "name": "一般使用者", + "status": 1, + "create_time": NumberLong(1728745200), + "update_time": NumberLong(1728745200) + } +]); + +// 建立索引 +db.role.createIndex({ "uid": 1 }, { unique: true }); +db.role.createIndex({ "client_id": 1 }); +db.role.createIndex({ "status": 1 }); + diff --git a/internal/config/config.go b/internal/config/config.go index 8c21333..e52fb9f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -52,12 +52,30 @@ type Config struct { // JWT Token 配置 Token struct { - AccessSecret string - RefreshSecret string - AccessTokenExpiry time.Duration - RefreshTokenExpiry time.Duration - OneTimeTokenExpiry time.Duration - MaxTokensPerUser int - MaxTokensPerDevice int + AccessSecret string + RefreshSecret string + AccessTokenExpiry time.Duration + RefreshTokenExpiry time.Duration + OneTimeTokenExpiry time.Duration + MaxTokensPerUser int + MaxTokensPerDevice int + } + + // RoleConfig 角色配置 + RoleConfig struct { + // UID 前綴 (例如: AM, RL) + UIDPrefix string + + // UID 數字長度 + UIDLength int + + // 管理員角色 UID + AdminRoleUID string + + // 管理員用戶 UID + AdminUserUID string + + // 預設角色名稱 + DefaultRoleName string } } diff --git a/internal/domain/redis.go b/internal/domain/redis.go new file mode 100644 index 0000000..67b9c8c --- /dev/null +++ b/internal/domain/redis.go @@ -0,0 +1,19 @@ +package domain + +import "strings" + +type RedisKey string + +const ( + GenerateVerifyCodeRedisKey RedisKey = "rf_code" +) + +func (key RedisKey) ToString() string { + return string(key) +} + +func (key RedisKey) With(s ...string) RedisKey { + parts := append([]string{string(key)}, s...) + + return RedisKey(strings.Join(parts, ":")) +} diff --git a/internal/handler/auth/login_handler.go b/internal/handler/auth/login_handler.go index 3d048f1..bf9f198 100644 --- a/internal/handler/auth/login_handler.go +++ b/internal/handler/auth/login_handler.go @@ -6,7 +6,6 @@ import ( "backend/internal/svc" "backend/internal/types" "backend/pkg/library/errs" - ers "backend/pkg/library/errs" "net/http" "github.com/zeromicro/go-zero/rest/httpx" @@ -18,38 +17,39 @@ func LoginHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { var req types.LoginReq if err := httpx.Parse(r, &req); err != nil { e := errs.InvalidFormat(err.Error()) - httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.RespOK{ - Code: int(e.FullCode()), - Msg: err.Error(), + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), }) return } - //if err := svcCtx.Validate.ValidateAll(req); err != nil { - // e := errs.InvalidFormat(err.Error()) - // httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.RespOK{ - // Code: int(e.FullCode()), - // Msg: err.Error(), - // }) - // - // return - //} + if err := svcCtx.Validate.ValidateAll(req); err != nil { + e := errs.InvalidFormat(err.Error()) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), + }) + + return + } l := auth.NewLoginLogic(r.Context(), svcCtx) resp, err := l.Login(&req) + if err != nil { - e := ers.FromError(err) + e := errs.FromError(err) httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{ Code: int(e.FullCode()), Msg: e.Error(), Error: e, }) } else { - httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.RespOK{ - Code: domain.SuccessCode, - Msg: domain.SuccessMessage, - Data: resp, + httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{ + Code: domain.SuccessCode, + Message: domain.SuccessMessage, + Data: resp, }) } } diff --git a/internal/handler/auth/refresh_token_handler.go b/internal/handler/auth/refresh_token_handler.go index e7f5854..8758288 100644 --- a/internal/handler/auth/refresh_token_handler.go +++ b/internal/handler/auth/refresh_token_handler.go @@ -1,6 +1,8 @@ package auth import ( + "backend/internal/domain" + "backend/pkg/library/errs" "net/http" "backend/internal/logic/auth" @@ -15,16 +17,40 @@ func RefreshTokenHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req types.RefreshTokenReq if err := httpx.Parse(r, &req); err != nil { - httpx.ErrorCtx(r.Context(), w, err) + e := errs.InvalidFormat(err.Error()) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), + }) + + return + } + + if err := svcCtx.Validate.ValidateAll(req); err != nil { + e := errs.InvalidFormat(err.Error()) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), + }) + return } l := auth.NewRefreshTokenLogic(r.Context(), svcCtx) resp, err := l.RefreshToken(&req) if err != nil { - httpx.ErrorCtx(r.Context(), w, err) + e := errs.FromError(err) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{ + Code: int(e.FullCode()), + Msg: e.Error(), + Error: e, + }) } else { - httpx.OkJsonCtx(r.Context(), w, resp) + httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{ + Code: domain.SuccessCode, + Message: domain.SuccessMessage, + Data: resp, + }) } } } diff --git a/internal/handler/auth/register_handler.go b/internal/handler/auth/register_handler.go index d2b93ae..7ec6a92 100644 --- a/internal/handler/auth/register_handler.go +++ b/internal/handler/auth/register_handler.go @@ -12,15 +12,14 @@ import ( "github.com/zeromicro/go-zero/rest/httpx" ) -// 註冊新帳號 func RegisterHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req types.LoginReq if err := httpx.Parse(r, &req); err != nil { e := errs.InvalidFormat(err.Error()) - httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.RespOK{ - Code: int(e.FullCode()), - Msg: err.Error(), + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), }) return @@ -28,9 +27,9 @@ func RegisterHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { if err := svcCtx.Validate.ValidateAll(req); err != nil { e := errs.InvalidFormat(err.Error()) - httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.RespOK{ - Code: int(e.FullCode()), - Msg: err.Error(), + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), }) return @@ -46,10 +45,10 @@ func RegisterHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { Error: e, }) } else { - httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.RespOK{ - Code: domain.SuccessCode, - Msg: domain.SuccessMessage, - Data: resp, + httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{ + Code: domain.SuccessCode, + Message: domain.SuccessMessage, + Data: resp, }) } } diff --git a/internal/handler/auth/request_password_reset_handler.go b/internal/handler/auth/request_password_reset_handler.go index 902ac75..68ee9f0 100644 --- a/internal/handler/auth/request_password_reset_handler.go +++ b/internal/handler/auth/request_password_reset_handler.go @@ -1,6 +1,8 @@ package auth import ( + "backend/internal/domain" + "backend/pkg/library/errs" "net/http" "backend/internal/logic/auth" @@ -15,16 +17,40 @@ func RequestPasswordResetHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req types.RequestPasswordResetReq if err := httpx.Parse(r, &req); err != nil { - httpx.ErrorCtx(r.Context(), w, err) + e := errs.InvalidFormat(err.Error()) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), + }) + + return + } + + if err := svcCtx.Validate.ValidateAll(req); err != nil { + e := errs.InvalidFormat(err.Error()) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), + }) + return } l := auth.NewRequestPasswordResetLogic(r.Context(), svcCtx) resp, err := l.RequestPasswordReset(&req) if err != nil { - httpx.ErrorCtx(r.Context(), w, err) + e := errs.FromError(err) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{ + Code: int(e.FullCode()), + Msg: e.Error(), + Error: e, + }) } else { - httpx.OkJsonCtx(r.Context(), w, resp) + httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{ + Code: domain.SuccessCode, + Message: domain.SuccessMessage, + Data: resp, + }) } } } diff --git a/internal/handler/auth/reset_password_handler.go b/internal/handler/auth/reset_password_handler.go index 6100b43..670a788 100644 --- a/internal/handler/auth/reset_password_handler.go +++ b/internal/handler/auth/reset_password_handler.go @@ -1,6 +1,8 @@ package auth import ( + "backend/internal/domain" + "backend/pkg/library/errs" "net/http" "backend/internal/logic/auth" @@ -15,16 +17,40 @@ func ResetPasswordHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req types.ResetPasswordReq if err := httpx.Parse(r, &req); err != nil { - httpx.ErrorCtx(r.Context(), w, err) + e := errs.InvalidFormat(err.Error()) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), + }) + + return + } + + if err := svcCtx.Validate.ValidateAll(req); err != nil { + e := errs.InvalidFormat(err.Error()) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), + }) + return } l := auth.NewResetPasswordLogic(r.Context(), svcCtx) resp, err := l.ResetPassword(&req) if err != nil { - httpx.ErrorCtx(r.Context(), w, err) + e := errs.FromError(err) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{ + Code: int(e.FullCode()), + Msg: e.Error(), + Error: e, + }) } else { - httpx.OkJsonCtx(r.Context(), w, resp) + httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{ + Code: domain.SuccessCode, + Message: domain.SuccessMessage, + Data: resp, + }) } } } diff --git a/internal/handler/auth/verify_password_reset_code_handler.go b/internal/handler/auth/verify_password_reset_code_handler.go index 534388a..8946d36 100644 --- a/internal/handler/auth/verify_password_reset_code_handler.go +++ b/internal/handler/auth/verify_password_reset_code_handler.go @@ -1,6 +1,8 @@ package auth import ( + "backend/internal/domain" + "backend/pkg/library/errs" "net/http" "backend/internal/logic/auth" @@ -15,16 +17,40 @@ func VerifyPasswordResetCodeHandler(svcCtx *svc.ServiceContext) http.HandlerFunc return func(w http.ResponseWriter, r *http.Request) { var req types.VerifyCodeReq if err := httpx.Parse(r, &req); err != nil { - httpx.ErrorCtx(r.Context(), w, err) + e := errs.InvalidFormat(err.Error()) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), + }) + + return + } + + if err := svcCtx.Validate.ValidateAll(req); err != nil { + e := errs.InvalidFormat(err.Error()) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), + }) + return } l := auth.NewVerifyPasswordResetCodeLogic(r.Context(), svcCtx) resp, err := l.VerifyPasswordResetCode(&req) if err != nil { - httpx.ErrorCtx(r.Context(), w, err) + e := errs.FromError(err) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{ + Code: int(e.FullCode()), + Msg: e.Error(), + Error: e, + }) } else { - httpx.OkJsonCtx(r.Context(), w, resp) + httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{ + Code: domain.SuccessCode, + Message: domain.SuccessMessage, + Data: resp, + }) } } } diff --git a/internal/handler/user/get_user_info_handler.go b/internal/handler/user/get_user_info_handler.go index d0e0b44..6dde980 100644 --- a/internal/handler/user/get_user_info_handler.go +++ b/internal/handler/user/get_user_info_handler.go @@ -1,12 +1,13 @@ package user import ( + "backend/internal/domain" + "backend/pkg/library/errs" "net/http" "backend/internal/logic/user" "backend/internal/svc" "backend/internal/types" - "github.com/zeromicro/go-zero/rest/httpx" ) @@ -15,16 +16,40 @@ func GetUserInfoHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var req types.Authorization if err := httpx.Parse(r, &req); err != nil { - httpx.ErrorCtx(r.Context(), w, err) + e := errs.InvalidFormat(err.Error()) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), + }) + + return + } + + if err := svcCtx.Validate.ValidateAll(req); err != nil { + e := errs.InvalidFormat(err.Error()) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{ + Code: int64(e.FullCode()), + Message: err.Error(), + }) + return } l := user.NewGetUserInfoLogic(r.Context(), svcCtx) resp, err := l.GetUserInfo(&req) if err != nil { - httpx.ErrorCtx(r.Context(), w, err) + e := errs.FromError(err) + httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{ + Code: int(e.FullCode()), + Msg: e.Error(), + Error: e, + }) } else { - httpx.OkJsonCtx(r.Context(), w, resp) + httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{ + Code: domain.SuccessCode, + Message: domain.SuccessMessage, + Data: resp, + }) } } } diff --git a/internal/logic/auth/generate_token.go b/internal/logic/auth/generate_token.go index a3cfea0..7c5a81c 100644 --- a/internal/logic/auth/generate_token.go +++ b/internal/logic/auth/generate_token.go @@ -10,10 +10,7 @@ import ( ) // 生成 Token -func generateToken(svc *svc.ServiceContext, ctx context.Context, req *types.LoginReq, uid string) (entity.TokenResp, error) { - // scope role 要修改,refresh tl - role := "user" - +func generateToken(svc *svc.ServiceContext, ctx context.Context, req *types.LoginReq, uid string, role string) (entity.TokenResp, error) { tk, err := svc.TokenUC.NewToken(ctx, entity.AuthorizationReq{ GrantType: token.ClientCredentials.ToString(), DeviceID: uid, // TODO 沒傳暫時先用UID 替代 diff --git a/internal/logic/auth/login_logic.go b/internal/logic/auth/login_logic.go index e0c421b..469c950 100644 --- a/internal/logic/auth/login_logic.go +++ b/internal/logic/auth/login_logic.go @@ -79,7 +79,12 @@ func (l *LoginLogic) Login(req *types.LoginReq) (resp *types.LoginResp, err erro return nil, err } - tk, err := generateToken(l.svcCtx, l.ctx, req, account.UID) + userRole, err := l.svcCtx.UserRoleUC.Get(l.ctx, account.UID) + if err != nil { + return nil, err + } + + tk, err := generateToken(l.svcCtx, l.ctx, req, account.UID, userRole.RoleUID) if err != nil { return nil, err } diff --git a/internal/logic/auth/register_logic.go b/internal/logic/auth/register_logic.go index 16f0d21..94ebca3 100644 --- a/internal/logic/auth/register_logic.go +++ b/internal/logic/auth/register_logic.go @@ -7,6 +7,7 @@ import ( "backend/pkg/library/errs/code" mb "backend/pkg/member/domain/member" member "backend/pkg/member/domain/usecase" + "backend/pkg/permission/domain/usecase" "context" "google.golang.org/protobuf/proto" @@ -76,9 +77,18 @@ func (l *RegisterLogic) Register(req *types.LoginReq) (resp *types.LoginResp, er return nil, err } + _, err = l.svcCtx.UserRoleUC.Assign(l.ctx, usecase.AssignRoleRequest{ + RoleUID: l.svcCtx.Config.RoleConfig.DefaultRoleName, + UserUID: account.UID, + Brand: "digimon", + }) + if err != nil { + return nil, err + } + // Step 5: 生成 Token req.LoginID = bd.CreateAccountReq.LoginID - tk, err := generateToken(l.svcCtx, l.ctx, req, account.UID) + tk, err := generateToken(l.svcCtx, l.ctx, req, account.UID, l.svcCtx.Config.RoleConfig.DefaultRoleName) if err != nil { return nil, err } diff --git a/internal/logic/auth/request_password_reset_logic.go b/internal/logic/auth/request_password_reset_logic.go index 9de0048..81a2a1e 100644 --- a/internal/logic/auth/request_password_reset_logic.go +++ b/internal/logic/auth/request_password_reset_logic.go @@ -1,7 +1,14 @@ package auth import ( + "backend/internal/domain" + "backend/internal/utils" + "backend/pkg/library/errs" + "backend/pkg/library/errs/code" + "backend/pkg/member/domain/member" + "backend/pkg/member/domain/usecase" "context" + "fmt" "backend/internal/svc" "backend/internal/types" @@ -25,7 +32,109 @@ func NewRequestPasswordResetLogic(ctx context.Context, svcCtx *svc.ServiceContex // RequestPasswordReset 請求發送密碼重設驗證碼 aka 忘記密碼 func (l *RequestPasswordResetLogic) RequestPasswordReset(req *types.RequestPasswordResetReq) (resp *types.RespOK, err error) { - // todo: add your logic here and delete this line + // 驗證並標準化帳號 + acc, err := l.validateAndNormalizeAccount(req.AccountType, req.Identifier) + if err != nil { + return nil, err + } - return + // 檢查發送冷卻時間 + rk := domain.GenerateVerifyCodeRedisKey.With(fmt.Sprintf("%s:%d", acc, member.GenerateCodeTypeForgetPassword)).ToString() + if err := l.checkVerifyCodeCooldown(rk); err != nil { + return nil, err + } + + // 確認帳號是否註冊並檢查平台限制 + if err := l.checkAccountAndPlatform(acc); err != nil { + return nil, err + } + + // 生成驗證碼 + vcode, err := l.svcCtx.AccountUC.GenerateRefreshCode(l.ctx, usecase.GenerateRefreshCodeRequest{ + LoginID: acc, + CodeType: member.GenerateCodeTypeForgetPassword, + }) + if err != nil { + return nil, err + } + + // 獲取用戶資訊並確認綁定帳號 + account, err := l.svcCtx.AccountUC.GetUIDByAccount(l.ctx, usecase.GetUIDByAccountRequest{Account: acc}) + if err != nil { + return nil, errs.ResourceNotFoundWithScope(code.CloudEPMember, 0, fmt.Sprintf("account not found:%s", acc)) + } + info, err := l.svcCtx.AccountUC.GetUserInfo(l.ctx, usecase.GetUserInfoRequest{UID: account.UID}) + if err != nil { + return nil, err + } + + // 發送驗證碼 + fmt.Println("======= send", vcode.Data.VerifyCode, &info) + + //nickname := getEmailShowName(&info) + //if err := l.sendVerificationCode(req.AccountType, acc, &info, vcode.Data.VerifyCode, nickname); err != nil { + // return nil, err + //} + + // 設置 Redis 鍵 + l.setRedisKeyWithExpiry(rk, vcode.Data.VerifyCode, 60) + + return &types.RespOK{}, nil +} + +// validateAndNormalizeAccount 驗證並標準化帳號 +func (l *RequestPasswordResetLogic) validateAndNormalizeAccount(accountType, account string) (string, error) { + switch member.GetAccountTypeByCode(accountType) { + case member.AccountTypePhone: + phone, isPhone := utils.NormalizeTaiwanMobile(account) + if !isPhone { + return "", errs.InvalidFormatWithScope(code.CloudEPMember, "phone number is invalid") + } + + return phone, nil + case member.AccountTypeMail: + if !utils.IsValidEmail(account) { + return "", errs.InvalidFormatWithScope(code.CloudEPMember, "email is invalid") + } + + return account, nil + case member.AccountTypeNone, member.AccountTypeDefine: + default: + } + + return "", errs.InvalidFormatWithScope(code.CloudEPMember, "unsupported account type") +} + +// checkVerifyCodeCooldown 檢查是否已在限制時間內發送過驗證碼 +func (l *RequestPasswordResetLogic) checkVerifyCodeCooldown(rk string) error { + if cachedCode, err := l.svcCtx.Redis.GetCtx(l.ctx, rk); err != nil || cachedCode != "" { + return errs.TooManyWithScope(code.CloudEPMember, "verification code already sent, please wait 3min for system to send again") + } + + return nil +} + +// checkAccountAndPlatform 檢查帳號是否註冊及平台限制 +func (l *RequestPasswordResetLogic) checkAccountAndPlatform(acc string) error { + accountInfo, err := l.svcCtx.AccountUC.GetUserAccountInfo(l.ctx, usecase.GetUIDByAccountRequest{Account: acc}) + if err != nil { + return err + } + + if accountInfo.Data.Platform != member.Digimon { + return errs.InvalidFormatWithScope(code.CloudEPMember, + "failed to send verify code since platform not correct") + } + + return nil +} + +// setRedisKeyWithExpiry 設置 Redis 鍵 +func (l *RequestPasswordResetLogic) setRedisKeyWithExpiry(rk, verifyCode string, expiry int) { + if status, err := l.svcCtx.Redis.SetnxExCtx(l.ctx, rk, verifyCode, expiry); err != nil || !status { + _ = errs.DatabaseErrorWithScopeL(code.CloudEPMember, 0, logx.WithContext(l.ctx), []logx.LogField{ + {Key: "redisKey", Value: rk}, + {Key: "error", Value: err.Error()}, + }, "failed to set redis expire").Wrap(err) + } } diff --git a/internal/logic/auth/reset_password_logic.go b/internal/logic/auth/reset_password_logic.go index 3eca800..86bb438 100644 --- a/internal/logic/auth/reset_password_logic.go +++ b/internal/logic/auth/reset_password_logic.go @@ -1,7 +1,15 @@ package auth import ( + "backend/internal/domain" + "backend/pkg/library/errs" + "backend/pkg/library/errs/code" + "backend/pkg/member/domain/member" + "backend/pkg/member/domain/usecase" + "backend/pkg/permission/domain/entity" + "context" + "fmt" "backend/internal/svc" "backend/internal/types" @@ -15,7 +23,7 @@ type ResetPasswordLogic struct { svcCtx *svc.ServiceContext } -// 執行密碼重設 +// NewResetPasswordLogic 執行密碼重設 func NewResetPasswordLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ResetPasswordLogic { return &ResetPasswordLogic{ Logger: logx.WithContext(ctx), @@ -24,8 +32,58 @@ func NewResetPasswordLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Res } } -func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordReq) (resp *types.RespOK, err error) { - // todo: add your logic here and delete this line +func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordReq) (*types.RespOK, error) { + // 驗證密碼,兩次密碼要一致 + if req.Password != req.PasswordConfirm { + return nil, errs.InvalidFormatWithScope(code.CloudEPMember, "password confirmation does not match") + } - return + // 驗證碼 + err := l.svcCtx.AccountUC.VerifyRefreshCode(l.ctx, usecase.VerifyRefreshCodeRequest{ + LoginID: req.Identifier, + CodeType: member.GenerateCodeTypeForgetPassword, + VerifyCode: req.VerifyCode, + }) + if err != nil { + // 表使沒有這驗證碼 + return nil, errs.ForbiddenWithScope(code.CloudEPMember, 0, "failed to get verify code") + } + + info, err := l.svcCtx.AccountUC.GetUserAccountInfo(l.ctx, usecase.GetUIDByAccountRequest{Account: req.Identifier}) + if err != nil { + return nil, err + } + + if info.Data.Platform != member.Digimon { + return nil, errs.ForbiddenWithScope(code.CloudEPMember, 0, "invalid platform") + } + + // 更新 + err = l.svcCtx.AccountUC.UpdateUserToken(l.ctx, usecase.UpdateTokenRequest{ + Account: req.Identifier, + Token: req.Password, + Platform: member.Digimon.ToInt64(), + }) + if err != nil { + return nil, err + } + + rk := domain.GenerateVerifyCodeRedisKey.With( + fmt.Sprintf("%s-%d", req.Identifier, member.GenerateCodeTypeForgetPassword), + ).ToString() + + _, _ = l.svcCtx.Redis.Del(rk) + + ac, err := l.svcCtx.AccountUC.GetUIDByAccount(l.ctx, usecase.GetUIDByAccountRequest{Account: req.Identifier}) + if err != nil { + return nil, err + } + + err = l.svcCtx.TokenUC.CancelTokens(l.ctx, entity.DoTokenByUIDReq{UID: ac.UID}) + if err != nil { + return nil, err + } + + // 返回成功響應 + return &types.RespOK{}, nil } diff --git a/internal/logic/auth/verify_password_reset_code_logic.go b/internal/logic/auth/verify_password_reset_code_logic.go index fe5f6a9..9980057 100644 --- a/internal/logic/auth/verify_password_reset_code_logic.go +++ b/internal/logic/auth/verify_password_reset_code_logic.go @@ -1,6 +1,9 @@ package auth import ( + "backend/pkg/library/errs" + "backend/pkg/member/domain/member" + "backend/pkg/member/domain/usecase" "context" "backend/internal/svc" @@ -25,7 +28,16 @@ func NewVerifyPasswordResetCodeLogic(ctx context.Context, svcCtx *svc.ServiceCon // VerifyPasswordResetCode 校驗密碼重設驗證碼(頁面需求,預先檢查看看, 顯示表演用) func (l *VerifyPasswordResetCodeLogic) VerifyPasswordResetCode(req *types.VerifyCodeReq) (resp *types.RespOK, err error) { - // todo: add your logic here and delete this line + // 先驗證,不刪除 + if err := l.svcCtx.AccountUC.CheckRefreshCode(l.ctx, usecase.VerifyRefreshCodeRequest{ + VerifyCode: req.VerifyCode, + LoginID: req.Identifier, + CodeType: member.GenerateCodeTypeForgetPassword, + }); err != nil { + e := errs.Forbidden("failed to get verify code").Wrap(err) - return + return nil, e + } + + return &types.RespOK{}, nil } diff --git a/internal/logic/user/get_user_info_logic.go b/internal/logic/user/get_user_info_logic.go index 5c945cd..4b1ac5e 100644 --- a/internal/logic/user/get_user_info_logic.go +++ b/internal/logic/user/get_user_info_logic.go @@ -1,7 +1,12 @@ package user import ( + "backend/pkg/member/domain/member" + "backend/pkg/member/domain/usecase" + "backend/pkg/permission/domain/token" "context" + "google.golang.org/protobuf/proto" + "time" "backend/internal/svc" "backend/internal/types" @@ -15,7 +20,7 @@ type GetUserInfoLogic struct { svcCtx *svc.ServiceContext } -// 取得當前登入的會員資訊(自己) +// NewGetUserInfoLogic 取得當前登入的會員資訊(自己) func NewGetUserInfoLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserInfoLogic { return &GetUserInfoLogic{ Logger: logx.WithContext(ctx), @@ -24,8 +29,88 @@ func NewGetUserInfoLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUs } } -func (l *GetUserInfoLogic) GetUserInfo(req *types.Authorization) (resp *types.UserInfoResp, err error) { - // todo: add your logic here and delete this line +func (l *GetUserInfoLogic) GetUserInfo(req *types.Authorization) (*types.MyInfo, error) { + uid := token.UID(l.ctx) + info, err := l.svcCtx.AccountUC.GetUserInfo(l.ctx, usecase.GetUserInfoRequest{ + UID: uid, + }) + if err != nil { + return nil, err + } - return + byUID, err := l.svcCtx.AccountUC.FindLoginIDByUID(l.ctx, uid) + if err != nil { + return nil, err + } + + accountInfo, err := l.svcCtx.AccountUC.GetUserAccountInfo(l.ctx, usecase.GetUIDByAccountRequest{ + Account: byUID.LoginID, + }) + if err != nil { + return nil, err + } + + userRole, err := l.svcCtx.UserRoleUC.Get(l.ctx, uid) + if err != nil { + return nil, err + } + role := userRole.RoleUID + res := &types.MyInfo{ + Platform: accountInfo.Data.Platform.ToString(), + UID: info.UID, + UpdateAt: time.Unix(0, info.CreateTime).UTC().Format(time.RFC3339), + CreateAt: time.Unix(0, info.UpdateTime).UTC().Format(time.RFC3339), + Role: role, + UserStatus: info.UserStatus.CodeToString(), + PreferredLanguage: info.PreferredLanguage, + Currency: info.Currency, + AlarmCategory: info.AlarmCategory.CodeToString(), + } + if info.Address != nil { + res.Address = info.Address + } + if info.AvatarURL != nil { + res.AvatarURL = info.AvatarURL + } + if info.FullName != nil { + res.FullName = info.FullName + } + + if info.Birthdate != nil { + b := ToDate(info.Birthdate) + res.Birthdate = b + } + + if info.Address != nil { + res.Address = info.Address + } + + if info.Nickname != nil { + res.Nickname = info.Nickname + } + + if info.Email != nil { + res.Email = info.Email + res.IsEmailVerified = proto.Bool(true) + } + + if info.PhoneNumber != nil { + res.PhoneNumber = info.PhoneNumber + res.IsPhoneVerified = proto.Bool(true) + } + if info.GenderCode != nil { + gc := member.GetGenderByCode(*info.GenderCode) + res.GenderCode = &gc + } + + return res, nil +} + +func ToDate(n *int64) *string { + result := "" + if n != nil { + result = time.Unix(*n, 0).UTC().Format(time.DateOnly) + } + + return &result } diff --git a/internal/middleware/auth_middleware.go b/internal/middleware/auth_middleware.go index e9ed61e..64157f5 100644 --- a/internal/middleware/auth_middleware.go +++ b/internal/middleware/auth_middleware.go @@ -1,19 +1,85 @@ package middleware -import "net/http" +import ( + "backend/internal/types" + "backend/pkg/library/errs" + "backend/pkg/permission/domain/entity" + "backend/pkg/permission/domain/token" + "context" + "github.com/zeromicro/go-zero/rest/httpx" -type AuthMiddleware struct { + "backend/pkg/permission/domain/usecase" + uc "backend/pkg/permission/usecase" + "net/http" +) + +type AuthMiddlewareParam struct { + TokenSec string + TokenUseCase usecase.TokenUseCase } -func NewAuthMiddleware() *AuthMiddleware { - return &AuthMiddleware{} +type AuthMiddleware struct { + TokenSec string + TokenUseCase usecase.TokenUseCase +} + +func NewAuthMiddleware(param AuthMiddlewareParam) *AuthMiddleware { + return &AuthMiddleware{ + TokenSec: param.TokenSec, + TokenUseCase: param.TokenUseCase, + } } func (m *AuthMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - // TODO generate middleware implement function, delete after code implementation + // 解析 Header + header := types.Authorization{} + if err := httpx.ParseHeaders(r, &header); err != nil { + m.writeErrorResponse(w, r, http.StatusBadRequest, "Failed to parse headers", int64(errs.InvalidFormat("").FullCode())) - // Passthrough to next handler if need - next(w, r) + return + } + + // 驗證 Token + claim, err := uc.ParseClaims(header.Authorization, m.TokenSec, true) + if err != nil { + // 是否需要紀錄錯誤,是不是只要紀錄除了驗證失敗或過期之外的真錯誤 + m.writeErrorResponse(w, r, + http.StatusUnauthorized, "failed to verify toke", + int64(100400)) + + return + } + + // 驗證 Token 是否在黑名單中 + if _, err := m.TokenUseCase.ValidationToken(r.Context(), entity.ValidationTokenReq{Token: header.Authorization}); err != nil { + m.writeErrorResponse(w, r, http.StatusForbidden, + "failed to get toke", + int64(100400)) + + return + } + + // 設置 context 並傳遞給下一個處理器 + ctx := SetContext(r, claim) + next(w, r.WithContext(ctx)) } } + +func SetContext(r *http.Request, claim uc.TokenClaims) context.Context { + ctx := context.WithValue(r.Context(), token.KeyRole, claim.Role()) + ctx = context.WithValue(ctx, token.KeyUID, claim.UID()) + ctx = context.WithValue(ctx, token.KeyDeviceID, claim.DeviceID()) + ctx = context.WithValue(ctx, token.KeyScope, claim.Scope()) + ctx = context.WithValue(ctx, token.KeyLoginID, claim.LoginID()) + + return ctx +} + +// writeErrorResponse 用於處理錯誤回應 +func (m *AuthMiddleware) writeErrorResponse(w http.ResponseWriter, r *http.Request, statusCode int, message string, code int64) { + httpx.WriteJsonCtx(r.Context(), w, statusCode, types.ErrorResp{ + Code: int(code), + Msg: message, + }) +} diff --git a/internal/svc/service_context.go b/internal/svc/service_context.go index b80c20a..3c11f77 100644 --- a/internal/svc/service_context.go +++ b/internal/svc/service_context.go @@ -19,6 +19,11 @@ type ServiceContext struct { AccountUC memberUC.AccountUseCase Validate vi.Validate TokenUC tokenUC.TokenUseCase + PermissionUC tokenUC.PermissionUseCase + RoleUC tokenUC.RoleUseCase + RolePermission tokenUC.RolePermissionUseCase + UserRoleUC tokenUC.UserRoleUseCase + Redis *redis.Redis } func NewServiceContext(c config.Config) *ServiceContext { @@ -28,11 +33,22 @@ func NewServiceContext(c config.Config) *ServiceContext { } errs.Scope = code.CloudEPPortalGW + rp := NewPermissionUC(&c) + tkUC := NewTokenUC(&c, rds) + return &ServiceContext{ - Config: c, - AuthMiddleware: middleware.NewAuthMiddleware().Handle, + Config: c, + AuthMiddleware: middleware.NewAuthMiddleware(middleware.AuthMiddlewareParam{ + TokenSec: c.Token.AccessSecret, + TokenUseCase: tkUC, + }).Handle, AccountUC: NewAccountUC(&c, rds), Validate: vi.MustValidator(), - TokenUC: NewTokenUC(&c, rds), + TokenUC: tkUC, + PermissionUC: rp.PermissionUC, + RoleUC: rp.RoleUC, + RolePermission: rp.RolePermission, + UserRoleUC: rp.UserRole, + Redis: rds, } } diff --git a/internal/svc/token.go b/internal/svc/token.go index c1aaf8a..b000170 100644 --- a/internal/svc/token.go +++ b/internal/svc/token.go @@ -2,9 +2,12 @@ package svc import ( "backend/internal/config" + mgo "backend/pkg/library/mongo" "backend/pkg/permission/domain/usecase" "backend/pkg/permission/repository" uc "backend/pkg/permission/usecase" + "github.com/zeromicro/go-zero/core/stores/cache" + "github.com/zeromicro/go-zero/core/stores/mon" "github.com/zeromicro/go-zero/core/stores/redis" ) @@ -16,3 +19,102 @@ func NewTokenUC(c *config.Config, rds *redis.Redis) usecase.TokenUseCase { Config: c, }) } + +type PermissionUC struct { + PermissionUC usecase.PermissionUseCase + RoleUC usecase.RoleUseCase + RolePermission usecase.RolePermissionUseCase + UserRole usecase.UserRoleUseCase +} + +func NewPermissionUC(c *config.Config) PermissionUC { + // 準備Mongo Config + conf := &mgo.Conf{ + Schema: c.Mongo.Schema, + Host: c.Mongo.Host, + Database: c.Mongo.Database, + MaxStaleness: c.Mongo.MaxStaleness, + MaxPoolSize: c.Mongo.MaxPoolSize, + MinPoolSize: c.Mongo.MinPoolSize, + MaxConnIdleTime: c.Mongo.MaxConnIdleTime, + Compressors: c.Mongo.Compressors, + EnableStandardReadWriteSplitMode: c.Mongo.EnableStandardReadWriteSplitMode, + ConnectTimeoutMs: c.Mongo.ConnectTimeoutMs, + } + if c.Mongo.User != "" { + conf.User = c.Mongo.User + conf.Password = c.Mongo.Password + } + + // 快取選項 + cacheOpts := []cache.Option{ + cache.WithExpiry(c.CacheExpireTime), + cache.WithNotFoundExpiry(c.CacheWithNotFoundExpiry), + } + dbOpts := []mon.Option{ + mgo.SetCustomDecimalType(), + mgo.InitMongoOptions(*conf), + } + permRepo := repository.NewPermissionRepository(repository.PermissionRepositoryParam{ + Conf: conf, + CacheConf: c.Cache, + CacheOpts: cacheOpts, + DBOpts: dbOpts, + }) + + rolePermRepo := repository.NewRolePermissionRepository(repository.RolePermissionRepositoryParam{ + Conf: conf, + CacheConf: c.Cache, + CacheOpts: cacheOpts, + DBOpts: dbOpts, + }) + + roleRepo := repository.NewRoleRepository(repository.RoleRepositoryParam{ + Conf: conf, + CacheConf: c.Cache, + CacheOpts: cacheOpts, + DBOpts: dbOpts, + }) + + userRoleRepo := repository.NewUserRoleRepository(repository.UserRoleRepositoryParam{ + Conf: conf, + CacheConf: c.Cache, + CacheOpts: cacheOpts, + DBOpts: dbOpts, + }) + + puc := uc.NewPermissionUseCase(uc.PermissionUseCaseParam{ + RoleRepo: roleRepo, + RolePermRepo: rolePermRepo, + UserRoleRepo: userRoleRepo, + PermRepo: permRepo, + }) + rpuc := uc.NewRolePermissionUseCase(uc.RolePermissionUseCaseParam{ + RoleRepo: roleRepo, + RolePermRepo: rolePermRepo, + UserRoleRepo: userRoleRepo, + PermRepo: permRepo, + PermUseCase: puc, + AdminRoleUID: c.RoleConfig.AdminRoleUID, + }) + ruc := uc.NewRoleUseCase(uc.RoleUseCaseParam{ + RoleRepo: roleRepo, + UserRoleRepo: userRoleRepo, + Config: uc.RoleUseCaseConfig{ + AdminRoleUID: c.RoleConfig.AdminRoleUID, + UIDPrefix: c.RoleConfig.UIDPrefix, + UIDLength: c.RoleConfig.UIDLength, + }, + RolePermUseCase: rpuc, + }) + + return PermissionUC{ + PermissionUC: puc, + RolePermission: rpuc, + RoleUC: ruc, + UserRole: uc.NewUserRoleUseCase(uc.UserRoleUseCaseParam{ + UserRoleRepo: userRoleRepo, + RoleRepo: roleRepo, + }), + } +} diff --git a/internal/types/types.go b/internal/types/types.go index b2e9bc8..5f25016 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -37,6 +37,30 @@ type LoginResp struct { TokenType string `json:"token_type"` // 通常固定為 "Bearer" } +type MyInfo struct { + Platform string `json:"platform"` // 註冊平台 + UID string `json:"uid"` // 用戶 UID + AvatarURL *string `json:"avatar_url,omitempty"` // 頭像 URL + FullName *string `json:"full_name,omitempty"` // 用戶全名 + Nickname *string `json:"nickname,omitempty"` // 暱稱 + GenderCode *string `json:"gender_code,omitempty"` // 性別代碼 + Birthdate *string `json:"birthdate,omitempty"` // 生日 (格式: 1993-04-17) + PhoneNumber *string `json:"phone_number,omitempty"` // 電話 + IsPhoneVerified *bool `json:"is_phone_verified,omitempty"` // 手機是否已驗證 + Email *string `json:"email,omitempty"` // 信箱 + IsEmailVerified *bool `json:"is_email_verified,omitempty"` // 信箱是否已驗證 + Address *string `json:"address,omitempty"` // 地址 + UserStatus string `json:"user_status,omitempty"` // 用戶狀態 + PreferredLanguage string `json:"preferred_language,omitempty"` // 偏好語言 + Currency string `json:"currency,omitempty"` // 偏好幣種 + AlarmCategory string `json:"alarm_category,omitempty"` // 告警狀態 + PostCode *string `json:"post_code,omitempty"` // 郵遞區號 + Carrier *string `json:"carrier,omitempty"` // 載具 + Role string `json:"role"` // 角色 + UpdateAt string `json:"update_at"` + CreateAt string `json:"create_at"` +} + type PagerResp struct { Total int64 `json:"total"` Size int64 `json:"size"` @@ -60,7 +84,7 @@ type RefreshTokenResp struct { } type RequestPasswordResetReq struct { - Identifier string `json:"identifier" validate:"required,email|phone"` // 使用者帳號 (信箱或手機) + Identifier string `json:"identifier" validate:"required"` // 使用者帳號 (信箱或手機) AccountType string `json:"account_type" validate:"required,oneof=email phone"` } @@ -77,9 +101,12 @@ type ResetPasswordReq struct { } type RespOK struct { - Code int `json:"code"` - Msg string `json:"msg"` - Data interface{} `json:"data,omitempty"` +} + +type Status struct { + Code int64 `json:"code"` // 狀態碼 + Message string `json:"message"` // 訊息 + Data interface{} `json:"data,omitempty"` // 可選的資料,當有返回時才出現 } type SubmitVerificationCodeReq struct { diff --git a/internal/utils/format.go b/internal/utils/format.go new file mode 100644 index 0000000..5c14300 --- /dev/null +++ b/internal/utils/format.go @@ -0,0 +1,36 @@ +package utils + +import ( + "regexp" + "strings" +) + +// NormalizeTaiwanMobile 標準化號碼並驗證是否為合法台灣手機號碼 +func NormalizeTaiwanMobile(phone string) (string, bool) { + // 移除空格 + phone = strings.ReplaceAll(phone, " ", "") + + // 移除 "+886" 並將剩餘部分標準化 + if strings.HasPrefix(phone, "+886") { + phone = strings.TrimPrefix(phone, "+886") + if !strings.HasPrefix(phone, "0") { + phone = "0" + phone + } + } + + // 正則表達式驗證標準化後的號碼 + regex := regexp.MustCompile(`^(09\d{8})$`) + if regex.MatchString(phone) { + return phone, true + } + + return "", false +} + +// IsValidEmail 驗證 Email 格式的函數 +func IsValidEmail(email string) bool { + // 定義正則表達式 + regex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`) + + return regex.MatchString(email) +} diff --git a/pkg/library/errs/code/code.go b/pkg/library/errs/code/code.go index 28f28b5..eae1400 100644 --- a/pkg/library/errs/code/code.go +++ b/pkg/library/errs/code/code.go @@ -35,6 +35,7 @@ const ( InsufficientQuota // 配額不足 ResourceHasMultiOwner // 資源有多個所有者 UserSuspended // 沒有權限使用該資源 + TooManyRequest // 單位時間內請求太多次 ) /* 詳細代碼 - GRPC */ @@ -77,13 +78,13 @@ const ( // 詳細代碼 - Token 類 09x const ( - _ = iota + CatToken - TokenCreateError // Token 創建錯誤 - TokenValidateError // Token 驗證錯誤 - TokenExpired // Token 過期 - TokenNotFound // Token 未找到 - TokenBlacklisted // Token 已被列入黑名單 - InvalidJWT // 無效的 JWT - RefreshTokenError // Refresh Token 錯誤 - OneTimeTokenError // 一次性 Token 錯誤 + _ = iota + CatToken + TokenCreateError // Token 創建錯誤 + TokenValidateError // Token 驗證錯誤 + TokenExpired // Token 過期 + TokenNotFound // Token 未找到 + TokenBlacklisted // Token 已被列入黑名單 + InvalidJWT // 無效的 JWT + RefreshTokenError // Refresh Token 錯誤 + OneTimeTokenError // 一次性 Token 錯誤 ) diff --git a/pkg/library/errs/easy_func.go b/pkg/library/errs/easy_func.go index dd53784..e299e11 100644 --- a/pkg/library/errs/easy_func.go +++ b/pkg/library/errs/easy_func.go @@ -556,3 +556,8 @@ func MsgSizeTooLargeL(l logx.Logger, filed []logx.LogField, s ...string) *LibErr return e } + +func TooManyWithScope(scope uint32, s ...string) *LibError { + return NewError(scope, code.TooManyRequest, defaultDetailCode, + fmt.Sprintf("%s", strings.Join(s, " "))) +} diff --git a/pkg/library/errs/errors.go b/pkg/library/errs/errors.go index e65cebe..95cf8c2 100644 --- a/pkg/library/errs/errors.go +++ b/pkg/library/errs/errors.go @@ -154,8 +154,13 @@ func (e *LibError) HTTPStatus() int { if e == nil || e.Code() == code.OK { return http.StatusOK } + + // 將 code 轉換為與常量定義相同的格式 (category + detail) + // 例如:code=3004 -> (3004%100) + 30 = 4 + 30 = 34 + codeValue := (e.Code() % 100) + e.Category() + // 根據錯誤碼判斷對應的 HTTP 狀態碼 - switch e.Code() / 100 { + switch codeValue { case code.ResourceInsufficient, code.InvalidFormat: // 如果資源不足,返回 400 狀態碼 return http.StatusBadRequest @@ -177,6 +182,9 @@ func (e *LibError) HTTPStatus() int { case code.NotValidImplementation: // 如果實現無效,返回 501 狀態碼 return http.StatusNotImplemented + case code.TooManyRequest: + // 如果實現無效,返回 501 狀態碼 + return http.StatusTooManyRequests default: // 如果沒有匹配的錯誤碼,則繼續下一步 } diff --git a/pkg/library/errs/errors_test.go b/pkg/library/errs/errors_test.go index f690ad6..dbbff07 100644 --- a/pkg/library/errs/errors_test.go +++ b/pkg/library/errs/errors_test.go @@ -71,12 +71,12 @@ func TestLibError_HTTPStatus(t *testing.T) { err *LibError expected int }{ - {"bad request", NewError(1, code.CatService, code.ResourceInsufficient, "bad request"), http.StatusBadRequest}, - {"unauthorized", NewError(1, code.CatAuth, code.Unauthorized, "unauthorized"), http.StatusUnauthorized}, - {"forbidden", NewError(1, code.CatAuth, code.Forbidden, "forbidden"), http.StatusForbidden}, - {"not found", NewError(1, code.CatResource, code.ResourceNotFound, "not found"), http.StatusNotFound}, - {"internal server error", NewError(1, code.CatDB, 1095, "not found"), http.StatusInternalServerError}, - {"input err", NewError(1, code.CatInput, 1095, "not found"), http.StatusBadRequest}, + {"bad request - ResourceInsufficient", NewError(1, code.CatResource, 4, "bad request"), http.StatusBadRequest}, + {"unauthorized", NewError(1, code.CatAuth, 1, "unauthorized"), http.StatusUnauthorized}, + {"forbidden", NewError(1, code.CatAuth, 5, "forbidden"), http.StatusForbidden}, + {"not found", NewError(1, code.CatResource, 1, "not found"), http.StatusNotFound}, + {"internal server error", NewError(1, code.CatDB, 95, "db error"), http.StatusInternalServerError}, + {"input err", NewError(1, code.CatInput, 1, "input error"), http.StatusBadRequest}, } for _, tt := range tests { diff --git a/pkg/member/domain/member/default.go b/pkg/member/domain/member/default.go index 458bafb..920ed2e 100644 --- a/pkg/member/domain/member/default.go +++ b/pkg/member/domain/member/default.go @@ -19,3 +19,35 @@ const ( const ( CurrencyTWD Currency = "TWD" ) + +var genderMap = map[int64]string{ + 0: "", + 1: "male", + 2: "female", + 3: "secret", +} + +func GetGenderByCode(g int64) string { + r, ok := genderMap[g] + if !ok { + return genderMap[0] + } + + return r +} + +var genderCodeMap = map[string]int64{ + "": 0, + "male": 1, + "female": 2, + "secret": 3, +} + +func GetGenderCodeByStr(g string) int64 { + r, ok := genderCodeMap[g] + if !ok { + return genderCodeMap[""] + } + + return r +} diff --git a/pkg/member/domain/repository/account_uid.go b/pkg/member/domain/repository/account_uid.go index 9f099f2..cd165f6 100644 --- a/pkg/member/domain/repository/account_uid.go +++ b/pkg/member/domain/repository/account_uid.go @@ -14,6 +14,7 @@ type AccountUIDRepository interface { Update(ctx context.Context, data *entity.AccountUID) (*mongo.UpdateResult, error) Delete(ctx context.Context, id string) (int64, error) FindUIDByLoginID(ctx context.Context, loginID string) (*entity.AccountUID, error) + FindOneByUID(ctx context.Context, uid string) (*entity.AccountUID, error) AccountUIDIndexUP } diff --git a/pkg/member/domain/usecase/account.go b/pkg/member/domain/usecase/account.go index c5f56dd..34f1a6f 100644 --- a/pkg/member/domain/usecase/account.go +++ b/pkg/member/domain/usecase/account.go @@ -31,6 +31,8 @@ type MemberUseCase interface { GetUserInfo(ctx context.Context, req GetUserInfoRequest) (UserInfo, error) // ListMember 取得會員列表 ListMember(ctx context.Context, req ListUserInfoRequest) (ListUserInfoResponse, error) + // FindLoginIDByUID 取得login id + FindLoginIDByUID(ctx context.Context, uid string) (BindingUser, error) } type BindingMemberUseCase interface { diff --git a/pkg/member/repository/account_uid.go b/pkg/member/repository/account_uid.go index d70cbe4..321cf05 100644 --- a/pkg/member/repository/account_uid.go +++ b/pkg/member/repository/account_uid.go @@ -112,6 +112,20 @@ func (repo *AccountUIDRepository) FindUIDByLoginID(ctx context.Context, loginID } } +func (repo *AccountUIDRepository) FindOneByUID(ctx context.Context, uid string) (*entity.AccountUID, error) { + var data entity.AccountUID + + err := repo.DB.GetClient().FindOne(ctx, &data, bson.M{"uid": uid}) + switch { + case err == nil: + return &data, nil + case errors.Is(err, mon.ErrNotFound): + return nil, ErrNotFound + default: + return nil, err + } +} + func (repo *AccountUIDRepository) Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error) { // 等價於 db.account_uid_binding.createIndex({"login_id": 1}, {unique: true}) repo.DB.PopulateIndex(ctx, "login_id", 1, true) diff --git a/pkg/member/repository/user.go b/pkg/member/repository/user.go index 2b11f97..95ffd6a 100644 --- a/pkg/member/repository/user.go +++ b/pkg/member/repository/user.go @@ -205,7 +205,7 @@ func (repo *UserRepository) FindOneByUID(ctx context.Context, uid string) (*enti // 不常寫,再找一次可接受 id := repo.UIDToID(ctx, uid) if id == "" { - return nil, errors.New("invalid uid") + return nil, ErrNotFound } rk := domain.GetUserRedisKey(id) diff --git a/pkg/member/usecase/account.go b/pkg/member/usecase/account.go index 9575d45..07f1e14 100644 --- a/pkg/member/usecase/account.go +++ b/pkg/member/usecase/account.go @@ -4,6 +4,7 @@ import ( "backend/pkg/member/domain/config" "backend/pkg/member/domain/repository" "backend/pkg/member/domain/usecase" + "context" ) type MemberUseCaseParam struct { @@ -24,3 +25,15 @@ func MustMemberUseCase(param MemberUseCaseParam) usecase.AccountUseCase { param, } } + +func (use *MemberUseCase) FindLoginIDByUID(ctx context.Context, uid string) (usecase.BindingUser, error) { + data, err := use.AccountUID.FindOneByUID(ctx, uid) + if err != nil { + return usecase.BindingUser{}, err + } + + return usecase.BindingUser{ + UID: data.UID, + LoginID: data.LoginID, + }, nil +} diff --git a/pkg/notification/config/config.go b/pkg/notification/config/config.go index efcdec4..1988c0c 100644 --- a/pkg/notification/config/config.go +++ b/pkg/notification/config/config.go @@ -1,6 +1,10 @@ package config -import "time" +import ( + "errors" + "fmt" + "time" +) type SMTPConfig struct { Enable bool @@ -13,6 +17,35 @@ type SMTPConfig struct { Password string } +// Validate 驗證 SMTP 配置 +func (c *SMTPConfig) Validate() error { + if !c.Enable { + return nil // 未啟用則不驗證 + } + + if c.Host == "" { + return errors.New("smtp host is required") + } + + if c.Port <= 0 || c.Port > 65535 { + return fmt.Errorf("smtp port must be between 1 and 65535, got %d", c.Port) + } + + if c.Username == "" { + return errors.New("smtp username is required") + } + + if c.Password == "" { + return errors.New("smtp password is required") + } + + if c.Sort < 0 { + return fmt.Errorf("smtp sort must be >= 0, got %d", c.Sort) + } + + return nil +} + type AmazonSesSettings struct { Enable bool Sort int @@ -26,6 +59,39 @@ type AmazonSesSettings struct { Token string } +// Validate 驗證 AWS SES 配置 +func (c *AmazonSesSettings) Validate() error { + if !c.Enable { + return nil // 未啟用則不驗證 + } + + if c.Region == "" { + return errors.New("aws ses region is required") + } + + if c.Sender == "" { + return errors.New("aws ses sender is required") + } + + if c.AccessKey == "" { + return errors.New("aws ses access key is required") + } + + if c.SecretKey == "" { + return errors.New("aws ses secret key is required") + } + + if c.Sort < 0 { + return fmt.Errorf("aws ses sort must be >= 0, got %d", c.Sort) + } + + if c.PoolSize < 0 { + return fmt.Errorf("aws ses pool size must be >= 0, got %d", c.PoolSize) + } + + return nil +} + type MitakeSMSSender struct { Enable bool Sort int @@ -35,6 +101,31 @@ type MitakeSMSSender struct { Password string } +// Validate 驗證 Mitake SMS 配置 +func (c *MitakeSMSSender) Validate() error { + if !c.Enable { + return nil // 未啟用則不驗證 + } + + if c.User == "" { + return errors.New("mitake user is required") + } + + if c.Password == "" { + return errors.New("mitake password is required") + } + + if c.Sort < 0 { + return fmt.Errorf("mitake sort must be >= 0, got %d", c.Sort) + } + + if c.PoolSize < 0 { + return fmt.Errorf("mitake pool size must be >= 0, got %d", c.PoolSize) + } + + return nil +} + // DeliveryConfig 傳送重試配置 type DeliveryConfig struct { MaxRetries int `json:"max_retries"` // 最大重試次數 @@ -44,3 +135,72 @@ type DeliveryConfig struct { Timeout time.Duration `json:"timeout"` // 單次發送超時時間 EnableHistory bool `json:"enable_history"` // 是否啟用歷史記錄 } + +// Validate 驗證 DeliveryConfig 配置 +func (c *DeliveryConfig) Validate() error { + if c.MaxRetries < 0 { + return fmt.Errorf("max_retries must be >= 0, got %d", c.MaxRetries) + } + + if c.MaxRetries > 10 { + return fmt.Errorf("max_retries should not exceed 10, got %d", c.MaxRetries) + } + + if c.InitialDelay < 0 { + return fmt.Errorf("initial_delay must be >= 0, got %v", c.InitialDelay) + } + + if c.InitialDelay > 10*time.Second { + return fmt.Errorf("initial_delay is too large (> 10s), got %v", c.InitialDelay) + } + + if c.BackoffFactor < 1.0 { + return fmt.Errorf("backoff_factor must be >= 1.0, got %v", c.BackoffFactor) + } + + if c.BackoffFactor > 10.0 { + return fmt.Errorf("backoff_factor is too large (> 10.0), got %v", c.BackoffFactor) + } + + if c.MaxDelay < 0 { + return fmt.Errorf("max_delay must be >= 0, got %v", c.MaxDelay) + } + + if c.MaxDelay > 5*time.Minute { + return fmt.Errorf("max_delay is too large (> 5m), got %v", c.MaxDelay) + } + + if c.Timeout <= 0 { + return fmt.Errorf("timeout must be > 0, got %v", c.Timeout) + } + + if c.Timeout > 5*time.Minute { + return fmt.Errorf("timeout is too large (> 5m), got %v", c.Timeout) + } + + // 檢查 InitialDelay 和 MaxDelay 的關係 + if c.InitialDelay > c.MaxDelay && c.MaxDelay > 0 { + return fmt.Errorf("initial_delay (%v) should not exceed max_delay (%v)", c.InitialDelay, c.MaxDelay) + } + + return nil +} + +// SetDefaults 設置默認值 +func (c *DeliveryConfig) SetDefaults() { + if c.MaxRetries == 0 { + c.MaxRetries = 3 + } + if c.InitialDelay == 0 { + c.InitialDelay = 100 * time.Millisecond + } + if c.BackoffFactor == 0 { + c.BackoffFactor = 2.0 + } + if c.MaxDelay == 0 { + c.MaxDelay = 30 * time.Second + } + if c.Timeout == 0 { + c.Timeout = 30 * time.Second + } +} diff --git a/pkg/notification/config/config_test.go b/pkg/notification/config/config_test.go new file mode 100644 index 0000000..f5ab901 --- /dev/null +++ b/pkg/notification/config/config_test.go @@ -0,0 +1,314 @@ +package config + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestSMTPConfig_Validate(t *testing.T) { + tests := []struct { + name string + config SMTPConfig + wantErr bool + errMsg string + }{ + { + name: "有效的 SMTP 配置", + config: SMTPConfig{ + Enable: true, + Sort: 1, + Host: "smtp.gmail.com", + Port: 587, + Username: "test@gmail.com", + Password: "password", + }, + wantErr: false, + }, + { + name: "未啟用的配置(不驗證)", + config: SMTPConfig{ + Enable: false, + }, + wantErr: false, + }, + { + name: "缺少 Host", + config: SMTPConfig{ + Enable: true, + Port: 587, + Username: "test@gmail.com", + Password: "password", + }, + wantErr: true, + errMsg: "smtp host is required", + }, + { + name: "無效的 Port", + config: SMTPConfig{ + Enable: true, + Host: "smtp.gmail.com", + Port: 99999, + Username: "test@gmail.com", + Password: "password", + }, + wantErr: true, + errMsg: "smtp port must be between 1 and 65535", + }, + { + name: "缺少 Username", + config: SMTPConfig{ + Enable: true, + Host: "smtp.gmail.com", + Port: 587, + Password: "password", + }, + wantErr: true, + errMsg: "smtp username is required", + }, + { + name: "負數的 Sort", + config: SMTPConfig{ + Enable: true, + Sort: -1, + Host: "smtp.gmail.com", + Port: 587, + Username: "test@gmail.com", + Password: "password", + }, + wantErr: true, + errMsg: "smtp sort must be >= 0", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if tt.wantErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestAmazonSesSettings_Validate(t *testing.T) { + tests := []struct { + name string + config AmazonSesSettings + wantErr bool + errMsg string + }{ + { + name: "有效的 AWS SES 配置", + config: AmazonSesSettings{ + Enable: true, + Sort: 1, + PoolSize: 10, + Region: "us-west-2", + Sender: "noreply@example.com", + AccessKey: "AKIAIOSFODNN7EXAMPLE", + SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + }, + wantErr: false, + }, + { + name: "未啟用的配置", + config: AmazonSesSettings{ + Enable: false, + }, + wantErr: false, + }, + { + name: "缺少 Region", + config: AmazonSesSettings{ + Enable: true, + Sender: "noreply@example.com", + AccessKey: "AKIAIOSFODNN7EXAMPLE", + SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + }, + wantErr: true, + errMsg: "aws ses region is required", + }, + { + name: "缺少 AccessKey", + config: AmazonSesSettings{ + Enable: true, + Region: "us-west-2", + Sender: "noreply@example.com", + SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + }, + wantErr: true, + errMsg: "aws ses access key is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if tt.wantErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestMitakeSMSSender_Validate(t *testing.T) { + tests := []struct { + name string + config MitakeSMSSender + wantErr bool + errMsg string + }{ + { + name: "有效的 Mitake 配置", + config: MitakeSMSSender{ + Enable: true, + Sort: 1, + PoolSize: 5, + User: "testuser", + Password: "testpass", + }, + wantErr: false, + }, + { + name: "缺少 User", + config: MitakeSMSSender{ + Enable: true, + Password: "testpass", + }, + wantErr: true, + errMsg: "mitake user is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if tt.wantErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDeliveryConfig_Validate(t *testing.T) { + tests := []struct { + name string + config DeliveryConfig + wantErr bool + errMsg string + }{ + { + name: "有效的配置", + config: DeliveryConfig{ + MaxRetries: 3, + InitialDelay: 100 * time.Millisecond, + BackoffFactor: 2.0, + MaxDelay: 30 * time.Second, + Timeout: 30 * time.Second, + }, + wantErr: false, + }, + { + name: "MaxRetries 為負數", + config: DeliveryConfig{ + MaxRetries: -1, + Timeout: 30 * time.Second, + }, + wantErr: true, + errMsg: "max_retries must be >= 0", + }, + { + name: "MaxRetries 過大", + config: DeliveryConfig{ + MaxRetries: 20, + Timeout: 30 * time.Second, + }, + wantErr: true, + errMsg: "max_retries should not exceed 10", + }, + { + name: "BackoffFactor 小於 1.0", + config: DeliveryConfig{ + MaxRetries: 3, + BackoffFactor: 0.5, + Timeout: 30 * time.Second, + }, + wantErr: true, + errMsg: "backoff_factor must be >= 1.0", + }, + { + name: "Timeout 為 0", + config: DeliveryConfig{ + MaxRetries: 3, + BackoffFactor: 2.0, + Timeout: 0, + }, + wantErr: true, + errMsg: "timeout must be > 0", + }, + { + name: "InitialDelay 大於 MaxDelay", + config: DeliveryConfig{ + MaxRetries: 3, + InitialDelay: 1 * time.Minute, + MaxDelay: 10 * time.Second, + Timeout: 30 * time.Second, + }, + wantErr: true, + errMsg: "initial_delay", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if tt.wantErr { + assert.Error(t, err) + if tt.errMsg != "" { + assert.Contains(t, err.Error(), tt.errMsg) + } + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestDeliveryConfig_SetDefaults(t *testing.T) { + config := DeliveryConfig{} + config.SetDefaults() + + assert.Equal(t, 3, config.MaxRetries) + assert.Equal(t, 100*time.Millisecond, config.InitialDelay) + assert.Equal(t, 2.0, config.BackoffFactor) + assert.Equal(t, 30*time.Second, config.MaxDelay) + assert.Equal(t, 30*time.Second, config.Timeout) + + // 測試不覆蓋已設置的值 + config2 := DeliveryConfig{ + MaxRetries: 5, + Timeout: 60 * time.Second, + } + config2.SetDefaults() + + assert.Equal(t, 5, config2.MaxRetries) // 保持原值 + assert.Equal(t, 60*time.Second, config2.Timeout) // 保持原值 + assert.Equal(t, 100*time.Millisecond, config2.InitialDelay) // 設置默認值 +} diff --git a/pkg/notification/domain/error.go b/pkg/notification/domain/error.go index d0c942d..d5bbad4 100644 --- a/pkg/notification/domain/error.go +++ b/pkg/notification/domain/error.go @@ -8,6 +8,7 @@ const ( FailedToSendEmailErrorCode FailedToSendSMSErrorCode FailedToGetTemplateErrorCode + FailedToRenderTemplateErrorCode FailedToSaveHistoryErrorCode FailedToRetryDeliveryErrorCode ) diff --git a/pkg/notification/domain/usecase/template.go b/pkg/notification/domain/usecase/template.go index a896ed7..6e65870 100644 --- a/pkg/notification/domain/usecase/template.go +++ b/pkg/notification/domain/usecase/template.go @@ -1,10 +1,24 @@ package usecase import ( + "backend/pkg/notification/domain/entity" "backend/pkg/notification/domain/template" "context" ) type TemplateUseCase interface { + // GetEmailTemplateByStatic 從靜態模板獲取郵件模板 GetEmailTemplateByStatic(ctx context.Context, language template.Language, templateID template.Type) (template.EmailTemplate, error) + + // GetEmailTemplate 獲取郵件模板(優先從資料庫,回退到靜態模板) + GetEmailTemplate(ctx context.Context, language template.Language, templateID template.Type) (template.EmailTemplate, error) + + // GetSMSTemplate 獲取 SMS 模板(優先從資料庫,回退到靜態模板) + GetSMSTemplate(ctx context.Context, language template.Language, templateID template.Type) (SMSTemplateResp, error) + + // RenderEmailTemplate 渲染郵件模板(替換變數) + RenderEmailTemplate(ctx context.Context, tmpl template.EmailTemplate, params entity.TemplateParams) (EmailTemplateResp, error) + + // RenderSMSTemplate 渲染 SMS 模板(替換變數) + RenderSMSTemplate(ctx context.Context, tmpl SMSTemplateResp, params entity.TemplateParams) (SMSTemplateResp, error) } diff --git a/pkg/notification/repository/aws_ses_mailer.go b/pkg/notification/repository/aws_ses_mailer.go index be072ba..1ccfc53 100644 --- a/pkg/notification/repository/aws_ses_mailer.go +++ b/pkg/notification/repository/aws_ses_mailer.go @@ -5,11 +5,10 @@ import ( "backend/pkg/notification/domain" "backend/pkg/notification/domain/repository" "context" - "time" + "fmt" "backend/pkg/library/errs" "backend/pkg/library/errs/code" - pool "backend/pkg/library/worker_pool" "github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/service/ses/types" @@ -25,8 +24,8 @@ type AwsEmailDeliveryParam struct { } type AwsEmailDeliveryRepository struct { - Client *ses.Client - Pool pool.WorkerPool + Client *ses.Client + Timeout int // 超時時間(秒),預設 30 } func MustAwsSesMailRepository(param AwsEmailDeliveryParam) repository.MailRepository { @@ -42,70 +41,62 @@ func MustAwsSesMailRepository(param AwsEmailDeliveryParam) repository.MailReposi // 創建 SES 客戶端 sesClient := ses.NewFromConfig(cfg) + // 設置默認超時時間 + timeout := 30 + if param.Conf.PoolSize > 0 { + timeout = param.Conf.PoolSize // 可以復用這個配置項,或新增專門的 Timeout 配置 + } + return &AwsEmailDeliveryRepository{ - Client: sesClient, - Pool: pool.NewWorkerPool(param.Conf.PoolSize), + Client: sesClient, + Timeout: timeout, } } -func (use *AwsEmailDeliveryRepository) SendMail(ctx context.Context, req repository.MailReq) error { - err := use.Pool.Submit(func() { - // 設置郵件參數 - to := make([]string, 0, len(req.To)) - to = append(to, req.To...) +func (repo *AwsEmailDeliveryRepository) SendMail(ctx context.Context, req repository.MailReq) error { + // 檢查 context 是否已取消 + if ctx.Err() != nil { + return ctx.Err() + } - input := &ses.SendEmailInput{ - Destination: &types.Destination{ - ToAddresses: to, - }, - Message: &types.Message{ - Body: &types.Body{ - Html: &types.Content{ - Charset: aws.String("UTF-8"), - Data: aws.String(req.Body), - }, - }, - Subject: &types.Content{ + // 設置郵件參數 + to := make([]string, 0, len(req.To)) + to = append(to, req.To...) + + input := &ses.SendEmailInput{ + Destination: &types.Destination{ + ToAddresses: to, + }, + Message: &types.Message{ + Body: &types.Body{ + Html: &types.Content{ Charset: aws.String("UTF-8"), - Data: aws.String(req.Subject), + Data: aws.String(req.Body), }, }, - Source: aws.String(req.From), - } + Subject: &types.Content{ + Charset: aws.String("UTF-8"), + Data: aws.String(req.Subject), + }, + }, + Source: aws.String(req.From), + } - // 發送郵件 - // TODO 不明原因送不出去,會被 context cancel 這裡先把它手動加到100sec - newCtx, cancel := context.WithTimeout(context.Background(), 100*time.Second) - defer cancel() - - //nolint:contextcheck - if _, err := use.Client.SendEmail(newCtx, input); err != nil { - _ = errs.ThirdPartyErrorL( - code.CloudEPNotification, - domain.FailedToSendEmailErrorCode, - logx.WithContext(ctx), - []logx.LogField{ - {Key: "req", Value: req}, - {Key: "func", Value: "AwsEmailDeliveryU.SendEmail"}, - {Key: "err", Value: err.Error()}, - }, - "failed to send mail by aws ses") - } - }) + // 發送郵件(直接使用傳入的 context,不創建新的 context) + _, err := repo.Client.SendEmail(ctx, input) if err != nil { - e := errs.ThirdPartyErrorL( + return errs.ThirdPartyErrorL( code.CloudEPNotification, domain.FailedToSendEmailErrorCode, logx.WithContext(ctx), []logx.LogField{ {Key: "req", Value: req}, - {Key: "func", Value: "AwsEmailDeliveryU.SendEmail"}, + {Key: "func", Value: "AwsEmailDeliveryRepository.SendEmail"}, {Key: "err", Value: err.Error()}, }, - "failed to send mail by aws ses") - - return e + fmt.Sprintf("failed to send mail by aws ses: %v", err)).Wrap(err) } + logx.WithContext(ctx).Infof("Email sent successfully via AWS SES to %v", req.To) return nil } diff --git a/pkg/notification/repository/mitake_sms_sender.go b/pkg/notification/repository/mitake_sms_sender.go index a8d1fdc..ca93173 100644 --- a/pkg/notification/repository/mitake_sms_sender.go +++ b/pkg/notification/repository/mitake_sms_sender.go @@ -5,10 +5,10 @@ import ( "backend/pkg/notification/domain" "backend/pkg/notification/domain/repository" "context" + "fmt" "backend/pkg/library/errs" "backend/pkg/library/errs/code" - pool "backend/pkg/library/worker_pool" "github.com/minchao/go-mitake" "github.com/zeromicro/go-zero/core/logx" @@ -21,45 +21,43 @@ type MitakeSMSDeliveryParam struct { type MitakeSMSDeliveryRepository struct { Client *mitake.Client - Pool pool.WorkerPool } -func (use *MitakeSMSDeliveryRepository) SendSMS(ctx context.Context, req repository.SMSMessageRequest) error { - // 用 goroutine pool 送,否則會超時 - err := use.Pool.Submit(func() { - message := mitake.Message{ - Dstaddr: req.PhoneNumber, - Destname: req.RecipientName, - Smbody: req.MessageContent, - } - _, err := use.Client.Send(message) - if err != nil { - logx.Error("failed to send sms via mitake") - } - }) +func (repo *MitakeSMSDeliveryRepository) SendSMS(ctx context.Context, req repository.SMSMessageRequest) error { + // 檢查 context 是否已取消 + if ctx.Err() != nil { + return ctx.Err() + } + // 構建簡訊訊息 + message := mitake.Message{ + Dstaddr: req.PhoneNumber, + Destname: req.RecipientName, + Smbody: req.MessageContent, + } + + // 直接發送,不使用 goroutine pool + // 讓 delivery usecase 統一管理重試和超時 + _, err := repo.Client.Send(message) if err != nil { - // 錯誤代碼 20-201-04 - e := errs.ThirdPartyErrorL( + return errs.ThirdPartyErrorL( code.CloudEPNotification, domain.FailedToSendSMSErrorCode, logx.WithContext(ctx), []logx.LogField{ {Key: "req", Value: req}, - {Key: "func", Value: "MitakeSMSDeliveryRepository.Client.Send"}, + {Key: "func", Value: "MitakeSMSDeliveryRepository.Send"}, {Key: "err", Value: err.Error()}, }, - "failed to send sns by mitake").Wrap(err) - - return e + fmt.Sprintf("failed to send sms by mitake: %v", err)).Wrap(err) } + logx.WithContext(ctx).Infof("SMS sent successfully via Mitake to %s", req.PhoneNumber) return nil } func MustMitakeRepository(param MitakeSMSDeliveryParam) repository.SMSClientRepository { return &MitakeSMSDeliveryRepository{ Client: mitake.NewClient(param.Conf.User, param.Conf.Password, nil), - Pool: pool.NewWorkerPool(param.Conf.PoolSize), } } diff --git a/pkg/notification/repository/smtp_mailer.go b/pkg/notification/repository/smtp_mailer.go index 2e778d6..edb24bc 100644 --- a/pkg/notification/repository/smtp_mailer.go +++ b/pkg/notification/repository/smtp_mailer.go @@ -2,10 +2,13 @@ package repository import ( "backend/pkg/notification/config" + "backend/pkg/notification/domain" "backend/pkg/notification/domain/repository" "context" + "fmt" - pool "backend/pkg/library/worker_pool" + "backend/pkg/library/errs" + "backend/pkg/library/errs/code" "github.com/zeromicro/go-zero/core/logx" "gopkg.in/gomail.v2" @@ -17,7 +20,6 @@ type SMTPMailUseCaseParam struct { type SMTPMailRepository struct { Client *gomail.Dialer - Pool pool.WorkerPool } func MustSMTPUseCase(param SMTPMailUseCaseParam) repository.MailRepository { @@ -28,26 +30,37 @@ func MustSMTPUseCase(param SMTPMailUseCaseParam) repository.MailRepository { param.Conf.Username, param.Conf.Password, ), - Pool: pool.NewWorkerPool(param.Conf.GoroutinePoolNum), } } -func (repo *SMTPMailRepository) SendMail(_ context.Context, req repository.MailReq) error { - // 用 goroutine pool 送,否則會超時 - err := repo.Pool.Submit(func() { - m := gomail.NewMessage() - m.SetHeader("From", req.From) - m.SetHeader("To", req.To...) - m.SetHeader("Subject", req.Subject) - m.SetBody("text/html", req.Body) - if err := repo.Client.DialAndSend(m); err != nil { - logx.WithCallerSkip(1).WithFields( - logx.Field("func", "MailUseCase.SendMail"), - logx.Field("req", req), - logx.Field("err", err), - ).Error("failed to send mail by mailgun") - } - }) +func (repo *SMTPMailRepository) SendMail(ctx context.Context, req repository.MailReq) error { + // 檢查 context 是否已取消 + if ctx.Err() != nil { + return ctx.Err() + } - return err + // 構建郵件 + m := gomail.NewMessage() + m.SetHeader("From", req.From) + m.SetHeader("To", req.To...) + m.SetHeader("Subject", req.Subject) + m.SetBody("text/html", req.Body) + + // 直接發送,不使用 goroutine pool + // 讓 delivery usecase 統一管理重試和超時 + if err := repo.Client.DialAndSend(m); err != nil { + return errs.ThirdPartyErrorL( + code.CloudEPNotification, + domain.FailedToSendEmailErrorCode, + logx.WithContext(ctx), + []logx.LogField{ + {Key: "func", Value: "SMTPMailRepository.SendMail"}, + {Key: "req", Value: req}, + {Key: "err", Value: err.Error()}, + }, + fmt.Sprintf("failed to send mail by smtp: %v", err)).Wrap(err) + } + + logx.WithContext(ctx).Infof("Email sent successfully via SMTP to %v", req.To) + return nil } diff --git a/pkg/notification/usecase/delivery.go b/pkg/notification/usecase/delivery.go index 4c2b830..997d5c4 100644 --- a/pkg/notification/usecase/delivery.go +++ b/pkg/notification/usecase/delivery.go @@ -74,7 +74,10 @@ func (use *DeliveryUseCase) SendMessage(ctx context.Context, req usecase.SMSMess } // 執行發送邏輯 - return use.sendSMSWithRetry(ctx, req, history) + return use.sendWithRetry(ctx, history, &smsProviderAdapter{ + providers: use.param.SMSProviders, + request: req, + }) } func (use *DeliveryUseCase) SendEmail(ctx context.Context, req usecase.MailReq) error { @@ -97,30 +100,130 @@ func (use *DeliveryUseCase) SendEmail(ctx context.Context, req usecase.MailReq) } // 執行發送邏輯 - return use.sendEmailWithRetry(ctx, req, history) + return use.sendWithRetry(ctx, history, &emailProviderAdapter{ + providers: use.param.EmailProviders, + request: req, + }) } -// sendSMSWithRetry 發送 SMS 並實現重試機制 -func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SMSMessageRequest, history *entity.DeliveryHistory) error { - // 根據 Sort 欄位對 SMSProviders 進行排序 - providers := make([]usecase.SMSProvider, len(use.param.SMSProviders)) - copy(providers, use.param.SMSProviders) - sort.Slice(providers, func(i, j int) bool { - return providers[i].Sort < providers[j].Sort +// providerAdapter 統一的供應商適配器接口 +type providerAdapter interface { + getProviderCount() int + getProviderName(index int) string + getProviderSort(index int) int64 + send(ctx context.Context, providerIndex int) error + getErrorCode() errs.ErrorCode + getType() string +} + +// smsProviderAdapter SMS 供應商適配器 +type smsProviderAdapter struct { + providers []usecase.SMSProvider + request usecase.SMSMessageRequest +} + +func (a *smsProviderAdapter) getProviderCount() int { + return len(a.providers) +} + +func (a *smsProviderAdapter) getProviderName(index int) string { + return fmt.Sprintf("sms_provider_%d", index) +} + +func (a *smsProviderAdapter) getProviderSort(index int) int64 { + return a.providers[index].Sort +} + +func (a *smsProviderAdapter) send(ctx context.Context, providerIndex int) error { + return a.providers[providerIndex].Repo.SendSMS(ctx, repository.SMSMessageRequest{ + PhoneNumber: a.request.PhoneNumber, + RecipientName: a.request.RecipientName, + MessageContent: a.request.MessageContent, + }) +} + +func (a *smsProviderAdapter) getErrorCode() errs.ErrorCode { + return domain.FailedToSendSMSErrorCode +} + +func (a *smsProviderAdapter) getType() string { + return "SMS" +} + +// emailProviderAdapter Email 供應商適配器 +type emailProviderAdapter struct { + providers []usecase.EmailProvider + request usecase.MailReq +} + +func (a *emailProviderAdapter) getProviderCount() int { + return len(a.providers) +} + +func (a *emailProviderAdapter) getProviderName(index int) string { + return fmt.Sprintf("email_provider_%d", index) +} + +func (a *emailProviderAdapter) getProviderSort(index int) int64 { + return a.providers[index].Sort +} + +func (a *emailProviderAdapter) send(ctx context.Context, providerIndex int) error { + return a.providers[providerIndex].Repo.SendMail(ctx, repository.MailReq{ + From: a.request.From, + To: a.request.To, + Subject: a.request.Subject, + Body: a.request.Body, + }) +} + +func (a *emailProviderAdapter) getErrorCode() errs.ErrorCode { + return domain.FailedToSendEmailErrorCode +} + +func (a *emailProviderAdapter) getType() string { + return "Email" +} + +// providerWithIndex 用於排序的結構 +type providerWithIndex struct { + index int + sort int64 +} + +// sendWithRetry 統一的發送重試邏輯 +func (use *DeliveryUseCase) sendWithRetry( + ctx context.Context, + history *entity.DeliveryHistory, + adapter providerAdapter, +) error { + // 按 Sort 欄位對供應商進行排序 + providerCount := adapter.getProviderCount() + sortedProviders := make([]providerWithIndex, providerCount) + for i := 0; i < providerCount; i++ { + sortedProviders[i] = providerWithIndex{ + index: i, + sort: adapter.getProviderSort(i), + } + } + sort.Slice(sortedProviders, func(i, j int) bool { + return sortedProviders[i].sort < sortedProviders[j].sort }) var lastErr error totalAttempts := 0 // 嘗試所有 providers - for providerIndex, provider := range providers { + for _, provider := range sortedProviders { + providerIndex := provider.index + // 為每個 provider 嘗試發送 for attempt := 0; attempt < use.param.DeliveryConfig.MaxRetries; attempt++ { totalAttempts++ // 更新歷史記錄狀態 history.Status = entity.DeliveryStatusSending - history.Provider = fmt.Sprintf("sms_provider_%d", providerIndex) + history.Provider = adapter.getProviderName(providerIndex) history.AttemptCount = totalAttempts history.UpdatedAt = time.Now() use.updateHistory(ctx, history) @@ -131,11 +234,7 @@ func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SM // 創建帶超時的 context sendCtx, cancel := context.WithTimeout(ctx, use.param.DeliveryConfig.Timeout) - err := provider.Repo.SendSMS(sendCtx, repository.SMSMessageRequest{ - PhoneNumber: req.PhoneNumber, - RecipientName: req.RecipientName, - MessageContent: req.MessageContent, - }) + err := adapter.send(sendCtx, providerIndex) cancel() @@ -153,8 +252,8 @@ func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SM attemptRecord.ErrorMessage = err.Error() lastErr = err - logx.WithContext(ctx).Errorf("SMS send attempt %d failed for provider %d: %v", - attempt+1, providerIndex, err) + logx.WithContext(ctx).Errorf("%s send attempt %d failed for provider %d: %v", + adapter.getType(), attempt+1, providerIndex, err) // 如果不是最後一次嘗試,等待後重試 if attempt < use.param.DeliveryConfig.MaxRetries-1 { @@ -179,7 +278,8 @@ func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SM use.updateHistory(ctx, history) use.addAttemptRecord(ctx, history.ID, attemptRecord) - logx.WithContext(ctx).Infof("SMS sent successfully after %d attempts", totalAttempts) + logx.WithContext(ctx).Infof("%s sent successfully after %d attempts", + adapter.getType(), totalAttempts) return nil } @@ -197,112 +297,9 @@ func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SM return errs.ThirdPartyError( code.CloudEPNotification, - domain.FailedToSendSMSErrorCode, - fmt.Sprintf("Failed to send SMS after %d attempts across %d providers", - totalAttempts, len(providers))) -} - -// sendEmailWithRetry 發送 Email 並實現重試機制 -func (use *DeliveryUseCase) sendEmailWithRetry(ctx context.Context, req usecase.MailReq, history *entity.DeliveryHistory) error { - // 根據 Sort 欄位對 EmailProviders 進行排序 - providers := make([]usecase.EmailProvider, len(use.param.EmailProviders)) - copy(providers, use.param.EmailProviders) - sort.Slice(providers, func(i, j int) bool { - return providers[i].Sort < providers[j].Sort - }) - - var lastErr error - totalAttempts := 0 - - // 嘗試所有 providers - for providerIndex, provider := range providers { - // 為每個 provider 嘗試發送 - for attempt := 0; attempt < use.param.DeliveryConfig.MaxRetries; attempt++ { - totalAttempts++ - - // 更新歷史記錄狀態 - history.Status = entity.DeliveryStatusSending - history.Provider = fmt.Sprintf("email_provider_%d", providerIndex) - history.AttemptCount = totalAttempts - history.UpdatedAt = time.Now() - use.updateHistory(ctx, history) - - // 記錄發送嘗試 - attemptStart := time.Now() - - // 創建帶超時的 context - sendCtx, cancel := context.WithTimeout(ctx, use.param.DeliveryConfig.Timeout) - - err := provider.Repo.SendMail(sendCtx, repository.MailReq{ - From: req.From, - To: req.To, - Subject: req.Subject, - Body: req.Body, - }) - - cancel() - - // 記錄嘗試結果 - attemptDuration := time.Since(attemptStart) - attemptRecord := entity.DeliveryAttempt{ - Provider: history.Provider, - AttemptAt: attemptStart, - Success: err == nil, - ErrorMessage: "", - Duration: attemptDuration.Milliseconds(), - } - - if err != nil { - attemptRecord.ErrorMessage = err.Error() - lastErr = err - - logx.WithContext(ctx).Errorf("Email send attempt %d failed for provider %d: %v", - attempt+1, providerIndex, err) - - // 如果不是最後一次嘗試,等待後重試 - if attempt < use.param.DeliveryConfig.MaxRetries-1 { - delay := use.calculateDelay(attempt) - history.Status = entity.DeliveryStatusRetrying - use.updateHistory(ctx, history) - use.addAttemptRecord(ctx, history.ID, attemptRecord) - - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(delay): - continue - } - } - } else { - // 發送成功 - history.Status = entity.DeliveryStatusSuccess - history.UpdatedAt = time.Now() - now := time.Now() - history.CompletedAt = &now - use.updateHistory(ctx, history) - use.addAttemptRecord(ctx, history.ID, attemptRecord) - - logx.WithContext(ctx).Infof("Email sent successfully after %d attempts", totalAttempts) - return nil - } - - use.addAttemptRecord(ctx, history.ID, attemptRecord) - } - } - - // 所有 providers 都失敗了 - history.Status = entity.DeliveryStatusFailed - history.ErrorMessage = fmt.Sprintf("All providers failed. Last error: %v", lastErr) - history.UpdatedAt = time.Now() - now := time.Now() - history.CompletedAt = &now - use.updateHistory(ctx, history) - - return errs.ThirdPartyError( - code.CloudEPNotification, - domain.FailedToSendEmailErrorCode, - fmt.Sprintf("Failed to send email after %d attempts across %d providers", - totalAttempts, len(providers))) + adapter.getErrorCode(), + fmt.Sprintf("Failed to send %s after %d attempts across %d providers", + adapter.getType(), totalAttempts, providerCount)) } // calculateDelay 計算指數退避延遲 diff --git a/pkg/notification/usecase/delivery_test.go b/pkg/notification/usecase/delivery_test.go new file mode 100644 index 0000000..ab170b6 --- /dev/null +++ b/pkg/notification/usecase/delivery_test.go @@ -0,0 +1,377 @@ +package usecase + +import ( + "backend/pkg/notification/config" + "backend/pkg/notification/domain/entity" + "backend/pkg/notification/domain/repository" + "backend/pkg/notification/domain/usecase" + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// mockSMSRepository 模擬 SMS Repository +type mockSMSRepository struct { + sendFunc func(ctx context.Context, req repository.SMSMessageRequest) error +} + +func (m *mockSMSRepository) SendSMS(ctx context.Context, req repository.SMSMessageRequest) error { + if m.sendFunc != nil { + return m.sendFunc(ctx, req) + } + return nil +} + +// mockMailRepository 模擬 Mail Repository +type mockMailRepository struct { + sendFunc func(ctx context.Context, req repository.MailReq) error +} + +func (m *mockMailRepository) SendMail(ctx context.Context, req repository.MailReq) error { + if m.sendFunc != nil { + return m.sendFunc(ctx, req) + } + return nil +} + +// mockHistoryRepository 模擬 History Repository +type mockHistoryRepository struct { + histories []entity.DeliveryHistory + attempts map[string][]entity.DeliveryAttempt +} + +func (m *mockHistoryRepository) CreateHistory(ctx context.Context, history *entity.DeliveryHistory) error { + m.histories = append(m.histories, *history) + return nil +} + +func (m *mockHistoryRepository) UpdateHistory(ctx context.Context, history *entity.DeliveryHistory) error { + for i := range m.histories { + if m.histories[i].ID == history.ID { + m.histories[i] = *history + return nil + } + } + return nil +} + +func (m *mockHistoryRepository) GetHistory(ctx context.Context, id string) (*entity.DeliveryHistory, error) { + for i := range m.histories { + if m.histories[i].ID == id { + return &m.histories[i], nil + } + } + return nil, errors.New("not found") +} + +func (m *mockHistoryRepository) AddAttempt(ctx context.Context, historyID string, attempt entity.DeliveryAttempt) error { + if m.attempts == nil { + m.attempts = make(map[string][]entity.DeliveryAttempt) + } + m.attempts[historyID] = append(m.attempts[historyID], attempt) + return nil +} + +func (m *mockHistoryRepository) ListHistory(ctx context.Context, filter repository.HistoryFilter) ([]*entity.DeliveryHistory, error) { + var result []*entity.DeliveryHistory + for i := range m.histories { + result = append(result, &m.histories[i]) + } + return result, nil +} + +func TestDeliveryUseCase_SendEmail_Success(t *testing.T) { + mockMail := &mockMailRepository{ + sendFunc: func(ctx context.Context, req repository.MailReq) error { + return nil // 成功 + }, + } + + mockHistory := &mockHistoryRepository{} + + uc := MustDeliveryUseCase(DeliveryUseCaseParam{ + EmailProviders: []usecase.EmailProvider{ + {Sort: 1, Repo: mockMail}, + }, + DeliveryConfig: config.DeliveryConfig{ + MaxRetries: 3, + InitialDelay: 10 * time.Millisecond, + BackoffFactor: 2.0, + MaxDelay: 100 * time.Millisecond, + Timeout: 5 * time.Second, + EnableHistory: true, + }, + HistoryRepo: mockHistory, + }) + + ctx := context.Background() + err := uc.SendEmail(ctx, usecase.MailReq{ + From: "test@example.com", + To: []string{"user@example.com"}, + Subject: "Test", + Body: "

Test email

", + }) + + assert.NoError(t, err) + // 驗證歷史記錄 + assert.Equal(t, 1, len(mockHistory.histories)) + assert.Equal(t, entity.DeliveryStatusSuccess, mockHistory.histories[0].Status) +} + +func TestDeliveryUseCase_SendEmail_RetryAndSuccess(t *testing.T) { + attemptCount := 0 + mockMail := &mockMailRepository{ + sendFunc: func(ctx context.Context, req repository.MailReq) error { + attemptCount++ + if attemptCount < 3 { + return errors.New("temporary error") + } + return nil // 第三次成功 + }, + } + + mockHistory := &mockHistoryRepository{} + + uc := MustDeliveryUseCase(DeliveryUseCaseParam{ + EmailProviders: []usecase.EmailProvider{ + {Sort: 1, Repo: mockMail}, + }, + DeliveryConfig: config.DeliveryConfig{ + MaxRetries: 3, + InitialDelay: 10 * time.Millisecond, + BackoffFactor: 2.0, + MaxDelay: 100 * time.Millisecond, + Timeout: 5 * time.Second, + EnableHistory: true, + }, + HistoryRepo: mockHistory, + }) + + ctx := context.Background() + err := uc.SendEmail(ctx, usecase.MailReq{ + From: "test@example.com", + To: []string{"user@example.com"}, + Subject: "Test", + Body: "

Test

", + }) + + assert.NoError(t, err) + assert.Equal(t, 3, attemptCount) // 重試了 3 次 + assert.Equal(t, 1, len(mockHistory.histories)) + assert.Equal(t, entity.DeliveryStatusSuccess, mockHistory.histories[0].Status) + assert.Equal(t, 3, mockHistory.histories[0].AttemptCount) +} + +func TestDeliveryUseCase_SendEmail_AllRetries_Failed(t *testing.T) { + mockMail := &mockMailRepository{ + sendFunc: func(ctx context.Context, req repository.MailReq) error { + return errors.New("persistent error") + }, + } + + mockHistory := &mockHistoryRepository{} + + uc := MustDeliveryUseCase(DeliveryUseCaseParam{ + EmailProviders: []usecase.EmailProvider{ + {Sort: 1, Repo: mockMail}, + }, + DeliveryConfig: config.DeliveryConfig{ + MaxRetries: 3, + InitialDelay: 10 * time.Millisecond, + BackoffFactor: 2.0, + MaxDelay: 100 * time.Millisecond, + Timeout: 5 * time.Second, + EnableHistory: true, + }, + HistoryRepo: mockHistory, + }) + + ctx := context.Background() + err := uc.SendEmail(ctx, usecase.MailReq{ + From: "test@example.com", + To: []string{"user@example.com"}, + Subject: "Test", + Body: "

Test

", + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "Failed to send Email") + assert.Equal(t, 1, len(mockHistory.histories)) + assert.Equal(t, entity.DeliveryStatusFailed, mockHistory.histories[0].Status) + assert.Equal(t, 3, mockHistory.histories[0].AttemptCount) // 嘗試了 3 次 +} + +func TestDeliveryUseCase_SendEmail_Failover(t *testing.T) { + mockMail1 := &mockMailRepository{ + sendFunc: func(ctx context.Context, req repository.MailReq) error { + return errors.New("provider 1 failed") + }, + } + + mockMail2 := &mockMailRepository{ + sendFunc: func(ctx context.Context, req repository.MailReq) error { + return nil // 備援成功 + }, + } + + mockHistory := &mockHistoryRepository{} + + uc := MustDeliveryUseCase(DeliveryUseCaseParam{ + EmailProviders: []usecase.EmailProvider{ + {Sort: 1, Repo: mockMail1}, // 主要供應商 + {Sort: 2, Repo: mockMail2}, // 備援供應商 + }, + DeliveryConfig: config.DeliveryConfig{ + MaxRetries: 2, + InitialDelay: 10 * time.Millisecond, + BackoffFactor: 2.0, + MaxDelay: 100 * time.Millisecond, + Timeout: 5 * time.Second, + EnableHistory: true, + }, + HistoryRepo: mockHistory, + }) + + ctx := context.Background() + err := uc.SendEmail(ctx, usecase.MailReq{ + From: "test@example.com", + To: []string{"user@example.com"}, + Subject: "Test", + Body: "

Test

", + }) + + assert.NoError(t, err) + // 驗證使用了備援供應商 + assert.Equal(t, 1, len(mockHistory.histories)) + assert.Equal(t, entity.DeliveryStatusSuccess, mockHistory.histories[0].Status) + // 總共嘗試次數:provider1 重試 2 次 + provider2 成功 1 次 = 3 次 + assert.Equal(t, 3, mockHistory.histories[0].AttemptCount) +} + +func TestDeliveryUseCase_SendSMS_Success(t *testing.T) { + mockSMS := &mockSMSRepository{ + sendFunc: func(ctx context.Context, req repository.SMSMessageRequest) error { + return nil + }, + } + + mockHistory := &mockHistoryRepository{} + + uc := MustDeliveryUseCase(DeliveryUseCaseParam{ + SMSProviders: []usecase.SMSProvider{ + {Sort: 1, Repo: mockSMS}, + }, + DeliveryConfig: config.DeliveryConfig{ + MaxRetries: 3, + InitialDelay: 10 * time.Millisecond, + BackoffFactor: 2.0, + MaxDelay: 100 * time.Millisecond, + Timeout: 5 * time.Second, + EnableHistory: true, + }, + HistoryRepo: mockHistory, + }) + + ctx := context.Background() + err := uc.SendMessage(ctx, usecase.SMSMessageRequest{ + PhoneNumber: "+886912345678", + RecipientName: "Test User", + MessageContent: "Your code: 123456", + }) + + assert.NoError(t, err) + assert.Equal(t, 1, len(mockHistory.histories)) + assert.Equal(t, entity.DeliveryStatusSuccess, mockHistory.histories[0].Status) +} + +func TestDeliveryUseCase_CalculateDelay(t *testing.T) { + uc := &DeliveryUseCase{ + param: DeliveryUseCaseParam{ + DeliveryConfig: config.DeliveryConfig{ + InitialDelay: 100 * time.Millisecond, + BackoffFactor: 2.0, + MaxDelay: 1 * time.Second, + }, + }, + } + + tests := []struct { + name string + attempt int + expected time.Duration + }{ + { + name: "第 0 次重試", + attempt: 0, + expected: 100 * time.Millisecond, + }, + { + name: "第 1 次重試", + attempt: 1, + expected: 200 * time.Millisecond, + }, + { + name: "第 2 次重試", + attempt: 2, + expected: 400 * time.Millisecond, + }, + { + name: "第 3 次重試", + attempt: 3, + expected: 800 * time.Millisecond, + }, + { + name: "第 10 次重試(達到 MaxDelay)", + attempt: 10, + expected: 1 * time.Second, // 受限於 MaxDelay + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + delay := uc.calculateDelay(tt.attempt) + assert.Equal(t, tt.expected, delay) + }) + } +} + +func TestDeliveryUseCase_ContextCancellation(t *testing.T) { + mockMail := &mockMailRepository{ + sendFunc: func(ctx context.Context, req repository.MailReq) error { + // 模擬慢速操作 + time.Sleep(100 * time.Millisecond) + return errors.New("should not reach here") + }, + } + + uc := MustDeliveryUseCase(DeliveryUseCaseParam{ + EmailProviders: []usecase.EmailProvider{ + {Sort: 1, Repo: mockMail}, + }, + DeliveryConfig: config.DeliveryConfig{ + MaxRetries: 3, + InitialDelay: 50 * time.Millisecond, + BackoffFactor: 2.0, + MaxDelay: 500 * time.Millisecond, + Timeout: 1 * time.Second, + EnableHistory: false, + }, + }) + + // 創建會被取消的 context + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + err := uc.SendEmail(ctx, usecase.MailReq{ + From: "test@example.com", + To: []string{"user@example.com"}, + Subject: "Test", + Body: "

Test

", + }) + + assert.Error(t, err) + assert.Equal(t, context.DeadlineExceeded, err) +} diff --git a/pkg/notification/usecase/template_test.go b/pkg/notification/usecase/template_test.go new file mode 100644 index 0000000..d5d9dcc --- /dev/null +++ b/pkg/notification/usecase/template_test.go @@ -0,0 +1,255 @@ +package usecase + +import ( + "backend/pkg/notification/domain/entity" + "backend/pkg/notification/domain/template" + "backend/pkg/notification/domain/usecase" + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTemplateUseCase_RenderEmailTemplate(t *testing.T) { + uc := MustTemplateUseCase(TemplateUseCaseParam{ + TemplateRepo: nil, + }) + + ctx := context.Background() + + tests := []struct { + name string + tmpl template.EmailTemplate + params entity.TemplateParams + expectedSubj string + expectedBody string + shouldContain []string + shouldNotError bool + }{ + { + name: "渲染基本參數", + tmpl: template.EmailTemplate{ + Title: "Hello {{.Username}}", + Body: "

Your code is: {{.VerifyCode}}

", + }, + params: entity.TemplateParams{ + Username: "張三", + VerifyCode: "123456", + }, + expectedSubj: "Hello 張三", + shouldContain: []string{"123456"}, + shouldNotError: true, + }, + { + name: "渲染額外參數", + tmpl: template.EmailTemplate{ + Title: "Welcome", + Body: "

Hello {{.Username}}, your link: {{.Link}}

", + }, + params: entity.TemplateParams{ + Username: "John", + Extra: map[string]string{ + "Link": "https://example.com", + }, + }, + shouldContain: []string{"John", "https://example.com"}, + shouldNotError: true, + }, + { + name: "特殊字符不轉義(簡單字符串替換)", + tmpl: template.EmailTemplate{ + Title: "Test", + Body: "

Name: {{.Username}}

", + }, + params: entity.TemplateParams{ + Username: "", + }, + shouldContain: []string{""}, // 使用簡單字符串替換,不轉義 + shouldNotError: true, + }, + { + name: "空模板", + tmpl: template.EmailTemplate{ + Title: "", + Body: "", + }, + params: entity.TemplateParams{ + Username: "Test", + }, + expectedSubj: "", + expectedBody: "", + shouldNotError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := uc.RenderEmailTemplate(ctx, tt.tmpl, tt.params) + + if tt.shouldNotError { + assert.NoError(t, err) + if tt.expectedSubj != "" { + assert.Equal(t, tt.expectedSubj, result.Subject) + } + if tt.expectedBody != "" { + assert.Equal(t, tt.expectedBody, result.Body) + } + for _, contain := range tt.shouldContain { + assert.Contains(t, result.Body, contain) + } + } else { + assert.Error(t, err) + } + }) + } +} + +func TestTemplateUseCase_RenderSMSTemplate(t *testing.T) { + uc := MustTemplateUseCase(TemplateUseCaseParam{ + TemplateRepo: nil, + }) + + ctx := context.Background() + + tests := []struct { + name string + tmpl usecase.SMSTemplateResp + params entity.TemplateParams + expectedBody string + shouldContain []string + shouldNotError bool + }{ + { + name: "渲染 SMS 驗證碼", + tmpl: usecase.SMSTemplateResp{ + Body: "您的驗證碼是:{{.VerifyCode}},請在5分鐘內使用。", + }, + params: entity.TemplateParams{ + VerifyCode: "654321", + }, + shouldContain: []string{"654321", "5分鐘"}, + shouldNotError: true, + }, + { + name: "SMS 純文本替換", + tmpl: usecase.SMSTemplateResp{ + Body: "Hi {{.Username}}, your code: {{.VerifyCode}}", + }, + params: entity.TemplateParams{ + Username: "", + VerifyCode: "111111", + }, + shouldContain: []string{"", "111111"}, // 使用簡單字符串替換 + shouldNotError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := uc.RenderSMSTemplate(ctx, tt.tmpl, tt.params) + + if tt.shouldNotError { + assert.NoError(t, err) + if tt.expectedBody != "" { + assert.Equal(t, tt.expectedBody, result.Body) + } + for _, contain := range tt.shouldContain { + assert.Contains(t, result.Body, contain) + } + } else { + assert.Error(t, err) + } + }) + } +} + +func TestTemplateUseCase_GetEmailTemplateByStatic(t *testing.T) { + uc := MustTemplateUseCase(TemplateUseCaseParam{ + TemplateRepo: nil, + }) + + ctx := context.Background() + + tests := []struct { + name string + language template.Language + templateID template.Type + wantErr bool + }{ + { + name: "獲取忘記密碼模板 (zh-tw)", + language: template.LanguageZhTW, + templateID: template.ForgetPasswordVerify, + wantErr: false, + }, + { + name: "獲取綁定郵箱模板 (zh-tw)", + language: template.LanguageZhTW, + templateID: template.BindingEmail, + wantErr: false, + }, + { + name: "不存在的語言", + language: template.Language("xx-xx"), + templateID: template.ForgetPasswordVerify, + wantErr: true, + }, + { + name: "不存在的模板類型", + language: template.LanguageZhTW, + templateID: template.Type("non_existent"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := uc.GetEmailTemplateByStatic(ctx, tt.language, tt.templateID) + + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.NotEmpty(t, result.Title) + assert.NotEmpty(t, result.Body) + } + }) + } +} + +func TestTemplateUseCase_GetDefaultSMSTemplate(t *testing.T) { + uc := &TemplateUseCase{} + + tests := []struct { + name string + templateID template.Type + shouldContain []string + }{ + { + name: "忘記密碼模板", + templateID: template.ForgetPasswordVerify, + shouldContain: []string{"密碼重設", "驗證碼"}, + }, + { + name: "綁定郵箱模板", + templateID: template.BindingEmail, + shouldContain: []string{"綁定", "驗證碼"}, + }, + { + name: "默認模板", + templateID: template.Type("unknown"), + shouldContain: []string{"驗證碼"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := uc.getDefaultSMSTemplate(tt.templateID) + + assert.NotEmpty(t, result.Body) + for _, contain := range tt.shouldContain { + assert.Contains(t, result.Body, contain) + } + }) + } +} diff --git a/pkg/permission/domain/config/permission.go b/pkg/permission/domain/config/permission.go index 78decd0..1a022b3 100644 --- a/pkg/permission/domain/config/permission.go +++ b/pkg/permission/domain/config/permission.go @@ -40,7 +40,7 @@ func DefaultConfig() Config { UIDLength: 6, AdminRoleUID: "AM000000", AdminUserUID: "B000000", - DefaultRoleName: "user", + DefaultRoleName: "USER", }, } } diff --git a/pkg/permission/domain/permission/role.go b/pkg/permission/domain/permission/role.go new file mode 100644 index 0000000..e82ff52 --- /dev/null +++ b/pkg/permission/domain/permission/role.go @@ -0,0 +1,5 @@ +package permission + +const ( + DefaultRole = "user" +) diff --git a/pkg/permission/domain/token/context.go b/pkg/permission/domain/token/context.go new file mode 100644 index 0000000..2c71336 --- /dev/null +++ b/pkg/permission/domain/token/context.go @@ -0,0 +1,50 @@ +package token + +import ( + "context" +) + +type ContextKey string + +func (c ContextKey) String() string { + return string(c) +} + +const ( + KeyRole ContextKey = "role" + KeyDeviceID ContextKey = "device_id" + KeyScope ContextKey = "scope" + KeyUID ContextKey = "uid" + KeyLoginID ContextKey = "login_id" +) + +func UID(ctx context.Context) string { return getString(ctx, KeyUID) } +func Scope(ctx context.Context) string { return getString(ctx, KeyScope) } +func Role(ctx context.Context) string { return getString(ctx, KeyRole) } +func DeviceID(ctx context.Context) string { return getString(ctx, KeyDeviceID) } +func LoginID(ctx context.Context) string { return getString(ctx, KeyLoginID) } + +func WithUID(ctx context.Context, uid string) context.Context { + return context.WithValue(ctx, KeyUID, uid) +} +func WithScope(ctx context.Context, scope string) context.Context { + return context.WithValue(ctx, KeyScope, scope) +} +func WithRole(ctx context.Context, role string) context.Context { + return context.WithValue(ctx, KeyRole, role) +} +func WithDeviceID(ctx context.Context, id string) context.Context { + return context.WithValue(ctx, KeyDeviceID, id) +} +func WithLoginID(ctx context.Context, login string) context.Context { + return context.WithValue(ctx, KeyLoginID, login) +} + +// --- Internal helper --- +func getString(ctx context.Context, key ContextKey) string { + if v, ok := ctx.Value(key).(string); ok { + return v + } + + return "" +} diff --git a/pkg/permission/repository/permission.go b/pkg/permission/repository/permission.go index 1f99fa1..90c4c61 100644 --- a/pkg/permission/repository/permission.go +++ b/pkg/permission/repository/permission.go @@ -26,7 +26,7 @@ type PermissionRepository struct { DB mongo.DocumentDBWithCacheUseCase } -func NewAccountRepository(param PermissionRepositoryParam) repository.PermissionRepository { +func NewPermissionRepository(param PermissionRepositoryParam) repository.PermissionRepository { e := entity.Permission{} documentDB, err := mongo.MustDocumentDBWithCache( param.Conf, diff --git a/pkg/permission/repository/permission_test.go b/pkg/permission/repository/permission_test.go index 2a5e5ff..69e0fba 100644 --- a/pkg/permission/repository/permission_test.go +++ b/pkg/permission/repository/permission_test.go @@ -59,7 +59,7 @@ func setupPermissionRepo(db string) (domainRepo.PermissionRepository, func(), er CacheConf: cacheConf, CacheOpts: cacheOpts, } - repo := NewAccountRepository(param) + repo := NewPermissionRepository(param) _, _ = repo.Index20251009001UP(context.Background()) return repo, tearDown, nil diff --git a/pkg/permission/usecase/token.go b/pkg/permission/usecase/token.go index 6728802..ab4dcc8 100755 --- a/pkg/permission/usecase/token.go +++ b/pkg/permission/usecase/token.go @@ -29,11 +29,11 @@ type TokenUseCase struct { } func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) { - claims, err := parseClaims(token, use.Config.Token.AccessSecret, false) + claims, err := ParseClaims(token, use.Config.Token.AccessSecret, false) if err != nil { return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{ - funcName: "parseClaims", + funcName: "ParseClaims", req: token, err: err, message: "validate token claims error", @@ -107,7 +107,7 @@ func (use *TokenUseCase) newToken(ctx context.Context, req *entity.Authorization RefreshCreateAt: now, } - tc := make(tokenClaims) + tc := make(TokenClaims) if req.Data != nil { for k, v := range req.Data { tc[k] = v @@ -116,7 +116,7 @@ func (use *TokenUseCase) newToken(ctx context.Context, req *entity.Authorization tc.SetRole(req.Role) tc.SetID(token.ID) tc.SetScope(req.Scope) - tc.SetAccount(req.Account) + tc.SetLoginID(req.Account) token.UID = tc.UID() @@ -158,7 +158,7 @@ func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTok } // Step 2: 提取 Claims Data - claimsData, err := parseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false) + claimsData, err := ParseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false) if err != nil { return entity.RefreshTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ @@ -179,7 +179,7 @@ func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTok Data: claimsData, Expires: req.Expires, IsRefreshToken: true, - Account: claimsData.Account(), + Account: claimsData.LoginID(), Role: claimsData.Role(), }) if err != nil { @@ -226,7 +226,7 @@ func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTok } func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error { - claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false) + claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false) if err != nil { return use.wrapTokenError(ctx, wrapTokenErrorReq{ funcName: "CancelToken extractClaims", @@ -263,11 +263,11 @@ func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelToken } func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) { - claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, true) + claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, true) if err != nil { return entity.ValidationTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ - funcName: "parseClaims", + funcName: "ParseClaims", req: req, err: err, message: "validate token claims error", @@ -400,11 +400,11 @@ func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.Quer func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) { // 驗證Token - claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false) + claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false) if err != nil { return entity.CreateOneTimeTokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{ - funcName: "parseClaims", + funcName: "ParseClaims", req: req, err: err, message: "failed to get token claims", @@ -637,7 +637,7 @@ func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string, // 為每個 token 創建黑名單條目 for _, token := range tokens { // 解析 token 獲取 JTI 和過期時間 - claims, err := parseClaims(token.AccessToken, use.Config.Token.AccessSecret, false) + claims, err := ParseClaims(token.AccessToken, use.Config.Token.AccessSecret, false) if err != nil { logx.WithContext(ctx).Errorw("failed to parse token for blacklisting", logx.Field("uid", uid), diff --git a/pkg/permission/usecase/token_claims.go b/pkg/permission/usecase/token_claims.go index f83eda8..e5efb5e 100755 --- a/pkg/permission/usecase/token_claims.go +++ b/pkg/permission/usecase/token_claims.go @@ -1,28 +1,28 @@ package usecase -type tokenClaims map[string]string +type TokenClaims map[string]string -func (tc tokenClaims) SetID(id string) { +func (tc TokenClaims) SetID(id string) { tc["id"] = id } -func (tc tokenClaims) SetRole(role string) { +func (tc TokenClaims) SetRole(role string) { tc["role"] = role } -func (tc tokenClaims) SetDeviceID(deviceID string) { +func (tc TokenClaims) SetDeviceID(deviceID string) { tc["device_id"] = deviceID } -func (tc tokenClaims) SetScope(scope string) { +func (tc TokenClaims) SetScope(scope string) { tc["scope"] = scope } -func (tc tokenClaims) SetAccount(account string) { - tc["account"] = account +func (tc TokenClaims) SetLoginID(loginID string) { + tc["login_id"] = loginID } -func (tc tokenClaims) Role() string { +func (tc TokenClaims) Role() string { role, ok := tc["role"] if !ok { return "" @@ -31,7 +31,7 @@ func (tc tokenClaims) Role() string { return role } -func (tc tokenClaims) ID() string { +func (tc TokenClaims) ID() string { id, ok := tc["id"] if !ok { return "" @@ -40,7 +40,7 @@ func (tc tokenClaims) ID() string { return id } -func (tc tokenClaims) DeviceID() string { +func (tc TokenClaims) DeviceID() string { deviceID, ok := tc["device_id"] if !ok { return "" @@ -49,7 +49,7 @@ func (tc tokenClaims) DeviceID() string { return deviceID } -func (tc tokenClaims) UID() string { +func (tc TokenClaims) UID() string { uid, ok := tc["uid"] if !ok { return "" @@ -58,7 +58,7 @@ func (tc tokenClaims) UID() string { return uid } -func (tc tokenClaims) Scope() string { +func (tc TokenClaims) Scope() string { scope, ok := tc["scope"] if !ok { return "" @@ -67,8 +67,8 @@ func (tc tokenClaims) Scope() string { return scope } -func (tc tokenClaims) Account() string { - scope, ok := tc["account"] +func (tc TokenClaims) LoginID() string { + scope, ok := tc["login_id"] if !ok { return "" } diff --git a/pkg/permission/usecase/token_claims_test.go b/pkg/permission/usecase/token_claims_test.go index 97f2ff9..645e7df 100644 --- a/pkg/permission/usecase/token_claims_test.go +++ b/pkg/permission/usecase/token_claims_test.go @@ -27,7 +27,7 @@ func TestTokenClaims_SetAndGetID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tc := make(tokenClaims) + tc := make(TokenClaims) tc.SetID(tt.id) result := tc.ID() @@ -61,7 +61,7 @@ func TestTokenClaims_SetAndGetRole(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tc := make(tokenClaims) + tc := make(TokenClaims) tc.SetRole(tt.role) result := tc.Role() @@ -91,7 +91,7 @@ func TestTokenClaims_SetAndGetDeviceID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tc := make(tokenClaims) + tc := make(TokenClaims) tc.SetDeviceID(tt.deviceID) result := tc.DeviceID() @@ -125,7 +125,7 @@ func TestTokenClaims_SetAndGetScope(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tc := make(tokenClaims) + tc := make(TokenClaims) tc.SetScope(tt.scope) // Note: there's no GetScope method, so we just verify it's set @@ -159,7 +159,7 @@ func TestTokenClaims_SetAndGetAccount(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tc := make(tokenClaims) + tc := make(TokenClaims) tc.SetAccount(tt.account) // Note: there's no GetAccount method, so we just verify it's set @@ -189,7 +189,7 @@ func TestTokenClaims_SetAndGetUID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tc := make(tokenClaims) + tc := make(TokenClaims) tc["uid"] = tt.uid result := tc.UID() @@ -199,7 +199,7 @@ func TestTokenClaims_SetAndGetUID(t *testing.T) { } func TestTokenClaims_GetNonExistentField(t *testing.T) { - tc := make(tokenClaims) + tc := make(TokenClaims) t.Run("get non-existent ID", func(t *testing.T) { result := tc.ID() @@ -223,7 +223,7 @@ func TestTokenClaims_GetNonExistentField(t *testing.T) { } func TestTokenClaims_MultipleFields(t *testing.T) { - tc := make(tokenClaims) + tc := make(TokenClaims) tc.SetID("token123") tc.SetRole("admin") @@ -243,7 +243,7 @@ func TestTokenClaims_MultipleFields(t *testing.T) { } func TestTokenClaims_Overwrite(t *testing.T) { - tc := make(tokenClaims) + tc := make(TokenClaims) t.Run("overwrite ID", func(t *testing.T) { tc.SetID("token123") @@ -263,7 +263,7 @@ func TestTokenClaims_Overwrite(t *testing.T) { } func TestTokenClaims_MapBehavior(t *testing.T) { - tc := make(tokenClaims) + tc := make(TokenClaims) t.Run("can set custom fields", func(t *testing.T) { tc["custom_field"] = "custom_value" @@ -271,7 +271,7 @@ func TestTokenClaims_MapBehavior(t *testing.T) { }) t.Run("can iterate over fields", func(t *testing.T) { - tc2 := make(tokenClaims) + tc2 := make(TokenClaims) tc2.SetID("token123") tc2.SetRole("admin") tc2["uid"] = "user123" @@ -303,7 +303,7 @@ func TestTokenClaims_MapBehavior(t *testing.T) { } func TestTokenClaims_EmptyMap(t *testing.T) { - tc := make(tokenClaims) + tc := make(TokenClaims) assert.Empty(t, tc.ID()) assert.Empty(t, tc.Role()) @@ -313,7 +313,7 @@ func TestTokenClaims_EmptyMap(t *testing.T) { } func TestTokenClaims_NilMap(t *testing.T) { - var tc tokenClaims + var tc TokenClaims t.Run("get from nil map", func(t *testing.T) { assert.Empty(t, tc.ID()) @@ -322,4 +322,3 @@ func TestTokenClaims_NilMap(t *testing.T) { assert.Empty(t, tc.UID()) }) } - diff --git a/pkg/permission/usecase/token_jwt.go b/pkg/permission/usecase/token_jwt.go index 396b8e6..d4b792a 100755 --- a/pkg/permission/usecase/token_jwt.go +++ b/pkg/permission/usecase/token_jwt.go @@ -21,6 +21,7 @@ func createAccessToken(token entity.Token, data any, secretKey string) (string, RegisteredClaims: jwt.RegisteredClaims{ ID: token.ID, ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)), + IssuedAt: jwt.NewNumericDate(time.Now()), Issuer: "permission", }, } @@ -76,10 +77,10 @@ func parseToken(accessToken string, secret string, validate bool) (jwt.MapClaims return claims, nil } -func parseClaims(accessToken string, secret string, validate bool) (tokenClaims, error) { +func ParseClaims(accessToken string, secret string, validate bool) (TokenClaims, error) { claimMap, err := parseToken(accessToken, secret, validate) if err != nil { - return tokenClaims{}, err + return TokenClaims{}, err } claimsData, ok := claimMap["data"].(map[string]any) @@ -87,7 +88,7 @@ func parseClaims(accessToken string, secret string, validate bool) (tokenClaims, return convertMap(claimsData), nil } - return tokenClaims{}, fmt.Errorf("get data from claim map error") + return TokenClaims{}, fmt.Errorf("get data from claim map error") } func convertMap(input map[string]interface{}) map[string]string { diff --git a/pkg/permission/usecase/token_jwt_test.go b/pkg/permission/usecase/token_jwt_test.go index 590d87d..ff51fe0 100644 --- a/pkg/permission/usecase/token_jwt_test.go +++ b/pkg/permission/usecase/token_jwt_test.go @@ -58,7 +58,7 @@ func TestCreateAccessToken(t *testing.T) { token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { return []byte(tt.secretKey), nil }) - + if tt.secretKey != "" { assert.NoError(t, err) assert.True(t, token.Valid) @@ -125,7 +125,7 @@ func TestCreateRefreshToken(t *testing.T) { func TestParseToken(t *testing.T) { secretKey := "test-secret-key" - + // Create a valid token first token := entity.Token{ ID: "test-id", @@ -135,7 +135,7 @@ func TestParseToken(t *testing.T) { "uid": "user123", "role": "admin", } - + validTokenStr, err := createAccessToken(token, data, secretKey) assert.NoError(t, err) @@ -192,7 +192,7 @@ func TestParseToken(t *testing.T) { } else { assert.NoError(t, err) assert.NotNil(t, claims) - + if tt.accessToken == validTokenStr { assert.Equal(t, "test-id", claims["jti"]) assert.Equal(t, "permission", claims["iss"]) @@ -204,7 +204,7 @@ func TestParseToken(t *testing.T) { func TestParseClaims(t *testing.T) { secretKey := "test-secret-key" - + // Create a valid token with data claims token := entity.Token{ ID: "test-id", @@ -215,7 +215,7 @@ func TestParseClaims(t *testing.T) { "role": "admin", "deviceId": "device456", } - + validTokenStr, err := createAccessToken(token, data, secretKey) assert.NoError(t, err) @@ -248,20 +248,20 @@ func TestParseClaims(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - claims, err := parseClaims(tt.accessToken, tt.secret, tt.validate) + claims, err := ParseClaims(tt.accessToken, tt.secret, tt.validate) if tt.wantErr { assert.Error(t, err) } else { assert.NoError(t, err) assert.NotNil(t, claims) - + if tt.expectUID != "" { uid, exists := claims["uid"] assert.True(t, exists) assert.Equal(t, tt.expectUID, uid) } - + if tt.expectRole != "" { role, exists := claims["role"] assert.True(t, exists)