test push

This commit is contained in:
王性驊 2025-10-22 21:40:31 +08:00
parent ef9b218f3b
commit d71ffea750
56 changed files with 2390 additions and 381 deletions

View File

@ -51,3 +51,11 @@ Token:
OneTimeTokenExpiry : 600s OneTimeTokenExpiry : 600s
MaxTokensPerUser : 2 MaxTokensPerUser : 2
MaxTokensPerDevice : 2 MaxTokensPerDevice : 2
RoleConfig:
UIDPrefix: "AM"
UIDLength: 6
AdminRoleUID: "AM000000"
AdminUserUID: "B000000"
DefaultRoleName: "USER"

View File

@ -3,11 +3,7 @@ syntax = "v1"
// ================ 通用響應 ================ // ================ 通用響應 ================
type ( type (
// 成功響應 // 成功響應
RespOK { RespOK {}
Code int `json:"code"`
Msg string `json:"msg"`
Data interface{} `json:"data,omitempty"`
}
// 分頁響應 // 分頁響應
PagerResp { PagerResp {
@ -29,4 +25,9 @@ type (
Authorization { Authorization {
Authorization string `header:"Authorization" validate:"required"` Authorization string `header:"Authorization" validate:"required"`
} }
Status {
Code int64 `json:"code"` // 狀態碼
Message string `json:"message"` // 訊息
Data interface{} `json:"data,omitempty"` // 可選的資料,當有返回時才出現
}
) )

View File

@ -38,7 +38,7 @@ type (
// RequestPasswordResetReq 請求發送「忘記密碼」的驗證碼 // RequestPasswordResetReq 請求發送「忘記密碼」的驗證碼
RequestPasswordResetReq { RequestPasswordResetReq {
Identifier string `json:"identifier" validate:"required,email|phone"` // 使用者帳號 (信箱或手機) Identifier string `json:"identifier" validate:"required"` // 使用者帳號 (信箱或手機)
AccountType string `json:"account_type" validate:"required,oneof=email phone"` AccountType string `json:"account_type" validate:"required,oneof=email phone"`
} }
@ -141,6 +141,31 @@ type (
VerifyCode string `json:"verify_code" validate:"required,len=6"` VerifyCode string `json:"verify_code" validate:"required,len=6"`
Authorization Authorization
} }
// MyInfo 用於獲取會員資訊的標準響應結構
MyInfo {
Platform string `json:"platform"` // 註冊平台
UID string `json:"uid"` // 用戶 UID
AvatarURL *string `json:"avatar_url,omitempty"` // 頭像 URL
FullName *string `json:"full_name,omitempty"` // 用戶全名
Nickname *string `json:"nickname,omitempty"` // 暱稱
GenderCode *string `json:"gender_code,omitempty"` // 性別代碼
Birthdate *string `json:"birthdate,omitempty"` // 生日 (格式: 1993-04-17)
PhoneNumber *string `json:"phone_number,omitempty"` // 電話
IsPhoneVerified *bool `json:"is_phone_verified,omitempty"` // 手機是否已驗證
Email *string `json:"email,omitempty"` // 信箱
IsEmailVerified *bool `json:"is_email_verified,omitempty"` // 信箱是否已驗證
Address *string `json:"address,omitempty"` // 地址
UserStatus string `json:"user_status,omitempty"` // 用戶狀態
PreferredLanguage string `json:"preferred_language,omitempty"` // 偏好語言
Currency string `json:"currency,omitempty"` // 偏好幣種
AlarmCategory string `json:"alarm_category,omitempty"` // 告警狀態
PostCode *string `json:"post_code,omitempty"` // 郵遞區號
Carrier *string `json:"carrier,omitempty"` // 載具
Role string `json:"role"` // 角色
UpdateAt string `json:"update_at"`
CreateAt string `json:"create_at"`
}
) )
// ================================================================= // =================================================================
@ -251,7 +276,7 @@ service gateway {
@respdoc-500 (ErrorResp) // 伺服器內部錯誤 @respdoc-500 (ErrorResp) // 伺服器內部錯誤
*/ */
@handler getUserInfo @handler getUserInfo
get /me (Authorization) returns (UserInfoResp) get /me (Authorization) returns (MyInfo)
@doc( @doc(
summary: "更新當前登入的會員資訊" summary: "更新當前登入的會員資訊"

View File

@ -0,0 +1,4 @@
db.role.deleteMany({
"uid": { "$in": ["ADMIN", "OPERATOR", "USER"] }
});

View File

@ -0,0 +1,32 @@
db.role.insertMany([
{
"client_id": 1,
"uid": "ADMIN",
"name": "管理員",
"status": 1,
"create_time": NumberLong(1728745200),
"update_time": NumberLong(1728745200)
},
{
"client_id": 1,
"uid": "OPERATOR",
"name": "操作員",
"status": 1,
"create_time": NumberLong(1728745200),
"update_time": NumberLong(1728745200)
},
{
"client_id": 1,
"uid": "USER",
"name": "一般使用者",
"status": 1,
"create_time": NumberLong(1728745200),
"update_time": NumberLong(1728745200)
}
]);
// 建立索引
db.role.createIndex({ "uid": 1 }, { unique: true });
db.role.createIndex({ "client_id": 1 });
db.role.createIndex({ "status": 1 });

View File

@ -52,12 +52,30 @@ type Config struct {
// JWT Token 配置 // JWT Token 配置
Token struct { Token struct {
AccessSecret string AccessSecret string
RefreshSecret string RefreshSecret string
AccessTokenExpiry time.Duration AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration OneTimeTokenExpiry time.Duration
MaxTokensPerUser int MaxTokensPerUser int
MaxTokensPerDevice int MaxTokensPerDevice int
}
// RoleConfig 角色配置
RoleConfig struct {
// UID 前綴 (例如: AM, RL)
UIDPrefix string
// UID 數字長度
UIDLength int
// 管理員角色 UID
AdminRoleUID string
// 管理員用戶 UID
AdminUserUID string
// 預設角色名稱
DefaultRoleName string
} }
} }

19
internal/domain/redis.go Normal file
View File

@ -0,0 +1,19 @@
package domain
import "strings"
type RedisKey string
const (
GenerateVerifyCodeRedisKey RedisKey = "rf_code"
)
func (key RedisKey) ToString() string {
return string(key)
}
func (key RedisKey) With(s ...string) RedisKey {
parts := append([]string{string(key)}, s...)
return RedisKey(strings.Join(parts, ":"))
}

View File

@ -6,7 +6,6 @@ import (
"backend/internal/svc" "backend/internal/svc"
"backend/internal/types" "backend/internal/types"
"backend/pkg/library/errs" "backend/pkg/library/errs"
ers "backend/pkg/library/errs"
"net/http" "net/http"
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
@ -18,38 +17,39 @@ func LoginHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
var req types.LoginReq var req types.LoginReq
if err := httpx.Parse(r, &req); err != nil { if err := httpx.Parse(r, &req); err != nil {
e := errs.InvalidFormat(err.Error()) e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.RespOK{ httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int(e.FullCode()), Code: int64(e.FullCode()),
Msg: err.Error(), Message: err.Error(),
}) })
return return
} }
//if err := svcCtx.Validate.ValidateAll(req); err != nil { if err := svcCtx.Validate.ValidateAll(req); err != nil {
// e := errs.InvalidFormat(err.Error()) e := errs.InvalidFormat(err.Error())
// httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.RespOK{ httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
// Code: int(e.FullCode()), Code: int64(e.FullCode()),
// Msg: err.Error(), Message: err.Error(),
// }) })
//
// return return
//} }
l := auth.NewLoginLogic(r.Context(), svcCtx) l := auth.NewLoginLogic(r.Context(), svcCtx)
resp, err := l.Login(&req) resp, err := l.Login(&req)
if err != nil { if err != nil {
e := ers.FromError(err) e := errs.FromError(err)
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{ httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{
Code: int(e.FullCode()), Code: int(e.FullCode()),
Msg: e.Error(), Msg: e.Error(),
Error: e, Error: e,
}) })
} else { } else {
httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.RespOK{ httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
Code: domain.SuccessCode, Code: domain.SuccessCode,
Msg: domain.SuccessMessage, Message: domain.SuccessMessage,
Data: resp, Data: resp,
}) })
} }
} }

View File

@ -1,6 +1,8 @@
package auth package auth
import ( import (
"backend/internal/domain"
"backend/pkg/library/errs"
"net/http" "net/http"
"backend/internal/logic/auth" "backend/internal/logic/auth"
@ -15,16 +17,40 @@ func RefreshTokenHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var req types.RefreshTokenReq var req types.RefreshTokenReq
if err := httpx.Parse(r, &req); err != nil { if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err) e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int64(e.FullCode()),
Message: err.Error(),
})
return
}
if err := svcCtx.Validate.ValidateAll(req); err != nil {
e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int64(e.FullCode()),
Message: err.Error(),
})
return return
} }
l := auth.NewRefreshTokenLogic(r.Context(), svcCtx) l := auth.NewRefreshTokenLogic(r.Context(), svcCtx)
resp, err := l.RefreshToken(&req) resp, err := l.RefreshToken(&req)
if err != nil { if err != nil {
httpx.ErrorCtx(r.Context(), w, err) e := errs.FromError(err)
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{
Code: int(e.FullCode()),
Msg: e.Error(),
Error: e,
})
} else { } else {
httpx.OkJsonCtx(r.Context(), w, resp) httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
Code: domain.SuccessCode,
Message: domain.SuccessMessage,
Data: resp,
})
} }
} }
} }

View File

@ -12,15 +12,14 @@ import (
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
) )
// 註冊新帳號
func RegisterHandler(svcCtx *svc.ServiceContext) http.HandlerFunc { func RegisterHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var req types.LoginReq var req types.LoginReq
if err := httpx.Parse(r, &req); err != nil { if err := httpx.Parse(r, &req); err != nil {
e := errs.InvalidFormat(err.Error()) e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.RespOK{ httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int(e.FullCode()), Code: int64(e.FullCode()),
Msg: err.Error(), Message: err.Error(),
}) })
return return
@ -28,9 +27,9 @@ func RegisterHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
if err := svcCtx.Validate.ValidateAll(req); err != nil { if err := svcCtx.Validate.ValidateAll(req); err != nil {
e := errs.InvalidFormat(err.Error()) e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.RespOK{ httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int(e.FullCode()), Code: int64(e.FullCode()),
Msg: err.Error(), Message: err.Error(),
}) })
return return
@ -46,10 +45,10 @@ func RegisterHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
Error: e, Error: e,
}) })
} else { } else {
httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.RespOK{ httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
Code: domain.SuccessCode, Code: domain.SuccessCode,
Msg: domain.SuccessMessage, Message: domain.SuccessMessage,
Data: resp, Data: resp,
}) })
} }
} }

View File

@ -1,6 +1,8 @@
package auth package auth
import ( import (
"backend/internal/domain"
"backend/pkg/library/errs"
"net/http" "net/http"
"backend/internal/logic/auth" "backend/internal/logic/auth"
@ -15,16 +17,40 @@ func RequestPasswordResetHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var req types.RequestPasswordResetReq var req types.RequestPasswordResetReq
if err := httpx.Parse(r, &req); err != nil { if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err) e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int64(e.FullCode()),
Message: err.Error(),
})
return
}
if err := svcCtx.Validate.ValidateAll(req); err != nil {
e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int64(e.FullCode()),
Message: err.Error(),
})
return return
} }
l := auth.NewRequestPasswordResetLogic(r.Context(), svcCtx) l := auth.NewRequestPasswordResetLogic(r.Context(), svcCtx)
resp, err := l.RequestPasswordReset(&req) resp, err := l.RequestPasswordReset(&req)
if err != nil { if err != nil {
httpx.ErrorCtx(r.Context(), w, err) e := errs.FromError(err)
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{
Code: int(e.FullCode()),
Msg: e.Error(),
Error: e,
})
} else { } else {
httpx.OkJsonCtx(r.Context(), w, resp) httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
Code: domain.SuccessCode,
Message: domain.SuccessMessage,
Data: resp,
})
} }
} }
} }

View File

@ -1,6 +1,8 @@
package auth package auth
import ( import (
"backend/internal/domain"
"backend/pkg/library/errs"
"net/http" "net/http"
"backend/internal/logic/auth" "backend/internal/logic/auth"
@ -15,16 +17,40 @@ func ResetPasswordHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var req types.ResetPasswordReq var req types.ResetPasswordReq
if err := httpx.Parse(r, &req); err != nil { if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err) e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int64(e.FullCode()),
Message: err.Error(),
})
return
}
if err := svcCtx.Validate.ValidateAll(req); err != nil {
e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int64(e.FullCode()),
Message: err.Error(),
})
return return
} }
l := auth.NewResetPasswordLogic(r.Context(), svcCtx) l := auth.NewResetPasswordLogic(r.Context(), svcCtx)
resp, err := l.ResetPassword(&req) resp, err := l.ResetPassword(&req)
if err != nil { if err != nil {
httpx.ErrorCtx(r.Context(), w, err) e := errs.FromError(err)
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{
Code: int(e.FullCode()),
Msg: e.Error(),
Error: e,
})
} else { } else {
httpx.OkJsonCtx(r.Context(), w, resp) httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
Code: domain.SuccessCode,
Message: domain.SuccessMessage,
Data: resp,
})
} }
} }
} }

View File

@ -1,6 +1,8 @@
package auth package auth
import ( import (
"backend/internal/domain"
"backend/pkg/library/errs"
"net/http" "net/http"
"backend/internal/logic/auth" "backend/internal/logic/auth"
@ -15,16 +17,40 @@ func VerifyPasswordResetCodeHandler(svcCtx *svc.ServiceContext) http.HandlerFunc
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var req types.VerifyCodeReq var req types.VerifyCodeReq
if err := httpx.Parse(r, &req); err != nil { if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err) e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int64(e.FullCode()),
Message: err.Error(),
})
return
}
if err := svcCtx.Validate.ValidateAll(req); err != nil {
e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int64(e.FullCode()),
Message: err.Error(),
})
return return
} }
l := auth.NewVerifyPasswordResetCodeLogic(r.Context(), svcCtx) l := auth.NewVerifyPasswordResetCodeLogic(r.Context(), svcCtx)
resp, err := l.VerifyPasswordResetCode(&req) resp, err := l.VerifyPasswordResetCode(&req)
if err != nil { if err != nil {
httpx.ErrorCtx(r.Context(), w, err) e := errs.FromError(err)
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{
Code: int(e.FullCode()),
Msg: e.Error(),
Error: e,
})
} else { } else {
httpx.OkJsonCtx(r.Context(), w, resp) httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
Code: domain.SuccessCode,
Message: domain.SuccessMessage,
Data: resp,
})
} }
} }
} }

View File

@ -1,12 +1,13 @@
package user package user
import ( import (
"backend/internal/domain"
"backend/pkg/library/errs"
"net/http" "net/http"
"backend/internal/logic/user" "backend/internal/logic/user"
"backend/internal/svc" "backend/internal/svc"
"backend/internal/types" "backend/internal/types"
"github.com/zeromicro/go-zero/rest/httpx" "github.com/zeromicro/go-zero/rest/httpx"
) )
@ -15,16 +16,40 @@ func GetUserInfoHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var req types.Authorization var req types.Authorization
if err := httpx.Parse(r, &req); err != nil { if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err) e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int64(e.FullCode()),
Message: err.Error(),
})
return
}
if err := svcCtx.Validate.ValidateAll(req); err != nil {
e := errs.InvalidFormat(err.Error())
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.Status{
Code: int64(e.FullCode()),
Message: err.Error(),
})
return return
} }
l := user.NewGetUserInfoLogic(r.Context(), svcCtx) l := user.NewGetUserInfoLogic(r.Context(), svcCtx)
resp, err := l.GetUserInfo(&req) resp, err := l.GetUserInfo(&req)
if err != nil { if err != nil {
httpx.ErrorCtx(r.Context(), w, err) e := errs.FromError(err)
httpx.WriteJsonCtx(r.Context(), w, e.HTTPStatus(), types.ErrorResp{
Code: int(e.FullCode()),
Msg: e.Error(),
Error: e,
})
} else { } else {
httpx.OkJsonCtx(r.Context(), w, resp) httpx.WriteJsonCtx(r.Context(), w, http.StatusOK, types.Status{
Code: domain.SuccessCode,
Message: domain.SuccessMessage,
Data: resp,
})
} }
} }
} }

