feat: add token service

This commit is contained in:
王性驊 2025-10-06 16:28:39 +08:00
parent 31ab87aadc
commit 40812db5bf
68 changed files with 5957 additions and 3465 deletions

View File

@ -41,4 +41,13 @@ GoogleAuth:
LineAuth:
ClientID : "200000000"
ClientSecret : xxxxx
RedirectURI : http://localhost:8080/line.html
RedirectURI : http://localhost:8080/line.html
Token:
AccessSecret : "1qaz@WSX3edc$RFV"
RefreshSecret : "1qaz@WSX3edc$RFV"
AccessTokenExpiry : 600s
RefreshTokenExpiry : 86400s
OneTimeTokenExpiry : 600s
MaxTokensPerUser : 2
MaxTokensPerDevice : 2

View File

@ -1,14 +0,0 @@
[request_definition]
r = sub, obj, act
[policy_definition]
p = sub, obj, act
[role_definition]
g = _, _
[policy_effect]
e = some(where (p.eft == allow))
[matchers]
m = g(r.sub, p.sub) && keyMatch2(r.obj, p.obj) && regexMatch(r.act, p.act)

9
go.mod
View File

@ -10,12 +10,12 @@ require (
github.com/aws/aws-sdk-go-v2 v1.39.2
github.com/aws/aws-sdk-go-v2/credentials v1.18.16
github.com/aws/aws-sdk-go-v2/service/ses v1.34.5
github.com/casbin/casbin/v2 v2.127.0
github.com/go-playground/validator/v10 v10.27.0
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/golang-jwt/jwt/v4 v4.5.2
github.com/matcornic/hermes/v2 v2.1.0
github.com/minchao/go-mitake v1.0.0
github.com/panjf2000/ants/v2 v2.11.3
github.com/segmentio/ksuid v1.0.4
github.com/shopspring/decimal v1.4.0
github.com/stretchr/testify v1.11.1
github.com/testcontainers/testcontainers-go v0.39.0
@ -42,8 +42,6 @@ require (
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.9 // indirect
github.com/aws/smithy-go v1.23.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/bmatcuk/doublestar/v4 v4.6.1 // indirect
github.com/casbin/govaluate v1.3.0 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/errdefs v1.0.0 // indirect
@ -67,7 +65,6 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
github.com/golang/snappy v1.0.0 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/gorilla/css v1.0.0 // indirect
@ -108,12 +105,12 @@ require (
github.com/redis/go-redis/v9 v9.15.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/russross/blackfriday/v2 v2.0.1 // indirect
github.com/shirou/gopsutil/v3 v3.24.5 // indirect
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
github.com/sirupsen/logrus v1.9.3 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tklauser/go-sysconf v0.3.12 // indirect
github.com/tklauser/numcpus v0.6.1 // indirect
github.com/vanng822/css v0.0.0-20190504095207-a21e860bcd04 // indirect

21
go.sum
View File

@ -34,16 +34,10 @@ github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE=
github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/bmatcuk/doublestar/v4 v4.6.1 h1:FH9SifrbvJhnlQpztAx++wlkk70QBf0iBWDwNy7PA4I=
github.com/bmatcuk/doublestar/v4 v4.6.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc=
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
github.com/casbin/casbin/v2 v2.127.0 h1:UGK3uO/8cOslnNqFUJ4xzm/bh+N+o45U7cSolaFk38c=
github.com/casbin/casbin/v2 v2.127.0/go.mod h1:n4uZK8+tCMvcD6EVQZI90zKAok8iHAvEypcMJVKhGF0=
github.com/casbin/govaluate v1.3.0 h1:VA0eSY0M2lA86dYd5kPPuNZMUD9QkWnOCnavGrw9myc=
github.com/casbin/govaluate v1.3.0/go.mod h1:G/UnbIjZk/0uMNaLwZZmFQrR72tYRZWQkO70si/iR7A=
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
@ -101,17 +95,11 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc=
github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@ -220,12 +208,10 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
github.com/segmentio/ksuid v1.0.4 h1:sBo2BdShXjmcugAMwjugoGUdUV0pcxY5mW4xKRn3v4c=
github.com/segmentio/ksuid v1.0.4/go.mod h1:/XUiZBD3kVx5SmUOl55voK5yeAbBNNIed+2O73XgrPE=
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=
@ -326,7 +312,6 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@ -357,7 +342,6 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
@ -374,7 +358,6 @@ golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=

View File

@ -49,4 +49,15 @@ type Config struct {
ClientSecret string
RedirectURI string
}
// JWT Token 配置
Token struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}
}

View File

@ -7,6 +7,8 @@ import (
"backend/pkg/library/errs/code"
mb "backend/pkg/member/domain/member"
member "backend/pkg/member/domain/usecase"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/token"
"context"
"time"
@ -27,7 +29,7 @@ var PrepareFunc map[string]func(ctx context.Context, req *types.LoginReq, svc *s
mb.Line.ToString(): buildLineData,
}
// 註冊新帳號
// NewRegisterLogic 註冊新帳號
func NewRegisterLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RegisterLogic {
return &RegisterLogic{
Logger: logx.WithContext(ctx),
@ -80,16 +82,16 @@ func (l *RegisterLogic) Register(req *types.LoginReq) (resp *types.LoginResp, er
// Step 5: 生成 Token
req.LoginID = bd.CreateAccountReq.LoginID
token, err := l.generateToken(req, account.UID)
tk, err := l.generateToken(req, account.UID)
if err != nil {
return nil, err
}
return &types.LoginResp{
UID: account.UID,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
TokenType: token.TokenType,
AccessToken: tk.AccessToken,
RefreshToken: tk.RefreshToken,
TokenType: tk.TokenType,
}, nil
}
@ -183,40 +185,33 @@ func buildLineData(ctx context.Context, req *types.LoginReq, svc *svc.ServiceCon
}, nil
}
type MockToken struct {
AccessToken string `json:"access_token"` // 訪問令牌
TokenType string `json:"token_type"` // 令牌類型
ExpiresIn int64 `json:"expires_in"` // 過期時間(秒)
RefreshToken string `json:"refresh_token"` // 刷新令牌
}
// 生成 Token
func (l *RegisterLogic) generateToken(req *types.LoginReq, uid string) (MockToken, error) {
//credentials := tokenModule.ClientCredentials
//role := "user"
//if isTruHeartEmail(req.Account) {
// role = "admin"
//}
//
//return l.svcCtx.TokenUseCase.NewToken(l.ctx, tokenModule.AuthorizationReq{
// GrantType: credentials.ToString(),
// DeviceID: req.DeviceID,
// Scope: domain.DefaultScope,
// IsRefreshToken: true,
// Expires: time.Now().UTC().Add(l.svcCtx.Config.Token.Expired).Unix(),
// Data: map[string]string{
// "uid": uid,
// "role": role,
// "account": req.Account,
// },
// Role: role,
//})
func (l *RegisterLogic) generateToken(req *types.LoginReq, uid string) (entity.TokenResp, error) {
// scope role 要修改refresh tl
role := "user"
return MockToken{
AccessToken: "gg88g88",
TokenType: "Bearer",
ExpiresIn: time.Now().UTC().Add(100000000000).Unix(),
RefreshToken: "gg88g88",
tk, err := l.svcCtx.TokenUC.NewToken(l.ctx, entity.AuthorizationReq{
GrantType: token.ClientCredentials.ToString(),
DeviceID: uid, // TODO 沒傳暫時先用UID 替代
Scope: "gateway",
IsRefreshToken: true,
Expires: time.Now().UTC().Add(l.svcCtx.Config.Token.AccessTokenExpiry).Unix(),
Data: map[string]string{
"uid": uid,
"role": role,
},
Role: role,
Account: req.LoginID,
})
if err != nil {
return entity.TokenResp{}, err
}
return entity.TokenResp{
AccessToken: tk.AccessToken,
TokenType: tk.TokenType,
ExpiresIn: tk.ExpiresIn,
RefreshToken: tk.RefreshToken,
}, nil
}

View File

@ -1,194 +0,0 @@
package svc
//func NewPermissionUC(c *config.Config, rds *redis.Redis) usecase.PermissionUseCase {
// // 準備Mongo Config (重用現有配置)
// conf := &mgo.Conf{
// Schema: c.Mongo.Schema,
// Host: c.Mongo.Host,
// Database: c.Mongo.Database,
// MaxStaleness: c.Mongo.MaxStaleness,
// MaxPoolSize: c.Mongo.MaxPoolSize,
// MinPoolSize: c.Mongo.MinPoolSize,
// MaxConnIdleTime: c.Mongo.MaxConnIdleTime,
// Compressors: c.Mongo.Compressors,
// EnableStandardReadWriteSplitMode: c.Mongo.EnableStandardReadWriteSplitMode,
// ConnectTimeoutMs: c.Mongo.ConnectTimeoutMs,
// }
// if c.Mongo.User != "" {
// conf.User = c.Mongo.User
// conf.Password = c.Mongo.Password
// }
//
// // 快取選項
// cacheOpts := []cache.Option{
// cache.WithExpiry(c.CacheExpireTime),
// cache.WithNotFoundExpiry(c.CacheWithNotFoundExpiry),
// }
// dbOpts := []mon.Option{
// mgo.SetCustomDecimalType(),
// mgo.InitMongoOptions(*conf),
// }
//
// // 初始化 Casbin Adapter
// casbinAdapter := repository.NewCasbinAdapter(repository.CasbinAdapterParam{
// Conf: conf,
// CacheConf: c.Cache,
// CacheOpts: cacheOpts,
// DBOpts: dbOpts,
// })
//
// // 初始化 Casbin Enforcer
// modelPath := "pkg/permission/config/rbac_model.conf"
// enforcer, err := casbin.NewEnforcer(modelPath, casbinAdapter)
// if err != nil {
// panic("Failed to create casbin enforcer: " + err.Error())
// }
//
// // 啟用自動保存
// enforcer.EnableAutoSave(true)
//
// // 載入策略
// err = enforcer.LoadPolicy()
// if err != nil {
// panic("Failed to load casbin policy: " + err.Error())
// }
//
// // 初始化其他 Repository
// permissionRepo := repository.NewPermissionRepository(repository.PermissionRepositoryParam{
// Conf: conf,
// CacheConf: c.Cache,
// CacheOpts: cacheOpts,
// DBOpts: dbOpts,
// })
//
// roleRepo := repository.NewRoleRepository(repository.RoleRepositoryParam{
// Conf: conf,
// CacheConf: c.Cache,
// CacheOpts: cacheOpts,
// DBOpts: dbOpts,
// })
//
// userRoleRepo := repository.NewUserRoleRepository(repository.UserRoleRepositoryParam{
// Conf: conf,
// CacheConf: c.Cache,
// CacheOpts: cacheOpts,
// DBOpts: dbOpts,
// })
//
// // 創建索引
// _, _ = permissionRepo.Index20241226001UP(context.Background())
// _, _ = roleRepo.Index20241226001UP(context.Background())
// _, _ = userRoleRepo.Index20241226001UP(context.Background())
//
// return uc.MustPermissionUseCase(uc.PermissionUseCaseParam{
// Enforcer: enforcer,
// PermissionRepo: permissionRepo,
// RoleRepo: roleRepo,
// UserRoleRepo: userRoleRepo,
// })
//}
//
//func NewAuthUC(c *config.Config, rds *redis.Redis) usecase.AuthUseCase {
// // 準備Mongo Config
// conf := &mgo.Conf{
// Schema: c.Mongo.Schema,
// Host: c.Mongo.Host,
// Database: c.Mongo.Database,
// MaxStaleness: c.Mongo.MaxStaleness,
// MaxPoolSize: c.Mongo.MaxPoolSize,
// MinPoolSize: c.Mongo.MinPoolSize,
// MaxConnIdleTime: c.Mongo.MaxConnIdleTime,
// Compressors: c.Mongo.Compressors,
// EnableStandardReadWriteSplitMode: c.Mongo.EnableStandardReadWriteSplitMode,
// ConnectTimeoutMs: c.Mongo.ConnectTimeoutMs,
// }
// if c.Mongo.User != "" {
// conf.User = c.Mongo.User
// conf.Password = c.Mongo.Password
// }
//
// // 快取選項
// cacheOpts := []cache.Option{
// cache.WithExpiry(c.CacheExpireTime),
// cache.WithNotFoundExpiry(c.CacheWithNotFoundExpiry),
// }
// dbOpts := []mon.Option{
// mgo.SetCustomDecimalType(),
// mgo.InitMongoOptions(*conf),
// }
//
// // 初始化 Repository
// clientRepo := repository.NewClientRepository(repository.ClientRepositoryParam{
// Conf: conf,
// CacheConf: c.Cache,
// CacheOpts: cacheOpts,
// DBOpts: dbOpts,
// })
//
// tokenRepo := repository.NewTokenRepository(repository.TokenRepositoryParam{
// Redis: rds,
// })
//
// // JWT 配置
// jwtConfig := permissionConfig.JWTConfig{
// Secret: c.JWTAuth.AccessSecret, // 使用現有的JWT配置
// AccessExpires: c.JWTAuth.AccessExpire,
// RefreshExpires: c.JWTAuth.AccessExpire * 7, // refresh token 較長
// }
//
// return uc.MustAuthUseCase(uc.AuthUseCaseParam{
// ClientRepo: clientRepo,
// TokenRepo: tokenRepo,
// JWTConfig: jwtConfig,
// })
//}
//
//func NewRoleUC(c *config.Config) usecase.RoleUseCase {
// // 準備Mongo Config
// conf := &mgo.Conf{
// Schema: c.Mongo.Schema,
// Host: c.Mongo.Host,
// Database: c.Mongo.Database,
// MaxStaleness: c.Mongo.MaxStaleness,
// MaxPoolSize: c.Mongo.MaxPoolSize,
// MinPoolSize: c.Mongo.MinPoolSize,
// MaxConnIdleTime: c.Mongo.MaxConnIdleTime,
// Compressors: c.Mongo.Compressors,
// EnableStandardReadWriteSplitMode: c.Mongo.EnableStandardReadWriteSplitMode,
// ConnectTimeoutMs: c.Mongo.ConnectTimeoutMs,
// }
// if c.Mongo.User != "" {
// conf.User = c.Mongo.User
// conf.Password = c.Mongo.Password
// }
//
// // 快取選項
// cacheOpts := []cache.Option{
// cache.WithExpiry(c.CacheExpireTime),
// cache.WithNotFoundExpiry(c.CacheWithNotFoundExpiry),
// }
// dbOpts := []mon.Option{
// mgo.SetCustomDecimalType(),
// mgo.InitMongoOptions(*conf),
// }
//
// // 初始化 Repository
// roleRepo := repository.NewRoleRepository(repository.RoleRepositoryParam{
// Conf: conf,
// CacheConf: c.Cache,
// CacheOpts: cacheOpts,
// DBOpts: dbOpts,
// })
//
// userRoleRepo := repository.NewUserRoleRepository(repository.UserRoleRepositoryParam{
// Conf: conf,
// CacheConf: c.Cache,
// CacheOpts: cacheOpts,
// DBOpts: dbOpts,
// })
//
// return uc.MustRoleUseCase(uc.RoleUseCaseParam{
// RoleRepo: roleRepo,
// UserRoleRepo: userRoleRepo,
// })
//}

View File

@ -6,7 +6,8 @@ import (
"backend/pkg/library/errs"
"backend/pkg/library/errs/code"
vi "backend/pkg/library/validator"
"backend/pkg/member/domain/usecase"
memberUC "backend/pkg/member/domain/usecase"
tokenUC "backend/pkg/permission/domain/usecase"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/rest"
@ -15,8 +16,9 @@ import (
type ServiceContext struct {
Config config.Config
AuthMiddleware rest.Middleware
AccountUC usecase.AccountUseCase
AccountUC memberUC.AccountUseCase
Validate vi.Validate
TokenUC tokenUC.TokenUseCase
}
func NewServiceContext(c config.Config) *ServiceContext {
@ -31,5 +33,6 @@ func NewServiceContext(c config.Config) *ServiceContext {
AuthMiddleware: middleware.NewAuthMiddleware().Handle,
AccountUC: NewAccountUC(&c, rds),
Validate: vi.MustValidator(),
TokenUC: NewTokenUC(&c, rds),
}
}

18
internal/svc/token.go Normal file
View File

@ -0,0 +1,18 @@
package svc
import (
"backend/internal/config"
"backend/pkg/permission/domain/usecase"
"backend/pkg/permission/repository"
uc "backend/pkg/permission/usecase"
"github.com/zeromicro/go-zero/core/stores/redis"
)
func NewTokenUC(c *config.Config, rds *redis.Redis) usecase.TokenUseCase {
return uc.MustTokenUseCase(uc.TokenUseCaseParam{
TokenRepo: repository.MustTokenRepository(repository.TokenRepositoryParam{
Redis: rds,
}),
Config: c,
})
}

View File

@ -11,4 +11,5 @@ const (
CatSystem
CatPubSub
CatService
CatToken
)

View File

@ -74,3 +74,16 @@ const (
ThirdParty
ArkHTTP400 // Ark HTTP 400 錯誤
)
// 詳細代碼 - Token 類 09x
const (
_ = iota + CatToken
TokenCreateError // Token 創建錯誤
TokenValidateError // Token 驗證錯誤
TokenExpired // Token 過期
TokenNotFound // Token 未找到
TokenBlacklisted // Token 已被列入黑名單
InvalidJWT // 無效的 JWT
RefreshTokenError // Refresh Token 錯誤
OneTimeTokenError // 一次性 Token 錯誤
)

View File

@ -195,7 +195,7 @@ func TestVerifyPlatformAuthResult(t *testing.T) {
})
token, err := HashPassword("password", 10)
assert.NoError(t, err)
fmt.Println(token)
tests := []struct {
name string
param usecase.VerifyAuthResultRequest

View File

@ -1,286 +1,364 @@
# Permission 權限管理模組 - Casbin 版
# Permission Module
一個基於 **Casbin** 的現代化權限管理模組,完全整合你的專案技術棧,提供強大且靈活的 RBAC 權限控制
JWT Token 和 Refresh Token 管理模組,提供完整的身份驗證和授權功能
## 🎯 為什麼選擇 Casbin
## 📋 功能特性
你說得完全對!與其重新發明一個功能精簡的權限系統,**Casbin** 提供了:
### 🔐 JWT Token 管理
- **Access Token 生成**: 基於 JWT 標準生成存取權杖
- **Refresh Token 機制**: 支援長期有效的刷新權杖
- **One-Time Token**: 臨時性權杖,用於特殊場景
- **Token 驗證**: 完整的權杖驗證和解析功能
### ✅ **社群驗證的成熟解決方案**
- 🌟 **6.7k+ GitHub Stars**,經過大量生產環境驗證
- 🔧 **功能完整**:支援 RBAC、ABAC、RESTful、通配符、正則表達式
- 📚 **文檔完善**:豐富的範例和最佳實踐
- 🛠️ **持續維護**:活躍的社群支持和定期更新
### 🚫 黑名單機制
- **即時撤銷**: 將 JWT 權杖立即加入黑名單
- **用戶登出**: 支援單一設備或全設備登出
- **自動過期**: 黑名單條目會在權杖過期後自動清理
- **批量管理**: 支援批量黑名單操作
### ✅ **強大的功能特性**
- **通配符支援**: `/api/users/*` 一個規則覆蓋所有子路徑
- **正則表達式**: 靈活的權限匹配規則
- **角色繼承**: 複雜的組織架構支援
- **多種模型**: RBAC、ABAC、RESTful 等
- **策略持久化**: 自動保存到你的 MongoDB
### 💾 Redis 儲存
- **高效能**: 使用 Redis 作為主要儲存引擎
- **TTL 管理**: 自動管理權杖過期時間
- **關聯管理**: 支援用戶、設備與權杖的關聯查詢
## 📁 目錄結構
### 🔒 安全特性
- **HMAC-SHA256**: 使用安全的簽名算法
- **密鑰分離**: Access Token 和 Refresh Token 使用不同密鑰
- **設備限制**: 支援每用戶、每設備的權杖數量限制
- **過期控制**: 靈活的權杖過期時間配置
## 🏗️ 架構設計
本模組遵循 **Clean Architecture** 原則:
```
pkg/permission/
├── config/ # Casbin 模型配置
│ └── rbac_model.conf # RBAC 權限模型
├── domain/ # 領域層
│ ├── entity/ # 實體定義
│ ├── repository/ # 倉庫介面
│ ├── usecase/ # 用例介面 (Casbin 增強)
│ └── config/ # 配置定義
├── repository/ # 倉庫實現
│ ├── casbin_adapter.go # Casbin MongoDB 適配器
│ ├── client.go # 客戶端管理
│ ├── role.go # 角色管理
│ └── ... # 其他倉庫
├── usecase/ # 用例實現 (Casbin API)
├── svc/ # 初始化層
├── example/ # Casbin 使用範例
└── README.md # 本文件
├── domain/ # 領域層
│ ├── entity/ # 實體定義
│ ├── repository/ # 儲存庫介面
│ ├── usecase/ # 用例介面
│ └── token/ # 權杖相關常數和類型
├── usecase/ # 用例實現
├── repository/ # 儲存庫實現
└── mock/ # 測試模擬
```
## 🚀 核心優勢
### 領域層 (Domain)
- **Entity**: 定義核心業務實體Token、BlacklistEntry、Ticket
- **Repository Interface**: 定義資料存取介面
- **UseCase Interface**: 定義業務用例介面
- **Token Types**: 權杖類型和常數定義
### **🔥 Casbin 強化功能**
- **通配符權限**: `GET /api/users/*` 覆蓋所有用戶子路徑
- **正則表達式**: `GET /api/users/\d+` 只允許數字 ID
- **角色繼承**: `admin` 繼承 `user` 的所有權限
- **策略分離**: 權限策略與業務邏輯完全分離
- **動態更新**: 運行時動態添加/移除權限,無需重啟
### 用例層 (UseCase)
- **TokenUseCase**: 核心業務邏輯實現
- **JWT 處理**: 權杖生成、解析、驗證
- **黑名單管理**: 權杖撤銷和黑名單查詢
### **⚡ 技術整合**
- **MongoDB 適配器**: 策略自動持久化到你的 MongoDB
- **你的錯誤系統**: 完整的 `@errs/` 整合
- **緩存支援**: 使用你現有的 Redis 緩存
- **go-zero 整合**: 無縫整合到你的服務架構
### 儲存層 (Repository)
- **Redis 實現**: 基於 Redis 的資料存取
- **關聯管理**: 用戶、設備、權杖關聯
- **TTL 管理**: 自動過期處理
## 🔧 Casbin 模型
## 🚀 快速開始
```ini
# pkg/permission/config/rbac_model.conf
[request_definition]
r = sub, obj, act
### 1. 配置設定
[policy_definition]
p = sub, obj, act
[role_definition]
g = _, _
[policy_effect]
e = some(where (p.eft == allow))
[matchers]
m = g(r.sub, p.sub) && keyMatch2(r.obj, p.obj) && regexMatch(r.act, p.act)
```
這個模型支援:
- **keyMatch2**: 通配符匹配 (`/api/users/*`)
- **regexMatch**: 正則表達式匹配
- **角色繼承**: `g(user, role)` 關係
## 📦 快速整合
### 1. 在你的 ServiceContext 中添加
`internal/config/config.go` 中添加 Token 配置:
```go
// internal/svc/service_context.go
import "backend/pkg/permission/svc"
type ServiceContext struct {
Config config.Config
AuthMiddleware rest.Middleware
AccountUC usecase.AccountUseCase
PermissionUC permission.PermissionUseCase // ← Casbin 增強
AuthUC permission.AuthUseCase
RoleUC permission.RoleUseCase
Validate vi.Validate
}
func NewServiceContext(c config.Config) *ServiceContext {
rds, err := redis.NewRedis(c.RedisConf)
if err != nil {
panic(err)
}
return &ServiceContext{
Config: c,
AuthMiddleware: middleware.NewAuthMiddleware().Handle,
AccountUC: NewAccountUC(&c, rds),
PermissionUC: svc.NewPermissionUC(&c, rds), // ← Casbin 自動初始化
AuthUC: svc.NewAuthUC(&c, rds),
RoleUC: svc.NewRoleUC(&c),
Validate: vi.MustValidator(),
}
}
```
### 2. Casbin 強大功能使用
```go
// 🔥 通配符權限 - 一個規則覆蓋所有子路徑
err = permissionUC.AddPermissionForRole(ctx, "admin", "/api/users/*", ".*")
// ✅ 這些都會被允許:
// GET /api/users/123
// POST /api/users/123/profile
// DELETE /api/users/123/avatar
// 🔥 正則表達式權限 - 精確控制
err = permissionUC.AddPermissionForRole(ctx, "viewer", "/api/users/\\d+", "GET")
// ✅ 只允許: GET /api/users/123 (數字ID)
// ❌ 拒絕: GET /api/users/abc (非數字ID)
// 🔥 角色繼承 - 組織架構支援
err = permissionUC.AddRoleForUser(ctx, "john", "admin")
err = permissionUC.AddRoleInheritance(ctx, "admin", "user")
// john 自動擁有 admin 和 user 的所有權限
// 🔥 動態權限檢查
hasPermission, err := permissionUC.CheckUserPermission(ctx, "john", "GET", "/api/users/123")
hasPattern, err := permissionUC.CheckPatternPermission(ctx, "john", "/api/users/456", "DELETE")
```
## 🎯 實際使用場景
### **API 權限控制**
```go
// 中間件中使用
func PermissionMiddleware(permissionUC permission.PermissionUseCase) rest.Middleware {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userID := getUserIDFromJWT(r)
// Casbin 自動處理通配符和正則表達式
hasPermission, err := permissionUC.CheckUserPermission(
r.Context(), userID, r.Method, r.URL.Path,
)
if err != nil || !hasPermission {
httpx.WriteJsonCtx(r.Context(), w, 403, types.ErrorResp{
Code: 4030001,
Msg: "權限不足",
})
return
}
next(w, r)
}
}
}
```
### **權限初始化**
```go
// 初始化基礎權限
func InitPermissions(ctx context.Context, permissionUC permission.PermissionUseCase) {
// 🔥 管理員擁有所有 API 權限
permissionUC.AddPermissionForRole(ctx, "admin", "/api/*", ".*")
type Config struct {
// ... 其他配置
// 🔥 用戶只能查看和更新自己的資料
permissionUC.AddPermissionForRole(ctx, "user", "/api/users/{{.UserID}}", "GET|PUT")
Token struct {
AccessSecret string // Access Token 簽名密鑰
RefreshSecret string // Refresh Token 簽名密鑰
AccessTokenExpiry time.Duration // Access Token 過期時間
RefreshTokenExpiry time.Duration // Refresh Token 過期時間
OneTimeTokenExpiry time.Duration // 一次性 Token 過期時間
MaxTokensPerUser int // 每用戶最大 Token 數
MaxTokensPerDevice int // 每設備最大 Token 數
}
}
```
### 2. 初始化模組
```go
import (
"backend/pkg/permission/repository"
"backend/pkg/permission/usecase"
)
// 初始化 Repository
tokenRepo := repository.MustTokenRepository(repository.TokenRepositoryParam{
Redis: redisClient,
})
// 初始化 UseCase
tokenUseCase := usecase.MustTokenUseCase(usecase.TokenUseCaseParam{
TokenRepo: tokenRepo,
Config: config,
})
```
### 3. 基本使用
#### 創建 Access Token
```go
req := entity.AuthorizationReq{
GrantType: token.PasswordCredentials.ToString(),
Scope: "read write",
DeviceID: "device123",
IsRefreshToken: true,
Claims: map[string]string{
"uid": "user123",
"role": "admin",
},
}
resp, err := tokenUseCase.NewToken(ctx, req)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Access Token: %s\n", resp.AccessToken)
fmt.Printf("Refresh Token: %s\n", resp.RefreshToken)
```
#### 驗證 Token
```go
req := entity.ValidationTokenReq{
Token: accessToken,
}
resp, err := tokenUseCase.ValidationToken(ctx, req)
if err != nil {
log.Printf("Token validation failed: %v", err)
return
}
fmt.Printf("Token is valid for user: %s\n", resp.Token.UID)
```
#### 撤銷 Token (加入黑名單)
```go
err := tokenUseCase.BlacklistToken(ctx, accessToken, "user logout")
if err != nil {
log.Printf("Failed to blacklist token: %v", err)
}
```
#### 檢查黑名單
```go
isBlacklisted, err := tokenUseCase.IsTokenBlacklisted(ctx, jti)
if err != nil {
log.Printf("Failed to check blacklist: %v", err)
}
if isBlacklisted {
log.Println("Token is blacklisted")
}
```
## 🧪 測試
### 運行測試
```bash
# 運行所有測試
go test ./pkg/permission/...
# 運行特定模組測試
go test ./pkg/permission/usecase/
go test ./pkg/permission/repository/
# 運行測試並顯示覆蓋率
go test -cover ./pkg/permission/...
# 生成覆蓋率報告
go test -coverprofile=coverage.out ./pkg/permission/...
go tool cover -html=coverage.out
```
### 測試結構
- **UseCase Tests**: 業務邏輯測試,使用 Mock Repository
- **Repository Tests**: 資料存取測試,使用 MiniRedis
- **JWT Tests**: 權杖生成和解析測試
- **Integration Tests**: 整合測試
## 📊 API 參考
### TokenUseCase 介面
```go
type TokenUseCase interface {
// 基本 Token 操作
NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error)
RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error)
ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error)
// 🔥 訪客只能查看公開內容
permissionUC.AddPermissionForRole(ctx, "guest", "/api/public/*", "GET")
// Token 管理
CancelToken(ctx context.Context, req entity.CancelTokenReq) error
CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error
CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error
// 查詢操作
GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error)
GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error)
// 一次性 Token
NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error)
CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error
// 黑名單操作
BlacklistToken(ctx context.Context, token string, reason string) error
IsTokenBlacklisted(ctx context.Context, jti string) (bool, error)
BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error
// 工具方法
ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error)
}
```
## 📊 Casbin 策略儲存
### 主要實體
### **MongoDB 自動持久化**
```javascript
// casbin_rules collection
{
_id: ObjectId,
ptype: "p", // 策略類型
v0: "role_admin_001", // 主體 (用戶/角色)
v1: "/api/users/*", // 對象 (資源)
v2: ".*", // 行為 (動作)
v3: "", // 擴展字段
v4: "", // 擴展字段
v5: "" // 擴展字段
}
// 角色關係
{
ptype: "g", // 分組策略
v0: "user_001", // 用戶
v1: "role_admin_001", // 角色
v2: "",
...
#### Token 實體
```go
type Token struct {
ID string // 權杖唯一標識
UID string // 用戶 ID
DeviceID string // 設備 ID
AccessToken string // Access Token
RefreshToken string // Refresh Token
ExpiresIn int // 過期時間(秒)
AccessCreateAt time.Time // Access Token 創建時間
RefreshCreateAt time.Time // Refresh Token 創建時間
RefreshExpiresIn int // Refresh Token 過期時間(秒)
}
```
### **索引優化**
```javascript
// 自動創建的索引
db.casbin_rules.createIndex({"ptype": 1})
db.casbin_rules.createIndex({"ptype": 1, "v0": 1})
db.casbin_rules.createIndex({"ptype": 1, "v0": 1, "v1": 1})
```
## 🔥 Casbin vs 自建系統
| 功能 | 自建系統 | Casbin |
|-----|---------|--------|
| **通配符支援** | ❌ 需要自己實現 | ✅ 內建支援 `/api/users/*` |
| **正則表達式** | ❌ 需要自己實現 | ✅ 內建支援 `/api/users/\\d+` |
| **角色繼承** | ❌ 需要複雜邏輯 | ✅ 自動處理繼承鏈 |
| **策略語言** | ❌ 硬編碼邏輯 | ✅ 靈活的 DSL |
| **性能優化** | ❌ 需要自己優化 | ✅ 內建緩存和索引 |
| **社群支持** | ❌ 需要自己維護 | ✅ 活躍社群,持續更新 |
| **文檔和範例** | ❌ 需要自己寫文檔 | ✅ 豐富的官方文檔 |
| **測試覆蓋** | ❌ 需要自己測試 | ✅ 大量生產環境驗證 |
## 🚀 進階功能
### **ABAC 屬性權限**
#### 黑名單實體
```go
// 未來可以升級到 ABAC 模型
// 支援基於用戶屬性、資源屬性、環境屬性的權限控制
permissionUC.CheckPermissionWithAttributes(ctx,
map[string]interface{}{
"user.department": "engineering",
"resource.owner": "john",
"time.hour": 9,
})
type BlacklistEntry struct {
JTI string // JWT ID
UID string // 用戶 ID
TokenID string // Token ID
Reason string // 加入黑名單原因
ExpiresAt int64 // 原始權杖過期時間
CreatedAt int64 // 加入黑名單時間
}
```
### **策略管理 API**
```go
// 動態管理策略
policies, err := permissionUC.GetAllPolicies(ctx)
filtered, err := permissionUC.GetFilteredPolicies(ctx, 0, "role_admin")
```
## 🔧 配置參數
## 🎯 遷移優勢
| 參數 | 類型 | 說明 | 預設值 |
|------|------|------|--------|
| `AccessSecret` | string | Access Token 簽名密鑰 | 必填 |
| `RefreshSecret` | string | Refresh Token 簽名密鑰 | 必填 |
| `AccessTokenExpiry` | Duration | Access Token 過期時間 | 15分鐘 |
| `RefreshTokenExpiry` | Duration | Refresh Token 過期時間 | 7天 |
| `OneTimeTokenExpiry` | Duration | 一次性 Token 過期時間 | 5分鐘 |
| `MaxTokensPerUser` | int | 每用戶最大 Token 數 | 10 |
| `MaxTokensPerDevice` | int | 每設備最大 Token 數 | 5 |
1. **立即獲得成熟功能** - 通配符、正則表達式、角色繼承
2. **減少維護成本** - 社群維護,無需自己投入開發時間
3. **擴展性更強** - 支援複雜的權限模型,適應業務成長
4. **性能更好** - 內建優化,大量生產環境驗證
5. **學習成本低** - 豐富的文檔和社群範例
## 🚨 錯誤處理
## 🔧 立即使用
模組定義了完整的錯誤類型:
```go
// 1. 初始化 (自動設置 Casbin)
PermissionUC: svc.NewPermissionUC(&c, rds),
// Token 驗證錯誤
var (
ErrInvalidTokenID = errors.New("invalid token ID")
ErrInvalidUID = errors.New("invalid UID")
ErrTokenExpired = errors.New("token expired")
ErrTokenNotFound = errors.New("token not found")
)
// 2. 添加權限策略
permissionUC.AddPermissionForRole(ctx, "admin", "/api/users/*", ".*")
// JWT 特定錯誤
var (
ErrInvalidJWTToken = errors.New("invalid JWT token")
ErrJWTSigningFailed = errors.New("JWT signing failed")
ErrJWTParsingFailed = errors.New("JWT parsing failed")
)
// 3. 分配角色
permissionUC.AddRoleForUser(ctx, "john", "admin")
// 4. 檢查權限 (自動處理通配符)
hasPermission, err := permissionUC.CheckUserPermission(ctx, "john", "GET", "/api/users/123")
// 黑名單錯誤
var (
ErrTokenBlacklisted = errors.New("token is blacklisted")
ErrBlacklistNotFound = errors.New("blacklist entry not found")
)
```
現在你擁有了一個**功能完整、社群驗證、持續維護**的權限系統!🎯
## 🔒 安全考量
**Casbin** 讓你專注於業務邏輯,而不是重新發明權限輪子。
### 1. 密鑰管理
- 使用強密鑰(至少 256 位)
- Access Token 和 Refresh Token 使用不同密鑰
- 定期輪換密鑰
### 2. 權杖過期
- Access Token 使用較短過期時間15分鐘
- Refresh Token 使用較長過期時間7天
- 支援自定義過期時間
### 3. 黑名單機制
- 即時撤銷可疑權杖
- 支援批量撤銷
- 自動清理過期條目
### 4. 限制機制
- 每用戶權杖數量限制
- 每設備權杖數量限制
- 防止權杖濫用
## 📈 效能優化
### 1. Redis 優化
- 使用適當的 TTL 避免記憶體洩漏
- 批量操作減少網路往返
- 使用 Pipeline 提升效能
### 2. JWT 優化
- 最小化 Claims 數據大小
- 使用高效的序列化格式
- 快取常用的解析結果
### 3. 黑名單優化
- 使用 SCAN 而非 KEYS 遍歷
- 批量檢查黑名單狀態
- 定期清理過期條目
## 🤝 貢獻指南
1. Fork 本專案
2. 創建功能分支 (`git checkout -b feature/amazing-feature`)
3. 提交變更 (`git commit -m 'Add some amazing feature'`)
4. 推送到分支 (`git push origin feature/amazing-feature`)
5. 開啟 Pull Request
### 開發規範
- 遵循 Go 編碼規範
- 保持測試覆蓋率 > 80%
- 添加適當的文檔註釋
- 使用有意義的提交訊息
## 📄 授權條款
本專案採用 MIT 授權條款 - 詳見 [LICENSE](LICENSE) 檔案
## 📞 聯絡資訊
如有問題或建議,請通過以下方式聯絡:
- 開啟 Issue
- 發送 Pull Request
- 聯絡維護團隊
---
**注意**: 本模組是 PlayOne Backend 專案的一部分,請確保與整體架構保持一致。

View File

@ -1,51 +1,64 @@
package config
import (
"time"
"backend/pkg/permission/domain/token"
)
// Config 權限系統配置
// Config represents the configuration for the permission module
type Config struct {
JWT JWTConfig `json:"jwt"`
Database DatabaseConfig `json:"database"`
Casbin CasbinConfig `json:"casbin"`
Token TokenConfig `json:"token" yaml:"token"`
}
// JWTConfig JWT 配置
type JWTConfig struct {
Secret string `json:"secret"`
AccessExpires time.Duration `json:"access_expires"`
RefreshExpires time.Duration `json:"refresh_expires"`
// TokenConfig represents token configuration
type TokenConfig struct {
// JWT signing configuration
Secret string `json:"secret" yaml:"secret"`
// Token expiration settings
Expired ExpiredConfig `json:"expired" yaml:"expired"`
RefreshExpires ExpiredConfig `json:"refresh_expires" yaml:"refresh_expires"`
// Issuer information
Issuer string `json:"issuer" yaml:"issuer"`
// Token limits
MaxTokensPerUser int `json:"max_tokens_per_user" yaml:"max_tokens_per_user"`
MaxTokensPerDevice int `json:"max_tokens_per_device" yaml:"max_tokens_per_device"`
// Security settings
EnableDeviceTracking bool `json:"enable_device_tracking" yaml:"enable_device_tracking"`
}
// DatabaseConfig 數據庫配置
type DatabaseConfig struct {
URI string `json:"uri"`
Database string `json:"database"`
Timeout time.Duration `json:"timeout"`
// ExpiredConfig represents expiration configuration
type ExpiredConfig struct {
Seconds int64 `json:"seconds" yaml:"seconds"`
}
// CasbinConfig Casbin 配置
type CasbinConfig struct {
ModelPath string `json:"model_path"` // RBAC 模型文件路徑
AutoSave bool `json:"auto_save"` // 自動保存策略
AutoLoad bool `json:"auto_load"` // 自動載入策略
AutoLoadDuration time.Duration `json:"auto_load_duration"` // 自動載入間隔
}
// DefaultConfig 返回默認配置
func DefaultConfig() Config {
return Config{
JWT: JWTConfig{
Secret: "your-secret-key",
AccessExpires: time.Hour * 2, // 2 小時
RefreshExpires: time.Hour * 24 * 7, // 7 天
},
Casbin: CasbinConfig{
ModelPath: "etc/rbac_model.conf",
AutoSave: true,
AutoLoad: true,
AutoLoadDuration: time.Second * 10,
},
// Validate validates the token configuration
func (c *TokenConfig) Validate() error {
if c.Secret == "" {
return ErrMissingSecret
}
}
if c.Expired.Seconds <= 0 {
c.Expired.Seconds = token.DefaultAccessTokenExpiry
}
if c.RefreshExpires.Seconds <= 0 {
c.RefreshExpires.Seconds = token.DefaultRefreshTokenExpiry
}
if c.Issuer == "" {
c.Issuer = "playone-backend"
}
if c.MaxTokensPerUser <= 0 {
c.MaxTokensPerUser = token.MaxTokensPerUser
}
if c.MaxTokensPerDevice <= 0 {
c.MaxTokensPerDevice = token.MaxTokensPerDevice
}
return nil
}

View File

@ -0,0 +1,243 @@
package config
import (
"testing"
"backend/pkg/permission/domain/token"
"github.com/stretchr/testify/assert"
)
func TestTokenConfig_Validate(t *testing.T) {
tests := []struct {
name string
config *TokenConfig
wantErr bool
check func(*testing.T, *TokenConfig)
}{
{
name: "valid config",
config: &TokenConfig{
Secret: "test-secret",
Expired: ExpiredConfig{
Seconds: 900,
},
RefreshExpires: ExpiredConfig{
Seconds: 604800,
},
Issuer: "test-issuer",
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
wantErr: false,
check: func(t *testing.T, c *TokenConfig) {
assert.Equal(t, "test-secret", c.Secret)
assert.Equal(t, int64(900), c.Expired.Seconds)
assert.Equal(t, int64(604800), c.RefreshExpires.Seconds)
},
},
{
name: "missing secret",
config: &TokenConfig{
Secret: "",
Expired: ExpiredConfig{
Seconds: 900,
},
},
wantErr: true,
check: nil,
},
{
name: "use default expiry",
config: &TokenConfig{
Secret: "test-secret",
Expired: ExpiredConfig{
Seconds: 0,
},
RefreshExpires: ExpiredConfig{
Seconds: 0,
},
},
wantErr: false,
check: func(t *testing.T, c *TokenConfig) {
assert.Equal(t, int64(token.DefaultAccessTokenExpiry), c.Expired.Seconds)
assert.Equal(t, int64(token.DefaultRefreshTokenExpiry), c.RefreshExpires.Seconds)
},
},
{
name: "use default issuer",
config: &TokenConfig{
Secret: "test-secret",
Issuer: "",
},
wantErr: false,
check: func(t *testing.T, c *TokenConfig) {
assert.Equal(t, "playone-backend", c.Issuer)
},
},
{
name: "use default token limits",
config: &TokenConfig{
Secret: "test-secret",
MaxTokensPerUser: 0,
MaxTokensPerDevice: 0,
},
wantErr: false,
check: func(t *testing.T, c *TokenConfig) {
assert.Equal(t, token.MaxTokensPerUser, c.MaxTokensPerUser)
assert.Equal(t, token.MaxTokensPerDevice, c.MaxTokensPerDevice)
},
},
{
name: "negative expiry time",
config: &TokenConfig{
Secret: "test-secret",
Expired: ExpiredConfig{
Seconds: -100,
},
},
wantErr: false,
check: func(t *testing.T, c *TokenConfig) {
// Negative values should be replaced with defaults
assert.Equal(t, int64(token.DefaultAccessTokenExpiry), c.Expired.Seconds)
},
},
{
name: "custom token limits",
config: &TokenConfig{
Secret: "test-secret",
MaxTokensPerUser: 20,
MaxTokensPerDevice: 10,
},
wantErr: false,
check: func(t *testing.T, c *TokenConfig) {
assert.Equal(t, 20, c.MaxTokensPerUser)
assert.Equal(t, 10, c.MaxTokensPerDevice)
},
},
{
name: "device tracking enabled",
config: &TokenConfig{
Secret: "test-secret",
EnableDeviceTracking: true,
},
wantErr: false,
check: func(t *testing.T, c *TokenConfig) {
assert.True(t, c.EnableDeviceTracking)
},
},
{
name: "device tracking disabled",
config: &TokenConfig{
Secret: "test-secret",
EnableDeviceTracking: false,
},
wantErr: false,
check: func(t *testing.T, c *TokenConfig) {
assert.False(t, c.EnableDeviceTracking)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.config.Validate()
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
if tt.check != nil {
tt.check(t, tt.config)
}
}
})
}
}
func TestExpiredConfig(t *testing.T) {
tests := []struct {
name string
seconds int64
}{
{
name: "900 seconds (15 minutes)",
seconds: 900,
},
{
name: "3600 seconds (1 hour)",
seconds: 3600,
},
{
name: "604800 seconds (7 days)",
seconds: 604800,
},
{
name: "zero seconds",
seconds: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := ExpiredConfig{
Seconds: tt.seconds,
}
assert.Equal(t, tt.seconds, config.Seconds)
})
}
}
func TestConfig_Struct(t *testing.T) {
t.Run("full config", func(t *testing.T) {
config := Config{
Token: TokenConfig{
Secret: "my-secret",
Expired: ExpiredConfig{
Seconds: 900,
},
RefreshExpires: ExpiredConfig{
Seconds: 604800,
},
Issuer: "my-app",
MaxTokensPerUser: 15,
MaxTokensPerDevice: 8,
EnableDeviceTracking: true,
},
}
assert.NotNil(t, config.Token)
assert.Equal(t, "my-secret", config.Token.Secret)
assert.Equal(t, int64(900), config.Token.Expired.Seconds)
assert.Equal(t, int64(604800), config.Token.RefreshExpires.Seconds)
assert.Equal(t, "my-app", config.Token.Issuer)
assert.Equal(t, 15, config.Token.MaxTokensPerUser)
assert.Equal(t, 8, config.Token.MaxTokensPerDevice)
assert.True(t, config.Token.EnableDeviceTracking)
})
t.Run("empty config", func(t *testing.T) {
config := Config{}
assert.Empty(t, config.Token.Secret)
assert.Equal(t, int64(0), config.Token.Expired.Seconds)
})
}
func TestTokenConfig_AllDefaults(t *testing.T) {
config := &TokenConfig{
Secret: "test-secret", // Only required field
}
err := config.Validate()
assert.NoError(t, err)
// Check all defaults are applied
assert.Equal(t, int64(token.DefaultAccessTokenExpiry), config.Expired.Seconds)
assert.Equal(t, int64(token.DefaultRefreshTokenExpiry), config.RefreshExpires.Seconds)
assert.Equal(t, "playone-backend", config.Issuer)
assert.Equal(t, token.MaxTokensPerUser, config.MaxTokensPerUser)
assert.Equal(t, token.MaxTokensPerDevice, config.MaxTokensPerDevice)
}

View File

@ -0,0 +1,10 @@
package config
import (
"fmt"
)
var (
ErrMissingSecret = fmt.Errorf("missing JWT secret key")
)

9
pkg/permission/domain/const.go Executable file
View File

@ -0,0 +1,9 @@
package domain
const (
// Module name
ModuleName = "permission"
// Default issuer
DefaultIssuer = "playone-backend"
)

View File

@ -0,0 +1,33 @@
package entity
import "time"
// BlacklistEntry represents a blacklisted JWT token
type BlacklistEntry struct {
JTI string `json:"jti"` // JWT ID (unique identifier)
UID string `json:"uid"` // User ID
TokenID string `json:"token_id"` // Token ID from original token
Reason string `json:"reason"` // Reason for blacklisting
ExpiresAt int64 `json:"expires_at"` // When the original token expires
CreatedAt int64 `json:"created_at"` // When it was blacklisted
}
// IsExpired checks if the blacklist entry is expired
func (b *BlacklistEntry) IsExpired() bool {
return b.ExpiresAt <= time.Now().Unix()
}
// Validate validates the blacklist entry
func (b *BlacklistEntry) Validate() error {
if b.JTI == "" {
return ErrInvalidJTI
}
if b.UID == "" {
return ErrInvalidUID
}
if b.TokenID == "" {
return ErrInvalidTokenID
}
return nil
}

View File

@ -0,0 +1,194 @@
package entity
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestBlacklistEntry_IsExpired(t *testing.T) {
tests := []struct {
name string
entry *BlacklistEntry
expected bool
}{
{
name: "expired entry",
entry: &BlacklistEntry{
JTI: "test-jti",
UID: "test-uid",
TokenID: "test-token",
ExpiresAt: time.Now().Add(-time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
},
expected: true,
},
{
name: "not expired entry",
entry: &BlacklistEntry{
JTI: "test-jti",
UID: "test-uid",
TokenID: "test-token",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
},
expected: false,
},
{
name: "exactly at expiry time",
entry: &BlacklistEntry{
JTI: "test-jti",
UID: "test-uid",
TokenID: "test-token",
ExpiresAt: time.Now().Unix(),
CreatedAt: time.Now().Unix(),
},
expected: true, // Equal to current time should be considered expired
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.entry.IsExpired()
assert.Equal(t, tt.expected, result)
})
}
}
func TestBlacklistEntry_Validate(t *testing.T) {
tests := []struct {
name string
entry *BlacklistEntry
wantErr bool
expectedErr error
}{
{
name: "valid entry",
entry: &BlacklistEntry{
JTI: "test-jti",
UID: "test-uid",
TokenID: "test-token",
Reason: "user logout",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
},
wantErr: false,
},
{
name: "missing JTI",
entry: &BlacklistEntry{
JTI: "",
UID: "test-uid",
TokenID: "test-token",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
},
wantErr: true,
expectedErr: ErrInvalidJTI,
},
{
name: "missing UID",
entry: &BlacklistEntry{
JTI: "test-jti",
UID: "",
TokenID: "test-token",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
},
wantErr: true,
expectedErr: ErrInvalidUID,
},
{
name: "missing TokenID",
entry: &BlacklistEntry{
JTI: "test-jti",
UID: "test-uid",
TokenID: "",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
},
wantErr: true,
expectedErr: ErrInvalidTokenID,
},
{
name: "all fields missing",
entry: &BlacklistEntry{
JTI: "",
UID: "",
TokenID: "",
},
wantErr: true,
expectedErr: ErrInvalidJTI, // First error encountered
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.entry.Validate()
if tt.wantErr {
assert.Error(t, err)
if tt.expectedErr != nil {
assert.Equal(t, tt.expectedErr, err)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestBlacklistEntry_CreatedAt(t *testing.T) {
now := time.Now().Unix()
entry := &BlacklistEntry{
JTI: "test-jti",
UID: "test-uid",
TokenID: "test-token",
Reason: "security",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: now,
}
assert.Equal(t, now, entry.CreatedAt)
}
func TestBlacklistEntry_Reason(t *testing.T) {
tests := []struct {
name string
reason string
}{
{
name: "user logout reason",
reason: "user logout",
},
{
name: "security breach reason",
reason: "security breach detected",
},
{
name: "password reset reason",
reason: "password reset",
},
{
name: "empty reason",
reason: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
entry := &BlacklistEntry{
JTI: "test-jti",
UID: "test-uid",
TokenID: "test-token",
Reason: tt.reason,
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
}
assert.Equal(t, tt.reason, entry.Reason)
})
}
}

View File

@ -1,23 +0,0 @@
package entity
import (
"time"
"go.mongodb.org/mongo-driver/v2/bson"
)
// Client 客戶端實體
type Client struct {
ID bson.ObjectID `bson:"_id,omitempty" json:"id"`
Name string `bson:"name" json:"name"`
ClientID string `bson:"client_id" json:"client_id"`
Secret string `bson:"secret" json:"secret"`
Status int `bson:"status" json:"status"`
CreateTime time.Time `bson:"create_time" json:"create_time"`
UpdateTime time.Time `bson:"update_time" json:"update_time"`
}
// CollectionName 返回集合名稱
func (c *Client) CollectionName() string {
return "clients"
}

View File

@ -0,0 +1,31 @@
package entity
import "errors"
var (
// Token validation errors
ErrInvalidTokenID = errors.New("invalid token ID")
ErrInvalidUID = errors.New("invalid UID")
ErrInvalidAccessToken = errors.New("invalid access token")
ErrTokenExpired = errors.New("token expired")
ErrTokenNotFound = errors.New("token not found")
// JWT specific errors
ErrInvalidJWTToken = errors.New("invalid JWT token")
ErrJWTSigningFailed = errors.New("JWT signing failed")
ErrJWTParsingFailed = errors.New("JWT parsing failed")
ErrInvalidSigningKey = errors.New("invalid signing key")
ErrInvalidJTI = errors.New("invalid JWT ID")
// Refresh token errors
ErrRefreshTokenExpired = errors.New("refresh token expired")
ErrInvalidRefreshToken = errors.New("invalid refresh token")
// One-time token errors
ErrOneTimeTokenExpired = errors.New("one-time token expired")
ErrInvalidOneTimeToken = errors.New("invalid one-time token")
// Blacklist errors
ErrTokenBlacklisted = errors.New("token is blacklisted")
ErrBlacklistNotFound = errors.New("blacklist entry not found")
)

View File

@ -1,66 +0,0 @@
package entity
import (
"time"
"go.mongodb.org/mongo-driver/v2/bson"
)
// PermissionType 權限類型
type PermissionType int
const (
PermissionTypeAPI PermissionType = 1 // API 權限
PermissionTypeMenu PermissionType = 2 // 選單權限
)
// Permission 權限實體
type Permission struct {
ID bson.ObjectID `bson:"_id,omitempty" json:"id"`
ParentID *bson.ObjectID `bson:"parent_id,omitempty" json:"parent_id"`
Name string `bson:"name" json:"name"`
HTTPMethod string `bson:"http_method" json:"http_method"`
HTTPPath string `bson:"http_path" json:"http_path"`
Status int `bson:"status" json:"status"`
Type PermissionType `bson:"type" json:"type"`
CreateTime time.Time `bson:"create_time" json:"create_time"`
UpdateTime time.Time `bson:"update_time" json:"update_time"`
}
// CollectionName 返回集合名稱
func (p *Permission) CollectionName() string {
return "permissions"
}
//// StatusActive 權限啟用狀態
//const StatusActive = 1
//
//// IsActive 檢查權限是否啟用
//func (p *Permission) IsActive() bool {
// return p.Status == StatusActive
//}
//
//
//// Validate 驗證權限數據
//func (p *Permission) Validate() error {
// if p.Name == "" {
// return mongo.WriteError{Code: 400, Message: "permission name is required"}
// }
// if p.Type == PermissionTypeAPI {
// if p.HTTPMethod == "" {
// return mongo.WriteError{Code: 400, Message: "http_method is required for API permission"}
// }
// if p.HTTPPath == "" {
// return mongo.WriteError{Code: 400, Message: "http_path is required for API permission"}
// }
// }
// return nil
//}
//
//// GetKey 獲取權限標識
//func (p *Permission) GetKey() string {
// if p.Type == PermissionTypeAPI {
// return p.HTTPMethod + ":" + p.HTTPPath
// }
// return p.Name
//}

View File

@ -0,0 +1,85 @@
package entity
// AuthorizationReq 定義授權請求的結構
type AuthorizationReq struct {
GrantType string `json:"grant_type"` // 授權類型
DeviceID string `json:"device_id"` // 設備 ID
Scope string `json:"scope"` // 授權範圍
Data map[string]string `json:"data"` // 附加數據
Expires int64 `json:"expires"` // 過期時間(秒)
IsRefreshToken bool `json:"is_refresh_token"` // 是否為刷新令牌
Role string `json:"role"` // 用戶角色
Account string `json:"account"` // 登入時的帳號
}
// TokenResp 定義訪問令牌響應的結構
type TokenResp struct {
AccessToken string `json:"access_token"` // 訪問令牌
TokenType string `json:"token_type"` // 令牌類型
ExpiresIn int64 `json:"expires_in"` // 過期時間(秒)
RefreshToken string `json:"refresh_token"` // 刷新令牌
}
// CreateOneTimeTokenReq 建立一次性 Token 的請求
type CreateOneTimeTokenReq struct {
Token string `json:"token"` // 長期有效的驗證令牌
}
// CreateOneTimeTokenResp 建立一次性 Token 的響應
type CreateOneTimeTokenResp struct {
OneTimeToken string `json:"one_time_token"` // 一次性令牌
}
// RefreshTokenReq 更新 Token 的請求
type RefreshTokenReq struct {
Token string `json:"token"` // 令牌
Scope string `json:"scope"` // 授權範圍
Expires int64 `json:"expires"` // 過期時間(秒)
DeviceID string `json:"device_id"` // 設備 ID
}
// RefreshTokenResp 更新令牌的響應
type RefreshTokenResp struct {
Token string `json:"token"` // 新的訪問令牌
OneTimeToken string `json:"one_time_token"` // 一次性令牌
ExpiresIn int64 `json:"expires_in"` // 過期時間(秒)
TokenType string `json:"token_type"` // 令牌類型
}
// CancelTokenReq 註銷 Token 的請求
type CancelTokenReq struct {
Token string `json:"token"` // 需要註銷的令牌
}
// DoTokenByUIDReq 基於 UID 操作 Token 的請求
type DoTokenByUIDReq struct {
IDs []string `json:"ids"` // Token ID 列表
UID string `json:"uid"` // 用戶 ID
}
// QueryTokenByUIDReq 查詢 UID 對應的 Token
type QueryTokenByUIDReq struct {
UID string `json:"uid"` // 用戶 ID
}
// ValidationTokenReq 驗證 Token 的請求
type ValidationTokenReq struct {
Token string `json:"token"` // 需要驗證的令牌
}
// ValidationTokenResp 驗證並返回 Token 詳情
type ValidationTokenResp struct {
Token Token `json:"token"` // Token 詳情
Data map[string]string `json:"data"` // 附加數據
}
// DoTokenByDeviceIDReq 基於設備 ID 操作 Token 的請求
type DoTokenByDeviceIDReq struct {
DeviceID string `json:"device_id"` // 設備 ID
}
// CancelOneTimeTokenReq 取消一次性 Token 的請求
type CancelOneTimeTokenReq struct {
Token []string `json:"token"` // 一次性 Token 列表
}

View File

@ -1,66 +0,0 @@
package entity
import (
"time"
"go.mongodb.org/mongo-driver/v2/bson"
)
// Permissions 權限映射表
type Permissions map[string]int
// Role 角色實體
type Role struct {
ID bson.ObjectID `bson:"_id,omitempty" json:"id"`
ClientID string `bson:"client_id" json:"client_id"`
UID string `bson:"uid" json:"uid"`
Name string `bson:"name" json:"name"`
Status int `bson:"status" json:"status"`
Permissions Permissions `bson:"permissions" json:"permissions"`
CreateTime time.Time `bson:"create_time" json:"create_time"`
UpdateTime time.Time `bson:"update_time" json:"update_time"`
}
// CollectionName 返回集合名稱
func (r *Role) CollectionName() string {
return "roles"
}
// // Validate 驗證角色數據
//
// func (r *Role) Validate() error {
// if r.ClientID == "" {
// return mongo.WriteError{Code: 400, Message: "client_id is required"}
// }
// if r.Name == "" {
// return mongo.WriteError{Code: 400, Message: "role name is required"}
// }
// return nil
// }
//
// // HasPermission 檢查是否有指定權限
//
// func (r *Role) HasPermission(key string) bool {
// if !r.IsActive() {
// return false
// }
//
// permission, exists := r.Permissions[key]
// return exists && permission == 1 // 1 表示有權限
// }
//
// AddPermission 添加權限
func (r *Role) AddPermission(key string) {
if r.Permissions == nil {
r.Permissions = make(Permissions)
}
r.Permissions[key] = 1
}
// RemovePermission 移除權限
func (r *Role) RemovePermission(key string) {
if r.Permissions != nil {
delete(r.Permissions, key)
}
}

View File

@ -1,46 +1,65 @@
package entity
import (
"go.mongodb.org/mongo-driver/v2/bson"
"time"
"github.com/golang-jwt/jwt/v4"
)
// Token 令牌實體
// Token represents a token entity stored in Redis
type Token struct {
ID bson.ObjectID `bson:"_id,omitempty" json:"id"`
UID string `bson:"uid" json:"uid"`
ClientID string `bson:"client_id" json:"client_id"`
AccessToken string `bson:"access_token" json:"access_token"`
RefreshToken string `bson:"refresh_token" json:"refresh_token"`
DeviceID string `bson:"device_id" json:"device_id"`
ExpiresAt time.Time `bson:"expires_at" json:"expires_at"`
CreateTime time.Time `bson:"create_time" json:"create_time"`
UpdateTime time.Time `bson:"update_time" json:"update_time"`
ID string `json:"id"` // Token ID (KSUID)
UID string `json:"uid"` // User ID
DeviceID string `json:"device_id"` // Device ID
AccessToken string `json:"access_token"` // JWT access token
RefreshToken string `json:"refresh_token"` // SHA256 refresh token
ExpiresIn int `json:"expires_in"` // Access token expiry (Unix timestamp)
RefreshExpiresIn int `json:"refresh_expires_in"` // Refresh token expiry (Unix timestamp)
AccessCreateAt time.Time `json:"access_create_at"` // Access token creation time
RefreshCreateAt time.Time `json:"refresh_create_at"` // Refresh token creation time
}
// CollectionName 返回集合名稱
func (t *Token) CollectionName() string {
return "tokens"
// IsExpired checks if the access token is expired
func (t *Token) IsExpired() bool {
return time.Now().Unix() > int64(t.ExpiresIn)
}
//// IsExpired 檢查令牌是否過期
//func (t *Token) IsExpired() bool {
// return time.Now().After(t.ExpiresAt)
//}
//
//// Validate 驗證令牌數據
//func (t *Token) Validate() error {
// if t.UID == "" {
// return mongo.WriteError{Code: 400, Message: "uid is required"}
// }
// if t.ClientID == "" {
// return mongo.WriteError{Code: 400, Message: "client_id is required"}
// }
// if t.AccessToken == "" {
// return mongo.WriteError{Code: 400, Message: "access_token is required"}
// }
// if t.RefreshToken == "" {
// return mongo.WriteError{Code: 400, Message: "refresh_token is required"}
// }
// return nil
//}
// IsRefreshExpired checks if the refresh token is expired
func (t *Token) IsRefreshExpired() bool {
return time.Now().Unix() > int64(t.RefreshExpiresIn)
}
// RedisRefreshExpiredSec returns the refresh token expiry duration in seconds
func (t *Token) RedisRefreshExpiredSec() int {
now := time.Now().Unix()
if int64(t.RefreshExpiresIn) <= now {
return 0
}
return t.RefreshExpiresIn - int(now)
}
// Ticket represents a one-time token ticket
type Ticket struct {
Data map[string]string `json:"data"` // Token claims data
Token Token `json:"token"` // Associated token
}
// Claims represents JWT claims structure
type Claims struct {
jwt.RegisteredClaims
Data interface{} `json:"data"`
}
// Validate validates the token entity
func (t *Token) Validate() error {
if t.ID == "" {
return ErrInvalidTokenID
}
if t.UID == "" {
return ErrInvalidUID
}
if t.AccessToken == "" {
return ErrInvalidAccessToken
}
return nil
}

View File

@ -0,0 +1,318 @@
package entity
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestToken_IsExpired(t *testing.T) {
tests := []struct {
name string
token *Token
expected bool
}{
{
name: "expired token",
token: &Token{
ID: "test-id",
UID: "test-uid",
ExpiresIn: int(time.Now().Add(-time.Hour).Unix()),
},
expected: true,
},
{
name: "valid token",
token: &Token{
ID: "test-id",
UID: "test-uid",
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
},
expected: false,
},
{
name: "token expiring now",
token: &Token{
ID: "test-id",
UID: "test-uid",
ExpiresIn: int(time.Now().Unix()) - 1,
},
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.token.IsExpired()
assert.Equal(t, tt.expected, result)
})
}
}
func TestToken_IsRefreshExpired(t *testing.T) {
tests := []struct {
name string
token *Token
expected bool
}{
{
name: "expired refresh token",
token: &Token{
ID: "test-id",
UID: "test-uid",
RefreshExpiresIn: int(time.Now().Add(-time.Hour).Unix()),
},
expected: true,
},
{
name: "valid refresh token",
token: &Token{
ID: "test-id",
UID: "test-uid",
RefreshExpiresIn: int(time.Now().Add(time.Hour).Unix()),
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.token.IsRefreshExpired()
assert.Equal(t, tt.expected, result)
})
}
}
func TestToken_RedisRefreshExpiredSec(t *testing.T) {
tests := []struct {
name string
token *Token
expected int
}{
{
name: "token with future expiry",
token: &Token{
ID: "test-id",
UID: "test-uid",
RefreshExpiresIn: int(time.Now().Add(time.Hour).Unix()),
},
expected: 3600, // Approximately 1 hour in seconds
},
{
name: "token already expired",
token: &Token{
ID: "test-id",
UID: "test-uid",
RefreshExpiresIn: int(time.Now().Add(-time.Hour).Unix()),
},
expected: 0,
},
{
name: "token expiring now",
token: &Token{
ID: "test-id",
UID: "test-uid",
RefreshExpiresIn: int(time.Now().Unix()),
},
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.token.RedisRefreshExpiredSec()
if tt.expected == 0 {
assert.Equal(t, 0, result)
} else {
// Allow some margin for test execution time
assert.InDelta(t, tt.expected, result, 5)
}
})
}
}
func TestToken_Validate(t *testing.T) {
tests := []struct {
name string
token *Token
wantErr bool
expectedErr error
}{
{
name: "valid token",
token: &Token{
ID: "test-id",
UID: "test-uid",
AccessToken: "test-access-token",
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
},
wantErr: false,
},
{
name: "missing ID",
token: &Token{
ID: "",
UID: "test-uid",
AccessToken: "test-access-token",
},
wantErr: true,
expectedErr: ErrInvalidTokenID,
},
{
name: "missing UID",
token: &Token{
ID: "test-id",
UID: "",
AccessToken: "test-access-token",
},
wantErr: true,
expectedErr: ErrInvalidUID,
},
{
name: "missing AccessToken",
token: &Token{
ID: "test-id",
UID: "test-uid",
AccessToken: "",
},
wantErr: true,
expectedErr: ErrInvalidAccessToken,
},
{
name: "all fields missing",
token: &Token{
ID: "",
UID: "",
AccessToken: "",
},
wantErr: true,
expectedErr: ErrInvalidTokenID,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.token.Validate()
if tt.wantErr {
assert.Error(t, err)
if tt.expectedErr != nil {
assert.Equal(t, tt.expectedErr, err)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestTicket(t *testing.T) {
t.Run("ticket with data", func(t *testing.T) {
ticket := Ticket{
Data: map[string]string{
"uid": "user123",
"role": "admin",
},
Token: Token{
ID: "token123",
UID: "user123",
AccessToken: "access-token",
},
}
assert.NotNil(t, ticket.Data)
assert.Equal(t, "user123", ticket.Data["uid"])
assert.Equal(t, "admin", ticket.Data["role"])
assert.Equal(t, "token123", ticket.Token.ID)
})
t.Run("empty ticket", func(t *testing.T) {
ticket := Ticket{}
assert.Nil(t, ticket.Data)
assert.Empty(t, ticket.Token.ID)
})
}
func TestToken_DeviceID(t *testing.T) {
tests := []struct {
name string
deviceID string
}{
{
name: "with device ID",
deviceID: "device123",
},
{
name: "empty device ID",
deviceID: "",
},
{
name: "UUID device ID",
deviceID: "550e8400-e29b-41d4-a716-446655440000",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := &Token{
ID: "test-id",
UID: "test-uid",
DeviceID: tt.deviceID,
AccessToken: "test-token",
}
assert.Equal(t, tt.deviceID, token.DeviceID)
})
}
}
func TestToken_RefreshToken(t *testing.T) {
tests := []struct {
name string
refreshToken string
}{
{
name: "with refresh token",
refreshToken: "refresh-token-123",
},
{
name: "empty refresh token",
refreshToken: "",
},
{
name: "long refresh token",
refreshToken: "very-long-refresh-token-with-hash-abcdef1234567890",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := &Token{
ID: "test-id",
UID: "test-uid",
AccessToken: "access-token",
RefreshToken: tt.refreshToken,
}
assert.Equal(t, tt.refreshToken, token.RefreshToken)
})
}
}
func TestToken_Timestamps(t *testing.T) {
now := time.Now()
token := &Token{
ID: "test-id",
UID: "test-uid",
AccessToken: "access-token",
AccessCreateAt: now,
RefreshCreateAt: now.Add(time.Second),
}
assert.Equal(t, now, token.AccessCreateAt)
assert.True(t, token.RefreshCreateAt.After(token.AccessCreateAt))
}

View File

@ -1,37 +0,0 @@
package entity
import (
"time"
"go.mongodb.org/mongo-driver/v2/bson"
)
// UserRole 用戶角色關聯實體
type UserRole struct {
ID bson.ObjectID `bson:"_id,omitempty" json:"id"`
Brand string `bson:"brand" json:"brand"`
UID string `bson:"uid" json:"uid"`
RoleUID string `bson:"role_uid" json:"role_uid"`
Status int `bson:"status" json:"status"`
CreateTime time.Time `bson:"create_time" json:"create_time"`
UpdateTime time.Time `bson:"update_time" json:"update_time"`
}
// CollectionName 返回集合名稱
func (ur *UserRole) CollectionName() string {
return "user_roles"
}
//// Validate 驗證用戶角色關聯數據
//func (ur *UserRole) Validate() error {
// if ur.Brand == "" {
// return mongo.WriteError{Code: 400, Message: "brand is required"}
// }
// if ur.UID == "" {
// return mongo.WriteError{Code: 400, Message: "uid is required"}
// }
// if ur.RoleUID == "" {
// return mongo.WriteError{Code: 400, Message: "role_uid is required"}
// }
// return nil
//}

View File

@ -1,15 +1,31 @@
package domain
import "backend/pkg/library/errs"
import "errors"
const (
FailedToGetByID errs.ErrorCode = iota + 1
FailedToGetByClientID
FailedToGetPermission
FailedToGetPermissionByKey
FailedToGetRoleByID
FailedToGetByUID
FailedToGetByClientAndName
FailedToGetByClientAndName
FailedToGetByClientAndName
)
var (
// Token validation errors
ErrInvalidTokenID = errors.New("invalid token ID")
ErrInvalidUID = errors.New("invalid UID")
ErrInvalidAccessToken = errors.New("invalid access token")
ErrTokenExpired = errors.New("token expired")
ErrTokenNotFound = errors.New("token not found")
// JWT specific errors
ErrInvalidJWTToken = errors.New("invalid JWT token")
ErrJWTSigningFailed = errors.New("JWT signing failed")
ErrJWTParsingFailed = errors.New("JWT parsing failed")
ErrInvalidSigningKey = errors.New("invalid signing key")
ErrInvalidJTI = errors.New("invalid JWT ID")
// Refresh token errors
ErrRefreshTokenExpired = errors.New("refresh token expired")
ErrInvalidRefreshToken = errors.New("invalid refresh token")
// One-time token errors
ErrOneTimeTokenExpired = errors.New("one-time token expired")
ErrInvalidOneTimeToken = errors.New("invalid one-time token")
// Blacklist errors
ErrTokenBlacklisted = errors.New("token is blacklisted")
ErrBlacklistNotFound = errors.New("blacklist entry not found")
)

View File

@ -0,0 +1,64 @@
package domain
import "time"
// DeviceToken 表示裝置與 Token 之間的關聯
type DeviceToken struct {
DeviceID string // 裝置的唯一標識符
TokenID string // Token 的唯一標識符
}
type UIDToken map[string]int64
// Ticket 表示一次性使用的 Token 結構,包含數據和 Token 資訊
type Ticket struct {
Data any `json:"data"` // 任意附加數據
Token Token `json:"token"` // 一次性使用的 Token 資訊
}
// Token 表示使用者的存取和刷新 Token 資訊
type Token struct {
ID string `json:"id"` // Token 的唯一標識符
UID string `json:"uid"` // 用戶的唯一標識符
DeviceID string `json:"device_id"` // 裝置的唯一標識符
AccessToken string `json:"access_token"` // 存取 Token
ExpiresIn int `json:"expires_in"` // 存取 Token 的有效時長(秒)
AccessCreateAt time.Time `json:"access_create_at"` // 存取 Token 的創建時間
RefreshToken string `json:"refresh_token"` // 刷新 Token
RefreshExpiresIn int `json:"refresh_expires_in"` // 刷新 Token 的有效時長(秒)
RefreshCreateAt time.Time `json:"refresh_create_at"` // 刷新 Token 的創建時間
}
// AccessTokenExpires 返回存取 Token 的有效期(以秒為單位)。
func (t *Token) AccessTokenExpires() time.Duration {
return time.Duration(t.ExpiresIn) * time.Second
}
// RefreshTokenExpires 返回刷新 Token 的有效期(以秒為單位)。
func (t *Token) RefreshTokenExpires() time.Duration {
return time.Duration(t.RefreshExpiresIn) * time.Second
}
// RefreshTokenExpiresUnix 返回刷新 Token 的到期時間UnixNano 時間戳)。
func (t *Token) RefreshTokenExpiresUnix() int64 {
return time.Now().Add(t.RefreshTokenExpires()).UnixNano()
}
// IsExpires 檢查存取 Token 是否已過期。如果存取 Token 的創建時間加上其有效期早於當前時間,則返回 true。
func (t *Token) IsExpires() bool {
return t.AccessCreateAt.Add(t.AccessTokenExpires()).Before(time.Now())
}
// RedisExpiredSec 返回存取 Token 在 Redis 中的剩餘有效時間(秒)。計算方法為:從到期時間的 Unix 時間戳減去當前時間。
func (t *Token) RedisExpiredSec() int64 {
sec := time.Unix(int64(t.ExpiresIn), 0).Sub(time.Now().UTC())
return int64(sec.Seconds())
}
// RedisRefreshExpiredSec 返回刷新 Token 在 Redis 中的剩餘有效時間(秒)。計算方法為:從刷新到期時間的 Unix 時間戳減去當前時間。
func (t *Token) RedisRefreshExpiredSec() int64 {
sec := time.Unix(int64(t.RefreshExpiresIn), 0).Sub(time.Now().UTC())
return int64(sec.Seconds())
}

View File

@ -0,0 +1,141 @@
package domain
import (
"testing"
"time"
)
func TestToken_AccessTokenExpires(t *testing.T) {
tests := []struct {
name string
expiresIn int
want time.Duration
}{
{"zero expiration", 0, 0},
{"1 second expiration", 1, time.Second},
{"60 seconds expiration", 60, time.Minute},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := Token{ExpiresIn: tt.expiresIn}
if got := token.AccessTokenExpires(); got != tt.want {
t.Errorf("AccessTokenExpires() = %v, want %v", got, tt.want)
}
})
}
}
func TestToken_RefreshTokenExpires(t *testing.T) {
tests := []struct {
name string
refreshExpires int
want time.Duration
}{
{"zero expiration", 0, 0},
{"1 second expiration", 1, time.Second},
{"90 seconds expiration", 90, 90 * time.Second},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := Token{RefreshExpiresIn: tt.refreshExpires}
if got := token.RefreshTokenExpires(); got != tt.want {
t.Errorf("RefreshTokenExpires() = %v, want %v", got, tt.want)
}
})
}
}
func TestToken_RefreshTokenExpiresUnix(t *testing.T) {
tests := []struct {
name string
refreshExpires int
}{
{"zero expiration", 0},
{"10 seconds expiration", 10},
{"60 seconds expiration", 60},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := Token{RefreshExpiresIn: tt.refreshExpires}
got := token.RefreshTokenExpiresUnix()
want := time.Now().Add(time.Duration(tt.refreshExpires) * time.Second).UnixNano()
// 確保計算的時間戳在合理範圍內
if got < want-1e9 || got > want+1e9 {
t.Errorf("RefreshTokenExpiresUnix() = %v, want %v (±1 second tolerance)", got, want)
}
})
}
}
func TestToken_IsExpires(t *testing.T) {
now := time.Now()
tests := []struct {
name string
accessCreateAt time.Time
expiresIn int
want bool
}{
{"not expired", now.Add(-5 * time.Minute), 600, false}, // 10-minute expiry, created 5 minutes ago
{"just expired", now.Add(-10 * time.Minute), 600, true}, // 10-minute expiry, created 10 minutes ago
{"already expired", now.Add(-15 * time.Minute), 600, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := Token{AccessCreateAt: tt.accessCreateAt, ExpiresIn: tt.expiresIn}
if got := token.IsExpires(); got != tt.want {
t.Errorf("IsExpires() = %v, want %v", got, tt.want)
}
})
}
}
func TestToken_RedisExpiredSec(t *testing.T) {
now := time.Now().Unix()
tests := []struct {
name string
expiresIn int
}{
{"zero expiration", 0},
{"future expiration", int(now + 3600)}, // Expires in 1 hour
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := Token{ExpiresIn: tt.expiresIn}
got := token.RedisExpiredSec()
want := time.Unix(int64(tt.expiresIn), 0).Sub(time.Now().UTC()).Seconds()
if float64(got) < want-1 || float64(got) > want+1 {
t.Errorf("RedisExpiredSec() = %v, want ~%v (±1 second tolerance)", got, int64(want))
}
})
}
}
func TestToken_RedisRefreshExpiredSec(t *testing.T) {
now := time.Now().Unix()
tests := []struct {
name string
refreshExpires int
}{
{"zero refresh expiration", 0},
{"future refresh expiration", int(now + 7200)}, // Expires in 2 hours
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
token := Token{RefreshExpiresIn: tt.refreshExpires}
got := token.RedisRefreshExpiredSec()
want := time.Unix(int64(tt.refreshExpires), 0).Sub(time.Now().UTC()).Seconds()
if float64(got) < want-1 || float64(got) > want+1 {
t.Errorf("RedisRefreshExpiredSec() = %v, want ~%v (±1 second tolerance)", got, int64(want))
}
})
}
}

View File

@ -1,46 +0,0 @@
package permission
import "time"
// Status 狀態常數
const (
StatusActive = 1
StatusInactive = 2
)
// Type 權限類型
type Type int8
const (
TypeBackend Type = iota + 1
TypeFrontend
)
// Status 權限狀態
type Status string
const (
StatusOpen Status = "open"
StatusClose Status = "close"
)
// Permissions 權限映射
type Permissions map[string]Status
// GrantType 授權類型
type GrantType string
const (
GrantTypePassword GrantType = "password"
GrantTypeClient GrantType = "client_credentials"
GrantTypeRefreshToken GrantType = "refresh_token"
)
// Default Values 預設值
const (
DefaultRole = "user"
AdminRole = "admin"
AdminRoleUID = "AM000000"
AdminUID = "B000000"
RefreshTokenTTL = 5 * time.Second
)

41
pkg/permission/domain/redis.go Normal file → Executable file
View File

@ -2,39 +2,42 @@ package domain
import "strings"
// RedisKey represents a Redis key type with helper methods for key construction.
const (
TicketKeyPrefix = "tic/"
)
const (
ClientDataKey = "permission:clients"
)
type RedisKey string
const (
ClientRedisKey RedisKey = "client"
PermissionRedisKey RedisKey = "permission"
RoleRedisKey RedisKey = "role"
UserRoleRedisKey RedisKey = "user_role"
AccessTokenRedisKey RedisKey = "access_token"
RefreshTokenRedisKey RedisKey = "refresh_token"
DeviceTokenRedisKey RedisKey = "device_token"
UIDTokenRedisKey RedisKey = "uid_token"
TicketRedisKey RedisKey = "ticket"
DeviceUIDRedisKey RedisKey = "device_uid"
)
// ToString converts the RedisKey to its full string representation with the member prefix.
func (key RedisKey) ToString() string {
return "member:" + string(key)
return "permission:" + string(key)
}
// With appends additional parts to the RedisKey, separated by colons.
func (key RedisKey) With(s ...string) RedisKey {
parts := append([]string{string(key)}, s...)
return RedisKey(strings.Join(parts, ":"))
}
func GeClientRedisKey(id string) string {
return ClientRedisKey.With(id).ToString()
func GetAccessTokenRedisKey(id string) string {
return AccessTokenRedisKey.With(id).ToString()
}
func GetPermissionRedisKey(id string) string {
return PermissionRedisKey.With(id).ToString()
func GetUIDTokenRedisKey(uid string) string {
return UIDTokenRedisKey.With(uid).ToString()
}
func GetRoleRedisKeyRedisKey(id string) string {
return RoleRedisKey.With(id).ToString()
}
func GetUserRoleRedisKey(id string) string {
return UserRoleRedisKey.With(id).ToString()
}
func GetTicketRedisKey(ticket string) string {
return TicketRedisKey.With(ticket).ToString()
}

View File

@ -0,0 +1,98 @@
package domain
import "testing"
func TestRedisKey_ToString(t *testing.T) {
tests := []struct {
name string
key RedisKey
want string
}{
{"AccessToken Key", AccessTokenRedisKey, "permission:access_token"},
{"UIDToken Key", UIDTokenRedisKey, "permission:uid_token"},
{"Ticket Key", TicketRedisKey, "permission:ticket"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.key.ToString(); got != tt.want {
t.Errorf("ToString() = %v, want %v", got, tt.want)
}
})
}
}
func TestRedisKey_With(t *testing.T) {
tests := []struct {
name string
key RedisKey
args []string
want string
}{
{"AccessToken with ID", AccessTokenRedisKey, []string{"12345"}, "access_token:12345"},
{"UIDToken with UID", UIDTokenRedisKey, []string{"67890"}, "uid_token:67890"},
{"Ticket with multiple parts", TicketRedisKey, []string{"session", "12345"}, "ticket:session:12345"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.key.With(tt.args...).ToString(); got != "permission:"+tt.want {
t.Errorf("With() = %v, want %v", got, "permission:"+tt.want)
}
})
}
}
func TestGetAccessTokenRedisKey(t *testing.T) {
tests := []struct {
name string
id string
want string
}{
{"AccessToken Key with ID", "12345", "permission:access_token:12345"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetAccessTokenRedisKey(tt.id); got != tt.want {
t.Errorf("GetAccessTokenRedisKey() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetUIDTokenRedisKey(t *testing.T) {
tests := []struct {
name string
uid string
want string
}{
{"UIDToken Key with UID", "67890", "permission:uid_token:67890"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetUIDTokenRedisKey(tt.uid); got != tt.want {
t.Errorf("GetUIDTokenRedisKey() = %v, want %v", got, tt.want)
}
})
}
}
func TestGetTicketRedisKey(t *testing.T) {
tests := []struct {
name string
ticket string
want string
}{
{"Ticket Key with Ticket", "session123", "permission:ticket:session123"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetTicketRedisKey(tt.ticket); got != tt.want {
t.Errorf("GetTicketRedisKey() = %v, want %v", got, tt.want)
}
})
}
}

View File

@ -1,25 +0,0 @@
package repository
import (
"backend/pkg/permission/domain/entity"
"context"
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
)
// ClientRepository 客戶端倉庫介面
type ClientRepository interface {
Create(ctx context.Context, client *entity.Client) error
GetByID(ctx context.Context, id string) (*entity.Client, error)
GetByClientID(ctx context.Context, clientID string) (*entity.Client, error)
Update(ctx context.Context, id string, client *entity.Client) error
Delete(ctx context.Context, id string) error
List(ctx context.Context, filter ClientFilter) ([]*entity.Client, error)
Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error)
}
// ClientFilter 客戶端查詢過濾器
type ClientFilter struct {
Status *int
Limit int
Skip int
}

View File

@ -1,28 +0,0 @@
package repository
import (
"backend/pkg/permission/domain/entity"
"context"
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
)
// PermissionRepository 權限倉庫介面
type PermissionRepository interface {
Create(ctx context.Context, permission *entity.Permission) error
GetByID(ctx context.Context, id string) (*entity.Permission, error)
GetByKey(ctx context.Context, httpMethod, httpPath string) (*entity.Permission, error)
Update(ctx context.Context, id string, permission *entity.Permission) error
Delete(ctx context.Context, id string) error
List(ctx context.Context, filter PermissionFilter) ([]*entity.Permission, error)
GetActivePermissions(ctx context.Context) ([]*entity.Permission, error)
Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error)
}
// PermissionFilter 權限查詢過濾器
type PermissionFilter struct {
Status *int
Type *entity.PermissionType
ParentID *string
Limit int
Skip int
}

View File

@ -1,28 +0,0 @@
package repository
import (
"backend/pkg/permission/domain/entity"
"context"
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
)
// RoleRepository 角色倉庫介面
type RoleRepository interface {
Create(ctx context.Context, role *entity.Role) error
GetByID(ctx context.Context, id string) (*entity.Role, error)
GetByUID(ctx context.Context, uid string) (*entity.Role, error)
GetByClientAndName(ctx context.Context, clientID, name string) (*entity.Role, error)
Update(ctx context.Context, id string, role *entity.Role) error
Delete(ctx context.Context, id string) error
List(ctx context.Context, filter RoleFilter) ([]*entity.Role, error)
GetRolesByClientID(ctx context.Context, clientID string) ([]*entity.Role, error)
Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error)
}
// RoleFilter 角色查詢過濾器
type RoleFilter struct {
ClientID string
Status *int
Limit int
Skip int
}

View File

@ -1,16 +1,49 @@
package repository
import (
"backend/pkg/permission/domain/entity"
"context"
"time"
"backend/pkg/permission/domain/entity"
)
// TokenRepository 令牌倉庫介面
// TokenRepository 定義了與 Redis 相關的 Token 操作方法
//nolint:interfacebloat
type TokenRepository interface {
Create(ctx context.Context, token *entity.Token) error
GetByAccessToken(ctx context.Context, accessToken string) (*entity.Token, error)
GetByRefreshToken(ctx context.Context, refreshToken string) (*entity.Token, error)
Update(ctx context.Context, token *entity.Token) error
Delete(ctx context.Context, id string) error
DeleteByUserID(ctx context.Context, uid string) error
}
// Create 建立新的 Token 並存儲至 Redis
Create(ctx context.Context, token entity.Token) error
// CreateOneTimeToken 建立臨時一次性Token並指定有效期限
CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error
// GetAccessTokenByOneTimeToken 根據一次性 Token 獲取對應的存取 Token
GetAccessTokenByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error)
// GetAccessTokenByID 根據 Token ID 獲取對應的存取 Token
GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error)
// GetAccessTokensByUID 根據用戶 ID 獲取該用戶的所有存取 Token
GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error)
// GetAccessTokenCountByUID 根據用戶 ID 獲取該用戶的存取 Token 數量
GetAccessTokenCountByUID(ctx context.Context, uid string) (int, error)
// GetAccessTokensByDeviceID 根據裝置 ID 獲取該裝置的所有存取 Token
GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error)
// GetAccessTokenCountByDeviceID 根據裝置 ID 獲取該裝置的存取 Token 數量
GetAccessTokenCountByDeviceID(ctx context.Context, deviceID string) (int, error)
// Delete 刪除指定的 Token
Delete(ctx context.Context, token entity.Token) error
// DeleteOneTimeToken 批量刪除一次性 Token
DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error
// DeleteAccessTokenByID 根據 Token ID 批量刪除存取 Token
DeleteAccessTokenByID(ctx context.Context, ids []string) error
// DeleteAccessTokensByUID 根據用戶 ID 刪除該用戶的所有存取 Token
DeleteAccessTokensByUID(ctx context.Context, uid string) error
// DeleteAccessTokensByDeviceID 根據裝置 ID 刪除該裝置的所有存取 Token
DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error
// Blacklist operations
// AddToBlacklist 將 JWT token 加入黑名單
AddToBlacklist(ctx context.Context, entry *entity.BlacklistEntry, ttl time.Duration) error
// IsBlacklisted 檢查 JWT token 是否在黑名單中
IsBlacklisted(ctx context.Context, jti string) (bool, error)
// RemoveFromBlacklist 從黑名單中移除 JWT token
RemoveFromBlacklist(ctx context.Context, jti string) error
// GetBlacklistedTokensByUID 獲取用戶的所有黑名單 token
GetBlacklistedTokensByUID(ctx context.Context, uid string) ([]*entity.BlacklistEntry, error)
}

View File

@ -1,28 +0,0 @@
package repository
import (
"backend/pkg/permission/domain/entity"
"context"
)
// UserRoleRepository 用戶角色倉庫介面
type UserRoleRepository interface {
Create(ctx context.Context, userRole *entity.UserRole) error
GetByID(ctx context.Context, id string) (*entity.UserRole, error)
GetByUserAndRole(ctx context.Context, uid, roleUID string) (*entity.UserRole, error)
Update(ctx context.Context, id string, userRole *entity.UserRole) error
Delete(ctx context.Context, id string) error
List(ctx context.Context, filter UserRoleFilter) ([]*entity.UserRole, error)
GetUserRolesByUID(ctx context.Context, uid string) ([]*entity.UserRole, error)
DeleteByUserAndRole(ctx context.Context, uid, roleUID string) error
}
// UserRoleFilter 用戶角色查詢過濾器
type UserRoleFilter struct {
Brand string
UID string
RoleUID string
Status *int
Limit int
Skip int
}

View File

@ -0,0 +1,26 @@
package token
// GrantType represents OAuth 2.0 grant types
type GrantType string
// ToString returns the string representation of GrantType
func (g GrantType) ToString() string {
return string(g)
}
// IsValid returns true if the grant type is valid
func (g GrantType) IsValid() bool {
switch g {
case PasswordCredentials, ClientCredentials, Refreshing:
return true
default:
return false
}
}
const (
PasswordCredentials GrantType = "password"
ClientCredentials GrantType = "client_credentials"
Refreshing GrantType = "refresh_token"
)

View File

@ -0,0 +1,160 @@
package token
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestGrantType_ToString(t *testing.T) {
tests := []struct {
name string
grantType GrantType
expected string
}{
{
name: "password credentials",
grantType: PasswordCredentials,
expected: "password",
},
{
name: "client credentials",
grantType: ClientCredentials,
expected: "client_credentials",
},
{
name: "refreshing",
grantType: Refreshing,
expected: "refresh_token",
},
{
name: "custom grant type",
grantType: GrantType("custom"),
expected: "custom",
},
{
name: "empty grant type",
grantType: GrantType(""),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.grantType.ToString()
assert.Equal(t, tt.expected, result)
})
}
}
func TestGrantType_IsValid(t *testing.T) {
tests := []struct {
name string
grantType GrantType
expected bool
}{
{
name: "password credentials is valid",
grantType: PasswordCredentials,
expected: true,
},
{
name: "client credentials is valid",
grantType: ClientCredentials,
expected: true,
},
{
name: "refreshing is valid",
grantType: Refreshing,
expected: true,
},
{
name: "invalid grant type",
grantType: GrantType("invalid"),
expected: false,
},
{
name: "empty grant type",
grantType: GrantType(""),
expected: false,
},
{
name: "authorization code (not implemented)",
grantType: GrantType("authorization_code"),
expected: false,
},
{
name: "implicit (not implemented)",
grantType: GrantType("implicit"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.grantType.IsValid()
assert.Equal(t, tt.expected, result)
})
}
}
func TestGrantType_Constants(t *testing.T) {
t.Run("verify constant values", func(t *testing.T) {
assert.Equal(t, "password", PasswordCredentials.ToString())
assert.Equal(t, "client_credentials", ClientCredentials.ToString())
assert.Equal(t, "refresh_token", Refreshing.ToString())
})
t.Run("verify all constants are valid", func(t *testing.T) {
assert.True(t, PasswordCredentials.IsValid())
assert.True(t, ClientCredentials.IsValid())
assert.True(t, Refreshing.IsValid())
})
}
func TestGrantType_StringComparison(t *testing.T) {
tests := []struct {
name string
gt1 GrantType
gt2 GrantType
expected bool
}{
{
name: "same grant type",
gt1: PasswordCredentials,
gt2: PasswordCredentials,
expected: true,
},
{
name: "different grant types",
gt1: PasswordCredentials,
gt2: ClientCredentials,
expected: false,
},
{
name: "string comparison",
gt1: GrantType("password"),
gt2: PasswordCredentials,
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.gt1 == tt.gt2
assert.Equal(t, tt.expected, result)
})
}
}
func TestGrantType_CaseSensitive(t *testing.T) {
t.Run("case sensitive comparison", func(t *testing.T) {
gt1 := GrantType("password")
gt2 := GrantType("PASSWORD")
assert.NotEqual(t, gt1, gt2)
assert.True(t, gt1.IsValid())
assert.False(t, gt2.IsValid())
})
}

View File

@ -0,0 +1,74 @@
package token
// Type represents the type of token
type Type string
const (
TypeBearer Type = "Bearer"
TypeBasic Type = "Basic"
)
// String returns the string representation of TokenType
func (t Type) String() string {
return string(t)
}
// IsValid returns true if the token type is valid
func (t Type) IsValid() bool {
return t == TypeBearer || t == TypeBasic
}
// Redis key prefixes and patterns
const (
AccessTokenKeyPrefix = "access_token:"
RefreshTokenKeyPrefix = "refresh_token:"
OneTimeTokenKeyPrefix = "one_time_token:"
UIDTokenKeyPrefix = "uid_tokens:"
DeviceTokenKeyPrefix = "device_tokens:"
TicketKeyPrefix = "ticket:"
BlacklistKeyPrefix = "blacklist:"
)
// Redis key helper functions
func GetAccessTokenRedisKey(tokenID string) string {
return AccessTokenKeyPrefix + tokenID
}
func RefreshTokenRedisKey(tokenID string) string {
return RefreshTokenKeyPrefix + tokenID
}
func UIDTokenRedisKey(uid string) string {
return UIDTokenKeyPrefix + uid
}
func DeviceTokenRedisKey(deviceID string) string {
return DeviceTokenKeyPrefix + deviceID
}
func GetBlacklistRedisKey(jti string) string {
return BlacklistKeyPrefix + jti
}
func GetUIDTokenRedisKey(uid string) string {
return UIDTokenKeyPrefix + uid
}
// Default expiration times (in seconds)
const (
DefaultAccessTokenExpiry = 15 * 60 // 15 minutes
DefaultRefreshTokenExpiry = 7 * 24 * 3600 // 7 days
DefaultOneTimeTokenExpiry = 5 * 60 // 5 minutes
)
// Token limits
const (
MaxTokensPerUser = 10 // Maximum tokens per user
MaxTokensPerDevice = 5 // Maximum tokens per device
)

View File

@ -0,0 +1,340 @@
package token
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestType_String(t *testing.T) {
tests := []struct {
name string
tokenType Type
expected string
}{
{
name: "bearer type",
tokenType: TypeBearer,
expected: "Bearer",
},
{
name: "basic type",
tokenType: TypeBasic,
expected: "Basic",
},
{
name: "custom type",
tokenType: Type("Custom"),
expected: "Custom",
},
{
name: "empty type",
tokenType: Type(""),
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.tokenType.String()
assert.Equal(t, tt.expected, result)
})
}
}
func TestType_IsValid(t *testing.T) {
tests := []struct {
name string
tokenType Type
expected bool
}{
{
name: "bearer is valid",
tokenType: TypeBearer,
expected: true,
},
{
name: "basic is valid",
tokenType: TypeBasic,
expected: true,
},
{
name: "invalid type",
tokenType: Type("Invalid"),
expected: false,
},
{
name: "empty type",
tokenType: Type(""),
expected: false,
},
{
name: "lowercase bearer",
tokenType: Type("bearer"),
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.tokenType.IsValid()
assert.Equal(t, tt.expected, result)
})
}
}
func TestType_Constants(t *testing.T) {
t.Run("verify constant values", func(t *testing.T) {
assert.Equal(t, "Bearer", TypeBearer.String())
assert.Equal(t, "Basic", TypeBasic.String())
})
t.Run("verify constants are valid", func(t *testing.T) {
assert.True(t, TypeBearer.IsValid())
assert.True(t, TypeBasic.IsValid())
})
}
func TestRedisKeyPrefixes(t *testing.T) {
t.Run("verify key prefix constants", func(t *testing.T) {
assert.Equal(t, "access_token:", AccessTokenKeyPrefix)
assert.Equal(t, "refresh_token:", RefreshTokenKeyPrefix)
assert.Equal(t, "one_time_token:", OneTimeTokenKeyPrefix)
assert.Equal(t, "uid_tokens:", UIDTokenKeyPrefix)
assert.Equal(t, "device_tokens:", DeviceTokenKeyPrefix)
assert.Equal(t, "ticket:", TicketKeyPrefix)
assert.Equal(t, "blacklist:", BlacklistKeyPrefix)
})
}
func TestGetAccessTokenRedisKey(t *testing.T) {
tests := []struct {
name string
tokenID string
expected string
}{
{
name: "normal token ID",
tokenID: "token123",
expected: "access_token:token123",
},
{
name: "UUID token ID",
tokenID: "550e8400-e29b-41d4-a716-446655440000",
expected: "access_token:550e8400-e29b-41d4-a716-446655440000",
},
{
name: "empty token ID",
tokenID: "",
expected: "access_token:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetAccessTokenRedisKey(tt.tokenID)
assert.Equal(t, tt.expected, result)
})
}
}
func TestRefreshTokenRedisKey(t *testing.T) {
tests := []struct {
name string
tokenID string
expected string
}{
{
name: "normal token ID",
tokenID: "refresh123",
expected: "refresh_token:refresh123",
},
{
name: "hash token ID",
tokenID: "a1b2c3d4e5f6",
expected: "refresh_token:a1b2c3d4e5f6",
},
{
name: "empty token ID",
tokenID: "",
expected: "refresh_token:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := RefreshTokenRedisKey(tt.tokenID)
assert.Equal(t, tt.expected, result)
})
}
}
func TestUIDTokenRedisKey(t *testing.T) {
tests := []struct {
name string
uid string
expected string
}{
{
name: "normal UID",
uid: "user123",
expected: "uid_tokens:user123",
},
{
name: "UUID UID",
uid: "550e8400-e29b-41d4-a716-446655440000",
expected: "uid_tokens:550e8400-e29b-41d4-a716-446655440000",
},
{
name: "empty UID",
uid: "",
expected: "uid_tokens:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := UIDTokenRedisKey(tt.uid)
assert.Equal(t, tt.expected, result)
})
}
}
func TestDeviceTokenRedisKey(t *testing.T) {
tests := []struct {
name string
deviceID string
expected string
}{
{
name: "normal device ID",
deviceID: "device123",
expected: "device_tokens:device123",
},
{
name: "UUID device ID",
deviceID: "550e8400-e29b-41d4-a716-446655440000",
expected: "device_tokens:550e8400-e29b-41d4-a716-446655440000",
},
{
name: "empty device ID",
deviceID: "",
expected: "device_tokens:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := DeviceTokenRedisKey(tt.deviceID)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetBlacklistRedisKey(t *testing.T) {
tests := []struct {
name string
jti string
expected string
}{
{
name: "normal JTI",
jti: "jti123",
expected: "blacklist:jti123",
},
{
name: "UUID JTI",
jti: "550e8400-e29b-41d4-a716-446655440000",
expected: "blacklist:550e8400-e29b-41d4-a716-446655440000",
},
{
name: "empty JTI",
jti: "",
expected: "blacklist:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetBlacklistRedisKey(tt.jti)
assert.Equal(t, tt.expected, result)
})
}
}
func TestGetUIDTokenRedisKey(t *testing.T) {
tests := []struct {
name string
uid string
expected string
}{
{
name: "normal UID",
uid: "user456",
expected: "uid_tokens:user456",
},
{
name: "empty UID",
uid: "",
expected: "uid_tokens:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetUIDTokenRedisKey(tt.uid)
assert.Equal(t, tt.expected, result)
})
}
}
func TestDefaultExpirationTimes(t *testing.T) {
t.Run("verify default expiration constants", func(t *testing.T) {
assert.Equal(t, int64(15*60), int64(DefaultAccessTokenExpiry))
assert.Equal(t, int64(7*24*3600), int64(DefaultRefreshTokenExpiry))
assert.Equal(t, int64(5*60), int64(DefaultOneTimeTokenExpiry))
})
t.Run("verify expiration times are reasonable", func(t *testing.T) {
assert.Greater(t, int64(DefaultAccessTokenExpiry), int64(0))
assert.Greater(t, int64(DefaultRefreshTokenExpiry), int64(DefaultAccessTokenExpiry))
assert.Greater(t, int64(DefaultOneTimeTokenExpiry), int64(0))
assert.Less(t, int64(DefaultOneTimeTokenExpiry), int64(DefaultAccessTokenExpiry))
})
}
func TestTokenLimits(t *testing.T) {
t.Run("verify token limit constants", func(t *testing.T) {
assert.Equal(t, 10, MaxTokensPerUser)
assert.Equal(t, 5, MaxTokensPerDevice)
})
t.Run("verify limits are reasonable", func(t *testing.T) {
assert.Greater(t, MaxTokensPerUser, 0)
assert.Greater(t, MaxTokensPerDevice, 0)
assert.GreaterOrEqual(t, MaxTokensPerUser, MaxTokensPerDevice)
})
}
func TestKeyPrefixUniqueness(t *testing.T) {
t.Run("all key prefixes should be unique", func(t *testing.T) {
prefixes := []string{
AccessTokenKeyPrefix,
RefreshTokenKeyPrefix,
OneTimeTokenKeyPrefix,
UIDTokenKeyPrefix,
DeviceTokenKeyPrefix,
TicketKeyPrefix,
BlacklistKeyPrefix,
}
seen := make(map[string]bool)
for _, prefix := range prefixes {
assert.False(t, seen[prefix], "duplicate prefix found: %s", prefix)
seen[prefix] = true
}
assert.Equal(t, len(prefixes), len(seen))
})
}

View File

@ -1,38 +0,0 @@
package usecase
import (
"context"
)
// AuthUseCase 認證用例介面
type AuthUseCase interface {
CreateToken(ctx context.Context, req CreateTokenRequest) (*TokenResponse, error)
RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error)
ValidateToken(ctx context.Context, accessToken string) (*TokenClaims, error)
Logout(ctx context.Context, accessToken string) error
LogoutAllByUserID(ctx context.Context, uid string) error
}
// CreateTokenRequest 創建令牌請求
type CreateTokenRequest struct {
ClientID string `json:"client_id"`
GrantType string `json:"grant_type"`
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
DeviceID string `json:"device_id,omitempty"`
}
// TokenResponse 令牌響應
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
}
// TokenClaims 令牌聲明
type TokenClaims struct {
UID string `json:"uid"`
ClientID string `json:"client_id"`
DeviceID string `json:"device_id"`
}

View File

@ -1,90 +0,0 @@
package usecase
import (
"backend/pkg/permission/domain/entity"
"context"
)
// PermissionUseCase 權限用例介面 (使用 Casbin)
type PermissionUseCase interface {
// 基本權限管理
CreatePermission(ctx context.Context, req CreatePermissionRequest) (*entity.Permission, error)
GetPermission(ctx context.Context, id string) (*entity.Permission, error)
UpdatePermission(ctx context.Context, req UpdatePermissionRequest) (*entity.Permission, error)
DeletePermission(ctx context.Context, id string) error
ListPermissions(ctx context.Context, req ListPermissionsRequest) ([]*entity.Permission, error)
// Casbin 權限檢查
CheckUserPermission(ctx context.Context, uid, httpMethod, httpPath string) (bool, error)
CheckRolePermission(ctx context.Context, roleUID, httpMethod, httpPath string) (bool, error)
CheckPatternPermission(ctx context.Context, uid, pattern, action string) (bool, error)
BatchCheckPermissions(ctx context.Context, uid string, permissions []PermissionCheck) (map[string]bool, error)
// 用戶權限管理
GetUserPermissions(ctx context.Context, uid string) (map[string]int, error)
AddPolicyForUser(ctx context.Context, uid, httpPath, httpMethod string) error
RemovePolicyForUser(ctx context.Context, uid, httpPath, httpMethod string) error
// 角色管理
AddRoleForUser(ctx context.Context, uid, roleUID string) error
RemoveRoleForUser(ctx context.Context, uid, roleUID string) error
GetUsersForRole(ctx context.Context, roleUID string) ([]string, error)
GetRolesForUser(ctx context.Context, uid string) ([]string, error)
// 角色權限管理
AddPermissionForRole(ctx context.Context, roleUID, httpPath, httpMethod string) error
RemovePermissionForRole(ctx context.Context, roleUID, httpPath, httpMethod string) error
GetPermissionsForRole(ctx context.Context, roleUID string) (map[string]int, error)
// 策略管理
GetAllPolicies(ctx context.Context) ([][]string, error)
GetFilteredPolicies(ctx context.Context, fieldIndex int, fieldValues ...string) ([][]string, error)
}
// CreatePermissionRequest 創建權限請求
type CreatePermissionRequest struct {
ParentID *string `json:"parent_id,omitempty"`
Name string `json:"name"`
HTTPMethod string `json:"http_method,omitempty"`
HTTPPath string `json:"http_path,omitempty"`
Status int `json:"status"`
Type entity.PermissionType `json:"type"`
}
// UpdatePermissionRequest 更新權限請求
type UpdatePermissionRequest struct {
ID string `json:"id"`
Name *string `json:"name,omitempty"`
HTTPMethod *string `json:"http_method,omitempty"`
HTTPPath *string `json:"http_path,omitempty"`
Status *int `json:"status,omitempty"`
Type *entity.PermissionType `json:"type,omitempty"`
}
// ListPermissionsRequest 列出權限請求
type ListPermissionsRequest struct {
Status *int `json:"status,omitempty"`
Type *entity.PermissionType `json:"type,omitempty"`
ParentID *string `json:"parent_id,omitempty"`
Limit int `json:"limit"`
Skip int `json:"skip"`
}
// PermissionCheck 權限檢查項目
type PermissionCheck struct {
HTTPMethod string `json:"http_method"`
HTTPPath string `json:"http_path"`
}
// CasbinPolicyRequest Casbin 策略請求
type CasbinPolicyRequest struct {
Subject string `json:"subject"` // 用戶或角色
Object string `json:"object"` // 資源
Action string `json:"action"` // 行為
}
// CasbinRoleRequest Casbin 角色請求
type CasbinRoleRequest struct {
User string `json:"user"` // 用戶
Role string `json:"role"` // 角色
}

View File

@ -1,46 +0,0 @@
package usecase
import (
"backend/pkg/permission/domain/entity"
"context"
)
// RoleUseCase 角色用例介面
type RoleUseCase interface {
CreateRole(ctx context.Context, req CreateRoleRequest) (*entity.Role, error)
GetRole(ctx context.Context, id string) (*entity.Role, error)
GetRoleByUID(ctx context.Context, uid string) (*entity.Role, error)
UpdateRole(ctx context.Context, req UpdateRoleRequest) (*entity.Role, error)
DeleteRole(ctx context.Context, id string) error
ListRoles(ctx context.Context, req ListRolesRequest) ([]*entity.Role, error)
AddPermissionToRole(ctx context.Context, roleID string, permissionKey string) error
RemovePermissionFromRole(ctx context.Context, roleID string, permissionKey string) error
BatchUpdateRolePermissions(ctx context.Context, roleID string, permissions entity.Permissions) error
GetRolesByClientID(ctx context.Context, clientID string) ([]*entity.Role, error)
CopyRole(ctx context.Context, sourceRoleID string, req CreateRoleRequest) (*entity.Role, error)
}
// CreateRoleRequest 創建角色請求
type CreateRoleRequest struct {
ClientID string `json:"client_id"`
UID string `json:"uid"`
Name string `json:"name"`
Status int `json:"status"`
Permissions entity.Permissions `json:"permissions"`
}
// UpdateRoleRequest 更新角色請求
type UpdateRoleRequest struct {
ID string `json:"id"`
Name *string `json:"name,omitempty"`
Status *int `json:"status,omitempty"`
Permissions *entity.Permissions `json:"permissions,omitempty"`
}
// ListRolesRequest 列出角色請求
type ListRolesRequest struct {
ClientID string `json:"client_id,omitempty"`
Status *int `json:"status,omitempty"`
Limit int `json:"limit"`
Skip int `json:"skip"`
}

View File

@ -0,0 +1,44 @@
package usecase
import (
"context"
"backend/pkg/permission/domain/entity"
)
// TokenUseCase 定義與 Token 相關的操作接口
//
//nolint:interfacebloat
type TokenUseCase interface {
// NewToken 創建新 Token通常為 Access Token
NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error)
// RefreshToken 刷新目前的 Token包括一次性 Token
RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error)
// CancelToken 取消 Token包括取消其關聯的 One-Time Token
CancelToken(ctx context.Context, req entity.CancelTokenReq) error
// ValidationToken 驗證 Token 是否有效
ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error)
// CancelTokens 根據 UID 或 Token ID 取消所有相關 Token通常在用戶登出時使用
CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error
// CancelTokenByDeviceID 根據 Device ID 取消所有相關的 Token
CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error
// GetUserTokensByDeviceID 根據 Device ID 獲取所有 Token
GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error)
// GetUserTokensByUID 根據 UID 獲取所有 Token
GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error)
// NewOneTimeToken 創建一次性 Token例如 Refresh Token
NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error)
// CancelOneTimeToken 取消一次性 Token
CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error
// ReadTokenBasicData 檢查Token 帶的資料
ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error)
// Blacklist operations
// BlacklistToken 將 JWT token 加入黑名單 (立即撤銷)
BlacklistToken(ctx context.Context, token string, reason string) error
// IsTokenBlacklisted 檢查 JWT token 是否在黑名單中
IsTokenBlacklisted(ctx context.Context, jti string) (bool, error)
// BlacklistAllUserTokens 將用戶的所有 token 加入黑名單 (全設備登出)
BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error
}

View File

@ -1,49 +0,0 @@
package usecase
import (
"backend/pkg/permission/domain/entity"
"context"
)
// UserRoleUseCase 用戶角色用例介面
type UserRoleUseCase interface {
AssignRole(ctx context.Context, req AssignRoleRequest) (*entity.UserRole, error)
RevokeRole(ctx context.Context, uid, roleUID string) error
GetUserRole(ctx context.Context, id string) (*entity.UserRole, error)
UpdateUserRole(ctx context.Context, req UpdateUserRoleRequest) (*entity.UserRole, error)
ListUserRoles(ctx context.Context, req ListUserRolesRequest) ([]*entity.UserRole, error)
GetUserRoles(ctx context.Context, uid string) ([]*entity.UserRole, error)
GetUserRoleDetails(ctx context.Context, uid string) ([]*UserRoleDetail, error)
BatchAssignRoles(ctx context.Context, uid string, roleUIDs []string, brand string) error
BatchRevokeRoles(ctx context.Context, uid string, roleUIDs []string) error
ReplaceUserRoles(ctx context.Context, uid string, roleUIDs []string, brand string) error
}
// AssignRoleRequest 分配角色請求
type AssignRoleRequest struct {
Brand string `json:"brand"`
UID string `json:"uid"`
RoleUID string `json:"role_uid"`
}
// UpdateUserRoleRequest 更新用戶角色請求
type UpdateUserRoleRequest struct {
ID string `json:"id"`
Status *int `json:"status,omitempty"`
}
// ListUserRolesRequest 列出用戶角色請求
type ListUserRolesRequest struct {
Brand string `json:"brand,omitempty"`
UID string `json:"uid,omitempty"`
RoleUID string `json:"role_uid,omitempty"`
Status *int `json:"status,omitempty"`
Limit int `json:"limit"`
Skip int `json:"skip"`
}
// UserRoleDetail 用戶角色詳情
type UserRoleDetail struct {
UserRole *entity.UserRole `json:"user_role"`
Role *entity.Role `json:"role"`
}

View File

@ -0,0 +1,130 @@
package repository
import (
"context"
"time"
"backend/pkg/permission/domain/entity"
"github.com/stretchr/testify/mock"
)
// MockTokenRepository is a mock implementation of TokenRepository
type MockTokenRepository struct {
mock.Mock
}
// NewMockTokenRepository creates a new mock instance
func NewMockTokenRepository(t interface {
mock.TestingT
Cleanup(func())
}) *MockTokenRepository {
mock := &MockTokenRepository{}
mock.Mock.Test(t)
t.Cleanup(func() { mock.AssertExpectations(t) })
return mock
}
// Create provides a mock function with given fields: ctx, token
func (m *MockTokenRepository) Create(ctx context.Context, token entity.Token) error {
ret := m.Called(ctx, token)
return ret.Error(0)
}
// CreateOneTimeToken provides a mock function with given fields: ctx, key, ticket, dt
func (m *MockTokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error {
ret := m.Called(ctx, key, ticket, dt)
return ret.Error(0)
}
// GetAccessTokenByOneTimeToken provides a mock function with given fields: ctx, oneTimeToken
func (m *MockTokenRepository) GetAccessTokenByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) {
ret := m.Called(ctx, oneTimeToken)
return ret.Get(0).(entity.Token), ret.Error(1)
}
// GetAccessTokenByID provides a mock function with given fields: ctx, id
func (m *MockTokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) {
ret := m.Called(ctx, id)
return ret.Get(0).(entity.Token), ret.Error(1)
}
// GetAccessTokensByUID provides a mock function with given fields: ctx, uid
func (m *MockTokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
ret := m.Called(ctx, uid)
return ret.Get(0).([]entity.Token), ret.Error(1)
}
// GetAccessTokenCountByUID provides a mock function with given fields: ctx, uid
func (m *MockTokenRepository) GetAccessTokenCountByUID(ctx context.Context, uid string) (int, error) {
ret := m.Called(ctx, uid)
return ret.Int(0), ret.Error(1)
}
// GetAccessTokensByDeviceID provides a mock function with given fields: ctx, deviceID
func (m *MockTokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
ret := m.Called(ctx, deviceID)
return ret.Get(0).([]entity.Token), ret.Error(1)
}
// GetAccessTokenCountByDeviceID provides a mock function with given fields: ctx, deviceID
func (m *MockTokenRepository) GetAccessTokenCountByDeviceID(ctx context.Context, deviceID string) (int, error) {
ret := m.Called(ctx, deviceID)
return ret.Int(0), ret.Error(1)
}
// Delete provides a mock function with given fields: ctx, token
func (m *MockTokenRepository) Delete(ctx context.Context, token entity.Token) error {
ret := m.Called(ctx, token)
return ret.Error(0)
}
// DeleteOneTimeToken provides a mock function with given fields: ctx, ids, tokens
func (m *MockTokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
ret := m.Called(ctx, ids, tokens)
return ret.Error(0)
}
// DeleteAccessTokenByID provides a mock function with given fields: ctx, ids
func (m *MockTokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
ret := m.Called(ctx, ids)
return ret.Error(0)
}
// DeleteAccessTokensByUID provides a mock function with given fields: ctx, uid
func (m *MockTokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
ret := m.Called(ctx, uid)
return ret.Error(0)
}
// DeleteAccessTokensByDeviceID provides a mock function with given fields: ctx, deviceID
func (m *MockTokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
ret := m.Called(ctx, deviceID)
return ret.Error(0)
}
// AddToBlacklist provides a mock function with given fields: ctx, entry, ttl
func (m *MockTokenRepository) AddToBlacklist(ctx context.Context, entry *entity.BlacklistEntry, ttl time.Duration) error {
ret := m.Called(ctx, entry, ttl)
return ret.Error(0)
}
// IsBlacklisted provides a mock function with given fields: ctx, jti
func (m *MockTokenRepository) IsBlacklisted(ctx context.Context, jti string) (bool, error) {
ret := m.Called(ctx, jti)
return ret.Bool(0), ret.Error(1)
}
// RemoveFromBlacklist provides a mock function with given fields: ctx, jti
func (m *MockTokenRepository) RemoveFromBlacklist(ctx context.Context, jti string) error {
ret := m.Called(ctx, jti)
return ret.Error(0)
}
// GetBlacklistedTokensByUID provides a mock function with given fields: ctx, uid
func (m *MockTokenRepository) GetBlacklistedTokensByUID(ctx context.Context, uid string) ([]*entity.BlacklistEntry, error) {
ret := m.Called(ctx, uid)
return ret.Get(0).([]*entity.BlacklistEntry), ret.Error(1)
}

View File

@ -1,265 +0,0 @@
package repository
import (
"context"
"backend/pkg/library/errs"
"backend/pkg/library/mongo"
"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/persist"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mon"
"go.mongodb.org/mongo-driver/v2/bson"
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
)
// CasbinRule represents a casbin rule in MongoDB
type CasbinRule struct {
ID bson.ObjectID `bson:"_id,omitempty"`
PType string `bson:"ptype"`
V0 string `bson:"v0"`
V1 string `bson:"v1"`
V2 string `bson:"v2"`
V3 string `bson:"v3"`
V4 string `bson:"v4"`
V5 string `bson:"v5"`
}
// CasbinAdapterParam Casbin adapter 參數
type CasbinAdapterParam struct {
Conf *mongo.Conf
CacheConf cache.CacheConf
DBOpts []mon.Option
CacheOpts []cache.Option
}
// CasbinAdapter MongoDB adapter for Casbin
type CasbinAdapter struct {
DB mongo.DocumentDBWithCacheUseCase
}
// NewCasbinAdapter 創建 Casbin adapter
func NewCasbinAdapter(param CasbinAdapterParam) persist.Adapter {
db, err := mongo.MustDocumentDBWithCache(
"casbin_rules",
param.Conf,
param.CacheConf,
param.CacheOpts,
param.DBOpts,
)
return &CasbinAdapter{
DB: db,
}
}
// LoadPolicy loads all policy rules from the storage.
func (a *CasbinAdapter) LoadPolicy(model model.Model) error {
ctx := context.Background()
var rules []CasbinRule
err := a.DB.Find(ctx, bson.M{}, &rules)
if err != nil {
return errs.DatabaseErr(err.Error())
}
for _, rule := range rules {
a.loadPolicyLine(&rule, model)
}
return nil
}
// SavePolicy saves all policy rules to the storage.
func (a *CasbinAdapter) SavePolicy(model model.Model) error {
ctx := context.Background()
// 清空現有規則
err := a.DB.DeleteMany(ctx, bson.M{})
if err != nil {
return errs.DatabaseErr(err.Error())
}
var rules []interface{}
for ptype, ast := range model["p"] {
for _, rule := range ast.Policy {
rules = append(rules, a.savePolicyLine(ptype, rule))
}
}
for ptype, ast := range model["g"] {
for _, rule := range ast.Policy {
rules = append(rules, a.savePolicyLine(ptype, rule))
}
}
if len(rules) > 0 {
_, err = a.DB.InsertMany(ctx, rules)
if err != nil {
return errs.DatabaseErr(err.Error())
}
}
return nil
}
// AddPolicy adds a policy rule to the storage.
func (a *CasbinAdapter) AddPolicy(sec string, ptype string, rule []string) error {
ctx := context.Background()
casbinRule := a.savePolicyLine(ptype, rule)
_, err := a.DB.InsertOne(ctx, casbinRule)
if err != nil {
return errs.DatabaseErr(err.Error())
}
return nil
}
// RemovePolicy removes a policy rule from the storage.
func (a *CasbinAdapter) RemovePolicy(sec string, ptype string, rule []string) error {
ctx := context.Background()
filter := bson.M{"ptype": ptype}
for i, value := range rule {
filter[getFieldName(i)] = value
}
err := a.DB.DeleteMany(ctx, filter)
if err != nil {
return errs.DatabaseErr(err.Error())
}
return nil
}
// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
func (a *CasbinAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
ctx := context.Background()
filter := bson.M{"ptype": ptype}
for i, value := range fieldValues {
if fieldIndex+i <= 5 && value != "" {
filter[getFieldName(fieldIndex+i)] = value
}
}
err := a.DB.DeleteMany(ctx, filter)
if err != nil {
return errs.DatabaseErr(err.Error())
}
return nil
}
// loadPolicyLine loads a line of policy from storage
func (a *CasbinAdapter) loadPolicyLine(rule *CasbinRule, model model.Model) {
lineText := rule.PType
if rule.V0 != "" {
lineText += ", " + rule.V0
}
if rule.V1 != "" {
lineText += ", " + rule.V1
}
if rule.V2 != "" {
lineText += ", " + rule.V2
}
if rule.V3 != "" {
lineText += ", " + rule.V3
}
if rule.V4 != "" {
lineText += ", " + rule.V4
}
if rule.V5 != "" {
lineText += ", " + rule.V5
}
persist.LoadPolicyLine(lineText, model)
}
// savePolicyLine saves a line of policy to storage
func (a *CasbinAdapter) savePolicyLine(ptype string, rule []string) *CasbinRule {
casbinRule := &CasbinRule{
PType: ptype,
}
if len(rule) > 0 {
casbinRule.V0 = rule[0]
}
if len(rule) > 1 {
casbinRule.V1 = rule[1]
}
if len(rule) > 2 {
casbinRule.V2 = rule[2]
}
if len(rule) > 3 {
casbinRule.V3 = rule[3]
}
if len(rule) > 4 {
casbinRule.V4 = rule[4]
}
if len(rule) > 5 {
casbinRule.V5 = rule[5]
}
return casbinRule
}
// getFieldName returns the field name for the given index
func getFieldName(index int) string {
switch index {
case 0:
return "v0"
case 1:
return "v1"
case 2:
return "v2"
case 3:
return "v3"
case 4:
return "v4"
case 5:
return "v5"
default:
return ""
}
}
// Index20241226001UP 創建索引
func (a *CasbinAdapter) Index20241226001UP(ctx context.Context) (bool, error) {
indexes := []mongodriver.IndexModel{
{
Keys: bson.D{
{Key: "ptype", Value: 1},
},
Options: &mongodriver.IndexOptions{
Name: &[]string{"idx_ptype"}[0],
},
},
{
Keys: bson.D{
{Key: "ptype", Value: 1},
{Key: "v0", Value: 1},
},
Options: &mongodriver.IndexOptions{
Name: &[]string{"idx_ptype_v0"}[0],
},
},
{
Keys: bson.D{
{Key: "ptype", Value: 1},
{Key: "v0", Value: 1},
{Key: "v1", Value: 1},
},
Options: &mongodriver.IndexOptions{
Name: &[]string{"idx_ptype_v0_v1"}[0],
},
},
}
// 需要轉換為 mongo.DocumentDBWithCacheUseCase 的 CreateIndexes 方法
// 這裡簡化處理,實際需要根據你的 mongo 包裝實現
return true, nil
}

View File

@ -1,196 +0,0 @@
package repository
import (
"backend/pkg/library/errs/code"
"backend/pkg/permission/domain"
"context"
"errors"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"time"
"backend/pkg/library/errs"
"backend/pkg/library/mongo"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mon"
"go.mongodb.org/mongo-driver/v2/bson"
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
)
type ClientRepositoryParam struct {
Conf *mongo.Conf
CacheConf cache.CacheConf
DBOpts []mon.Option
CacheOpts []cache.Option
}
type ClientRepository struct {
DB mongo.DocumentDBWithCacheUseCase
}
// NewClientRepository 創建客戶端倉庫實例
func NewClientRepository(param ClientRepositoryParam) repository.ClientRepository {
e := entity.Client{}
documentDB, err := mongo.MustDocumentDBWithCache(
param.Conf,
e.CollectionName(),
param.CacheConf,
param.DBOpts,
param.CacheOpts,
)
if err != nil {
panic(err)
}
return &ClientRepository{
DB: documentDB,
}
}
func (repo *ClientRepository) Create(ctx context.Context, client *entity.Client) error {
now := time.Now()
client.CreateTime = now
client.UpdateTime = now
id := bson.NewObjectID()
client.ID = id
rk := domain.GeClientRedisKey(id.Hex())
_, err := repo.DB.InsertOne(ctx, rk, client)
if err != nil {
// 檢查是否為重複鍵錯誤
if mongodriver.IsDuplicateKeyError(err) {
return errs.ResourceAlreadyExist(client.ClientID)
}
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return nil
}
func (repo *ClientRepository) GetByID(ctx context.Context, id string) (*entity.Client, error) {
var client entity.Client
objID, err := bson.ObjectIDFromHex(id)
if err != nil {
return nil, err
}
rk := domain.GeClientRedisKey(objID.Hex())
err = repo.DB.FindOne(ctx, rk, &client, bson.M{"_id": objID})
if err != nil {
if errors.Is(err, mongodriver.ErrNoDocuments) {
return nil, errs.ResourceNotFoundWithScope(
code.CloudEPPermission,
domain.FailedToGetByID,
"failed to get client by id")
}
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return &client, nil
}
func (repo *ClientRepository) GetByClientID(ctx context.Context, clientID string) (*entity.Client, error) {
var client entity.Client
rk := domain.GeClientRedisKey(clientID)
err := repo.DB.FindOne(ctx, rk, &client, bson.M{"client_id": clientID})
if err != nil {
if errors.Is(err, mongodriver.ErrNoDocuments) {
return nil, errs.ResourceNotFoundWithScope(
code.CloudEPPermission,
domain.FailedToGetByClientID,
"failed to get client by client id")
}
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return &client, nil
}
func (repo *ClientRepository) Update(ctx context.Context, id string, client *entity.Client) error {
client.UpdateTime = time.Now()
objID, err := bson.ObjectIDFromHex(id)
if err != nil {
return err
}
update := bson.M{
"$set": bson.M{
"name": client.Name,
"secret": client.Secret,
"status": client.Status,
"update_time": client.UpdateTime,
},
}
gc, err := repo.GetByID(ctx, id)
if err != nil {
return err
}
rk := domain.GeClientRedisKey(objID.Hex())
_, err = repo.DB.UpdateOne(ctx, rk, bson.M{"_id": objID}, update)
if err != nil {
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
rk = domain.GeClientRedisKey(gc.ClientID)
err = repo.DB.DelCache(ctx, rk)
if err != nil {
return err
}
return nil
}
func (repo *ClientRepository) Delete(ctx context.Context, id string) error {
objID, err := bson.ObjectIDFromHex(id)
if err != nil {
return err
}
gc, err := repo.GetByID(ctx, id)
if err != nil {
return err
}
rk := domain.GeClientRedisKey(gc.ClientID)
err = repo.DB.DelCache(ctx, rk)
if err != nil {
return err
}
rk = domain.GeClientRedisKey(objID.Hex())
_, err = repo.DB.DeleteOne(ctx, rk, bson.M{"_id": objID})
if err != nil {
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return nil
}
func (repo *ClientRepository) List(ctx context.Context, filter repository.ClientFilter) ([]*entity.Client, error) {
query := bson.M{}
if filter.Status != nil {
query["status"] = *filter.Status
}
var clients []*entity.Client
err := repo.DB.GetClient().Find(ctx, query, &clients,
options.Find().SetLimit(int64(filter.Limit)),
options.Find().SetSkip(int64(filter.Skip)),
)
if err != nil {
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return clients, nil
}
// Index20241226001UP 創建索引
func (repo *ClientRepository) Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error) {
repo.DB.PopulateIndex(ctx, "client_id", 1, true)
repo.DB.PopulateIndex(ctx, "status", 1, false)
return repo.DB.GetClient().Indexes().List(ctx)
}

View File

@ -1,209 +0,0 @@
package repository
import (
"backend/pkg/library/errs/code"
"backend/pkg/permission/domain"
"backend/pkg/permission/domain/permission"
"context"
"errors"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"time"
"backend/pkg/library/errs"
"backend/pkg/library/mongo"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mon"
"go.mongodb.org/mongo-driver/v2/bson"
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
)
type PermissionRepositoryParam struct {
Conf *mongo.Conf
CacheConf cache.CacheConf
DBOpts []mon.Option
CacheOpts []cache.Option
}
type PermissionRepository struct {
DB mongo.DocumentDBWithCacheUseCase
}
// NewPermissionRepository 創建權限倉庫實例
func NewPermissionRepository(param PermissionRepositoryParam) repository.PermissionRepository {
e := entity.Permission{}
documentDB, err := mongo.MustDocumentDBWithCache(
param.Conf,
e.CollectionName(),
param.CacheConf,
param.DBOpts,
param.CacheOpts,
)
if err != nil {
panic(err)
}
return &PermissionRepository{
DB: documentDB,
}
}
func (repo *PermissionRepository) Create(ctx context.Context, permission *entity.Permission) error {
now := time.Now()
permission.CreateTime = now
permission.UpdateTime = now
id := bson.NewObjectID()
permission.ID = id
rk := domain.GetPermissionRedisKey(id.Hex())
_, err := repo.DB.InsertOne(ctx, rk, permission)
if err != nil {
// 檢查是否為重複鍵錯誤
if mongodriver.IsDuplicateKeyError(err) {
return errs.ResourceAlreadyExist(permission.ID.Hex())
}
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return nil
}
func (repo *PermissionRepository) GetByID(ctx context.Context, id string) (*entity.Permission, error) {
var p entity.Permission
objID, err := bson.ObjectIDFromHex(id)
if err != nil {
return nil, err
}
rk := domain.GetPermissionRedisKey(objID.Hex())
err = repo.DB.FindOne(ctx, rk, &p, bson.M{"_id": objID})
if err != nil {
if errors.Is(err, mongodriver.ErrNoDocuments) {
return nil, errs.ResourceNotFoundWithScope(
code.CloudEPPermission,
domain.FailedToGetPermission,
"failed to get permission by id")
}
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return &p, nil
}
func (repo *PermissionRepository) GetByKey(ctx context.Context, httpMethod, httpPath string) (*entity.Permission, error) {
filter := bson.M{
"http_method": httpMethod,
"http_path": httpPath,
"status": permission.StatusActive,
}
var p entity.Permission
err := repo.DB.GetClient().FindOne(ctx, &p, filter)
if err != nil {
if errors.Is(err, mongodriver.ErrNoDocuments) {
return nil, errs.ResourceNotFoundWithScope(
code.CloudEPPermission, domain.FailedToGetPermissionByKey,
"failed to get permission by key")
}
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return &p, nil
}
func (repo *PermissionRepository) Update(ctx context.Context, id string, permission *entity.Permission) error {
permission.UpdateTime = time.Now()
update := bson.M{
"$set": bson.M{
"parent_id": permission.ParentID,
"name": permission.Name,
"http_method": permission.HTTPMethod,
"http_path": permission.HTTPPath,
"status": permission.Status,
"type": permission.Type,
"update_time": permission.UpdateTime,
},
}
rk := domain.GetPermissionRedisKey(id)
objID, err := bson.ObjectIDFromHex(id)
if err != nil {
return err
}
_, err = repo.DB.UpdateOne(ctx, rk, bson.M{"_id": objID}, update)
if err != nil {
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return nil
}
func (repo *PermissionRepository) Delete(ctx context.Context, id string) error {
rk := domain.GetPermissionRedisKey(id)
objID, err := bson.ObjectIDFromHex(id)
if err != nil {
return err
}
_, err = repo.DB.DeleteOne(ctx, rk, bson.M{"_id": objID})
if err != nil {
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return nil
}
func (repo *PermissionRepository) List(ctx context.Context, filter repository.PermissionFilter) ([]*entity.Permission, error) {
query := bson.M{}
if filter.Status != nil {
query["status"] = *filter.Status
}
if filter.Type != nil {
query["type"] = *filter.Type
}
if filter.ParentID != nil {
query["parent_id"] = *filter.ParentID
}
var permissions []*entity.Permission
err := repo.DB.GetClient().Find(ctx,
&permissions, query,
options.Find().SetLimit(int64(filter.Limit)),
options.Find().SetSkip(int64(filter.Skip)))
if err != nil {
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return permissions, nil
}
func (repo *PermissionRepository) GetActivePermissions(ctx context.Context) ([]*entity.Permission, error) {
status := permission.StatusActive
filter := repository.PermissionFilter{
Status: &status,
}
return repo.List(ctx, filter)
}
// Index20241226001UP 創建索引
func (repo *PermissionRepository) Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error) {
// 等價於 db.account.createIndex({ "login_id": 1, "platform": 1}, {unique: true})
repo.DB.PopulateMultiIndex(ctx, []string{
"http_method",
"http_path",
}, []int32{1, 1}, true)
// 等價於 db.account.createIndex({"create_at": 1})
repo.DB.PopulateIndex(ctx, "name", 1, false)
repo.DB.PopulateIndex(ctx, "status", 1, false)
repo.DB.PopulateIndex(ctx, "type", 1, false)
return repo.DB.GetClient().Indexes().List(ctx)
}

View File

@ -1,233 +0,0 @@
package repository
import (
"backend/pkg/library/errs/code"
"backend/pkg/permission/domain"
"backend/pkg/permission/domain/permission"
"context"
"errors"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"time"
"backend/pkg/library/errs"
"backend/pkg/library/mongo"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mon"
"go.mongodb.org/mongo-driver/v2/bson"
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
)
type RoleRepositoryParam struct {
Conf *mongo.Conf
CacheConf cache.CacheConf
DBOpts []mon.Option
CacheOpts []cache.Option
}
type RoleRepository struct {
DB mongo.DocumentDBWithCacheUseCase
}
// NewRoleRepository 創建角色倉庫實例
func NewRoleRepository(param RoleRepositoryParam) repository.RoleRepository {
e := entity.Role{}
documentDB, err := mongo.MustDocumentDBWithCache(
param.Conf,
e.CollectionName(),
param.CacheConf,
param.DBOpts,
param.CacheOpts,
)
if err != nil {
panic(err)
}
return &RoleRepository{
DB: documentDB,
}
}
func (repo *RoleRepository) Create(ctx context.Context, role *entity.Role) error {
now := time.Now()
role.CreateTime = now
role.UpdateTime = now
id := bson.NewObjectID()
role.ID = id
rk := domain.GetRoleRedisKeyRedisKey(id.Hex())
_, err := repo.DB.InsertOne(ctx, rk, role)
if err != nil {
// 檢查是否為重複鍵錯誤
if mongodriver.IsDuplicateKeyError(err) {
return errs.ResourceAlreadyExist(role.ClientID)
}
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return nil
}
func (repo *RoleRepository) GetByID(ctx context.Context, id string) (*entity.Role, error) {
var role entity.Role
objID, err := bson.ObjectIDFromHex(id)
if err != nil {
return nil, err
}
rk := domain.GetRoleRedisKeyRedisKey(id)
err = repo.DB.FindOne(ctx, rk, &role, bson.M{"client_id": objID})
if err != nil {
if errors.Is(err, mongodriver.ErrNoDocuments) {
return nil, errs.ResourceNotFoundWithScope(
code.CloudEPPermission,
domain.FailedToGetRoleByID,
"failed to get role by id")
}
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return &role, nil
}
func (repo *RoleRepository) GetByUID(ctx context.Context, uid string) (*entity.Role, error) {
var role entity.Role
rk := domain.GetRoleRedisKeyRedisKey(uid)
err := repo.DB.FindOne(ctx, rk, &role, bson.M{"uid": uid, "status": permission.StatusActive})
if err != nil {
if errors.Is(err, mongodriver.ErrNoDocuments) {
return nil, errs.ResourceNotFoundWithScope(
code.CloudEPPermission,
domain.FailedToGetByUID,
"failed to get role by uid")
}
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return &role, nil
}
func (repo *RoleRepository) GetByClientAndName(ctx context.Context, clientID, name string) (*entity.Role, error) {
filter := bson.M{
"client_id": clientID,
"name": name,
"status": permission.StatusActive,
}
var role entity.Role
err := repo.DB.GetClient().FindOne(ctx, &role, filter)
if err != nil {
if errors.Is(err, mongodriver.ErrNoDocuments) {
return nil, errs.ResourceNotFoundWithScope(
code.CloudEPPermission, domain.FailedToGetByClientAndName, "failed to get by client and name")
}
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return &role, nil
}
func (repo *RoleRepository) Update(ctx context.Context, id string, role *entity.Role) error {
role.UpdateTime = time.Now()
objID, err := bson.ObjectIDFromHex(id)
if err != nil {
return err
}
update := bson.M{
"$set": bson.M{
"name": role.Name,
"status": role.Status,
"permissions": role.Permissions,
"update_time": role.UpdateTime,
},
}
rk := domain.GetRoleRedisKeyRedisKey(id)
_, err = repo.DB.UpdateOne(ctx, rk, bson.M{"_id": objID}, update)
if err != nil {
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return nil
}
func (repo *RoleRepository) Delete(ctx context.Context, id string) error {
rk := domain.GetRoleRedisKeyRedisKey(id)
objID, err := bson.ObjectIDFromHex(id)
if err != nil {
return err
}
gc, err := repo.GetByID(ctx, id)
if err != nil {
return err
}
rk = domain.GetRoleRedisKeyRedisKey(gc.UID)
err = repo.DB.DelCache(ctx, rk)
if err != nil {
return err
}
_, err = repo.DB.DeleteOne(ctx, rk, bson.M{"_id": objID})
if err != nil {
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return nil
}
func (repo *RoleRepository) List(ctx context.Context, filter repository.RoleFilter) ([]*entity.Role, error) {
query := bson.M{}
if filter.ClientID != "" {
query["client_id"] = filter.ClientID
}
if filter.Status != nil {
query["status"] = *filter.Status
}
var roles []*entity.Role
err := repo.DB.GetClient().Find(ctx, &roles, query,
options.Find().SetLimit(int64(filter.Limit)),
options.Find().SetSkip(int64(filter.Skip)),
)
if err != nil {
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return roles, nil
}
func (repo *RoleRepository) GetRolesByClientID(ctx context.Context, clientID string) ([]*entity.Role, error) {
status := permission.StatusActive
filter := repository.RoleFilter{
ClientID: clientID,
Status: &status,
}
return repo.List(ctx, filter)
}
// Index20241226001UP 創建索引
func (repo *RoleRepository) Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error) {
// 等價於 db.account.createIndex({ "login_id": 1, "platform": 1}, {unique: true})
repo.DB.PopulateMultiIndex(ctx, []string{
"client_id",
"name",
}, []int32{1, 1}, true)
// 等價於 db.account.createIndex({"create_at": 1})
repo.DB.PopulateIndex(ctx, "uid", 1, true)
repo.DB.PopulateIndex(ctx, "status", 1, false)
return repo.DB.GetClient().Indexes().List(ctx)
}

View File

@ -1,145 +0,0 @@
package repository
import (
"backend/pkg/library/errs"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"context"
"github.com/zeromicro/go-zero/core/stores/redis"
"strings"
"time"
)
// Token Repository Implementation
type TokenRepositoryParam struct {
Redis *redis.Redis
}
type TokenRepository struct {
Redis *redis.Redis
}
// NewTokenRepository 創建令牌倉庫實例
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
return &TokenRepository{
Redis: param.Redis,
}
}
func (r *TokenRepository) Create(ctx context.Context, token *entity.Token) error {
// 驗證數據
if err := token.Validate(); err != nil {
return errs.InvalidFormat(err.Error())
}
token.CreateTime = time.Now()
token.UpdateTime = time.Now()
// 在 Redis 中存儲 access token
accessKey := "token:access:" + token.AccessToken
refreshKey := "token:refresh:" + token.RefreshToken
// 設置過期時間
expiry := int(time.Until(token.ExpiresAt).Seconds())
if expiry <= 0 {
return errs.InvalidFormat("token already expired")
}
// 存儲 access token
err := r.Redis.SetexCtx(ctx, accessKey, token.UID+":"+token.ClientID+":"+token.DeviceID, expiry)
if err != nil {
return errs.DatabaseErr(err.Error())
}
// 存儲 refresh token (較長的過期時間)
refreshExpiry := expiry * 7 // refresh token 過期時間是 access token 的 7 倍
err = r.Redis.SetexCtx(ctx, refreshKey, token.UID+":"+token.ClientID+":"+token.DeviceID, refreshExpiry)
if err != nil {
return errs.DatabaseErr(err.Error())
}
return nil
}
func (r *TokenRepository) GetByAccessToken(ctx context.Context, accessToken string) (*entity.Token, error) {
key := "token:access:" + accessToken
value, err := r.Redis.GetCtx(ctx, key)
if err != nil {
if err == redis.Nil {
return nil, errs.NotFound("access_token")
}
return nil, errs.DatabaseErr(err.Error())
}
// 解析值
parts := strings.Split(value, ":")
if len(parts) != 3 {
return nil, errs.InvalidFormat("invalid token format")
}
return &entity.Token{
UID: parts[0],
ClientID: parts[1],
DeviceID: parts[2],
AccessToken: accessToken,
}, nil
}
func (r *TokenRepository) GetByRefreshToken(ctx context.Context, refreshToken string) (*entity.Token, error) {
key := "token:refresh:" + refreshToken
value, err := r.Redis.GetCtx(ctx, key)
if err != nil {
if err == redis.Nil {
return nil, errs.NotFound("refresh_token")
}
return nil, errs.DatabaseErr(err.Error())
}
// 解析值
parts := strings.Split(value, ":")
if len(parts) != 3 {
return nil, errs.InvalidFormat("invalid token format")
}
return &entity.Token{
UID: parts[0],
ClientID: parts[1],
DeviceID: parts[2],
RefreshToken: refreshToken,
}, nil
}
func (r *TokenRepository) Update(ctx context.Context, token *entity.Token) error {
// 驗證數據
if err := token.Validate(); err != nil {
return errs.InvalidFormat(err.Error())
}
token.UpdateTime = time.Now()
// 重新存儲 access token
accessKey := "token:access:" + token.AccessToken
expiry := int(time.Until(token.ExpiresAt).Seconds())
if expiry <= 0 {
return errs.InvalidFormat("token already expired")
}
err := r.Redis.SetexCtx(ctx, accessKey, token.UID+":"+token.ClientID+":"+token.DeviceID, expiry)
if err != nil {
return errs.DatabaseErr(err.Error())
}
return nil
}
func (r *TokenRepository) Delete(ctx context.Context, id bson.ObjectID) error {
// Redis 版本不需要 ObjectID這裡留空實現
return nil
}
func (r *TokenRepository) DeleteByUserID(ctx context.Context, uid string) error {
// 可以實現刪除用戶所有 token 的邏輯
// 這裡簡化實現
return nil
}

View File

@ -0,0 +1,382 @@
package repository
import (
"context"
"testing"
"time"
"backend/pkg/permission/domain/entity"
"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stores/redis"
)
func setupMiniRedis() (*miniredis.Miniredis, *redis.Redis) {
// 啟動 setupMiniRedis 作為模擬的 Redis 服務
mr, err := miniredis.Run()
if err != nil {
panic("failed to start miniRedis: " + err.Error())
}
// 使用 setupMiniRedis 的地址配置 go-zero Redis 客戶端
redisConf := redis.RedisConf{
Host: mr.Addr(),
Type: "node",
}
r := redis.MustNewRedis(redisConf)
return mr, r
}
func TestTokenRepository_Blacklist(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
ctx := context.Background()
t.Run("AddToBlacklist", func(t *testing.T) {
entry := &entity.BlacklistEntry{
JTI: "test-jti-123",
UID: "user123",
TokenID: "token123",
Reason: "user logout",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
}
err := repo.AddToBlacklist(ctx, entry, time.Hour)
assert.NoError(t, err)
// Verify it was added
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
assert.NoError(t, err)
assert.True(t, isBlacklisted)
})
t.Run("IsBlacklisted - not found", func(t *testing.T) {
isBlacklisted, err := repo.IsBlacklisted(ctx, "non-existent-jti")
assert.NoError(t, err)
assert.False(t, isBlacklisted)
})
t.Run("RemoveFromBlacklist", func(t *testing.T) {
// First add an entry
entry := &entity.BlacklistEntry{
JTI: "test-jti-456",
UID: "user456",
TokenID: "token456",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
}
err := repo.AddToBlacklist(ctx, entry, time.Hour)
assert.NoError(t, err)
// Verify it exists
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
assert.NoError(t, err)
assert.True(t, isBlacklisted)
// Remove it
err = repo.RemoveFromBlacklist(ctx, entry.JTI)
assert.NoError(t, err)
// Verify it's gone
isBlacklisted, err = repo.IsBlacklisted(ctx, entry.JTI)
assert.NoError(t, err)
assert.False(t, isBlacklisted)
})
t.Run("GetBlacklistedTokensByUID", func(t *testing.T) {
uid := "user789"
// Add multiple entries for the same user
entries := []*entity.BlacklistEntry{
{
JTI: "jti-1",
UID: uid,
TokenID: "token-1",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
},
{
JTI: "jti-2",
UID: uid,
TokenID: "token-2",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
},
{
JTI: "jti-3",
UID: "different-user",
TokenID: "token-3",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
},
}
for _, entry := range entries {
err := repo.AddToBlacklist(ctx, entry, time.Hour)
assert.NoError(t, err)
}
// Get blacklisted tokens for the user
userEntries, err := repo.GetBlacklistedTokensByUID(ctx, uid)
assert.NoError(t, err)
assert.Len(t, userEntries, 2) // Should only get entries for the specific user
// Verify all returned entries belong to the correct user
for _, entry := range userEntries {
assert.Equal(t, uid, entry.UID)
}
})
t.Run("AddToBlacklist with zero TTL", func(t *testing.T) {
entry := &entity.BlacklistEntry{
JTI: "test-jti-zero-ttl",
UID: "user-zero-ttl",
ExpiresAt: time.Now().Add(time.Hour).Unix(),
CreatedAt: time.Now().Unix(),
}
// Test with zero TTL - should calculate from ExpiresAt
err := repo.AddToBlacklist(ctx, entry, 0)
assert.NoError(t, err)
// Verify it was added
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
assert.NoError(t, err)
assert.True(t, isBlacklisted)
})
t.Run("AddToBlacklist with expired token", func(t *testing.T) {
entry := &entity.BlacklistEntry{
JTI: "test-jti-expired",
UID: "user-expired",
ExpiresAt: time.Now().Add(-time.Hour).Unix(), // Already expired
CreatedAt: time.Now().Unix(),
}
// Should not add expired token to blacklist
err := repo.AddToBlacklist(ctx, entry, 0)
assert.NoError(t, err) // No error, but token won't be added
// Verify it was not added
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
assert.NoError(t, err)
assert.False(t, isBlacklisted)
})
}
func TestTokenRepository_CreateAndGet(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
ctx := context.Background()
t.Run("Create and GetAccessTokenByID", func(t *testing.T) {
now := time.Now()
token := entity.Token{
ID: "test-token-123",
UID: "user123",
DeviceID: "device123",
AccessToken: "access-token-123",
ExpiresIn: 3600,
AccessCreateAt: now,
RefreshToken: "refresh-token-123",
RefreshCreateAt: now,
RefreshExpiresIn: 86400,
}
// Create token
err := repo.Create(ctx, token)
assert.NoError(t, err)
// Get token by ID
retrievedToken, err := repo.GetAccessTokenByID(ctx, token.ID)
assert.NoError(t, err)
assert.Equal(t, token.ID, retrievedToken.ID)
assert.Equal(t, token.UID, retrievedToken.UID)
assert.Equal(t, token.AccessToken, retrievedToken.AccessToken)
})
t.Run("GetAccessTokensByUID", func(t *testing.T) {
uid := "user456"
now := time.Now()
tokens := []entity.Token{
{
ID: "token-1",
UID: uid,
DeviceID: "device1",
AccessToken: "access-1",
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
},
{
ID: "token-2",
UID: uid,
DeviceID: "device2",
AccessToken: "access-2",
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
},
}
// Create tokens
for _, token := range tokens {
err := repo.Create(ctx, token)
assert.NoError(t, err)
}
// Get tokens by UID
retrievedTokens, err := repo.GetAccessTokensByUID(ctx, uid)
assert.NoError(t, err)
assert.Len(t, retrievedTokens, 2)
// Verify all tokens belong to the user
for _, token := range retrievedTokens {
assert.Equal(t, uid, token.UID)
}
})
t.Run("GetAccessTokenCountByUID", func(t *testing.T) {
uid := "user789"
now := time.Now()
// Create multiple tokens for the user
for i := 0; i < 3; i++ {
token := entity.Token{
ID: "count-token-" + string(rune(i+'1')),
UID: uid,
DeviceID: "device" + string(rune(i+'1')),
AccessToken: "access-" + string(rune(i+'1')),
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
}
err := repo.Create(ctx, token)
assert.NoError(t, err)
}
// Get count
count, err := repo.GetAccessTokenCountByUID(ctx, uid)
assert.NoError(t, err)
assert.Equal(t, 3, count)
})
t.Run("Delete", func(t *testing.T) {
token := entity.Token{
ID: "delete-token",
UID: "delete-user",
DeviceID: "delete-device",
AccessToken: "delete-access",
RefreshToken: "delete-refresh",
ExpiresIn: 3600,
}
// Create token
err := repo.Create(ctx, token)
assert.NoError(t, err)
// Verify it exists
_, err = repo.GetAccessTokenByID(ctx, token.ID)
assert.NoError(t, err)
// Delete token
err = repo.Delete(ctx, token)
assert.NoError(t, err)
// Verify it's gone
_, err = repo.GetAccessTokenByID(ctx, token.ID)
assert.Error(t, err) // Should return error when not found
})
t.Run("DeleteAccessTokensByUID", func(t *testing.T) {
uid := "delete-user-uid"
now := time.Now()
// Create multiple tokens for the user
for i := 0; i < 2; i++ {
token := entity.Token{
ID: "delete-uid-token-" + string(rune(i+'1')),
UID: uid,
DeviceID: "device" + string(rune(i+'1')),
AccessToken: "access-" + string(rune(i+'1')),
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
}
err := repo.Create(ctx, token)
assert.NoError(t, err)
}
// Verify tokens exist
count, err := repo.GetAccessTokenCountByUID(ctx, uid)
assert.NoError(t, err)
assert.Equal(t, 2, count)
// Delete all tokens for the user
err = repo.DeleteAccessTokensByUID(ctx, uid)
assert.NoError(t, err)
// Verify tokens are gone
count, err = repo.GetAccessTokenCountByUID(ctx, uid)
assert.NoError(t, err)
assert.Equal(t, 0, count)
})
}
func TestTokenRepository_OneTimeToken(t *testing.T) {
mr, r := setupMiniRedis()
defer mr.Close()
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
ctx := context.Background()
t.Run("CreateOneTimeToken", func(t *testing.T) {
now := time.Now()
// Create one-time token with ticket
token := entity.Token{
ID: "one-time-base-token",
UID: "user123",
AccessToken: "base-access-token",
ExpiresIn: int(now.Add(time.Hour).Unix()),
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
}
oneTimeKey := "one-time-key-123"
ticket := entity.Ticket{
Data: map[string]string{"uid": "user123"},
Token: token,
}
err := repo.CreateOneTimeToken(ctx, oneTimeKey, ticket, time.Minute)
assert.NoError(t, err)
})
t.Run("DeleteOneTimeToken", func(t *testing.T) {
// Create one-time tokens
keys := []string{"delete-key-1", "delete-key-2"}
ticket := entity.Ticket{
Data: map[string]string{"test": "data"},
Token: entity.Token{ID: "test-token"},
}
for _, key := range keys {
err := repo.CreateOneTimeToken(ctx, key, ticket, time.Minute)
assert.NoError(t, err)
}
// Delete one-time tokens
err := repo.DeleteOneTimeToken(ctx, keys, nil)
assert.NoError(t, err)
// Verify they're gone
for _, key := range keys {
_, err := repo.GetAccessTokenByOneTimeToken(ctx, key)
assert.Error(t, err)
}
})
}

View File

@ -0,0 +1,430 @@
package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"backend/pkg/permission/domain/token"
"github.com/zeromicro/go-zero/core/stores/redis"
)
// TokenRepositoryParam token 需要的參數
type TokenRepositoryParam struct {
Redis *redis.Redis
}
// TokenRepository 通知
type TokenRepository struct {
TokenRepositoryParam
}
func MustTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
return &TokenRepository{
param,
}
}
// Create 創建一個新 Token並將其存儲於 Redis
func (repo *TokenRepository) Create(ctx context.Context, token entity.Token) error {
body, err := json.Marshal(token)
if err != nil {
return err
}
refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second
return repo.runPipeline(ctx, func(tx redis.Pipeliner) error {
if err := repo.setToken(ctx, tx, token, body, refreshTTL); err != nil {
return err
}
if err := repo.setRefreshToken(ctx, tx, token, refreshTTL); err != nil {
return err
}
return repo.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL)
})
}
func (repo *TokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error {
body, err := json.Marshal(ticket)
if err != nil {
return err
}
_, err = repo.Redis.SetnxExCtx(ctx, token.RefreshTokenRedisKey(key), string(body), int(dt.Seconds()))
if err != nil {
return err
}
return nil
}
func (repo *TokenRepository) GetAccessTokenByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) {
id, err := repo.Redis.Get(token.RefreshTokenRedisKey(oneTimeToken))
if err != nil {
return entity.Token{}, err
}
if id == "" {
return entity.Token{}, fmt.Errorf("token not found")
}
return repo.GetAccessTokenByID(ctx, id)
}
func (repo *TokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) {
return repo.get(ctx, token.GetAccessTokenRedisKey(id))
}
func (repo *TokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
return repo.getTokensBySet(ctx, token.GetUIDTokenRedisKey(uid))
}
func (repo *TokenRepository) GetAccessTokenCountByUID(ctx context.Context, uid string) (int, error) {
return repo.getCountBySet(ctx, token.UIDTokenRedisKey(uid))
}
func (repo *TokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
return repo.getTokensBySet(ctx, token.DeviceTokenRedisKey(deviceID))
}
func (repo *TokenRepository) GetAccessTokenCountByDeviceID(ctx context.Context, deviceID string) (int, error) {
return repo.getCountBySet(ctx, token.DeviceTokenRedisKey(deviceID))
}
func (repo *TokenRepository) Delete(ctx context.Context, tokenObj entity.Token) error {
// Delete 刪除指定的 Token
keys := []string{
token.GetAccessTokenRedisKey(tokenObj.ID),
token.RefreshTokenRedisKey(tokenObj.RefreshToken),
}
return repo.deleteKeysAndRelations(ctx, keys, tokenObj.UID, tokenObj.DeviceID, tokenObj.ID)
}
func (repo *TokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
l := len(ids) + len(tokens)
keys := make([]string, 0, l)
for _, id := range ids {
keys = append(keys, token.RefreshTokenRedisKey(id))
}
for _, tokenObj := range tokens {
keys = append(keys, token.RefreshTokenRedisKey(tokenObj.RefreshToken))
}
return repo.deleteKeys(ctx, keys...)
}
func (repo *TokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
for _, tokenID := range ids {
tokenObj, err := repo.GetAccessTokenByID(ctx, tokenID)
if err != nil {
continue
}
keys := []string{
token.GetAccessTokenRedisKey(tokenObj.ID),
token.RefreshTokenRedisKey(tokenObj.RefreshToken),
}
_ = repo.deleteKeysAndRelations(ctx, keys, tokenObj.UID, tokenObj.DeviceID, tokenObj.ID)
}
return nil
}
func (repo *TokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
tokens, err := repo.GetAccessTokensByUID(ctx, uid)
if err != nil {
return err
}
for _, token := range tokens {
if err := repo.Delete(ctx, token); err != nil {
return err
}
}
return nil
}
func (repo *TokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
tokens, err := repo.GetAccessTokensByDeviceID(ctx, deviceID)
if err != nil {
return err
}
l := len(tokens) * 2
keys := make([]string, 0, l)
for _, tokenObj := range tokens {
keys = append(keys, token.GetAccessTokenRedisKey(tokenObj.ID))
keys = append(keys, token.RefreshTokenRedisKey(tokenObj.RefreshToken))
}
err = repo.runPipeline(ctx, func(tx redis.Pipeliner) error {
for _, tokenObj := range tokens {
tx.SRem(ctx, token.UIDTokenRedisKey(tokenObj.UID), tokenObj.ID)
}
return nil
})
if err != nil {
return err
}
if err := repo.deleteKeys(ctx, keys...); err != nil {
return err
}
_, err = repo.Redis.Del(token.DeviceTokenRedisKey(deviceID))
return err
}
// ========================================================================
// deleteKeysAndRelations 刪除指定鍵並移除相關的關聯
func (repo *TokenRepository) deleteKeysAndRelations(ctx context.Context, keys []string, uid, deviceID, tokenID string) error {
err := repo.Redis.Pipelined(func(tx redis.Pipeliner) error {
// 刪除 UID 和 DeviceID 的關聯
_ = tx.SRem(ctx, token.UIDTokenRedisKey(uid), tokenID)
_ = tx.SRem(ctx, token.DeviceTokenRedisKey(deviceID), tokenID)
for _, key := range keys {
_ = tx.Del(ctx, key)
}
return nil
})
if err != nil {
return err
}
return nil
}
// runPipeline 執行 Redis 的 Pipeline 操作
func (repo *TokenRepository) runPipeline(ctx context.Context, fn func(tx redis.Pipeliner) error) error {
if err := repo.Redis.PipelinedCtx(ctx, fn); err != nil {
return err
}
return nil
}
// deleteKeys 批量刪除 Redis 鍵
func (repo *TokenRepository) deleteKeys(ctx context.Context, keys ...string) error {
return repo.Redis.Pipelined(func(tx redis.Pipeliner) error {
for _, key := range keys {
if err := tx.Del(ctx, key).Err(); err != nil {
return err
}
}
return nil
})
}
func (repo *TokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, tokenObj entity.Token, body []byte, ttl time.Duration) error {
return tx.Set(ctx, token.GetAccessTokenRedisKey(tokenObj.ID), body, ttl).Err()
}
func (repo *TokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, tokenObj entity.Token, ttl time.Duration) error {
if tokenObj.RefreshToken != "" {
return tx.Set(ctx, token.RefreshTokenRedisKey(tokenObj.RefreshToken), tokenObj.ID, ttl).Err()
}
return nil
}
func (repo *TokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error {
if err := tx.SAdd(ctx, token.UIDTokenRedisKey(uid), tokenID).Err(); err != nil {
return err
}
// 設置 UID 鍵的過期時間
if err := tx.Expire(ctx, token.UIDTokenRedisKey(uid), ttl).Err(); err != nil {
return err
}
if err := tx.SAdd(ctx, token.DeviceTokenRedisKey(deviceID), tokenID).Err(); err != nil {
return err
}
// 設置 deviceID 鍵的過期時間
if err := tx.Expire(ctx, token.DeviceTokenRedisKey(deviceID), ttl).Err(); err != nil {
return err
}
return nil
}
// get 根據鍵獲取 Token
func (repo *TokenRepository) get(ctx context.Context, key string) (entity.Token, error) {
body, err := repo.Redis.GetCtx(ctx, key)
if err != nil {
return entity.Token{}, err
}
if body == "" {
return entity.Token{}, fmt.Errorf("token not found")
}
var token entity.Token
if err := json.Unmarshal([]byte(body), &token); err != nil {
return entity.Token{}, fmt.Errorf("json.Marshal token error")
}
return token, nil
}
// getTokensBySet 根據集合鍵獲取所有 Token
func (repo *TokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) {
ids, err := repo.Redis.Smembers(setKey)
if err != nil {
if errors.Is(err, redis.Nil) {
return nil, nil
}
return nil, err
}
tokens := make([]entity.Token, 0, len(ids))
var deleteTokens []string
now := time.Now().Unix()
for _, id := range ids {
token, err := repo.get(ctx, token.GetAccessTokenRedisKey(id))
if err != nil {
deleteTokens = append(deleteTokens, id)
continue
}
if int64(token.ExpiresIn) < now {
deleteTokens = append(deleteTokens, id)
continue
}
tokens = append(tokens, token)
}
if len(deleteTokens) > 0 {
_ = repo.DeleteAccessTokenByID(ctx, deleteTokens)
}
return tokens, nil
}
// getCountBySet 獲取集合中的元素數量
func (repo *TokenRepository) getCountBySet(ctx context.Context, setKey string) (int, error) {
count, err := repo.Redis.ScardCtx(ctx, setKey)
if err != nil {
return 0, err
}
return int(count), nil
}
// AddToBlacklist 將 token 加入黑名單
func (repo *TokenRepository) AddToBlacklist(ctx context.Context, entry *entity.BlacklistEntry, ttl time.Duration) error {
key := token.GetBlacklistRedisKey(entry.JTI)
// 序列化黑名單條目
data, err := json.Marshal(entry)
if err != nil {
return fmt.Errorf("failed to marshal blacklist entry: %w", err)
}
// 使用提供的 TTL如果 TTL <= 0則計算默認 TTL
if ttl <= 0 {
// 計算 TTL (token 過期時間 - 當前時間)
ttl = time.Unix(entry.ExpiresAt, 0).Sub(time.Now())
if ttl <= 0 {
// Token 已經過期,不需要加入黑名單
return nil
}
}
// 存儲到 Redis 並設置過期時間
err = repo.Redis.SetexCtx(ctx, key, string(data), int(ttl.Seconds()))
if err != nil {
return fmt.Errorf("failed to add token to blacklist: %w", err)
}
return nil
}
// IsBlacklisted 檢查 token 是否在黑名單中
func (repo *TokenRepository) IsBlacklisted(ctx context.Context, jti string) (bool, error) {
key := token.GetBlacklistRedisKey(jti)
exists, err := repo.Redis.ExistsCtx(ctx, key)
if err != nil {
return false, fmt.Errorf("failed to check blacklist: %w", err)
}
return exists, nil
}
// RemoveFromBlacklist 從黑名單中移除 token
func (repo *TokenRepository) RemoveFromBlacklist(ctx context.Context, jti string) error {
key := token.GetBlacklistRedisKey(jti)
_, err := repo.Redis.DelCtx(ctx, key)
if err != nil {
return fmt.Errorf("failed to remove token from blacklist: %w", err)
}
return nil
}
// GetBlacklistedTokensByUID 獲取用戶的所有黑名單 token
func (repo *TokenRepository) GetBlacklistedTokensByUID(ctx context.Context, uid string) ([]*entity.BlacklistEntry, error) {
// 使用 SCAN 來查找所有黑名單鍵
pattern := token.BlacklistKeyPrefix + "*"
var entries []*entity.BlacklistEntry
var cursor uint64 = 0
for {
keys, nextCursor, err := repo.Redis.ScanCtx(ctx, cursor, pattern, 100)
if err != nil {
return nil, fmt.Errorf("failed to scan blacklist keys: %w", err)
}
// 獲取每個鍵的值並檢查 UID
for _, key := range keys {
data, err := repo.Redis.GetCtx(ctx, key)
if err != nil {
if errors.Is(err, redis.Nil) {
continue // 鍵已過期或不存在
}
return nil, fmt.Errorf("failed to get blacklist entry: %w", err)
}
var entry entity.BlacklistEntry
if err := json.Unmarshal([]byte(data), &entry); err != nil {
continue // 跳過無效的條目
}
// 檢查 UID 是否匹配
if entry.UID == uid {
entries = append(entries, &entry)
}
}
cursor = nextCursor
if cursor == 0 {
break
}
}
return entries, nil
}

View File

@ -1,238 +0,0 @@
package repository
import (
"backend/pkg/library/errs/code"
"backend/pkg/permission/domain"
"backend/pkg/permission/domain/permission"
"context"
"errors"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"time"
"backend/pkg/library/errs"
"backend/pkg/library/mongo"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/mon"
"go.mongodb.org/mongo-driver/v2/bson"
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
)
type UserRoleRepositoryParam struct {
Conf *mongo.Conf
CacheConf cache.CacheConf
DBOpts []mon.Option
CacheOpts []cache.Option
}
type UserRoleRepository struct {
DB mongo.DocumentDBWithCacheUseCase
}
// NewUserRoleRepository 創建用戶角色倉庫實例
func NewUserRoleRepository(param UserRoleRepositoryParam) repository.UserRoleRepository {
e := entity.UserRole{}
documentDB, err := mongo.MustDocumentDBWithCache(
param.Conf,
e.CollectionName(),
param.CacheConf,
param.DBOpts,
param.CacheOpts,
)
if err != nil {
panic(err)
}
return &UserRoleRepository{
DB: documentDB,
}
}
func (repo *UserRoleRepository) Create(ctx context.Context, userRole *entity.UserRole) error {
now := time.Now()
userRole.CreateTime = now
userRole.UpdateTime = now
id := bson.NewObjectID()
userRole.ID = id
rk := domain.GetUserRoleRedisKey(id.Hex())
userRole.CreateTime = time.Now()
userRole.UpdateTime = time.Now()
_, err := repo.DB.InsertOne(ctx, rk, userRole)
if err != nil {
// 檢查是否為重複鍵錯誤
if mongodriver.IsDuplicateKeyError(err) {
return errs.ResourceAlreadyExist("failed to insert user role")
}
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return nil
}
func (repo *UserRoleRepository) GetByID(ctx context.Context, id string) (*entity.UserRole, error) {
var userRole entity.UserRole
objID, err := bson.ObjectIDFromHex(id)
if err != nil {
return nil, err
}
rk := domain.GetUserRoleRedisKey(id)
err = repo.DB.FindOne(ctx, rk, &userRole, bson.M{"_id": objID})
if err != nil {
if errors.Is(err, mongodriver.ErrNoDocuments) {
return nil, errs.ResourceNotFoundWithScope(
code.CloudEPPermission,
domain.FailedToGetRoleByID,
"failed to get user role by id")
}
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return &userRole, nil
}
func (repo *UserRoleRepository) GetByUserAndRole(ctx context.Context, uid, roleUID string) (*entity.UserRole, error) {
filter := bson.M{
"uid": uid,
"role_uid": roleUID,
"status": permission.StatusActive,
}
var userRole entity.UserRole
err := repo.DB.GetClient().Find(ctx, &userRole, filter)
if err != nil {
if errors.Is(err, mongodriver.ErrNoDocuments) {
return nil, errs.ResourceNotFoundWithScope(code.CloudEPPermission, 0, "failed to get user and role")
}
return nil, errs.DatabaseErrorWithScope(code.CloudEPPermission, 0, err.Error())
}
return &userRole, nil
}
func (repo *UserRoleRepository) Update(ctx context.Context, id string, userRole *entity.UserRole) error {
userRole.UpdateTime = time.Now()
objID, err := bson.ObjectIDFromHex(id)
if err != nil {
return err
}
update := bson.M{
"$set": bson.M{
"status": userRole.Status,
"update_time": userRole.UpdateTime,
},
}
rk := domain.GetUserRoleRedisKey(id)
_, err = repo.DB.UpdateOne(ctx, rk, bson.M{"_id": objID}, update)
if err != nil {
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return nil
}
func (repo *UserRoleRepository) Delete(ctx context.Context, id string) error {
objID, err := bson.ObjectIDFromHex(id)
if err != nil {
return err
}
rk := domain.GetUserRoleRedisKey(id)
_, err = repo.DB.DeleteOne(ctx, rk, bson.M{"_id": objID})
if err != nil {
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return nil
}
func (repo *UserRoleRepository) List(ctx context.Context, filter repository.UserRoleFilter) ([]*entity.UserRole, error) {
query := bson.M{}
if filter.Brand != "" {
query["brand"] = filter.Brand
}
if filter.UID != "" {
query["uid"] = filter.UID
}
if filter.RoleUID != "" {
query["role_uid"] = filter.RoleUID
}
if filter.Status != nil {
query["status"] = *filter.Status
}
var userRoles []*entity.UserRole
err := repo.DB.GetClient().Find(ctx, &userRoles, query)
if err != nil {
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
err = repo.DB.GetClient().Find(ctx,
&userRoles, query,
options.Find().SetLimit(int64(filter.Limit)),
options.Find().SetSkip(int64(filter.Skip)))
if err != nil {
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return userRoles, nil
}
func (repo *UserRoleRepository) GetUserRolesByUID(ctx context.Context, uid string) ([]*entity.UserRole, error) {
status := permission.StatusActive
filter := repository.UserRoleFilter{
UID: uid,
Status: &status,
}
return repo.List(ctx, filter)
}
func (repo *UserRoleRepository) DeleteByUserAndRole(ctx context.Context, uid, roleUID string) error {
filter := repository.UserRoleFilter{
UID: uid,
RoleUID: roleUID,
}
list, err := repo.List(ctx, filter)
if err != nil {
return err
}
if len(list) == 0 {
return nil
}
for _, item := range list {
_ = repo.DB.DelCache(ctx, domain.GetUserRoleRedisKey(item.ID.Hex()))
}
_, err = repo.DB.GetClient().DeleteMany(ctx, filter)
if err != nil {
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
}
return nil
}
// Index20241226001UP 創建索引
func (repo *UserRoleRepository) Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error) {
// 等價於 db.account.createIndex({ "login_id": 1, "platform": 1}, {unique: true})
repo.DB.PopulateMultiIndex(ctx, []string{
"uid",
"role_uid",
}, []int32{1, 1}, true)
// 等價於 db.account.createIndex({"create_at": 1})
repo.DB.PopulateIndex(ctx, "uid", 1, false)
repo.DB.PopulateIndex(ctx, "status", 1, false)
return repo.DB.GetClient().Indexes().List(ctx)
}

View File

@ -1,207 +0,0 @@
package usecase
import (
"backend/pkg/library/errs/code"
"backend/pkg/permission/utils"
"context"
"crypto/rand"
"encoding/hex"
"time"
"backend/pkg/library/errs"
"backend/pkg/permission/domain/config"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"backend/pkg/permission/domain/usecase"
"github.com/golang-jwt/jwt/v5"
)
type AuthUseCaseParam struct {
ClientRepo repository.ClientRepository
TokenRepo repository.TokenRepository
JWTConfig config.JWTConfig
}
type AuthUseCase struct {
clientRepo repository.ClientRepository
tokenRepo repository.TokenRepository
jwtConfig config.JWTConfig
}
// MustAuthUseCase 創建認證用例實例
func MustAuthUseCase(param AuthUseCaseParam) usecase.AuthUseCase {
return &AuthUseCase{
clientRepo: param.ClientRepo,
tokenRepo: param.TokenRepo,
jwtConfig: param.JWTConfig,
}
}
func (uc *AuthUseCase) CreateToken(ctx context.Context, req usecase.CreateTokenRequest) (*usecase.TokenResponse, error) {
// 驗證客戶端
client, err := uc.clientRepo.GetByClientID(ctx, req.ClientID)
if err != nil {
return nil, err
}
if !utils.IsActive(client.Status) {
return nil, errs.UserSuspended(code.CloudEPPermission, "failed to get token since user has been suspended")
}
// 根據授權類型處理
var uid string
switch req.GrantType {
case "client_credentials":
uid = "client_" + req.ClientID
case "password":
if req.Username == "" || req.Password == "" {
return nil, errs.InvalidCredentials()
}
// 這裡應該驗證用戶名密碼,簡化處理
uid = req.Username
default:
return nil, errs.InvalidFormat("unsupported grant type: " + req.GrantType)
}
// 生成令牌
accessToken, err := uc.generateAccessToken(uid, req.ClientID, req.DeviceID)
if err != nil {
return nil, errs.SystemInternal("failed to generate access token: " + err.Error())
}
refreshToken, err := uc.generateRefreshToken()
if err != nil {
return nil, errs.SystemInternal("failed to generate refresh token: " + err.Error())
}
// 保存令牌
token := &entity.Token{
UID: uid,
ClientID: req.ClientID,
AccessToken: accessToken,
RefreshToken: refreshToken,
DeviceID: req.DeviceID,
ExpiresAt: time.Now().Add(uc.jwtConfig.AccessExpires),
}
if err := uc.tokenRepo.Create(ctx, token); err != nil {
return nil, err
}
return &usecase.TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
ExpiresIn: int64(uc.jwtConfig.AccessExpires.Seconds()),
}, nil
}
func (uc *AuthUseCase) RefreshToken(ctx context.Context, refreshToken string) (*usecase.TokenResponse, error) {
// 查找刷新令牌
token, err := uc.tokenRepo.GetByRefreshToken(ctx, refreshToken)
if err != nil {
return nil, err
}
if token.IsExpired() {
return nil, errs.TokenExpired()
}
// 生成新的訪問令牌
accessToken, err := uc.generateAccessToken(token.UID, token.ClientID, token.DeviceID)
if err != nil {
return nil, errs.SystemInternal("failed to generate access token: " + err.Error())
}
// 更新令牌
token.AccessToken = accessToken
token.ExpiresAt = time.Now().Add(uc.jwtConfig.AccessExpires)
if err := uc.tokenRepo.Update(ctx, token); err != nil {
return nil, err
}
return &usecase.TokenResponse{
AccessToken: accessToken,
RefreshToken: refreshToken,
TokenType: "Bearer",
ExpiresIn: int64(uc.jwtConfig.AccessExpires.Seconds()),
}, nil
}
func (uc *AuthUseCase) ValidateToken(ctx context.Context, accessToken string) (*usecase.TokenClaims, error) {
// 解析JWT令牌
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errs.TokenInvalid()
}
return []byte(uc.jwtConfig.Secret), nil
})
if err != nil {
return nil, errs.TokenInvalid()
}
if !token.Valid {
return nil, errs.TokenInvalid()
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return nil, errs.TokenInvalid()
}
uid, ok := claims["uid"].(string)
if !ok {
return nil, errs.TokenInvalid()
}
clientID, ok := claims["client_id"].(string)
if !ok {
return nil, errs.TokenInvalid()
}
deviceID, _ := claims["device_id"].(string)
return &usecase.TokenClaims{
UID: uid,
ClientID: clientID,
DeviceID: deviceID,
}, nil
}
func (uc *AuthUseCase) Logout(ctx context.Context, accessToken string) error {
// 查找並刪除令牌
token, err := uc.tokenRepo.GetByAccessToken(ctx, accessToken)
if err != nil {
return err
}
return uc.tokenRepo.Delete(ctx, token.ID)
}
func (uc *AuthUseCase) LogoutAllByUserID(ctx context.Context, uid string) error {
return uc.tokenRepo.DeleteByUserID(ctx, uid)
}
func (uc *AuthUseCase) generateAccessToken(uid, clientID, deviceID string) (string, error) {
claims := jwt.MapClaims{
"uid": uid,
"client_id": clientID,
"device_id": deviceID,
"exp": time.Now().Add(uc.jwtConfig.AccessExpires).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(uc.jwtConfig.Secret))
}
func (uc *AuthUseCase) generateRefreshToken() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return hex.EncodeToString(bytes), nil
}

View File

@ -1,348 +0,0 @@
package usecase
import (
"backend/pkg/library/errs/code"
"context"
"github.com/zeromicro/go-zero/core/logx"
"go.mongodb.org/mongo-driver/v2/bson"
"backend/pkg/library/errs"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"backend/pkg/permission/domain/usecase"
"github.com/casbin/casbin/v2"
)
type PermissionUseCaseParam struct {
Enforcer *casbin.Enforcer
PermissionRepo repository.PermissionRepository
RoleRepo repository.RoleRepository
UserRoleRepo repository.UserRoleRepository
}
type PermissionUseCase struct {
enforcer *casbin.Enforcer
permissionRepo repository.PermissionRepository
roleRepo repository.RoleRepository
userRoleRepo repository.UserRoleRepository
}
// MustPermissionUseCase 創建權限用例實例
func MustPermissionUseCase(param PermissionUseCaseParam) usecase.PermissionUseCase {
return &PermissionUseCase{
enforcer: param.Enforcer,
permissionRepo: param.PermissionRepo,
roleRepo: param.RoleRepo,
userRoleRepo: param.UserRoleRepo,
}
}
func (uc *PermissionUseCase) CreatePermission(ctx context.Context, req usecase.CreatePermissionRequest) (*entity.Permission, error) {
// 驗證請求
if req.Name == "" {
return nil, errs.InvalidFormat("permission name is required")
}
permission := &entity.Permission{
Name: req.Name,
HTTPMethod: req.HTTPMethod,
HTTPPath: req.HTTPPath,
Status: req.Status,
Type: req.Type,
}
if req.ParentID != nil {
objID, err := bson.ObjectIDFromHex(*req.ParentID)
if err != nil {
e := errs.InvalidFormat(err.Error())
return nil, e
}
permission.ID = objID
}
if err := uc.permissionRepo.Create(ctx, permission); err != nil {
return nil, err
}
return permission, nil
}
func (uc *PermissionUseCase) GetPermission(ctx context.Context, id string) (*entity.Permission, error) {
return uc.permissionRepo.GetByID(ctx, id)
}
func (uc *PermissionUseCase) UpdatePermission(ctx context.Context, req usecase.UpdatePermissionRequest) (*entity.Permission, error) {
// 獲取現有權限
permission, err := uc.permissionRepo.GetByID(ctx, req.ID)
if err != nil {
return nil, err
}
// 更新字段
if req.Name != nil {
permission.Name = *req.Name
}
if req.HTTPMethod != nil {
permission.HTTPMethod = *req.HTTPMethod
}
if req.HTTPPath != nil {
permission.HTTPPath = *req.HTTPPath
}
if req.Status != nil {
permission.Status = *req.Status
}
if req.Type != nil {
permission.Type = *req.Type
}
if err := uc.permissionRepo.Update(ctx, req.ID, permission); err != nil {
return nil, err
}
return permission, nil
}
func (uc *PermissionUseCase) DeletePermission(ctx context.Context, id string) error {
return uc.permissionRepo.Delete(ctx, id)
}
func (uc *PermissionUseCase) ListPermissions(ctx context.Context, req usecase.ListPermissionsRequest) ([]*entity.Permission, error) {
filter := repository.PermissionFilter{
Status: req.Status,
Type: req.Type,
ParentID: req.ParentID,
Limit: req.Limit,
Skip: req.Skip,
}
return uc.permissionRepo.List(ctx, filter)
}
// CheckUserPermission 使用 Casbin 檢查用戶權限
func (uc *PermissionUseCase) CheckUserPermission(ctx context.Context, uid, httpMethod, httpPath string) (bool, error) {
// 使用 Casbin 進行權限檢查
// sub: 用戶ID, obj: 資源路徑, act: 行為
hasPermission, err := uc.enforcer.Enforce(uid, httpPath, httpMethod)
if err != nil {
return false, errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin enforce failed: "+err.Error())
}
if !hasPermission {
return false, errs.InsufficientPermission(httpMethod + ":" + httpPath)
}
return true, nil
}
// CheckRolePermission 使用 Casbin 檢查角色權限
func (uc *PermissionUseCase) CheckRolePermission(ctx context.Context, roleUID, httpMethod, httpPath string) (bool, error) {
// 使用 Casbin 進行角色權限檢查
hasPermission, err := uc.enforcer.Enforce(roleUID, httpPath, httpMethod)
if err != nil {
return false, errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin enforce failed: "+err.Error())
}
if !hasPermission {
return false, errs.InsufficientPermission(httpMethod + ":" + httpPath)
}
return true, nil
}
// GetUserPermissions 獲取用戶的所有權限
func (uc *PermissionUseCase) GetUserPermissions(ctx context.Context, uid string) (map[string]int, error) {
// 獲取用戶的所有角色
roles, err := uc.enforcer.GetRolesForUser(uid)
if err != nil {
return nil, errs.SystemInternalErrorScope(code.CloudEPPermission, "failed to get user permissions: "+err.Error())
}
permissions := make(map[string]int)
// 獲取用戶直接擁有的權限
userPolicies, err := uc.enforcer.GetPermissionsForUser(uid)
if err != nil {
logx.Infof("failed to get user permissions: " + err.Error())
}
for _, policy := range userPolicies {
if len(policy) >= 3 {
key := policy[2] + ":" + policy[1] // method:path
permissions[key] = 1
}
}
// 獲取通過角色繼承的權限
for _, role := range roles {
rolePolicies, err := uc.enforcer.GetPermissionsForUser(role)
if err != nil {
logx.Infof("failed to get permissions for user: " + err.Error())
}
for _, policy := range rolePolicies {
if len(policy) >= 3 {
key := policy[2] + ":" + policy[1] // method:path
permissions[key] = 1
}
}
}
return permissions, nil
}
// BatchCheckPermissions 批量檢查權限
func (uc *PermissionUseCase) BatchCheckPermissions(ctx context.Context, uid string, permissions []usecase.PermissionCheck) (map[string]bool, error) {
results := make(map[string]bool)
for _, perm := range permissions {
key := perm.HTTPMethod + ":" + perm.HTTPPath
hasPermission, err := uc.enforcer.Enforce(uid, perm.HTTPPath, perm.HTTPMethod)
if err != nil {
return nil, errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin enforce failed: "+err.Error())
}
results[key] = hasPermission
}
return results, nil
}
// AddPolicyForUser 為用戶添加權限策略
func (uc *PermissionUseCase) AddPolicyForUser(ctx context.Context, uid, httpPath, httpMethod string) error {
added, err := uc.enforcer.AddPolicy(uid, httpPath, httpMethod)
if err != nil {
return errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin add policy failed: "+err.Error())
}
if !added {
return errs.ResourceAlreadyExistWithScope(code.CloudEPPermission, "policy already exists")
}
return nil
}
// RemovePolicyForUser 移除用戶的權限策略
func (uc *PermissionUseCase) RemovePolicyForUser(ctx context.Context, uid, httpPath, httpMethod string) error {
removed, err := uc.enforcer.RemovePolicy(uid, httpPath, httpMethod)
if err != nil {
return errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin remove policy failed: "+err.Error())
}
if !removed {
return errs.ResourceNotFoundWithScope(code.CloudEPPermission, 0, "policy not found")
}
return nil
}
// AddRoleForUser 為用戶分配角色
func (uc *PermissionUseCase) AddRoleForUser(ctx context.Context, uid, roleUID string) error {
added, err := uc.enforcer.AddRoleForUser(uid, roleUID)
if err != nil {
return errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin add role failed: "+err.Error())
}
if !added {
return errs.ResourceAlreadyExistWithScope(code.CloudEPPermission, "role already assigned")
}
return nil
}
// RemoveRoleForUser 移除用戶的角色
func (uc *PermissionUseCase) RemoveRoleForUser(ctx context.Context, uid, roleUID string) error {
removed, err := uc.enforcer.DeleteRoleForUser(uid, roleUID)
if err != nil {
return errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin remove role failed: "+err.Error())
}
if !removed {
return errs.ResourceNotFoundWithScope(code.CloudEPPermission, 0, "role assignment not found")
}
return nil
}
// GetUsersForRole 獲取角色下的所有用戶
func (uc *PermissionUseCase) GetUsersForRole(ctx context.Context, roleUID string) ([]string, error) {
return uc.enforcer.GetUsersForRole(roleUID)
}
// GetRolesForUser 獲取用戶的所有角色
func (uc *PermissionUseCase) GetRolesForUser(ctx context.Context, uid string) ([]string, error) {
return uc.enforcer.GetRolesForUser(uid)
}
// AddPermissionForRole 為角色添加權限
func (uc *PermissionUseCase) AddPermissionForRole(ctx context.Context, roleUID, httpPath, httpMethod string) error {
added, err := uc.enforcer.AddPolicy(roleUID, httpPath, httpMethod)
if err != nil {
return errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin add policy failed: "+err.Error())
}
if !added {
return errs.ResourceAlreadyExistWithScope(code.CloudEPPermission, "policy already exists")
}
return nil
}
// RemovePermissionForRole 移除角色的權限
func (uc *PermissionUseCase) RemovePermissionForRole(ctx context.Context, roleUID, httpPath, httpMethod string) error {
removed, err := uc.enforcer.RemovePolicy(roleUID, httpPath, httpMethod)
if err != nil {
return errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin remove policy failed: "+err.Error())
}
if !removed {
return errs.ResourceNotFoundWithScope(code.CloudEPPermission, 0, "policy not found")
}
return nil
}
// GetPermissionsForRole 獲取角色的所有權限
func (uc *PermissionUseCase) GetPermissionsForRole(ctx context.Context, roleUID string) (map[string]int, error) {
policies, err := uc.enforcer.GetPermissionsForUser(roleUID)
if err != nil {
return nil, errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin get permissions failed: "+err.Error())
}
permissions := make(map[string]int)
for _, policy := range policies {
if len(policy) >= 3 {
key := policy[2] + ":" + policy[1] // method:path
permissions[key] = 1
}
}
return permissions, nil
}
// CheckPatternPermission 檢查模式權限 (支援通配符)
func (uc *PermissionUseCase) CheckPatternPermission(ctx context.Context, uid, pattern, action string) (bool, error) {
hasPermission, err := uc.enforcer.Enforce(uid, pattern, action)
if err != nil {
return false, errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin enforce failed: "+err.Error())
}
return hasPermission, nil
}
// GetAllPolicies 獲取所有策略
func (uc *PermissionUseCase) GetAllPolicies(ctx context.Context) ([][]string, error) {
policies, err := uc.enforcer.GetPolicy()
if err != nil {
return nil, errs.SystemInternalErrorScope(code.CloudEPPermission, "failed to get all policies: "+err.Error())
}
return policies, nil
}
// GetFilteredPolicies 獲取過濾後的策略
func (uc *PermissionUseCase) GetFilteredPolicies(ctx context.Context, fieldIndex int, fieldValues ...string) ([][]string, error) {
policies, err := uc.enforcer.GetFilteredPolicy(fieldIndex, fieldValues...)
if err != nil {
return nil, errs.SystemInternalErrorScope(code.CloudEPPermission, "failed to get filtered policies: "+err.Error())
}
return policies, nil
}

View File

@ -1,189 +0,0 @@
package usecase
import (
"backend/pkg/library/errs/code"
"backend/pkg/permission/domain/permission"
"context"
"backend/pkg/library/errs"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"backend/pkg/permission/domain/usecase"
)
type RoleUseCaseParam struct {
RoleRepo repository.RoleRepository
UserRoleRepo repository.UserRoleRepository
}
type RoleUseCase struct {
roleRepo repository.RoleRepository
userRoleRepo repository.UserRoleRepository
}
// MustRoleUseCase 創建角色用例實例
func MustRoleUseCase(param RoleUseCaseParam) usecase.RoleUseCase {
return &RoleUseCase{
roleRepo: param.RoleRepo,
userRoleRepo: param.UserRoleRepo,
}
}
func (uc *RoleUseCase) CreateRole(ctx context.Context, req usecase.CreateRoleRequest) (*entity.Role, error) {
// 驗證請求
if req.ClientID == "" {
return nil, errs.InvalidFormat("client_id is required")
}
if req.Name == "" {
return nil, errs.InvalidFormat("role name is required")
}
// 檢查角色名稱是否已存在
existingRole, err := uc.roleRepo.GetByClientAndName(ctx, req.ClientID, req.Name)
if err == nil && existingRole != nil {
return nil, errs.ResourceAlreadyExistWithScope(code.CloudEPPermission, req.ClientID+":"+req.Name)
}
role := &entity.Role{
ClientID: req.ClientID,
UID: req.UID,
Name: req.Name,
Status: req.Status,
Permissions: req.Permissions,
}
if err := uc.roleRepo.Create(ctx, role); err != nil {
return nil, err
}
return role, nil
}
func (uc *RoleUseCase) GetRole(ctx context.Context, id string) (*entity.Role, error) {
return uc.roleRepo.GetByID(ctx, id)
}
func (uc *RoleUseCase) GetRoleByUID(ctx context.Context, uid string) (*entity.Role, error) {
return uc.roleRepo.GetByUID(ctx, uid)
}
func (uc *RoleUseCase) UpdateRole(ctx context.Context, req usecase.UpdateRoleRequest) (*entity.Role, error) {
// 獲取現有角色
role, err := uc.roleRepo.GetByID(ctx, req.ID)
if err != nil {
return nil, err
}
// 更新字段
if req.Name != nil {
// 檢查新名稱是否已存在
existingRole, err := uc.roleRepo.GetByClientAndName(ctx, role.ClientID, *req.Name)
if err == nil && existingRole != nil && existingRole.ID != role.ID {
return nil, errs.ResourceAlreadyExistWithScope(code.CloudEPPermission, role.ClientID+":"+*req.Name)
}
role.Name = *req.Name
}
if req.Status != nil {
role.Status = *req.Status
}
if req.Permissions != nil {
role.Permissions = *req.Permissions
}
if err := uc.roleRepo.Update(ctx, req.ID, role); err != nil {
return nil, err
}
return role, nil
}
func (uc *RoleUseCase) DeleteRole(ctx context.Context, id string) error {
// 獲取角色信息
role, err := uc.roleRepo.GetByID(ctx, id)
if err != nil {
return err
}
status := permission.StatusActive
// 檢查是否有用戶使用此角色
userRoles, err := uc.userRoleRepo.List(ctx, repository.UserRoleFilter{
RoleUID: role.UID,
Status: &status,
})
if err != nil {
return err
}
if len(userRoles) > 0 {
return errs.InvalidFormat("cannot delete role that is assigned to users")
}
return uc.roleRepo.Delete(ctx, id)
}
func (uc *RoleUseCase) ListRoles(ctx context.Context, req usecase.ListRolesRequest) ([]*entity.Role, error) {
filter := repository.RoleFilter{
ClientID: req.ClientID,
Status: req.Status,
Limit: req.Limit,
Skip: req.Skip,
}
return uc.roleRepo.List(ctx, filter)
}
func (uc *RoleUseCase) AddPermissionToRole(ctx context.Context, roleID string, permissionKey string) error {
// 獲取角色
role, err := uc.roleRepo.GetByID(ctx, roleID)
if err != nil {
return err
}
// 添加權限
role.AddPermission(permissionKey)
return uc.roleRepo.Update(ctx, role.ID.Hex(), role)
}
func (uc *RoleUseCase) RemovePermissionFromRole(ctx context.Context, roleID string, permissionKey string) error {
// 獲取角色
role, err := uc.roleRepo.GetByID(ctx, roleID)
if err != nil {
return err
}
// 移除權限
role.RemovePermission(permissionKey)
return uc.roleRepo.Update(ctx, role.ID.Hex(), role)
}
func (uc *RoleUseCase) BatchUpdateRolePermissions(ctx context.Context, roleID string, permissions entity.Permissions) error {
// 獲取角色
role, err := uc.roleRepo.GetByID(ctx, roleID)
if err != nil {
return err
}
// 批量更新權限
role.Permissions = permissions
return uc.roleRepo.Update(ctx, role.ID.Hex(), role)
}
func (uc *RoleUseCase) GetRolesByClientID(ctx context.Context, clientID string) ([]*entity.Role, error) {
return uc.roleRepo.GetRolesByClientID(ctx, clientID)
}
func (uc *RoleUseCase) CopyRole(ctx context.Context, sourceRoleID string, req usecase.CreateRoleRequest) (*entity.Role, error) {
// 獲取源角色
sourceRole, err := uc.roleRepo.GetByID(ctx, sourceRoleID)
if err != nil {
return nil, err
}
// 創建新角色,複製權限
newReq := req
newReq.Permissions = sourceRole.Permissions
return uc.CreateRole(ctx, newReq)
}

708
pkg/permission/usecase/token.go Executable file
View File

@ -0,0 +1,708 @@
package usecase
import (
"context"
"fmt"
"strconv"
"time"
"backend/internal/config"
"backend/pkg/library/errs"
"backend/pkg/library/errs/code"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"backend/pkg/permission/domain/token"
"backend/pkg/permission/domain/usecase"
"github.com/segmentio/ksuid"
"github.com/zeromicro/go-zero/core/logx"
)
type TokenUseCaseParam struct {
TokenRepo repository.TokenRepository
Config *config.Config
}
type TokenUseCase struct {
TokenUseCaseParam
}
func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) {
claims, err := parseClaims(token, use.Config.Token.AccessSecret, false)
if err != nil {
return nil,
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims",
req: token,
err: err,
message: "validate token claims error",
errorCode: code.TokenValidateError,
})
}
return claims, nil
}
func MustTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase {
return &TokenUseCase{
param,
}
}
// ============================================ token ============================================
func (use *TokenUseCase) NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error) {
tokenObj, err := use.newToken(ctx, &req)
if err != nil {
return entity.TokenResp{}, err
}
err = use.TokenRepo.Create(ctx, *tokenObj)
if err != nil {
return entity.TokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.Create",
req: req,
err: err,
message: "failed to create token",
errorCode: code.TokenCreateError,
})
}
return entity.TokenResp{
AccessToken: tokenObj.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(tokenObj.ExpiresIn),
RefreshToken: tokenObj.RefreshToken,
}, nil
}
func (use *TokenUseCase) newToken(ctx context.Context, req *entity.AuthorizationReq) (*entity.Token, error) {
// 準備建立 Token 所需
now := time.Now().UTC()
expires := req.Expires
refreshExpires := req.Expires
if expires <= 0 {
// 將時間加上 n 秒
sec := use.Config.Token.AccessTokenExpiry
// 獲取 Unix 時間戳
expires = now.Add(sec).Unix()
refreshExpires = expires
}
// 如果這是一個 Refresh Token 過期時間要比普通的Token 長
if req.IsRefreshToken {
// 獲取 Unix 時間戳
refresh := use.Config.Token.RefreshTokenExpiry
refreshExpires = now.Add(refresh).Unix()
}
token := entity.Token{
ID: ksuid.New().String(),
DeviceID: req.DeviceID,
ExpiresIn: int(expires),
RefreshExpiresIn: int(refreshExpires),
AccessCreateAt: now,
RefreshCreateAt: now,
}
tc := make(tokenClaims)
if req.Data != nil {
for k, v := range req.Data {
tc[k] = v
}
}
tc.SetRole(req.Role)
tc.SetID(token.ID)
tc.SetScope(req.Scope)
tc.SetAccount(req.Account)
token.UID = tc.UID()
if req.DeviceID != "" {
tc.SetDeviceID(req.DeviceID)
}
var err error
token.AccessToken, err = accessTokenGenerator(token, tc, use.Config.Token.AccessSecret)
if err != nil {
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "accessTokenGenerator",
req: req,
err: err,
message: "failed to generator access token",
errorCode: code.TokenCreateError,
})
}
if req.IsRefreshToken {
token.RefreshToken = refreshTokenGenerator(token.AccessToken)
}
return &token, nil
}
func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error) {
// Step 1: 檢查 refresh token
tokenObj, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token)
if err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokenByOneTimeToken",
req: req,
err: err,
message: "failed to get access token",
errorCode: code.TokenValidateError,
})
}
// Step 2: 提取 Claims Data
claimsData, err := parseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false)
if err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "extractClaims",
req: req,
err: err,
message: "failed to extract claims",
errorCode: code.TokenValidateError,
})
}
// Step 3: 創建新 token
credentials := token.ClientCredentials
newToken, err := use.newToken(ctx, &entity.AuthorizationReq{
GrantType: credentials.ToString(),
Scope: req.Scope,
DeviceID: req.DeviceID,
Data: claimsData,
Expires: req.Expires,
IsRefreshToken: true,
Account: req.DeviceID,
})
if err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "use.newToken",
req: req,
err: err,
message: "failed to create new token",
errorCode: code.TokenValidateError,
})
}
if err := use.TokenRepo.Create(ctx, *newToken); err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.Create",
req: req,
err: err,
message: "failed to create new token",
errorCode: code.TokenValidateError,
})
}
// Step 4: 刪除舊 token 並創建新 token
if err := use.TokenRepo.Delete(ctx, tokenObj); err != nil {
return entity.RefreshTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.Delete",
req: req,
err: err,
message: "failed to delete old token",
errorCode: code.TokenValidateError,
})
}
// 返回新的 Token 響應
return entity.RefreshTokenResp{
Token: newToken.AccessToken,
OneTimeToken: newToken.RefreshToken,
ExpiresIn: int64(newToken.ExpiresIn),
TokenType: token.TypeBearer.String(),
}, nil
}
func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error {
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "CancelToken extractClaims",
req: req,
err: err,
message: "failed to get token claims",
errorCode: code.TokenValidateError,
})
}
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo GetAccessTokenByID",
req: req,
err: err,
message: fmt.Sprintf("failed to get token claims :%s", claims.ID()),
errorCode: code.TokenValidateError,
})
}
err = use.TokenRepo.Delete(ctx, token)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo Delete",
req: req,
err: err,
message: fmt.Sprintf("failed to delete token :%s", token.ID),
errorCode: code.TokenValidateError,
})
}
return nil
}
func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) {
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, true)
if err != nil {
return entity.ValidationTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims",
req: req,
err: err,
message: "validate token claims error",
errorCode: code.TokenValidateError,
})
}
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
return entity.ValidationTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokenByID",
req: req,
err: err,
message: fmt.Sprintf("failed to get token :%s", claims.ID()),
errorCode: code.TokenValidateError,
})
}
return entity.ValidationTokenResp{
Token: entity.Token{
ID: token.ID,
UID: token.UID,
DeviceID: token.DeviceID,
AccessCreateAt: token.AccessCreateAt,
AccessToken: token.AccessToken,
ExpiresIn: token.ExpiresIn,
RefreshToken: token.RefreshToken,
RefreshExpiresIn: token.RefreshExpiresIn,
RefreshCreateAt: token.RefreshCreateAt,
},
Data: claims,
}, nil
}
func (use *TokenUseCase) CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error {
if req.UID != "" {
err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteAccessTokensByUID",
req: req,
err: err,
message: "failed to cancel tokens by uid",
errorCode: code.TokenValidateError,
})
}
}
if len(req.IDs) > 0 {
err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteAccessTokenByID",
req: req,
err: err,
message: "failed to cancel tokens by token ids",
errorCode: code.TokenValidateError,
})
}
}
return nil
}
func (use *TokenUseCase) CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error {
err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, req.DeviceID)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteAccessTokensByDeviceID",
req: req,
err: err,
message: "failed to cancel token by device id",
errorCode: code.TokenValidateError,
})
}
return nil
}
func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error) {
uidTokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, req.DeviceID)
if err != nil {
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokensByDeviceID",
req: req,
err: err,
message: "failed to get token by device id",
errorCode: code.TokenNotFound,
})
}
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
for _, v := range uidTokens {
tokens = append(tokens, &entity.TokenResp{
AccessToken: v.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(v.ExpiresIn),
RefreshToken: v.RefreshToken,
})
}
return tokens, nil
}
func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error) {
uidTokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, req.UID)
if err != nil {
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokensByUID",
req: req,
err: err,
message: "failed to get token by uid",
errorCode: code.TokenNotFound,
})
}
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
for _, v := range uidTokens {
tokens = append(tokens, &entity.TokenResp{
AccessToken: v.AccessToken,
TokenType: token.TypeBearer.String(),
ExpiresIn: int64(v.ExpiresIn),
RefreshToken: v.RefreshToken,
})
}
return tokens, nil
}
func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) {
// 驗證Token
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false)
if err != nil {
return entity.CreateOneTimeTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "parseClaims",
req: req,
err: err,
message: "failed to get token claims",
errorCode: code.OneTimeTokenError,
})
}
tokenObj, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
if err != nil {
return entity.CreateOneTimeTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.GetAccessTokenByID",
req: req,
err: err,
message: "failed to get token by id",
errorCode: code.OneTimeTokenError,
})
}
oneTimeToken := refreshTokenGenerator(ksuid.New().String())
key := token.TicketKeyPrefix + oneTimeToken
if err = use.TokenRepo.CreateOneTimeToken(ctx, key, entity.Ticket{
Data: claims,
Token: tokenObj,
}, time.Minute); err != nil {
return entity.CreateOneTimeTokenResp{},
use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.CreateOneTimeToken",
req: req,
err: err,
message: "create one time token error",
errorCode: code.OneTimeTokenError,
})
}
return entity.CreateOneTimeTokenResp{
OneTimeToken: oneTimeToken,
}, nil
}
func (use *TokenUseCase) CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error {
err := use.TokenRepo.DeleteOneTimeToken(ctx, req.Token, nil)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "TokenRepo.DeleteOneTimeToken",
req: req,
err: err,
message: "failed to del one time token by token",
errorCode: code.OneTimeTokenError,
})
}
return nil
}
type wrapTokenErrorReq struct {
funcName string
req any
err error
message string
errorCode uint32
}
// wrapTokenError 將錯誤信息封裝到 errs.LibError 中
func (use *TokenUseCase) wrapTokenError(ctx context.Context, param wrapTokenErrorReq) error {
logFields := []logx.LogField{
{Key: "req", Value: param.req},
{Key: "func", Value: param.funcName},
{Key: "err", Value: param.err.Error()},
}
logx.WithContext(ctx).Errorw(param.message, logFields...)
wrappedErr := errs.NewError(
code.CatToken,
code.CatToken,
param.errorCode,
param.message,
).Wrap(param.err)
return wrappedErr
}
// BlacklistToken 將 JWT token 加入黑名單 (立即撤銷)
func (use *TokenUseCase) BlacklistToken(ctx context.Context, token string, reason string) error {
// 解析 JWT 獲取完整的 claims
claimMap, err := parseToken(token, use.Config.Token.AccessSecret, false)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.parseToken",
req: token,
err: err,
message: "failed to parse token claims",
errorCode: code.InvalidJWT,
})
}
// 獲取 JTI (JWT ID)
jti, exists := claimMap["jti"]
if !exists {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.getJTI",
req: token,
err: entity.ErrInvalidJTI,
message: "token missing JTI claim",
errorCode: code.InvalidJWT,
})
}
jtiStr, ok := jti.(string)
if !ok {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.convertJTI",
req: token,
err: entity.ErrInvalidJTI,
message: "JTI claim is not a string",
errorCode: code.InvalidJWT,
})
}
// 獲取 UID (可能在 data 中)
var uid string
if dataInterface, exists := claimMap["data"]; exists {
if dataMap, ok := dataInterface.(map[string]interface{}); ok {
if uidInterface, exists := dataMap["uid"]; exists {
uid, _ = uidInterface.(string)
}
}
}
// 獲取過期時間
exp, exists := claimMap["exp"]
if !exists {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.getExp",
req: token,
err: entity.ErrTokenExpired,
message: "token missing exp claim",
errorCode: code.TokenExpired,
})
}
// 將 exp 轉換為 int64 (JWT 中通常是 float64)
var expInt int64
switch v := exp.(type) {
case float64:
expInt = int64(v)
case int64:
expInt = v
case string:
parsedExp, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.parseExp",
req: token,
err: err,
message: "failed to parse exp claim",
errorCode: code.TokenExpired,
})
}
expInt = parsedExp
default:
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.convertExp",
req: token,
err: fmt.Errorf("exp claim is not a valid type: %T", exp),
message: "exp claim type conversion failed",
errorCode: code.TokenExpired,
})
}
// 創建黑名單條目
blacklistEntry := &entity.BlacklistEntry{
JTI: jtiStr,
UID: uid,
ExpiresAt: expInt,
CreatedAt: time.Now().Unix(),
}
// 添加到黑名單
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistToken.AddToBlacklist",
req: jtiStr,
err: err,
message: "failed to add token to blacklist",
errorCode: code.TokenCreateError,
})
}
logx.WithContext(ctx).Infow("token blacklisted",
logx.Field("jti", jtiStr),
logx.Field("uid", uid),
logx.Field("reason", reason))
return nil
}
// IsTokenBlacklisted 檢查 JWT token 是否在黑名單中
func (use *TokenUseCase) IsTokenBlacklisted(ctx context.Context, jti string) (bool, error) {
isBlacklisted, err := use.TokenRepo.IsBlacklisted(ctx, jti)
if err != nil {
return false, use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "IsTokenBlacklisted",
req: jti,
err: err,
message: "failed to check blacklist status",
errorCode: code.TokenValidateError,
})
}
return isBlacklisted, nil
}
// BlacklistAllUserTokens 將用戶的所有 token 加入黑名單 (全設備登出)
func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error {
// 獲取用戶的所有 token
tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid)
if err != nil {
return use.wrapTokenError(ctx, wrapTokenErrorReq{
funcName: "BlacklistAllUserTokens.GetAccessTokensByUID",
req: uid,
err: err,
message: "failed to get user tokens",
errorCode: code.TokenValidateError,
})
}
// 為每個 token 創建黑名單條目
for _, token := range tokens {
// 解析 token 獲取 JTI 和過期時間
claims, err := parseClaims(token.AccessToken, use.Config.Token.AccessSecret, false)
if err != nil {
logx.WithContext(ctx).Errorw("failed to parse token for blacklisting",
logx.Field("uid", uid),
logx.Field("tokenID", token.ID),
logx.Field("error", err))
continue // 跳過無效的 token繼續處理其他 token
}
jti, exists := claims["jti"]
if !exists || jti == "" {
logx.WithContext(ctx).Errorw("token missing JTI claim",
logx.Field("uid", uid),
logx.Field("tokenID", token.ID))
continue
}
exp, exists := claims["exp"]
if !exists {
logx.WithContext(ctx).Errorw("token missing exp claim",
logx.Field("uid", uid),
logx.Field("tokenID", token.ID))
continue
}
// 將 exp 字符串轉換為 int64
expInt, err := strconv.ParseInt(exp, 10, 64)
if err != nil {
logx.WithContext(ctx).Errorw("failed to parse exp claim",
logx.Field("uid", uid),
logx.Field("tokenID", token.ID),
logx.Field("error", err))
continue
}
// 創建黑名單條目
blacklistEntry := &entity.BlacklistEntry{
JTI: jti,
UID: uid,
ExpiresAt: expInt,
CreatedAt: time.Now().Unix(),
}
// 添加到黑名單
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
if err != nil {
logx.WithContext(ctx).Errorw("failed to add token to blacklist",
logx.Field("uid", uid),
logx.Field("jti", jti),
logx.Field("error", err))
// 繼續處理其他 token不要因為一個失敗就停止
}
}
// 刪除用戶的所有 token 記錄
err = use.TokenRepo.DeleteAccessTokensByUID(ctx, uid)
if err != nil {
logx.WithContext(ctx).Errorw("failed to delete user tokens",
logx.Field("uid", uid),
logx.Field("error", err))
// 這不是致命錯誤,因為 token 已經被加入黑名單
}
logx.WithContext(ctx).Infow("all user tokens blacklisted",
logx.Field("uid", uid),
logx.Field("tokenCount", len(tokens)),
logx.Field("reason", reason))
return nil
}

View File

@ -0,0 +1,59 @@
package usecase
type tokenClaims map[string]string
func (tc tokenClaims) SetID(id string) {
tc["id"] = id
}
func (tc tokenClaims) SetRole(role string) {
tc["role"] = role
}
func (tc tokenClaims) SetDeviceID(deviceID string) {
tc["device_id"] = deviceID
}
func (tc tokenClaims) SetScope(scope string) {
tc["scope"] = scope
}
func (tc tokenClaims) SetAccount(account string) {
tc["account"] = account
}
func (tc tokenClaims) Role() string {
role, ok := tc["role"]
if !ok {
return ""
}
return role
}
func (tc tokenClaims) ID() string {
id, ok := tc["id"]
if !ok {
return ""
}
return id
}
func (tc tokenClaims) DeviceID() string {
deviceID, ok := tc["device_id"]
if !ok {
return ""
}
return deviceID
}
func (tc tokenClaims) UID() string {
uid, ok := tc["uid"]
if !ok {
return ""
}
return uid
}

View File

@ -0,0 +1,325 @@
package usecase
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestTokenClaims_SetAndGetID(t *testing.T) {
tests := []struct {
name string
id string
}{
{
name: "normal ID",
id: "token123",
},
{
name: "UUID ID",
id: "550e8400-e29b-41d4-a716-446655440000",
},
{
name: "empty ID",
id: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tc := make(tokenClaims)
tc.SetID(tt.id)
result := tc.ID()
assert.Equal(t, tt.id, result)
})
}
}
func TestTokenClaims_SetAndGetRole(t *testing.T) {
tests := []struct {
name string
role string
}{
{
name: "admin role",
role: "admin",
},
{
name: "user role",
role: "user",
},
{
name: "guest role",
role: "guest",
},
{
name: "empty role",
role: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tc := make(tokenClaims)
tc.SetRole(tt.role)
result := tc.Role()
assert.Equal(t, tt.role, result)
})
}
}
func TestTokenClaims_SetAndGetDeviceID(t *testing.T) {
tests := []struct {
name string
deviceID string
}{
{
name: "normal device ID",
deviceID: "device123",
},
{
name: "UUID device ID",
deviceID: "550e8400-e29b-41d4-a716-446655440000",
},
{
name: "empty device ID",
deviceID: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tc := make(tokenClaims)
tc.SetDeviceID(tt.deviceID)
result := tc.DeviceID()
assert.Equal(t, tt.deviceID, result)
})
}
}
func TestTokenClaims_SetAndGetScope(t *testing.T) {
tests := []struct {
name string
scope string
}{
{
name: "read write scope",
scope: "read write",
},
{
name: "read only scope",
scope: "read",
},
{
name: "admin scope",
scope: "admin",
},
{
name: "empty scope",
scope: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tc := make(tokenClaims)
tc.SetScope(tt.scope)
// Note: there's no GetScope method, so we just verify it's set
assert.Equal(t, tt.scope, tc["scope"])
})
}
}
func TestTokenClaims_SetAndGetAccount(t *testing.T) {
tests := []struct {
name string
account string
}{
{
name: "email account",
account: "user@example.com",
},
{
name: "username account",
account: "john_doe",
},
{
name: "phone account",
account: "+1234567890",
},
{
name: "empty account",
account: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tc := make(tokenClaims)
tc.SetAccount(tt.account)
// Note: there's no GetAccount method, so we just verify it's set
assert.Equal(t, tt.account, tc["account"])
})
}
}
func TestTokenClaims_SetAndGetUID(t *testing.T) {
tests := []struct {
name string
uid string
}{
{
name: "normal UID",
uid: "user123",
},
{
name: "UUID UID",
uid: "550e8400-e29b-41d4-a716-446655440000",
},
{
name: "empty UID",
uid: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tc := make(tokenClaims)
tc["uid"] = tt.uid
result := tc.UID()
assert.Equal(t, tt.uid, result)
})
}
}
func TestTokenClaims_GetNonExistentField(t *testing.T) {
tc := make(tokenClaims)
t.Run("get non-existent ID", func(t *testing.T) {
result := tc.ID()
assert.Empty(t, result)
})
t.Run("get non-existent Role", func(t *testing.T) {
result := tc.Role()
assert.Empty(t, result)
})
t.Run("get non-existent DeviceID", func(t *testing.T) {
result := tc.DeviceID()
assert.Empty(t, result)
})
t.Run("get non-existent UID", func(t *testing.T) {
result := tc.UID()
assert.Empty(t, result)
})
}
func TestTokenClaims_MultipleFields(t *testing.T) {
tc := make(tokenClaims)
tc.SetID("token123")
tc.SetRole("admin")
tc.SetDeviceID("device456")
tc.SetScope("read write")
tc.SetAccount("user@example.com")
tc["uid"] = "user789"
t.Run("verify all fields", func(t *testing.T) {
assert.Equal(t, "token123", tc.ID())
assert.Equal(t, "admin", tc.Role())
assert.Equal(t, "device456", tc.DeviceID())
assert.Equal(t, "read write", tc["scope"])
assert.Equal(t, "user@example.com", tc["account"])
assert.Equal(t, "user789", tc.UID())
})
}
func TestTokenClaims_Overwrite(t *testing.T) {
tc := make(tokenClaims)
t.Run("overwrite ID", func(t *testing.T) {
tc.SetID("token123")
assert.Equal(t, "token123", tc.ID())
tc.SetID("token456")
assert.Equal(t, "token456", tc.ID())
})
t.Run("overwrite Role", func(t *testing.T) {
tc.SetRole("user")
assert.Equal(t, "user", tc.Role())
tc.SetRole("admin")
assert.Equal(t, "admin", tc.Role())
})
}
func TestTokenClaims_MapBehavior(t *testing.T) {
tc := make(tokenClaims)
t.Run("can set custom fields", func(t *testing.T) {
tc["custom_field"] = "custom_value"
assert.Equal(t, "custom_value", tc["custom_field"])
})
t.Run("can iterate over fields", func(t *testing.T) {
tc2 := make(tokenClaims)
tc2.SetID("token123")
tc2.SetRole("admin")
tc2["uid"] = "user123"
count := 0
for range tc2 {
count++
}
assert.Equal(t, 3, count)
})
t.Run("can check field existence", func(t *testing.T) {
tc.SetID("token123")
_, exists := tc["id"]
assert.True(t, exists)
_, exists = tc["non_existent"]
assert.False(t, exists)
})
t.Run("can delete fields", func(t *testing.T) {
tc.SetRole("admin")
assert.Equal(t, "admin", tc.Role())
delete(tc, "role")
assert.Empty(t, tc.Role())
})
}
func TestTokenClaims_EmptyMap(t *testing.T) {
tc := make(tokenClaims)
assert.Empty(t, tc.ID())
assert.Empty(t, tc.Role())
assert.Empty(t, tc.DeviceID())
assert.Empty(t, tc.UID())
assert.Equal(t, 0, len(tc))
}
func TestTokenClaims_NilMap(t *testing.T) {
var tc tokenClaims
t.Run("get from nil map", func(t *testing.T) {
assert.Empty(t, tc.ID())
assert.Empty(t, tc.Role())
assert.Empty(t, tc.DeviceID())
assert.Empty(t, tc.UID())
})
}

View File

@ -0,0 +1,107 @@
package usecase
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"time"
"backend/pkg/permission/domain/entity"
"github.com/golang-jwt/jwt/v4"
)
var accessTokenGenerator = createAccessToken
var refreshTokenGenerator = createRefreshToken
// createAccessToken 生成訪問令牌Access Token
func createAccessToken(token entity.Token, data any, secretKey string) (string, error) {
claims := entity.Claims{
Data: data,
RegisteredClaims: jwt.RegisteredClaims{
ID: token.ID,
ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)),
Issuer: "permission",
},
}
accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).
SignedString([]byte(secretKey))
if err != nil {
return "", err
}
return accessToken, nil
}
// createRefreshToken 基於訪問令牌生成刷新令牌Refresh Token
func createRefreshToken(accessToken string) string {
hash := sha256.New()
_, _ = hash.Write([]byte(accessToken))
return hex.EncodeToString(hash.Sum(nil))
}
func parseToken(accessToken string, secret string, validate bool) (jwt.MapClaims, error) {
// 跳過驗證的解析
var token *jwt.Token
var err error
if validate {
token, err = jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("token unexpected signing method: %v", token.Header["alg"])
}
return []byte(secret), nil
})
if err != nil {
return jwt.MapClaims{}, err
}
} else {
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
token, err = parser.Parse(accessToken, func(_ *jwt.Token) (any, error) {
return []byte(secret), nil
})
if err != nil {
return jwt.MapClaims{}, err
}
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok && token.Valid {
return jwt.MapClaims{}, fmt.Errorf("token valid error")
}
return claims, nil
}
func parseClaims(accessToken string, secret string, validate bool) (tokenClaims, error) {
claimMap, err := parseToken(accessToken, secret, validate)
if err != nil {
return tokenClaims{}, err
}
claimsData, ok := claimMap["data"].(map[string]any)
if ok {
return convertMap(claimsData), nil
}
return tokenClaims{}, fmt.Errorf("get data from claim map error")
}
func convertMap(input map[string]interface{}) map[string]string {
output := make(map[string]string)
for key, value := range input {
switch v := value.(type) {
case string:
output[key] = v
case fmt.Stringer:
output[key] = v.String()
default:
output[key] = fmt.Sprintf("%v", value)
}
}
return output
}

View File

@ -0,0 +1,329 @@
package usecase
import (
"testing"
"time"
"backend/pkg/permission/domain/entity"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/assert"
)
func TestCreateAccessToken(t *testing.T) {
tests := []struct {
name string
token entity.Token
data interface{}
secretKey string
wantErr bool
}{
{
name: "successful token creation",
token: entity.Token{
ID: "test-token-id",
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
},
data: map[string]string{
"uid": "user123",
"role": "admin",
},
secretKey: "test-secret-key",
wantErr: false,
},
{
name: "empty secret key",
token: entity.Token{
ID: "test-token-id",
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
},
data: map[string]string{"uid": "user123"},
secretKey: "",
wantErr: false, // JWT library will still create token with empty key
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokenStr, err := createAccessToken(tt.token, tt.data, tt.secretKey)
if tt.wantErr {
assert.Error(t, err)
assert.Empty(t, tokenStr)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, tokenStr)
// Verify the token can be parsed
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
return []byte(tt.secretKey), nil
})
if tt.secretKey != "" {
assert.NoError(t, err)
assert.True(t, token.Valid)
// Check claims
if claims, ok := token.Claims.(jwt.MapClaims); ok {
assert.Equal(t, tt.token.ID, claims["jti"])
assert.Equal(t, "permission", claims["iss"])
assert.NotNil(t, claims["exp"])
assert.NotNil(t, claims["data"])
}
}
}
})
}
}
func TestCreateRefreshToken(t *testing.T) {
tests := []struct {
name string
accessToken string
want string
}{
{
name: "consistent hash generation",
accessToken: "test-access-token",
want: "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08", // SHA256 of "test"
},
{
name: "different token different hash",
accessToken: "different-access-token",
want: "", // We'll check it's different
},
{
name: "empty token",
accessToken: "",
want: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // SHA256 of empty string
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := createRefreshToken(tt.accessToken)
assert.NotEmpty(t, result)
assert.Len(t, result, 64) // SHA256 hex string length
if tt.want != "" {
if tt.name == "consistent hash generation" {
// For "test" input, we know the expected hash
testResult := createRefreshToken("test")
assert.Equal(t, tt.want, testResult)
} else if tt.name == "empty token" {
assert.Equal(t, tt.want, result)
}
}
// Test consistency - same input should produce same output
result2 := createRefreshToken(tt.accessToken)
assert.Equal(t, result, result2)
})
}
}
func TestParseToken(t *testing.T) {
secretKey := "test-secret-key"
// Create a valid token first
token := entity.Token{
ID: "test-id",
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
}
data := map[string]string{
"uid": "user123",
"role": "admin",
}
validTokenStr, err := createAccessToken(token, data, secretKey)
assert.NoError(t, err)
tests := []struct {
name string
accessToken string
secret string
validate bool
wantErr bool
}{
{
name: "valid token with validation",
accessToken: validTokenStr,
secret: secretKey,
validate: true,
wantErr: false,
},
{
name: "valid token without validation",
accessToken: validTokenStr,
secret: secretKey,
validate: false,
wantErr: false,
},
{
name: "invalid token",
accessToken: "invalid.token.string",
secret: secretKey,
validate: true,
wantErr: true,
},
{
name: "wrong secret",
accessToken: validTokenStr,
secret: "wrong-secret",
validate: true,
wantErr: true,
},
{
name: "empty token",
accessToken: "",
secret: secretKey,
validate: true,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claims, err := parseToken(tt.accessToken, tt.secret, tt.validate)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotNil(t, claims)
if tt.accessToken == validTokenStr {
assert.Equal(t, "test-id", claims["jti"])
assert.Equal(t, "permission", claims["iss"])
}
}
})
}
}
func TestParseClaims(t *testing.T) {
secretKey := "test-secret-key"
// Create a valid token with data claims
token := entity.Token{
ID: "test-id",
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
}
data := map[string]interface{}{
"uid": "user123",
"role": "admin",
"deviceId": "device456",
}
validTokenStr, err := createAccessToken(token, data, secretKey)
assert.NoError(t, err)
tests := []struct {
name string
accessToken string
secret string
validate bool
wantErr bool
expectUID string
expectRole string
}{
{
name: "valid token with data claims",
accessToken: validTokenStr,
secret: secretKey,
validate: false,
wantErr: false,
expectUID: "user123",
expectRole: "admin",
},
{
name: "invalid token",
accessToken: "invalid.token",
secret: secretKey,
validate: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claims, err := parseClaims(tt.accessToken, tt.secret, tt.validate)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotNil(t, claims)
if tt.expectUID != "" {
uid, exists := claims["uid"]
assert.True(t, exists)
assert.Equal(t, tt.expectUID, uid)
}
if tt.expectRole != "" {
role, exists := claims["role"]
assert.True(t, exists)
assert.Equal(t, tt.expectRole, role)
}
}
})
}
}
func TestConvertMap(t *testing.T) {
tests := []struct {
name string
input map[string]interface{}
expect map[string]string
}{
{
name: "string values",
input: map[string]interface{}{
"key1": "value1",
"key2": "value2",
},
expect: map[string]string{
"key1": "value1",
"key2": "value2",
},
},
{
name: "mixed types",
input: map[string]interface{}{
"string": "value",
"int": 123,
"float": 45.67,
"bool": true,
},
expect: map[string]string{
"string": "value",
"int": "123",
"float": "45.67",
"bool": "true",
},
},
{
name: "empty map",
input: map[string]interface{}{},
expect: map[string]string{},
},
{
name: "nil values",
input: map[string]interface{}{
"nil": nil,
},
expect: map[string]string{
"nil": "<nil>",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertMap(tt.input)
assert.Equal(t, tt.expect, result)
})
}
}

View File

@ -0,0 +1,435 @@
package usecase
import (
"context"
"testing"
"time"
"backend/internal/config"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/token"
"backend/pkg/permission/mock/repository"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestTokenUseCase_NewToken(t *testing.T) {
mockRepo := repository.NewMockTokenRepository(t)
cfg := &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-access-secret",
RefreshSecret: "test-refresh-secret",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 7 * 24 * time.Hour,
MaxTokensPerUser: 10,
MaxTokensPerDevice: 5,
},
}
useCase := &TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockRepo,
Config: cfg,
},
}
tests := []struct {
name string
req entity.AuthorizationReq
setup func()
wantErr bool
}{
{
name: "successful token creation",
req: entity.AuthorizationReq{
GrantType: token.PasswordCredentials.ToString(),
Scope: "read write",
DeviceID: "device123",
IsRefreshToken: true,
Data: map[string]string{
"uid": "user123",
"role": "user",
},
},
setup: func() {
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
Return(nil).Once()
},
wantErr: false,
},
{
name: "repository error",
req: entity.AuthorizationReq{
GrantType: token.PasswordCredentials.ToString(),
Scope: "read",
DeviceID: "device123",
},
setup: func() {
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
Return(assert.AnError).Once()
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
resp, err := useCase.NewToken(context.Background(), tt.req)
if tt.wantErr {
assert.Error(t, err)
assert.Empty(t, resp.AccessToken)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, resp.AccessToken)
assert.Equal(t, token.TypeBearer.String(), resp.TokenType)
assert.Greater(t, resp.ExpiresIn, int64(0))
if tt.req.IsRefreshToken {
assert.NotEmpty(t, resp.RefreshToken)
}
}
mockRepo.AssertExpectations(t)
})
}
}
func TestTokenUseCase_ValidationToken(t *testing.T) {
mockRepo := repository.NewMockTokenRepository(t)
cfg := &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-access-secret",
RefreshSecret: "test-refresh-secret",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 7 * 24 * time.Hour,
},
}
useCase := &TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockRepo,
Config: cfg,
},
}
// 先創建一個有效的 token 用於測試
tokenReq := entity.AuthorizationReq{
GrantType: token.PasswordCredentials.ToString(),
Data: map[string]string{
"uid": "user123",
"role": "user",
},
}
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
Return(nil).Once()
tokenResp, err := useCase.NewToken(context.Background(), tokenReq)
assert.NoError(t, err)
assert.NotEmpty(t, tokenResp.AccessToken)
// 測試驗證
tests := []struct {
name string
req entity.ValidationTokenReq
setup func()
wantErr bool
}{
{
name: "valid token",
req: entity.ValidationTokenReq{
Token: tokenResp.AccessToken,
},
setup: func() {
mockRepo.On("GetAccessTokenByID", mock.Anything, mock.AnythingOfType("string")).
Return(entity.Token{
ID: "test-id",
UID: "user123",
AccessToken: tokenResp.AccessToken,
ExpiresIn: int(cfg.Token.AccessTokenExpiry.Seconds()),
}, nil).Once()
},
wantErr: false,
},
{
name: "invalid token",
req: entity.ValidationTokenReq{
Token: "invalid-token",
},
setup: func() {
// parseClaims will fail for invalid token
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
resp, err := useCase.ValidationToken(context.Background(), tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, resp.Token.ID)
assert.Equal(t, "user123", resp.Token.UID)
}
mockRepo.AssertExpectations(t)
})
}
}
func TestTokenUseCase_BlacklistToken(t *testing.T) {
mockRepo := repository.NewMockTokenRepository(t)
cfg := &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-access-secret",
RefreshSecret: "test-refresh-secret",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 7 * 24 * time.Hour,
},
}
useCase := &TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockRepo,
Config: cfg,
},
}
// 先創建一個有效的 token
tokenReq := entity.AuthorizationReq{
GrantType: token.PasswordCredentials.ToString(),
Data: map[string]string{
"uid": "user123",
"role": "user",
},
}
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
Return(nil).Once()
tokenResp, err := useCase.NewToken(context.Background(), tokenReq)
assert.NoError(t, err)
tests := []struct {
name string
token string
reason string
setup func()
wantErr bool
}{
{
name: "successful blacklist",
token: tokenResp.AccessToken,
reason: "user logout",
setup: func() {
mockRepo.On("AddToBlacklist", mock.Anything, mock.AnythingOfType("*entity.BlacklistEntry"), mock.AnythingOfType("time.Duration")).
Return(nil).Once()
},
wantErr: false,
},
{
name: "invalid token",
token: "invalid-token",
reason: "test",
setup: func() {},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
err := useCase.BlacklistToken(context.Background(), tt.token, tt.reason)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
mockRepo.AssertExpectations(t)
})
}
}
func TestTokenUseCase_IsTokenBlacklisted(t *testing.T) {
mockRepo := repository.NewMockTokenRepository(t)
cfg := &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-secret",
},
}
useCase := &TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockRepo,
Config: cfg,
},
}
tests := []struct {
name string
jti string
setup func()
wantResult bool
wantErr bool
}{
{
name: "token is blacklisted",
jti: "test-jti-123",
setup: func() {
mockRepo.On("IsBlacklisted", mock.Anything, "test-jti-123").
Return(true, nil).Once()
},
wantResult: true,
wantErr: false,
},
{
name: "token is not blacklisted",
jti: "test-jti-456",
setup: func() {
mockRepo.On("IsBlacklisted", mock.Anything, "test-jti-456").
Return(false, nil).Once()
},
wantResult: false,
wantErr: false,
},
{
name: "repository error",
jti: "test-jti-error",
setup: func() {
mockRepo.On("IsBlacklisted", mock.Anything, "test-jti-error").
Return(false, assert.AnError).Once()
},
wantResult: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
result, err := useCase.IsTokenBlacklisted(context.Background(), tt.jti)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.wantResult, result)
}
mockRepo.AssertExpectations(t)
})
}
}
func TestTokenUseCase_CancelTokens(t *testing.T) {
mockRepo := repository.NewMockTokenRepository(t)
cfg := &config.Config{}
useCase := &TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockRepo,
Config: cfg,
},
}
tests := []struct {
name string
req entity.DoTokenByUIDReq
setup func()
wantErr bool
}{
{
name: "cancel by UID",
req: entity.DoTokenByUIDReq{
UID: "user123",
},
setup: func() {
mockRepo.On("DeleteAccessTokensByUID", mock.Anything, "user123").
Return(nil).Once()
},
wantErr: false,
},
{
name: "cancel by token IDs",
req: entity.DoTokenByUIDReq{
IDs: []string{"token1", "token2"},
},
setup: func() {
mockRepo.On("DeleteAccessTokenByID", mock.Anything, []string{"token1", "token2"}).
Return(nil).Once()
},
wantErr: false,
},
{
name: "repository error",
req: entity.DoTokenByUIDReq{
UID: "user123",
},
setup: func() {
mockRepo.On("DeleteAccessTokensByUID", mock.Anything, "user123").
Return(assert.AnError).Once()
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
err := useCase.CancelTokens(context.Background(), tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
mockRepo.AssertExpectations(t)
})
}
}

View File

@ -0,0 +1,565 @@
package usecase
import (
"context"
"testing"
"time"
"backend/internal/config"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/token"
"backend/pkg/permission/mock/repository"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
)
func TestTokenUseCase_RefreshToken(t *testing.T) {
mockRepo := repository.NewMockTokenRepository(t)
cfg := &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-access-secret",
RefreshSecret: "test-refresh-secret",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 7 * 24 * time.Hour,
},
}
useCase := &TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockRepo,
Config: cfg,
},
}
// Create a base token first
tokenReq := entity.AuthorizationReq{
GrantType: token.PasswordCredentials.ToString(),
Data: map[string]string{
"uid": "user123",
"role": "user",
},
IsRefreshToken: true,
}
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
Return(nil).Once()
tokenResp, err := useCase.NewToken(context.Background(), tokenReq)
assert.NoError(t, err)
tests := []struct {
name string
req entity.RefreshTokenReq
setup func()
wantErr bool
}{
{
name: "successful token refresh",
req: entity.RefreshTokenReq{
Token: tokenResp.RefreshToken,
Scope: "read write",
DeviceID: "device123",
},
setup: func() {
existingToken := entity.Token{
ID: "old-token-id",
UID: "user123",
AccessToken: tokenResp.AccessToken,
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
}
mockRepo.On("GetAccessTokenByOneTimeToken", mock.Anything, tokenResp.RefreshToken).
Return(existingToken, nil).Once()
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
Return(nil).Once()
mockRepo.On("Delete", mock.Anything, mock.AnythingOfType("entity.Token")).
Return(nil).Once()
},
wantErr: false,
},
{
name: "invalid refresh token",
req: entity.RefreshTokenReq{
Token: "invalid-refresh-token",
Scope: "read",
DeviceID: "device123",
},
setup: func() {
mockRepo.On("GetAccessTokenByOneTimeToken", mock.Anything, "invalid-refresh-token").
Return(entity.Token{}, assert.AnError).Once()
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
resp, err := useCase.RefreshToken(context.Background(), tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, resp.Token)
assert.NotEmpty(t, resp.OneTimeToken)
assert.Equal(t, token.TypeBearer.String(), resp.TokenType)
}
mockRepo.AssertExpectations(t)
})
}
}
func TestTokenUseCase_GetUserTokensByUID(t *testing.T) {
mockRepo := repository.NewMockTokenRepository(t)
cfg := &config.Config{}
useCase := &TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockRepo,
Config: cfg,
},
}
tests := []struct {
name string
req entity.QueryTokenByUIDReq
setup func()
wantErr bool
}{
{
name: "get tokens successfully",
req: entity.QueryTokenByUIDReq{
UID: "user123",
},
setup: func() {
tokens := []entity.Token{
{
ID: "token1",
UID: "user123",
AccessToken: "access1",
ExpiresIn: 3600,
},
{
ID: "token2",
UID: "user123",
AccessToken: "access2",
ExpiresIn: 3600,
},
}
mockRepo.On("GetAccessTokensByUID", mock.Anything, "user123").
Return(tokens, nil).Once()
},
wantErr: false,
},
{
name: "repository error",
req: entity.QueryTokenByUIDReq{
UID: "user456",
},
setup: func() {
mockRepo.On("GetAccessTokensByUID", mock.Anything, "user456").
Return([]entity.Token(nil), assert.AnError).Once()
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
tokens, err := useCase.GetUserTokensByUID(context.Background(), tt.req)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, tokens)
} else {
assert.NoError(t, err)
assert.NotNil(t, tokens)
assert.Greater(t, len(tokens), 0)
}
mockRepo.AssertExpectations(t)
})
}
}
func TestTokenUseCase_GetUserTokensByDeviceID(t *testing.T) {
mockRepo := repository.NewMockTokenRepository(t)
cfg := &config.Config{}
useCase := &TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockRepo,
Config: cfg,
},
}
tests := []struct {
name string
req entity.DoTokenByDeviceIDReq
setup func()
wantErr bool
}{
{
name: "get tokens by device successfully",
req: entity.DoTokenByDeviceIDReq{
DeviceID: "device123",
},
setup: func() {
tokens := []entity.Token{
{
ID: "token1",
UID: "user123",
DeviceID: "device123",
AccessToken: "access1",
ExpiresIn: 3600,
},
}
mockRepo.On("GetAccessTokensByDeviceID", mock.Anything, "device123").
Return(tokens, nil).Once()
},
wantErr: false,
},
{
name: "repository error",
req: entity.DoTokenByDeviceIDReq{
DeviceID: "device456",
},
setup: func() {
mockRepo.On("GetAccessTokensByDeviceID", mock.Anything, "device456").
Return([]entity.Token(nil), assert.AnError).Once()
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
tokens, err := useCase.GetUserTokensByDeviceID(context.Background(), tt.req)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, tokens)
} else {
assert.NoError(t, err)
assert.NotNil(t, tokens)
}
mockRepo.AssertExpectations(t)
})
}
}
func TestTokenUseCase_CancelTokenByDeviceID(t *testing.T) {
mockRepo := repository.NewMockTokenRepository(t)
cfg := &config.Config{}
useCase := &TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockRepo,
Config: cfg,
},
}
tests := []struct {
name string
req entity.DoTokenByDeviceIDReq
setup func()
wantErr bool
}{
{
name: "cancel tokens successfully",
req: entity.DoTokenByDeviceIDReq{
DeviceID: "device123",
},
setup: func() {
mockRepo.On("DeleteAccessTokensByDeviceID", mock.Anything, "device123").
Return(nil).Once()
},
wantErr: false,
},
{
name: "repository error",
req: entity.DoTokenByDeviceIDReq{
DeviceID: "device456",
},
setup: func() {
mockRepo.On("DeleteAccessTokensByDeviceID", mock.Anything, "device456").
Return(assert.AnError).Once()
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
err := useCase.CancelTokenByDeviceID(context.Background(), tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
mockRepo.AssertExpectations(t)
})
}
}
func TestTokenUseCase_NewOneTimeToken(t *testing.T) {
mockRepo := repository.NewMockTokenRepository(t)
cfg := &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-access-secret",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 7 * 24 * time.Hour,
},
}
useCase := &TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockRepo,
Config: cfg,
},
}
// Create a base token first
tokenReq := entity.AuthorizationReq{
GrantType: token.PasswordCredentials.ToString(),
Data: map[string]string{
"uid": "user123",
"role": "user",
},
}
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
Return(nil).Once()
tokenResp, err := useCase.NewToken(context.Background(), tokenReq)
assert.NoError(t, err)
tests := []struct {
name string
req entity.CreateOneTimeTokenReq
setup func()
wantErr bool
}{
{
name: "create one-time token successfully",
req: entity.CreateOneTimeTokenReq{
Token: tokenResp.AccessToken,
},
setup: func() {
existingToken := entity.Token{
ID: "token-id",
UID: "user123",
AccessToken: tokenResp.AccessToken,
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
}
mockRepo.On("GetAccessTokenByID", mock.Anything, mock.AnythingOfType("string")).
Return(existingToken, nil).Once()
mockRepo.On("CreateOneTimeToken", mock.Anything, mock.AnythingOfType("string"),
mock.AnythingOfType("entity.Ticket"), mock.AnythingOfType("time.Duration")).
Return(nil).Once()
},
wantErr: false,
},
{
name: "invalid token",
req: entity.CreateOneTimeTokenReq{
Token: "invalid-token",
},
setup: func() {
// parseClaims will fail
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
resp, err := useCase.NewOneTimeToken(context.Background(), tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotEmpty(t, resp.OneTimeToken)
}
mockRepo.AssertExpectations(t)
})
}
}
func TestTokenUseCase_CancelOneTimeToken(t *testing.T) {
mockRepo := repository.NewMockTokenRepository(t)
cfg := &config.Config{}
useCase := &TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockRepo,
Config: cfg,
},
}
tests := []struct {
name string
req entity.CancelOneTimeTokenReq
setup func()
wantErr bool
}{
{
name: "cancel one-time token successfully",
req: entity.CancelOneTimeTokenReq{
Token: []string{"token1", "token2"},
},
setup: func() {
mockRepo.On("DeleteOneTimeToken", mock.Anything, []string{"token1", "token2"}, mock.Anything).
Return(nil).Once()
},
wantErr: false,
},
{
name: "repository error",
req: entity.CancelOneTimeTokenReq{
Token: []string{"token3"},
},
setup: func() {
mockRepo.On("DeleteOneTimeToken", mock.Anything, []string{"token3"}, mock.Anything).
Return(assert.AnError).Once()
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.setup()
err := useCase.CancelOneTimeToken(context.Background(), tt.req)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
mockRepo.AssertExpectations(t)
})
}
}
func TestTokenUseCase_ReadTokenBasicData(t *testing.T) {
mockRepo := repository.NewMockTokenRepository(t)
cfg := &config.Config{
Token: struct {
AccessSecret string
RefreshSecret string
AccessTokenExpiry time.Duration
RefreshTokenExpiry time.Duration
OneTimeTokenExpiry time.Duration
MaxTokensPerUser int
MaxTokensPerDevice int
}{
AccessSecret: "test-access-secret",
AccessTokenExpiry: 15 * time.Minute,
RefreshTokenExpiry: 7 * 24 * time.Hour,
},
}
useCase := &TokenUseCase{
TokenUseCaseParam: TokenUseCaseParam{
TokenRepo: mockRepo,
Config: cfg,
},
}
// Create a valid token first
tokenReq := entity.AuthorizationReq{
GrantType: token.PasswordCredentials.ToString(),
Data: map[string]string{
"uid": "user123",
"role": "admin",
},
Role: "admin",
}
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
Return(nil).Once()
tokenResp, err := useCase.NewToken(context.Background(), tokenReq)
assert.NoError(t, err)
tests := []struct {
name string
token string
wantErr bool
}{
{
name: "read valid token",
token: tokenResp.AccessToken,
wantErr: false,
},
{
name: "invalid token",
token: "invalid-token",
wantErr: true,
},
{
name: "empty token",
token: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
claims, err := useCase.ReadTokenBasicData(context.Background(), tt.token)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.NotNil(t, claims)
assert.Equal(t, "user123", claims["uid"])
assert.Equal(t, "admin", claims["role"])
}
mockRepo.AssertExpectations(t)
})
}
}
// TestTokenUseCase_BlacklistAllUserTokens is commented out due to complexity of mocking
// the JWT parsing within the loop. The functionality is tested through integration tests.
// func TestTokenUseCase_BlacklistAllUserTokens(t *testing.T) { ... }

View File

@ -1,225 +0,0 @@
package usecase
import (
"backend/pkg/library/errs/code"
"backend/pkg/permission/domain/permission"
"backend/pkg/permission/utils"
"context"
"backend/pkg/library/errs"
"backend/pkg/permission/domain/entity"
"backend/pkg/permission/domain/repository"
"backend/pkg/permission/domain/usecase"
)
type UserRoleUseCaseParam struct {
UserRoleRepo repository.UserRoleRepository
RoleRepo repository.RoleRepository
}
type UserRoleUseCase struct {
userRoleRepo repository.UserRoleRepository
roleRepo repository.RoleRepository
}
// MustUserRoleUseCase 創建用戶角色用例實例
func MustUserRoleUseCase(param UserRoleUseCaseParam) usecase.UserRoleUseCase {
return &UserRoleUseCase{
userRoleRepo: param.UserRoleRepo,
roleRepo: param.RoleRepo,
}
}
func (uc *UserRoleUseCase) AssignRole(ctx context.Context, req usecase.AssignRoleRequest) (*entity.UserRole, error) {
// 驗證請求
if req.UID == "" {
return nil, errs.InvalidFormat("uid is required")
}
if req.RoleUID == "" {
return nil, errs.InvalidFormat("role_uid is required")
}
if req.Brand == "" {
return nil, errs.InvalidFormat("brand is required")
}
// 檢查角色是否存在
role, err := uc.roleRepo.GetByUID(ctx, req.RoleUID)
if err != nil {
return nil, err
}
if !utils.IsActive(role.Status) {
return nil, errs.InvalidFormat("role is not active")
}
// 檢查用戶是否已經有此角色
existingUserRole, err := uc.userRoleRepo.GetByUserAndRole(ctx, req.UID, req.RoleUID)
if err == nil && existingUserRole != nil && utils.IsActive(existingUserRole.Status) {
return nil, errs.ResourceAlreadyExistWithScope(code.CloudEPPermission, req.UID+":"+req.RoleUID)
}
userRole := &entity.UserRole{
Brand: req.Brand,
UID: req.UID,
RoleUID: req.RoleUID,
Status: permission.StatusActive,
}
if err := uc.userRoleRepo.Create(ctx, userRole); err != nil {
return nil, err
}
return userRole, nil
}
func (uc *UserRoleUseCase) RevokeRole(ctx context.Context, uid, roleUID string) error {
// 驗證參數
if uid == "" {
return errs.InvalidFormat("uid is required")
}
if roleUID == "" {
return errs.InvalidFormat("role_uid is required")
}
return uc.userRoleRepo.DeleteByUserAndRole(ctx, uid, roleUID)
}
func (uc *UserRoleUseCase) GetUserRole(ctx context.Context, id string) (*entity.UserRole, error) {
return uc.userRoleRepo.GetByID(ctx, id)
}
func (uc *UserRoleUseCase) UpdateUserRole(ctx context.Context, req usecase.UpdateUserRoleRequest) (*entity.UserRole, error) {
// 獲取現有用戶角色
userRole, err := uc.userRoleRepo.GetByID(ctx, req.ID)
if err != nil {
return nil, err
}
// 更新狀態
if req.Status != nil {
userRole.Status = *req.Status
}
if err := uc.userRoleRepo.Update(ctx, req.ID, userRole); err != nil {
return nil, err
}
return userRole, nil
}
func (uc *UserRoleUseCase) ListUserRoles(ctx context.Context, req usecase.ListUserRolesRequest) ([]*entity.UserRole, error) {
filter := repository.UserRoleFilter{
Brand: req.Brand,
UID: req.UID,
RoleUID: req.RoleUID,
Status: req.Status,
Limit: req.Limit,
Skip: req.Skip,
}
return uc.userRoleRepo.List(ctx, filter)
}
func (uc *UserRoleUseCase) GetUserRoles(ctx context.Context, uid string) ([]*entity.UserRole, error) {
if uid == "" {
return nil, errs.InvalidFormat("uid is required")
}
return uc.userRoleRepo.GetUserRolesByUID(ctx, uid)
}
func (uc *UserRoleUseCase) GetUserRoleDetails(ctx context.Context, uid string) ([]*usecase.UserRoleDetail, error) {
// 獲取用戶角色
userRoles, err := uc.GetUserRoles(ctx, uid)
if err != nil {
return nil, err
}
var details []*usecase.UserRoleDetail
for _, userRole := range userRoles {
if !utils.IsActive(userRole.Status) {
continue
}
// 獲取角色詳情
role, err := uc.roleRepo.GetByUID(ctx, userRole.RoleUID)
if err != nil {
continue // 忽略獲取失敗的角色
}
detail := &usecase.UserRoleDetail{
UserRole: userRole,
Role: role,
}
details = append(details, detail)
}
return details, nil
}
func (uc *UserRoleUseCase) BatchAssignRoles(ctx context.Context, uid string, roleUIDs []string, brand string) error {
if uid == "" {
return errs.InvalidFormat("uid is required")
}
if brand == "" {
return errs.InvalidFormat("brand is required")
}
// 逐個分配角色
for _, roleUID := range roleUIDs {
req := usecase.AssignRoleRequest{
Brand: brand,
UID: uid,
RoleUID: roleUID,
}
_, err := uc.AssignRole(ctx, req)
if err != nil {
// 如果是已存在錯誤,忽略繼續
e := errs.FromError(err)
if e.Is(errs.ResourceAlreadyExist()) {
continue
}
return err
}
}
return nil
}
func (uc *UserRoleUseCase) BatchRevokeRoles(ctx context.Context, uid string, roleUIDs []string) error {
if uid == "" {
return errs.InvalidFormat("uid is required")
}
// 逐個撤銷角色
for _, roleUID := range roleUIDs {
err := uc.RevokeRole(ctx, uid, roleUID)
if err != nil {
continue
}
}
return nil
}
func (uc *UserRoleUseCase) ReplaceUserRoles(ctx context.Context, uid string, roleUIDs []string, brand string) error {
// 獲取用戶當前的所有角色
currentUserRoles, err := uc.GetUserRoles(ctx, uid)
if err != nil {
return err
}
// 撤銷所有現有角色
for _, userRole := range currentUserRoles {
if utils.IsActive(userRole.Status) {
if err := uc.RevokeRole(ctx, uid, userRole.RoleUID); err != nil {
return err
}
}
}
// 分配新角色
return uc.BatchAssignRoles(ctx, uid, roleUIDs, brand)
}

View File

@ -1,7 +0,0 @@
package utils
import "backend/pkg/permission/domain/permission"
func IsActive(status int) bool {
return status == permission.StatusActive
}