test push
This commit is contained in:
parent
ef9b218f3b
commit
d71ffea750
|
|
@ -51,3 +51,11 @@ Token:
|
|||
OneTimeTokenExpiry : 600s
|
||||
MaxTokensPerUser : 2
|
||||
MaxTokensPerDevice : 2
|
||||
|
||||
|
||||
RoleConfig:
|
||||
UIDPrefix: "AM"
|
||||
UIDLength: 6
|
||||
AdminRoleUID: "AM000000"
|
||||
AdminUserUID: "B000000"
|
||||
DefaultRoleName: "USER"
|
||||
|
|
@ -3,11 +3,7 @@ syntax = "v1"
|
|||
// ================ 通用響應 ================
|
||||
type (
|
||||
// 成功響應
|
||||
RespOK {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
RespOK {}
|
||||
|
||||
// 分頁響應
|
||||
PagerResp {
|
||||
|
|
@ -29,4 +25,9 @@ type (
|
|||
Authorization {
|
||||
Authorization string `header:"Authorization" validate:"required"`
|
||||
}
|
||||
Status {
|
||||
Code int64 `json:"code"` // 狀態碼
|
||||
Message string `json:"message"` // 訊息
|
||||
Data interface{} `json:"data,omitempty"` // 可選的資料,當有返回時才出現
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ type (
|
|||
|
||||
// RequestPasswordResetReq 請求發送「忘記密碼」的驗證碼
|
||||
RequestPasswordResetReq {
|
||||
Identifier string `json:"identifier" validate:"required,email|phone"` // 使用者帳號 (信箱或手機)
|
||||
Identifier string `json:"identifier" validate:"required"` // 使用者帳號 (信箱或手機)
|
||||
AccountType string `json:"account_type" validate:"required,oneof=email phone"`
|
||||
}
|
||||
|
||||
|
|
@ -141,6 +141,31 @@ type (
|
|||
VerifyCode string `json:"verify_code" validate:"required,len=6"`
|
||||
Authorization
|
||||
}
|
||||
|
||||
// MyInfo 用於獲取會員資訊的標準響應結構
|
||||
MyInfo {
|
||||
Platform string `json:"platform"` // 註冊平台
|
||||
UID string `json:"uid"` // 用戶 UID
|
||||
AvatarURL *string `json:"avatar_url,omitempty"` // 頭像 URL
|
||||
FullName *string `json:"full_name,omitempty"` // 用戶全名
|
||||
Nickname *string `json:"nickname,omitempty"` // 暱稱
|
||||
GenderCode *string `json:"gender_code,omitempty"` // 性別代碼
|
||||
Birthdate *string `json:"birthdate,omitempty"` // 生日 (格式: 1993-04-17)
|
||||
PhoneNumber *string `json:"phone_number,omitempty"` // 電話
|
||||
IsPhoneVerified *bool `json:"is_phone_verified,omitempty"` // 手機是否已驗證
|
||||
Email *string `json:"email,omitempty"` // 信箱
|
||||
IsEmailVerified *bool `json:"is_email_verified,omitempty"` // 信箱是否已驗證
|
||||
Address *string `json:"address,omitempty"` // 地址
|
||||
UserStatus string `json:"user_status,omitempty"` // 用戶狀態
|
||||
PreferredLanguage string `json:"preferred_language,omitempty"` // 偏好語言
|
||||
Currency string `json:"currency,omitempty"` // 偏好幣種
|
||||
AlarmCategory string `json:"alarm_category,omitempty"` // 告警狀態
|
||||
PostCode *string `json:"post_code,omitempty"` // 郵遞區號
|
||||
Carrier *string `json:"carrier,omitempty"` // 載具
|
||||
Role string `json:"role"` // 角色
|
||||
UpdateAt string `json:"update_at"`
|
||||
CreateAt string `json:"create_at"`
|
||||
}
|
||||
)
|
||||
|
||||
// =================================================================
|
||||
|
|
@ -251,7 +276,7 @@ service gateway {
|
|||
@respdoc-500 (ErrorResp) // 伺服器內部錯誤
|
||||
*/
|
||||
@handler getUserInfo
|
||||
get /me (Authorization) returns (UserInfoResp)
|
||||
get /me (Authorization) returns (MyInfo)
|
||||
|
||||
@doc(
|
||||
summary: "更新當前登入的會員資訊"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
db.role.deleteMany({
|
||||
"uid": { "$in": ["ADMIN", "OPERATOR", "USER"] }
|
||||
});
|
||||
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
db.role.insertMany([
|
||||
{
|
||||
"client_id": 1,
|
||||
"uid": "ADMIN",
|
||||
"name": "管理員",
|
||||
"status": 1,
|
||||
"create_time": NumberLong(1728745200),
|
||||
"update_time": NumberLong(1728745200)
|
||||
},
|
||||
{
|
||||
"client_id": 1,
|
||||
"uid": "OPERATOR",
|
||||
"name": "操作員",
|
||||
"status": 1,
|
||||
"create_time": NumberLong(1728745200),
|
||||
"update_time": NumberLong(1728745200)
|
||||
},
|
||||
{
|
||||
"client_id": 1,
|
||||
"uid": "USER",
|
||||
"name": "一般使用者",
|
||||
"status": 1,
|
||||
"create_time": NumberLong(1728745200),
|
||||
"update_time": NumberLong(1728745200)
|
||||
}
|
||||
]);
|
||||
|
||||
// 建立索引
|
||||
db.role.createIndex({ "uid": 1 }, { unique: true });
|
||||
db.role.createIndex({ "client_id": 1 });
|
||||
db.role.createIndex({ "status": 1 });
|
||||
|
||||
|
|
@ -52,12 +52,30 @@ type Config struct {
|
|||
|
||||
// JWT Token 配置
|
||||
Token struct {
|
||||
AccessSecret string
|
||||
RefreshSecret string
|
||||
AccessTokenExpiry time.Duration
|
||||
RefreshTokenExpiry time.Duration
|
||||
OneTimeTokenExpiry time.Duration
|
||||
MaxTokensPerUser int
|
||||
MaxTokensPerDevice int
|
||||
AccessSecret string
|
||||
RefreshSecret string
|
||||
AccessTokenExpiry time.Duration
|
||||
RefreshTokenExpiry time.Duration
|
||||
OneTimeTokenExpiry time.Duration
|
||||
MaxTokensPerUser int
|
||||
MaxTokensPerDevice int
|
||||
}
|
||||
|
||||
// RoleConfig 角色配置
|
||||
RoleConfig struct {
|
||||
// UID 前綴 (例如: AM, RL)
|
||||
UIDPrefix string
|
||||
|
||||
// UID 數字長度
|
||||
UIDLength int
|
||||
|
||||
// 管理員角色 UID
|
||||
AdminRoleUID string
|
||||
|
||||
// 管理員用戶 UID
|
||||
AdminUserUID string
|
||||
|
||||
// 預設角色名稱
|
||||
DefaultRoleName string
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
package domain
|
||||
|
||||
import "strings"
|
||||
|
||||
type RedisKey string
|
||||
|
||||
const (
|
||||
GenerateVerifyCodeRedisKey RedisKey = "rf_code"
|
||||
)
|
||||
|
||||
func (key RedisKey) ToString() string {
|
||||
return string(key)
|
||||
}
|
||||
|
||||
func (key RedisKey) With(s ...string) RedisKey {
|
||||
parts := append([]string{string(key)}, s...)
|
||||
|
||||
return RedisKey(strings.Join(parts, ":"))
|
||||
}
|
||||
|
|
@ -6,7 +6,6 @@ import (
|
|||
"backend/internal/svc"
|
||||
"backend/internal/types"
|
||||
"backend/pkg/library/errs"
|
||||
ers "backend/pkg/library/errs"
|
||||
"net/http"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
|
|
@ -18,38 +17,39 @@ func LoginHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
|||
var req types.LoginReq
|
||||
if err := httpx.Parse(r, &req); err != nil {
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.RespOK{
|
||||
Code: int(e.FullCode()),
|
||||
Msg: err.Error(),
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
//if err := svcCtx.Validate.ValidateAll(req); err != nil {
|
||||
// e := errs.InvalidFormat(err.Error())
|
||||
// httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.RespOK{
|
||||
// Code: int(e.FullCode()),
|
||||
// Msg: err.Error(),
|
||||
// })
|
||||
//
|
||||
// return
|
||||
//}
|
||||
if err := svcCtx.Validate.ValidateAll(req); err != nil {
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
l := auth.NewLoginLogic(r.Context(), svcCtx)
|
||||
resp, err := l.Login(&req)
|
||||
|
||||
if err != nil {
|
||||
e := ers.FromError(err)
|
||||
e := errs.FromError(err)
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{
|
||||
Code: int(e.FullCode()),
|
||||
Msg: e.Error(),
|
||||
Error: e,
|
||||
})
|
||||
} else {
|
||||
httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.RespOK{
|
||||
Code: domain.SuccessCode,
|
||||
Msg: domain.SuccessMessage,
|
||||
Data: resp,
|
||||
httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
|
||||
Code: domain.SuccessCode,
|
||||
Message: domain.SuccessMessage,
|
||||
Data: resp,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"backend/internal/domain"
|
||||
"backend/pkg/library/errs"
|
||||
"net/http"
|
||||
|
||||
"backend/internal/logic/auth"
|
||||
|
|
@ -15,16 +17,40 @@ func RefreshTokenHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
|||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req types.RefreshTokenReq
|
||||
if err := httpx.Parse(r, &req); err != nil {
|
||||
httpx.ErrorCtx(r.Context(), w, err)
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := svcCtx.Validate.ValidateAll(req); err != nil {
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
l := auth.NewRefreshTokenLogic(r.Context(), svcCtx)
|
||||
resp, err := l.RefreshToken(&req)
|
||||
if err != nil {
|
||||
httpx.ErrorCtx(r.Context(), w, err)
|
||||
e := errs.FromError(err)
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{
|
||||
Code: int(e.FullCode()),
|
||||
Msg: e.Error(),
|
||||
Error: e,
|
||||
})
|
||||
} else {
|
||||
httpx.OkJsonCtx(r.Context(), w, resp)
|
||||
httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
|
||||
Code: domain.SuccessCode,
|
||||
Message: domain.SuccessMessage,
|
||||
Data: resp,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,15 +12,14 @@ import (
|
|||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
)
|
||||
|
||||
// 註冊新帳號
|
||||
func RegisterHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req types.LoginReq
|
||||
if err := httpx.Parse(r, &req); err != nil {
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.RespOK{
|
||||
Code: int(e.FullCode()),
|
||||
Msg: err.Error(),
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
|
|
@ -28,9 +27,9 @@ func RegisterHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
|||
|
||||
if err := svcCtx.Validate.ValidateAll(req); err != nil {
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.RespOK{
|
||||
Code: int(e.FullCode()),
|
||||
Msg: err.Error(),
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
|
|
@ -46,10 +45,10 @@ func RegisterHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
|||
Error: e,
|
||||
})
|
||||
} else {
|
||||
httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.RespOK{
|
||||
Code: domain.SuccessCode,
|
||||
Msg: domain.SuccessMessage,
|
||||
Data: resp,
|
||||
httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
|
||||
Code: domain.SuccessCode,
|
||||
Message: domain.SuccessMessage,
|
||||
Data: resp,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"backend/internal/domain"
|
||||
"backend/pkg/library/errs"
|
||||
"net/http"
|
||||
|
||||
"backend/internal/logic/auth"
|
||||
|
|
@ -15,16 +17,40 @@ func RequestPasswordResetHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
|||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req types.RequestPasswordResetReq
|
||||
if err := httpx.Parse(r, &req); err != nil {
|
||||
httpx.ErrorCtx(r.Context(), w, err)
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := svcCtx.Validate.ValidateAll(req); err != nil {
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
l := auth.NewRequestPasswordResetLogic(r.Context(), svcCtx)
|
||||
resp, err := l.RequestPasswordReset(&req)
|
||||
if err != nil {
|
||||
httpx.ErrorCtx(r.Context(), w, err)
|
||||
e := errs.FromError(err)
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{
|
||||
Code: int(e.FullCode()),
|
||||
Msg: e.Error(),
|
||||
Error: e,
|
||||
})
|
||||
} else {
|
||||
httpx.OkJsonCtx(r.Context(), w, resp)
|
||||
httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
|
||||
Code: domain.SuccessCode,
|
||||
Message: domain.SuccessMessage,
|
||||
Data: resp,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"backend/internal/domain"
|
||||
"backend/pkg/library/errs"
|
||||
"net/http"
|
||||
|
||||
"backend/internal/logic/auth"
|
||||
|
|
@ -15,16 +17,40 @@ func ResetPasswordHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
|||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req types.ResetPasswordReq
|
||||
if err := httpx.Parse(r, &req); err != nil {
|
||||
httpx.ErrorCtx(r.Context(), w, err)
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := svcCtx.Validate.ValidateAll(req); err != nil {
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
l := auth.NewResetPasswordLogic(r.Context(), svcCtx)
|
||||
resp, err := l.ResetPassword(&req)
|
||||
if err != nil {
|
||||
httpx.ErrorCtx(r.Context(), w, err)
|
||||
e := errs.FromError(err)
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{
|
||||
Code: int(e.FullCode()),
|
||||
Msg: e.Error(),
|
||||
Error: e,
|
||||
})
|
||||
} else {
|
||||
httpx.OkJsonCtx(r.Context(), w, resp)
|
||||
httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
|
||||
Code: domain.SuccessCode,
|
||||
Message: domain.SuccessMessage,
|
||||
Data: resp,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"backend/internal/domain"
|
||||
"backend/pkg/library/errs"
|
||||
"net/http"
|
||||
|
||||
"backend/internal/logic/auth"
|
||||
|
|
@ -15,16 +17,40 @@ func VerifyPasswordResetCodeHandler(svcCtx *svc.ServiceContext) http.HandlerFunc
|
|||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req types.VerifyCodeReq
|
||||
if err := httpx.Parse(r, &req); err != nil {
|
||||
httpx.ErrorCtx(r.Context(), w, err)
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := svcCtx.Validate.ValidateAll(req); err != nil {
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
l := auth.NewVerifyPasswordResetCodeLogic(r.Context(), svcCtx)
|
||||
resp, err := l.VerifyPasswordResetCode(&req)
|
||||
if err != nil {
|
||||
httpx.ErrorCtx(r.Context(), w, err)
|
||||
e := errs.FromError(err)
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{
|
||||
Code: int(e.FullCode()),
|
||||
Msg: e.Error(),
|
||||
Error: e,
|
||||
})
|
||||
} else {
|
||||
httpx.OkJsonCtx(r.Context(), w, resp)
|
||||
httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
|
||||
Code: domain.SuccessCode,
|
||||
Message: domain.SuccessMessage,
|
||||
Data: resp,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
package user
|
||||
|
||||
import (
|
||||
"backend/internal/domain"
|
||||
"backend/pkg/library/errs"
|
||||
"net/http"
|
||||
|
||||
"backend/internal/logic/user"
|
||||
"backend/internal/svc"
|
||||
"backend/internal/types"
|
||||
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
)
|
||||
|
||||
|
|
@ -15,16 +16,40 @@ func GetUserInfoHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
|
|||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req types.Authorization
|
||||
if err := httpx.Parse(r, &req); err != nil {
|
||||
httpx.ErrorCtx(r.Context(), w, err)
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := svcCtx.Validate.ValidateAll(req); err != nil {
|
||||
e := errs.InvalidFormat(err.Error())
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
|
||||
Code: int64(e.FullCode()),
|
||||
Message: err.Error(),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
l := user.NewGetUserInfoLogic(r.Context(), svcCtx)
|
||||
resp, err := l.GetUserInfo(&req)
|
||||
if err != nil {
|
||||
httpx.ErrorCtx(r.Context(), w, err)
|
||||
e := errs.FromError(err)
|
||||
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{
|
||||
Code: int(e.FullCode()),
|
||||
Msg: e.Error(),
|
||||
Error: e,
|
||||
})
|
||||
} else {
|
||||
httpx.OkJsonCtx(r.Context(), w, resp)
|
||||
httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
|
||||
Code: domain.SuccessCode,
|
||||
Message: domain.SuccessMessage,
|
||||
Data: resp,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,10 +10,7 @@ import (
|
|||
)
|
||||
|
||||
// 生成 Token
|
||||
func generateToken(svc *svc.ServiceContext, ctx context.Context, req *types.LoginReq, uid string) (entity.TokenResp, error) {
|
||||
// scope role 要修改,refresh tl
|
||||
role := "user"
|
||||
|
||||
func generateToken(svc *svc.ServiceContext, ctx context.Context, req *types.LoginReq, uid string, role string) (entity.TokenResp, error) {
|
||||
tk, err := svc.TokenUC.NewToken(ctx, entity.AuthorizationReq{
|
||||
GrantType: token.ClientCredentials.ToString(),
|
||||
DeviceID: uid, // TODO 沒傳暫時先用UID 替代
|
||||
|
|
|
|||
|
|
@ -79,7 +79,12 @@ func (l *LoginLogic) Login(req *types.LoginReq) (resp *types.LoginResp, err erro
|
|||
return nil, err
|
||||
}
|
||||
|
||||
tk, err := generateToken(l.svcCtx, l.ctx, req, account.UID)
|
||||
userRole, err := l.svcCtx.UserRoleUC.Get(l.ctx, account.UID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tk, err := generateToken(l.svcCtx, l.ctx, req, account.UID, userRole.RoleUID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"backend/pkg/library/errs/code"
|
||||
mb "backend/pkg/member/domain/member"
|
||||
member "backend/pkg/member/domain/usecase"
|
||||
"backend/pkg/permission/domain/usecase"
|
||||
"context"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
|
|
@ -76,9 +77,18 @@ func (l *RegisterLogic) Register(req *types.LoginReq) (resp *types.LoginResp, er
|
|||
return nil, err
|
||||
}
|
||||
|
||||
_, err = l.svcCtx.UserRoleUC.Assign(l.ctx, usecase.AssignRoleRequest{
|
||||
RoleUID: l.svcCtx.Config.RoleConfig.DefaultRoleName,
|
||||
UserUID: account.UID,
|
||||
Brand: "digimon",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Step 5: 生成 Token
|
||||
req.LoginID = bd.CreateAccountReq.LoginID
|
||||
tk, err := generateToken(l.svcCtx, l.ctx, req, account.UID)
|
||||
tk, err := generateToken(l.svcCtx, l.ctx, req, account.UID, l.svcCtx.Config.RoleConfig.DefaultRoleName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,14 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"backend/internal/domain"
|
||||
"backend/internal/utils"
|
||||
"backend/pkg/library/errs"
|
||||
"backend/pkg/library/errs/code"
|
||||
"backend/pkg/member/domain/member"
|
||||
"backend/pkg/member/domain/usecase"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"backend/internal/svc"
|
||||
"backend/internal/types"
|
||||
|
|
@ -25,7 +32,109 @@ func NewRequestPasswordResetLogic(ctx context.Context, svcCtx *svc.ServiceContex
|
|||
|
||||
// RequestPasswordReset 請求發送密碼重設驗證碼 aka 忘記密碼
|
||||
func (l *RequestPasswordResetLogic) RequestPasswordReset(req *types.RequestPasswordResetReq) (resp *types.RespOK, err error) {
|
||||
// todo: add your logic here and delete this line
|
||||
// 驗證並標準化帳號
|
||||
acc, err := l.validateAndNormalizeAccount(req.AccountType, req.Identifier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
// 檢查發送冷卻時間
|
||||
rk := domain.GenerateVerifyCodeRedisKey.With(fmt.Sprintf("%s:%d", acc, member.GenerateCodeTypeForgetPassword)).ToString()
|
||||
if err := l.checkVerifyCodeCooldown(rk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 確認帳號是否註冊並檢查平台限制
|
||||
if err := l.checkAccountAndPlatform(acc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 生成驗證碼
|
||||
vcode, err := l.svcCtx.AccountUC.GenerateRefreshCode(l.ctx, usecase.GenerateRefreshCodeRequest{
|
||||
LoginID: acc,
|
||||
CodeType: member.GenerateCodeTypeForgetPassword,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 獲取用戶資訊並確認綁定帳號
|
||||
account, err := l.svcCtx.AccountUC.GetUIDByAccount(l.ctx, usecase.GetUIDByAccountRequest{Account: acc})
|
||||
if err != nil {
|
||||
return nil, errs.ResourceNotFoundWithScope(code.CloudEPMember, 0, fmt.Sprintf("account not found:%s", acc))
|
||||
}
|
||||
info, err := l.svcCtx.AccountUC.GetUserInfo(l.ctx, usecase.GetUserInfoRequest{UID: account.UID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 發送驗證碼
|
||||
fmt.Println("======= send", vcode.Data.VerifyCode, &info)
|
||||
|
||||
//nickname := getEmailShowName(&info)
|
||||
//if err := l.sendVerificationCode(req.AccountType, acc, &info, vcode.Data.VerifyCode, nickname); err != nil {
|
||||
// return nil, err
|
||||
//}
|
||||
|
||||
// 設置 Redis 鍵
|
||||
l.setRedisKeyWithExpiry(rk, vcode.Data.VerifyCode, 60)
|
||||
|
||||
return &types.RespOK{}, nil
|
||||
}
|
||||
|
||||
// validateAndNormalizeAccount 驗證並標準化帳號
|
||||
func (l *RequestPasswordResetLogic) validateAndNormalizeAccount(accountType, account string) (string, error) {
|
||||
switch member.GetAccountTypeByCode(accountType) {
|
||||
case member.AccountTypePhone:
|
||||
phone, isPhone := utils.NormalizeTaiwanMobile(account)
|
||||
if !isPhone {
|
||||
return "", errs.InvalidFormatWithScope(code.CloudEPMember, "phone number is invalid")
|
||||
}
|
||||
|
||||
return phone, nil
|
||||
case member.AccountTypeMail:
|
||||
if !utils.IsValidEmail(account) {
|
||||
return "", errs.InvalidFormatWithScope(code.CloudEPMember, "email is invalid")
|
||||
}
|
||||
|
||||
return account, nil
|
||||
case member.AccountTypeNone, member.AccountTypeDefine:
|
||||
default:
|
||||
}
|
||||
|
||||
return "", errs.InvalidFormatWithScope(code.CloudEPMember, "unsupported account type")
|
||||
}
|
||||
|
||||
// checkVerifyCodeCooldown 檢查是否已在限制時間內發送過驗證碼
|
||||
func (l *RequestPasswordResetLogic) checkVerifyCodeCooldown(rk string) error {
|
||||
if cachedCode, err := l.svcCtx.Redis.GetCtx(l.ctx, rk); err != nil || cachedCode != "" {
|
||||
return errs.TooManyWithScope(code.CloudEPMember, "verification code already sent, please wait 3min for system to send again")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAccountAndPlatform 檢查帳號是否註冊及平台限制
|
||||
func (l *RequestPasswordResetLogic) checkAccountAndPlatform(acc string) error {
|
||||
accountInfo, err := l.svcCtx.AccountUC.GetUserAccountInfo(l.ctx, usecase.GetUIDByAccountRequest{Account: acc})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if accountInfo.Data.Platform != member.Digimon {
|
||||
return errs.InvalidFormatWithScope(code.CloudEPMember,
|
||||
"failed to send verify code since platform not correct")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setRedisKeyWithExpiry 設置 Redis 鍵
|
||||
func (l *RequestPasswordResetLogic) setRedisKeyWithExpiry(rk, verifyCode string, expiry int) {
|
||||
if status, err := l.svcCtx.Redis.SetnxExCtx(l.ctx, rk, verifyCode, expiry); err != nil || !status {
|
||||
_ = errs.DatabaseErrorWithScopeL(code.CloudEPMember, 0, logx.WithContext(l.ctx), []logx.LogField{
|
||||
{Key: "redisKey", Value: rk},
|
||||
{Key: "error", Value: err.Error()},
|
||||
}, "failed to set redis expire").Wrap(err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,15 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"backend/internal/domain"
|
||||
"backend/pkg/library/errs"
|
||||
"backend/pkg/library/errs/code"
|
||||
"backend/pkg/member/domain/member"
|
||||
"backend/pkg/member/domain/usecase"
|
||||
"backend/pkg/permission/domain/entity"
|
||||
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"backend/internal/svc"
|
||||
"backend/internal/types"
|
||||
|
|
@ -15,7 +23,7 @@ type ResetPasswordLogic struct {
|
|||
svcCtx *svc.ServiceContext
|
||||
}
|
||||
|
||||
// 執行密碼重設
|
||||
// NewResetPasswordLogic 執行密碼重設
|
||||
func NewResetPasswordLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ResetPasswordLogic {
|
||||
return &ResetPasswordLogic{
|
||||
Logger: logx.WithContext(ctx),
|
||||
|
|
@ -24,8 +32,58 @@ func NewResetPasswordLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Res
|
|||
}
|
||||
}
|
||||
|
||||
func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordReq) (resp *types.RespOK, err error) {
|
||||
// todo: add your logic here and delete this line
|
||||
func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordReq) (*types.RespOK, error) {
|
||||
// 驗證密碼,兩次密碼要一致
|
||||
if req.Password != req.PasswordConfirm {
|
||||
return nil, errs.InvalidFormatWithScope(code.CloudEPMember, "password confirmation does not match")
|
||||
}
|
||||
|
||||
return
|
||||
// 驗證碼
|
||||
err := l.svcCtx.AccountUC.VerifyRefreshCode(l.ctx, usecase.VerifyRefreshCodeRequest{
|
||||
LoginID: req.Identifier,
|
||||
CodeType: member.GenerateCodeTypeForgetPassword,
|
||||
VerifyCode: req.VerifyCode,
|
||||
})
|
||||
if err != nil {
|
||||
// 表使沒有這驗證碼
|
||||
return nil, errs.ForbiddenWithScope(code.CloudEPMember, 0, "failed to get verify code")
|
||||
}
|
||||
|
||||
info, err := l.svcCtx.AccountUC.GetUserAccountInfo(l.ctx, usecase.GetUIDByAccountRequest{Account: req.Identifier})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if info.Data.Platform != member.Digimon {
|
||||
return nil, errs.ForbiddenWithScope(code.CloudEPMember, 0, "invalid platform")
|
||||
}
|
||||
|
||||
// 更新
|
||||
err = l.svcCtx.AccountUC.UpdateUserToken(l.ctx, usecase.UpdateTokenRequest{
|
||||
Account: req.Identifier,
|
||||
Token: req.Password,
|
||||
Platform: member.Digimon.ToInt64(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rk := domain.GenerateVerifyCodeRedisKey.With(
|
||||
fmt.Sprintf("%s-%d", req.Identifier, member.GenerateCodeTypeForgetPassword),
|
||||
).ToString()
|
||||
|
||||
_, _ = l.svcCtx.Redis.Del(rk)
|
||||
|
||||
ac, err := l.svcCtx.AccountUC.GetUIDByAccount(l.ctx, usecase.GetUIDByAccountRequest{Account: req.Identifier})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = l.svcCtx.TokenUC.CancelTokens(l.ctx, entity.DoTokenByUIDReq{UID: ac.UID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 返回成功響應
|
||||
return &types.RespOK{}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,9 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"backend/pkg/library/errs"
|
||||
"backend/pkg/member/domain/member"
|
||||
"backend/pkg/member/domain/usecase"
|
||||
"context"
|
||||
|
||||
"backend/internal/svc"
|
||||
|
|
@ -25,7 +28,16 @@ func NewVerifyPasswordResetCodeLogic(ctx context.Context, svcCtx *svc.ServiceCon
|
|||
|
||||
// VerifyPasswordResetCode 校驗密碼重設驗證碼(頁面需求,預先檢查看看, 顯示表演用)
|
||||
func (l *VerifyPasswordResetCodeLogic) VerifyPasswordResetCode(req *types.VerifyCodeReq) (resp *types.RespOK, err error) {
|
||||
// todo: add your logic here and delete this line
|
||||
// 先驗證,不刪除
|
||||
if err := l.svcCtx.AccountUC.CheckRefreshCode(l.ctx, usecase.VerifyRefreshCodeRequest{
|
||||
VerifyCode: req.VerifyCode,
|
||||
LoginID: req.Identifier,
|
||||
CodeType: member.GenerateCodeTypeForgetPassword,
|
||||
}); err != nil {
|
||||
e := errs.Forbidden("failed to get verify code").Wrap(err)
|
||||
|
||||
return
|
||||
return nil, e
|
||||
}
|
||||
|
||||
return &types.RespOK{}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,12 @@
|
|||
package user
|
||||
|
||||
import (
|
||||
"backend/pkg/member/domain/member"
|
||||
"backend/pkg/member/domain/usecase"
|
||||
"backend/pkg/permission/domain/token"
|
||||
"context"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"time"
|
||||
|
||||
"backend/internal/svc"
|
||||
"backend/internal/types"
|
||||
|
|
@ -15,7 +20,7 @@ type GetUserInfoLogic struct {
|
|||
svcCtx *svc.ServiceContext
|
||||
}
|
||||
|
||||
// 取得當前登入的會員資訊(自己)
|
||||
// NewGetUserInfoLogic 取得當前登入的會員資訊(自己)
|
||||
func NewGetUserInfoLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserInfoLogic {
|
||||
return &GetUserInfoLogic{
|
||||
Logger: logx.WithContext(ctx),
|
||||
|
|
@ -24,8 +29,88 @@ func NewGetUserInfoLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUs
|
|||
}
|
||||
}
|
||||
|
||||
func (l *GetUserInfoLogic) GetUserInfo(req *types.Authorization) (resp *types.UserInfoResp, err error) {
|
||||
// todo: add your logic here and delete this line
|
||||
func (l *GetUserInfoLogic) GetUserInfo(req *types.Authorization) (*types.MyInfo, error) {
|
||||
uid := token.UID(l.ctx)
|
||||
info, err := l.svcCtx.AccountUC.GetUserInfo(l.ctx, usecase.GetUserInfoRequest{
|
||||
UID: uid,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
byUID, err := l.svcCtx.AccountUC.FindLoginIDByUID(l.ctx, uid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountInfo, err := l.svcCtx.AccountUC.GetUserAccountInfo(l.ctx, usecase.GetUIDByAccountRequest{
|
||||
Account: byUID.LoginID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userRole, err := l.svcCtx.UserRoleUC.Get(l.ctx, uid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
role := userRole.RoleUID
|
||||
res := &types.MyInfo{
|
||||
Platform: accountInfo.Data.Platform.ToString(),
|
||||
UID: info.UID,
|
||||
UpdateAt: time.Unix(0, info.CreateTime).UTC().Format(time.RFC3339),
|
||||
CreateAt: time.Unix(0, info.UpdateTime).UTC().Format(time.RFC3339),
|
||||
Role: role,
|
||||
UserStatus: info.UserStatus.CodeToString(),
|
||||
PreferredLanguage: info.PreferredLanguage,
|
||||
Currency: info.Currency,
|
||||
AlarmCategory: info.AlarmCategory.CodeToString(),
|
||||
}
|
||||
if info.Address != nil {
|
||||
res.Address = info.Address
|
||||
}
|
||||
if info.AvatarURL != nil {
|
||||
res.AvatarURL = info.AvatarURL
|
||||
}
|
||||
if info.FullName != nil {
|
||||
res.FullName = info.FullName
|
||||
}
|
||||
|
||||
if info.Birthdate != nil {
|
||||
b := ToDate(info.Birthdate)
|
||||
res.Birthdate = b
|
||||
}
|
||||
|
||||
if info.Address != nil {
|
||||
res.Address = info.Address
|
||||
}
|
||||
|
||||
if info.Nickname != nil {
|
||||
res.Nickname = info.Nickname
|
||||
}
|
||||
|
||||
if info.Email != nil {
|
||||
res.Email = info.Email
|
||||
res.IsEmailVerified = proto.Bool(true)
|
||||
}
|
||||
|
||||
if info.PhoneNumber != nil {
|
||||
res.PhoneNumber = info.PhoneNumber
|
||||
res.IsPhoneVerified = proto.Bool(true)
|
||||
}
|
||||
if info.GenderCode != nil {
|
||||
gc := member.GetGenderByCode(*info.GenderCode)
|
||||
res.GenderCode = &gc
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func ToDate(n *int64) *string {
|
||||
result := ""
|
||||
if n != nil {
|
||||
result = time.Unix(*n, 0).UTC().Format(time.DateOnly)
|
||||
}
|
||||
|
||||
return &result
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,19 +1,85 @@
|
|||
package middleware
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"backend/internal/types"
|
||||
"backend/pkg/library/errs"
|
||||
"backend/pkg/permission/domain/entity"
|
||||
"backend/pkg/permission/domain/token"
|
||||
"context"
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
|
||||
type AuthMiddleware struct {
|
||||
"backend/pkg/permission/domain/usecase"
|
||||
uc "backend/pkg/permission/usecase"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type AuthMiddlewareParam struct {
|
||||
TokenSec string
|
||||
TokenUseCase usecase.TokenUseCase
|
||||
}
|
||||
|
||||
func NewAuthMiddleware() *AuthMiddleware {
|
||||
return &AuthMiddleware{}
|
||||
type AuthMiddleware struct {
|
||||
TokenSec string
|
||||
TokenUseCase usecase.TokenUseCase
|
||||
}
|
||||
|
||||
func NewAuthMiddleware(param AuthMiddlewareParam) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
TokenSec: param.TokenSec,
|
||||
TokenUseCase: param.TokenUseCase,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// TODO generate middleware implement function, delete after code implementation
|
||||
// 解析 Header
|
||||
header := types.Authorization{}
|
||||
if err := httpx.ParseHeaders(r, &header); err != nil {
|
||||
m.writeErrorResponse(w, r, http.StatusBadRequest, "Failed to parse headers", int64(errs.InvalidFormat("").FullCode()))
|
||||
|
||||
// Passthrough to next handler if need
|
||||
next(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// 驗證 Token
|
||||
claim, err := uc.ParseClaims(header.Authorization, m.TokenSec, true)
|
||||
if err != nil {
|
||||
// 是否需要紀錄錯誤,是不是只要紀錄除了驗證失敗或過期之外的真錯誤
|
||||
m.writeErrorResponse(w, r,
|
||||
http.StatusUnauthorized, "failed to verify toke",
|
||||
int64(100400))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 驗證 Token 是否在黑名單中
|
||||
if _, err := m.TokenUseCase.ValidationToken(r.Context(), entity.ValidationTokenReq{Token: header.Authorization}); err != nil {
|
||||
m.writeErrorResponse(w, r, http.StatusForbidden,
|
||||
"failed to get toke",
|
||||
int64(100400))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 設置 context 並傳遞給下一個處理器
|
||||
ctx := SetContext(r, claim)
|
||||
next(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
func SetContext(r *http.Request, claim uc.TokenClaims) context.Context {
|
||||
ctx := context.WithValue(r.Context(), token.KeyRole, claim.Role())
|
||||
ctx = context.WithValue(ctx, token.KeyUID, claim.UID())
|
||||
ctx = context.WithValue(ctx, token.KeyDeviceID, claim.DeviceID())
|
||||
ctx = context.WithValue(ctx, token.KeyScope, claim.Scope())
|
||||
ctx = context.WithValue(ctx, token.KeyLoginID, claim.LoginID())
|
||||
|
||||
return ctx
|
||||
}
|
||||
|
||||
// writeErrorResponse 用於處理錯誤回應
|
||||
func (m *AuthMiddleware) writeErrorResponse(w http.ResponseWriter, r *http.Request, statusCode int, message string, code int64) {
|
||||
httpx.WriteJsonCtx(r.Context(), w, statusCode, types.ErrorResp{
|
||||
Code: int(code),
|
||||
Msg: message,
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,11 @@ type ServiceContext struct {
|
|||
AccountUC memberUC.AccountUseCase
|
||||
Validate vi.Validate
|
||||
TokenUC tokenUC.TokenUseCase
|
||||
PermissionUC tokenUC.PermissionUseCase
|
||||
RoleUC tokenUC.RoleUseCase
|
||||
RolePermission tokenUC.RolePermissionUseCase
|
||||
UserRoleUC tokenUC.UserRoleUseCase
|
||||
Redis *redis.Redis
|
||||
}
|
||||
|
||||
func NewServiceContext(c config.Config) *ServiceContext {
|
||||
|
|
@ -28,11 +33,22 @@ func NewServiceContext(c config.Config) *ServiceContext {
|
|||
}
|
||||
errs.Scope = code.CloudEPPortalGW
|
||||
|
||||
rp := NewPermissionUC(&c)
|
||||
tkUC := NewTokenUC(&c, rds)
|
||||
|
||||
return &ServiceContext{
|
||||
Config: c,
|
||||
AuthMiddleware: middleware.NewAuthMiddleware().Handle,
|
||||
Config: c,
|
||||
AuthMiddleware: middleware.NewAuthMiddleware(middleware.AuthMiddlewareParam{
|
||||
TokenSec: c.Token.AccessSecret,
|
||||
TokenUseCase: tkUC,
|
||||
}).Handle,
|
||||
AccountUC: NewAccountUC(&c, rds),
|
||||
Validate: vi.MustValidator(),
|
||||
TokenUC: NewTokenUC(&c, rds),
|
||||
TokenUC: tkUC,
|
||||
PermissionUC: rp.PermissionUC,
|
||||
RoleUC: rp.RoleUC,
|
||||
RolePermission: rp.RolePermission,
|
||||
UserRoleUC: rp.UserRole,
|
||||
Redis: rds,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,9 +2,12 @@ package svc
|
|||
|
||||
import (
|
||||
"backend/internal/config"
|
||||
mgo "backend/pkg/library/mongo"
|
||||
"backend/pkg/permission/domain/usecase"
|
||||
"backend/pkg/permission/repository"
|
||||
uc "backend/pkg/permission/usecase"
|
||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
||||
"github.com/zeromicro/go-zero/core/stores/mon"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
|
|
@ -16,3 +19,102 @@ func NewTokenUC(c *config.Config, rds *redis.Redis) usecase.TokenUseCase {
|
|||
Config: c,
|
||||
})
|
||||
}
|
||||
|
||||
type PermissionUC struct {
|
||||
PermissionUC usecase.PermissionUseCase
|
||||
RoleUC usecase.RoleUseCase
|
||||
RolePermission usecase.RolePermissionUseCase
|
||||
UserRole usecase.UserRoleUseCase
|
||||
}
|
||||
|
||||
func NewPermissionUC(c *config.Config) PermissionUC {
|
||||
// 準備Mongo Config
|
||||
conf := &mgo.Conf{
|
||||
Schema: c.Mongo.Schema,
|
||||
Host: c.Mongo.Host,
|
||||
Database: c.Mongo.Database,
|
||||
MaxStaleness: c.Mongo.MaxStaleness,
|
||||
MaxPoolSize: c.Mongo.MaxPoolSize,
|
||||
MinPoolSize: c.Mongo.MinPoolSize,
|
||||
MaxConnIdleTime: c.Mongo.MaxConnIdleTime,
|
||||
Compressors: c.Mongo.Compressors,
|
||||
EnableStandardReadWriteSplitMode: c.Mongo.EnableStandardReadWriteSplitMode,
|
||||
ConnectTimeoutMs: c.Mongo.ConnectTimeoutMs,
|
||||
}
|
||||
if c.Mongo.User != "" {
|
||||
conf.User = c.Mongo.User
|
||||
conf.Password = c.Mongo.Password
|
||||
}
|
||||
|
||||
// 快取選項
|
||||
cacheOpts := []cache.Option{
|
||||
cache.WithExpiry(c.CacheExpireTime),
|
||||
cache.WithNotFoundExpiry(c.CacheWithNotFoundExpiry),
|
||||
}
|
||||
dbOpts := []mon.Option{
|
||||
mgo.SetCustomDecimalType(),
|
||||
mgo.InitMongoOptions(*conf),
|
||||
}
|
||||
permRepo := repository.NewPermissionRepository(repository.PermissionRepositoryParam{
|
||||
Conf: conf,
|
||||
CacheConf: c.Cache,
|
||||
CacheOpts: cacheOpts,
|
||||
DBOpts: dbOpts,
|
||||
})
|
||||
|
||||
rolePermRepo := repository.NewRolePermissionRepository(repository.RolePermissionRepositoryParam{
|
||||
Conf: conf,
|
||||
CacheConf: c.Cache,
|
||||
CacheOpts: cacheOpts,
|
||||
DBOpts: dbOpts,
|
||||
})
|
||||
|
||||
roleRepo := repository.NewRoleRepository(repository.RoleRepositoryParam{
|
||||
Conf: conf,
|
||||
CacheConf: c.Cache,
|
||||
CacheOpts: cacheOpts,
|
||||
DBOpts: dbOpts,
|
||||
})
|
||||
|
||||
userRoleRepo := repository.NewUserRoleRepository(repository.UserRoleRepositoryParam{
|
||||
Conf: conf,
|
||||
CacheConf: c.Cache,
|
||||
CacheOpts: cacheOpts,
|
||||
DBOpts: dbOpts,
|
||||
})
|
||||
|
||||
puc := uc.NewPermissionUseCase(uc.PermissionUseCaseParam{
|
||||
RoleRepo: roleRepo,
|
||||
RolePermRepo: rolePermRepo,
|
||||
UserRoleRepo: userRoleRepo,
|
||||
PermRepo: permRepo,
|
||||
})
|
||||
rpuc := uc.NewRolePermissionUseCase(uc.RolePermissionUseCaseParam{
|
||||
RoleRepo: roleRepo,
|
||||
RolePermRepo: rolePermRepo,
|
||||
UserRoleRepo: userRoleRepo,
|
||||
PermRepo: permRepo,
|
||||
PermUseCase: puc,
|
||||
AdminRoleUID: c.RoleConfig.AdminRoleUID,
|
||||
})
|
||||
ruc := uc.NewRoleUseCase(uc.RoleUseCaseParam{
|
||||
RoleRepo: roleRepo,
|
||||
UserRoleRepo: userRoleRepo,
|
||||
Config: uc.RoleUseCaseConfig{
|
||||
AdminRoleUID: c.RoleConfig.AdminRoleUID,
|
||||
UIDPrefix: c.RoleConfig.UIDPrefix,
|
||||
UIDLength: c.RoleConfig.UIDLength,
|
||||
},
|
||||
RolePermUseCase: rpuc,
|
||||
})
|
||||
|
||||
return PermissionUC{
|
||||
PermissionUC: puc,
|
||||
RolePermission: rpuc,
|
||||
RoleUC: ruc,
|
||||
UserRole: uc.NewUserRoleUseCase(uc.UserRoleUseCaseParam{
|
||||
UserRoleRepo: userRoleRepo,
|
||||
RoleRepo: roleRepo,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,6 +37,30 @@ type LoginResp struct {
|
|||
TokenType string `json:"token_type"` // 通常固定為 "Bearer"
|
||||
}
|
||||
|
||||
type MyInfo struct {
|
||||
Platform string `json:"platform"` // 註冊平台
|
||||
UID string `json:"uid"` // 用戶 UID
|
||||
AvatarURL *string `json:"avatar_url,omitempty"` // 頭像 URL
|
||||
FullName *string `json:"full_name,omitempty"` // 用戶全名
|
||||
Nickname *string `json:"nickname,omitempty"` // 暱稱
|
||||
GenderCode *string `json:"gender_code,omitempty"` // 性別代碼
|
||||
Birthdate *string `json:"birthdate,omitempty"` // 生日 (格式: 1993-04-17)
|
||||
PhoneNumber *string `json:"phone_number,omitempty"` // 電話
|
||||
IsPhoneVerified *bool `json:"is_phone_verified,omitempty"` // 手機是否已驗證
|
||||
Email *string `json:"email,omitempty"` // 信箱
|
||||
IsEmailVerified *bool `json:"is_email_verified,omitempty"` // 信箱是否已驗證
|
||||
Address *string `json:"address,omitempty"` // 地址
|
||||
UserStatus string `json:"user_status,omitempty"` // 用戶狀態
|
||||
PreferredLanguage string `json:"preferred_language,omitempty"` // 偏好語言
|
||||
Currency string `json:"currency,omitempty"` // 偏好幣種
|
||||
AlarmCategory string `json:"alarm_category,omitempty"` // 告警狀態
|
||||
PostCode *string `json:"post_code,omitempty"` // 郵遞區號
|
||||
Carrier *string `json:"carrier,omitempty"` // 載具
|
||||
Role string `json:"role"` // 角色
|
||||
UpdateAt string `json:"update_at"`
|
||||
CreateAt string `json:"create_at"`
|
||||
}
|
||||
|
||||
type PagerResp struct {
|
||||
Total int64 `json:"total"`
|
||||
Size int64 `json:"size"`
|
||||
|
|
@ -60,7 +84,7 @@ type RefreshTokenResp struct {
|
|||
}
|
||||
|
||||
type RequestPasswordResetReq struct {
|
||||
Identifier string `json:"identifier" validate:"required,email|phone"` // 使用者帳號 (信箱或手機)
|
||||
Identifier string `json:"identifier" validate:"required"` // 使用者帳號 (信箱或手機)
|
||||
AccountType string `json:"account_type" validate:"required,oneof=email phone"`
|
||||
}
|
||||
|
||||
|
|
@ -77,9 +101,12 @@ type ResetPasswordReq struct {
|
|||
}
|
||||
|
||||
type RespOK struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
type Status struct {
|
||||
Code int64 `json:"code"` // 狀態碼
|
||||
Message string `json:"message"` // 訊息
|
||||
Data interface{} `json:"data,omitempty"` // 可選的資料,當有返回時才出現
|
||||
}
|
||||
|
||||
type SubmitVerificationCodeReq struct {
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
package utils
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NormalizeTaiwanMobile 標準化號碼並驗證是否為合法台灣手機號碼
|
||||
func NormalizeTaiwanMobile(phone string) (string, bool) {
|
||||
// 移除空格
|
||||
phone = strings.ReplaceAll(phone, " ", "")
|
||||
|
||||
// 移除 "+886" 並將剩餘部分標準化
|
||||
if strings.HasPrefix(phone, "+886") {
|
||||
phone = strings.TrimPrefix(phone, "+886")
|
||||
if !strings.HasPrefix(phone, "0") {
|
||||
phone = "0" + phone
|
||||
}
|
||||
}
|
||||
|
||||
// 正則表達式驗證標準化後的號碼
|
||||
regex := regexp.MustCompile(`^(09\d{8})$`)
|
||||
if regex.MatchString(phone) {
|
||||
return phone, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
||||
// IsValidEmail 驗證 Email 格式的函數
|
||||
func IsValidEmail(email string) bool {
|
||||
// 定義正則表達式
|
||||
regex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
|
||||
|
||||
return regex.MatchString(email)
|
||||
}
|
||||
|
|
@ -35,6 +35,7 @@ const (
|
|||
InsufficientQuota // 配額不足
|
||||
ResourceHasMultiOwner // 資源有多個所有者
|
||||
UserSuspended // 沒有權限使用該資源
|
||||
TooManyRequest // 單位時間內請求太多次
|
||||
)
|
||||
|
||||
/* 詳細代碼 - GRPC */
|
||||
|
|
@ -77,13 +78,13 @@ const (
|
|||
|
||||
// 詳細代碼 - Token 類 09x
|
||||
const (
|
||||
_ = iota + CatToken
|
||||
TokenCreateError // Token 創建錯誤
|
||||
TokenValidateError // Token 驗證錯誤
|
||||
TokenExpired // Token 過期
|
||||
TokenNotFound // Token 未找到
|
||||
TokenBlacklisted // Token 已被列入黑名單
|
||||
InvalidJWT // 無效的 JWT
|
||||
RefreshTokenError // Refresh Token 錯誤
|
||||
OneTimeTokenError // 一次性 Token 錯誤
|
||||
_ = iota + CatToken
|
||||
TokenCreateError // Token 創建錯誤
|
||||
TokenValidateError // Token 驗證錯誤
|
||||
TokenExpired // Token 過期
|
||||
TokenNotFound // Token 未找到
|
||||
TokenBlacklisted // Token 已被列入黑名單
|
||||
InvalidJWT // 無效的 JWT
|
||||
RefreshTokenError // Refresh Token 錯誤
|
||||
OneTimeTokenError // 一次性 Token 錯誤
|
||||
)
|
||||
|
|
|
|||
|
|
@ -556,3 +556,8 @@ func MsgSizeTooLargeL(l logx.Logger, filed []logx.LogField, s ...string) *LibErr
|
|||
|
||||
return e
|
||||
}
|
||||
|
||||
func TooManyWithScope(scope uint32, s ...string) *LibError {
|
||||
return NewError(scope, code.TooManyRequest, defaultDetailCode,
|
||||
fmt.Sprintf("%s", strings.Join(s, " ")))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -154,8 +154,13 @@ func (e *LibError) HTTPStatus() int {
|
|||
if e == nil || e.Code() == code.OK {
|
||||
return http.StatusOK
|
||||
}
|
||||
|
||||
// 將 code 轉換為與常量定義相同的格式 (category + detail)
|
||||
// 例如:code=3004 -> (3004%100) + 30 = 4 + 30 = 34
|
||||
codeValue := (e.Code() % 100) + e.Category()
|
||||
|
||||
// 根據錯誤碼判斷對應的 HTTP 狀態碼
|
||||
switch e.Code() / 100 {
|
||||
switch codeValue {
|
||||
case code.ResourceInsufficient, code.InvalidFormat:
|
||||
// 如果資源不足,返回 400 狀態碼
|
||||
return http.StatusBadRequest
|
||||
|
|
@ -177,6 +182,9 @@ func (e *LibError) HTTPStatus() int {
|
|||
case code.NotValidImplementation:
|
||||
// 如果實現無效,返回 501 狀態碼
|
||||
return http.StatusNotImplemented
|
||||
case code.TooManyRequest:
|
||||
// 如果實現無效,返回 501 狀態碼
|
||||
return http.StatusTooManyRequests
|
||||
default:
|
||||
// 如果沒有匹配的錯誤碼,則繼續下一步
|
||||
}
|
||||
|
|
|
|||
|
|
@ -71,12 +71,12 @@ func TestLibError_HTTPStatus(t *testing.T) {
|
|||
err *LibError
|
||||
expected int
|
||||
}{
|
||||
{"bad request", NewError(1, code.CatService, code.ResourceInsufficient, "bad request"), http.StatusBadRequest},
|
||||
{"unauthorized", NewError(1, code.CatAuth, code.Unauthorized, "unauthorized"), http.StatusUnauthorized},
|
||||
{"forbidden", NewError(1, code.CatAuth, code.Forbidden, "forbidden"), http.StatusForbidden},
|
||||
{"not found", NewError(1, code.CatResource, code.ResourceNotFound, "not found"), http.StatusNotFound},
|
||||
{"internal server error", NewError(1, code.CatDB, 1095, "not found"), http.StatusInternalServerError},
|
||||
{"input err", NewError(1, code.CatInput, 1095, "not found"), http.StatusBadRequest},
|
||||
{"bad request - ResourceInsufficient", NewError(1, code.CatResource, 4, "bad request"), http.StatusBadRequest},
|
||||
{"unauthorized", NewError(1, code.CatAuth, 1, "unauthorized"), http.StatusUnauthorized},
|
||||
{"forbidden", NewError(1, code.CatAuth, 5, "forbidden"), http.StatusForbidden},
|
||||
{"not found", NewError(1, code.CatResource, 1, "not found"), http.StatusNotFound},
|
||||
{"internal server error", NewError(1, code.CatDB, 95, "db error"), http.StatusInternalServerError},
|
||||
{"input err", NewError(1, code.CatInput, 1, "input error"), http.StatusBadRequest},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
|
|||
|
|
@ -19,3 +19,35 @@ const (
|
|||
const (
|
||||
CurrencyTWD Currency = "TWD"
|
||||
)
|
||||
|
||||
var genderMap = map[int64]string{
|
||||
0: "",
|
||||
1: "male",
|
||||
2: "female",
|
||||
3: "secret",
|
||||
}
|
||||
|
||||
func GetGenderByCode(g int64) string {
|
||||
r, ok := genderMap[g]
|
||||
if !ok {
|
||||
return genderMap[0]
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
var genderCodeMap = map[string]int64{
|
||||
"": 0,
|
||||
"male": 1,
|
||||
"female": 2,
|
||||
"secret": 3,
|
||||
}
|
||||
|
||||
func GetGenderCodeByStr(g string) int64 {
|
||||
r, ok := genderCodeMap[g]
|
||||
if !ok {
|
||||
return genderCodeMap[""]
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ type AccountUIDRepository interface {
|
|||
Update(ctx context.Context, data *entity.AccountUID) (*mongo.UpdateResult, error)
|
||||
Delete(ctx context.Context, id string) (int64, error)
|
||||
FindUIDByLoginID(ctx context.Context, loginID string) (*entity.AccountUID, error)
|
||||
FindOneByUID(ctx context.Context, uid string) (*entity.AccountUID, error)
|
||||
AccountUIDIndexUP
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,8 @@ type MemberUseCase interface {
|
|||
GetUserInfo(ctx context.Context, req GetUserInfoRequest) (UserInfo, error)
|
||||
// ListMember 取得會員列表
|
||||
ListMember(ctx context.Context, req ListUserInfoRequest) (ListUserInfoResponse, error)
|
||||
// FindLoginIDByUID 取得login id
|
||||
FindLoginIDByUID(ctx context.Context, uid string) (BindingUser, error)
|
||||
}
|
||||
|
||||
type BindingMemberUseCase interface {
|
||||
|
|
|
|||
|
|
@ -112,6 +112,20 @@ func (repo *AccountUIDRepository) FindUIDByLoginID(ctx context.Context, loginID
|
|||
}
|
||||
}
|
||||
|
||||
func (repo *AccountUIDRepository) FindOneByUID(ctx context.Context, uid string) (*entity.AccountUID, error) {
|
||||
var data entity.AccountUID
|
||||
|
||||
err := repo.DB.GetClient().FindOne(ctx, &data, bson.M{"uid": uid})
|
||||
switch {
|
||||
case err == nil:
|
||||
return &data, nil
|
||||
case errors.Is(err, mon.ErrNotFound):
|
||||
return nil, ErrNotFound
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (repo *AccountUIDRepository) Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error) {
|
||||
// 等價於 db.account_uid_binding.createIndex({"login_id": 1}, {unique: true})
|
||||
repo.DB.PopulateIndex(ctx, "login_id", 1, true)
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ func (repo *UserRepository) FindOneByUID(ctx context.Context, uid string) (*enti
|
|||
// 不常寫,再找一次可接受
|
||||
id := repo.UIDToID(ctx, uid)
|
||||
if id == "" {
|
||||
return nil, errors.New("invalid uid")
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
rk := domain.GetUserRedisKey(id)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"backend/pkg/member/domain/config"
|
||||
"backend/pkg/member/domain/repository"
|
||||
"backend/pkg/member/domain/usecase"
|
||||
"context"
|
||||
)
|
||||
|
||||
type MemberUseCaseParam struct {
|
||||
|
|
@ -24,3 +25,15 @@ func MustMemberUseCase(param MemberUseCaseParam) usecase.AccountUseCase {
|
|||
param,
|
||||
}
|
||||
}
|
||||
|
||||
func (use *MemberUseCase) FindLoginIDByUID(ctx context.Context, uid string) (usecase.BindingUser, error) {
|
||||
data, err := use.AccountUID.FindOneByUID(ctx, uid)
|
||||
if err != nil {
|
||||
return usecase.BindingUser{}, err
|
||||
}
|
||||
|
||||
return usecase.BindingUser{
|
||||
UID: data.UID,
|
||||
LoginID: data.LoginID,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
package config
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SMTPConfig struct {
|
||||
Enable bool
|
||||
|
|
@ -13,6 +17,35 @@ type SMTPConfig struct {
|
|||
Password string
|
||||
}
|
||||
|
||||
// Validate 驗證 SMTP 配置
|
||||
func (c *SMTPConfig) Validate() error {
|
||||
if !c.Enable {
|
||||
return nil // 未啟用則不驗證
|
||||
}
|
||||
|
||||
if c.Host == "" {
|
||||
return errors.New("smtp host is required")
|
||||
}
|
||||
|
||||
if c.Port <= 0 || c.Port > 65535 {
|
||||
return fmt.Errorf("smtp port must be between 1 and 65535, got %d", c.Port)
|
||||
}
|
||||
|
||||
if c.Username == "" {
|
||||
return errors.New("smtp username is required")
|
||||
}
|
||||
|
||||
if c.Password == "" {
|
||||
return errors.New("smtp password is required")
|
||||
}
|
||||
|
||||
if c.Sort < 0 {
|
||||
return fmt.Errorf("smtp sort must be >= 0, got %d", c.Sort)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type AmazonSesSettings struct {
|
||||
Enable bool
|
||||
Sort int
|
||||
|
|
@ -26,6 +59,39 @@ type AmazonSesSettings struct {
|
|||
Token string
|
||||
}
|
||||
|
||||
// Validate 驗證 AWS SES 配置
|
||||
func (c *AmazonSesSettings) Validate() error {
|
||||
if !c.Enable {
|
||||
return nil // 未啟用則不驗證
|
||||
}
|
||||
|
||||
if c.Region == "" {
|
||||
return errors.New("aws ses region is required")
|
||||
}
|
||||
|
||||
if c.Sender == "" {
|
||||
return errors.New("aws ses sender is required")
|
||||
}
|
||||
|
||||
if c.AccessKey == "" {
|
||||
return errors.New("aws ses access key is required")
|
||||
}
|
||||
|
||||
if c.SecretKey == "" {
|
||||
return errors.New("aws ses secret key is required")
|
||||
}
|
||||
|
||||
if c.Sort < 0 {
|
||||
return fmt.Errorf("aws ses sort must be >= 0, got %d", c.Sort)
|
||||
}
|
||||
|
||||
if c.PoolSize < 0 {
|
||||
return fmt.Errorf("aws ses pool size must be >= 0, got %d", c.PoolSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type MitakeSMSSender struct {
|
||||
Enable bool
|
||||
Sort int
|
||||
|
|
@ -35,6 +101,31 @@ type MitakeSMSSender struct {
|
|||
Password string
|
||||
}
|
||||
|
||||
// Validate 驗證 Mitake SMS 配置
|
||||
func (c *MitakeSMSSender) Validate() error {
|
||||
if !c.Enable {
|
||||
return nil // 未啟用則不驗證
|
||||
}
|
||||
|
||||
if c.User == "" {
|
||||
return errors.New("mitake user is required")
|
||||
}
|
||||
|
||||
if c.Password == "" {
|
||||
return errors.New("mitake password is required")
|
||||
}
|
||||
|
||||
if c.Sort < 0 {
|
||||
return fmt.Errorf("mitake sort must be >= 0, got %d", c.Sort)
|
||||
}
|
||||
|
||||
if c.PoolSize < 0 {
|
||||
return fmt.Errorf("mitake pool size must be >= 0, got %d", c.PoolSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeliveryConfig 傳送重試配置
|
||||
type DeliveryConfig struct {
|
||||
MaxRetries int `json:"max_retries"` // 最大重試次數
|
||||
|
|
@ -44,3 +135,72 @@ type DeliveryConfig struct {
|
|||
Timeout time.Duration `json:"timeout"` // 單次發送超時時間
|
||||
EnableHistory bool `json:"enable_history"` // 是否啟用歷史記錄
|
||||
}
|
||||
|
||||
// Validate 驗證 DeliveryConfig 配置
|
||||
func (c *DeliveryConfig) Validate() error {
|
||||
if c.MaxRetries < 0 {
|
||||
return fmt.Errorf("max_retries must be >= 0, got %d", c.MaxRetries)
|
||||
}
|
||||
|
||||
if c.MaxRetries > 10 {
|
||||
return fmt.Errorf("max_retries should not exceed 10, got %d", c.MaxRetries)
|
||||
}
|
||||
|
||||
if c.InitialDelay < 0 {
|
||||
return fmt.Errorf("initial_delay must be >= 0, got %v", c.InitialDelay)
|
||||
}
|
||||
|
||||
if c.InitialDelay > 10*time.Second {
|
||||
return fmt.Errorf("initial_delay is too large (> 10s), got %v", c.InitialDelay)
|
||||
}
|
||||
|
||||
if c.BackoffFactor < 1.0 {
|
||||
return fmt.Errorf("backoff_factor must be >= 1.0, got %v", c.BackoffFactor)
|
||||
}
|
||||
|
||||
if c.BackoffFactor > 10.0 {
|
||||
return fmt.Errorf("backoff_factor is too large (> 10.0), got %v", c.BackoffFactor)
|
||||
}
|
||||
|
||||
if c.MaxDelay < 0 {
|
||||
return fmt.Errorf("max_delay must be >= 0, got %v", c.MaxDelay)
|
||||
}
|
||||
|
||||
if c.MaxDelay > 5*time.Minute {
|
||||
return fmt.Errorf("max_delay is too large (> 5m), got %v", c.MaxDelay)
|
||||
}
|
||||
|
||||
if c.Timeout <= 0 {
|
||||
return fmt.Errorf("timeout must be > 0, got %v", c.Timeout)
|
||||
}
|
||||
|
||||
if c.Timeout > 5*time.Minute {
|
||||
return fmt.Errorf("timeout is too large (> 5m), got %v", c.Timeout)
|
||||
}
|
||||
|
||||
// 檢查 InitialDelay 和 MaxDelay 的關係
|
||||
if c.InitialDelay > c.MaxDelay && c.MaxDelay > 0 {
|
||||
return fmt.Errorf("initial_delay (%v) should not exceed max_delay (%v)", c.InitialDelay, c.MaxDelay)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDefaults 設置默認值
|
||||
func (c *DeliveryConfig) SetDefaults() {
|
||||
if c.MaxRetries == 0 {
|
||||
c.MaxRetries = 3
|
||||
}
|
||||
if c.InitialDelay == 0 {
|
||||
c.InitialDelay = 100 * time.Millisecond
|
||||
}
|
||||
if c.BackoffFactor == 0 {
|
||||
c.BackoffFactor = 2.0
|
||||
}
|
||||
if c.MaxDelay == 0 {
|
||||
c.MaxDelay = 30 * time.Second
|
||||
}
|
||||
if c.Timeout == 0 {
|
||||
c.Timeout = 30 * time.Second
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,314 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestSMTPConfig_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config SMTPConfig
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "有效的 SMTP 配置",
|
||||
config: SMTPConfig{
|
||||
Enable: true,
|
||||
Sort: 1,
|
||||
Host: "smtp.gmail.com",
|
||||
Port: 587,
|
||||
Username: "test@gmail.com",
|
||||
Password: "password",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "未啟用的配置(不驗證)",
|
||||
config: SMTPConfig{
|
||||
Enable: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "缺少 Host",
|
||||
config: SMTPConfig{
|
||||
Enable: true,
|
||||
Port: 587,
|
||||
Username: "test@gmail.com",
|
||||
Password: "password",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "smtp host is required",
|
||||
},
|
||||
{
|
||||
name: "無效的 Port",
|
||||
config: SMTPConfig{
|
||||
Enable: true,
|
||||
Host: "smtp.gmail.com",
|
||||
Port: 99999,
|
||||
Username: "test@gmail.com",
|
||||
Password: "password",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "smtp port must be between 1 and 65535",
|
||||
},
|
||||
{
|
||||
name: "缺少 Username",
|
||||
config: SMTPConfig{
|
||||
Enable: true,
|
||||
Host: "smtp.gmail.com",
|
||||
Port: 587,
|
||||
Password: "password",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "smtp username is required",
|
||||
},
|
||||
{
|
||||
name: "負數的 Sort",
|
||||
config: SMTPConfig{
|
||||
Enable: true,
|
||||
Sort: -1,
|
||||
Host: "smtp.gmail.com",
|
||||
Port: 587,
|
||||
Username: "test@gmail.com",
|
||||
Password: "password",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "smtp sort must be >= 0",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAmazonSesSettings_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config AmazonSesSettings
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "有效的 AWS SES 配置",
|
||||
config: AmazonSesSettings{
|
||||
Enable: true,
|
||||
Sort: 1,
|
||||
PoolSize: 10,
|
||||
Region: "us-west-2",
|
||||
Sender: "noreply@example.com",
|
||||
AccessKey: "AKIAIOSFODNN7EXAMPLE",
|
||||
SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "未啟用的配置",
|
||||
config: AmazonSesSettings{
|
||||
Enable: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "缺少 Region",
|
||||
config: AmazonSesSettings{
|
||||
Enable: true,
|
||||
Sender: "noreply@example.com",
|
||||
AccessKey: "AKIAIOSFODNN7EXAMPLE",
|
||||
SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "aws ses region is required",
|
||||
},
|
||||
{
|
||||
name: "缺少 AccessKey",
|
||||
config: AmazonSesSettings{
|
||||
Enable: true,
|
||||
Region: "us-west-2",
|
||||
Sender: "noreply@example.com",
|
||||
SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "aws ses access key is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMitakeSMSSender_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config MitakeSMSSender
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "有效的 Mitake 配置",
|
||||
config: MitakeSMSSender{
|
||||
Enable: true,
|
||||
Sort: 1,
|
||||
PoolSize: 5,
|
||||
User: "testuser",
|
||||
Password: "testpass",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "缺少 User",
|
||||
config: MitakeSMSSender{
|
||||
Enable: true,
|
||||
Password: "testpass",
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "mitake user is required",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeliveryConfig_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
config DeliveryConfig
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
name: "有效的配置",
|
||||
config: DeliveryConfig{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
MaxDelay: 30 * time.Second,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "MaxRetries 為負數",
|
||||
config: DeliveryConfig{
|
||||
MaxRetries: -1,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "max_retries must be >= 0",
|
||||
},
|
||||
{
|
||||
name: "MaxRetries 過大",
|
||||
config: DeliveryConfig{
|
||||
MaxRetries: 20,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "max_retries should not exceed 10",
|
||||
},
|
||||
{
|
||||
name: "BackoffFactor 小於 1.0",
|
||||
config: DeliveryConfig{
|
||||
MaxRetries: 3,
|
||||
BackoffFactor: 0.5,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "backoff_factor must be >= 1.0",
|
||||
},
|
||||
{
|
||||
name: "Timeout 為 0",
|
||||
config: DeliveryConfig{
|
||||
MaxRetries: 3,
|
||||
BackoffFactor: 2.0,
|
||||
Timeout: 0,
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "timeout must be > 0",
|
||||
},
|
||||
{
|
||||
name: "InitialDelay 大於 MaxDelay",
|
||||
config: DeliveryConfig{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: 1 * time.Minute,
|
||||
MaxDelay: 10 * time.Second,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
wantErr: true,
|
||||
errMsg: "initial_delay",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.config.Validate()
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeliveryConfig_SetDefaults(t *testing.T) {
|
||||
config := DeliveryConfig{}
|
||||
config.SetDefaults()
|
||||
|
||||
assert.Equal(t, 3, config.MaxRetries)
|
||||
assert.Equal(t, 100*time.Millisecond, config.InitialDelay)
|
||||
assert.Equal(t, 2.0, config.BackoffFactor)
|
||||
assert.Equal(t, 30*time.Second, config.MaxDelay)
|
||||
assert.Equal(t, 30*time.Second, config.Timeout)
|
||||
|
||||
// 測試不覆蓋已設置的值
|
||||
config2 := DeliveryConfig{
|
||||
MaxRetries: 5,
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
config2.SetDefaults()
|
||||
|
||||
assert.Equal(t, 5, config2.MaxRetries) // 保持原值
|
||||
assert.Equal(t, 60*time.Second, config2.Timeout) // 保持原值
|
||||
assert.Equal(t, 100*time.Millisecond, config2.InitialDelay) // 設置默認值
|
||||
}
|
||||
|
|
@ -8,6 +8,7 @@ const (
|
|||
FailedToSendEmailErrorCode
|
||||
FailedToSendSMSErrorCode
|
||||
FailedToGetTemplateErrorCode
|
||||
FailedToRenderTemplateErrorCode
|
||||
FailedToSaveHistoryErrorCode
|
||||
FailedToRetryDeliveryErrorCode
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,24 @@
|
|||
package usecase
|
||||
|
||||
import (
|
||||
"backend/pkg/notification/domain/entity"
|
||||
"backend/pkg/notification/domain/template"
|
||||
"context"
|
||||
)
|
||||
|
||||
type TemplateUseCase interface {
|
||||
// GetEmailTemplateByStatic 從靜態模板獲取郵件模板
|
||||
GetEmailTemplateByStatic(ctx context.Context, language template.Language, templateID template.Type) (template.EmailTemplate, error)
|
||||
|
||||
// GetEmailTemplate 獲取郵件模板(優先從資料庫,回退到靜態模板)
|
||||
GetEmailTemplate(ctx context.Context, language template.Language, templateID template.Type) (template.EmailTemplate, error)
|
||||
|
||||
// GetSMSTemplate 獲取 SMS 模板(優先從資料庫,回退到靜態模板)
|
||||
GetSMSTemplate(ctx context.Context, language template.Language, templateID template.Type) (SMSTemplateResp, error)
|
||||
|
||||
// RenderEmailTemplate 渲染郵件模板(替換變數)
|
||||
RenderEmailTemplate(ctx context.Context, tmpl template.EmailTemplate, params entity.TemplateParams) (EmailTemplateResp, error)
|
||||
|
||||
// RenderSMSTemplate 渲染 SMS 模板(替換變數)
|
||||
RenderSMSTemplate(ctx context.Context, tmpl SMSTemplateResp, params entity.TemplateParams) (SMSTemplateResp, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,11 +5,10 @@ import (
|
|||
"backend/pkg/notification/domain"
|
||||
"backend/pkg/notification/domain/repository"
|
||||
"context"
|
||||
"time"
|
||||
"fmt"
|
||||
|
||||
"backend/pkg/library/errs"
|
||||
"backend/pkg/library/errs/code"
|
||||
pool "backend/pkg/library/worker_pool"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/ses/types"
|
||||
|
|
@ -25,8 +24,8 @@ type AwsEmailDeliveryParam struct {
|
|||
}
|
||||
|
||||
type AwsEmailDeliveryRepository struct {
|
||||
Client *ses.Client
|
||||
Pool pool.WorkerPool
|
||||
Client *ses.Client
|
||||
Timeout int // 超時時間(秒),預設 30
|
||||
}
|
||||
|
||||
func MustAwsSesMailRepository(param AwsEmailDeliveryParam) repository.MailRepository {
|
||||
|
|
@ -42,70 +41,62 @@ func MustAwsSesMailRepository(param AwsEmailDeliveryParam) repository.MailReposi
|
|||
// 創建 SES 客戶端
|
||||
sesClient := ses.NewFromConfig(cfg)
|
||||
|
||||
// 設置默認超時時間
|
||||
timeout := 30
|
||||
if param.Conf.PoolSize > 0 {
|
||||
timeout = param.Conf.PoolSize // 可以復用這個配置項,或新增專門的 Timeout 配置
|
||||
}
|
||||
|
||||
return &AwsEmailDeliveryRepository{
|
||||
Client: sesClient,
|
||||
Pool: pool.NewWorkerPool(param.Conf.PoolSize),
|
||||
Client: sesClient,
|
||||
Timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (use *AwsEmailDeliveryRepository) SendMail(ctx context.Context, req repository.MailReq) error {
|
||||
err := use.Pool.Submit(func() {
|
||||
// 設置郵件參數
|
||||
to := make([]string, 0, len(req.To))
|
||||
to = append(to, req.To...)
|
||||
func (repo *AwsEmailDeliveryRepository) SendMail(ctx context.Context, req repository.MailReq) error {
|
||||
// 檢查 context 是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
input := &ses.SendEmailInput{
|
||||
Destination: &types.Destination{
|
||||
ToAddresses: to,
|
||||
},
|
||||
Message: &types.Message{
|
||||
Body: &types.Body{
|
||||
Html: &types.Content{
|
||||
Charset: aws.String("UTF-8"),
|
||||
Data: aws.String(req.Body),
|
||||
},
|
||||
},
|
||||
Subject: &types.Content{
|
||||
// 設置郵件參數
|
||||
to := make([]string, 0, len(req.To))
|
||||
to = append(to, req.To...)
|
||||
|
||||
input := &ses.SendEmailInput{
|
||||
Destination: &types.Destination{
|
||||
ToAddresses: to,
|
||||
},
|
||||
Message: &types.Message{
|
||||
Body: &types.Body{
|
||||
Html: &types.Content{
|
||||
Charset: aws.String("UTF-8"),
|
||||
Data: aws.String(req.Subject),
|
||||
Data: aws.String(req.Body),
|
||||
},
|
||||
},
|
||||
Source: aws.String(req.From),
|
||||
}
|
||||
Subject: &types.Content{
|
||||
Charset: aws.String("UTF-8"),
|
||||
Data: aws.String(req.Subject),
|
||||
},
|
||||
},
|
||||
Source: aws.String(req.From),
|
||||
}
|
||||
|
||||
// 發送郵件
|
||||
// TODO 不明原因送不出去,會被 context cancel 這裡先把它手動加到100sec
|
||||
newCtx, cancel := context.WithTimeout(context.Background(), 100*time.Second)
|
||||
defer cancel()
|
||||
|
||||
//nolint:contextcheck
|
||||
if _, err := use.Client.SendEmail(newCtx, input); err != nil {
|
||||
_ = errs.ThirdPartyErrorL(
|
||||
code.CloudEPNotification,
|
||||
domain.FailedToSendEmailErrorCode,
|
||||
logx.WithContext(ctx),
|
||||
[]logx.LogField{
|
||||
{Key: "req", Value: req},
|
||||
{Key: "func", Value: "AwsEmailDeliveryU.SendEmail"},
|
||||
{Key: "err", Value: err.Error()},
|
||||
},
|
||||
"failed to send mail by aws ses")
|
||||
}
|
||||
})
|
||||
// 發送郵件(直接使用傳入的 context,不創建新的 context)
|
||||
_, err := repo.Client.SendEmail(ctx, input)
|
||||
if err != nil {
|
||||
e := errs.ThirdPartyErrorL(
|
||||
return errs.ThirdPartyErrorL(
|
||||
code.CloudEPNotification,
|
||||
domain.FailedToSendEmailErrorCode,
|
||||
logx.WithContext(ctx),
|
||||
[]logx.LogField{
|
||||
{Key: "req", Value: req},
|
||||
{Key: "func", Value: "AwsEmailDeliveryU.SendEmail"},
|
||||
{Key: "func", Value: "AwsEmailDeliveryRepository.SendEmail"},
|
||||
{Key: "err", Value: err.Error()},
|
||||
},
|
||||
"failed to send mail by aws ses")
|
||||
|
||||
return e
|
||||
fmt.Sprintf("failed to send mail by aws ses: %v", err)).Wrap(err)
|
||||
}
|
||||
|
||||
logx.WithContext(ctx).Infof("Email sent successfully via AWS SES to %v", req.To)
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@ import (
|
|||
"backend/pkg/notification/domain"
|
||||
"backend/pkg/notification/domain/repository"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"backend/pkg/library/errs"
|
||||
"backend/pkg/library/errs/code"
|
||||
pool "backend/pkg/library/worker_pool"
|
||||
|
||||
"github.com/minchao/go-mitake"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
|
|
@ -21,45 +21,43 @@ type MitakeSMSDeliveryParam struct {
|
|||
|
||||
type MitakeSMSDeliveryRepository struct {
|
||||
Client *mitake.Client
|
||||
Pool pool.WorkerPool
|
||||
}
|
||||
|
||||
func (use *MitakeSMSDeliveryRepository) SendSMS(ctx context.Context, req repository.SMSMessageRequest) error {
|
||||
// 用 goroutine pool 送,否則會超時
|
||||
err := use.Pool.Submit(func() {
|
||||
message := mitake.Message{
|
||||
Dstaddr: req.PhoneNumber,
|
||||
Destname: req.RecipientName,
|
||||
Smbody: req.MessageContent,
|
||||
}
|
||||
_, err := use.Client.Send(message)
|
||||
if err != nil {
|
||||
logx.Error("failed to send sms via mitake")
|
||||
}
|
||||
})
|
||||
func (repo *MitakeSMSDeliveryRepository) SendSMS(ctx context.Context, req repository.SMSMessageRequest) error {
|
||||
// 檢查 context 是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 構建簡訊訊息
|
||||
message := mitake.Message{
|
||||
Dstaddr: req.PhoneNumber,
|
||||
Destname: req.RecipientName,
|
||||
Smbody: req.MessageContent,
|
||||
}
|
||||
|
||||
// 直接發送,不使用 goroutine pool
|
||||
// 讓 delivery usecase 統一管理重試和超時
|
||||
_, err := repo.Client.Send(message)
|
||||
if err != nil {
|
||||
// 錯誤代碼 20-201-04
|
||||
e := errs.ThirdPartyErrorL(
|
||||
return errs.ThirdPartyErrorL(
|
||||
code.CloudEPNotification,
|
||||
domain.FailedToSendSMSErrorCode,
|
||||
logx.WithContext(ctx),
|
||||
[]logx.LogField{
|
||||
{Key: "req", Value: req},
|
||||
{Key: "func", Value: "MitakeSMSDeliveryRepository.Client.Send"},
|
||||
{Key: "func", Value: "MitakeSMSDeliveryRepository.Send"},
|
||||
{Key: "err", Value: err.Error()},
|
||||
},
|
||||
"failed to send sns by mitake").Wrap(err)
|
||||
|
||||
return e
|
||||
fmt.Sprintf("failed to send sms by mitake: %v", err)).Wrap(err)
|
||||
}
|
||||
|
||||
logx.WithContext(ctx).Infof("SMS sent successfully via Mitake to %s", req.PhoneNumber)
|
||||
return nil
|
||||
}
|
||||
|
||||
func MustMitakeRepository(param MitakeSMSDeliveryParam) repository.SMSClientRepository {
|
||||
return &MitakeSMSDeliveryRepository{
|
||||
Client: mitake.NewClient(param.Conf.User, param.Conf.Password, nil),
|
||||
Pool: pool.NewWorkerPool(param.Conf.PoolSize),
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,10 +2,13 @@ package repository
|
|||
|
||||
import (
|
||||
"backend/pkg/notification/config"
|
||||
"backend/pkg/notification/domain"
|
||||
"backend/pkg/notification/domain/repository"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
pool "backend/pkg/library/worker_pool"
|
||||
"backend/pkg/library/errs"
|
||||
"backend/pkg/library/errs/code"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"gopkg.in/gomail.v2"
|
||||
|
|
@ -17,7 +20,6 @@ type SMTPMailUseCaseParam struct {
|
|||
|
||||
type SMTPMailRepository struct {
|
||||
Client *gomail.Dialer
|
||||
Pool pool.WorkerPool
|
||||
}
|
||||
|
||||
func MustSMTPUseCase(param SMTPMailUseCaseParam) repository.MailRepository {
|
||||
|
|
@ -28,26 +30,37 @@ func MustSMTPUseCase(param SMTPMailUseCaseParam) repository.MailRepository {
|
|||
param.Conf.Username,
|
||||
param.Conf.Password,
|
||||
),
|
||||
Pool: pool.NewWorkerPool(param.Conf.GoroutinePoolNum),
|
||||
}
|
||||
}
|
||||
|
||||
func (repo *SMTPMailRepository) SendMail(_ context.Context, req repository.MailReq) error {
|
||||
// 用 goroutine pool 送,否則會超時
|
||||
err := repo.Pool.Submit(func() {
|
||||
m := gomail.NewMessage()
|
||||
m.SetHeader("From", req.From)
|
||||
m.SetHeader("To", req.To...)
|
||||
m.SetHeader("Subject", req.Subject)
|
||||
m.SetBody("text/html", req.Body)
|
||||
if err := repo.Client.DialAndSend(m); err != nil {
|
||||
logx.WithCallerSkip(1).WithFields(
|
||||
logx.Field("func", "MailUseCase.SendMail"),
|
||||
logx.Field("req", req),
|
||||
logx.Field("err", err),
|
||||
).Error("failed to send mail by mailgun")
|
||||
}
|
||||
})
|
||||
func (repo *SMTPMailRepository) SendMail(ctx context.Context, req repository.MailReq) error {
|
||||
// 檢查 context 是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
return err
|
||||
// 構建郵件
|
||||
m := gomail.NewMessage()
|
||||
m.SetHeader("From", req.From)
|
||||
m.SetHeader("To", req.To...)
|
||||
m.SetHeader("Subject", req.Subject)
|
||||
m.SetBody("text/html", req.Body)
|
||||
|
||||
// 直接發送,不使用 goroutine pool
|
||||
// 讓 delivery usecase 統一管理重試和超時
|
||||
if err := repo.Client.DialAndSend(m); err != nil {
|
||||
return errs.ThirdPartyErrorL(
|
||||
code.CloudEPNotification,
|
||||
domain.FailedToSendEmailErrorCode,
|
||||
logx.WithContext(ctx),
|
||||
[]logx.LogField{
|
||||
{Key: "func", Value: "SMTPMailRepository.SendMail"},
|
||||
{Key: "req", Value: req},
|
||||
{Key: "err", Value: err.Error()},
|
||||
},
|
||||
fmt.Sprintf("failed to send mail by smtp: %v", err)).Wrap(err)
|
||||
}
|
||||
|
||||
logx.WithContext(ctx).Infof("Email sent successfully via SMTP to %v", req.To)
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,7 +74,10 @@ func (use *DeliveryUseCase) SendMessage(ctx context.Context, req usecase.SMSMess
|
|||
}
|
||||
|
||||
// 執行發送邏輯
|
||||
return use.sendSMSWithRetry(ctx, req, history)
|
||||
return use.sendWithRetry(ctx, history, &smsProviderAdapter{
|
||||
providers: use.param.SMSProviders,
|
||||
request: req,
|
||||
})
|
||||
}
|
||||
|
||||
func (use *DeliveryUseCase) SendEmail(ctx context.Context, req usecase.MailReq) error {
|
||||
|
|
@ -97,30 +100,130 @@ func (use *DeliveryUseCase) SendEmail(ctx context.Context, req usecase.MailReq)
|
|||
}
|
||||
|
||||
// 執行發送邏輯
|
||||
return use.sendEmailWithRetry(ctx, req, history)
|
||||
return use.sendWithRetry(ctx, history, &emailProviderAdapter{
|
||||
providers: use.param.EmailProviders,
|
||||
request: req,
|
||||
})
|
||||
}
|
||||
|
||||
// sendSMSWithRetry 發送 SMS 並實現重試機制
|
||||
func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SMSMessageRequest, history *entity.DeliveryHistory) error {
|
||||
// 根據 Sort 欄位對 SMSProviders 進行排序
|
||||
providers := make([]usecase.SMSProvider, len(use.param.SMSProviders))
|
||||
copy(providers, use.param.SMSProviders)
|
||||
sort.Slice(providers, func(i, j int) bool {
|
||||
return providers[i].Sort < providers[j].Sort
|
||||
// providerAdapter 統一的供應商適配器接口
|
||||
type providerAdapter interface {
|
||||
getProviderCount() int
|
||||
getProviderName(index int) string
|
||||
getProviderSort(index int) int64
|
||||
send(ctx context.Context, providerIndex int) error
|
||||
getErrorCode() errs.ErrorCode
|
||||
getType() string
|
||||
}
|
||||
|
||||
// smsProviderAdapter SMS 供應商適配器
|
||||
type smsProviderAdapter struct {
|
||||
providers []usecase.SMSProvider
|
||||
request usecase.SMSMessageRequest
|
||||
}
|
||||
|
||||
func (a *smsProviderAdapter) getProviderCount() int {
|
||||
return len(a.providers)
|
||||
}
|
||||
|
||||
func (a *smsProviderAdapter) getProviderName(index int) string {
|
||||
return fmt.Sprintf("sms_provider_%d", index)
|
||||
}
|
||||
|
||||
func (a *smsProviderAdapter) getProviderSort(index int) int64 {
|
||||
return a.providers[index].Sort
|
||||
}
|
||||
|
||||
func (a *smsProviderAdapter) send(ctx context.Context, providerIndex int) error {
|
||||
return a.providers[providerIndex].Repo.SendSMS(ctx, repository.SMSMessageRequest{
|
||||
PhoneNumber: a.request.PhoneNumber,
|
||||
RecipientName: a.request.RecipientName,
|
||||
MessageContent: a.request.MessageContent,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *smsProviderAdapter) getErrorCode() errs.ErrorCode {
|
||||
return domain.FailedToSendSMSErrorCode
|
||||
}
|
||||
|
||||
func (a *smsProviderAdapter) getType() string {
|
||||
return "SMS"
|
||||
}
|
||||
|
||||
// emailProviderAdapter Email 供應商適配器
|
||||
type emailProviderAdapter struct {
|
||||
providers []usecase.EmailProvider
|
||||
request usecase.MailReq
|
||||
}
|
||||
|
||||
func (a *emailProviderAdapter) getProviderCount() int {
|
||||
return len(a.providers)
|
||||
}
|
||||
|
||||
func (a *emailProviderAdapter) getProviderName(index int) string {
|
||||
return fmt.Sprintf("email_provider_%d", index)
|
||||
}
|
||||
|
||||
func (a *emailProviderAdapter) getProviderSort(index int) int64 {
|
||||
return a.providers[index].Sort
|
||||
}
|
||||
|
||||
func (a *emailProviderAdapter) send(ctx context.Context, providerIndex int) error {
|
||||
return a.providers[providerIndex].Repo.SendMail(ctx, repository.MailReq{
|
||||
From: a.request.From,
|
||||
To: a.request.To,
|
||||
Subject: a.request.Subject,
|
||||
Body: a.request.Body,
|
||||
})
|
||||
}
|
||||
|
||||
func (a *emailProviderAdapter) getErrorCode() errs.ErrorCode {
|
||||
return domain.FailedToSendEmailErrorCode
|
||||
}
|
||||
|
||||
func (a *emailProviderAdapter) getType() string {
|
||||
return "Email"
|
||||
}
|
||||
|
||||
// providerWithIndex 用於排序的結構
|
||||
type providerWithIndex struct {
|
||||
index int
|
||||
sort int64
|
||||
}
|
||||
|
||||
// sendWithRetry 統一的發送重試邏輯
|
||||
func (use *DeliveryUseCase) sendWithRetry(
|
||||
ctx context.Context,
|
||||
history *entity.DeliveryHistory,
|
||||
adapter providerAdapter,
|
||||
) error {
|
||||
// 按 Sort 欄位對供應商進行排序
|
||||
providerCount := adapter.getProviderCount()
|
||||
sortedProviders := make([]providerWithIndex, providerCount)
|
||||
for i := 0; i < providerCount; i++ {
|
||||
sortedProviders[i] = providerWithIndex{
|
||||
index: i,
|
||||
sort: adapter.getProviderSort(i),
|
||||
}
|
||||
}
|
||||
sort.Slice(sortedProviders, func(i, j int) bool {
|
||||
return sortedProviders[i].sort < sortedProviders[j].sort
|
||||
})
|
||||
|
||||
var lastErr error
|
||||
totalAttempts := 0
|
||||
|
||||
// 嘗試所有 providers
|
||||
for providerIndex, provider := range providers {
|
||||
for _, provider := range sortedProviders {
|
||||
providerIndex := provider.index
|
||||
|
||||
// 為每個 provider 嘗試發送
|
||||
for attempt := 0; attempt < use.param.DeliveryConfig.MaxRetries; attempt++ {
|
||||
totalAttempts++
|
||||
|
||||
// 更新歷史記錄狀態
|
||||
history.Status = entity.DeliveryStatusSending
|
||||
history.Provider = fmt.Sprintf("sms_provider_%d", providerIndex)
|
||||
history.Provider = adapter.getProviderName(providerIndex)
|
||||
history.AttemptCount = totalAttempts
|
||||
history.UpdatedAt = time.Now()
|
||||
use.updateHistory(ctx, history)
|
||||
|
|
@ -131,11 +234,7 @@ func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SM
|
|||
// 創建帶超時的 context
|
||||
sendCtx, cancel := context.WithTimeout(ctx, use.param.DeliveryConfig.Timeout)
|
||||
|
||||
err := provider.Repo.SendSMS(sendCtx, repository.SMSMessageRequest{
|
||||
PhoneNumber: req.PhoneNumber,
|
||||
RecipientName: req.RecipientName,
|
||||
MessageContent: req.MessageContent,
|
||||
})
|
||||
err := adapter.send(sendCtx, providerIndex)
|
||||
|
||||
cancel()
|
||||
|
||||
|
|
@ -153,8 +252,8 @@ func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SM
|
|||
attemptRecord.ErrorMessage = err.Error()
|
||||
lastErr = err
|
||||
|
||||
logx.WithContext(ctx).Errorf("SMS send attempt %d failed for provider %d: %v",
|
||||
attempt+1, providerIndex, err)
|
||||
logx.WithContext(ctx).Errorf("%s send attempt %d failed for provider %d: %v",
|
||||
adapter.getType(), attempt+1, providerIndex, err)
|
||||
|
||||
// 如果不是最後一次嘗試,等待後重試
|
||||
if attempt < use.param.DeliveryConfig.MaxRetries-1 {
|
||||
|
|
@ -179,7 +278,8 @@ func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SM
|
|||
use.updateHistory(ctx, history)
|
||||
use.addAttemptRecord(ctx, history.ID, attemptRecord)
|
||||
|
||||
logx.WithContext(ctx).Infof("SMS sent successfully after %d attempts", totalAttempts)
|
||||
logx.WithContext(ctx).Infof("%s sent successfully after %d attempts",
|
||||
adapter.getType(), totalAttempts)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -197,112 +297,9 @@ func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SM
|
|||
|
||||
return errs.ThirdPartyError(
|
||||
code.CloudEPNotification,
|
||||
domain.FailedToSendSMSErrorCode,
|
||||
fmt.Sprintf("Failed to send SMS after %d attempts across %d providers",
|
||||
totalAttempts, len(providers)))
|
||||
}
|
||||
|
||||
// sendEmailWithRetry 發送 Email 並實現重試機制
|
||||
func (use *DeliveryUseCase) sendEmailWithRetry(ctx context.Context, req usecase.MailReq, history *entity.DeliveryHistory) error {
|
||||
// 根據 Sort 欄位對 EmailProviders 進行排序
|
||||
providers := make([]usecase.EmailProvider, len(use.param.EmailProviders))
|
||||
copy(providers, use.param.EmailProviders)
|
||||
sort.Slice(providers, func(i, j int) bool {
|
||||
return providers[i].Sort < providers[j].Sort
|
||||
})
|
||||
|
||||
var lastErr error
|
||||
totalAttempts := 0
|
||||
|
||||
// 嘗試所有 providers
|
||||
for providerIndex, provider := range providers {
|
||||
// 為每個 provider 嘗試發送
|
||||
for attempt := 0; attempt < use.param.DeliveryConfig.MaxRetries; attempt++ {
|
||||
totalAttempts++
|
||||
|
||||
// 更新歷史記錄狀態
|
||||
history.Status = entity.DeliveryStatusSending
|
||||
history.Provider = fmt.Sprintf("email_provider_%d", providerIndex)
|
||||
history.AttemptCount = totalAttempts
|
||||
history.UpdatedAt = time.Now()
|
||||
use.updateHistory(ctx, history)
|
||||
|
||||
// 記錄發送嘗試
|
||||
attemptStart := time.Now()
|
||||
|
||||
// 創建帶超時的 context
|
||||
sendCtx, cancel := context.WithTimeout(ctx, use.param.DeliveryConfig.Timeout)
|
||||
|
||||
err := provider.Repo.SendMail(sendCtx, repository.MailReq{
|
||||
From: req.From,
|
||||
To: req.To,
|
||||
Subject: req.Subject,
|
||||
Body: req.Body,
|
||||
})
|
||||
|
||||
cancel()
|
||||
|
||||
// 記錄嘗試結果
|
||||
attemptDuration := time.Since(attemptStart)
|
||||
attemptRecord := entity.DeliveryAttempt{
|
||||
Provider: history.Provider,
|
||||
AttemptAt: attemptStart,
|
||||
Success: err == nil,
|
||||
ErrorMessage: "",
|
||||
Duration: attemptDuration.Milliseconds(),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
attemptRecord.ErrorMessage = err.Error()
|
||||
lastErr = err
|
||||
|
||||
logx.WithContext(ctx).Errorf("Email send attempt %d failed for provider %d: %v",
|
||||
attempt+1, providerIndex, err)
|
||||
|
||||
// 如果不是最後一次嘗試,等待後重試
|
||||
if attempt < use.param.DeliveryConfig.MaxRetries-1 {
|
||||
delay := use.calculateDelay(attempt)
|
||||
history.Status = entity.DeliveryStatusRetrying
|
||||
use.updateHistory(ctx, history)
|
||||
use.addAttemptRecord(ctx, history.ID, attemptRecord)
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 發送成功
|
||||
history.Status = entity.DeliveryStatusSuccess
|
||||
history.UpdatedAt = time.Now()
|
||||
now := time.Now()
|
||||
history.CompletedAt = &now
|
||||
use.updateHistory(ctx, history)
|
||||
use.addAttemptRecord(ctx, history.ID, attemptRecord)
|
||||
|
||||
logx.WithContext(ctx).Infof("Email sent successfully after %d attempts", totalAttempts)
|
||||
return nil
|
||||
}
|
||||
|
||||
use.addAttemptRecord(ctx, history.ID, attemptRecord)
|
||||
}
|
||||
}
|
||||
|
||||
// 所有 providers 都失敗了
|
||||
history.Status = entity.DeliveryStatusFailed
|
||||
history.ErrorMessage = fmt.Sprintf("All providers failed. Last error: %v", lastErr)
|
||||
history.UpdatedAt = time.Now()
|
||||
now := time.Now()
|
||||
history.CompletedAt = &now
|
||||
use.updateHistory(ctx, history)
|
||||
|
||||
return errs.ThirdPartyError(
|
||||
code.CloudEPNotification,
|
||||
domain.FailedToSendEmailErrorCode,
|
||||
fmt.Sprintf("Failed to send email after %d attempts across %d providers",
|
||||
totalAttempts, len(providers)))
|
||||
adapter.getErrorCode(),
|
||||
fmt.Sprintf("Failed to send %s after %d attempts across %d providers",
|
||||
adapter.getType(), totalAttempts, providerCount))
|
||||
}
|
||||
|
||||
// calculateDelay 計算指數退避延遲
|
||||
|
|
|
|||
|
|
@ -0,0 +1,377 @@
|
|||
package usecase
|
||||
|
||||
import (
|
||||
"backend/pkg/notification/config"
|
||||
"backend/pkg/notification/domain/entity"
|
||||
"backend/pkg/notification/domain/repository"
|
||||
"backend/pkg/notification/domain/usecase"
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// mockSMSRepository 模擬 SMS Repository
|
||||
type mockSMSRepository struct {
|
||||
sendFunc func(ctx context.Context, req repository.SMSMessageRequest) error
|
||||
}
|
||||
|
||||
func (m *mockSMSRepository) SendSMS(ctx context.Context, req repository.SMSMessageRequest) error {
|
||||
if m.sendFunc != nil {
|
||||
return m.sendFunc(ctx, req)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mockMailRepository 模擬 Mail Repository
|
||||
type mockMailRepository struct {
|
||||
sendFunc func(ctx context.Context, req repository.MailReq) error
|
||||
}
|
||||
|
||||
func (m *mockMailRepository) SendMail(ctx context.Context, req repository.MailReq) error {
|
||||
if m.sendFunc != nil {
|
||||
return m.sendFunc(ctx, req)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mockHistoryRepository 模擬 History Repository
|
||||
type mockHistoryRepository struct {
|
||||
histories []entity.DeliveryHistory
|
||||
attempts map[string][]entity.DeliveryAttempt
|
||||
}
|
||||
|
||||
func (m *mockHistoryRepository) CreateHistory(ctx context.Context, history *entity.DeliveryHistory) error {
|
||||
m.histories = append(m.histories, *history)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHistoryRepository) UpdateHistory(ctx context.Context, history *entity.DeliveryHistory) error {
|
||||
for i := range m.histories {
|
||||
if m.histories[i].ID == history.ID {
|
||||
m.histories[i] = *history
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHistoryRepository) GetHistory(ctx context.Context, id string) (*entity.DeliveryHistory, error) {
|
||||
for i := range m.histories {
|
||||
if m.histories[i].ID == id {
|
||||
return &m.histories[i], nil
|
||||
}
|
||||
}
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
|
||||
func (m *mockHistoryRepository) AddAttempt(ctx context.Context, historyID string, attempt entity.DeliveryAttempt) error {
|
||||
if m.attempts == nil {
|
||||
m.attempts = make(map[string][]entity.DeliveryAttempt)
|
||||
}
|
||||
m.attempts[historyID] = append(m.attempts[historyID], attempt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockHistoryRepository) ListHistory(ctx context.Context, filter repository.HistoryFilter) ([]*entity.DeliveryHistory, error) {
|
||||
var result []*entity.DeliveryHistory
|
||||
for i := range m.histories {
|
||||
result = append(result, &m.histories[i])
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func TestDeliveryUseCase_SendEmail_Success(t *testing.T) {
|
||||
mockMail := &mockMailRepository{
|
||||
sendFunc: func(ctx context.Context, req repository.MailReq) error {
|
||||
return nil // 成功
|
||||
},
|
||||
}
|
||||
|
||||
mockHistory := &mockHistoryRepository{}
|
||||
|
||||
uc := MustDeliveryUseCase(DeliveryUseCaseParam{
|
||||
EmailProviders: []usecase.EmailProvider{
|
||||
{Sort: 1, Repo: mockMail},
|
||||
},
|
||||
DeliveryConfig: config.DeliveryConfig{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
Timeout: 5 * time.Second,
|
||||
EnableHistory: true,
|
||||
},
|
||||
HistoryRepo: mockHistory,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := uc.SendEmail(ctx, usecase.MailReq{
|
||||
From: "test@example.com",
|
||||
To: []string{"user@example.com"},
|
||||
Subject: "Test",
|
||||
Body: "<p>Test email</p>",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
// 驗證歷史記錄
|
||||
assert.Equal(t, 1, len(mockHistory.histories))
|
||||
assert.Equal(t, entity.DeliveryStatusSuccess, mockHistory.histories[0].Status)
|
||||
}
|
||||
|
||||
func TestDeliveryUseCase_SendEmail_RetryAndSuccess(t *testing.T) {
|
||||
attemptCount := 0
|
||||
mockMail := &mockMailRepository{
|
||||
sendFunc: func(ctx context.Context, req repository.MailReq) error {
|
||||
attemptCount++
|
||||
if attemptCount < 3 {
|
||||
return errors.New("temporary error")
|
||||
}
|
||||
return nil // 第三次成功
|
||||
},
|
||||
}
|
||||
|
||||
mockHistory := &mockHistoryRepository{}
|
||||
|
||||
uc := MustDeliveryUseCase(DeliveryUseCaseParam{
|
||||
EmailProviders: []usecase.EmailProvider{
|
||||
{Sort: 1, Repo: mockMail},
|
||||
},
|
||||
DeliveryConfig: config.DeliveryConfig{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
Timeout: 5 * time.Second,
|
||||
EnableHistory: true,
|
||||
},
|
||||
HistoryRepo: mockHistory,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := uc.SendEmail(ctx, usecase.MailReq{
|
||||
From: "test@example.com",
|
||||
To: []string{"user@example.com"},
|
||||
Subject: "Test",
|
||||
Body: "<p>Test</p>",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 3, attemptCount) // 重試了 3 次
|
||||
assert.Equal(t, 1, len(mockHistory.histories))
|
||||
assert.Equal(t, entity.DeliveryStatusSuccess, mockHistory.histories[0].Status)
|
||||
assert.Equal(t, 3, mockHistory.histories[0].AttemptCount)
|
||||
}
|
||||
|
||||
func TestDeliveryUseCase_SendEmail_AllRetries_Failed(t *testing.T) {
|
||||
mockMail := &mockMailRepository{
|
||||
sendFunc: func(ctx context.Context, req repository.MailReq) error {
|
||||
return errors.New("persistent error")
|
||||
},
|
||||
}
|
||||
|
||||
mockHistory := &mockHistoryRepository{}
|
||||
|
||||
uc := MustDeliveryUseCase(DeliveryUseCaseParam{
|
||||
EmailProviders: []usecase.EmailProvider{
|
||||
{Sort: 1, Repo: mockMail},
|
||||
},
|
||||
DeliveryConfig: config.DeliveryConfig{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
Timeout: 5 * time.Second,
|
||||
EnableHistory: true,
|
||||
},
|
||||
HistoryRepo: mockHistory,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := uc.SendEmail(ctx, usecase.MailReq{
|
||||
From: "test@example.com",
|
||||
To: []string{"user@example.com"},
|
||||
Subject: "Test",
|
||||
Body: "<p>Test</p>",
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "Failed to send Email")
|
||||
assert.Equal(t, 1, len(mockHistory.histories))
|
||||
assert.Equal(t, entity.DeliveryStatusFailed, mockHistory.histories[0].Status)
|
||||
assert.Equal(t, 3, mockHistory.histories[0].AttemptCount) // 嘗試了 3 次
|
||||
}
|
||||
|
||||
func TestDeliveryUseCase_SendEmail_Failover(t *testing.T) {
|
||||
mockMail1 := &mockMailRepository{
|
||||
sendFunc: func(ctx context.Context, req repository.MailReq) error {
|
||||
return errors.New("provider 1 failed")
|
||||
},
|
||||
}
|
||||
|
||||
mockMail2 := &mockMailRepository{
|
||||
sendFunc: func(ctx context.Context, req repository.MailReq) error {
|
||||
return nil // 備援成功
|
||||
},
|
||||
}
|
||||
|
||||
mockHistory := &mockHistoryRepository{}
|
||||
|
||||
uc := MustDeliveryUseCase(DeliveryUseCaseParam{
|
||||
EmailProviders: []usecase.EmailProvider{
|
||||
{Sort: 1, Repo: mockMail1}, // 主要供應商
|
||||
{Sort: 2, Repo: mockMail2}, // 備援供應商
|
||||
},
|
||||
DeliveryConfig: config.DeliveryConfig{
|
||||
MaxRetries: 2,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
Timeout: 5 * time.Second,
|
||||
EnableHistory: true,
|
||||
},
|
||||
HistoryRepo: mockHistory,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := uc.SendEmail(ctx, usecase.MailReq{
|
||||
From: "test@example.com",
|
||||
To: []string{"user@example.com"},
|
||||
Subject: "Test",
|
||||
Body: "<p>Test</p>",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
// 驗證使用了備援供應商
|
||||
assert.Equal(t, 1, len(mockHistory.histories))
|
||||
assert.Equal(t, entity.DeliveryStatusSuccess, mockHistory.histories[0].Status)
|
||||
// 總共嘗試次數:provider1 重試 2 次 + provider2 成功 1 次 = 3 次
|
||||
assert.Equal(t, 3, mockHistory.histories[0].AttemptCount)
|
||||
}
|
||||
|
||||
func TestDeliveryUseCase_SendSMS_Success(t *testing.T) {
|
||||
mockSMS := &mockSMSRepository{
|
||||
sendFunc: func(ctx context.Context, req repository.SMSMessageRequest) error {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
mockHistory := &mockHistoryRepository{}
|
||||
|
||||
uc := MustDeliveryUseCase(DeliveryUseCaseParam{
|
||||
SMSProviders: []usecase.SMSProvider{
|
||||
{Sort: 1, Repo: mockSMS},
|
||||
},
|
||||
DeliveryConfig: config.DeliveryConfig{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: 10 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
Timeout: 5 * time.Second,
|
||||
EnableHistory: true,
|
||||
},
|
||||
HistoryRepo: mockHistory,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
err := uc.SendMessage(ctx, usecase.SMSMessageRequest{
|
||||
PhoneNumber: "+886912345678",
|
||||
RecipientName: "Test User",
|
||||
MessageContent: "Your code: 123456",
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(mockHistory.histories))
|
||||
assert.Equal(t, entity.DeliveryStatusSuccess, mockHistory.histories[0].Status)
|
||||
}
|
||||
|
||||
func TestDeliveryUseCase_CalculateDelay(t *testing.T) {
|
||||
uc := &DeliveryUseCase{
|
||||
param: DeliveryUseCaseParam{
|
||||
DeliveryConfig: config.DeliveryConfig{
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
MaxDelay: 1 * time.Second,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attempt int
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
name: "第 0 次重試",
|
||||
attempt: 0,
|
||||
expected: 100 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
name: "第 1 次重試",
|
||||
attempt: 1,
|
||||
expected: 200 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
name: "第 2 次重試",
|
||||
attempt: 2,
|
||||
expected: 400 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
name: "第 3 次重試",
|
||||
attempt: 3,
|
||||
expected: 800 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
name: "第 10 次重試(達到 MaxDelay)",
|
||||
attempt: 10,
|
||||
expected: 1 * time.Second, // 受限於 MaxDelay
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
delay := uc.calculateDelay(tt.attempt)
|
||||
assert.Equal(t, tt.expected, delay)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeliveryUseCase_ContextCancellation(t *testing.T) {
|
||||
mockMail := &mockMailRepository{
|
||||
sendFunc: func(ctx context.Context, req repository.MailReq) error {
|
||||
// 模擬慢速操作
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return errors.New("should not reach here")
|
||||
},
|
||||
}
|
||||
|
||||
uc := MustDeliveryUseCase(DeliveryUseCaseParam{
|
||||
EmailProviders: []usecase.EmailProvider{
|
||||
{Sort: 1, Repo: mockMail},
|
||||
},
|
||||
DeliveryConfig: config.DeliveryConfig{
|
||||
MaxRetries: 3,
|
||||
InitialDelay: 50 * time.Millisecond,
|
||||
BackoffFactor: 2.0,
|
||||
MaxDelay: 500 * time.Millisecond,
|
||||
Timeout: 1 * time.Second,
|
||||
EnableHistory: false,
|
||||
},
|
||||
})
|
||||
|
||||
// 創建會被取消的 context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := uc.SendEmail(ctx, usecase.MailReq{
|
||||
From: "test@example.com",
|
||||
To: []string{"user@example.com"},
|
||||
Subject: "Test",
|
||||
Body: "<p>Test</p>",
|
||||
})
|
||||
|
||||
assert.Error(t, err)
|
||||
assert.Equal(t, context.DeadlineExceeded, err)
|
||||
}
|
||||
|
|
@ -0,0 +1,255 @@
|
|||
package usecase
|
||||
|
||||
import (
|
||||
"backend/pkg/notification/domain/entity"
|
||||
"backend/pkg/notification/domain/template"
|
||||
"backend/pkg/notification/domain/usecase"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTemplateUseCase_RenderEmailTemplate(t *testing.T) {
|
||||
uc := MustTemplateUseCase(TemplateUseCaseParam{
|
||||
TemplateRepo: nil,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tmpl template.EmailTemplate
|
||||
params entity.TemplateParams
|
||||
expectedSubj string
|
||||
expectedBody string
|
||||
shouldContain []string
|
||||
shouldNotError bool
|
||||
}{
|
||||
{
|
||||
name: "渲染基本參數",
|
||||
tmpl: template.EmailTemplate{
|
||||
Title: "Hello {{.Username}}",
|
||||
Body: "<p>Your code is: {{.VerifyCode}}</p>",
|
||||
},
|
||||
params: entity.TemplateParams{
|
||||
Username: "張三",
|
||||
VerifyCode: "123456",
|
||||
},
|
||||
expectedSubj: "Hello 張三",
|
||||
shouldContain: []string{"123456"},
|
||||
shouldNotError: true,
|
||||
},
|
||||
{
|
||||
name: "渲染額外參數",
|
||||
tmpl: template.EmailTemplate{
|
||||
Title: "Welcome",
|
||||
Body: "<p>Hello {{.Username}}, your link: {{.Link}}</p>",
|
||||
},
|
||||
params: entity.TemplateParams{
|
||||
Username: "John",
|
||||
Extra: map[string]string{
|
||||
"Link": "https://example.com",
|
||||
},
|
||||
},
|
||||
shouldContain: []string{"John", "https://example.com"},
|
||||
shouldNotError: true,
|
||||
},
|
||||
{
|
||||
name: "特殊字符不轉義(簡單字符串替換)",
|
||||
tmpl: template.EmailTemplate{
|
||||
Title: "Test",
|
||||
Body: "<p>Name: {{.Username}}</p>",
|
||||
},
|
||||
params: entity.TemplateParams{
|
||||
Username: "<script>alert('xss')</script>",
|
||||
},
|
||||
shouldContain: []string{"<script>alert('xss')</script>"}, // 使用簡單字符串替換,不轉義
|
||||
shouldNotError: true,
|
||||
},
|
||||
{
|
||||
name: "空模板",
|
||||
tmpl: template.EmailTemplate{
|
||||
Title: "",
|
||||
Body: "",
|
||||
},
|
||||
params: entity.TemplateParams{
|
||||
Username: "Test",
|
||||
},
|
||||
expectedSubj: "",
|
||||
expectedBody: "",
|
||||
shouldNotError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := uc.RenderEmailTemplate(ctx, tt.tmpl, tt.params)
|
||||
|
||||
if tt.shouldNotError {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedSubj != "" {
|
||||
assert.Equal(t, tt.expectedSubj, result.Subject)
|
||||
}
|
||||
if tt.expectedBody != "" {
|
||||
assert.Equal(t, tt.expectedBody, result.Body)
|
||||
}
|
||||
for _, contain := range tt.shouldContain {
|
||||
assert.Contains(t, result.Body, contain)
|
||||
}
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateUseCase_RenderSMSTemplate(t *testing.T) {
|
||||
uc := MustTemplateUseCase(TemplateUseCaseParam{
|
||||
TemplateRepo: nil,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tmpl usecase.SMSTemplateResp
|
||||
params entity.TemplateParams
|
||||
expectedBody string
|
||||
shouldContain []string
|
||||
shouldNotError bool
|
||||
}{
|
||||
{
|
||||
name: "渲染 SMS 驗證碼",
|
||||
tmpl: usecase.SMSTemplateResp{
|
||||
Body: "您的驗證碼是:{{.VerifyCode}},請在5分鐘內使用。",
|
||||
},
|
||||
params: entity.TemplateParams{
|
||||
VerifyCode: "654321",
|
||||
},
|
||||
shouldContain: []string{"654321", "5分鐘"},
|
||||
shouldNotError: true,
|
||||
},
|
||||
{
|
||||
name: "SMS 純文本替換",
|
||||
tmpl: usecase.SMSTemplateResp{
|
||||
Body: "Hi {{.Username}}, your code: {{.VerifyCode}}",
|
||||
},
|
||||
params: entity.TemplateParams{
|
||||
Username: "<test>",
|
||||
VerifyCode: "111111",
|
||||
},
|
||||
shouldContain: []string{"<test>", "111111"}, // 使用簡單字符串替換
|
||||
shouldNotError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := uc.RenderSMSTemplate(ctx, tt.tmpl, tt.params)
|
||||
|
||||
if tt.shouldNotError {
|
||||
assert.NoError(t, err)
|
||||
if tt.expectedBody != "" {
|
||||
assert.Equal(t, tt.expectedBody, result.Body)
|
||||
}
|
||||
for _, contain := range tt.shouldContain {
|
||||
assert.Contains(t, result.Body, contain)
|
||||
}
|
||||
} else {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateUseCase_GetEmailTemplateByStatic(t *testing.T) {
|
||||
uc := MustTemplateUseCase(TemplateUseCaseParam{
|
||||
TemplateRepo: nil,
|
||||
})
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
language template.Language
|
||||
templateID template.Type
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "獲取忘記密碼模板 (zh-tw)",
|
||||
language: template.LanguageZhTW,
|
||||
templateID: template.ForgetPasswordVerify,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "獲取綁定郵箱模板 (zh-tw)",
|
||||
language: template.LanguageZhTW,
|
||||
templateID: template.BindingEmail,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "不存在的語言",
|
||||
language: template.Language("xx-xx"),
|
||||
templateID: template.ForgetPasswordVerify,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "不存在的模板類型",
|
||||
language: template.LanguageZhTW,
|
||||
templateID: template.Type("non_existent"),
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := uc.GetEmailTemplateByStatic(ctx, tt.language, tt.templateID)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, result.Title)
|
||||
assert.NotEmpty(t, result.Body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateUseCase_GetDefaultSMSTemplate(t *testing.T) {
|
||||
uc := &TemplateUseCase{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
templateID template.Type
|
||||
shouldContain []string
|
||||
}{
|
||||
{
|
||||
name: "忘記密碼模板",
|
||||
templateID: template.ForgetPasswordVerify,
|
||||
shouldContain: []string{"密碼重設", "驗證碼"},
|
||||
},
|
||||
{
|
||||
name: "綁定郵箱模板",
|
||||
templateID: template.BindingEmail,
|
||||
shouldContain: []string{"綁定", "驗證碼"},
|
||||
},
|
||||
{
|
||||
name: "默認模板",
|
||||
templateID: template.Type("unknown"),
|
||||
shouldContain: []string{"驗證碼"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := uc.getDefaultSMSTemplate(tt.templateID)
|
||||
|
||||
assert.NotEmpty(t, result.Body)
|
||||
for _, contain := range tt.shouldContain {
|
||||
assert.Contains(t, result.Body, contain)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -40,7 +40,7 @@ func DefaultConfig() Config {
|
|||
UIDLength: 6,
|
||||
AdminRoleUID: "AM000000",
|
||||
AdminUserUID: "B000000",
|
||||
DefaultRoleName: "user",
|
||||
DefaultRoleName: "USER",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
package permission
|
||||
|
||||
const (
|
||||
DefaultRole = "user"
|
||||
)
|
||||
|
|
@ -0,0 +1,50 @@
|
|||
package token
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type ContextKey string
|
||||
|
||||
func (c ContextKey) String() string {
|
||||
return string(c)
|
||||
}
|
||||
|
||||
const (
|
||||
KeyRole ContextKey = "role"
|
||||
KeyDeviceID ContextKey = "device_id"
|
||||
KeyScope ContextKey = "scope"
|
||||
KeyUID ContextKey = "uid"
|
||||
KeyLoginID ContextKey = "login_id"
|
||||
)
|
||||
|
||||
func UID(ctx context.Context) string { return getString(ctx, KeyUID) }
|
||||
func Scope(ctx context.Context) string { return getString(ctx, KeyScope) }
|
||||
func Role(ctx context.Context) string { return getString(ctx, KeyRole) }
|
||||
func DeviceID(ctx context.Context) string { return getString(ctx, KeyDeviceID) }
|
||||
func LoginID(ctx context.Context) string { return getString(ctx, KeyLoginID) }
|
||||
|
||||
func WithUID(ctx context.Context, uid string) context.Context {
|
||||
return context.WithValue(ctx, KeyUID, uid)
|
||||
}
|
||||
func WithScope(ctx context.Context, scope string) context.Context {
|
||||
return context.WithValue(ctx, KeyScope, scope)
|
||||
}
|
||||
func WithRole(ctx context.Context, role string) context.Context {
|
||||
return context.WithValue(ctx, KeyRole, role)
|
||||
}
|
||||
func WithDeviceID(ctx context.Context, id string) context.Context {
|
||||
return context.WithValue(ctx, KeyDeviceID, id)
|
||||
}
|
||||
func WithLoginID(ctx context.Context, login string) context.Context {
|
||||
return context.WithValue(ctx, KeyLoginID, login)
|
||||
}
|
||||
|
||||
// --- Internal helper ---
|
||||
func getString(ctx context.Context, key ContextKey) string {
|
||||
if v, ok := ctx.Value(key).(string); ok {
|
||||
return v
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
|
@ -26,7 +26,7 @@ type PermissionRepository struct {
|
|||
DB mongo.DocumentDBWithCacheUseCase
|
||||
}
|
||||
|
||||
func NewAccountRepository(param PermissionRepositoryParam) repository.PermissionRepository {
|
||||
func NewPermissionRepository(param PermissionRepositoryParam) repository.PermissionRepository {
|
||||
e := entity.Permission{}
|
||||
documentDB, err := mongo.MustDocumentDBWithCache(
|
||||
param.Conf,
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ func setupPermissionRepo(db string) (domainRepo.PermissionRepository, func(), er
|
|||
CacheConf: cacheConf,
|
||||
CacheOpts: cacheOpts,
|
||||
}
|
||||
repo := NewAccountRepository(param)
|
||||
repo := NewPermissionRepository(param)
|
||||
_, _ = repo.Index20251009001UP(context.Background())
|
||||
|
||||
return repo, tearDown, nil
|
||||
|
|
|
|||
|
|
@ -29,11 +29,11 @@ type TokenUseCase struct {
|
|||
}
|
||||
|
||||
func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) {
|
||||
claims, err := parseClaims(token, use.Config.Token.AccessSecret, false)
|
||||
claims, err := ParseClaims(token, use.Config.Token.AccessSecret, false)
|
||||
if err != nil {
|
||||
return nil,
|
||||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||
funcName: "parseClaims",
|
||||
funcName: "ParseClaims",
|
||||
req: token,
|
||||
err: err,
|
||||
message: "validate token claims error",
|
||||
|
|
@ -107,7 +107,7 @@ func (use *TokenUseCase) newToken(ctx context.Context, req *entity.Authorization
|
|||
RefreshCreateAt: now,
|
||||
}
|
||||
|
||||
tc := make(tokenClaims)
|
||||
tc := make(TokenClaims)
|
||||
if req.Data != nil {
|
||||
for k, v := range req.Data {
|
||||
tc[k] = v
|
||||
|
|
@ -116,7 +116,7 @@ func (use *TokenUseCase) newToken(ctx context.Context, req *entity.Authorization
|
|||
tc.SetRole(req.Role)
|
||||
tc.SetID(token.ID)
|
||||
tc.SetScope(req.Scope)
|
||||
tc.SetAccount(req.Account)
|
||||
tc.SetLoginID(req.Account)
|
||||
|
||||
token.UID = tc.UID()
|
||||
|
||||
|
|
@ -158,7 +158,7 @@ func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTok
|
|||
}
|
||||
|
||||
// Step 2: 提取 Claims Data
|
||||
claimsData, err := parseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false)
|
||||
claimsData, err := ParseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false)
|
||||
if err != nil {
|
||||
return entity.RefreshTokenResp{},
|
||||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||
|
|
@ -179,7 +179,7 @@ func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTok
|
|||
Data: claimsData,
|
||||
Expires: req.Expires,
|
||||
IsRefreshToken: true,
|
||||
Account: claimsData.Account(),
|
||||
Account: claimsData.LoginID(),
|
||||
Role: claimsData.Role(),
|
||||
})
|
||||
if err != nil {
|
||||
|
|
@ -226,7 +226,7 @@ func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTok
|
|||
}
|
||||
|
||||
func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error {
|
||||
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false)
|
||||
claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false)
|
||||
if err != nil {
|
||||
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||
funcName: "CancelToken extractClaims",
|
||||
|
|
@ -263,11 +263,11 @@ func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelToken
|
|||
}
|
||||
|
||||
func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) {
|
||||
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, true)
|
||||
claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, true)
|
||||
if err != nil {
|
||||
return entity.ValidationTokenResp{},
|
||||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||
funcName: "parseClaims",
|
||||
funcName: "ParseClaims",
|
||||
req: req,
|
||||
err: err,
|
||||
message: "validate token claims error",
|
||||
|
|
@ -400,11 +400,11 @@ func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.Quer
|
|||
|
||||
func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) {
|
||||
// 驗證Token
|
||||
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false)
|
||||
claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false)
|
||||
if err != nil {
|
||||
return entity.CreateOneTimeTokenResp{},
|
||||
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||
funcName: "parseClaims",
|
||||
funcName: "ParseClaims",
|
||||
req: req,
|
||||
err: err,
|
||||
message: "failed to get token claims",
|
||||
|
|
@ -637,7 +637,7 @@ func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string,
|
|||
// 為每個 token 創建黑名單條目
|
||||
for _, token := range tokens {
|
||||
// 解析 token 獲取 JTI 和過期時間
|
||||
claims, err := parseClaims(token.AccessToken, use.Config.Token.AccessSecret, false)
|
||||
claims, err := ParseClaims(token.AccessToken, use.Config.Token.AccessSecret, false)
|
||||
if err != nil {
|
||||
logx.WithContext(ctx).Errorw("failed to parse token for blacklisting",
|
||||
logx.Field("uid", uid),
|
||||
|
|
|
|||
|
|
@ -1,28 +1,28 @@
|
|||
package usecase
|
||||
|
||||
type tokenClaims map[string]string
|
||||
type TokenClaims map[string]string
|
||||
|
||||
func (tc tokenClaims) SetID(id string) {
|
||||
func (tc TokenClaims) SetID(id string) {
|
||||
tc["id"] = id
|
||||
}
|
||||
|
||||
func (tc tokenClaims) SetRole(role string) {
|
||||
func (tc TokenClaims) SetRole(role string) {
|
||||
tc["role"] = role
|
||||
}
|
||||
|
||||
func (tc tokenClaims) SetDeviceID(deviceID string) {
|
||||
func (tc TokenClaims) SetDeviceID(deviceID string) {
|
||||
tc["device_id"] = deviceID
|
||||
}
|
||||
|
||||
func (tc tokenClaims) SetScope(scope string) {
|
||||
func (tc TokenClaims) SetScope(scope string) {
|
||||
tc["scope"] = scope
|
||||
}
|
||||
|
||||
func (tc tokenClaims) SetAccount(account string) {
|
||||
tc["account"] = account
|
||||
func (tc TokenClaims) SetLoginID(loginID string) {
|
||||
tc["login_id"] = loginID
|
||||
}
|
||||
|
||||
func (tc tokenClaims) Role() string {
|
||||
func (tc TokenClaims) Role() string {
|
||||
role, ok := tc["role"]
|
||||
if !ok {
|
||||
return ""
|
||||
|
|
@ -31,7 +31,7 @@ func (tc tokenClaims) Role() string {
|
|||
return role
|
||||
}
|
||||
|
||||
func (tc tokenClaims) ID() string {
|
||||
func (tc TokenClaims) ID() string {
|
||||
id, ok := tc["id"]
|
||||
if !ok {
|
||||
return ""
|
||||
|
|
@ -40,7 +40,7 @@ func (tc tokenClaims) ID() string {
|
|||
return id
|
||||
}
|
||||
|
||||
func (tc tokenClaims) DeviceID() string {
|
||||
func (tc TokenClaims) DeviceID() string {
|
||||
deviceID, ok := tc["device_id"]
|
||||
if !ok {
|
||||
return ""
|
||||
|
|
@ -49,7 +49,7 @@ func (tc tokenClaims) DeviceID() string {
|
|||
return deviceID
|
||||
}
|
||||
|
||||
func (tc tokenClaims) UID() string {
|
||||
func (tc TokenClaims) UID() string {
|
||||
uid, ok := tc["uid"]
|
||||
if !ok {
|
||||
return ""
|
||||
|
|
@ -58,7 +58,7 @@ func (tc tokenClaims) UID() string {
|
|||
return uid
|
||||
}
|
||||
|
||||
func (tc tokenClaims) Scope() string {
|
||||
func (tc TokenClaims) Scope() string {
|
||||
scope, ok := tc["scope"]
|
||||
if !ok {
|
||||
return ""
|
||||
|
|
@ -67,8 +67,8 @@ func (tc tokenClaims) Scope() string {
|
|||
return scope
|
||||
}
|
||||
|
||||
func (tc tokenClaims) Account() string {
|
||||
scope, ok := tc["account"]
|
||||
func (tc TokenClaims) LoginID() string {
|
||||
scope, ok := tc["login_id"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ func TestTokenClaims_SetAndGetID(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tc := make(tokenClaims)
|
||||
tc := make(TokenClaims)
|
||||
tc.SetID(tt.id)
|
||||
|
||||
result := tc.ID()
|
||||
|
|
@ -61,7 +61,7 @@ func TestTokenClaims_SetAndGetRole(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tc := make(tokenClaims)
|
||||
tc := make(TokenClaims)
|
||||
tc.SetRole(tt.role)
|
||||
|
||||
result := tc.Role()
|
||||
|
|
@ -91,7 +91,7 @@ func TestTokenClaims_SetAndGetDeviceID(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tc := make(tokenClaims)
|
||||
tc := make(TokenClaims)
|
||||
tc.SetDeviceID(tt.deviceID)
|
||||
|
||||
result := tc.DeviceID()
|
||||
|
|
@ -125,7 +125,7 @@ func TestTokenClaims_SetAndGetScope(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tc := make(tokenClaims)
|
||||
tc := make(TokenClaims)
|
||||
tc.SetScope(tt.scope)
|
||||
|
||||
// Note: there's no GetScope method, so we just verify it's set
|
||||
|
|
@ -159,7 +159,7 @@ func TestTokenClaims_SetAndGetAccount(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tc := make(tokenClaims)
|
||||
tc := make(TokenClaims)
|
||||
tc.SetAccount(tt.account)
|
||||
|
||||
// Note: there's no GetAccount method, so we just verify it's set
|
||||
|
|
@ -189,7 +189,7 @@ func TestTokenClaims_SetAndGetUID(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tc := make(tokenClaims)
|
||||
tc := make(TokenClaims)
|
||||
tc["uid"] = tt.uid
|
||||
|
||||
result := tc.UID()
|
||||
|
|
@ -199,7 +199,7 @@ func TestTokenClaims_SetAndGetUID(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTokenClaims_GetNonExistentField(t *testing.T) {
|
||||
tc := make(tokenClaims)
|
||||
tc := make(TokenClaims)
|
||||
|
||||
t.Run("get non-existent ID", func(t *testing.T) {
|
||||
result := tc.ID()
|
||||
|
|
@ -223,7 +223,7 @@ func TestTokenClaims_GetNonExistentField(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTokenClaims_MultipleFields(t *testing.T) {
|
||||
tc := make(tokenClaims)
|
||||
tc := make(TokenClaims)
|
||||
|
||||
tc.SetID("token123")
|
||||
tc.SetRole("admin")
|
||||
|
|
@ -243,7 +243,7 @@ func TestTokenClaims_MultipleFields(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTokenClaims_Overwrite(t *testing.T) {
|
||||
tc := make(tokenClaims)
|
||||
tc := make(TokenClaims)
|
||||
|
||||
t.Run("overwrite ID", func(t *testing.T) {
|
||||
tc.SetID("token123")
|
||||
|
|
@ -263,7 +263,7 @@ func TestTokenClaims_Overwrite(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTokenClaims_MapBehavior(t *testing.T) {
|
||||
tc := make(tokenClaims)
|
||||
tc := make(TokenClaims)
|
||||
|
||||
t.Run("can set custom fields", func(t *testing.T) {
|
||||
tc["custom_field"] = "custom_value"
|
||||
|
|
@ -271,7 +271,7 @@ func TestTokenClaims_MapBehavior(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("can iterate over fields", func(t *testing.T) {
|
||||
tc2 := make(tokenClaims)
|
||||
tc2 := make(TokenClaims)
|
||||
tc2.SetID("token123")
|
||||
tc2.SetRole("admin")
|
||||
tc2["uid"] = "user123"
|
||||
|
|
@ -303,7 +303,7 @@ func TestTokenClaims_MapBehavior(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTokenClaims_EmptyMap(t *testing.T) {
|
||||
tc := make(tokenClaims)
|
||||
tc := make(TokenClaims)
|
||||
|
||||
assert.Empty(t, tc.ID())
|
||||
assert.Empty(t, tc.Role())
|
||||
|
|
@ -313,7 +313,7 @@ func TestTokenClaims_EmptyMap(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTokenClaims_NilMap(t *testing.T) {
|
||||
var tc tokenClaims
|
||||
var tc TokenClaims
|
||||
|
||||
t.Run("get from nil map", func(t *testing.T) {
|
||||
assert.Empty(t, tc.ID())
|
||||
|
|
@ -322,4 +322,3 @@ func TestTokenClaims_NilMap(t *testing.T) {
|
|||
assert.Empty(t, tc.UID())
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ func createAccessToken(token entity.Token, data any, secretKey string) (string,
|
|||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ID: token.ID,
|
||||
ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: "permission",
|
||||
},
|
||||
}
|
||||
|
|
@ -76,10 +77,10 @@ func parseToken(accessToken string, secret string, validate bool) (jwt.MapClaims
|
|||
return claims, nil
|
||||
}
|
||||
|
||||
func parseClaims(accessToken string, secret string, validate bool) (tokenClaims, error) {
|
||||
func ParseClaims(accessToken string, secret string, validate bool) (TokenClaims, error) {
|
||||
claimMap, err := parseToken(accessToken, secret, validate)
|
||||
if err != nil {
|
||||
return tokenClaims{}, err
|
||||
return TokenClaims{}, err
|
||||
}
|
||||
|
||||
claimsData, ok := claimMap["data"].(map[string]any)
|
||||
|
|
@ -87,7 +88,7 @@ func parseClaims(accessToken string, secret string, validate bool) (tokenClaims,
|
|||
return convertMap(claimsData), nil
|
||||
}
|
||||
|
||||
return tokenClaims{}, fmt.Errorf("get data from claim map error")
|
||||
return TokenClaims{}, fmt.Errorf("get data from claim map error")
|
||||
}
|
||||
|
||||
func convertMap(input map[string]interface{}) map[string]string {
|
||||
|
|
|
|||
|
|
@ -248,7 +248,7 @@ func TestParseClaims(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
claims, err := parseClaims(tt.accessToken, tt.secret, tt.validate)
|
||||
claims, err := ParseClaims(tt.accessToken, tt.secret, tt.validate)
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
|
|
|
|||
Loading…
Reference in New Issue