View File

@ -10,10 +10,7 @@ import (
) )
// 生成 Token // 生成 Token
func generateToken(svc *svc.ServiceContext, ctx context.Context, req *types.LoginReq, uid string) (entity.TokenResp, error) { func generateToken(svc *svc.ServiceContext, ctx context.Context, req *types.LoginReq, uid string, role string) (entity.TokenResp, error) {
// scope role 要修改refresh tl
role := "user"
tk, err := svc.TokenUC.NewToken(ctx, entity.AuthorizationReq{ tk, err := svc.TokenUC.NewToken(ctx, entity.AuthorizationReq{
GrantType: token.ClientCredentials.ToString(), GrantType: token.ClientCredentials.ToString(),
DeviceID: uid, // TODO 沒傳暫時先用UID 替代 DeviceID: uid, // TODO 沒傳暫時先用UID 替代

View File

@ -79,7 +79,12 @@ func (l *LoginLogic) Login(req *types.LoginReq) (resp *types.LoginResp, err erro
return nil, err return nil, err
} }
tk, err := generateToken(l.svcCtx, l.ctx, req, account.UID) userRole, err := l.svcCtx.UserRoleUC.Get(l.ctx, account.UID)
if err != nil {
return nil, err
}
tk, err := generateToken(l.svcCtx, l.ctx, req, account.UID, userRole.RoleUID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -7,6 +7,7 @@ import (
"backend/pkg/library/errs/code" "backend/pkg/library/errs/code"
mb "backend/pkg/member/domain/member" mb "backend/pkg/member/domain/member"
member "backend/pkg/member/domain/usecase" member "backend/pkg/member/domain/usecase"
"backend/pkg/permission/domain/usecase"
"context" "context"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -76,9 +77,18 @@ func (l *RegisterLogic) Register(req *types.LoginReq) (resp *types.LoginResp, er
return nil, err return nil, err
} }
_, err = l.svcCtx.UserRoleUC.Assign(l.ctx, usecase.AssignRoleRequest{
RoleUID: l.svcCtx.Config.RoleConfig.DefaultRoleName,
UserUID: account.UID,
Brand: "digimon",
})
if err != nil {
return nil, err
}
// Step 5: 生成 Token // Step 5: 生成 Token
req.LoginID = bd.CreateAccountReq.LoginID req.LoginID = bd.CreateAccountReq.LoginID
tk, err := generateToken(l.svcCtx, l.ctx, req, account.UID) tk, err := generateToken(l.svcCtx, l.ctx, req, account.UID, l.svcCtx.Config.RoleConfig.DefaultRoleName)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,7 +1,14 @@
package auth package auth
import ( import (
"backend/internal/domain"
"backend/internal/utils"
"backend/pkg/library/errs"
"backend/pkg/library/errs/code"
"backend/pkg/member/domain/member"
"backend/pkg/member/domain/usecase"
"context" "context"
"fmt"
"backend/internal/svc" "backend/internal/svc"
"backend/internal/types" "backend/internal/types"
@ -25,7 +32,109 @@ func NewRequestPasswordResetLogic(ctx context.Context, svcCtx *svc.ServiceContex
// RequestPasswordReset 請求發送密碼重設驗證碼 aka 忘記密碼 // RequestPasswordReset 請求發送密碼重設驗證碼 aka 忘記密碼
func (l *RequestPasswordResetLogic) RequestPasswordReset(req *types.RequestPasswordResetReq) (resp *types.RespOK, err error) { func (l *RequestPasswordResetLogic) RequestPasswordReset(req *types.RequestPasswordResetReq) (resp *types.RespOK, err error) {
// todo: add your logic here and delete this line // 驗證並標準化帳號
acc, err := l.validateAndNormalizeAccount(req.AccountType, req.Identifier)
if err != nil {
return nil, err
}
return // 檢查發送冷卻時間
rk := domain.GenerateVerifyCodeRedisKey.With(fmt.Sprintf("%s:%d", acc, member.GenerateCodeTypeForgetPassword)).ToString()
if err := l.checkVerifyCodeCooldown(rk); err != nil {
return nil, err
}
// 確認帳號是否註冊並檢查平台限制
if err := l.checkAccountAndPlatform(acc); err != nil {
return nil, err
}
// 生成驗證碼
vcode, err := l.svcCtx.AccountUC.GenerateRefreshCode(l.ctx, usecase.GenerateRefreshCodeRequest{
LoginID: acc,
CodeType: member.GenerateCodeTypeForgetPassword,
})
if err != nil {
return nil, err
}
// 獲取用戶資訊並確認綁定帳號
account, err := l.svcCtx.AccountUC.GetUIDByAccount(l.ctx, usecase.GetUIDByAccountRequest{Account: acc})
if err != nil {
return nil, errs.ResourceNotFoundWithScope(code.CloudEPMember, 0, fmt.Sprintf("account not found:%s", acc))
}
info, err := l.svcCtx.AccountUC.GetUserInfo(l.ctx, usecase.GetUserInfoRequest{UID: account.UID})
if err != nil {
return nil, err
}
// 發送驗證碼
fmt.Println("======= send", vcode.Data.VerifyCode, &info)
//nickname := getEmailShowName(&info)
//if err := l.sendVerificationCode(req.AccountType, acc, &info, vcode.Data.VerifyCode, nickname); err != nil {
// return nil, err
//}
// 設置 Redis 鍵
l.setRedisKeyWithExpiry(rk, vcode.Data.VerifyCode, 60)
return &types.RespOK{}, nil
}
// validateAndNormalizeAccount 驗證並標準化帳號
func (l *RequestPasswordResetLogic) validateAndNormalizeAccount(accountType, account string) (string, error) {
switch member.GetAccountTypeByCode(accountType) {
case member.AccountTypePhone:
phone, isPhone := utils.NormalizeTaiwanMobile(account)
if !isPhone {
return "", errs.InvalidFormatWithScope(code.CloudEPMember, "phone number is invalid")
}
return phone, nil
case member.AccountTypeMail:
if !utils.IsValidEmail(account) {
return "", errs.InvalidFormatWithScope(code.CloudEPMember, "email is invalid")
}
return account, nil
case member.AccountTypeNone, member.AccountTypeDefine:
default:
}
return "", errs.InvalidFormatWithScope(code.CloudEPMember, "unsupported account type")
}
// checkVerifyCodeCooldown 檢查是否已在限制時間內發送過驗證碼
func (l *RequestPasswordResetLogic) checkVerifyCodeCooldown(rk string) error {
if cachedCode, err := l.svcCtx.Redis.GetCtx(l.ctx, rk); err != nil || cachedCode != "" {
return errs.TooManyWithScope(code.CloudEPMember, "verification code already sent, please wait 3min for system to send again")
}
return nil
}
// checkAccountAndPlatform 檢查帳號是否註冊及平台限制
func (l *RequestPasswordResetLogic) checkAccountAndPlatform(acc string) error {
accountInfo, err := l.svcCtx.AccountUC.GetUserAccountInfo(l.ctx, usecase.GetUIDByAccountRequest{Account: acc})
if err != nil {
return err
}
if accountInfo.Data.Platform != member.Digimon {
return errs.InvalidFormatWithScope(code.CloudEPMember,
"failed to send verify code since platform not correct")
}
return nil
}
// setRedisKeyWithExpiry 設置 Redis 鍵
func (l *RequestPasswordResetLogic) setRedisKeyWithExpiry(rk, verifyCode string, expiry int) {
if status, err := l.svcCtx.Redis.SetnxExCtx(l.ctx, rk, verifyCode, expiry); err != nil || !status {
_ = errs.DatabaseErrorWithScopeL(code.CloudEPMember, 0, logx.WithContext(l.ctx), []logx.LogField{
{Key: "redisKey", Value: rk},
{Key: "error", Value: err.Error()},
}, "failed to set redis expire").Wrap(err)
}
} }

View File

@ -1,7 +1,15 @@
package auth package auth
import ( import (
"backend/internal/domain"
"backend/pkg/library/errs"
"backend/pkg/library/errs/code"
"backend/pkg/member/domain/member"
"backend/pkg/member/domain/usecase"
"backend/pkg/permission/domain/entity"
"context" "context"
"fmt"
"backend/internal/svc" "backend/internal/svc"
"backend/internal/types" "backend/internal/types"
@ -15,7 +23,7 @@ type ResetPasswordLogic struct {
svcCtx *svc.ServiceContext svcCtx *svc.ServiceContext
} }
// 執行密碼重設 // NewResetPasswordLogic 執行密碼重設
func NewResetPasswordLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ResetPasswordLogic { func NewResetPasswordLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ResetPasswordLogic {
return &ResetPasswordLogic{ return &ResetPasswordLogic{
Logger: logx.WithContext(ctx), Logger: logx.WithContext(ctx),
@ -24,8 +32,58 @@ func NewResetPasswordLogic(ctx context.Context, svcCtx *svc.ServiceContext) *Res
} }
} }
func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordReq) (resp *types.RespOK, err error) { func (l *ResetPasswordLogic) ResetPassword(req *types.ResetPasswordReq) (*types.RespOK, error) {
// todo: add your logic here and delete this line // 驗證密碼,兩次密碼要一致
if req.Password != req.PasswordConfirm {
return nil, errs.InvalidFormatWithScope(code.CloudEPMember, "password confirmation does not match")
}
return // 驗證碼
err := l.svcCtx.AccountUC.VerifyRefreshCode(l.ctx, usecase.VerifyRefreshCodeRequest{
LoginID: req.Identifier,
CodeType: member.GenerateCodeTypeForgetPassword,
VerifyCode: req.VerifyCode,
})
if err != nil {
// 表使沒有這驗證碼
return nil, errs.ForbiddenWithScope(code.CloudEPMember, 0, "failed to get verify code")
}
info, err := l.svcCtx.AccountUC.GetUserAccountInfo(l.ctx, usecase.GetUIDByAccountRequest{Account: req.Identifier})
if err != nil {
return nil, err
}
if info.Data.Platform != member.Digimon {
return nil, errs.ForbiddenWithScope(code.CloudEPMember, 0, "invalid platform")
}
// 更新
err = l.svcCtx.AccountUC.UpdateUserToken(l.ctx, usecase.UpdateTokenRequest{
Account: req.Identifier,
Token: req.Password,
Platform: member.Digimon.ToInt64(),
})
if err != nil {
return nil, err
}
rk := domain.GenerateVerifyCodeRedisKey.With(
fmt.Sprintf("%s-%d", req.Identifier, member.GenerateCodeTypeForgetPassword),
).ToString()
_, _ = l.svcCtx.Redis.Del(rk)
ac, err := l.svcCtx.AccountUC.GetUIDByAccount(l.ctx, usecase.GetUIDByAccountRequest{Account: req.Identifier})
if err != nil {
return nil, err
}
err = l.svcCtx.TokenUC.CancelTokens(l.ctx, entity.DoTokenByUIDReq{UID: ac.UID})
if err != nil {
return nil, err
}
// 返回成功響應
return &types.RespOK{}, nil
} }

View File

@ -1,6 +1,9 @@
package auth package auth
import ( import (
"backend/pkg/library/errs"
"backend/pkg/member/domain/member"
"backend/pkg/member/domain/usecase"
"context" "context"
"backend/internal/svc" "backend/internal/svc"
@ -25,7 +28,16 @@ func NewVerifyPasswordResetCodeLogic(ctx context.Context, svcCtx *svc.ServiceCon
// VerifyPasswordResetCode 校驗密碼重設驗證碼(頁面需求,預先檢查看看, 顯示表演用) // VerifyPasswordResetCode 校驗密碼重設驗證碼(頁面需求,預先檢查看看, 顯示表演用)
func (l *VerifyPasswordResetCodeLogic) VerifyPasswordResetCode(req *types.VerifyCodeReq) (resp *types.RespOK, err error) { func (l *VerifyPasswordResetCodeLogic) VerifyPasswordResetCode(req *types.VerifyCodeReq) (resp *types.RespOK, err error) {
// todo: add your logic here and delete this line // 先驗證,不刪除
if err := l.svcCtx.AccountUC.CheckRefreshCode(l.ctx, usecase.VerifyRefreshCodeRequest{
VerifyCode: req.VerifyCode,
LoginID: req.Identifier,
CodeType: member.GenerateCodeTypeForgetPassword,
}); err != nil {
e := errs.Forbidden("failed to get verify code").Wrap(err)
return return nil, e
}
return &types.RespOK{}, nil
} }

View File

@ -1,7 +1,12 @@
package user package user
import ( import (
"backend/pkg/member/domain/member"
"backend/pkg/member/domain/usecase"
"backend/pkg/permission/domain/token"
"context" "context"
"google.golang.org/protobuf/proto"
"time"
"backend/internal/svc" "backend/internal/svc"
"backend/internal/types" "backend/internal/types"
@ -15,7 +20,7 @@ type GetUserInfoLogic struct {
svcCtx *svc.ServiceContext svcCtx *svc.ServiceContext
} }
// 取得當前登入的會員資訊(自己) // NewGetUserInfoLogic 取得當前登入的會員資訊(自己)
func NewGetUserInfoLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserInfoLogic { func NewGetUserInfoLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUserInfoLogic {
return &GetUserInfoLogic{ return &GetUserInfoLogic{
Logger: logx.WithContext(ctx), Logger: logx.WithContext(ctx),
@ -24,8 +29,88 @@ func NewGetUserInfoLogic(ctx context.Context, svcCtx *svc.ServiceContext) *GetUs
} }
} }
func (l *GetUserInfoLogic) GetUserInfo(req *types.Authorization) (resp *types.UserInfoResp, err error) { func (l *GetUserInfoLogic) GetUserInfo(req *types.Authorization) (*types.MyInfo, error) {
// todo: add your logic here and delete this line uid := token.UID(l.ctx)
info, err := l.svcCtx.AccountUC.GetUserInfo(l.ctx, usecase.GetUserInfoRequest{
UID: uid,
})
if err != nil {
return nil, err
}
return byUID, err := l.svcCtx.AccountUC.FindLoginIDByUID(l.ctx, uid)
if err != nil {
return nil, err
}
accountInfo, err := l.svcCtx.AccountUC.GetUserAccountInfo(l.ctx, usecase.GetUIDByAccountRequest{
Account: byUID.LoginID,
})
if err != nil {
return nil, err
}
userRole, err := l.svcCtx.UserRoleUC.Get(l.ctx, uid)
if err != nil {
return nil, err
}
role := userRole.RoleUID
res := &types.MyInfo{
Platform: accountInfo.Data.Platform.ToString(),
UID: info.UID,
UpdateAt: time.Unix(0, info.CreateTime).UTC().Format(time.RFC3339),
CreateAt: time.Unix(0, info.UpdateTime).UTC().Format(time.RFC3339),
Role: role,
UserStatus: info.UserStatus.CodeToString(),
PreferredLanguage: info.PreferredLanguage,
Currency: info.Currency,
AlarmCategory: info.AlarmCategory.CodeToString(),
}
if info.Address != nil {
res.Address = info.Address
}
if info.AvatarURL != nil {
res.AvatarURL = info.AvatarURL
}
if info.FullName != nil {
res.FullName = info.FullName
}
if info.Birthdate != nil {
b := ToDate(info.Birthdate)
res.Birthdate = b
}
if info.Address != nil {
res.Address = info.Address
}
if info.Nickname != nil {
res.Nickname = info.Nickname
}
if info.Email != nil {
res.Email = info.Email
res.IsEmailVerified = proto.Bool(true)
}
if info.PhoneNumber != nil {
res.PhoneNumber = info.PhoneNumber
res.IsPhoneVerified = proto.Bool(true)
}
if info.GenderCode != nil {
gc := member.GetGenderByCode(*info.GenderCode)
res.GenderCode = &gc
}
return res, nil
}
func ToDate(n *int64) *string {
result := ""
if n != nil {
result = time.Unix(*n, 0).UTC().Format(time.DateOnly)
}
return &result
} }

View File

@ -1,19 +1,85 @@
package middleware package middleware
import "net/http" import (
"backend/internal/types"
"backend/pkg/library/errs"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/token"
"context"
"github.com/zeromicro/go-zero/rest/httpx"
type AuthMiddleware struct { "backend/pkg/permission/domain/usecase"
uc "backend/pkg/permission/usecase"
"net/http"
)
type AuthMiddlewareParam struct {
TokenSec string
TokenUseCase usecase.TokenUseCase
} }
func NewAuthMiddleware() *AuthMiddleware { type AuthMiddleware struct {
return &AuthMiddleware{} TokenSec string
TokenUseCase usecase.TokenUseCase
}
func NewAuthMiddleware(param AuthMiddlewareParam) *AuthMiddleware {
return &AuthMiddleware{
TokenSec: param.TokenSec,
TokenUseCase: param.TokenUseCase,
}
} }
func (m *AuthMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { func (m *AuthMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
// TODO generate middleware implement function, delete after code implementation // 解析 Header
header := types.Authorization{}
if err := httpx.ParseHeaders(r, &header); err != nil {
m.writeErrorResponse(w, r, http.StatusBadRequest, "Failed to parse headers", int64(errs.InvalidFormat("").FullCode()))
// Passthrough to next handler if need return
next(w, r) }
// 驗證 Token
claim, err := uc.ParseClaims(header.Authorization, m.TokenSec, true)
if err != nil {
// 是否需要紀錄錯誤,是不是只要紀錄除了驗證失敗或過期之外的真錯誤
m.writeErrorResponse(w, r,
http.StatusUnauthorized, "failed to verify toke",
int64(100400))
return
}
// 驗證 Token 是否在黑名單中
if _, err := m.TokenUseCase.ValidationToken(r.Context(), entity.ValidationTokenReq{Token: header.Authorization}); err != nil {
m.writeErrorResponse(w, r, http.StatusForbidden,
"failed to get toke",
int64(100400))
return
}
// 設置 context 並傳遞給下一個處理器
ctx := SetContext(r, claim)
next(w, r.WithContext(ctx))
} }
} }
func SetContext(r *http.Request, claim uc.TokenClaims) context.Context {
ctx := context.WithValue(r.Context(), token.KeyRole, claim.Role())
ctx = context.WithValue(ctx, token.KeyUID, claim.UID())
ctx = context.WithValue(ctx, token.KeyDeviceID, claim.DeviceID())
ctx = context.WithValue(ctx, token.KeyScope, claim.Scope())
ctx = context.WithValue(ctx, token.KeyLoginID, claim.LoginID())
return ctx
}
// writeErrorResponse 用於處理錯誤回應
func (m *AuthMiddleware) writeErrorResponse(w http.ResponseWriter, r *http.Request, statusCode int, message string, code int64) {
httpx.WriteJsonCtx(r.Context(), w, statusCode, types.ErrorResp{
Code: int(code),
Msg: message,
})
}

View File

@ -19,6 +19,11 @@ type ServiceContext struct {
AccountUC memberUC.AccountUseCase AccountUC memberUC.AccountUseCase
Validate vi.Validate Validate vi.Validate
TokenUC tokenUC.TokenUseCase TokenUC tokenUC.TokenUseCase
PermissionUC tokenUC.PermissionUseCase
RoleUC tokenUC.RoleUseCase
RolePermission tokenUC.RolePermissionUseCase
UserRoleUC tokenUC.UserRoleUseCase
Redis *redis.Redis
} }
func NewServiceContext(c config.Config) *ServiceContext { func NewServiceContext(c config.Config) *ServiceContext {
@ -28,11 +33,22 @@ func NewServiceContext(c config.Config) *ServiceContext {
} }
errs.Scope = code.CloudEPPortalGW errs.Scope = code.CloudEPPortalGW
rp := NewPermissionUC(&c)
tkUC := NewTokenUC(&c, rds)
return &ServiceContext{ return &ServiceContext{
Config: c, Config: c,
AuthMiddleware: middleware.NewAuthMiddleware().Handle, AuthMiddleware: middleware.NewAuthMiddleware(middleware.AuthMiddlewareParam{
TokenSec: c.Token.AccessSecret,
TokenUseCase: tkUC,
}).Handle,
AccountUC: NewAccountUC(&c, rds), AccountUC: NewAccountUC(&c, rds),
Validate: vi.MustValidator(), Validate: vi.MustValidator(),
TokenUC: NewTokenUC(&c, rds), TokenUC: tkUC,
PermissionUC: rp.PermissionUC,
RoleUC: rp.RoleUC,
RolePermission: rp.RolePermission,
UserRoleUC: rp.UserRole,
Redis: rds,
} }
} }

View File

@ -2,9 +2,12 @@ package svc
import ( import (
"backend/internal/config" "backend/internal/config"
mgo "backend/pkg/library/mongo"
"backend/pkg/permission/domain/usecase" "backend/pkg/permission/domain/usecase"
"backend/pkg/permission/repository" "backend/pkg/permission/repository"
uc "backend/pkg/permission/usecase" uc "backend/pkg/permission/usecase"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mon"
"github.com/zeromicro/go-zero/core/stores/redis" "github.com/zeromicro/go-zero/core/stores/redis"
) )
@ -16,3 +19,102 @@ func NewTokenUC(c *config.Config, rds *redis.Redis) usecase.TokenUseCase {
Config: c, Config: c,
}) })
} }
type PermissionUC struct {
PermissionUC usecase.PermissionUseCase
RoleUC usecase.RoleUseCase
RolePermission usecase.RolePermissionUseCase
UserRole usecase.UserRoleUseCase
}
func NewPermissionUC(c *config.Config) PermissionUC {
// 準備Mongo Config
conf := &mgo.Conf{
Schema: c.Mongo.Schema,
Host: c.Mongo.Host,
Database: c.Mongo.Database,
MaxStaleness: c.Mongo.MaxStaleness,
MaxPoolSize: c.Mongo.MaxPoolSize,
MinPoolSize: c.Mongo.MinPoolSize,
MaxConnIdleTime: c.Mongo.MaxConnIdleTime,
Compressors: c.Mongo.Compressors,
EnableStandardReadWriteSplitMode: c.Mongo.EnableStandardReadWriteSplitMode,
ConnectTimeoutMs: c.Mongo.ConnectTimeoutMs,
}
if c.Mongo.User != "" {
conf.User = c.Mongo.User
conf.Password = c.Mongo.Password
}
// 快取選項
cacheOpts := []cache.Option{
cache.WithExpiry(c.CacheExpireTime),
cache.WithNotFoundExpiry(c.CacheWithNotFoundExpiry),
}
dbOpts := []mon.Option{
mgo.SetCustomDecimalType(),
mgo.InitMongoOptions(*conf),
}
permRepo := repository.NewPermissionRepository(repository.PermissionRepositoryParam{
Conf: conf,
CacheConf: c.Cache,
CacheOpts: cacheOpts,
DBOpts: dbOpts,
})
rolePermRepo := repository.NewRolePermissionRepository(repository.RolePermissionRepositoryParam{
Conf: conf,
CacheConf: c.Cache,
CacheOpts: cacheOpts,
DBOpts: dbOpts,
})
roleRepo := repository.NewRoleRepository(repository.RoleRepositoryParam{
Conf: conf,
CacheConf: c.Cache,
CacheOpts: cacheOpts,
DBOpts: dbOpts,
})
userRoleRepo := repository.NewUserRoleRepository(repository.UserRoleRepositoryParam{
Conf: conf,
CacheConf: c.Cache,
CacheOpts: cacheOpts,
DBOpts: dbOpts,
})
puc := uc.NewPermissionUseCase(uc.PermissionUseCaseParam{
RoleRepo: roleRepo,
RolePermRepo: rolePermRepo,
UserRoleRepo: userRoleRepo,
PermRepo: permRepo,
})
rpuc := uc.NewRolePermissionUseCase(uc.RolePermissionUseCaseParam{
RoleRepo: roleRepo,
RolePermRepo: rolePermRepo,
UserRoleRepo: userRoleRepo,
PermRepo: permRepo,
PermUseCase: puc,
AdminRoleUID: c.RoleConfig.AdminRoleUID,
})
ruc := uc.NewRoleUseCase(uc.RoleUseCaseParam{
RoleRepo: roleRepo,
UserRoleRepo: userRoleRepo,
Config: uc.RoleUseCaseConfig{
AdminRoleUID: c.RoleConfig.AdminRoleUID,
UIDPrefix: c.RoleConfig.UIDPrefix,
UIDLength: c.RoleConfig.UIDLength,
},
RolePermUseCase: rpuc,
})
return PermissionUC{
PermissionUC: puc,
RolePermission: rpuc,
RoleUC: ruc,
UserRole: uc.NewUserRoleUseCase(uc.UserRoleUseCaseParam{
UserRoleRepo: userRoleRepo,
RoleRepo: roleRepo,
}),
}
}

View File

@ -37,6 +37,30 @@ type LoginResp struct {
TokenType string `json:"token_type"` // 通常固定為 "Bearer" TokenType string `json:"token_type"` // 通常固定為 "Bearer"
} }
type MyInfo struct {
Platform string `json:"platform"` // 註冊平台
UID string `json:"uid"` // 用戶 UID
AvatarURL *string `json:"avatar_url,omitempty"` // 頭像 URL
FullName *string `json:"full_name,omitempty"` // 用戶全名
Nickname *string `json:"nickname,omitempty"` // 暱稱
GenderCode *string `json:"gender_code,omitempty"` // 性別代碼
Birthdate *string `json:"birthdate,omitempty"` // 生日 (格式: 1993-04-17)
PhoneNumber *string `json:"phone_number,omitempty"` // 電話
IsPhoneVerified *bool `json:"is_phone_verified,omitempty"` // 手機是否已驗證
Email *string `json:"email,omitempty"` // 信箱
IsEmailVerified *bool `json:"is_email_verified,omitempty"` // 信箱是否已驗證
Address *string `json:"address,omitempty"` // 地址
UserStatus string `json:"user_status,omitempty"` // 用戶狀態
PreferredLanguage string `json:"preferred_language,omitempty"` // 偏好語言
Currency string `json:"currency,omitempty"` // 偏好幣種
AlarmCategory string `json:"alarm_category,omitempty"` // 告警狀態
PostCode *string `json:"post_code,omitempty"` // 郵遞區號
Carrier *string `json:"carrier,omitempty"` // 載具
Role string `json:"role"` // 角色
UpdateAt string `json:"update_at"`
CreateAt string `json:"create_at"`
}
type PagerResp struct { type PagerResp struct {
Total int64 `json:"total"` Total int64 `json:"total"`
Size int64 `json:"size"` Size int64 `json:"size"`
@ -60,7 +84,7 @@ type RefreshTokenResp struct {
} }
type RequestPasswordResetReq struct { type RequestPasswordResetReq struct {
Identifier string `json:"identifier" validate:"required,email|phone"` // 使用者帳號 (信箱或手機) Identifier string `json:"identifier" validate:"required"` // 使用者帳號 (信箱或手機)
AccountType string `json:"account_type" validate:"required,oneof=email phone"` AccountType string `json:"account_type" validate:"required,oneof=email phone"`
} }
@ -77,9 +101,12 @@ type ResetPasswordReq struct {
} }
type RespOK struct { type RespOK struct {
Code int `json:"code"` }
Msg string `json:"msg"`
Data interface{} `json:"data,omitempty"` type Status struct {
Code int64 `json:"code"` // 狀態碼
Message string `json:"message"` // 訊息
Data interface{} `json:"data,omitempty"` // 可選的資料,當有返回時才出現
} }
type SubmitVerificationCodeReq struct { type SubmitVerificationCodeReq struct {

36
internal/utils/format.go Normal file
View File

@ -0,0 +1,36 @@
package utils
import (
"regexp"
"strings"
)
// NormalizeTaiwanMobile 標準化號碼並驗證是否為合法台灣手機號碼
func NormalizeTaiwanMobile(phone string) (string, bool) {
// 移除空格
phone = strings.ReplaceAll(phone, " ", "")
// 移除 "+886" 並將剩餘部分標準化
if strings.HasPrefix(phone, "+886") {
phone = strings.TrimPrefix(phone, "+886")
if !strings.HasPrefix(phone, "0") {
phone = "0" + phone
}
}
// 正則表達式驗證標準化後的號碼
regex := regexp.MustCompile(`^(09\d{8})$`)
if regex.MatchString(phone) {
return phone, true
}
return "", false
}
// IsValidEmail 驗證 Email 格式的函數
func IsValidEmail(email string) bool {
// 定義正則表達式
regex := regexp.MustCompile(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`)
return regex.MatchString(email)
}

View File

@ -35,6 +35,7 @@ const (
InsufficientQuota // 配額不足 InsufficientQuota // 配額不足
ResourceHasMultiOwner // 資源有多個所有者 ResourceHasMultiOwner // 資源有多個所有者
UserSuspended // 沒有權限使用該資源 UserSuspended // 沒有權限使用該資源
TooManyRequest // 單位時間內請求太多次
) )
/* 詳細代碼 - GRPC */ /* 詳細代碼 - GRPC */
@ -77,13 +78,13 @@ const (
// 詳細代碼 - Token 類 09x // 詳細代碼 - Token 類 09x
const ( const (
_ = iota + CatToken _ = iota + CatToken
TokenCreateError // Token 創建錯誤 TokenCreateError // Token 創建錯誤
TokenValidateError // Token 驗證錯誤 TokenValidateError // Token 驗證錯誤
TokenExpired // Token 過期 TokenExpired // Token 過期
TokenNotFound // Token 未找到 TokenNotFound // Token 未找到
TokenBlacklisted // Token 已被列入黑名單 TokenBlacklisted // Token 已被列入黑名單
InvalidJWT // 無效的 JWT InvalidJWT // 無效的 JWT
RefreshTokenError // Refresh Token 錯誤 RefreshTokenError // Refresh Token 錯誤
OneTimeTokenError // 一次性 Token 錯誤 OneTimeTokenError // 一次性 Token 錯誤
) )

View File

@ -556,3 +556,8 @@ func MsgSizeTooLargeL(l logx.Logger, filed []logx.LogField, s ...string) *LibErr
return e return e
} }
func TooManyWithScope(scope uint32, s ...string) *LibError {
return NewError(scope, code.TooManyRequest, defaultDetailCode,
fmt.Sprintf("%s", strings.Join(s, " ")))
}

View File

@ -154,8 +154,13 @@ func (e *LibError) HTTPStatus() int {
if e == nil || e.Code() == code.OK { if e == nil || e.Code() == code.OK {
return http.StatusOK return http.StatusOK
} }
// 將 code 轉換為與常量定義相同的格式 (category + detail)
// 例如code=3004 -> (3004%100) + 30 = 4 + 30 = 34
codeValue := (e.Code() % 100) + e.Category()
// 根據錯誤碼判斷對應的 HTTP 狀態碼 // 根據錯誤碼判斷對應的 HTTP 狀態碼
switch e.Code() / 100 { switch codeValue {
case code.ResourceInsufficient, code.InvalidFormat: case code.ResourceInsufficient, code.InvalidFormat:
// 如果資源不足,返回 400 狀態碼 // 如果資源不足,返回 400 狀態碼
return http.StatusBadRequest return http.StatusBadRequest
@ -177,6 +182,9 @@ func (e *LibError) HTTPStatus() int {
case code.NotValidImplementation: case code.NotValidImplementation:
// 如果實現無效,返回 501 狀態碼 // 如果實現無效,返回 501 狀態碼
return http.StatusNotImplemented return http.StatusNotImplemented
case code.TooManyRequest:
// 如果實現無效,返回 501 狀態碼
return http.StatusTooManyRequests
default: default:
// 如果沒有匹配的錯誤碼,則繼續下一步 // 如果沒有匹配的錯誤碼,則繼續下一步
} }

View File

@ -71,12 +71,12 @@ func TestLibError_HTTPStatus(t *testing.T) {
err *LibError err *LibError
expected int expected int
}{ }{
{"bad request", NewError(1, code.CatService, code.ResourceInsufficient, "bad request"), http.StatusBadRequest}, {"bad request - ResourceInsufficient", NewError(1, code.CatResource, 4, "bad request"), http.StatusBadRequest},
{"unauthorized", NewError(1, code.CatAuth, code.Unauthorized, "unauthorized"), http.StatusUnauthorized}, {"unauthorized", NewError(1, code.CatAuth, 1, "unauthorized"), http.StatusUnauthorized},
{"forbidden", NewError(1, code.CatAuth, code.Forbidden, "forbidden"), http.StatusForbidden}, {"forbidden", NewError(1, code.CatAuth, 5, "forbidden"), http.StatusForbidden},
{"not found", NewError(1, code.CatResource, code.ResourceNotFound, "not found"), http.StatusNotFound}, {"not found", NewError(1, code.CatResource, 1, "not found"), http.StatusNotFound},
{"internal server error", NewError(1, code.CatDB, 1095, "not found"), http.StatusInternalServerError}, {"internal server error", NewError(1, code.CatDB, 95, "db error"), http.StatusInternalServerError},
{"input err", NewError(1, code.CatInput, 1095, "not found"), http.StatusBadRequest}, {"input err", NewError(1, code.CatInput, 1, "input error"), http.StatusBadRequest},
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -19,3 +19,35 @@ const (
const ( const (
CurrencyTWD Currency = "TWD" CurrencyTWD Currency = "TWD"
) )
var genderMap = map[int64]string{
0: "",
1: "male",
2: "female",
3: "secret",
}
func GetGenderByCode(g int64) string {
r, ok := genderMap[g]
if !ok {
return genderMap[0]
}
return r
}
var genderCodeMap = map[string]int64{
"": 0,
"male": 1,
"female": 2,
"secret": 3,
}
func GetGenderCodeByStr(g string) int64 {
r, ok := genderCodeMap[g]
if !ok {
return genderCodeMap[""]
}
return r
}

View File

@ -14,6 +14,7 @@ type AccountUIDRepository interface {
Update(ctx context.Context, data *entity.AccountUID) (*mongo.UpdateResult, error) Update(ctx context.Context, data *entity.AccountUID) (*mongo.UpdateResult, error)
Delete(ctx context.Context, id string) (int64, error) Delete(ctx context.Context, id string) (int64, error)
FindUIDByLoginID(ctx context.Context, loginID string) (*entity.AccountUID, error) FindUIDByLoginID(ctx context.Context, loginID string) (*entity.AccountUID, error)
FindOneByUID(ctx context.Context, uid string) (*entity.AccountUID, error)
AccountUIDIndexUP AccountUIDIndexUP
} }

View File

@ -31,6 +31,8 @@ type MemberUseCase interface {
GetUserInfo(ctx context.Context, req GetUserInfoRequest) (UserInfo, error) GetUserInfo(ctx context.Context, req GetUserInfoRequest) (UserInfo, error)
// ListMember 取得會員列表 // ListMember 取得會員列表
ListMember(ctx context.Context, req ListUserInfoRequest) (ListUserInfoResponse, error) ListMember(ctx context.Context, req ListUserInfoRequest) (ListUserInfoResponse, error)
// FindLoginIDByUID 取得login id
FindLoginIDByUID(ctx context.Context, uid string) (BindingUser, error)
} }
type BindingMemberUseCase interface { type BindingMemberUseCase interface {

View File

@ -112,6 +112,20 @@ func (repo *AccountUIDRepository) FindUIDByLoginID(ctx context.Context, loginID
} }
} }
func (repo *AccountUIDRepository) FindOneByUID(ctx context.Context, uid string) (*entity.AccountUID, error) {
var data entity.AccountUID
err := repo.DB.GetClient().FindOne(ctx, &data, bson.M{"uid": uid})
switch {
case err == nil:
return &data, nil
case errors.Is(err, mon.ErrNotFound):
return nil, ErrNotFound
default:
return nil, err
}
}
func (repo *AccountUIDRepository) Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error) { func (repo *AccountUIDRepository) Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error) {
// 等價於 db.account_uid_binding.createIndex({"login_id": 1}, {unique: true}) // 等價於 db.account_uid_binding.createIndex({"login_id": 1}, {unique: true})
repo.DB.PopulateIndex(ctx, "login_id", 1, true) repo.DB.PopulateIndex(ctx, "login_id", 1, true)

View File

@ -205,7 +205,7 @@ func (repo *UserRepository) FindOneByUID(ctx context.Context, uid string) (*enti
// 不常寫,再找一次可接受 // 不常寫,再找一次可接受
id := repo.UIDToID(ctx, uid) id := repo.UIDToID(ctx, uid)
if id == "" { if id == "" {
return nil, errors.New("invalid uid") return nil, ErrNotFound
} }
rk := domain.GetUserRedisKey(id) rk := domain.GetUserRedisKey(id)

View File

@ -4,6 +4,7 @@ import (
"backend/pkg/member/domain/config" "backend/pkg/member/domain/config"
"backend/pkg/member/domain/repository" "backend/pkg/member/domain/repository"
"backend/pkg/member/domain/usecase" "backend/pkg/member/domain/usecase"
"context"
) )
type MemberUseCaseParam struct { type MemberUseCaseParam struct {
@ -24,3 +25,15 @@ func MustMemberUseCase(param MemberUseCaseParam) usecase.AccountUseCase {
param, param,
} }
} }
func (use *MemberUseCase) FindLoginIDByUID(ctx context.Context, uid string) (usecase.BindingUser, error) {
data, err := use.AccountUID.FindOneByUID(ctx, uid)
if err != nil {
return usecase.BindingUser{}, err
}
return usecase.BindingUser{
UID: data.UID,
LoginID: data.LoginID,
}, nil
}

View File

@ -1,6 +1,10 @@
package config package config
import "time" import (
"errors"
"fmt"
"time"
)
type SMTPConfig struct { type SMTPConfig struct {
Enable bool Enable bool
@ -13,6 +17,35 @@ type SMTPConfig struct {
Password string Password string
} }
// Validate 驗證 SMTP 配置
func (c *SMTPConfig) Validate() error {
if !c.Enable {
return nil // 未啟用則不驗證
}
if c.Host == "" {
return errors.New("smtp host is required")
}
if c.Port <= 0 || c.Port > 65535 {
return fmt.Errorf("smtp port must be between 1 and 65535, got %d", c.Port)
}
if c.Username == "" {
return errors.New("smtp username is required")
}
if c.Password == "" {
return errors.New("smtp password is required")
}
if c.Sort < 0 {
return fmt.Errorf("smtp sort must be >= 0, got %d", c.Sort)
}
return nil
}
type AmazonSesSettings struct { type AmazonSesSettings struct {
Enable bool Enable bool
Sort int Sort int
@ -26,6 +59,39 @@ type AmazonSesSettings struct {
Token string Token string
} }
// Validate 驗證 AWS SES 配置
func (c *AmazonSesSettings) Validate() error {
if !c.Enable {
return nil // 未啟用則不驗證
}
if c.Region == "" {
return errors.New("aws ses region is required")
}
if c.Sender == "" {
return errors.New("aws ses sender is required")
}
if c.AccessKey == "" {
return errors.New("aws ses access key is required")
}
if c.SecretKey == "" {
return errors.New("aws ses secret key is required")
}
if c.Sort < 0 {
return fmt.Errorf("aws ses sort must be >= 0, got %d", c.Sort)
}
if c.PoolSize < 0 {
return fmt.Errorf("aws ses pool size must be >= 0, got %d", c.PoolSize)
}
return nil
}
type MitakeSMSSender struct { type MitakeSMSSender struct {
Enable bool Enable bool
Sort int Sort int
@ -35,6 +101,31 @@ type MitakeSMSSender struct {
Password string Password string
} }
// Validate 驗證 Mitake SMS 配置
func (c *MitakeSMSSender) Validate() error {
if !c.Enable {
return nil // 未啟用則不驗證
}
if c.User == "" {
return errors.New("mitake user is required")
}
if c.Password == "" {
return errors.New("mitake password is required")
}
if c.Sort < 0 {
return fmt.Errorf("mitake sort must be >= 0, got %d", c.Sort)
}
if c.PoolSize < 0 {
return fmt.Errorf("mitake pool size must be >= 0, got %d", c.PoolSize)
}
return nil
}
// DeliveryConfig 傳送重試配置 // DeliveryConfig 傳送重試配置
type DeliveryConfig struct { type DeliveryConfig struct {
MaxRetries int `json:"max_retries"` // 最大重試次數 MaxRetries int `json:"max_retries"` // 最大重試次數
@ -44,3 +135,72 @@ type DeliveryConfig struct {
Timeout time.Duration `json:"timeout"` // 單次發送超時時間 Timeout time.Duration `json:"timeout"` // 單次發送超時時間
EnableHistory bool `json:"enable_history"` // 是否啟用歷史記錄 EnableHistory bool `json:"enable_history"` // 是否啟用歷史記錄
} }
// Validate 驗證 DeliveryConfig 配置
func (c *DeliveryConfig) Validate() error {
if c.MaxRetries < 0 {
return fmt.Errorf("max_retries must be >= 0, got %d", c.MaxRetries)
}
if c.MaxRetries > 10 {
return fmt.Errorf("max_retries should not exceed 10, got %d", c.MaxRetries)
}
if c.InitialDelay < 0 {
return fmt.Errorf("initial_delay must be >= 0, got %v", c.InitialDelay)
}
if c.InitialDelay > 10*time.Second {
return fmt.Errorf("initial_delay is too large (> 10s), got %v", c.InitialDelay)
}
if c.BackoffFactor < 1.0 {
return fmt.Errorf("backoff_factor must be >= 1.0, got %v", c.BackoffFactor)
}
if c.BackoffFactor > 10.0 {
return fmt.Errorf("backoff_factor is too large (> 10.0), got %v", c.BackoffFactor)
}
if c.MaxDelay < 0 {
return fmt.Errorf("max_delay must be >= 0, got %v", c.MaxDelay)
}
if c.MaxDelay > 5*time.Minute {
return fmt.Errorf("max_delay is too large (> 5m), got %v", c.MaxDelay)
}
if c.Timeout <= 0 {
return fmt.Errorf("timeout must be > 0, got %v", c.Timeout)
}
if c.Timeout > 5*time.Minute {
return fmt.Errorf("timeout is too large (> 5m), got %v", c.Timeout)
}
// 檢查 InitialDelay 和 MaxDelay 的關係
if c.InitialDelay > c.MaxDelay && c.MaxDelay > 0 {
return fmt.Errorf("initial_delay (%v) should not exceed max_delay (%v)", c.InitialDelay, c.MaxDelay)
}
return nil
}
// SetDefaults 設置默認值
func (c *DeliveryConfig) SetDefaults() {
if c.MaxRetries == 0 {
c.MaxRetries = 3
}
if c.InitialDelay == 0 {
c.InitialDelay = 100 * time.Millisecond
}
if c.BackoffFactor == 0 {
c.BackoffFactor = 2.0
}
if c.MaxDelay == 0 {
c.MaxDelay = 30 * time.Second
}
if c.Timeout == 0 {
c.Timeout = 30 * time.Second
}
}

View File

@ -0,0 +1,314 @@
package config
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestSMTPConfig_Validate(t *testing.T) {
tests := []struct {
name string
config SMTPConfig
wantErr bool
errMsg string
}{
{
name: "有效的 SMTP 配置",
config: SMTPConfig{
Enable: true,
Sort: 1,
Host: "smtp.gmail.com",
Port: 587,
Username: "test@gmail.com",
Password: "password",
},
wantErr: false,
},
{
name: "未啟用的配置(不驗證)",
config: SMTPConfig{
Enable: false,
},
wantErr: false,
},
{
name: "缺少 Host",
config: SMTPConfig{
Enable: true,
Port: 587,
Username: "test@gmail.com",
Password: "password",
},
wantErr: true,
errMsg: "smtp host is required",
},
{
name: "無效的 Port",
config: SMTPConfig{
Enable: true,
Host: "smtp.gmail.com",
Port: 99999,
Username: "test@gmail.com",
Password: "password",
},
wantErr: true,
errMsg: "smtp port must be between 1 and 65535",
},
{
name: "缺少 Username",
config: SMTPConfig{
Enable: true,
Host: "smtp.gmail.com",
Port: 587,
Password: "password",
},
wantErr: true,
errMsg: "smtp username is required",
},
{
name: "負數的 Sort",
config: SMTPConfig{
Enable: true,
Sort: -1,
Host: "smtp.gmail.com",
Port: 587,
Username: "test@gmail.com",
Password: "password",
},
wantErr: true,
errMsg: "smtp sort must be >= 0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.wantErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestAmazonSesSettings_Validate(t *testing.T) {
tests := []struct {
name string
config AmazonSesSettings
wantErr bool
errMsg string
}{
{
name: "有效的 AWS SES 配置",
config: AmazonSesSettings{
Enable: true,
Sort: 1,
PoolSize: 10,
Region: "us-west-2",
Sender: "noreply@example.com",
AccessKey: "AKIAIOSFODNN7EXAMPLE",
SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
},
wantErr: false,
},
{
name: "未啟用的配置",
config: AmazonSesSettings{
Enable: false,
},
wantErr: false,
},
{
name: "缺少 Region",
config: AmazonSesSettings{
Enable: true,
Sender: "noreply@example.com",
AccessKey: "AKIAIOSFODNN7EXAMPLE",
SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
},
wantErr: true,
errMsg: "aws ses region is required",
},
{
name: "缺少 AccessKey",
config: AmazonSesSettings{
Enable: true,
Region: "us-west-2",
Sender: "noreply@example.com",
SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
},
wantErr: true,
errMsg: "aws ses access key is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.wantErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestMitakeSMSSender_Validate(t *testing.T) {
tests := []struct {
name string
config MitakeSMSSender
wantErr bool
errMsg string
}{
{
name: "有效的 Mitake 配置",
config: MitakeSMSSender{
Enable: true,
Sort: 1,
PoolSize: 5,
User: "testuser",
Password: "testpass",
},
wantErr: false,
},
{
name: "缺少 User",
config: MitakeSMSSender{
Enable: true,
Password: "testpass",
},
wantErr: true,
errMsg: "mitake user is required",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.wantErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestDeliveryConfig_Validate(t *testing.T) {
tests := []struct {
name string
config DeliveryConfig
wantErr bool
errMsg string
}{
{
name: "有效的配置",
config: DeliveryConfig{
MaxRetries: 3,
InitialDelay: 100 * time.Millisecond,
BackoffFactor: 2.0,
MaxDelay: 30 * time.Second,
Timeout: 30 * time.Second,
},
wantErr: false,
},
{
name: "MaxRetries 為負數",
config: DeliveryConfig{
MaxRetries: -1,
Timeout: 30 * time.Second,
},
wantErr: true,
errMsg: "max_retries must be >= 0",
},
{
name: "MaxRetries 過大",
config: DeliveryConfig{
MaxRetries: 20,
Timeout: 30 * time.Second,
},
wantErr: true,
errMsg: "max_retries should not exceed 10",
},
{
name: "BackoffFactor 小於 1.0",
config: DeliveryConfig{
MaxRetries: 3,
BackoffFactor: 0.5,
Timeout: 30 * time.Second,
},
wantErr: true,
errMsg: "backoff_factor must be >= 1.0",
},
{
name: "Timeout 為 0",
config: DeliveryConfig{
MaxRetries: 3,
BackoffFactor: 2.0,
Timeout: 0,
},
wantErr: true,
errMsg: "timeout must be > 0",
},
{
name: "InitialDelay 大於 MaxDelay",
config: DeliveryConfig{
MaxRetries: 3,
InitialDelay: 1 * time.Minute,
MaxDelay: 10 * time.Second,
Timeout: 30 * time.Second,
},
wantErr: true,
errMsg: "initial_delay",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.wantErr {
assert.Error(t, err)
if tt.errMsg != "" {
assert.Contains(t, err.Error(), tt.errMsg)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestDeliveryConfig_SetDefaults(t *testing.T) {
config := DeliveryConfig{}
config.SetDefaults()
assert.Equal(t, 3, config.MaxRetries)
assert.Equal(t, 100*time.Millisecond, config.InitialDelay)
assert.Equal(t, 2.0, config.BackoffFactor)
assert.Equal(t, 30*time.Second, config.MaxDelay)
assert.Equal(t, 30*time.Second, config.Timeout)
// 測試不覆蓋已設置的值
config2 := DeliveryConfig{
MaxRetries: 5,
Timeout: 60 * time.Second,
}
config2.SetDefaults()
assert.Equal(t, 5, config2.MaxRetries) // 保持原值
assert.Equal(t, 60*time.Second, config2.Timeout) // 保持原值
assert.Equal(t, 100*time.Millisecond, config2.InitialDelay) // 設置默認值
}

View File

@ -8,6 +8,7 @@ const (
FailedToSendEmailErrorCode FailedToSendEmailErrorCode
FailedToSendSMSErrorCode FailedToSendSMSErrorCode
FailedToGetTemplateErrorCode FailedToGetTemplateErrorCode
FailedToRenderTemplateErrorCode
FailedToSaveHistoryErrorCode FailedToSaveHistoryErrorCode
FailedToRetryDeliveryErrorCode FailedToRetryDeliveryErrorCode
) )

View File

@ -1,10 +1,24 @@
package usecase package usecase
import ( import (
"backend/pkg/notification/domain/entity"
"backend/pkg/notification/domain/template" "backend/pkg/notification/domain/template"
"context" "context"
) )
type TemplateUseCase interface { type TemplateUseCase interface {
// GetEmailTemplateByStatic 從靜態模板獲取郵件模板
GetEmailTemplateByStatic(ctx context.Context, language template.Language, templateID template.Type) (template.EmailTemplate, error) GetEmailTemplateByStatic(ctx context.Context, language template.Language, templateID template.Type) (template.EmailTemplate, error)
// GetEmailTemplate 獲取郵件模板(優先從資料庫,回退到靜態模板)
GetEmailTemplate(ctx context.Context, language template.Language, templateID template.Type) (template.EmailTemplate, error)
// GetSMSTemplate 獲取 SMS 模板(優先從資料庫,回退到靜態模板)
GetSMSTemplate(ctx context.Context, language template.Language, templateID template.Type) (SMSTemplateResp, error)
// RenderEmailTemplate 渲染郵件模板(替換變數)
RenderEmailTemplate(ctx context.Context, tmpl template.EmailTemplate, params entity.TemplateParams) (EmailTemplateResp, error)
// RenderSMSTemplate 渲染 SMS 模板(替換變數)
RenderSMSTemplate(ctx context.Context, tmpl SMSTemplateResp, params entity.TemplateParams) (SMSTemplateResp, error)
} }

View File

@ -5,11 +5,10 @@ import (
"backend/pkg/notification/domain" "backend/pkg/notification/domain"
"backend/pkg/notification/domain/repository" "backend/pkg/notification/domain/repository"
"context" "context"
"time" "fmt"
"backend/pkg/library/errs" "backend/pkg/library/errs"
"backend/pkg/library/errs/code" "backend/pkg/library/errs/code"
pool "backend/pkg/library/worker_pool"
"github.com/aws/aws-sdk-go-v2/credentials" "github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/ses/types" "github.com/aws/aws-sdk-go-v2/service/ses/types"
@ -25,8 +24,8 @@ type AwsEmailDeliveryParam struct {
} }
type AwsEmailDeliveryRepository struct { type AwsEmailDeliveryRepository struct {
Client *ses.Client Client *ses.Client
Pool pool.WorkerPool Timeout int // 超時時間(秒),預設 30
} }
func MustAwsSesMailRepository(param AwsEmailDeliveryParam) repository.MailRepository { func MustAwsSesMailRepository(param AwsEmailDeliveryParam) repository.MailRepository {
@ -42,70 +41,62 @@ func MustAwsSesMailRepository(param AwsEmailDeliveryParam) repository.MailReposi
// 創建 SES 客戶端 // 創建 SES 客戶端
sesClient := ses.NewFromConfig(cfg) sesClient := ses.NewFromConfig(cfg)
// 設置默認超時時間
timeout := 30
if param.Conf.PoolSize > 0 {
timeout = param.Conf.PoolSize // 可以復用這個配置項,或新增專門的 Timeout 配置
}
return &AwsEmailDeliveryRepository{ return &AwsEmailDeliveryRepository{
Client: sesClient, Client: sesClient,
Pool: pool.NewWorkerPool(param.Conf.PoolSize), Timeout: timeout,
} }
} }
func (use *AwsEmailDeliveryRepository) SendMail(ctx context.Context, req repository.MailReq) error { func (repo *AwsEmailDeliveryRepository) SendMail(ctx context.Context, req repository.MailReq) error {
err := use.Pool.Submit(func() { // 檢查 context 是否已取消
// 設置郵件參數 if ctx.Err() != nil {
to := make([]string, 0, len(req.To)) return ctx.Err()
to = append(to, req.To...) }
input := &ses.SendEmailInput{ // 設置郵件參數
Destination: &types.Destination{ to := make([]string, 0, len(req.To))
ToAddresses: to, to = append(to, req.To...)
},
Message: &types.Message{ input := &ses.SendEmailInput{
Body: &types.Body{ Destination: &types.Destination{
Html: &types.Content{ ToAddresses: to,
Charset: aws.String("UTF-8"), },
Data: aws.String(req.Body), Message: &types.Message{
}, Body: &types.Body{
}, Html: &types.Content{
Subject: &types.Content{
Charset: aws.String("UTF-8"), Charset: aws.String("UTF-8"),
Data: aws.String(req.Subject), Data: aws.String(req.Body),
}, },
}, },
Source: aws.String(req.From), Subject: &types.Content{
} Charset: aws.String("UTF-8"),
Data: aws.String(req.Subject),
},
},
Source: aws.String(req.From),
}
// 發送郵件 // 發送郵件(直接使用傳入的 context不創建新的 context
// TODO 不明原因送不出去,會被 context cancel 這裡先把它手動加到100sec _, err := repo.Client.SendEmail(ctx, input)
newCtx, cancel := context.WithTimeout(context.Background(), 100*time.Second)
defer cancel()
//nolint:contextcheck
if _, err := use.Client.SendEmail(newCtx, input); err != nil {
_ = errs.ThirdPartyErrorL(
code.CloudEPNotification,
domain.FailedToSendEmailErrorCode,
logx.WithContext(ctx),
[]logx.LogField{
{Key: "req", Value: req},
{Key: "func", Value: "AwsEmailDeliveryU.SendEmail"},
{Key: "err", Value: err.Error()},
},
"failed to send mail by aws ses")
}
})
if err != nil { if err != nil {
e := errs.ThirdPartyErrorL( return errs.ThirdPartyErrorL(
code.CloudEPNotification, code.CloudEPNotification,
domain.FailedToSendEmailErrorCode, domain.FailedToSendEmailErrorCode,
logx.WithContext(ctx), logx.WithContext(ctx),
[]logx.LogField{ []logx.LogField{
{Key: "req", Value: req}, {Key: "req", Value: req},
{Key: "func", Value: "AwsEmailDeliveryU.SendEmail"}, {Key: "func", Value: "AwsEmailDeliveryRepository.SendEmail"},
{Key: "err", Value: err.Error()}, {Key: "err", Value: err.Error()},
}, },
"failed to send mail by aws ses") fmt.Sprintf("failed to send mail by aws ses: %v", err)).Wrap(err)
return e
} }
logx.WithContext(ctx).Infof("Email sent successfully via AWS SES to %v", req.To)
return nil return nil
} }

View File

@ -5,10 +5,10 @@ import (
"backend/pkg/notification/domain" "backend/pkg/notification/domain"
"backend/pkg/notification/domain/repository" "backend/pkg/notification/domain/repository"
"context" "context"
"fmt"
"backend/pkg/library/errs" "backend/pkg/library/errs"
"backend/pkg/library/errs/code" "backend/pkg/library/errs/code"
pool "backend/pkg/library/worker_pool"
"github.com/minchao/go-mitake" "github.com/minchao/go-mitake"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
@ -21,45 +21,43 @@ type MitakeSMSDeliveryParam struct {
type MitakeSMSDeliveryRepository struct { type MitakeSMSDeliveryRepository struct {
Client *mitake.Client Client *mitake.Client
Pool pool.WorkerPool
} }
func (use *MitakeSMSDeliveryRepository) SendSMS(ctx context.Context, req repository.SMSMessageRequest) error { func (repo *MitakeSMSDeliveryRepository) SendSMS(ctx context.Context, req repository.SMSMessageRequest) error {
// 用 goroutine pool 送,否則會超時 // 檢查 context 是否已取消
err := use.Pool.Submit(func() { if ctx.Err() != nil {
message := mitake.Message{ return ctx.Err()
Dstaddr: req.PhoneNumber, }
Destname: req.RecipientName,
Smbody: req.MessageContent,
}
_, err := use.Client.Send(message)
if err != nil {
logx.Error("failed to send sms via mitake")
}
})
// 構建簡訊訊息
message := mitake.Message{
Dstaddr: req.PhoneNumber,
Destname: req.RecipientName,
Smbody: req.MessageContent,
}
// 直接發送,不使用 goroutine pool
// 讓 delivery usecase 統一管理重試和超時
_, err := repo.Client.Send(message)
if err != nil { if err != nil {
// 錯誤代碼 20-201-04 return errs.ThirdPartyErrorL(
e := errs.ThirdPartyErrorL(
code.CloudEPNotification, code.CloudEPNotification,
domain.FailedToSendSMSErrorCode, domain.FailedToSendSMSErrorCode,
logx.WithContext(ctx), logx.WithContext(ctx),
[]logx.LogField{ []logx.LogField{
{Key: "req", Value: req}, {Key: "req", Value: req},
{Key: "func", Value: "MitakeSMSDeliveryRepository.Client.Send"}, {Key: "func", Value: "MitakeSMSDeliveryRepository.Send"},
{Key: "err", Value: err.Error()}, {Key: "err", Value: err.Error()},
}, },
"failed to send sns by mitake").Wrap(err) fmt.Sprintf("failed to send sms by mitake: %v", err)).Wrap(err)
return e
} }
logx.WithContext(ctx).Infof("SMS sent successfully via Mitake to %s", req.PhoneNumber)
return nil return nil
} }
func MustMitakeRepository(param MitakeSMSDeliveryParam) repository.SMSClientRepository { func MustMitakeRepository(param MitakeSMSDeliveryParam) repository.SMSClientRepository {
return &MitakeSMSDeliveryRepository{ return &MitakeSMSDeliveryRepository{
Client: mitake.NewClient(param.Conf.User, param.Conf.Password, nil), Client: mitake.NewClient(param.Conf.User, param.Conf.Password, nil),
Pool: pool.NewWorkerPool(param.Conf.PoolSize),
} }
} }

View File

@ -2,10 +2,13 @@ package repository
import ( import (
"backend/pkg/notification/config" "backend/pkg/notification/config"
"backend/pkg/notification/domain"
"backend/pkg/notification/domain/repository" "backend/pkg/notification/domain/repository"
"context" "context"
"fmt"
pool "backend/pkg/library/worker_pool" "backend/pkg/library/errs"
"backend/pkg/library/errs/code"
"github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/logx"
"gopkg.in/gomail.v2" "gopkg.in/gomail.v2"
@ -17,7 +20,6 @@ type SMTPMailUseCaseParam struct {
type SMTPMailRepository struct { type SMTPMailRepository struct {
Client *gomail.Dialer Client *gomail.Dialer
Pool pool.WorkerPool
} }
func MustSMTPUseCase(param SMTPMailUseCaseParam) repository.MailRepository { func MustSMTPUseCase(param SMTPMailUseCaseParam) repository.MailRepository {
@ -28,26 +30,37 @@ func MustSMTPUseCase(param SMTPMailUseCaseParam) repository.MailRepository {
param.Conf.Username, param.Conf.Username,
param.Conf.Password, param.Conf.Password,
), ),
Pool: pool.NewWorkerPool(param.Conf.GoroutinePoolNum),
} }
} }
func (repo *SMTPMailRepository) SendMail(_ context.Context, req repository.MailReq) error { func (repo *SMTPMailRepository) SendMail(ctx context.Context, req repository.MailReq) error {
// 用 goroutine pool 送,否則會超時 // 檢查 context 是否已取消
err := repo.Pool.Submit(func() { if ctx.Err() != nil {
m := gomail.NewMessage() return ctx.Err()
m.SetHeader("From", req.From) }
m.SetHeader("To", req.To...)
m.SetHeader("Subject", req.Subject)
m.SetBody("text/html", req.Body)
if err := repo.Client.DialAndSend(m); err != nil {
logx.WithCallerSkip(1).WithFields(
logx.Field("func", "MailUseCase.SendMail"),
logx.Field("req", req),
logx.Field("err", err),
).Error("failed to send mail by mailgun")
}
})
return err // 構建郵件
m := gomail.NewMessage()
m.SetHeader("From", req.From)
m.SetHeader("To", req.To...)
m.SetHeader("Subject", req.Subject)
m.SetBody("text/html", req.Body)
// 直接發送,不使用 goroutine pool
// 讓 delivery usecase 統一管理重試和超時
if err := repo.Client.DialAndSend(m); err != nil {
return errs.ThirdPartyErrorL(
code.CloudEPNotification,
domain.FailedToSendEmailErrorCode,
logx.WithContext(ctx),
[]logx.LogField{
{Key: "func", Value: "SMTPMailRepository.SendMail"},
{Key: "req", Value: req},
{Key: "err", Value: err.Error()},
},
fmt.Sprintf("failed to send mail by smtp: %v", err)).Wrap(err)
}
logx.WithContext(ctx).Infof("Email sent successfully via SMTP to %v", req.To)
return nil
} }

View File

@ -74,7 +74,10 @@ func (use *DeliveryUseCase) SendMessage(ctx context.Context, req usecase.SMSMess
} }
// 執行發送邏輯 // 執行發送邏輯
return use.sendSMSWithRetry(ctx, req, history) return use.sendWithRetry(ctx, history, &smsProviderAdapter{
providers: use.param.SMSProviders,
request: req,
})
} }
func (use *DeliveryUseCase) SendEmail(ctx context.Context, req usecase.MailReq) error { func (use *DeliveryUseCase) SendEmail(ctx context.Context, req usecase.MailReq) error {
@ -97,30 +100,130 @@ func (use *DeliveryUseCase) SendEmail(ctx context.Context, req usecase.MailReq)
} }
// 執行發送邏輯 // 執行發送邏輯
return use.sendEmailWithRetry(ctx, req, history) return use.sendWithRetry(ctx, history, &emailProviderAdapter{
providers: use.param.EmailProviders,
request: req,
})
} }
// sendSMSWithRetry 發送 SMS 並實現重試機制 // providerAdapter 統一的供應商適配器接口
func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SMSMessageRequest, history *entity.DeliveryHistory) error { type providerAdapter interface {
// 根據 Sort 欄位對 SMSProviders 進行排序 getProviderCount() int
providers := make([]usecase.SMSProvider, len(use.param.SMSProviders)) getProviderName(index int) string
copy(providers, use.param.SMSProviders) getProviderSort(index int) int64
sort.Slice(providers, func(i, j int) bool { send(ctx context.Context, providerIndex int) error
return providers[i].Sort < providers[j].Sort getErrorCode() errs.ErrorCode
getType() string
}
// smsProviderAdapter SMS 供應商適配器
type smsProviderAdapter struct {
providers []usecase.SMSProvider
request usecase.SMSMessageRequest
}
func (a *smsProviderAdapter) getProviderCount() int {
return len(a.providers)
}
func (a *smsProviderAdapter) getProviderName(index int) string {
return fmt.Sprintf("sms_provider_%d", index)
}
func (a *smsProviderAdapter) getProviderSort(index int) int64 {
return a.providers[index].Sort
}
func (a *smsProviderAdapter) send(ctx context.Context, providerIndex int) error {
return a.providers[providerIndex].Repo.SendSMS(ctx, repository.SMSMessageRequest{
PhoneNumber: a.request.PhoneNumber,
RecipientName: a.request.RecipientName,
MessageContent: a.request.MessageContent,
})
}
func (a *smsProviderAdapter) getErrorCode() errs.ErrorCode {
return domain.FailedToSendSMSErrorCode
}
func (a *smsProviderAdapter) getType() string {
return "SMS"
}
// emailProviderAdapter Email 供應商適配器
type emailProviderAdapter struct {
providers []usecase.EmailProvider
request usecase.MailReq
}
func (a *emailProviderAdapter) getProviderCount() int {
return len(a.providers)
}
func (a *emailProviderAdapter) getProviderName(index int) string {
return fmt.Sprintf("email_provider_%d", index)
}
func (a *emailProviderAdapter) getProviderSort(index int) int64 {
return a.providers[index].Sort
}
func (a *emailProviderAdapter) send(ctx context.Context, providerIndex int) error {
return a.providers[providerIndex].Repo.SendMail(ctx, repository.MailReq{
From: a.request.From,
To: a.request.To,
Subject: a.request.Subject,
Body: a.request.Body,
})
}
func (a *emailProviderAdapter) getErrorCode() errs.ErrorCode {
return domain.FailedToSendEmailErrorCode
}
func (a *emailProviderAdapter) getType() string {
return "Email"
}
// providerWithIndex 用於排序的結構
type providerWithIndex struct {
index int
sort int64
}
// sendWithRetry 統一的發送重試邏輯
func (use *DeliveryUseCase) sendWithRetry(
ctx context.Context,
history *entity.DeliveryHistory,
adapter providerAdapter,
) error {
// 按 Sort 欄位對供應商進行排序
providerCount := adapter.getProviderCount()
sortedProviders := make([]providerWithIndex, providerCount)
for i := 0; i < providerCount; i++ {
sortedProviders[i] = providerWithIndex{
index: i,
sort: adapter.getProviderSort(i),
}
}
sort.Slice(sortedProviders, func(i, j int) bool {
return sortedProviders[i].sort < sortedProviders[j].sort
}) })
var lastErr error var lastErr error
totalAttempts := 0 totalAttempts := 0
// 嘗試所有 providers // 嘗試所有 providers
for providerIndex, provider := range providers { for _, provider := range sortedProviders {
providerIndex := provider.index
// 為每個 provider 嘗試發送 // 為每個 provider 嘗試發送
for attempt := 0; attempt < use.param.DeliveryConfig.MaxRetries; attempt++ { for attempt := 0; attempt < use.param.DeliveryConfig.MaxRetries; attempt++ {
totalAttempts++ totalAttempts++
// 更新歷史記錄狀態 // 更新歷史記錄狀態
history.Status = entity.DeliveryStatusSending history.Status = entity.DeliveryStatusSending
history.Provider = fmt.Sprintf("sms_provider_%d", providerIndex) history.Provider = adapter.getProviderName(providerIndex)
history.AttemptCount = totalAttempts history.AttemptCount = totalAttempts
history.UpdatedAt = time.Now() history.UpdatedAt = time.Now()
use.updateHistory(ctx, history) use.updateHistory(ctx, history)
@ -131,11 +234,7 @@ func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SM
// 創建帶超時的 context // 創建帶超時的 context
sendCtx, cancel := context.WithTimeout(ctx, use.param.DeliveryConfig.Timeout) sendCtx, cancel := context.WithTimeout(ctx, use.param.DeliveryConfig.Timeout)
err := provider.Repo.SendSMS(sendCtx, repository.SMSMessageRequest{ err := adapter.send(sendCtx, providerIndex)
PhoneNumber: req.PhoneNumber,
RecipientName: req.RecipientName,
MessageContent: req.MessageContent,
})
cancel() cancel()
@ -153,8 +252,8 @@ func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SM
attemptRecord.ErrorMessage = err.Error() attemptRecord.ErrorMessage = err.Error()
lastErr = err lastErr = err
logx.WithContext(ctx).Errorf("SMS send attempt %d failed for provider %d: %v", logx.WithContext(ctx).Errorf("%s send attempt %d failed for provider %d: %v",
attempt+1, providerIndex, err) adapter.getType(), attempt+1, providerIndex, err)
// 如果不是最後一次嘗試,等待後重試 // 如果不是最後一次嘗試,等待後重試
if attempt < use.param.DeliveryConfig.MaxRetries-1 { if attempt < use.param.DeliveryConfig.MaxRetries-1 {
@ -179,7 +278,8 @@ func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SM
use.updateHistory(ctx, history) use.updateHistory(ctx, history)
use.addAttemptRecord(ctx, history.ID, attemptRecord) use.addAttemptRecord(ctx, history.ID, attemptRecord)
logx.WithContext(ctx).Infof("SMS sent successfully after %d attempts", totalAttempts) logx.WithContext(ctx).Infof("%s sent successfully after %d attempts",
adapter.getType(), totalAttempts)
return nil return nil
} }
@ -197,112 +297,9 @@ func (use *DeliveryUseCase) sendSMSWithRetry(ctx context.Context, req usecase.SM
return errs.ThirdPartyError( return errs.ThirdPartyError(
code.CloudEPNotification, code.CloudEPNotification,
domain.FailedToSendSMSErrorCode, adapter.getErrorCode(),
fmt.Sprintf("Failed to send SMS after %d attempts across %d providers", fmt.Sprintf("Failed to send %s after %d attempts across %d providers",
totalAttempts, len(providers))) adapter.getType(), totalAttempts, providerCount))
}
// sendEmailWithRetry 發送 Email 並實現重試機制
func (use *DeliveryUseCase) sendEmailWithRetry(ctx context.Context, req usecase.MailReq, history *entity.DeliveryHistory) error {
// 根據 Sort 欄位對 EmailProviders 進行排序
providers := make([]usecase.EmailProvider, len(use.param.EmailProviders))
copy(providers, use.param.EmailProviders)
sort.Slice(providers, func(i, j int) bool {
return providers[i].Sort < providers[j].Sort
})
var lastErr error
totalAttempts := 0
// 嘗試所有 providers
for providerIndex, provider := range providers {
// 為每個 provider 嘗試發送
for attempt := 0; attempt < use.param.DeliveryConfig.MaxRetries; attempt++ {
totalAttempts++
// 更新歷史記錄狀態
history.Status = entity.DeliveryStatusSending
history.Provider = fmt.Sprintf("email_provider_%d", providerIndex)
history.AttemptCount = totalAttempts
history.UpdatedAt = time.Now()
use.updateHistory(ctx, history)
// 記錄發送嘗試
attemptStart := time.Now()
// 創建帶超時的 context
sendCtx, cancel := context.WithTimeout(ctx, use.param.DeliveryConfig.Timeout)
err := provider.Repo.SendMail(sendCtx, repository.MailReq{
From: req.From,
To: req.To,
Subject: req.Subject,
Body: req.Body,
})
cancel()
// 記錄嘗試結果
attemptDuration := time.Since(attemptStart)
attemptRecord := entity.DeliveryAttempt{
Provider: history.Provider,
AttemptAt: attemptStart,
Success: err == nil,
ErrorMessage: "",
Duration: attemptDuration.Milliseconds(),
}
if err != nil {
attemptRecord.ErrorMessage = err.Error()
lastErr = err
logx.WithContext(ctx).Errorf("Email send attempt %d failed for provider %d: %v",
attempt+1, providerIndex, err)
// 如果不是最後一次嘗試,等待後重試
if attempt < use.param.DeliveryConfig.MaxRetries-1 {
delay := use.calculateDelay(attempt)
history.Status = entity.DeliveryStatusRetrying
use.updateHistory(ctx, history)
use.addAttemptRecord(ctx, history.ID, attemptRecord)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
continue
}
}
} else {
// 發送成功
history.Status = entity.DeliveryStatusSuccess
history.UpdatedAt = time.Now()
now := time.Now()
history.CompletedAt = &now
use.updateHistory(ctx, history)
use.addAttemptRecord(ctx, history.ID, attemptRecord)
logx.WithContext(ctx).Infof("Email sent successfully after %d attempts", totalAttempts)
return nil
}
use.addAttemptRecord(ctx, history.ID, attemptRecord)
}
}
// 所有 providers 都失敗了
history.Status = entity.DeliveryStatusFailed
history.ErrorMessage = fmt.Sprintf("All providers failed. Last error: %v", lastErr)
history.UpdatedAt = time.Now()
now := time.Now()
history.CompletedAt = &now
use.updateHistory(ctx, history)
return errs.ThirdPartyError(
code.CloudEPNotification,
domain.FailedToSendEmailErrorCode,
fmt.Sprintf("Failed to send email after %d attempts across %d providers",
totalAttempts, len(providers)))
} }
// calculateDelay 計算指數退避延遲 // calculateDelay 計算指數退避延遲

View File

@ -0,0 +1,377 @@
package usecase
import (
"backend/pkg/notification/config"
"backend/pkg/notification/domain/entity"
"backend/pkg/notification/domain/repository"
"backend/pkg/notification/domain/usecase"
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// mockSMSRepository 模擬 SMS Repository
type mockSMSRepository struct {
sendFunc func(ctx context.Context, req repository.SMSMessageRequest) error
}
func (m *mockSMSRepository) SendSMS(ctx context.Context, req repository.SMSMessageRequest) error {
if m.sendFunc != nil {
return m.sendFunc(ctx, req)
}
return nil
}
// mockMailRepository 模擬 Mail Repository
type mockMailRepository struct {
sendFunc func(ctx context.Context, req repository.MailReq) error
}
func (m *mockMailRepository) SendMail(ctx context.Context, req repository.MailReq) error {
if m.sendFunc != nil {
return m.sendFunc(ctx, req)
}
return nil
}
// mockHistoryRepository 模擬 History Repository
type mockHistoryRepository struct {
histories []entity.DeliveryHistory
attempts map[string][]entity.DeliveryAttempt
}
func (m *mockHistoryRepository) CreateHistory(ctx context.Context, history *entity.DeliveryHistory) error {
m.histories = append(m.histories, *history)
return nil
}
func (m *mockHistoryRepository) UpdateHistory(ctx context.Context, history *entity.DeliveryHistory) error {
for i := range m.histories {
if m.histories[i].ID == history.ID {
m.histories[i] = *history
return nil
}
}
return nil
}
func (m *mockHistoryRepository) GetHistory(ctx context.Context, id string) (*entity.DeliveryHistory, error) {
for i := range m.histories {
if m.histories[i].ID == id {
return &m.histories[i], nil
}
}
return nil, errors.New("not found")
}
func (m *mockHistoryRepository) AddAttempt(ctx context.Context, historyID string, attempt entity.DeliveryAttempt) error {
if m.attempts == nil {
m.attempts = make(map[string][]entity.DeliveryAttempt)
}
m.attempts[historyID] = append(m.attempts[historyID], attempt)
return nil
}
func (m *mockHistoryRepository) ListHistory(ctx context.Context, filter repository.HistoryFilter) ([]*entity.DeliveryHistory, error) {
var result []*entity.DeliveryHistory
for i := range m.histories {
result = append(result, &m.histories[i])
}
return result, nil
}
func TestDeliveryUseCase_SendEmail_Success(t *testing.T) {
mockMail := &mockMailRepository{
sendFunc: func(ctx context.Context, req repository.MailReq) error {
return nil // 成功
},
}
mockHistory := &mockHistoryRepository{}
uc := MustDeliveryUseCase(DeliveryUseCaseParam{
EmailProviders: []usecase.EmailProvider{
{Sort: 1, Repo: mockMail},
},
DeliveryConfig: config.DeliveryConfig{
MaxRetries: 3,
InitialDelay: 10 * time.Millisecond,
BackoffFactor: 2.0,
MaxDelay: 100 * time.Millisecond,
Timeout: 5 * time.Second,
EnableHistory: true,
},
HistoryRepo: mockHistory,
})
ctx := context.Background()
err := uc.SendEmail(ctx, usecase.MailReq{
From: "test@example.com",
To: []string{"user@example.com"},
Subject: "Test",
Body: "<p>Test email</p>",
})
assert.NoError(t, err)
// 驗證歷史記錄
assert.Equal(t, 1, len(mockHistory.histories))
assert.Equal(t, entity.DeliveryStatusSuccess, mockHistory.histories[0].Status)
}
func TestDeliveryUseCase_SendEmail_RetryAndSuccess(t *testing.T) {
attemptCount := 0
mockMail := &mockMailRepository{
sendFunc: func(ctx context.Context, req repository.MailReq) error {
attemptCount++
if attemptCount < 3 {
return errors.New("temporary error")
}
return nil // 第三次成功
},
}
mockHistory := &mockHistoryRepository{}
uc := MustDeliveryUseCase(DeliveryUseCaseParam{
EmailProviders: []usecase.EmailProvider{
{Sort: 1, Repo: mockMail},
},
DeliveryConfig: config.DeliveryConfig{
MaxRetries: 3,
InitialDelay: 10 * time.Millisecond,
BackoffFactor: 2.0,
MaxDelay: 100 * time.Millisecond,
Timeout: 5 * time.Second,
EnableHistory: true,
},
HistoryRepo: mockHistory,
})
ctx := context.Background()
err := uc.SendEmail(ctx, usecase.MailReq{
From: "test@example.com",
To: []string{"user@example.com"},
Subject: "Test",
Body: "<p>Test</p>",
})
assert.NoError(t, err)
assert.Equal(t, 3, attemptCount) // 重試了 3 次
assert.Equal(t, 1, len(mockHistory.histories))
assert.Equal(t, entity.DeliveryStatusSuccess, mockHistory.histories[0].Status)
assert.Equal(t, 3, mockHistory.histories[0].AttemptCount)
}
func TestDeliveryUseCase_SendEmail_AllRetries_Failed(t *testing.T) {
mockMail := &mockMailRepository{
sendFunc: func(ctx context.Context, req repository.MailReq) error {
return errors.New("persistent error")
},
}
mockHistory := &mockHistoryRepository{}
uc := MustDeliveryUseCase(DeliveryUseCaseParam{
EmailProviders: []usecase.EmailProvider{
{Sort: 1, Repo: mockMail},
},
DeliveryConfig: config.DeliveryConfig{
MaxRetries: 3,
InitialDelay: 10 * time.Millisecond,
BackoffFactor: 2.0,
MaxDelay: 100 * time.Millisecond,
Timeout: 5 * time.Second,
EnableHistory: true,
},
HistoryRepo: mockHistory,
})
ctx := context.Background()
err := uc.SendEmail(ctx, usecase.MailReq{
From: "test@example.com",
To: []string{"user@example.com"},
Subject: "Test",
Body: "<p>Test</p>",
})
assert.Error(t, err)
assert.Contains(t, err.Error(), "Failed to send Email")
assert.Equal(t, 1, len(mockHistory.histories))
assert.Equal(t, entity.DeliveryStatusFailed, mockHistory.histories[0].Status)
assert.Equal(t, 3, mockHistory.histories[0].AttemptCount) // 嘗試了 3 次
}
func TestDeliveryUseCase_SendEmail_Failover(t *testing.T) {
mockMail1 := &mockMailRepository{
sendFunc: func(ctx context.Context, req repository.MailReq) error {
return errors.New("provider 1 failed")
},
}
mockMail2 := &mockMailRepository{
sendFunc: func(ctx context.Context, req repository.MailReq) error {
return nil // 備援成功
},
}
mockHistory := &mockHistoryRepository{}
uc := MustDeliveryUseCase(DeliveryUseCaseParam{
EmailProviders: []usecase.EmailProvider{
{Sort: 1, Repo: mockMail1}, // 主要供應商
{Sort: 2, Repo: mockMail2}, // 備援供應商
},
DeliveryConfig: config.DeliveryConfig{
MaxRetries: 2,
InitialDelay: 10 * time.Millisecond,
BackoffFactor: 2.0,
MaxDelay: 100 * time.Millisecond,
Timeout: 5 * time.Second,
EnableHistory: true,
},
HistoryRepo: mockHistory,
})
ctx := context.Background()
err := uc.SendEmail(ctx, usecase.MailReq{
From: "test@example.com",
To: []string{"user@example.com"},
Subject: "Test",
Body: "<p>Test</p>",
})
assert.NoError(t, err)
// 驗證使用了備援供應商
assert.Equal(t, 1, len(mockHistory.histories))
assert.Equal(t, entity.DeliveryStatusSuccess, mockHistory.histories[0].Status)
// 總共嘗試次數provider1 重試 2 次 + provider2 成功 1 次 = 3 次
assert.Equal(t, 3, mockHistory.histories[0].AttemptCount)
}
func TestDeliveryUseCase_SendSMS_Success(t *testing.T) {
mockSMS := &mockSMSRepository{
sendFunc: func(ctx context.Context, req repository.SMSMessageRequest) error {
return nil
},
}
mockHistory := &mockHistoryRepository{}
uc := MustDeliveryUseCase(DeliveryUseCaseParam{
SMSProviders: []usecase.SMSProvider{
{Sort: 1, Repo: mockSMS},
},
DeliveryConfig: config.DeliveryConfig{
MaxRetries: 3,
InitialDelay: 10 * time.Millisecond,
BackoffFactor: 2.0,
MaxDelay: 100 * time.Millisecond,
Timeout: 5 * time.Second,
EnableHistory: true,
},
HistoryRepo: mockHistory,
})
ctx := context.Background()
err := uc.SendMessage(ctx, usecase.SMSMessageRequest{
PhoneNumber: "+886912345678",
RecipientName: "Test User",
MessageContent: "Your code: 123456",
})
assert.NoError(t, err)
assert.Equal(t, 1, len(mockHistory.histories))
assert.Equal(t, entity.DeliveryStatusSuccess, mockHistory.histories[0].Status)
}
func TestDeliveryUseCase_CalculateDelay(t *testing.T) {
uc := &DeliveryUseCase{
param: DeliveryUseCaseParam{
DeliveryConfig: config.DeliveryConfig{
InitialDelay: 100 * time.Millisecond,
BackoffFactor: 2.0,
MaxDelay: 1 * time.Second,
},
},
}
tests := []struct {
name string
attempt int
expected time.Duration
}{
{
name: "第 0 次重試",
attempt: 0,
expected: 100 * time.Millisecond,
},
{
name: "第 1 次重試",
attempt: 1,
expected: 200 * time.Millisecond,
},
{
name: "第 2 次重試",
attempt: 2,
expected: 400 * time.Millisecond,
},
{
name: "第 3 次重試",
attempt: 3,
expected: 800 * time.Millisecond,
},
{
name: "第 10 次重試(達到 MaxDelay",
attempt: 10,
expected: 1 * time.Second, // 受限於 MaxDelay
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
delay := uc.calculateDelay(tt.attempt)
assert.Equal(t, tt.expected, delay)
})
}
}
func TestDeliveryUseCase_ContextCancellation(t *testing.T) {
mockMail := &mockMailRepository{
sendFunc: func(ctx context.Context, req repository.MailReq) error {
// 模擬慢速操作
time.Sleep(100 * time.Millisecond)
return errors.New("should not reach here")
},
}
uc := MustDeliveryUseCase(DeliveryUseCaseParam{
EmailProviders: []usecase.EmailProvider{
{Sort: 1, Repo: mockMail},
},
DeliveryConfig: config.DeliveryConfig{
MaxRetries: 3,
InitialDelay: 50 * time.Millisecond,
BackoffFactor: 2.0,
MaxDelay: 500 * time.Millisecond,
Timeout: 1 * time.Second,
EnableHistory: false,
},
})
// 創建會被取消的 context
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
err := uc.SendEmail(ctx, usecase.MailReq{
From: "test@example.com",
To: []string{"user@example.com"},
Subject: "Test",
Body: "<p>Test</p>",
})
assert.Error(t, err)
assert.Equal(t, context.DeadlineExceeded, err)
}

View File

@ -0,0 +1,255 @@
package usecase
import (
"backend/pkg/notification/domain/entity"
"backend/pkg/notification/domain/template"
"backend/pkg/notification/domain/usecase"
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func TestTemplateUseCase_RenderEmailTemplate(t *testing.T) {
uc := MustTemplateUseCase(TemplateUseCaseParam{
TemplateRepo: nil,
})
ctx := context.Background()
tests := []struct {
name string
tmpl template.EmailTemplate
params entity.TemplateParams
expectedSubj string
expectedBody string
shouldContain []string
shouldNotError bool
}{
{
name: "渲染基本參數",
tmpl: template.EmailTemplate{
Title: "Hello {{.Username}}",
Body: "<p>Your code is: {{.VerifyCode}}</p>",
},
params: entity.TemplateParams{
Username: "張三",
VerifyCode: "123456",
},
expectedSubj: "Hello 張三",
shouldContain: []string{"123456"},
shouldNotError: true,
},
{
name: "渲染額外參數",
tmpl: template.EmailTemplate{
Title: "Welcome",
Body: "<p>Hello {{.Username}}, your link: {{.Link}}</p>",
},
params: entity.TemplateParams{
Username: "John",
Extra: map[string]string{
"Link": "https://example.com",
},
},
shouldContain: []string{"John", "https://example.com"},
shouldNotError: true,
},
{
name: "特殊字符不轉義(簡單字符串替換)",
tmpl: template.EmailTemplate{
Title: "Test",
Body: "<p>Name: {{.Username}}</p>",
},
params: entity.TemplateParams{
Username: "<script>alert('xss')</script>",
},
shouldContain: []string{"<script>alert('xss')</script>"}, // 使用簡單字符串替換,不轉義
shouldNotError: true,
},
{
name: "空模板",
tmpl: template.EmailTemplate{
Title: "",
Body: "",
},
params: entity.TemplateParams{
Username: "Test",
},
expectedSubj: "",
expectedBody: "",
shouldNotError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := uc.RenderEmailTemplate(ctx, tt.tmpl, tt.params)
if tt.shouldNotError {
assert.NoError(t, err)
if tt.expectedSubj != "" {
assert.Equal(t, tt.expectedSubj, result.Subject)
}
if tt.expectedBody != "" {
assert.Equal(t, tt.expectedBody, result.Body)
}
for _, contain := range tt.shouldContain {
assert.Contains(t, result.Body, contain)
}
} else {
assert.Error(t, err)
}
})
}
}
func TestTemplateUseCase_RenderSMSTemplate(t *testing.T) {
uc := MustTemplateUseCase(TemplateUseCaseParam{
TemplateRepo: nil,
})
ctx := context.Background()
tests := []struct {
name string
tmpl usecase.SMSTemplateResp
params entity.TemplateParams
expectedBody string
shouldContain []string
shouldNotError bool
}{
{
name: "渲染 SMS 驗證碼",
tmpl: usecase.SMSTemplateResp{
Body: "您的驗證碼是:{{.VerifyCode}}請在5分鐘內使用。",
},
params: entity.TemplateParams{
VerifyCode: "654321",
},
shouldContain: []string{"654321", "5分鐘"},
shouldNotError: true,
},
{
name: "SMS 純文本替換",
tmpl: usecase.SMSTemplateResp{
Body: "Hi {{.Username}}, your code: {{.VerifyCode}}",
},
params: entity.TemplateParams{
Username: "<test>",
VerifyCode: "111111",
},
shouldContain: []string{"<test>", "111111"}, // 使用簡單字符串替換
shouldNotError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := uc.RenderSMSTemplate(ctx, tt.tmpl, tt.params)
if tt.shouldNotError {
assert.NoError(t, err)
if tt.expectedBody != "" {
assert.Equal(t, tt.expectedBody, result.Body)
}
for _, contain := range tt.shouldContain {
assert.Contains(t, result.Body, contain)
}
} else {
assert.Error(t, err)
}
})
}
}
func TestTemplateUseCase_GetEmailTemplateByStatic(t *testing.T) {
uc := MustTemplateUseCase(TemplateUseCaseParam{
TemplateRepo: nil,
})
ctx := context.Background()
tests := []struct {
name string
language template.Language
templateID template.Type
wantErr bool
}{
{
name: "獲取忘記密碼模板 (zh-tw)",
language: template.LanguageZhTW,
templateID: template.ForgetPasswordVerify,
wantErr: false,
},
{
name: "獲取綁定郵箱模板 (zh-tw)",
language: template.LanguageZhTW,
templateID: template.BindingEmail,
wantErr: false,
},
{
name: "不存在的語言",
language: template.Language("xx-xx"),
templateID: template.ForgetPasswordVerify,
wantErr: true,
},
{
name: "不存在的模板類型",
language: template.LanguageZhTW,
templateID: template.Type("non_existent"),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := uc.GetEmailTemplateByStatic(ctx, tt.language, tt.templateID)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, result.Title)
assert.NotEmpty(t, result.Body)
}
})
}
}
func TestTemplateUseCase_GetDefaultSMSTemplate(t *testing.T) {
uc := &TemplateUseCase{}
tests := []struct {
name string
templateID template.Type
shouldContain []string
}{
{
name: "忘記密碼模板",
templateID: template.ForgetPasswordVerify,
shouldContain: []string{"密碼重設", "驗證碼"},
},
{
name: "綁定郵箱模板",
templateID: template.BindingEmail,
shouldContain: []string{"綁定", "驗證碼"},
},
{
name: "默認模板",
templateID: template.Type("unknown"),
shouldContain: []string{"驗證碼"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := uc.getDefaultSMSTemplate(tt.templateID)
assert.NotEmpty(t, result.Body)
for _, contain := range tt.shouldContain {
assert.Contains(t, result.Body, contain)
}
})
}
}

View File

@ -40,7 +40,7 @@ func DefaultConfig() Config {
UIDLength: 6, UIDLength: 6,
AdminRoleUID: "AM000000", AdminRoleUID: "AM000000",
AdminUserUID: "B000000", AdminUserUID: "B000000",
DefaultRoleName: "user", DefaultRoleName: "USER",
}, },
} }
} }

View File

@ -0,0 +1,5 @@
package permission
const (
DefaultRole = "user"
)

View File

@ -0,0 +1,50 @@
package token
import (
"context"
)
type ContextKey string
func (c ContextKey) String() string {
return string(c)
}
const (
KeyRole ContextKey = "role"
KeyDeviceID ContextKey = "device_id"
KeyScope ContextKey = "scope"
KeyUID ContextKey = "uid"
KeyLoginID ContextKey = "login_id"
)
func UID(ctx context.Context) string { return getString(ctx, KeyUID) }
func Scope(ctx context.Context) string { return getString(ctx, KeyScope) }
func Role(ctx context.Context) string { return getString(ctx, KeyRole) }
func DeviceID(ctx context.Context) string { return getString(ctx, KeyDeviceID) }
func LoginID(ctx context.Context) string { return getString(ctx, KeyLoginID) }
func WithUID(ctx context.Context, uid string) context.Context {
return context.WithValue(ctx, KeyUID, uid)
}
func WithScope(ctx context.Context, scope string) context.Context {
return context.WithValue(ctx, KeyScope, scope)
}
func WithRole(ctx context.Context, role string) context.Context {
return context.WithValue(ctx, KeyRole, role)
}
func WithDeviceID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, KeyDeviceID, id)
}
func WithLoginID(ctx context.Context, login string) context.Context {
return context.WithValue(ctx, KeyLoginID, login)
}
// --- Internal helper ---
func getString(ctx context.Context, key ContextKey) string {
if v, ok := ctx.Value(key).(string); ok {
return v
}
return ""
}

View File

@ -26,7 +26,7 @@ type PermissionRepository struct {
DB mongo.DocumentDBWithCacheUseCase DB mongo.DocumentDBWithCacheUseCase
} }
func NewAccountRepository(param PermissionRepositoryParam) repository.PermissionRepository { func NewPermissionRepository(param PermissionRepositoryParam) repository.PermissionRepository {
e := entity.Permission{} e := entity.Permission{}
documentDB, err := mongo.MustDocumentDBWithCache( documentDB, err := mongo.MustDocumentDBWithCache(
param.Conf, param.Conf,

View File

@ -59,7 +59,7 @@ func setupPermissionRepo(db string) (domainRepo.PermissionRepository, func(), er
CacheConf: cacheConf, CacheConf: cacheConf,
CacheOpts: cacheOpts, CacheOpts: cacheOpts,
} }
repo := NewAccountRepository(param) repo := NewPermissionRepository(param)
_, _ = repo.Index20251009001UP(context.Background()) _, _ = repo.Index20251009001UP(context.Background())
return repo, tearDown, nil return repo, tearDown, nil

View File

@ -29,11 +29,11 @@ type TokenUseCase struct {
} }
func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) { func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) {
claims, err := parseClaims(token, use.Config.Token.AccessSecret, false) claims, err := ParseClaims(token, use.Config.Token.AccessSecret, false)
if err != nil { if err != nil {
return nil, return nil,
use.wrapTokenError(ctx, wrapTokenErrorReq{ use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims", funcName: "ParseClaims",
req: token, req: token,
err: err, err: err,
message: "validate token claims error", message: "validate token claims error",
@ -107,7 +107,7 @@ func (use *TokenUseCase) newToken(ctx context.Context, req *entity.Authorization
RefreshCreateAt: now, RefreshCreateAt: now,
} }
tc := make(tokenClaims) tc := make(TokenClaims)
if req.Data != nil { if req.Data != nil {
for k, v := range req.Data { for k, v := range req.Data {
tc[k] = v tc[k] = v
@ -116,7 +116,7 @@ func (use *TokenUseCase) newToken(ctx context.Context, req *entity.Authorization
tc.SetRole(req.Role) tc.SetRole(req.Role)
tc.SetID(token.ID) tc.SetID(token.ID)
tc.SetScope(req.Scope) tc.SetScope(req.Scope)
tc.SetAccount(req.Account) tc.SetLoginID(req.Account)
token.UID = tc.UID() token.UID = tc.UID()
@ -158,7 +158,7 @@ func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTok
} }
// Step 2: 提取 Claims Data // Step 2: 提取 Claims Data
claimsData, err := parseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false) claimsData, err := ParseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false)
if err != nil { if err != nil {
return entity.RefreshTokenResp{}, return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{ use.wrapTokenError(ctx, wrapTokenErrorReq{
@ -179,7 +179,7 @@ func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTok
Data: claimsData, Data: claimsData,
Expires: req.Expires, Expires: req.Expires,
IsRefreshToken: true, IsRefreshToken: true,
Account: claimsData.Account(), Account: claimsData.LoginID(),
Role: claimsData.Role(), Role: claimsData.Role(),
}) })
if err != nil { if err != nil {
@ -226,7 +226,7 @@ func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTok
} }
func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error { func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error {
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false) claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false)
if err != nil { if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{ return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "CancelToken extractClaims", funcName: "CancelToken extractClaims",
@ -263,11 +263,11 @@ func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelToken
} }
func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) { func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) {
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, true) claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, true)
if err != nil { if err != nil {
return entity.ValidationTokenResp{}, return entity.ValidationTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{ use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims", funcName: "ParseClaims",
req: req, req: req,
err: err, err: err,
message: "validate token claims error", message: "validate token claims error",
@ -400,11 +400,11 @@ func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.Quer
func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) { func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) {
// 驗證Token // 驗證Token
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false) claims, err := ParseClaims(req.Token, use.Config.Token.AccessSecret, false)
if err != nil { if err != nil {
return entity.CreateOneTimeTokenResp{}, return entity.CreateOneTimeTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{ use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims", funcName: "ParseClaims",
req: req, req: req,
err: err, err: err,
message: "failed to get token claims", message: "failed to get token claims",
@ -637,7 +637,7 @@ func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string,
// 為每個 token 創建黑名單條目 // 為每個 token 創建黑名單條目
for _, token := range tokens { for _, token := range tokens {
// 解析 token 獲取 JTI 和過期時間 // 解析 token 獲取 JTI 和過期時間
claims, err := parseClaims(token.AccessToken, use.Config.Token.AccessSecret, false) claims, err := ParseClaims(token.AccessToken, use.Config.Token.AccessSecret, false)
if err != nil { if err != nil {
logx.WithContext(ctx).Errorw("failed to parse token for blacklisting", logx.WithContext(ctx).Errorw("failed to parse token for blacklisting",
logx.Field("uid", uid), logx.Field("uid", uid),

View File

@ -1,28 +1,28 @@
package usecase package usecase
type tokenClaims map[string]string type TokenClaims map[string]string
func (tc tokenClaims) SetID(id string) { func (tc TokenClaims) SetID(id string) {
tc["id"] = id tc["id"] = id
} }
func (tc tokenClaims) SetRole(role string) { func (tc TokenClaims) SetRole(role string) {
tc["role"] = role tc["role"] = role
} }
func (tc tokenClaims) SetDeviceID(deviceID string) { func (tc TokenClaims) SetDeviceID(deviceID string) {
tc["device_id"] = deviceID tc["device_id"] = deviceID
} }
func (tc tokenClaims) SetScope(scope string) { func (tc TokenClaims) SetScope(scope string) {
tc["scope"] = scope tc["scope"] = scope
} }
func (tc tokenClaims) SetAccount(account string) { func (tc TokenClaims) SetLoginID(loginID string) {
tc["account"] = account tc["login_id"] = loginID
} }
func (tc tokenClaims) Role() string { func (tc TokenClaims) Role() string {
role, ok := tc["role"] role, ok := tc["role"]
if !ok { if !ok {
return "" return ""
@ -31,7 +31,7 @@ func (tc tokenClaims) Role() string {
return role return role
} }
func (tc tokenClaims) ID() string { func (tc TokenClaims) ID() string {
id, ok := tc["id"] id, ok := tc["id"]
if !ok { if !ok {
return "" return ""
@ -40,7 +40,7 @@ func (tc tokenClaims) ID() string {
return id return id
} }
func (tc tokenClaims) DeviceID() string { func (tc TokenClaims) DeviceID() string {
deviceID, ok := tc["device_id"] deviceID, ok := tc["device_id"]
if !ok { if !ok {
return "" return ""
@ -49,7 +49,7 @@ func (tc tokenClaims) DeviceID() string {
return deviceID return deviceID
} }
func (tc tokenClaims) UID() string { func (tc TokenClaims) UID() string {
uid, ok := tc["uid"] uid, ok := tc["uid"]
if !ok { if !ok {
return "" return ""
@ -58,7 +58,7 @@ func (tc tokenClaims) UID() string {
return uid return uid
} }
func (tc tokenClaims) Scope() string { func (tc TokenClaims) Scope() string {
scope, ok := tc["scope"] scope, ok := tc["scope"]
if !ok { if !ok {
return "" return ""
@ -67,8 +67,8 @@ func (tc tokenClaims) Scope() string {
return scope return scope
} }
func (tc tokenClaims) Account() string { func (tc TokenClaims) LoginID() string {
scope, ok := tc["account"] scope, ok := tc["login_id"]
if !ok { if !ok {
return "" return ""
} }

View File

@ -27,7 +27,7 @@ func TestTokenClaims_SetAndGetID(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tc := make(tokenClaims) tc := make(TokenClaims)
tc.SetID(tt.id) tc.SetID(tt.id)
result := tc.ID() result := tc.ID()
@ -61,7 +61,7 @@ func TestTokenClaims_SetAndGetRole(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tc := make(tokenClaims) tc := make(TokenClaims)
tc.SetRole(tt.role) tc.SetRole(tt.role)
result := tc.Role() result := tc.Role()
@ -91,7 +91,7 @@ func TestTokenClaims_SetAndGetDeviceID(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tc := make(tokenClaims) tc := make(TokenClaims)
tc.SetDeviceID(tt.deviceID) tc.SetDeviceID(tt.deviceID)
result := tc.DeviceID() result := tc.DeviceID()
@ -125,7 +125,7 @@ func TestTokenClaims_SetAndGetScope(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tc := make(tokenClaims) tc := make(TokenClaims)
tc.SetScope(tt.scope) tc.SetScope(tt.scope)
// Note: there's no GetScope method, so we just verify it's set // Note: there's no GetScope method, so we just verify it's set
@ -159,7 +159,7 @@ func TestTokenClaims_SetAndGetAccount(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tc := make(tokenClaims) tc := make(TokenClaims)
tc.SetAccount(tt.account) tc.SetAccount(tt.account)
// Note: there's no GetAccount method, so we just verify it's set // Note: there's no GetAccount method, so we just verify it's set
@ -189,7 +189,7 @@ func TestTokenClaims_SetAndGetUID(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tc := make(tokenClaims) tc := make(TokenClaims)
tc["uid"] = tt.uid tc["uid"] = tt.uid
result := tc.UID() result := tc.UID()
@ -199,7 +199,7 @@ func TestTokenClaims_SetAndGetUID(t *testing.T) {
} }
func TestTokenClaims_GetNonExistentField(t *testing.T) { func TestTokenClaims_GetNonExistentField(t *testing.T) {
tc := make(tokenClaims) tc := make(TokenClaims)
t.Run("get non-existent ID", func(t *testing.T) { t.Run("get non-existent ID", func(t *testing.T) {
result := tc.ID() result := tc.ID()
@ -223,7 +223,7 @@ func TestTokenClaims_GetNonExistentField(t *testing.T) {
} }
func TestTokenClaims_MultipleFields(t *testing.T) { func TestTokenClaims_MultipleFields(t *testing.T) {
tc := make(tokenClaims) tc := make(TokenClaims)
tc.SetID("token123") tc.SetID("token123")
tc.SetRole("admin") tc.SetRole("admin")
@ -243,7 +243,7 @@ func TestTokenClaims_MultipleFields(t *testing.T) {
} }
func TestTokenClaims_Overwrite(t *testing.T) { func TestTokenClaims_Overwrite(t *testing.T) {
tc := make(tokenClaims) tc := make(TokenClaims)
t.Run("overwrite ID", func(t *testing.T) { t.Run("overwrite ID", func(t *testing.T) {
tc.SetID("token123") tc.SetID("token123")
@ -263,7 +263,7 @@ func TestTokenClaims_Overwrite(t *testing.T) {
} }
func TestTokenClaims_MapBehavior(t *testing.T) { func TestTokenClaims_MapBehavior(t *testing.T) {
tc := make(tokenClaims) tc := make(TokenClaims)
t.Run("can set custom fields", func(t *testing.T) { t.Run("can set custom fields", func(t *testing.T) {
tc["custom_field"] = "custom_value" tc["custom_field"] = "custom_value"
@ -271,7 +271,7 @@ func TestTokenClaims_MapBehavior(t *testing.T) {
}) })
t.Run("can iterate over fields", func(t *testing.T) { t.Run("can iterate over fields", func(t *testing.T) {
tc2 := make(tokenClaims) tc2 := make(TokenClaims)
tc2.SetID("token123") tc2.SetID("token123")
tc2.SetRole("admin") tc2.SetRole("admin")
tc2["uid"] = "user123" tc2["uid"] = "user123"
@ -303,7 +303,7 @@ func TestTokenClaims_MapBehavior(t *testing.T) {
} }
func TestTokenClaims_EmptyMap(t *testing.T) { func TestTokenClaims_EmptyMap(t *testing.T) {
tc := make(tokenClaims) tc := make(TokenClaims)
assert.Empty(t, tc.ID()) assert.Empty(t, tc.ID())
assert.Empty(t, tc.Role()) assert.Empty(t, tc.Role())
@ -313,7 +313,7 @@ func TestTokenClaims_EmptyMap(t *testing.T) {
} }
func TestTokenClaims_NilMap(t *testing.T) { func TestTokenClaims_NilMap(t *testing.T) {
var tc tokenClaims var tc TokenClaims
t.Run("get from nil map", func(t *testing.T) { t.Run("get from nil map", func(t *testing.T) {
assert.Empty(t, tc.ID()) assert.Empty(t, tc.ID())
@ -322,4 +322,3 @@ func TestTokenClaims_NilMap(t *testing.T) {
assert.Empty(t, tc.UID()) assert.Empty(t, tc.UID())
}) })
} }

View File

@ -21,6 +21,7 @@ func createAccessToken(token entity.Token, data any, secretKey string) (string,
RegisteredClaims: jwt.RegisteredClaims{ RegisteredClaims: jwt.RegisteredClaims{
ID: token.ID, ID: token.ID,
ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)), ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: "permission", Issuer: "permission",
}, },
} }
@ -76,10 +77,10 @@ func parseToken(accessToken string, secret string, validate bool) (jwt.MapClaims
return claims, nil return claims, nil
} }
func parseClaims(accessToken string, secret string, validate bool) (tokenClaims, error) { func ParseClaims(accessToken string, secret string, validate bool) (TokenClaims, error) {
claimMap, err := parseToken(accessToken, secret, validate) claimMap, err := parseToken(accessToken, secret, validate)
if err != nil { if err != nil {
return tokenClaims{}, err return TokenClaims{}, err
} }
claimsData, ok := claimMap["data"].(map[string]any) claimsData, ok := claimMap["data"].(map[string]any)
@ -87,7 +88,7 @@ func parseClaims(accessToken string, secret string, validate bool) (tokenClaims,
return convertMap(claimsData), nil return convertMap(claimsData), nil
} }
return tokenClaims{}, fmt.Errorf("get data from claim map error") return TokenClaims{}, fmt.Errorf("get data from claim map error")
} }
func convertMap(input map[string]interface{}) map[string]string { func convertMap(input map[string]interface{}) map[string]string {

View File

@ -58,7 +58,7 @@ func TestCreateAccessToken(t *testing.T) {
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) { token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
return []byte(tt.secretKey), nil return []byte(tt.secretKey), nil
}) })
if tt.secretKey != "" { if tt.secretKey != "" {
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, token.Valid) assert.True(t, token.Valid)
@ -125,7 +125,7 @@ func TestCreateRefreshToken(t *testing.T) {
func TestParseToken(t *testing.T) { func TestParseToken(t *testing.T) {
secretKey := "test-secret-key" secretKey := "test-secret-key"
// Create a valid token first // Create a valid token first
token := entity.Token{ token := entity.Token{
ID: "test-id", ID: "test-id",
@ -135,7 +135,7 @@ func TestParseToken(t *testing.T) {
"uid": "user123", "uid": "user123",
"role": "admin", "role": "admin",
} }
validTokenStr, err := createAccessToken(token, data, secretKey) validTokenStr, err := createAccessToken(token, data, secretKey)
assert.NoError(t, err) assert.NoError(t, err)
@ -192,7 +192,7 @@ func TestParseToken(t *testing.T) {
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, claims) assert.NotNil(t, claims)
if tt.accessToken == validTokenStr { if tt.accessToken == validTokenStr {
assert.Equal(t, "test-id", claims["jti"]) assert.Equal(t, "test-id", claims["jti"])
assert.Equal(t, "permission", claims["iss"]) assert.Equal(t, "permission", claims["iss"])
@ -204,7 +204,7 @@ func TestParseToken(t *testing.T) {
func TestParseClaims(t *testing.T) { func TestParseClaims(t *testing.T) {
secretKey := "test-secret-key" secretKey := "test-secret-key"
// Create a valid token with data claims // Create a valid token with data claims
token := entity.Token{ token := entity.Token{
ID: "test-id", ID: "test-id",
@ -215,7 +215,7 @@ func TestParseClaims(t *testing.T) {
"role": "admin", "role": "admin",
"deviceId": "device456", "deviceId": "device456",
} }
validTokenStr, err := createAccessToken(token, data, secretKey) validTokenStr, err := createAccessToken(token, data, secretKey)
assert.NoError(t, err) assert.NoError(t, err)
@ -248,20 +248,20 @@ func TestParseClaims(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
claims, err := parseClaims(tt.accessToken, tt.secret, tt.validate) claims, err := ParseClaims(tt.accessToken, tt.secret, tt.validate)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, claims) assert.NotNil(t, claims)
if tt.expectUID != "" { if tt.expectUID != "" {
uid, exists := claims["uid"] uid, exists := claims["uid"]
assert.True(t, exists) assert.True(t, exists)
assert.Equal(t, tt.expectUID, uid) assert.Equal(t, tt.expectUID, uid)
} }
if tt.expectRole != "" { if tt.expectRole != "" {
role, exists := claims["role"] role, exists := claims["role"]
assert.True(t, exists) assert.True(t, exists)