feat: add token service
This commit is contained in:
parent
31ab87aadc
commit
40812db5bf
|
|
@ -42,3 +42,12 @@ LineAuth:
|
||||||
ClientID : "200000000"
|
ClientID : "200000000"
|
||||||
ClientSecret : xxxxx
|
ClientSecret : xxxxx
|
||||||
RedirectURI : http://localhost:8080/line.html
|
RedirectURI : http://localhost:8080/line.html
|
||||||
|
|
||||||
|
Token:
|
||||||
|
AccessSecret : "1qaz@WSX3edc$RFV"
|
||||||
|
RefreshSecret : "1qaz@WSX3edc$RFV"
|
||||||
|
AccessTokenExpiry : 600s
|
||||||
|
RefreshTokenExpiry : 86400s
|
||||||
|
OneTimeTokenExpiry : 600s
|
||||||
|
MaxTokensPerUser : 2
|
||||||
|
MaxTokensPerDevice : 2
|
||||||
|
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
[request_definition]
|
|
||||||
r = sub, obj, act
|
|
||||||
|
|
||||||
[policy_definition]
|
|
||||||
p = sub, obj, act
|
|
||||||
|
|
||||||
[role_definition]
|
|
||||||
g = _, _
|
|
||||||
|
|
||||||
[policy_effect]
|
|
||||||
e = some(where (p.eft == allow))
|
|
||||||
|
|
||||||
[matchers]
|
|
||||||
m = g(r.sub, p.sub) && keyMatch2(r.obj, p.obj) && regexMatch(r.act, p.act)
|
|
||||||
9
go.mod
9
go.mod
|
|
@ -10,12 +10,12 @@ require (
|
||||||
github.com/aws/aws-sdk-go-v2 v1.39.2
|
github.com/aws/aws-sdk-go-v2 v1.39.2
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.18.16
|
github.com/aws/aws-sdk-go-v2/credentials v1.18.16
|
||||||
github.com/aws/aws-sdk-go-v2/service/ses v1.34.5
|
github.com/aws/aws-sdk-go-v2/service/ses v1.34.5
|
||||||
github.com/casbin/casbin/v2 v2.127.0
|
|
||||||
github.com/go-playground/validator/v10 v10.27.0
|
github.com/go-playground/validator/v10 v10.27.0
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
github.com/golang-jwt/jwt/v4 v4.5.2
|
||||||
github.com/matcornic/hermes/v2 v2.1.0
|
github.com/matcornic/hermes/v2 v2.1.0
|
||||||
github.com/minchao/go-mitake v1.0.0
|
github.com/minchao/go-mitake v1.0.0
|
||||||
github.com/panjf2000/ants/v2 v2.11.3
|
github.com/panjf2000/ants/v2 v2.11.3
|
||||||
|
github.com/segmentio/ksuid v1.0.4
|
||||||
github.com/shopspring/decimal v1.4.0
|
github.com/shopspring/decimal v1.4.0
|
||||||
github.com/stretchr/testify v1.11.1
|
github.com/stretchr/testify v1.11.1
|
||||||
github.com/testcontainers/testcontainers-go v0.39.0
|
github.com/testcontainers/testcontainers-go v0.39.0
|
||||||
|
|
@ -42,8 +42,6 @@ require (
|
||||||
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.9 // indirect
|
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.9 // indirect
|
||||||
github.com/aws/smithy-go v1.23.0 // indirect
|
github.com/aws/smithy-go v1.23.0 // indirect
|
||||||
github.com/beorn7/perks v1.0.1 // indirect
|
github.com/beorn7/perks v1.0.1 // indirect
|
||||||
github.com/bmatcuk/doublestar/v4 v4.6.1 // indirect
|
|
||||||
github.com/casbin/govaluate v1.3.0 // indirect
|
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||||
github.com/containerd/errdefs v1.0.0 // indirect
|
github.com/containerd/errdefs v1.0.0 // indirect
|
||||||
|
|
@ -67,7 +65,6 @@ require (
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/gogo/protobuf v1.3.2 // indirect
|
github.com/gogo/protobuf v1.3.2 // indirect
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.2 // indirect
|
|
||||||
github.com/golang/snappy v1.0.0 // indirect
|
github.com/golang/snappy v1.0.0 // indirect
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
github.com/google/uuid v1.6.0 // indirect
|
||||||
github.com/gorilla/css v1.0.0 // indirect
|
github.com/gorilla/css v1.0.0 // indirect
|
||||||
|
|
@ -108,12 +105,12 @@ require (
|
||||||
github.com/redis/go-redis/v9 v9.15.0 // indirect
|
github.com/redis/go-redis/v9 v9.15.0 // indirect
|
||||||
github.com/rivo/uniseg v0.2.0 // indirect
|
github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
github.com/russross/blackfriday/v2 v2.0.1 // indirect
|
github.com/russross/blackfriday/v2 v2.0.1 // indirect
|
||||||
github.com/shirou/gopsutil/v3 v3.24.5 // indirect
|
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||||
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
|
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
|
||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||||
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
|
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf // indirect
|
||||||
|
github.com/stretchr/objx v0.5.2 // indirect
|
||||||
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
github.com/tklauser/go-sysconf v0.3.12 // indirect
|
||||||
github.com/tklauser/numcpus v0.6.1 // indirect
|
github.com/tklauser/numcpus v0.6.1 // indirect
|
||||||
github.com/vanng822/css v0.0.0-20190504095207-a21e860bcd04 // indirect
|
github.com/vanng822/css v0.0.0-20190504095207-a21e860bcd04 // indirect
|
||||||
|
|
|
||||||
21
go.sum
21
go.sum
|
|
@ -34,16 +34,10 @@ github.com/aws/smithy-go v1.23.0 h1:8n6I3gXzWJB2DxBDnfxgBaSX6oe0d/t10qGz7OKqMCE=
|
||||||
github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
|
github.com/aws/smithy-go v1.23.0/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
|
||||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
github.com/bmatcuk/doublestar/v4 v4.6.1 h1:FH9SifrbvJhnlQpztAx++wlkk70QBf0iBWDwNy7PA4I=
|
|
||||||
github.com/bmatcuk/doublestar/v4 v4.6.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc=
|
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||||
github.com/casbin/casbin/v2 v2.127.0 h1:UGK3uO/8cOslnNqFUJ4xzm/bh+N+o45U7cSolaFk38c=
|
|
||||||
github.com/casbin/casbin/v2 v2.127.0/go.mod h1:n4uZK8+tCMvcD6EVQZI90zKAok8iHAvEypcMJVKhGF0=
|
|
||||||
github.com/casbin/govaluate v1.3.0 h1:VA0eSY0M2lA86dYd5kPPuNZMUD9QkWnOCnavGrw9myc=
|
|
||||||
github.com/casbin/govaluate v1.3.0/go.mod h1:G/UnbIjZk/0uMNaLwZZmFQrR72tYRZWQkO70si/iR7A=
|
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
|
@ -101,17 +95,11 @@ github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
|
||||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
|
||||||
github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc=
|
|
||||||
github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4=
|
|
||||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
||||||
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
|
||||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
|
@ -220,12 +208,10 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
|
||||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||||
github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q=
|
github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q=
|
||||||
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/iZ9pI=
|
github.com/segmentio/ksuid v1.0.4 h1:sBo2BdShXjmcugAMwjugoGUdUV0pcxY5mW4xKRn3v4c=
|
||||||
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
|
github.com/segmentio/ksuid v1.0.4/go.mod h1:/XUiZBD3kVx5SmUOl55voK5yeAbBNNIed+2O73XgrPE=
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
github.com/shirou/gopsutil/v4 v4.25.6/go.mod h1:PfybzyydfZcN+JMMjkF6Zb8Mq1A/VcogFFg7hj50W9c=
|
||||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
|
||||||
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
|
|
||||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||||
github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=
|
github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=
|
||||||
|
|
@ -326,7 +312,6 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||||
golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
|
||||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
|
|
@ -357,7 +342,6 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|
||||||
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k=
|
||||||
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
|
|
@ -374,7 +358,6 @@ golang.org/x/text v0.29.0/go.mod h1:7MhJOA9CD2qZyOKYazxdYMF85OwPdEr9jTtBpO7ydH4=
|
||||||
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4=
|
||||||
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
|
||||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
|
|
|
||||||
|
|
@ -49,4 +49,15 @@ type Config struct {
|
||||||
ClientSecret string
|
ClientSecret string
|
||||||
RedirectURI string
|
RedirectURI string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JWT Token 配置
|
||||||
|
Token struct {
|
||||||
|
AccessSecret string
|
||||||
|
RefreshSecret string
|
||||||
|
AccessTokenExpiry time.Duration
|
||||||
|
RefreshTokenExpiry time.Duration
|
||||||
|
OneTimeTokenExpiry time.Duration
|
||||||
|
MaxTokensPerUser int
|
||||||
|
MaxTokensPerDevice int
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ import (
|
||||||
"backend/pkg/library/errs/code"
|
"backend/pkg/library/errs/code"
|
||||||
mb "backend/pkg/member/domain/member"
|
mb "backend/pkg/member/domain/member"
|
||||||
member "backend/pkg/member/domain/usecase"
|
member "backend/pkg/member/domain/usecase"
|
||||||
|
"backend/pkg/permission/domain/entity"
|
||||||
|
"backend/pkg/permission/domain/token"
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -27,7 +29,7 @@ var PrepareFunc map[string]func(ctx context.Context, req *types.LoginReq, svc *s
|
||||||
mb.Line.ToString(): buildLineData,
|
mb.Line.ToString(): buildLineData,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 註冊新帳號
|
// NewRegisterLogic 註冊新帳號
|
||||||
func NewRegisterLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RegisterLogic {
|
func NewRegisterLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RegisterLogic {
|
||||||
return &RegisterLogic{
|
return &RegisterLogic{
|
||||||
Logger: logx.WithContext(ctx),
|
Logger: logx.WithContext(ctx),
|
||||||
|
|
@ -80,16 +82,16 @@ func (l *RegisterLogic) Register(req *types.LoginReq) (resp *types.LoginResp, er
|
||||||
|
|
||||||
// Step 5: 生成 Token
|
// Step 5: 生成 Token
|
||||||
req.LoginID = bd.CreateAccountReq.LoginID
|
req.LoginID = bd.CreateAccountReq.LoginID
|
||||||
token, err := l.generateToken(req, account.UID)
|
tk, err := l.generateToken(req, account.UID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &types.LoginResp{
|
return &types.LoginResp{
|
||||||
UID: account.UID,
|
UID: account.UID,
|
||||||
AccessToken: token.AccessToken,
|
AccessToken: tk.AccessToken,
|
||||||
RefreshToken: token.RefreshToken,
|
RefreshToken: tk.RefreshToken,
|
||||||
TokenType: token.TokenType,
|
TokenType: tk.TokenType,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -183,40 +185,33 @@ func buildLineData(ctx context.Context, req *types.LoginReq, svc *svc.ServiceCon
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MockToken struct {
|
// 生成 Token
|
||||||
AccessToken string `json:"access_token"` // 訪問令牌
|
func (l *RegisterLogic) generateToken(req *types.LoginReq, uid string) (entity.TokenResp, error) {
|
||||||
TokenType string `json:"token_type"` // 令牌類型
|
// scope role 要修改,refresh tl
|
||||||
ExpiresIn int64 `json:"expires_in"` // 過期時間(秒)
|
role := "user"
|
||||||
RefreshToken string `json:"refresh_token"` // 刷新令牌
|
|
||||||
|
tk, err := l.svcCtx.TokenUC.NewToken(l.ctx, entity.AuthorizationReq{
|
||||||
|
GrantType: token.ClientCredentials.ToString(),
|
||||||
|
DeviceID: uid, // TODO 沒傳暫時先用UID 替代
|
||||||
|
Scope: "gateway",
|
||||||
|
IsRefreshToken: true,
|
||||||
|
Expires: time.Now().UTC().Add(l.svcCtx.Config.Token.AccessTokenExpiry).Unix(),
|
||||||
|
Data: map[string]string{
|
||||||
|
"uid": uid,
|
||||||
|
"role": role,
|
||||||
|
},
|
||||||
|
Role: role,
|
||||||
|
Account: req.LoginID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return entity.TokenResp{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 生成 Token
|
return entity.TokenResp{
|
||||||
func (l *RegisterLogic) generateToken(req *types.LoginReq, uid string) (MockToken, error) {
|
AccessToken: tk.AccessToken,
|
||||||
//credentials := tokenModule.ClientCredentials
|
TokenType: tk.TokenType,
|
||||||
//role := "user"
|
ExpiresIn: tk.ExpiresIn,
|
||||||
//if isTruHeartEmail(req.Account) {
|
RefreshToken: tk.RefreshToken,
|
||||||
// role = "admin"
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//return l.svcCtx.TokenUseCase.NewToken(l.ctx, tokenModule.AuthorizationReq{
|
|
||||||
// GrantType: credentials.ToString(),
|
|
||||||
// DeviceID: req.DeviceID,
|
|
||||||
// Scope: domain.DefaultScope,
|
|
||||||
// IsRefreshToken: true,
|
|
||||||
// Expires: time.Now().UTC().Add(l.svcCtx.Config.Token.Expired).Unix(),
|
|
||||||
// Data: map[string]string{
|
|
||||||
// "uid": uid,
|
|
||||||
// "role": role,
|
|
||||||
// "account": req.Account,
|
|
||||||
// },
|
|
||||||
// Role: role,
|
|
||||||
//})
|
|
||||||
|
|
||||||
return MockToken{
|
|
||||||
AccessToken: "gg88g88",
|
|
||||||
TokenType: "Bearer",
|
|
||||||
ExpiresIn: time.Now().UTC().Add(100000000000).Unix(),
|
|
||||||
RefreshToken: "gg88g88",
|
|
||||||
}, nil
|
}, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,194 +0,0 @@
|
||||||
package svc
|
|
||||||
|
|
||||||
//func NewPermissionUC(c *config.Config, rds *redis.Redis) usecase.PermissionUseCase {
|
|
||||||
// // 準備Mongo Config (重用現有配置)
|
|
||||||
// conf := &mgo.Conf{
|
|
||||||
// Schema: c.Mongo.Schema,
|
|
||||||
// Host: c.Mongo.Host,
|
|
||||||
// Database: c.Mongo.Database,
|
|
||||||
// MaxStaleness: c.Mongo.MaxStaleness,
|
|
||||||
// MaxPoolSize: c.Mongo.MaxPoolSize,
|
|
||||||
// MinPoolSize: c.Mongo.MinPoolSize,
|
|
||||||
// MaxConnIdleTime: c.Mongo.MaxConnIdleTime,
|
|
||||||
// Compressors: c.Mongo.Compressors,
|
|
||||||
// EnableStandardReadWriteSplitMode: c.Mongo.EnableStandardReadWriteSplitMode,
|
|
||||||
// ConnectTimeoutMs: c.Mongo.ConnectTimeoutMs,
|
|
||||||
// }
|
|
||||||
// if c.Mongo.User != "" {
|
|
||||||
// conf.User = c.Mongo.User
|
|
||||||
// conf.Password = c.Mongo.Password
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // 快取選項
|
|
||||||
// cacheOpts := []cache.Option{
|
|
||||||
// cache.WithExpiry(c.CacheExpireTime),
|
|
||||||
// cache.WithNotFoundExpiry(c.CacheWithNotFoundExpiry),
|
|
||||||
// }
|
|
||||||
// dbOpts := []mon.Option{
|
|
||||||
// mgo.SetCustomDecimalType(),
|
|
||||||
// mgo.InitMongoOptions(*conf),
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // 初始化 Casbin Adapter
|
|
||||||
// casbinAdapter := repository.NewCasbinAdapter(repository.CasbinAdapterParam{
|
|
||||||
// Conf: conf,
|
|
||||||
// CacheConf: c.Cache,
|
|
||||||
// CacheOpts: cacheOpts,
|
|
||||||
// DBOpts: dbOpts,
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// // 初始化 Casbin Enforcer
|
|
||||||
// modelPath := "pkg/permission/config/rbac_model.conf"
|
|
||||||
// enforcer, err := casbin.NewEnforcer(modelPath, casbinAdapter)
|
|
||||||
// if err != nil {
|
|
||||||
// panic("Failed to create casbin enforcer: " + err.Error())
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // 啟用自動保存
|
|
||||||
// enforcer.EnableAutoSave(true)
|
|
||||||
//
|
|
||||||
// // 載入策略
|
|
||||||
// err = enforcer.LoadPolicy()
|
|
||||||
// if err != nil {
|
|
||||||
// panic("Failed to load casbin policy: " + err.Error())
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // 初始化其他 Repository
|
|
||||||
// permissionRepo := repository.NewPermissionRepository(repository.PermissionRepositoryParam{
|
|
||||||
// Conf: conf,
|
|
||||||
// CacheConf: c.Cache,
|
|
||||||
// CacheOpts: cacheOpts,
|
|
||||||
// DBOpts: dbOpts,
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// roleRepo := repository.NewRoleRepository(repository.RoleRepositoryParam{
|
|
||||||
// Conf: conf,
|
|
||||||
// CacheConf: c.Cache,
|
|
||||||
// CacheOpts: cacheOpts,
|
|
||||||
// DBOpts: dbOpts,
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// userRoleRepo := repository.NewUserRoleRepository(repository.UserRoleRepositoryParam{
|
|
||||||
// Conf: conf,
|
|
||||||
// CacheConf: c.Cache,
|
|
||||||
// CacheOpts: cacheOpts,
|
|
||||||
// DBOpts: dbOpts,
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// // 創建索引
|
|
||||||
// _, _ = permissionRepo.Index20241226001UP(context.Background())
|
|
||||||
// _, _ = roleRepo.Index20241226001UP(context.Background())
|
|
||||||
// _, _ = userRoleRepo.Index20241226001UP(context.Background())
|
|
||||||
//
|
|
||||||
// return uc.MustPermissionUseCase(uc.PermissionUseCaseParam{
|
|
||||||
// Enforcer: enforcer,
|
|
||||||
// PermissionRepo: permissionRepo,
|
|
||||||
// RoleRepo: roleRepo,
|
|
||||||
// UserRoleRepo: userRoleRepo,
|
|
||||||
// })
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//func NewAuthUC(c *config.Config, rds *redis.Redis) usecase.AuthUseCase {
|
|
||||||
// // 準備Mongo Config
|
|
||||||
// conf := &mgo.Conf{
|
|
||||||
// Schema: c.Mongo.Schema,
|
|
||||||
// Host: c.Mongo.Host,
|
|
||||||
// Database: c.Mongo.Database,
|
|
||||||
// MaxStaleness: c.Mongo.MaxStaleness,
|
|
||||||
// MaxPoolSize: c.Mongo.MaxPoolSize,
|
|
||||||
// MinPoolSize: c.Mongo.MinPoolSize,
|
|
||||||
// MaxConnIdleTime: c.Mongo.MaxConnIdleTime,
|
|
||||||
// Compressors: c.Mongo.Compressors,
|
|
||||||
// EnableStandardReadWriteSplitMode: c.Mongo.EnableStandardReadWriteSplitMode,
|
|
||||||
// ConnectTimeoutMs: c.Mongo.ConnectTimeoutMs,
|
|
||||||
// }
|
|
||||||
// if c.Mongo.User != "" {
|
|
||||||
// conf.User = c.Mongo.User
|
|
||||||
// conf.Password = c.Mongo.Password
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // 快取選項
|
|
||||||
// cacheOpts := []cache.Option{
|
|
||||||
// cache.WithExpiry(c.CacheExpireTime),
|
|
||||||
// cache.WithNotFoundExpiry(c.CacheWithNotFoundExpiry),
|
|
||||||
// }
|
|
||||||
// dbOpts := []mon.Option{
|
|
||||||
// mgo.SetCustomDecimalType(),
|
|
||||||
// mgo.InitMongoOptions(*conf),
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // 初始化 Repository
|
|
||||||
// clientRepo := repository.NewClientRepository(repository.ClientRepositoryParam{
|
|
||||||
// Conf: conf,
|
|
||||||
// CacheConf: c.Cache,
|
|
||||||
// CacheOpts: cacheOpts,
|
|
||||||
// DBOpts: dbOpts,
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// tokenRepo := repository.NewTokenRepository(repository.TokenRepositoryParam{
|
|
||||||
// Redis: rds,
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// // JWT 配置
|
|
||||||
// jwtConfig := permissionConfig.JWTConfig{
|
|
||||||
// Secret: c.JWTAuth.AccessSecret, // 使用現有的JWT配置
|
|
||||||
// AccessExpires: c.JWTAuth.AccessExpire,
|
|
||||||
// RefreshExpires: c.JWTAuth.AccessExpire * 7, // refresh token 較長
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// return uc.MustAuthUseCase(uc.AuthUseCaseParam{
|
|
||||||
// ClientRepo: clientRepo,
|
|
||||||
// TokenRepo: tokenRepo,
|
|
||||||
// JWTConfig: jwtConfig,
|
|
||||||
// })
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//func NewRoleUC(c *config.Config) usecase.RoleUseCase {
|
|
||||||
// // 準備Mongo Config
|
|
||||||
// conf := &mgo.Conf{
|
|
||||||
// Schema: c.Mongo.Schema,
|
|
||||||
// Host: c.Mongo.Host,
|
|
||||||
// Database: c.Mongo.Database,
|
|
||||||
// MaxStaleness: c.Mongo.MaxStaleness,
|
|
||||||
// MaxPoolSize: c.Mongo.MaxPoolSize,
|
|
||||||
// MinPoolSize: c.Mongo.MinPoolSize,
|
|
||||||
// MaxConnIdleTime: c.Mongo.MaxConnIdleTime,
|
|
||||||
// Compressors: c.Mongo.Compressors,
|
|
||||||
// EnableStandardReadWriteSplitMode: c.Mongo.EnableStandardReadWriteSplitMode,
|
|
||||||
// ConnectTimeoutMs: c.Mongo.ConnectTimeoutMs,
|
|
||||||
// }
|
|
||||||
// if c.Mongo.User != "" {
|
|
||||||
// conf.User = c.Mongo.User
|
|
||||||
// conf.Password = c.Mongo.Password
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // 快取選項
|
|
||||||
// cacheOpts := []cache.Option{
|
|
||||||
// cache.WithExpiry(c.CacheExpireTime),
|
|
||||||
// cache.WithNotFoundExpiry(c.CacheWithNotFoundExpiry),
|
|
||||||
// }
|
|
||||||
// dbOpts := []mon.Option{
|
|
||||||
// mgo.SetCustomDecimalType(),
|
|
||||||
// mgo.InitMongoOptions(*conf),
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // 初始化 Repository
|
|
||||||
// roleRepo := repository.NewRoleRepository(repository.RoleRepositoryParam{
|
|
||||||
// Conf: conf,
|
|
||||||
// CacheConf: c.Cache,
|
|
||||||
// CacheOpts: cacheOpts,
|
|
||||||
// DBOpts: dbOpts,
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// userRoleRepo := repository.NewUserRoleRepository(repository.UserRoleRepositoryParam{
|
|
||||||
// Conf: conf,
|
|
||||||
// CacheConf: c.Cache,
|
|
||||||
// CacheOpts: cacheOpts,
|
|
||||||
// DBOpts: dbOpts,
|
|
||||||
// })
|
|
||||||
//
|
|
||||||
// return uc.MustRoleUseCase(uc.RoleUseCaseParam{
|
|
||||||
// RoleRepo: roleRepo,
|
|
||||||
// UserRoleRepo: userRoleRepo,
|
|
||||||
// })
|
|
||||||
//}
|
|
||||||
|
|
@ -6,7 +6,8 @@ import (
|
||||||
"backend/pkg/library/errs"
|
"backend/pkg/library/errs"
|
||||||
"backend/pkg/library/errs/code"
|
"backend/pkg/library/errs/code"
|
||||||
vi "backend/pkg/library/validator"
|
vi "backend/pkg/library/validator"
|
||||||
"backend/pkg/member/domain/usecase"
|
memberUC "backend/pkg/member/domain/usecase"
|
||||||
|
tokenUC "backend/pkg/permission/domain/usecase"
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||||
"github.com/zeromicro/go-zero/rest"
|
"github.com/zeromicro/go-zero/rest"
|
||||||
|
|
@ -15,8 +16,9 @@ import (
|
||||||
type ServiceContext struct {
|
type ServiceContext struct {
|
||||||
Config config.Config
|
Config config.Config
|
||||||
AuthMiddleware rest.Middleware
|
AuthMiddleware rest.Middleware
|
||||||
AccountUC usecase.AccountUseCase
|
AccountUC memberUC.AccountUseCase
|
||||||
Validate vi.Validate
|
Validate vi.Validate
|
||||||
|
TokenUC tokenUC.TokenUseCase
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServiceContext(c config.Config) *ServiceContext {
|
func NewServiceContext(c config.Config) *ServiceContext {
|
||||||
|
|
@ -31,5 +33,6 @@ func NewServiceContext(c config.Config) *ServiceContext {
|
||||||
AuthMiddleware: middleware.NewAuthMiddleware().Handle,
|
AuthMiddleware: middleware.NewAuthMiddleware().Handle,
|
||||||
AccountUC: NewAccountUC(&c, rds),
|
AccountUC: NewAccountUC(&c, rds),
|
||||||
Validate: vi.MustValidator(),
|
Validate: vi.MustValidator(),
|
||||||
|
TokenUC: NewTokenUC(&c, rds),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,18 @@
|
||||||
|
package svc
|
||||||
|
|
||||||
|
import (
|
||||||
|
"backend/internal/config"
|
||||||
|
"backend/pkg/permission/domain/usecase"
|
||||||
|
"backend/pkg/permission/repository"
|
||||||
|
uc "backend/pkg/permission/usecase"
|
||||||
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewTokenUC(c *config.Config, rds *redis.Redis) usecase.TokenUseCase {
|
||||||
|
return uc.MustTokenUseCase(uc.TokenUseCaseParam{
|
||||||
|
TokenRepo: repository.MustTokenRepository(repository.TokenRepositoryParam{
|
||||||
|
Redis: rds,
|
||||||
|
}),
|
||||||
|
Config: c,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -11,4 +11,5 @@ const (
|
||||||
CatSystem
|
CatSystem
|
||||||
CatPubSub
|
CatPubSub
|
||||||
CatService
|
CatService
|
||||||
|
CatToken
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -74,3 +74,16 @@ const (
|
||||||
ThirdParty
|
ThirdParty
|
||||||
ArkHTTP400 // Ark HTTP 400 錯誤
|
ArkHTTP400 // Ark HTTP 400 錯誤
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// 詳細代碼 - Token 類 09x
|
||||||
|
const (
|
||||||
|
_ = iota + CatToken
|
||||||
|
TokenCreateError // Token 創建錯誤
|
||||||
|
TokenValidateError // Token 驗證錯誤
|
||||||
|
TokenExpired // Token 過期
|
||||||
|
TokenNotFound // Token 未找到
|
||||||
|
TokenBlacklisted // Token 已被列入黑名單
|
||||||
|
InvalidJWT // 無效的 JWT
|
||||||
|
RefreshTokenError // Refresh Token 錯誤
|
||||||
|
OneTimeTokenError // 一次性 Token 錯誤
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -195,7 +195,7 @@ func TestVerifyPlatformAuthResult(t *testing.T) {
|
||||||
})
|
})
|
||||||
token, err := HashPassword("password", 10)
|
token, err := HashPassword("password", 10)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
fmt.Println(token)
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
param usecase.VerifyAuthResultRequest
|
param usecase.VerifyAuthResultRequest
|
||||||
|
|
|
||||||
|
|
@ -1,286 +1,364 @@
|
||||||
# Permission 權限管理模組 - Casbin 版
|
# Permission Module
|
||||||
|
|
||||||
一個基於 **Casbin** 的現代化權限管理模組,完全整合你的專案技術棧,提供強大且靈活的 RBAC 權限控制。
|
JWT Token 和 Refresh Token 管理模組,提供完整的身份驗證和授權功能。
|
||||||
|
|
||||||
## 🎯 為什麼選擇 Casbin?
|
## 📋 功能特性
|
||||||
|
|
||||||
你說得完全對!與其重新發明一個功能精簡的權限系統,**Casbin** 提供了:
|
### 🔐 JWT Token 管理
|
||||||
|
- **Access Token 生成**: 基於 JWT 標準生成存取權杖
|
||||||
|
- **Refresh Token 機制**: 支援長期有效的刷新權杖
|
||||||
|
- **One-Time Token**: 臨時性權杖,用於特殊場景
|
||||||
|
- **Token 驗證**: 完整的權杖驗證和解析功能
|
||||||
|
|
||||||
### ✅ **社群驗證的成熟解決方案**
|
### 🚫 黑名單機制
|
||||||
- 🌟 **6.7k+ GitHub Stars**,經過大量生產環境驗證
|
- **即時撤銷**: 將 JWT 權杖立即加入黑名單
|
||||||
- 🔧 **功能完整**:支援 RBAC、ABAC、RESTful、通配符、正則表達式
|
- **用戶登出**: 支援單一設備或全設備登出
|
||||||
- 📚 **文檔完善**:豐富的範例和最佳實踐
|
- **自動過期**: 黑名單條目會在權杖過期後自動清理
|
||||||
- 🛠️ **持續維護**:活躍的社群支持和定期更新
|
- **批量管理**: 支援批量黑名單操作
|
||||||
|
|
||||||
### ✅ **強大的功能特性**
|
### 💾 Redis 儲存
|
||||||
- **通配符支援**: `/api/users/*` 一個規則覆蓋所有子路徑
|
- **高效能**: 使用 Redis 作為主要儲存引擎
|
||||||
- **正則表達式**: 靈活的權限匹配規則
|
- **TTL 管理**: 自動管理權杖過期時間
|
||||||
- **角色繼承**: 複雜的組織架構支援
|
- **關聯管理**: 支援用戶、設備與權杖的關聯查詢
|
||||||
- **多種模型**: RBAC、ABAC、RESTful 等
|
|
||||||
- **策略持久化**: 自動保存到你的 MongoDB
|
|
||||||
|
|
||||||
## 📁 目錄結構
|
### 🔒 安全特性
|
||||||
|
- **HMAC-SHA256**: 使用安全的簽名算法
|
||||||
|
- **密鑰分離**: Access Token 和 Refresh Token 使用不同密鑰
|
||||||
|
- **設備限制**: 支援每用戶、每設備的權杖數量限制
|
||||||
|
- **過期控制**: 靈活的權杖過期時間配置
|
||||||
|
|
||||||
|
## 🏗️ 架構設計
|
||||||
|
|
||||||
|
本模組遵循 **Clean Architecture** 原則:
|
||||||
|
|
||||||
```
|
```
|
||||||
pkg/permission/
|
pkg/permission/
|
||||||
├── config/ # Casbin 模型配置
|
|
||||||
│ └── rbac_model.conf # RBAC 權限模型
|
|
||||||
├── domain/ # 領域層
|
├── domain/ # 領域層
|
||||||
│ ├── entity/ # 實體定義
|
│ ├── entity/ # 實體定義
|
||||||
│ ├── repository/ # 倉庫介面
|
│ ├── repository/ # 儲存庫介面
|
||||||
│ ├── usecase/ # 用例介面 (Casbin 增強)
|
│ ├── usecase/ # 用例介面
|
||||||
│ └── config/ # 配置定義
|
│ └── token/ # 權杖相關常數和類型
|
||||||
├── repository/ # 倉庫實現
|
├── usecase/ # 用例實現
|
||||||
│ ├── casbin_adapter.go # Casbin MongoDB 適配器
|
├── repository/ # 儲存庫實現
|
||||||
│ ├── client.go # 客戶端管理
|
└── mock/ # 測試模擬
|
||||||
│ ├── role.go # 角色管理
|
|
||||||
│ └── ... # 其他倉庫
|
|
||||||
├── usecase/ # 用例實現 (Casbin API)
|
|
||||||
├── svc/ # 初始化層
|
|
||||||
├── example/ # Casbin 使用範例
|
|
||||||
└── README.md # 本文件
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🚀 核心優勢
|
### 領域層 (Domain)
|
||||||
|
- **Entity**: 定義核心業務實體(Token、BlacklistEntry、Ticket)
|
||||||
|
- **Repository Interface**: 定義資料存取介面
|
||||||
|
- **UseCase Interface**: 定義業務用例介面
|
||||||
|
- **Token Types**: 權杖類型和常數定義
|
||||||
|
|
||||||
### **🔥 Casbin 強化功能**
|
### 用例層 (UseCase)
|
||||||
- **通配符權限**: `GET /api/users/*` 覆蓋所有用戶子路徑
|
- **TokenUseCase**: 核心業務邏輯實現
|
||||||
- **正則表達式**: `GET /api/users/\d+` 只允許數字 ID
|
- **JWT 處理**: 權杖生成、解析、驗證
|
||||||
- **角色繼承**: `admin` 繼承 `user` 的所有權限
|
- **黑名單管理**: 權杖撤銷和黑名單查詢
|
||||||
- **策略分離**: 權限策略與業務邏輯完全分離
|
|
||||||
- **動態更新**: 運行時動態添加/移除權限,無需重啟
|
|
||||||
|
|
||||||
### **⚡ 技術整合**
|
### 儲存層 (Repository)
|
||||||
- **MongoDB 適配器**: 策略自動持久化到你的 MongoDB
|
- **Redis 實現**: 基於 Redis 的資料存取
|
||||||
- **你的錯誤系統**: 完整的 `@errs/` 整合
|
- **關聯管理**: 用戶、設備、權杖關聯
|
||||||
- **緩存支援**: 使用你現有的 Redis 緩存
|
- **TTL 管理**: 自動過期處理
|
||||||
- **go-zero 整合**: 無縫整合到你的服務架構
|
|
||||||
|
|
||||||
## 🔧 Casbin 模型
|
## 🚀 快速開始
|
||||||
|
|
||||||
```ini
|
### 1. 配置設定
|
||||||
# pkg/permission/config/rbac_model.conf
|
|
||||||
[request_definition]
|
|
||||||
r = sub, obj, act
|
|
||||||
|
|
||||||
[policy_definition]
|
在 `internal/config/config.go` 中添加 Token 配置:
|
||||||
p = sub, obj, act
|
|
||||||
|
|
||||||
[role_definition]
|
|
||||||
g = _, _
|
|
||||||
|
|
||||||
[policy_effect]
|
|
||||||
e = some(where (p.eft == allow))
|
|
||||||
|
|
||||||
[matchers]
|
|
||||||
m = g(r.sub, p.sub) && keyMatch2(r.obj, p.obj) && regexMatch(r.act, p.act)
|
|
||||||
```
|
|
||||||
|
|
||||||
這個模型支援:
|
|
||||||
- **keyMatch2**: 通配符匹配 (`/api/users/*`)
|
|
||||||
- **regexMatch**: 正則表達式匹配
|
|
||||||
- **角色繼承**: `g(user, role)` 關係
|
|
||||||
|
|
||||||
## 📦 快速整合
|
|
||||||
|
|
||||||
### 1. 在你的 ServiceContext 中添加
|
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// internal/svc/service_context.go
|
type Config struct {
|
||||||
import "backend/pkg/permission/svc"
|
// ... 其他配置
|
||||||
|
|
||||||
type ServiceContext struct {
|
Token struct {
|
||||||
Config config.Config
|
AccessSecret string // Access Token 簽名密鑰
|
||||||
AuthMiddleware rest.Middleware
|
RefreshSecret string // Refresh Token 簽名密鑰
|
||||||
AccountUC usecase.AccountUseCase
|
AccessTokenExpiry time.Duration // Access Token 過期時間
|
||||||
PermissionUC permission.PermissionUseCase // ← Casbin 增強
|
RefreshTokenExpiry time.Duration // Refresh Token 過期時間
|
||||||
AuthUC permission.AuthUseCase
|
OneTimeTokenExpiry time.Duration // 一次性 Token 過期時間
|
||||||
RoleUC permission.RoleUseCase
|
MaxTokensPerUser int // 每用戶最大 Token 數
|
||||||
Validate vi.Validate
|
MaxTokensPerDevice int // 每設備最大 Token 數
|
||||||
}
|
|
||||||
|
|
||||||
func NewServiceContext(c config.Config) *ServiceContext {
|
|
||||||
rds, err := redis.NewRedis(c.RedisConf)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ServiceContext{
|
|
||||||
Config: c,
|
|
||||||
AuthMiddleware: middleware.NewAuthMiddleware().Handle,
|
|
||||||
AccountUC: NewAccountUC(&c, rds),
|
|
||||||
PermissionUC: svc.NewPermissionUC(&c, rds), // ← Casbin 自動初始化
|
|
||||||
AuthUC: svc.NewAuthUC(&c, rds),
|
|
||||||
RoleUC: svc.NewRoleUC(&c),
|
|
||||||
Validate: vi.MustValidator(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### 2. Casbin 強大功能使用
|
### 2. 初始化模組
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// 🔥 通配符權限 - 一個規則覆蓋所有子路徑
|
import (
|
||||||
err = permissionUC.AddPermissionForRole(ctx, "admin", "/api/users/*", ".*")
|
"backend/pkg/permission/repository"
|
||||||
|
"backend/pkg/permission/usecase"
|
||||||
// ✅ 這些都會被允許:
|
|
||||||
// GET /api/users/123
|
|
||||||
// POST /api/users/123/profile
|
|
||||||
// DELETE /api/users/123/avatar
|
|
||||||
|
|
||||||
// 🔥 正則表達式權限 - 精確控制
|
|
||||||
err = permissionUC.AddPermissionForRole(ctx, "viewer", "/api/users/\\d+", "GET")
|
|
||||||
|
|
||||||
// ✅ 只允許: GET /api/users/123 (數字ID)
|
|
||||||
// ❌ 拒絕: GET /api/users/abc (非數字ID)
|
|
||||||
|
|
||||||
// 🔥 角色繼承 - 組織架構支援
|
|
||||||
err = permissionUC.AddRoleForUser(ctx, "john", "admin")
|
|
||||||
err = permissionUC.AddRoleInheritance(ctx, "admin", "user")
|
|
||||||
|
|
||||||
// john 自動擁有 admin 和 user 的所有權限
|
|
||||||
|
|
||||||
// 🔥 動態權限檢查
|
|
||||||
hasPermission, err := permissionUC.CheckUserPermission(ctx, "john", "GET", "/api/users/123")
|
|
||||||
hasPattern, err := permissionUC.CheckPatternPermission(ctx, "john", "/api/users/456", "DELETE")
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🎯 實際使用場景
|
|
||||||
|
|
||||||
### **API 權限控制**
|
|
||||||
```go
|
|
||||||
// 中間件中使用
|
|
||||||
func PermissionMiddleware(permissionUC permission.PermissionUseCase) rest.Middleware {
|
|
||||||
return func(next http.HandlerFunc) http.HandlerFunc {
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
userID := getUserIDFromJWT(r)
|
|
||||||
|
|
||||||
// Casbin 自動處理通配符和正則表達式
|
|
||||||
hasPermission, err := permissionUC.CheckUserPermission(
|
|
||||||
r.Context(), userID, r.Method, r.URL.Path,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil || !hasPermission {
|
// 初始化 Repository
|
||||||
httpx.WriteJsonCtx(r.Context(), w, 403, types.ErrorResp{
|
tokenRepo := repository.MustTokenRepository(repository.TokenRepositoryParam{
|
||||||
Code: 4030001,
|
Redis: redisClient,
|
||||||
Msg: "權限不足",
|
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 初始化 UseCase
|
||||||
|
tokenUseCase := usecase.MustTokenUseCase(usecase.TokenUseCaseParam{
|
||||||
|
TokenRepo: tokenRepo,
|
||||||
|
Config: config,
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 基本使用
|
||||||
|
|
||||||
|
#### 創建 Access Token
|
||||||
|
|
||||||
|
```go
|
||||||
|
req := entity.AuthorizationReq{
|
||||||
|
GrantType: token.PasswordCredentials.ToString(),
|
||||||
|
Scope: "read write",
|
||||||
|
DeviceID: "device123",
|
||||||
|
IsRefreshToken: true,
|
||||||
|
Claims: map[string]string{
|
||||||
|
"uid": "user123",
|
||||||
|
"role": "admin",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := tokenUseCase.NewToken(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Access Token: %s\n", resp.AccessToken)
|
||||||
|
fmt.Printf("Refresh Token: %s\n", resp.RefreshToken)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 驗證 Token
|
||||||
|
|
||||||
|
```go
|
||||||
|
req := entity.ValidationTokenReq{
|
||||||
|
Token: accessToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := tokenUseCase.ValidationToken(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Token validation failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
next(w, r)
|
fmt.Printf("Token is valid for user: %s\n", resp.Token.UID)
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### **權限初始化**
|
#### 撤銷 Token (加入黑名單)
|
||||||
```go
|
|
||||||
// 初始化基礎權限
|
|
||||||
func InitPermissions(ctx context.Context, permissionUC permission.PermissionUseCase) {
|
|
||||||
// 🔥 管理員擁有所有 API 權限
|
|
||||||
permissionUC.AddPermissionForRole(ctx, "admin", "/api/*", ".*")
|
|
||||||
|
|
||||||
// 🔥 用戶只能查看和更新自己的資料
|
|
||||||
permissionUC.AddPermissionForRole(ctx, "user", "/api/users/{{.UserID}}", "GET|PUT")
|
|
||||||
|
|
||||||
// 🔥 訪客只能查看公開內容
|
|
||||||
permissionUC.AddPermissionForRole(ctx, "guest", "/api/public/*", "GET")
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 📊 Casbin 策略儲存
|
|
||||||
|
|
||||||
### **MongoDB 自動持久化**
|
|
||||||
```javascript
|
|
||||||
// casbin_rules collection
|
|
||||||
{
|
|
||||||
_id: ObjectId,
|
|
||||||
ptype: "p", // 策略類型
|
|
||||||
v0: "role_admin_001", // 主體 (用戶/角色)
|
|
||||||
v1: "/api/users/*", // 對象 (資源)
|
|
||||||
v2: ".*", // 行為 (動作)
|
|
||||||
v3: "", // 擴展字段
|
|
||||||
v4: "", // 擴展字段
|
|
||||||
v5: "" // 擴展字段
|
|
||||||
}
|
|
||||||
|
|
||||||
// 角色關係
|
|
||||||
{
|
|
||||||
ptype: "g", // 分組策略
|
|
||||||
v0: "user_001", // 用戶
|
|
||||||
v1: "role_admin_001", // 角色
|
|
||||||
v2: "",
|
|
||||||
...
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### **索引優化**
|
|
||||||
```javascript
|
|
||||||
// 自動創建的索引
|
|
||||||
db.casbin_rules.createIndex({"ptype": 1})
|
|
||||||
db.casbin_rules.createIndex({"ptype": 1, "v0": 1})
|
|
||||||
db.casbin_rules.createIndex({"ptype": 1, "v0": 1, "v1": 1})
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🔥 Casbin vs 自建系統
|
|
||||||
|
|
||||||
| 功能 | 自建系統 | Casbin |
|
|
||||||
|-----|---------|--------|
|
|
||||||
| **通配符支援** | ❌ 需要自己實現 | ✅ 內建支援 `/api/users/*` |
|
|
||||||
| **正則表達式** | ❌ 需要自己實現 | ✅ 內建支援 `/api/users/\\d+` |
|
|
||||||
| **角色繼承** | ❌ 需要複雜邏輯 | ✅ 自動處理繼承鏈 |
|
|
||||||
| **策略語言** | ❌ 硬編碼邏輯 | ✅ 靈活的 DSL |
|
|
||||||
| **性能優化** | ❌ 需要自己優化 | ✅ 內建緩存和索引 |
|
|
||||||
| **社群支持** | ❌ 需要自己維護 | ✅ 活躍社群,持續更新 |
|
|
||||||
| **文檔和範例** | ❌ 需要自己寫文檔 | ✅ 豐富的官方文檔 |
|
|
||||||
| **測試覆蓋** | ❌ 需要自己測試 | ✅ 大量生產環境驗證 |
|
|
||||||
|
|
||||||
## 🚀 進階功能
|
|
||||||
|
|
||||||
### **ABAC 屬性權限**
|
|
||||||
```go
|
|
||||||
// 未來可以升級到 ABAC 模型
|
|
||||||
// 支援基於用戶屬性、資源屬性、環境屬性的權限控制
|
|
||||||
permissionUC.CheckPermissionWithAttributes(ctx,
|
|
||||||
map[string]interface{}{
|
|
||||||
"user.department": "engineering",
|
|
||||||
"resource.owner": "john",
|
|
||||||
"time.hour": 9,
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
### **策略管理 API**
|
|
||||||
```go
|
|
||||||
// 動態管理策略
|
|
||||||
policies, err := permissionUC.GetAllPolicies(ctx)
|
|
||||||
filtered, err := permissionUC.GetFilteredPolicies(ctx, 0, "role_admin")
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🎯 遷移優勢
|
|
||||||
|
|
||||||
1. **立即獲得成熟功能** - 通配符、正則表達式、角色繼承
|
|
||||||
2. **減少維護成本** - 社群維護,無需自己投入開發時間
|
|
||||||
3. **擴展性更強** - 支援複雜的權限模型,適應業務成長
|
|
||||||
4. **性能更好** - 內建優化,大量生產環境驗證
|
|
||||||
5. **學習成本低** - 豐富的文檔和社群範例
|
|
||||||
|
|
||||||
## 🔧 立即使用
|
|
||||||
|
|
||||||
```go
|
```go
|
||||||
// 1. 初始化 (自動設置 Casbin)
|
err := tokenUseCase.BlacklistToken(ctx, accessToken, "user logout")
|
||||||
PermissionUC: svc.NewPermissionUC(&c, rds),
|
if err != nil {
|
||||||
|
log.Printf("Failed to blacklist token: %v", err)
|
||||||
// 2. 添加權限策略
|
}
|
||||||
permissionUC.AddPermissionForRole(ctx, "admin", "/api/users/*", ".*")
|
|
||||||
|
|
||||||
// 3. 分配角色
|
|
||||||
permissionUC.AddRoleForUser(ctx, "john", "admin")
|
|
||||||
|
|
||||||
// 4. 檢查權限 (自動處理通配符)
|
|
||||||
hasPermission, err := permissionUC.CheckUserPermission(ctx, "john", "GET", "/api/users/123")
|
|
||||||
```
|
```
|
||||||
|
|
||||||
現在你擁有了一個**功能完整、社群驗證、持續維護**的權限系統!🎯
|
#### 檢查黑名單
|
||||||
|
|
||||||
**Casbin** 讓你專注於業務邏輯,而不是重新發明權限輪子。
|
```go
|
||||||
|
isBlacklisted, err := tokenUseCase.IsTokenBlacklisted(ctx, jti)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to check blacklist: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isBlacklisted {
|
||||||
|
log.Println("Token is blacklisted")
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🧪 測試
|
||||||
|
|
||||||
|
### 運行測試
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 運行所有測試
|
||||||
|
go test ./pkg/permission/...
|
||||||
|
|
||||||
|
# 運行特定模組測試
|
||||||
|
go test ./pkg/permission/usecase/
|
||||||
|
go test ./pkg/permission/repository/
|
||||||
|
|
||||||
|
# 運行測試並顯示覆蓋率
|
||||||
|
go test -cover ./pkg/permission/...
|
||||||
|
|
||||||
|
# 生成覆蓋率報告
|
||||||
|
go test -coverprofile=coverage.out ./pkg/permission/...
|
||||||
|
go tool cover -html=coverage.out
|
||||||
|
```
|
||||||
|
|
||||||
|
### 測試結構
|
||||||
|
|
||||||
|
- **UseCase Tests**: 業務邏輯測試,使用 Mock Repository
|
||||||
|
- **Repository Tests**: 資料存取測試,使用 MiniRedis
|
||||||
|
- **JWT Tests**: 權杖生成和解析測試
|
||||||
|
- **Integration Tests**: 整合測試
|
||||||
|
|
||||||
|
## 📊 API 參考
|
||||||
|
|
||||||
|
### TokenUseCase 介面
|
||||||
|
|
||||||
|
```go
|
||||||
|
type TokenUseCase interface {
|
||||||
|
// 基本 Token 操作
|
||||||
|
NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error)
|
||||||
|
RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error)
|
||||||
|
ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error)
|
||||||
|
|
||||||
|
// Token 管理
|
||||||
|
CancelToken(ctx context.Context, req entity.CancelTokenReq) error
|
||||||
|
CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error
|
||||||
|
CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error
|
||||||
|
|
||||||
|
// 查詢操作
|
||||||
|
GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error)
|
||||||
|
GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error)
|
||||||
|
|
||||||
|
// 一次性 Token
|
||||||
|
NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error)
|
||||||
|
CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error
|
||||||
|
|
||||||
|
// 黑名單操作
|
||||||
|
BlacklistToken(ctx context.Context, token string, reason string) error
|
||||||
|
IsTokenBlacklisted(ctx context.Context, jti string) (bool, error)
|
||||||
|
BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error
|
||||||
|
|
||||||
|
// 工具方法
|
||||||
|
ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 主要實體
|
||||||
|
|
||||||
|
#### Token 實體
|
||||||
|
```go
|
||||||
|
type Token struct {
|
||||||
|
ID string // 權杖唯一標識
|
||||||
|
UID string // 用戶 ID
|
||||||
|
DeviceID string // 設備 ID
|
||||||
|
AccessToken string // Access Token
|
||||||
|
RefreshToken string // Refresh Token
|
||||||
|
ExpiresIn int // 過期時間(秒)
|
||||||
|
AccessCreateAt time.Time // Access Token 創建時間
|
||||||
|
RefreshCreateAt time.Time // Refresh Token 創建時間
|
||||||
|
RefreshExpiresIn int // Refresh Token 過期時間(秒)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 黑名單實體
|
||||||
|
```go
|
||||||
|
type BlacklistEntry struct {
|
||||||
|
JTI string // JWT ID
|
||||||
|
UID string // 用戶 ID
|
||||||
|
TokenID string // Token ID
|
||||||
|
Reason string // 加入黑名單原因
|
||||||
|
ExpiresAt int64 // 原始權杖過期時間
|
||||||
|
CreatedAt int64 // 加入黑名單時間
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔧 配置參數
|
||||||
|
|
||||||
|
| 參數 | 類型 | 說明 | 預設值 |
|
||||||
|
|------|------|------|--------|
|
||||||
|
| `AccessSecret` | string | Access Token 簽名密鑰 | 必填 |
|
||||||
|
| `RefreshSecret` | string | Refresh Token 簽名密鑰 | 必填 |
|
||||||
|
| `AccessTokenExpiry` | Duration | Access Token 過期時間 | 15分鐘 |
|
||||||
|
| `RefreshTokenExpiry` | Duration | Refresh Token 過期時間 | 7天 |
|
||||||
|
| `OneTimeTokenExpiry` | Duration | 一次性 Token 過期時間 | 5分鐘 |
|
||||||
|
| `MaxTokensPerUser` | int | 每用戶最大 Token 數 | 10 |
|
||||||
|
| `MaxTokensPerDevice` | int | 每設備最大 Token 數 | 5 |
|
||||||
|
|
||||||
|
## 🚨 錯誤處理
|
||||||
|
|
||||||
|
模組定義了完整的錯誤類型:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Token 驗證錯誤
|
||||||
|
var (
|
||||||
|
ErrInvalidTokenID = errors.New("invalid token ID")
|
||||||
|
ErrInvalidUID = errors.New("invalid UID")
|
||||||
|
ErrTokenExpired = errors.New("token expired")
|
||||||
|
ErrTokenNotFound = errors.New("token not found")
|
||||||
|
)
|
||||||
|
|
||||||
|
// JWT 特定錯誤
|
||||||
|
var (
|
||||||
|
ErrInvalidJWTToken = errors.New("invalid JWT token")
|
||||||
|
ErrJWTSigningFailed = errors.New("JWT signing failed")
|
||||||
|
ErrJWTParsingFailed = errors.New("JWT parsing failed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// 黑名單錯誤
|
||||||
|
var (
|
||||||
|
ErrTokenBlacklisted = errors.New("token is blacklisted")
|
||||||
|
ErrBlacklistNotFound = errors.New("blacklist entry not found")
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔒 安全考量
|
||||||
|
|
||||||
|
### 1. 密鑰管理
|
||||||
|
- 使用強密鑰(至少 256 位)
|
||||||
|
- Access Token 和 Refresh Token 使用不同密鑰
|
||||||
|
- 定期輪換密鑰
|
||||||
|
|
||||||
|
### 2. 權杖過期
|
||||||
|
- Access Token 使用較短過期時間(15分鐘)
|
||||||
|
- Refresh Token 使用較長過期時間(7天)
|
||||||
|
- 支援自定義過期時間
|
||||||
|
|
||||||
|
### 3. 黑名單機制
|
||||||
|
- 即時撤銷可疑權杖
|
||||||
|
- 支援批量撤銷
|
||||||
|
- 自動清理過期條目
|
||||||
|
|
||||||
|
### 4. 限制機制
|
||||||
|
- 每用戶權杖數量限制
|
||||||
|
- 每設備權杖數量限制
|
||||||
|
- 防止權杖濫用
|
||||||
|
|
||||||
|
## 📈 效能優化
|
||||||
|
|
||||||
|
### 1. Redis 優化
|
||||||
|
- 使用適當的 TTL 避免記憶體洩漏
|
||||||
|
- 批量操作減少網路往返
|
||||||
|
- 使用 Pipeline 提升效能
|
||||||
|
|
||||||
|
### 2. JWT 優化
|
||||||
|
- 最小化 Claims 數據大小
|
||||||
|
- 使用高效的序列化格式
|
||||||
|
- 快取常用的解析結果
|
||||||
|
|
||||||
|
### 3. 黑名單優化
|
||||||
|
- 使用 SCAN 而非 KEYS 遍歷
|
||||||
|
- 批量檢查黑名單狀態
|
||||||
|
- 定期清理過期條目
|
||||||
|
|
||||||
|
## 🤝 貢獻指南
|
||||||
|
|
||||||
|
1. Fork 本專案
|
||||||
|
2. 創建功能分支 (`git checkout -b feature/amazing-feature`)
|
||||||
|
3. 提交變更 (`git commit -m 'Add some amazing feature'`)
|
||||||
|
4. 推送到分支 (`git push origin feature/amazing-feature`)
|
||||||
|
5. 開啟 Pull Request
|
||||||
|
|
||||||
|
### 開發規範
|
||||||
|
|
||||||
|
- 遵循 Go 編碼規範
|
||||||
|
- 保持測試覆蓋率 > 80%
|
||||||
|
- 添加適當的文檔註釋
|
||||||
|
- 使用有意義的提交訊息
|
||||||
|
|
||||||
|
## 📄 授權條款
|
||||||
|
|
||||||
|
本專案採用 MIT 授權條款 - 詳見 [LICENSE](LICENSE) 檔案
|
||||||
|
|
||||||
|
## 📞 聯絡資訊
|
||||||
|
|
||||||
|
如有問題或建議,請通過以下方式聯絡:
|
||||||
|
|
||||||
|
- 開啟 Issue
|
||||||
|
- 發送 Pull Request
|
||||||
|
- 聯絡維護團隊
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**注意**: 本模組是 PlayOne Backend 專案的一部分,請確保與整體架構保持一致。
|
||||||
|
|
@ -1,51 +1,64 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"backend/pkg/permission/domain/token"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config 權限系統配置
|
// Config represents the configuration for the permission module
|
||||||
type Config struct {
|
type Config struct {
|
||||||
JWT JWTConfig `json:"jwt"`
|
Token TokenConfig `json:"token" yaml:"token"`
|
||||||
Database DatabaseConfig `json:"database"`
|
|
||||||
Casbin CasbinConfig `json:"casbin"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// JWTConfig JWT 配置
|
// TokenConfig represents token configuration
|
||||||
type JWTConfig struct {
|
type TokenConfig struct {
|
||||||
Secret string `json:"secret"`
|
// JWT signing configuration
|
||||||
AccessExpires time.Duration `json:"access_expires"`
|
Secret string `json:"secret" yaml:"secret"`
|
||||||
RefreshExpires time.Duration `json:"refresh_expires"`
|
|
||||||
|
// Token expiration settings
|
||||||
|
Expired ExpiredConfig `json:"expired" yaml:"expired"`
|
||||||
|
RefreshExpires ExpiredConfig `json:"refresh_expires" yaml:"refresh_expires"`
|
||||||
|
|
||||||
|
// Issuer information
|
||||||
|
Issuer string `json:"issuer" yaml:"issuer"`
|
||||||
|
|
||||||
|
// Token limits
|
||||||
|
MaxTokensPerUser int `json:"max_tokens_per_user" yaml:"max_tokens_per_user"`
|
||||||
|
MaxTokensPerDevice int `json:"max_tokens_per_device" yaml:"max_tokens_per_device"`
|
||||||
|
|
||||||
|
// Security settings
|
||||||
|
EnableDeviceTracking bool `json:"enable_device_tracking" yaml:"enable_device_tracking"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// DatabaseConfig 數據庫配置
|
// ExpiredConfig represents expiration configuration
|
||||||
type DatabaseConfig struct {
|
type ExpiredConfig struct {
|
||||||
URI string `json:"uri"`
|
Seconds int64 `json:"seconds" yaml:"seconds"`
|
||||||
Database string `json:"database"`
|
|
||||||
Timeout time.Duration `json:"timeout"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CasbinConfig Casbin 配置
|
// Validate validates the token configuration
|
||||||
type CasbinConfig struct {
|
func (c *TokenConfig) Validate() error {
|
||||||
ModelPath string `json:"model_path"` // RBAC 模型文件路徑
|
if c.Secret == "" {
|
||||||
AutoSave bool `json:"auto_save"` // 自動保存策略
|
return ErrMissingSecret
|
||||||
AutoLoad bool `json:"auto_load"` // 自動載入策略
|
|
||||||
AutoLoadDuration time.Duration `json:"auto_load_duration"` // 自動載入間隔
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultConfig 返回默認配置
|
if c.Expired.Seconds <= 0 {
|
||||||
func DefaultConfig() Config {
|
c.Expired.Seconds = token.DefaultAccessTokenExpiry
|
||||||
return Config{
|
|
||||||
JWT: JWTConfig{
|
|
||||||
Secret: "your-secret-key",
|
|
||||||
AccessExpires: time.Hour * 2, // 2 小時
|
|
||||||
RefreshExpires: time.Hour * 24 * 7, // 7 天
|
|
||||||
},
|
|
||||||
Casbin: CasbinConfig{
|
|
||||||
ModelPath: "etc/rbac_model.conf",
|
|
||||||
AutoSave: true,
|
|
||||||
AutoLoad: true,
|
|
||||||
AutoLoadDuration: time.Second * 10,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.RefreshExpires.Seconds <= 0 {
|
||||||
|
c.RefreshExpires.Seconds = token.DefaultRefreshTokenExpiry
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Issuer == "" {
|
||||||
|
c.Issuer = "playone-backend"
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.MaxTokensPerUser <= 0 {
|
||||||
|
c.MaxTokensPerUser = token.MaxTokensPerUser
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.MaxTokensPerDevice <= 0 {
|
||||||
|
c.MaxTokensPerDevice = token.MaxTokensPerDevice
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -0,0 +1,243 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"backend/pkg/permission/domain/token"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTokenConfig_Validate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
config *TokenConfig
|
||||||
|
wantErr bool
|
||||||
|
check func(*testing.T, *TokenConfig)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid config",
|
||||||
|
config: &TokenConfig{
|
||||||
|
Secret: "test-secret",
|
||||||
|
Expired: ExpiredConfig{
|
||||||
|
Seconds: 900,
|
||||||
|
},
|
||||||
|
RefreshExpires: ExpiredConfig{
|
||||||
|
Seconds: 604800,
|
||||||
|
},
|
||||||
|
Issuer: "test-issuer",
|
||||||
|
MaxTokensPerUser: 10,
|
||||||
|
MaxTokensPerDevice: 5,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
check: func(t *testing.T, c *TokenConfig) {
|
||||||
|
assert.Equal(t, "test-secret", c.Secret)
|
||||||
|
assert.Equal(t, int64(900), c.Expired.Seconds)
|
||||||
|
assert.Equal(t, int64(604800), c.RefreshExpires.Seconds)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing secret",
|
||||||
|
config: &TokenConfig{
|
||||||
|
Secret: "",
|
||||||
|
Expired: ExpiredConfig{
|
||||||
|
Seconds: 900,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
check: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "use default expiry",
|
||||||
|
config: &TokenConfig{
|
||||||
|
Secret: "test-secret",
|
||||||
|
Expired: ExpiredConfig{
|
||||||
|
Seconds: 0,
|
||||||
|
},
|
||||||
|
RefreshExpires: ExpiredConfig{
|
||||||
|
Seconds: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
check: func(t *testing.T, c *TokenConfig) {
|
||||||
|
assert.Equal(t, int64(token.DefaultAccessTokenExpiry), c.Expired.Seconds)
|
||||||
|
assert.Equal(t, int64(token.DefaultRefreshTokenExpiry), c.RefreshExpires.Seconds)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "use default issuer",
|
||||||
|
config: &TokenConfig{
|
||||||
|
Secret: "test-secret",
|
||||||
|
Issuer: "",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
check: func(t *testing.T, c *TokenConfig) {
|
||||||
|
assert.Equal(t, "playone-backend", c.Issuer)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "use default token limits",
|
||||||
|
config: &TokenConfig{
|
||||||
|
Secret: "test-secret",
|
||||||
|
MaxTokensPerUser: 0,
|
||||||
|
MaxTokensPerDevice: 0,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
check: func(t *testing.T, c *TokenConfig) {
|
||||||
|
assert.Equal(t, token.MaxTokensPerUser, c.MaxTokensPerUser)
|
||||||
|
assert.Equal(t, token.MaxTokensPerDevice, c.MaxTokensPerDevice)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative expiry time",
|
||||||
|
config: &TokenConfig{
|
||||||
|
Secret: "test-secret",
|
||||||
|
Expired: ExpiredConfig{
|
||||||
|
Seconds: -100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
check: func(t *testing.T, c *TokenConfig) {
|
||||||
|
// Negative values should be replaced with defaults
|
||||||
|
assert.Equal(t, int64(token.DefaultAccessTokenExpiry), c.Expired.Seconds)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom token limits",
|
||||||
|
config: &TokenConfig{
|
||||||
|
Secret: "test-secret",
|
||||||
|
MaxTokensPerUser: 20,
|
||||||
|
MaxTokensPerDevice: 10,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
check: func(t *testing.T, c *TokenConfig) {
|
||||||
|
assert.Equal(t, 20, c.MaxTokensPerUser)
|
||||||
|
assert.Equal(t, 10, c.MaxTokensPerDevice)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "device tracking enabled",
|
||||||
|
config: &TokenConfig{
|
||||||
|
Secret: "test-secret",
|
||||||
|
EnableDeviceTracking: true,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
check: func(t *testing.T, c *TokenConfig) {
|
||||||
|
assert.True(t, c.EnableDeviceTracking)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "device tracking disabled",
|
||||||
|
config: &TokenConfig{
|
||||||
|
Secret: "test-secret",
|
||||||
|
EnableDeviceTracking: false,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
check: func(t *testing.T, c *TokenConfig) {
|
||||||
|
assert.False(t, c.EnableDeviceTracking)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.config.Validate()
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
if tt.check != nil {
|
||||||
|
tt.check(t, tt.config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpiredConfig(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
seconds int64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "900 seconds (15 minutes)",
|
||||||
|
seconds: 900,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "3600 seconds (1 hour)",
|
||||||
|
seconds: 3600,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "604800 seconds (7 days)",
|
||||||
|
seconds: 604800,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero seconds",
|
||||||
|
seconds: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
config := ExpiredConfig{
|
||||||
|
Seconds: tt.seconds,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.seconds, config.Seconds)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_Struct(t *testing.T) {
|
||||||
|
t.Run("full config", func(t *testing.T) {
|
||||||
|
config := Config{
|
||||||
|
Token: TokenConfig{
|
||||||
|
Secret: "my-secret",
|
||||||
|
Expired: ExpiredConfig{
|
||||||
|
Seconds: 900,
|
||||||
|
},
|
||||||
|
RefreshExpires: ExpiredConfig{
|
||||||
|
Seconds: 604800,
|
||||||
|
},
|
||||||
|
Issuer: "my-app",
|
||||||
|
MaxTokensPerUser: 15,
|
||||||
|
MaxTokensPerDevice: 8,
|
||||||
|
EnableDeviceTracking: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotNil(t, config.Token)
|
||||||
|
assert.Equal(t, "my-secret", config.Token.Secret)
|
||||||
|
assert.Equal(t, int64(900), config.Token.Expired.Seconds)
|
||||||
|
assert.Equal(t, int64(604800), config.Token.RefreshExpires.Seconds)
|
||||||
|
assert.Equal(t, "my-app", config.Token.Issuer)
|
||||||
|
assert.Equal(t, 15, config.Token.MaxTokensPerUser)
|
||||||
|
assert.Equal(t, 8, config.Token.MaxTokensPerDevice)
|
||||||
|
assert.True(t, config.Token.EnableDeviceTracking)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty config", func(t *testing.T) {
|
||||||
|
config := Config{}
|
||||||
|
|
||||||
|
assert.Empty(t, config.Token.Secret)
|
||||||
|
assert.Equal(t, int64(0), config.Token.Expired.Seconds)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenConfig_AllDefaults(t *testing.T) {
|
||||||
|
config := &TokenConfig{
|
||||||
|
Secret: "test-secret", // Only required field
|
||||||
|
}
|
||||||
|
|
||||||
|
err := config.Validate()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Check all defaults are applied
|
||||||
|
assert.Equal(t, int64(token.DefaultAccessTokenExpiry), config.Expired.Seconds)
|
||||||
|
assert.Equal(t, int64(token.DefaultRefreshTokenExpiry), config.RefreshExpires.Seconds)
|
||||||
|
assert.Equal(t, "playone-backend", config.Issuer)
|
||||||
|
assert.Equal(t, token.MaxTokensPerUser, config.MaxTokensPerUser)
|
||||||
|
assert.Equal(t, token.MaxTokensPerDevice, config.MaxTokensPerDevice)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMissingSecret = fmt.Errorf("missing JWT secret key")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Module name
|
||||||
|
ModuleName = "permission"
|
||||||
|
|
||||||
|
// Default issuer
|
||||||
|
DefaultIssuer = "playone-backend"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
package entity
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// BlacklistEntry represents a blacklisted JWT token
|
||||||
|
type BlacklistEntry struct {
|
||||||
|
JTI string `json:"jti"` // JWT ID (unique identifier)
|
||||||
|
UID string `json:"uid"` // User ID
|
||||||
|
TokenID string `json:"token_id"` // Token ID from original token
|
||||||
|
Reason string `json:"reason"` // Reason for blacklisting
|
||||||
|
ExpiresAt int64 `json:"expires_at"` // When the original token expires
|
||||||
|
CreatedAt int64 `json:"created_at"` // When it was blacklisted
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpired checks if the blacklist entry is expired
|
||||||
|
func (b *BlacklistEntry) IsExpired() bool {
|
||||||
|
return b.ExpiresAt <= time.Now().Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the blacklist entry
|
||||||
|
func (b *BlacklistEntry) Validate() error {
|
||||||
|
if b.JTI == "" {
|
||||||
|
return ErrInvalidJTI
|
||||||
|
}
|
||||||
|
if b.UID == "" {
|
||||||
|
return ErrInvalidUID
|
||||||
|
}
|
||||||
|
if b.TokenID == "" {
|
||||||
|
return ErrInvalidTokenID
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,194 @@
|
||||||
|
package entity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBlacklistEntry_IsExpired(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
entry *BlacklistEntry
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "expired entry",
|
||||||
|
entry: &BlacklistEntry{
|
||||||
|
JTI: "test-jti",
|
||||||
|
UID: "test-uid",
|
||||||
|
TokenID: "test-token",
|
||||||
|
ExpiresAt: time.Now().Add(-time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "not expired entry",
|
||||||
|
entry: &BlacklistEntry{
|
||||||
|
JTI: "test-jti",
|
||||||
|
UID: "test-uid",
|
||||||
|
TokenID: "test-token",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "exactly at expiry time",
|
||||||
|
entry: &BlacklistEntry{
|
||||||
|
JTI: "test-jti",
|
||||||
|
UID: "test-uid",
|
||||||
|
TokenID: "test-token",
|
||||||
|
ExpiresAt: time.Now().Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
},
|
||||||
|
expected: true, // Equal to current time should be considered expired
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.entry.IsExpired()
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlacklistEntry_Validate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
entry *BlacklistEntry
|
||||||
|
wantErr bool
|
||||||
|
expectedErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid entry",
|
||||||
|
entry: &BlacklistEntry{
|
||||||
|
JTI: "test-jti",
|
||||||
|
UID: "test-uid",
|
||||||
|
TokenID: "test-token",
|
||||||
|
Reason: "user logout",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing JTI",
|
||||||
|
entry: &BlacklistEntry{
|
||||||
|
JTI: "",
|
||||||
|
UID: "test-uid",
|
||||||
|
TokenID: "test-token",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
expectedErr: ErrInvalidJTI,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing UID",
|
||||||
|
entry: &BlacklistEntry{
|
||||||
|
JTI: "test-jti",
|
||||||
|
UID: "",
|
||||||
|
TokenID: "test-token",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
expectedErr: ErrInvalidUID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing TokenID",
|
||||||
|
entry: &BlacklistEntry{
|
||||||
|
JTI: "test-jti",
|
||||||
|
UID: "test-uid",
|
||||||
|
TokenID: "",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
expectedErr: ErrInvalidTokenID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all fields missing",
|
||||||
|
entry: &BlacklistEntry{
|
||||||
|
JTI: "",
|
||||||
|
UID: "",
|
||||||
|
TokenID: "",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
expectedErr: ErrInvalidJTI, // First error encountered
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.entry.Validate()
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
if tt.expectedErr != nil {
|
||||||
|
assert.Equal(t, tt.expectedErr, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlacklistEntry_CreatedAt(t *testing.T) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
entry := &BlacklistEntry{
|
||||||
|
JTI: "test-jti",
|
||||||
|
UID: "test-uid",
|
||||||
|
TokenID: "test-token",
|
||||||
|
Reason: "security",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, now, entry.CreatedAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlacklistEntry_Reason(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
reason string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "user logout reason",
|
||||||
|
reason: "user logout",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "security breach reason",
|
||||||
|
reason: "security breach detected",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "password reset reason",
|
||||||
|
reason: "password reset",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty reason",
|
||||||
|
reason: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
entry := &BlacklistEntry{
|
||||||
|
JTI: "test-jti",
|
||||||
|
UID: "test-uid",
|
||||||
|
TokenID: "test-token",
|
||||||
|
Reason: tt.reason,
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.reason, entry.Reason)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -1,23 +0,0 @@
|
||||||
package entity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"go.mongodb.org/mongo-driver/v2/bson"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Client 客戶端實體
|
|
||||||
type Client struct {
|
|
||||||
ID bson.ObjectID `bson:"_id,omitempty" json:"id"`
|
|
||||||
Name string `bson:"name" json:"name"`
|
|
||||||
ClientID string `bson:"client_id" json:"client_id"`
|
|
||||||
Secret string `bson:"secret" json:"secret"`
|
|
||||||
Status int `bson:"status" json:"status"`
|
|
||||||
CreateTime time.Time `bson:"create_time" json:"create_time"`
|
|
||||||
UpdateTime time.Time `bson:"update_time" json:"update_time"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CollectionName 返回集合名稱
|
|
||||||
func (c *Client) CollectionName() string {
|
|
||||||
return "clients"
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,31 @@
|
||||||
|
package entity
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Token validation errors
|
||||||
|
ErrInvalidTokenID = errors.New("invalid token ID")
|
||||||
|
ErrInvalidUID = errors.New("invalid UID")
|
||||||
|
ErrInvalidAccessToken = errors.New("invalid access token")
|
||||||
|
ErrTokenExpired = errors.New("token expired")
|
||||||
|
ErrTokenNotFound = errors.New("token not found")
|
||||||
|
|
||||||
|
// JWT specific errors
|
||||||
|
ErrInvalidJWTToken = errors.New("invalid JWT token")
|
||||||
|
ErrJWTSigningFailed = errors.New("JWT signing failed")
|
||||||
|
ErrJWTParsingFailed = errors.New("JWT parsing failed")
|
||||||
|
ErrInvalidSigningKey = errors.New("invalid signing key")
|
||||||
|
ErrInvalidJTI = errors.New("invalid JWT ID")
|
||||||
|
|
||||||
|
// Refresh token errors
|
||||||
|
ErrRefreshTokenExpired = errors.New("refresh token expired")
|
||||||
|
ErrInvalidRefreshToken = errors.New("invalid refresh token")
|
||||||
|
|
||||||
|
// One-time token errors
|
||||||
|
ErrOneTimeTokenExpired = errors.New("one-time token expired")
|
||||||
|
ErrInvalidOneTimeToken = errors.New("invalid one-time token")
|
||||||
|
|
||||||
|
// Blacklist errors
|
||||||
|
ErrTokenBlacklisted = errors.New("token is blacklisted")
|
||||||
|
ErrBlacklistNotFound = errors.New("blacklist entry not found")
|
||||||
|
)
|
||||||
|
|
@ -1,66 +0,0 @@
|
||||||
package entity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"go.mongodb.org/mongo-driver/v2/bson"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PermissionType 權限類型
|
|
||||||
type PermissionType int
|
|
||||||
|
|
||||||
const (
|
|
||||||
PermissionTypeAPI PermissionType = 1 // API 權限
|
|
||||||
PermissionTypeMenu PermissionType = 2 // 選單權限
|
|
||||||
)
|
|
||||||
|
|
||||||
// Permission 權限實體
|
|
||||||
type Permission struct {
|
|
||||||
ID bson.ObjectID `bson:"_id,omitempty" json:"id"`
|
|
||||||
ParentID *bson.ObjectID `bson:"parent_id,omitempty" json:"parent_id"`
|
|
||||||
Name string `bson:"name" json:"name"`
|
|
||||||
HTTPMethod string `bson:"http_method" json:"http_method"`
|
|
||||||
HTTPPath string `bson:"http_path" json:"http_path"`
|
|
||||||
Status int `bson:"status" json:"status"`
|
|
||||||
Type PermissionType `bson:"type" json:"type"`
|
|
||||||
CreateTime time.Time `bson:"create_time" json:"create_time"`
|
|
||||||
UpdateTime time.Time `bson:"update_time" json:"update_time"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CollectionName 返回集合名稱
|
|
||||||
func (p *Permission) CollectionName() string {
|
|
||||||
return "permissions"
|
|
||||||
}
|
|
||||||
|
|
||||||
//// StatusActive 權限啟用狀態
|
|
||||||
//const StatusActive = 1
|
|
||||||
//
|
|
||||||
//// IsActive 檢查權限是否啟用
|
|
||||||
//func (p *Permission) IsActive() bool {
|
|
||||||
// return p.Status == StatusActive
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//
|
|
||||||
//// Validate 驗證權限數據
|
|
||||||
//func (p *Permission) Validate() error {
|
|
||||||
// if p.Name == "" {
|
|
||||||
// return mongo.WriteError{Code: 400, Message: "permission name is required"}
|
|
||||||
// }
|
|
||||||
// if p.Type == PermissionTypeAPI {
|
|
||||||
// if p.HTTPMethod == "" {
|
|
||||||
// return mongo.WriteError{Code: 400, Message: "http_method is required for API permission"}
|
|
||||||
// }
|
|
||||||
// if p.HTTPPath == "" {
|
|
||||||
// return mongo.WriteError{Code: 400, Message: "http_path is required for API permission"}
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// return nil
|
|
||||||
//}
|
|
||||||
//
|
|
||||||
//// GetKey 獲取權限標識
|
|
||||||
//func (p *Permission) GetKey() string {
|
|
||||||
// if p.Type == PermissionTypeAPI {
|
|
||||||
// return p.HTTPMethod + ":" + p.HTTPPath
|
|
||||||
// }
|
|
||||||
// return p.Name
|
|
||||||
//}
|
|
||||||
|
|
@ -0,0 +1,85 @@
|
||||||
|
package entity
|
||||||
|
|
||||||
|
// AuthorizationReq 定義授權請求的結構
|
||||||
|
type AuthorizationReq struct {
|
||||||
|
GrantType string `json:"grant_type"` // 授權類型
|
||||||
|
DeviceID string `json:"device_id"` // 設備 ID
|
||||||
|
Scope string `json:"scope"` // 授權範圍
|
||||||
|
Data map[string]string `json:"data"` // 附加數據
|
||||||
|
Expires int64 `json:"expires"` // 過期時間(秒)
|
||||||
|
IsRefreshToken bool `json:"is_refresh_token"` // 是否為刷新令牌
|
||||||
|
Role string `json:"role"` // 用戶角色
|
||||||
|
Account string `json:"account"` // 登入時的帳號
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenResp 定義訪問令牌響應的結構
|
||||||
|
type TokenResp struct {
|
||||||
|
AccessToken string `json:"access_token"` // 訪問令牌
|
||||||
|
TokenType string `json:"token_type"` // 令牌類型
|
||||||
|
ExpiresIn int64 `json:"expires_in"` // 過期時間(秒)
|
||||||
|
RefreshToken string `json:"refresh_token"` // 刷新令牌
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOneTimeTokenReq 建立一次性 Token 的請求
|
||||||
|
type CreateOneTimeTokenReq struct {
|
||||||
|
Token string `json:"token"` // 長期有效的驗證令牌
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOneTimeTokenResp 建立一次性 Token 的響應
|
||||||
|
type CreateOneTimeTokenResp struct {
|
||||||
|
OneTimeToken string `json:"one_time_token"` // 一次性令牌
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenReq 更新 Token 的請求
|
||||||
|
type RefreshTokenReq struct {
|
||||||
|
Token string `json:"token"` // 令牌
|
||||||
|
Scope string `json:"scope"` // 授權範圍
|
||||||
|
Expires int64 `json:"expires"` // 過期時間(秒)
|
||||||
|
DeviceID string `json:"device_id"` // 設備 ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenResp 更新令牌的響應
|
||||||
|
type RefreshTokenResp struct {
|
||||||
|
Token string `json:"token"` // 新的訪問令牌
|
||||||
|
OneTimeToken string `json:"one_time_token"` // 一次性令牌
|
||||||
|
ExpiresIn int64 `json:"expires_in"` // 過期時間(秒)
|
||||||
|
TokenType string `json:"token_type"` // 令牌類型
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelTokenReq 註銷 Token 的請求
|
||||||
|
type CancelTokenReq struct {
|
||||||
|
Token string `json:"token"` // 需要註銷的令牌
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoTokenByUIDReq 基於 UID 操作 Token 的請求
|
||||||
|
type DoTokenByUIDReq struct {
|
||||||
|
IDs []string `json:"ids"` // Token ID 列表
|
||||||
|
UID string `json:"uid"` // 用戶 ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryTokenByUIDReq 查詢 UID 對應的 Token
|
||||||
|
type QueryTokenByUIDReq struct {
|
||||||
|
UID string `json:"uid"` // 用戶 ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidationTokenReq 驗證 Token 的請求
|
||||||
|
type ValidationTokenReq struct {
|
||||||
|
Token string `json:"token"` // 需要驗證的令牌
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidationTokenResp 驗證並返回 Token 詳情
|
||||||
|
type ValidationTokenResp struct {
|
||||||
|
Token Token `json:"token"` // Token 詳情
|
||||||
|
Data map[string]string `json:"data"` // 附加數據
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoTokenByDeviceIDReq 基於設備 ID 操作 Token 的請求
|
||||||
|
type DoTokenByDeviceIDReq struct {
|
||||||
|
DeviceID string `json:"device_id"` // 設備 ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelOneTimeTokenReq 取消一次性 Token 的請求
|
||||||
|
type CancelOneTimeTokenReq struct {
|
||||||
|
Token []string `json:"token"` // 一次性 Token 列表
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -1,66 +0,0 @@
|
||||||
package entity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"go.mongodb.org/mongo-driver/v2/bson"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Permissions 權限映射表
|
|
||||||
type Permissions map[string]int
|
|
||||||
|
|
||||||
// Role 角色實體
|
|
||||||
type Role struct {
|
|
||||||
ID bson.ObjectID `bson:"_id,omitempty" json:"id"`
|
|
||||||
ClientID string `bson:"client_id" json:"client_id"`
|
|
||||||
UID string `bson:"uid" json:"uid"`
|
|
||||||
Name string `bson:"name" json:"name"`
|
|
||||||
Status int `bson:"status" json:"status"`
|
|
||||||
Permissions Permissions `bson:"permissions" json:"permissions"`
|
|
||||||
CreateTime time.Time `bson:"create_time" json:"create_time"`
|
|
||||||
UpdateTime time.Time `bson:"update_time" json:"update_time"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CollectionName 返回集合名稱
|
|
||||||
func (r *Role) CollectionName() string {
|
|
||||||
return "roles"
|
|
||||||
}
|
|
||||||
|
|
||||||
// // Validate 驗證角色數據
|
|
||||||
//
|
|
||||||
// func (r *Role) Validate() error {
|
|
||||||
// if r.ClientID == "" {
|
|
||||||
// return mongo.WriteError{Code: 400, Message: "client_id is required"}
|
|
||||||
// }
|
|
||||||
// if r.Name == "" {
|
|
||||||
// return mongo.WriteError{Code: 400, Message: "role name is required"}
|
|
||||||
// }
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // HasPermission 檢查是否有指定權限
|
|
||||||
//
|
|
||||||
// func (r *Role) HasPermission(key string) bool {
|
|
||||||
// if !r.IsActive() {
|
|
||||||
// return false
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// permission, exists := r.Permissions[key]
|
|
||||||
// return exists && permission == 1 // 1 表示有權限
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
|
|
||||||
// AddPermission 添加權限
|
|
||||||
func (r *Role) AddPermission(key string) {
|
|
||||||
if r.Permissions == nil {
|
|
||||||
r.Permissions = make(Permissions)
|
|
||||||
}
|
|
||||||
r.Permissions[key] = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePermission 移除權限
|
|
||||||
func (r *Role) RemovePermission(key string) {
|
|
||||||
if r.Permissions != nil {
|
|
||||||
delete(r.Permissions, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,46 +1,65 @@
|
||||||
package entity
|
package entity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"go.mongodb.org/mongo-driver/v2/bson"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Token 令牌實體
|
// Token represents a token entity stored in Redis
|
||||||
type Token struct {
|
type Token struct {
|
||||||
ID bson.ObjectID `bson:"_id,omitempty" json:"id"`
|
ID string `json:"id"` // Token ID (KSUID)
|
||||||
UID string `bson:"uid" json:"uid"`
|
UID string `json:"uid"` // User ID
|
||||||
ClientID string `bson:"client_id" json:"client_id"`
|
DeviceID string `json:"device_id"` // Device ID
|
||||||
AccessToken string `bson:"access_token" json:"access_token"`
|
AccessToken string `json:"access_token"` // JWT access token
|
||||||
RefreshToken string `bson:"refresh_token" json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"` // SHA256 refresh token
|
||||||
DeviceID string `bson:"device_id" json:"device_id"`
|
ExpiresIn int `json:"expires_in"` // Access token expiry (Unix timestamp)
|
||||||
ExpiresAt time.Time `bson:"expires_at" json:"expires_at"`
|
RefreshExpiresIn int `json:"refresh_expires_in"` // Refresh token expiry (Unix timestamp)
|
||||||
CreateTime time.Time `bson:"create_time" json:"create_time"`
|
AccessCreateAt time.Time `json:"access_create_at"` // Access token creation time
|
||||||
UpdateTime time.Time `bson:"update_time" json:"update_time"`
|
RefreshCreateAt time.Time `json:"refresh_create_at"` // Refresh token creation time
|
||||||
}
|
}
|
||||||
|
|
||||||
// CollectionName 返回集合名稱
|
// IsExpired checks if the access token is expired
|
||||||
func (t *Token) CollectionName() string {
|
func (t *Token) IsExpired() bool {
|
||||||
return "tokens"
|
return time.Now().Unix() > int64(t.ExpiresIn)
|
||||||
}
|
}
|
||||||
|
|
||||||
//// IsExpired 檢查令牌是否過期
|
// IsRefreshExpired checks if the refresh token is expired
|
||||||
//func (t *Token) IsExpired() bool {
|
func (t *Token) IsRefreshExpired() bool {
|
||||||
// return time.Now().After(t.ExpiresAt)
|
return time.Now().Unix() > int64(t.RefreshExpiresIn)
|
||||||
//}
|
}
|
||||||
//
|
|
||||||
//// Validate 驗證令牌數據
|
// RedisRefreshExpiredSec returns the refresh token expiry duration in seconds
|
||||||
//func (t *Token) Validate() error {
|
func (t *Token) RedisRefreshExpiredSec() int {
|
||||||
// if t.UID == "" {
|
now := time.Now().Unix()
|
||||||
// return mongo.WriteError{Code: 400, Message: "uid is required"}
|
if int64(t.RefreshExpiresIn) <= now {
|
||||||
// }
|
return 0
|
||||||
// if t.ClientID == "" {
|
}
|
||||||
// return mongo.WriteError{Code: 400, Message: "client_id is required"}
|
return t.RefreshExpiresIn - int(now)
|
||||||
// }
|
}
|
||||||
// if t.AccessToken == "" {
|
|
||||||
// return mongo.WriteError{Code: 400, Message: "access_token is required"}
|
// Ticket represents a one-time token ticket
|
||||||
// }
|
type Ticket struct {
|
||||||
// if t.RefreshToken == "" {
|
Data map[string]string `json:"data"` // Token claims data
|
||||||
// return mongo.WriteError{Code: 400, Message: "refresh_token is required"}
|
Token Token `json:"token"` // Associated token
|
||||||
// }
|
}
|
||||||
// return nil
|
|
||||||
//}
|
// Claims represents JWT claims structure
|
||||||
|
type Claims struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
Data interface{} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate validates the token entity
|
||||||
|
func (t *Token) Validate() error {
|
||||||
|
if t.ID == "" {
|
||||||
|
return ErrInvalidTokenID
|
||||||
|
}
|
||||||
|
if t.UID == "" {
|
||||||
|
return ErrInvalidUID
|
||||||
|
}
|
||||||
|
if t.AccessToken == "" {
|
||||||
|
return ErrInvalidAccessToken
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,318 @@
|
||||||
|
package entity
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestToken_IsExpired(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token *Token
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "expired token",
|
||||||
|
token: &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
ExpiresIn: int(time.Now().Add(-time.Hour).Unix()),
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid token",
|
||||||
|
token: &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token expiring now",
|
||||||
|
token: &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
ExpiresIn: int(time.Now().Unix()) - 1,
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.token.IsExpired()
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToken_IsRefreshExpired(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token *Token
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "expired refresh token",
|
||||||
|
token: &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
RefreshExpiresIn: int(time.Now().Add(-time.Hour).Unix()),
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid refresh token",
|
||||||
|
token: &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
RefreshExpiresIn: int(time.Now().Add(time.Hour).Unix()),
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.token.IsRefreshExpired()
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToken_RedisRefreshExpiredSec(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token *Token
|
||||||
|
expected int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "token with future expiry",
|
||||||
|
token: &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
RefreshExpiresIn: int(time.Now().Add(time.Hour).Unix()),
|
||||||
|
},
|
||||||
|
expected: 3600, // Approximately 1 hour in seconds
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token already expired",
|
||||||
|
token: &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
RefreshExpiresIn: int(time.Now().Add(-time.Hour).Unix()),
|
||||||
|
},
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token expiring now",
|
||||||
|
token: &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
RefreshExpiresIn: int(time.Now().Unix()),
|
||||||
|
},
|
||||||
|
expected: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.token.RedisRefreshExpiredSec()
|
||||||
|
|
||||||
|
if tt.expected == 0 {
|
||||||
|
assert.Equal(t, 0, result)
|
||||||
|
} else {
|
||||||
|
// Allow some margin for test execution time
|
||||||
|
assert.InDelta(t, tt.expected, result, 5)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToken_Validate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token *Token
|
||||||
|
wantErr bool
|
||||||
|
expectedErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid token",
|
||||||
|
token: &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
AccessToken: "test-access-token",
|
||||||
|
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing ID",
|
||||||
|
token: &Token{
|
||||||
|
ID: "",
|
||||||
|
UID: "test-uid",
|
||||||
|
AccessToken: "test-access-token",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
expectedErr: ErrInvalidTokenID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing UID",
|
||||||
|
token: &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "",
|
||||||
|
AccessToken: "test-access-token",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
expectedErr: ErrInvalidUID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing AccessToken",
|
||||||
|
token: &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
AccessToken: "",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
expectedErr: ErrInvalidAccessToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "all fields missing",
|
||||||
|
token: &Token{
|
||||||
|
ID: "",
|
||||||
|
UID: "",
|
||||||
|
AccessToken: "",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
expectedErr: ErrInvalidTokenID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := tt.token.Validate()
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
if tt.expectedErr != nil {
|
||||||
|
assert.Equal(t, tt.expectedErr, err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTicket(t *testing.T) {
|
||||||
|
t.Run("ticket with data", func(t *testing.T) {
|
||||||
|
ticket := Ticket{
|
||||||
|
Data: map[string]string{
|
||||||
|
"uid": "user123",
|
||||||
|
"role": "admin",
|
||||||
|
},
|
||||||
|
Token: Token{
|
||||||
|
ID: "token123",
|
||||||
|
UID: "user123",
|
||||||
|
AccessToken: "access-token",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NotNil(t, ticket.Data)
|
||||||
|
assert.Equal(t, "user123", ticket.Data["uid"])
|
||||||
|
assert.Equal(t, "admin", ticket.Data["role"])
|
||||||
|
assert.Equal(t, "token123", ticket.Token.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty ticket", func(t *testing.T) {
|
||||||
|
ticket := Ticket{}
|
||||||
|
|
||||||
|
assert.Nil(t, ticket.Data)
|
||||||
|
assert.Empty(t, ticket.Token.ID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToken_DeviceID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
deviceID string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with device ID",
|
||||||
|
deviceID: "device123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty device ID",
|
||||||
|
deviceID: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UUID device ID",
|
||||||
|
deviceID: "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
token := &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
DeviceID: tt.deviceID,
|
||||||
|
AccessToken: "test-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.deviceID, token.DeviceID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToken_RefreshToken(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
refreshToken string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with refresh token",
|
||||||
|
refreshToken: "refresh-token-123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty refresh token",
|
||||||
|
refreshToken: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "long refresh token",
|
||||||
|
refreshToken: "very-long-refresh-token-with-hash-abcdef1234567890",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
token := &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
AccessToken: "access-token",
|
||||||
|
RefreshToken: tt.refreshToken,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, tt.refreshToken, token.RefreshToken)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToken_Timestamps(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
token := &Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "test-uid",
|
||||||
|
AccessToken: "access-token",
|
||||||
|
AccessCreateAt: now,
|
||||||
|
RefreshCreateAt: now.Add(time.Second),
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, now, token.AccessCreateAt)
|
||||||
|
assert.True(t, token.RefreshCreateAt.After(token.AccessCreateAt))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -1,37 +0,0 @@
|
||||||
package entity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"go.mongodb.org/mongo-driver/v2/bson"
|
|
||||||
)
|
|
||||||
|
|
||||||
// UserRole 用戶角色關聯實體
|
|
||||||
type UserRole struct {
|
|
||||||
ID bson.ObjectID `bson:"_id,omitempty" json:"id"`
|
|
||||||
Brand string `bson:"brand" json:"brand"`
|
|
||||||
UID string `bson:"uid" json:"uid"`
|
|
||||||
RoleUID string `bson:"role_uid" json:"role_uid"`
|
|
||||||
Status int `bson:"status" json:"status"`
|
|
||||||
CreateTime time.Time `bson:"create_time" json:"create_time"`
|
|
||||||
UpdateTime time.Time `bson:"update_time" json:"update_time"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CollectionName 返回集合名稱
|
|
||||||
func (ur *UserRole) CollectionName() string {
|
|
||||||
return "user_roles"
|
|
||||||
}
|
|
||||||
|
|
||||||
//// Validate 驗證用戶角色關聯數據
|
|
||||||
//func (ur *UserRole) Validate() error {
|
|
||||||
// if ur.Brand == "" {
|
|
||||||
// return mongo.WriteError{Code: 400, Message: "brand is required"}
|
|
||||||
// }
|
|
||||||
// if ur.UID == "" {
|
|
||||||
// return mongo.WriteError{Code: 400, Message: "uid is required"}
|
|
||||||
// }
|
|
||||||
// if ur.RoleUID == "" {
|
|
||||||
// return mongo.WriteError{Code: 400, Message: "role_uid is required"}
|
|
||||||
// }
|
|
||||||
// return nil
|
|
||||||
//}
|
|
||||||
|
|
@ -1,15 +1,31 @@
|
||||||
package domain
|
package domain
|
||||||
|
|
||||||
import "backend/pkg/library/errs"
|
import "errors"
|
||||||
|
|
||||||
const (
|
var (
|
||||||
FailedToGetByID errs.ErrorCode = iota + 1
|
// Token validation errors
|
||||||
FailedToGetByClientID
|
ErrInvalidTokenID = errors.New("invalid token ID")
|
||||||
FailedToGetPermission
|
ErrInvalidUID = errors.New("invalid UID")
|
||||||
FailedToGetPermissionByKey
|
ErrInvalidAccessToken = errors.New("invalid access token")
|
||||||
FailedToGetRoleByID
|
ErrTokenExpired = errors.New("token expired")
|
||||||
FailedToGetByUID
|
ErrTokenNotFound = errors.New("token not found")
|
||||||
FailedToGetByClientAndName
|
|
||||||
FailedToGetByClientAndName
|
// JWT specific errors
|
||||||
FailedToGetByClientAndName
|
ErrInvalidJWTToken = errors.New("invalid JWT token")
|
||||||
|
ErrJWTSigningFailed = errors.New("JWT signing failed")
|
||||||
|
ErrJWTParsingFailed = errors.New("JWT parsing failed")
|
||||||
|
ErrInvalidSigningKey = errors.New("invalid signing key")
|
||||||
|
ErrInvalidJTI = errors.New("invalid JWT ID")
|
||||||
|
|
||||||
|
// Refresh token errors
|
||||||
|
ErrRefreshTokenExpired = errors.New("refresh token expired")
|
||||||
|
ErrInvalidRefreshToken = errors.New("invalid refresh token")
|
||||||
|
|
||||||
|
// One-time token errors
|
||||||
|
ErrOneTimeTokenExpired = errors.New("one-time token expired")
|
||||||
|
ErrInvalidOneTimeToken = errors.New("invalid one-time token")
|
||||||
|
|
||||||
|
// Blacklist errors
|
||||||
|
ErrTokenBlacklisted = errors.New("token is blacklisted")
|
||||||
|
ErrBlacklistNotFound = errors.New("blacklist entry not found")
|
||||||
)
|
)
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// DeviceToken 表示裝置與 Token 之間的關聯
|
||||||
|
type DeviceToken struct {
|
||||||
|
DeviceID string // 裝置的唯一標識符
|
||||||
|
TokenID string // Token 的唯一標識符
|
||||||
|
}
|
||||||
|
|
||||||
|
type UIDToken map[string]int64
|
||||||
|
|
||||||
|
// Ticket 表示一次性使用的 Token 結構,包含數據和 Token 資訊
|
||||||
|
type Ticket struct {
|
||||||
|
Data any `json:"data"` // 任意附加數據
|
||||||
|
Token Token `json:"token"` // 一次性使用的 Token 資訊
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token 表示使用者的存取和刷新 Token 資訊
|
||||||
|
type Token struct {
|
||||||
|
ID string `json:"id"` // Token 的唯一標識符
|
||||||
|
UID string `json:"uid"` // 用戶的唯一標識符
|
||||||
|
DeviceID string `json:"device_id"` // 裝置的唯一標識符
|
||||||
|
AccessToken string `json:"access_token"` // 存取 Token
|
||||||
|
ExpiresIn int `json:"expires_in"` // 存取 Token 的有效時長(秒)
|
||||||
|
AccessCreateAt time.Time `json:"access_create_at"` // 存取 Token 的創建時間
|
||||||
|
RefreshToken string `json:"refresh_token"` // 刷新 Token
|
||||||
|
RefreshExpiresIn int `json:"refresh_expires_in"` // 刷新 Token 的有效時長(秒)
|
||||||
|
RefreshCreateAt time.Time `json:"refresh_create_at"` // 刷新 Token 的創建時間
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccessTokenExpires 返回存取 Token 的有效期(以秒為單位)。
|
||||||
|
func (t *Token) AccessTokenExpires() time.Duration {
|
||||||
|
return time.Duration(t.ExpiresIn) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenExpires 返回刷新 Token 的有效期(以秒為單位)。
|
||||||
|
func (t *Token) RefreshTokenExpires() time.Duration {
|
||||||
|
return time.Duration(t.RefreshExpiresIn) * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokenExpiresUnix 返回刷新 Token 的到期時間(UnixNano 時間戳)。
|
||||||
|
func (t *Token) RefreshTokenExpiresUnix() int64 {
|
||||||
|
return time.Now().Add(t.RefreshTokenExpires()).UnixNano()
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsExpires 檢查存取 Token 是否已過期。如果存取 Token 的創建時間加上其有效期早於當前時間,則返回 true。
|
||||||
|
func (t *Token) IsExpires() bool {
|
||||||
|
return t.AccessCreateAt.Add(t.AccessTokenExpires()).Before(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedisExpiredSec 返回存取 Token 在 Redis 中的剩餘有效時間(秒)。計算方法為:從到期時間的 Unix 時間戳減去當前時間。
|
||||||
|
func (t *Token) RedisExpiredSec() int64 {
|
||||||
|
sec := time.Unix(int64(t.ExpiresIn), 0).Sub(time.Now().UTC())
|
||||||
|
|
||||||
|
return int64(sec.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedisRefreshExpiredSec 返回刷新 Token 在 Redis 中的剩餘有效時間(秒)。計算方法為:從刷新到期時間的 Unix 時間戳減去當前時間。
|
||||||
|
func (t *Token) RedisRefreshExpiredSec() int64 {
|
||||||
|
sec := time.Unix(int64(t.RefreshExpiresIn), 0).Sub(time.Now().UTC())
|
||||||
|
|
||||||
|
return int64(sec.Seconds())
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,141 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestToken_AccessTokenExpires(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
expiresIn int
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{"zero expiration", 0, 0},
|
||||||
|
{"1 second expiration", 1, time.Second},
|
||||||
|
{"60 seconds expiration", 60, time.Minute},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
token := Token{ExpiresIn: tt.expiresIn}
|
||||||
|
if got := token.AccessTokenExpires(); got != tt.want {
|
||||||
|
t.Errorf("AccessTokenExpires() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToken_RefreshTokenExpires(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
refreshExpires int
|
||||||
|
want time.Duration
|
||||||
|
}{
|
||||||
|
{"zero expiration", 0, 0},
|
||||||
|
{"1 second expiration", 1, time.Second},
|
||||||
|
{"90 seconds expiration", 90, 90 * time.Second},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
token := Token{RefreshExpiresIn: tt.refreshExpires}
|
||||||
|
if got := token.RefreshTokenExpires(); got != tt.want {
|
||||||
|
t.Errorf("RefreshTokenExpires() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToken_RefreshTokenExpiresUnix(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
refreshExpires int
|
||||||
|
}{
|
||||||
|
{"zero expiration", 0},
|
||||||
|
{"10 seconds expiration", 10},
|
||||||
|
{"60 seconds expiration", 60},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
token := Token{RefreshExpiresIn: tt.refreshExpires}
|
||||||
|
got := token.RefreshTokenExpiresUnix()
|
||||||
|
want := time.Now().Add(time.Duration(tt.refreshExpires) * time.Second).UnixNano()
|
||||||
|
|
||||||
|
// 確保計算的時間戳在合理範圍內
|
||||||
|
if got < want-1e9 || got > want+1e9 {
|
||||||
|
t.Errorf("RefreshTokenExpiresUnix() = %v, want %v (±1 second tolerance)", got, want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToken_IsExpires(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accessCreateAt time.Time
|
||||||
|
expiresIn int
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"not expired", now.Add(-5 * time.Minute), 600, false}, // 10-minute expiry, created 5 minutes ago
|
||||||
|
{"just expired", now.Add(-10 * time.Minute), 600, true}, // 10-minute expiry, created 10 minutes ago
|
||||||
|
{"already expired", now.Add(-15 * time.Minute), 600, true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
token := Token{AccessCreateAt: tt.accessCreateAt, ExpiresIn: tt.expiresIn}
|
||||||
|
if got := token.IsExpires(); got != tt.want {
|
||||||
|
t.Errorf("IsExpires() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToken_RedisExpiredSec(t *testing.T) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
expiresIn int
|
||||||
|
}{
|
||||||
|
{"zero expiration", 0},
|
||||||
|
{"future expiration", int(now + 3600)}, // Expires in 1 hour
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
token := Token{ExpiresIn: tt.expiresIn}
|
||||||
|
got := token.RedisExpiredSec()
|
||||||
|
want := time.Unix(int64(tt.expiresIn), 0).Sub(time.Now().UTC()).Seconds()
|
||||||
|
|
||||||
|
if float64(got) < want-1 || float64(got) > want+1 {
|
||||||
|
t.Errorf("RedisExpiredSec() = %v, want ~%v (±1 second tolerance)", got, int64(want))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToken_RedisRefreshExpiredSec(t *testing.T) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
refreshExpires int
|
||||||
|
}{
|
||||||
|
{"zero refresh expiration", 0},
|
||||||
|
{"future refresh expiration", int(now + 7200)}, // Expires in 2 hours
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
token := Token{RefreshExpiresIn: tt.refreshExpires}
|
||||||
|
got := token.RedisRefreshExpiredSec()
|
||||||
|
want := time.Unix(int64(tt.refreshExpires), 0).Sub(time.Now().UTC()).Seconds()
|
||||||
|
|
||||||
|
if float64(got) < want-1 || float64(got) > want+1 {
|
||||||
|
t.Errorf("RedisRefreshExpiredSec() = %v, want ~%v (±1 second tolerance)", got, int64(want))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,46 +0,0 @@
|
||||||
package permission
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
// Status 狀態常數
|
|
||||||
const (
|
|
||||||
StatusActive = 1
|
|
||||||
StatusInactive = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
// Type 權限類型
|
|
||||||
type Type int8
|
|
||||||
|
|
||||||
const (
|
|
||||||
TypeBackend Type = iota + 1
|
|
||||||
TypeFrontend
|
|
||||||
)
|
|
||||||
|
|
||||||
// Status 權限狀態
|
|
||||||
type Status string
|
|
||||||
|
|
||||||
const (
|
|
||||||
StatusOpen Status = "open"
|
|
||||||
StatusClose Status = "close"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Permissions 權限映射
|
|
||||||
type Permissions map[string]Status
|
|
||||||
|
|
||||||
// GrantType 授權類型
|
|
||||||
type GrantType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
GrantTypePassword GrantType = "password"
|
|
||||||
GrantTypeClient GrantType = "client_credentials"
|
|
||||||
GrantTypeRefreshToken GrantType = "refresh_token"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Default Values 預設值
|
|
||||||
const (
|
|
||||||
DefaultRole = "user"
|
|
||||||
AdminRole = "admin"
|
|
||||||
AdminRoleUID = "AM000000"
|
|
||||||
AdminUID = "B000000"
|
|
||||||
RefreshTokenTTL = 5 * time.Second
|
|
||||||
)
|
|
||||||
|
|
@ -2,39 +2,42 @@ package domain
|
||||||
|
|
||||||
import "strings"
|
import "strings"
|
||||||
|
|
||||||
// RedisKey represents a Redis key type with helper methods for key construction.
|
const (
|
||||||
|
TicketKeyPrefix = "tic/"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ClientDataKey = "permission:clients"
|
||||||
|
)
|
||||||
|
|
||||||
type RedisKey string
|
type RedisKey string
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ClientRedisKey RedisKey = "client"
|
AccessTokenRedisKey RedisKey = "access_token"
|
||||||
PermissionRedisKey RedisKey = "permission"
|
RefreshTokenRedisKey RedisKey = "refresh_token"
|
||||||
RoleRedisKey RedisKey = "role"
|
DeviceTokenRedisKey RedisKey = "device_token"
|
||||||
UserRoleRedisKey RedisKey = "user_role"
|
UIDTokenRedisKey RedisKey = "uid_token"
|
||||||
|
TicketRedisKey RedisKey = "ticket"
|
||||||
|
DeviceUIDRedisKey RedisKey = "device_uid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ToString converts the RedisKey to its full string representation with the member prefix.
|
|
||||||
func (key RedisKey) ToString() string {
|
func (key RedisKey) ToString() string {
|
||||||
return "member:" + string(key)
|
return "permission:" + string(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
// With appends additional parts to the RedisKey, separated by colons.
|
|
||||||
func (key RedisKey) With(s ...string) RedisKey {
|
func (key RedisKey) With(s ...string) RedisKey {
|
||||||
parts := append([]string{string(key)}, s...)
|
parts := append([]string{string(key)}, s...)
|
||||||
return RedisKey(strings.Join(parts, ":"))
|
return RedisKey(strings.Join(parts, ":"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeClientRedisKey(id string) string {
|
func GetAccessTokenRedisKey(id string) string {
|
||||||
return ClientRedisKey.With(id).ToString()
|
return AccessTokenRedisKey.With(id).ToString()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPermissionRedisKey(id string) string {
|
func GetUIDTokenRedisKey(uid string) string {
|
||||||
return PermissionRedisKey.With(id).ToString()
|
return UIDTokenRedisKey.With(uid).ToString()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetRoleRedisKeyRedisKey(id string) string {
|
func GetTicketRedisKey(ticket string) string {
|
||||||
return RoleRedisKey.With(id).ToString()
|
return TicketRedisKey.With(ticket).ToString()
|
||||||
}
|
|
||||||
|
|
||||||
func GetUserRoleRedisKey(id string) string {
|
|
||||||
return UserRoleRedisKey.With(id).ToString()
|
|
||||||
}
|
}
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
package domain
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestRedisKey_ToString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key RedisKey
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"AccessToken Key", AccessTokenRedisKey, "permission:access_token"},
|
||||||
|
{"UIDToken Key", UIDTokenRedisKey, "permission:uid_token"},
|
||||||
|
{"Ticket Key", TicketRedisKey, "permission:ticket"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.key.ToString(); got != tt.want {
|
||||||
|
t.Errorf("ToString() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedisKey_With(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key RedisKey
|
||||||
|
args []string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"AccessToken with ID", AccessTokenRedisKey, []string{"12345"}, "access_token:12345"},
|
||||||
|
{"UIDToken with UID", UIDTokenRedisKey, []string{"67890"}, "uid_token:67890"},
|
||||||
|
{"Ticket with multiple parts", TicketRedisKey, []string{"session", "12345"}, "ticket:session:12345"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := tt.key.With(tt.args...).ToString(); got != "permission:"+tt.want {
|
||||||
|
t.Errorf("With() = %v, want %v", got, "permission:"+tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccessTokenRedisKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
id string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"AccessToken Key with ID", "12345", "permission:access_token:12345"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := GetAccessTokenRedisKey(tt.id); got != tt.want {
|
||||||
|
t.Errorf("GetAccessTokenRedisKey() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUIDTokenRedisKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
uid string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"UIDToken Key with UID", "67890", "permission:uid_token:67890"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := GetUIDTokenRedisKey(tt.uid); got != tt.want {
|
||||||
|
t.Errorf("GetUIDTokenRedisKey() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetTicketRedisKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ticket string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"Ticket Key with Ticket", "session123", "permission:ticket:session123"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := GetTicketRedisKey(tt.ticket); got != tt.want {
|
||||||
|
t.Errorf("GetTicketRedisKey() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"context"
|
|
||||||
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ClientRepository 客戶端倉庫介面
|
|
||||||
type ClientRepository interface {
|
|
||||||
Create(ctx context.Context, client *entity.Client) error
|
|
||||||
GetByID(ctx context.Context, id string) (*entity.Client, error)
|
|
||||||
GetByClientID(ctx context.Context, clientID string) (*entity.Client, error)
|
|
||||||
Update(ctx context.Context, id string, client *entity.Client) error
|
|
||||||
Delete(ctx context.Context, id string) error
|
|
||||||
List(ctx context.Context, filter ClientFilter) ([]*entity.Client, error)
|
|
||||||
Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientFilter 客戶端查詢過濾器
|
|
||||||
type ClientFilter struct {
|
|
||||||
Status *int
|
|
||||||
Limit int
|
|
||||||
Skip int
|
|
||||||
}
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"context"
|
|
||||||
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PermissionRepository 權限倉庫介面
|
|
||||||
type PermissionRepository interface {
|
|
||||||
Create(ctx context.Context, permission *entity.Permission) error
|
|
||||||
GetByID(ctx context.Context, id string) (*entity.Permission, error)
|
|
||||||
GetByKey(ctx context.Context, httpMethod, httpPath string) (*entity.Permission, error)
|
|
||||||
Update(ctx context.Context, id string, permission *entity.Permission) error
|
|
||||||
Delete(ctx context.Context, id string) error
|
|
||||||
List(ctx context.Context, filter PermissionFilter) ([]*entity.Permission, error)
|
|
||||||
GetActivePermissions(ctx context.Context) ([]*entity.Permission, error)
|
|
||||||
Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PermissionFilter 權限查詢過濾器
|
|
||||||
type PermissionFilter struct {
|
|
||||||
Status *int
|
|
||||||
Type *entity.PermissionType
|
|
||||||
ParentID *string
|
|
||||||
Limit int
|
|
||||||
Skip int
|
|
||||||
}
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"context"
|
|
||||||
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RoleRepository 角色倉庫介面
|
|
||||||
type RoleRepository interface {
|
|
||||||
Create(ctx context.Context, role *entity.Role) error
|
|
||||||
GetByID(ctx context.Context, id string) (*entity.Role, error)
|
|
||||||
GetByUID(ctx context.Context, uid string) (*entity.Role, error)
|
|
||||||
GetByClientAndName(ctx context.Context, clientID, name string) (*entity.Role, error)
|
|
||||||
Update(ctx context.Context, id string, role *entity.Role) error
|
|
||||||
Delete(ctx context.Context, id string) error
|
|
||||||
List(ctx context.Context, filter RoleFilter) ([]*entity.Role, error)
|
|
||||||
GetRolesByClientID(ctx context.Context, clientID string) ([]*entity.Role, error)
|
|
||||||
Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RoleFilter 角色查詢過濾器
|
|
||||||
type RoleFilter struct {
|
|
||||||
ClientID string
|
|
||||||
Status *int
|
|
||||||
Limit int
|
|
||||||
Skip int
|
|
||||||
}
|
|
||||||
|
|
@ -1,16 +1,49 @@
|
||||||
package repository
|
package repository
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"context"
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"backend/pkg/permission/domain/entity"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TokenRepository 令牌倉庫介面
|
// TokenRepository 定義了與 Redis 相關的 Token 操作方法
|
||||||
|
//nolint:interfacebloat
|
||||||
type TokenRepository interface {
|
type TokenRepository interface {
|
||||||
Create(ctx context.Context, token *entity.Token) error
|
// Create 建立新的 Token 並存儲至 Redis
|
||||||
GetByAccessToken(ctx context.Context, accessToken string) (*entity.Token, error)
|
Create(ctx context.Context, token entity.Token) error
|
||||||
GetByRefreshToken(ctx context.Context, refreshToken string) (*entity.Token, error)
|
// CreateOneTimeToken 建立臨時(一次性)Token,並指定有效期限
|
||||||
Update(ctx context.Context, token *entity.Token) error
|
CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error
|
||||||
Delete(ctx context.Context, id string) error
|
// GetAccessTokenByOneTimeToken 根據一次性 Token 獲取對應的存取 Token
|
||||||
DeleteByUserID(ctx context.Context, uid string) error
|
GetAccessTokenByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error)
|
||||||
|
// GetAccessTokenByID 根據 Token ID 獲取對應的存取 Token
|
||||||
|
GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error)
|
||||||
|
// GetAccessTokensByUID 根據用戶 ID 獲取該用戶的所有存取 Token
|
||||||
|
GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error)
|
||||||
|
// GetAccessTokenCountByUID 根據用戶 ID 獲取該用戶的存取 Token 數量
|
||||||
|
GetAccessTokenCountByUID(ctx context.Context, uid string) (int, error)
|
||||||
|
// GetAccessTokensByDeviceID 根據裝置 ID 獲取該裝置的所有存取 Token
|
||||||
|
GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error)
|
||||||
|
// GetAccessTokenCountByDeviceID 根據裝置 ID 獲取該裝置的存取 Token 數量
|
||||||
|
GetAccessTokenCountByDeviceID(ctx context.Context, deviceID string) (int, error)
|
||||||
|
// Delete 刪除指定的 Token
|
||||||
|
Delete(ctx context.Context, token entity.Token) error
|
||||||
|
// DeleteOneTimeToken 批量刪除一次性 Token
|
||||||
|
DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error
|
||||||
|
// DeleteAccessTokenByID 根據 Token ID 批量刪除存取 Token
|
||||||
|
DeleteAccessTokenByID(ctx context.Context, ids []string) error
|
||||||
|
// DeleteAccessTokensByUID 根據用戶 ID 刪除該用戶的所有存取 Token
|
||||||
|
DeleteAccessTokensByUID(ctx context.Context, uid string) error
|
||||||
|
// DeleteAccessTokensByDeviceID 根據裝置 ID 刪除該裝置的所有存取 Token
|
||||||
|
DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error
|
||||||
|
|
||||||
|
// Blacklist operations
|
||||||
|
// AddToBlacklist 將 JWT token 加入黑名單
|
||||||
|
AddToBlacklist(ctx context.Context, entry *entity.BlacklistEntry, ttl time.Duration) error
|
||||||
|
// IsBlacklisted 檢查 JWT token 是否在黑名單中
|
||||||
|
IsBlacklisted(ctx context.Context, jti string) (bool, error)
|
||||||
|
// RemoveFromBlacklist 從黑名單中移除 JWT token
|
||||||
|
RemoveFromBlacklist(ctx context.Context, jti string) error
|
||||||
|
// GetBlacklistedTokensByUID 獲取用戶的所有黑名單 token
|
||||||
|
GetBlacklistedTokensByUID(ctx context.Context, uid string) ([]*entity.BlacklistEntry, error)
|
||||||
}
|
}
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"context"
|
|
||||||
)
|
|
||||||
|
|
||||||
// UserRoleRepository 用戶角色倉庫介面
|
|
||||||
type UserRoleRepository interface {
|
|
||||||
Create(ctx context.Context, userRole *entity.UserRole) error
|
|
||||||
GetByID(ctx context.Context, id string) (*entity.UserRole, error)
|
|
||||||
GetByUserAndRole(ctx context.Context, uid, roleUID string) (*entity.UserRole, error)
|
|
||||||
Update(ctx context.Context, id string, userRole *entity.UserRole) error
|
|
||||||
Delete(ctx context.Context, id string) error
|
|
||||||
List(ctx context.Context, filter UserRoleFilter) ([]*entity.UserRole, error)
|
|
||||||
GetUserRolesByUID(ctx context.Context, uid string) ([]*entity.UserRole, error)
|
|
||||||
DeleteByUserAndRole(ctx context.Context, uid, roleUID string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// UserRoleFilter 用戶角色查詢過濾器
|
|
||||||
type UserRoleFilter struct {
|
|
||||||
Brand string
|
|
||||||
UID string
|
|
||||||
RoleUID string
|
|
||||||
Status *int
|
|
||||||
Limit int
|
|
||||||
Skip int
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
package token
|
||||||
|
|
||||||
|
// GrantType represents OAuth 2.0 grant types
|
||||||
|
type GrantType string
|
||||||
|
|
||||||
|
// ToString returns the string representation of GrantType
|
||||||
|
func (g GrantType) ToString() string {
|
||||||
|
return string(g)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid returns true if the grant type is valid
|
||||||
|
func (g GrantType) IsValid() bool {
|
||||||
|
switch g {
|
||||||
|
case PasswordCredentials, ClientCredentials, Refreshing:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
PasswordCredentials GrantType = "password"
|
||||||
|
ClientCredentials GrantType = "client_credentials"
|
||||||
|
Refreshing GrantType = "refresh_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,160 @@
|
||||||
|
package token
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGrantType_ToString(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
grantType GrantType
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "password credentials",
|
||||||
|
grantType: PasswordCredentials,
|
||||||
|
expected: "password",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "client credentials",
|
||||||
|
grantType: ClientCredentials,
|
||||||
|
expected: "client_credentials",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "refreshing",
|
||||||
|
grantType: Refreshing,
|
||||||
|
expected: "refresh_token",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom grant type",
|
||||||
|
grantType: GrantType("custom"),
|
||||||
|
expected: "custom",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty grant type",
|
||||||
|
grantType: GrantType(""),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.grantType.ToString()
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGrantType_IsValid(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
grantType GrantType
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "password credentials is valid",
|
||||||
|
grantType: PasswordCredentials,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "client credentials is valid",
|
||||||
|
grantType: ClientCredentials,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "refreshing is valid",
|
||||||
|
grantType: Refreshing,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid grant type",
|
||||||
|
grantType: GrantType("invalid"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty grant type",
|
||||||
|
grantType: GrantType(""),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "authorization code (not implemented)",
|
||||||
|
grantType: GrantType("authorization_code"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "implicit (not implemented)",
|
||||||
|
grantType: GrantType("implicit"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.grantType.IsValid()
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGrantType_Constants(t *testing.T) {
|
||||||
|
t.Run("verify constant values", func(t *testing.T) {
|
||||||
|
assert.Equal(t, "password", PasswordCredentials.ToString())
|
||||||
|
assert.Equal(t, "client_credentials", ClientCredentials.ToString())
|
||||||
|
assert.Equal(t, "refresh_token", Refreshing.ToString())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("verify all constants are valid", func(t *testing.T) {
|
||||||
|
assert.True(t, PasswordCredentials.IsValid())
|
||||||
|
assert.True(t, ClientCredentials.IsValid())
|
||||||
|
assert.True(t, Refreshing.IsValid())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGrantType_StringComparison(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
gt1 GrantType
|
||||||
|
gt2 GrantType
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "same grant type",
|
||||||
|
gt1: PasswordCredentials,
|
||||||
|
gt2: PasswordCredentials,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different grant types",
|
||||||
|
gt1: PasswordCredentials,
|
||||||
|
gt2: ClientCredentials,
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string comparison",
|
||||||
|
gt1: GrantType("password"),
|
||||||
|
gt2: PasswordCredentials,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.gt1 == tt.gt2
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGrantType_CaseSensitive(t *testing.T) {
|
||||||
|
t.Run("case sensitive comparison", func(t *testing.T) {
|
||||||
|
gt1 := GrantType("password")
|
||||||
|
gt2 := GrantType("PASSWORD")
|
||||||
|
|
||||||
|
assert.NotEqual(t, gt1, gt2)
|
||||||
|
assert.True(t, gt1.IsValid())
|
||||||
|
assert.False(t, gt2.IsValid())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
package token
|
||||||
|
|
||||||
|
// Type represents the type of token
|
||||||
|
type Type string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TypeBearer Type = "Bearer"
|
||||||
|
TypeBasic Type = "Basic"
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns the string representation of TokenType
|
||||||
|
func (t Type) String() string {
|
||||||
|
return string(t)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid returns true if the token type is valid
|
||||||
|
func (t Type) IsValid() bool {
|
||||||
|
return t == TypeBearer || t == TypeBasic
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redis key prefixes and patterns
|
||||||
|
const (
|
||||||
|
AccessTokenKeyPrefix = "access_token:"
|
||||||
|
|
||||||
|
RefreshTokenKeyPrefix = "refresh_token:"
|
||||||
|
|
||||||
|
OneTimeTokenKeyPrefix = "one_time_token:"
|
||||||
|
|
||||||
|
UIDTokenKeyPrefix = "uid_tokens:"
|
||||||
|
|
||||||
|
DeviceTokenKeyPrefix = "device_tokens:"
|
||||||
|
|
||||||
|
TicketKeyPrefix = "ticket:"
|
||||||
|
|
||||||
|
BlacklistKeyPrefix = "blacklist:"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Redis key helper functions
|
||||||
|
func GetAccessTokenRedisKey(tokenID string) string {
|
||||||
|
return AccessTokenKeyPrefix + tokenID
|
||||||
|
}
|
||||||
|
|
||||||
|
func RefreshTokenRedisKey(tokenID string) string {
|
||||||
|
return RefreshTokenKeyPrefix + tokenID
|
||||||
|
}
|
||||||
|
|
||||||
|
func UIDTokenRedisKey(uid string) string {
|
||||||
|
return UIDTokenKeyPrefix + uid
|
||||||
|
}
|
||||||
|
|
||||||
|
func DeviceTokenRedisKey(deviceID string) string {
|
||||||
|
return DeviceTokenKeyPrefix + deviceID
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetBlacklistRedisKey(jti string) string {
|
||||||
|
return BlacklistKeyPrefix + jti
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetUIDTokenRedisKey(uid string) string {
|
||||||
|
return UIDTokenKeyPrefix + uid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default expiration times (in seconds)
|
||||||
|
const (
|
||||||
|
DefaultAccessTokenExpiry = 15 * 60 // 15 minutes
|
||||||
|
DefaultRefreshTokenExpiry = 7 * 24 * 3600 // 7 days
|
||||||
|
DefaultOneTimeTokenExpiry = 5 * 60 // 5 minutes
|
||||||
|
)
|
||||||
|
|
||||||
|
// Token limits
|
||||||
|
const (
|
||||||
|
MaxTokensPerUser = 10 // Maximum tokens per user
|
||||||
|
MaxTokensPerDevice = 5 // Maximum tokens per device
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,340 @@
|
||||||
|
package token
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestType_String(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tokenType Type
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "bearer type",
|
||||||
|
tokenType: TypeBearer,
|
||||||
|
expected: "Bearer",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "basic type",
|
||||||
|
tokenType: TypeBasic,
|
||||||
|
expected: "Basic",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "custom type",
|
||||||
|
tokenType: Type("Custom"),
|
||||||
|
expected: "Custom",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty type",
|
||||||
|
tokenType: Type(""),
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.tokenType.String()
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestType_IsValid(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tokenType Type
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "bearer is valid",
|
||||||
|
tokenType: TypeBearer,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "basic is valid",
|
||||||
|
tokenType: TypeBasic,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid type",
|
||||||
|
tokenType: Type("Invalid"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty type",
|
||||||
|
tokenType: Type(""),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "lowercase bearer",
|
||||||
|
tokenType: Type("bearer"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.tokenType.IsValid()
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestType_Constants(t *testing.T) {
|
||||||
|
t.Run("verify constant values", func(t *testing.T) {
|
||||||
|
assert.Equal(t, "Bearer", TypeBearer.String())
|
||||||
|
assert.Equal(t, "Basic", TypeBasic.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("verify constants are valid", func(t *testing.T) {
|
||||||
|
assert.True(t, TypeBearer.IsValid())
|
||||||
|
assert.True(t, TypeBasic.IsValid())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRedisKeyPrefixes(t *testing.T) {
|
||||||
|
t.Run("verify key prefix constants", func(t *testing.T) {
|
||||||
|
assert.Equal(t, "access_token:", AccessTokenKeyPrefix)
|
||||||
|
assert.Equal(t, "refresh_token:", RefreshTokenKeyPrefix)
|
||||||
|
assert.Equal(t, "one_time_token:", OneTimeTokenKeyPrefix)
|
||||||
|
assert.Equal(t, "uid_tokens:", UIDTokenKeyPrefix)
|
||||||
|
assert.Equal(t, "device_tokens:", DeviceTokenKeyPrefix)
|
||||||
|
assert.Equal(t, "ticket:", TicketKeyPrefix)
|
||||||
|
assert.Equal(t, "blacklist:", BlacklistKeyPrefix)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetAccessTokenRedisKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tokenID string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal token ID",
|
||||||
|
tokenID: "token123",
|
||||||
|
expected: "access_token:token123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UUID token ID",
|
||||||
|
tokenID: "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
expected: "access_token:550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty token ID",
|
||||||
|
tokenID: "",
|
||||||
|
expected: "access_token:",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetAccessTokenRedisKey(tt.tokenID)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRefreshTokenRedisKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tokenID string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal token ID",
|
||||||
|
tokenID: "refresh123",
|
||||||
|
expected: "refresh_token:refresh123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "hash token ID",
|
||||||
|
tokenID: "a1b2c3d4e5f6",
|
||||||
|
expected: "refresh_token:a1b2c3d4e5f6",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty token ID",
|
||||||
|
tokenID: "",
|
||||||
|
expected: "refresh_token:",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := RefreshTokenRedisKey(tt.tokenID)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUIDTokenRedisKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
uid string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal UID",
|
||||||
|
uid: "user123",
|
||||||
|
expected: "uid_tokens:user123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UUID UID",
|
||||||
|
uid: "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
expected: "uid_tokens:550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty UID",
|
||||||
|
uid: "",
|
||||||
|
expected: "uid_tokens:",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := UIDTokenRedisKey(tt.uid)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceTokenRedisKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
deviceID string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal device ID",
|
||||||
|
deviceID: "device123",
|
||||||
|
expected: "device_tokens:device123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UUID device ID",
|
||||||
|
deviceID: "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
expected: "device_tokens:550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty device ID",
|
||||||
|
deviceID: "",
|
||||||
|
expected: "device_tokens:",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := DeviceTokenRedisKey(tt.deviceID)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetBlacklistRedisKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
jti string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal JTI",
|
||||||
|
jti: "jti123",
|
||||||
|
expected: "blacklist:jti123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UUID JTI",
|
||||||
|
jti: "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
expected: "blacklist:550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty JTI",
|
||||||
|
jti: "",
|
||||||
|
expected: "blacklist:",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetBlacklistRedisKey(tt.jti)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUIDTokenRedisKey(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
uid string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal UID",
|
||||||
|
uid: "user456",
|
||||||
|
expected: "uid_tokens:user456",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty UID",
|
||||||
|
uid: "",
|
||||||
|
expected: "uid_tokens:",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetUIDTokenRedisKey(tt.uid)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultExpirationTimes(t *testing.T) {
|
||||||
|
t.Run("verify default expiration constants", func(t *testing.T) {
|
||||||
|
assert.Equal(t, int64(15*60), int64(DefaultAccessTokenExpiry))
|
||||||
|
assert.Equal(t, int64(7*24*3600), int64(DefaultRefreshTokenExpiry))
|
||||||
|
assert.Equal(t, int64(5*60), int64(DefaultOneTimeTokenExpiry))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("verify expiration times are reasonable", func(t *testing.T) {
|
||||||
|
assert.Greater(t, int64(DefaultAccessTokenExpiry), int64(0))
|
||||||
|
assert.Greater(t, int64(DefaultRefreshTokenExpiry), int64(DefaultAccessTokenExpiry))
|
||||||
|
assert.Greater(t, int64(DefaultOneTimeTokenExpiry), int64(0))
|
||||||
|
assert.Less(t, int64(DefaultOneTimeTokenExpiry), int64(DefaultAccessTokenExpiry))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenLimits(t *testing.T) {
|
||||||
|
t.Run("verify token limit constants", func(t *testing.T) {
|
||||||
|
assert.Equal(t, 10, MaxTokensPerUser)
|
||||||
|
assert.Equal(t, 5, MaxTokensPerDevice)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("verify limits are reasonable", func(t *testing.T) {
|
||||||
|
assert.Greater(t, MaxTokensPerUser, 0)
|
||||||
|
assert.Greater(t, MaxTokensPerDevice, 0)
|
||||||
|
assert.GreaterOrEqual(t, MaxTokensPerUser, MaxTokensPerDevice)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyPrefixUniqueness(t *testing.T) {
|
||||||
|
t.Run("all key prefixes should be unique", func(t *testing.T) {
|
||||||
|
prefixes := []string{
|
||||||
|
AccessTokenKeyPrefix,
|
||||||
|
RefreshTokenKeyPrefix,
|
||||||
|
OneTimeTokenKeyPrefix,
|
||||||
|
UIDTokenKeyPrefix,
|
||||||
|
DeviceTokenKeyPrefix,
|
||||||
|
TicketKeyPrefix,
|
||||||
|
BlacklistKeyPrefix,
|
||||||
|
}
|
||||||
|
|
||||||
|
seen := make(map[string]bool)
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
assert.False(t, seen[prefix], "duplicate prefix found: %s", prefix)
|
||||||
|
seen[prefix] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, len(prefixes), len(seen))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AuthUseCase 認證用例介面
|
|
||||||
type AuthUseCase interface {
|
|
||||||
CreateToken(ctx context.Context, req CreateTokenRequest) (*TokenResponse, error)
|
|
||||||
RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error)
|
|
||||||
ValidateToken(ctx context.Context, accessToken string) (*TokenClaims, error)
|
|
||||||
Logout(ctx context.Context, accessToken string) error
|
|
||||||
LogoutAllByUserID(ctx context.Context, uid string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateTokenRequest 創建令牌請求
|
|
||||||
type CreateTokenRequest struct {
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
GrantType string `json:"grant_type"`
|
|
||||||
Username string `json:"username,omitempty"`
|
|
||||||
Password string `json:"password,omitempty"`
|
|
||||||
DeviceID string `json:"device_id,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenResponse 令牌響應
|
|
||||||
type TokenResponse struct {
|
|
||||||
AccessToken string `json:"access_token"`
|
|
||||||
RefreshToken string `json:"refresh_token"`
|
|
||||||
TokenType string `json:"token_type"`
|
|
||||||
ExpiresIn int64 `json:"expires_in"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TokenClaims 令牌聲明
|
|
||||||
type TokenClaims struct {
|
|
||||||
UID string `json:"uid"`
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
DeviceID string `json:"device_id"`
|
|
||||||
}
|
|
||||||
|
|
@ -1,90 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"context"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PermissionUseCase 權限用例介面 (使用 Casbin)
|
|
||||||
type PermissionUseCase interface {
|
|
||||||
// 基本權限管理
|
|
||||||
CreatePermission(ctx context.Context, req CreatePermissionRequest) (*entity.Permission, error)
|
|
||||||
GetPermission(ctx context.Context, id string) (*entity.Permission, error)
|
|
||||||
UpdatePermission(ctx context.Context, req UpdatePermissionRequest) (*entity.Permission, error)
|
|
||||||
DeletePermission(ctx context.Context, id string) error
|
|
||||||
ListPermissions(ctx context.Context, req ListPermissionsRequest) ([]*entity.Permission, error)
|
|
||||||
|
|
||||||
// Casbin 權限檢查
|
|
||||||
CheckUserPermission(ctx context.Context, uid, httpMethod, httpPath string) (bool, error)
|
|
||||||
CheckRolePermission(ctx context.Context, roleUID, httpMethod, httpPath string) (bool, error)
|
|
||||||
CheckPatternPermission(ctx context.Context, uid, pattern, action string) (bool, error)
|
|
||||||
BatchCheckPermissions(ctx context.Context, uid string, permissions []PermissionCheck) (map[string]bool, error)
|
|
||||||
|
|
||||||
// 用戶權限管理
|
|
||||||
GetUserPermissions(ctx context.Context, uid string) (map[string]int, error)
|
|
||||||
AddPolicyForUser(ctx context.Context, uid, httpPath, httpMethod string) error
|
|
||||||
RemovePolicyForUser(ctx context.Context, uid, httpPath, httpMethod string) error
|
|
||||||
|
|
||||||
// 角色管理
|
|
||||||
AddRoleForUser(ctx context.Context, uid, roleUID string) error
|
|
||||||
RemoveRoleForUser(ctx context.Context, uid, roleUID string) error
|
|
||||||
GetUsersForRole(ctx context.Context, roleUID string) ([]string, error)
|
|
||||||
GetRolesForUser(ctx context.Context, uid string) ([]string, error)
|
|
||||||
|
|
||||||
// 角色權限管理
|
|
||||||
AddPermissionForRole(ctx context.Context, roleUID, httpPath, httpMethod string) error
|
|
||||||
RemovePermissionForRole(ctx context.Context, roleUID, httpPath, httpMethod string) error
|
|
||||||
GetPermissionsForRole(ctx context.Context, roleUID string) (map[string]int, error)
|
|
||||||
|
|
||||||
// 策略管理
|
|
||||||
GetAllPolicies(ctx context.Context) ([][]string, error)
|
|
||||||
GetFilteredPolicies(ctx context.Context, fieldIndex int, fieldValues ...string) ([][]string, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatePermissionRequest 創建權限請求
|
|
||||||
type CreatePermissionRequest struct {
|
|
||||||
ParentID *string `json:"parent_id,omitempty"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
HTTPMethod string `json:"http_method,omitempty"`
|
|
||||||
HTTPPath string `json:"http_path,omitempty"`
|
|
||||||
Status int `json:"status"`
|
|
||||||
Type entity.PermissionType `json:"type"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePermissionRequest 更新權限請求
|
|
||||||
type UpdatePermissionRequest struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Name *string `json:"name,omitempty"`
|
|
||||||
HTTPMethod *string `json:"http_method,omitempty"`
|
|
||||||
HTTPPath *string `json:"http_path,omitempty"`
|
|
||||||
Status *int `json:"status,omitempty"`
|
|
||||||
Type *entity.PermissionType `json:"type,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListPermissionsRequest 列出權限請求
|
|
||||||
type ListPermissionsRequest struct {
|
|
||||||
Status *int `json:"status,omitempty"`
|
|
||||||
Type *entity.PermissionType `json:"type,omitempty"`
|
|
||||||
ParentID *string `json:"parent_id,omitempty"`
|
|
||||||
Limit int `json:"limit"`
|
|
||||||
Skip int `json:"skip"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// PermissionCheck 權限檢查項目
|
|
||||||
type PermissionCheck struct {
|
|
||||||
HTTPMethod string `json:"http_method"`
|
|
||||||
HTTPPath string `json:"http_path"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CasbinPolicyRequest Casbin 策略請求
|
|
||||||
type CasbinPolicyRequest struct {
|
|
||||||
Subject string `json:"subject"` // 用戶或角色
|
|
||||||
Object string `json:"object"` // 資源
|
|
||||||
Action string `json:"action"` // 行為
|
|
||||||
}
|
|
||||||
|
|
||||||
// CasbinRoleRequest Casbin 角色請求
|
|
||||||
type CasbinRoleRequest struct {
|
|
||||||
User string `json:"user"` // 用戶
|
|
||||||
Role string `json:"role"` // 角色
|
|
||||||
}
|
|
||||||
|
|
@ -1,46 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"context"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RoleUseCase 角色用例介面
|
|
||||||
type RoleUseCase interface {
|
|
||||||
CreateRole(ctx context.Context, req CreateRoleRequest) (*entity.Role, error)
|
|
||||||
GetRole(ctx context.Context, id string) (*entity.Role, error)
|
|
||||||
GetRoleByUID(ctx context.Context, uid string) (*entity.Role, error)
|
|
||||||
UpdateRole(ctx context.Context, req UpdateRoleRequest) (*entity.Role, error)
|
|
||||||
DeleteRole(ctx context.Context, id string) error
|
|
||||||
ListRoles(ctx context.Context, req ListRolesRequest) ([]*entity.Role, error)
|
|
||||||
AddPermissionToRole(ctx context.Context, roleID string, permissionKey string) error
|
|
||||||
RemovePermissionFromRole(ctx context.Context, roleID string, permissionKey string) error
|
|
||||||
BatchUpdateRolePermissions(ctx context.Context, roleID string, permissions entity.Permissions) error
|
|
||||||
GetRolesByClientID(ctx context.Context, clientID string) ([]*entity.Role, error)
|
|
||||||
CopyRole(ctx context.Context, sourceRoleID string, req CreateRoleRequest) (*entity.Role, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateRoleRequest 創建角色請求
|
|
||||||
type CreateRoleRequest struct {
|
|
||||||
ClientID string `json:"client_id"`
|
|
||||||
UID string `json:"uid"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Status int `json:"status"`
|
|
||||||
Permissions entity.Permissions `json:"permissions"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateRoleRequest 更新角色請求
|
|
||||||
type UpdateRoleRequest struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Name *string `json:"name,omitempty"`
|
|
||||||
Status *int `json:"status,omitempty"`
|
|
||||||
Permissions *entity.Permissions `json:"permissions,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListRolesRequest 列出角色請求
|
|
||||||
type ListRolesRequest struct {
|
|
||||||
ClientID string `json:"client_id,omitempty"`
|
|
||||||
Status *int `json:"status,omitempty"`
|
|
||||||
Limit int `json:"limit"`
|
|
||||||
Skip int `json:"skip"`
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
package usecase
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"backend/pkg/permission/domain/entity"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenUseCase 定義與 Token 相關的操作接口
|
||||||
|
//
|
||||||
|
//nolint:interfacebloat
|
||||||
|
type TokenUseCase interface {
|
||||||
|
// NewToken 創建新 Token,通常為 Access Token
|
||||||
|
NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error)
|
||||||
|
// RefreshToken 刷新目前的 Token,包括一次性 Token
|
||||||
|
RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error)
|
||||||
|
// CancelToken 取消 Token,包括取消其關聯的 One-Time Token
|
||||||
|
CancelToken(ctx context.Context, req entity.CancelTokenReq) error
|
||||||
|
// ValidationToken 驗證 Token 是否有效
|
||||||
|
ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error)
|
||||||
|
// CancelTokens 根據 UID 或 Token ID 取消所有相關 Token,通常在用戶登出時使用
|
||||||
|
CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error
|
||||||
|
// CancelTokenByDeviceID 根據 Device ID 取消所有相關的 Token
|
||||||
|
CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error
|
||||||
|
// GetUserTokensByDeviceID 根據 Device ID 獲取所有 Token
|
||||||
|
GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error)
|
||||||
|
// GetUserTokensByUID 根據 UID 獲取所有 Token
|
||||||
|
GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error)
|
||||||
|
// NewOneTimeToken 創建一次性 Token,例如 Refresh Token
|
||||||
|
NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error)
|
||||||
|
// CancelOneTimeToken 取消一次性 Token
|
||||||
|
CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error
|
||||||
|
// ReadTokenBasicData 檢查Token 帶的資料
|
||||||
|
ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error)
|
||||||
|
|
||||||
|
// Blacklist operations
|
||||||
|
|
||||||
|
// BlacklistToken 將 JWT token 加入黑名單 (立即撤銷)
|
||||||
|
BlacklistToken(ctx context.Context, token string, reason string) error
|
||||||
|
// IsTokenBlacklisted 檢查 JWT token 是否在黑名單中
|
||||||
|
IsTokenBlacklisted(ctx context.Context, jti string) (bool, error)
|
||||||
|
// BlacklistAllUserTokens 將用戶的所有 token 加入黑名單 (全設備登出)
|
||||||
|
BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error
|
||||||
|
}
|
||||||
|
|
@ -1,49 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"context"
|
|
||||||
)
|
|
||||||
|
|
||||||
// UserRoleUseCase 用戶角色用例介面
|
|
||||||
type UserRoleUseCase interface {
|
|
||||||
AssignRole(ctx context.Context, req AssignRoleRequest) (*entity.UserRole, error)
|
|
||||||
RevokeRole(ctx context.Context, uid, roleUID string) error
|
|
||||||
GetUserRole(ctx context.Context, id string) (*entity.UserRole, error)
|
|
||||||
UpdateUserRole(ctx context.Context, req UpdateUserRoleRequest) (*entity.UserRole, error)
|
|
||||||
ListUserRoles(ctx context.Context, req ListUserRolesRequest) ([]*entity.UserRole, error)
|
|
||||||
GetUserRoles(ctx context.Context, uid string) ([]*entity.UserRole, error)
|
|
||||||
GetUserRoleDetails(ctx context.Context, uid string) ([]*UserRoleDetail, error)
|
|
||||||
BatchAssignRoles(ctx context.Context, uid string, roleUIDs []string, brand string) error
|
|
||||||
BatchRevokeRoles(ctx context.Context, uid string, roleUIDs []string) error
|
|
||||||
ReplaceUserRoles(ctx context.Context, uid string, roleUIDs []string, brand string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// AssignRoleRequest 分配角色請求
|
|
||||||
type AssignRoleRequest struct {
|
|
||||||
Brand string `json:"brand"`
|
|
||||||
UID string `json:"uid"`
|
|
||||||
RoleUID string `json:"role_uid"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateUserRoleRequest 更新用戶角色請求
|
|
||||||
type UpdateUserRoleRequest struct {
|
|
||||||
ID string `json:"id"`
|
|
||||||
Status *int `json:"status,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListUserRolesRequest 列出用戶角色請求
|
|
||||||
type ListUserRolesRequest struct {
|
|
||||||
Brand string `json:"brand,omitempty"`
|
|
||||||
UID string `json:"uid,omitempty"`
|
|
||||||
RoleUID string `json:"role_uid,omitempty"`
|
|
||||||
Status *int `json:"status,omitempty"`
|
|
||||||
Limit int `json:"limit"`
|
|
||||||
Skip int `json:"skip"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// UserRoleDetail 用戶角色詳情
|
|
||||||
type UserRoleDetail struct {
|
|
||||||
UserRole *entity.UserRole `json:"user_role"`
|
|
||||||
Role *entity.Role `json:"role"`
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,130 @@
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"backend/pkg/permission/domain/entity"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockTokenRepository is a mock implementation of TokenRepository
|
||||||
|
type MockTokenRepository struct {
|
||||||
|
mock.Mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockTokenRepository creates a new mock instance
|
||||||
|
func NewMockTokenRepository(t interface {
|
||||||
|
mock.TestingT
|
||||||
|
Cleanup(func())
|
||||||
|
}) *MockTokenRepository {
|
||||||
|
mock := &MockTokenRepository{}
|
||||||
|
mock.Mock.Test(t)
|
||||||
|
|
||||||
|
t.Cleanup(func() { mock.AssertExpectations(t) })
|
||||||
|
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create provides a mock function with given fields: ctx, token
|
||||||
|
func (m *MockTokenRepository) Create(ctx context.Context, token entity.Token) error {
|
||||||
|
ret := m.Called(ctx, token)
|
||||||
|
return ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateOneTimeToken provides a mock function with given fields: ctx, key, ticket, dt
|
||||||
|
func (m *MockTokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error {
|
||||||
|
ret := m.Called(ctx, key, ticket, dt)
|
||||||
|
return ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessTokenByOneTimeToken provides a mock function with given fields: ctx, oneTimeToken
|
||||||
|
func (m *MockTokenRepository) GetAccessTokenByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) {
|
||||||
|
ret := m.Called(ctx, oneTimeToken)
|
||||||
|
return ret.Get(0).(entity.Token), ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessTokenByID provides a mock function with given fields: ctx, id
|
||||||
|
func (m *MockTokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) {
|
||||||
|
ret := m.Called(ctx, id)
|
||||||
|
return ret.Get(0).(entity.Token), ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessTokensByUID provides a mock function with given fields: ctx, uid
|
||||||
|
func (m *MockTokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
|
||||||
|
ret := m.Called(ctx, uid)
|
||||||
|
return ret.Get(0).([]entity.Token), ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessTokenCountByUID provides a mock function with given fields: ctx, uid
|
||||||
|
func (m *MockTokenRepository) GetAccessTokenCountByUID(ctx context.Context, uid string) (int, error) {
|
||||||
|
ret := m.Called(ctx, uid)
|
||||||
|
return ret.Int(0), ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessTokensByDeviceID provides a mock function with given fields: ctx, deviceID
|
||||||
|
func (m *MockTokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
|
||||||
|
ret := m.Called(ctx, deviceID)
|
||||||
|
return ret.Get(0).([]entity.Token), ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessTokenCountByDeviceID provides a mock function with given fields: ctx, deviceID
|
||||||
|
func (m *MockTokenRepository) GetAccessTokenCountByDeviceID(ctx context.Context, deviceID string) (int, error) {
|
||||||
|
ret := m.Called(ctx, deviceID)
|
||||||
|
return ret.Int(0), ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete provides a mock function with given fields: ctx, token
|
||||||
|
func (m *MockTokenRepository) Delete(ctx context.Context, token entity.Token) error {
|
||||||
|
ret := m.Called(ctx, token)
|
||||||
|
return ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteOneTimeToken provides a mock function with given fields: ctx, ids, tokens
|
||||||
|
func (m *MockTokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
|
||||||
|
ret := m.Called(ctx, ids, tokens)
|
||||||
|
return ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAccessTokenByID provides a mock function with given fields: ctx, ids
|
||||||
|
func (m *MockTokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
|
||||||
|
ret := m.Called(ctx, ids)
|
||||||
|
return ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAccessTokensByUID provides a mock function with given fields: ctx, uid
|
||||||
|
func (m *MockTokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
|
||||||
|
ret := m.Called(ctx, uid)
|
||||||
|
return ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAccessTokensByDeviceID provides a mock function with given fields: ctx, deviceID
|
||||||
|
func (m *MockTokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
|
||||||
|
ret := m.Called(ctx, deviceID)
|
||||||
|
return ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddToBlacklist provides a mock function with given fields: ctx, entry, ttl
|
||||||
|
func (m *MockTokenRepository) AddToBlacklist(ctx context.Context, entry *entity.BlacklistEntry, ttl time.Duration) error {
|
||||||
|
ret := m.Called(ctx, entry, ttl)
|
||||||
|
return ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsBlacklisted provides a mock function with given fields: ctx, jti
|
||||||
|
func (m *MockTokenRepository) IsBlacklisted(ctx context.Context, jti string) (bool, error) {
|
||||||
|
ret := m.Called(ctx, jti)
|
||||||
|
return ret.Bool(0), ret.Error(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveFromBlacklist provides a mock function with given fields: ctx, jti
|
||||||
|
func (m *MockTokenRepository) RemoveFromBlacklist(ctx context.Context, jti string) error {
|
||||||
|
ret := m.Called(ctx, jti)
|
||||||
|
return ret.Error(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlacklistedTokensByUID provides a mock function with given fields: ctx, uid
|
||||||
|
func (m *MockTokenRepository) GetBlacklistedTokensByUID(ctx context.Context, uid string) ([]*entity.BlacklistEntry, error) {
|
||||||
|
ret := m.Called(ctx, uid)
|
||||||
|
return ret.Get(0).([]*entity.BlacklistEntry), ret.Error(1)
|
||||||
|
}
|
||||||
|
|
@ -1,265 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"backend/pkg/library/errs"
|
|
||||||
"backend/pkg/library/mongo"
|
|
||||||
|
|
||||||
"github.com/casbin/casbin/v2/model"
|
|
||||||
"github.com/casbin/casbin/v2/persist"
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/mon"
|
|
||||||
"go.mongodb.org/mongo-driver/v2/bson"
|
|
||||||
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CasbinRule represents a casbin rule in MongoDB
|
|
||||||
type CasbinRule struct {
|
|
||||||
ID bson.ObjectID `bson:"_id,omitempty"`
|
|
||||||
PType string `bson:"ptype"`
|
|
||||||
V0 string `bson:"v0"`
|
|
||||||
V1 string `bson:"v1"`
|
|
||||||
V2 string `bson:"v2"`
|
|
||||||
V3 string `bson:"v3"`
|
|
||||||
V4 string `bson:"v4"`
|
|
||||||
V5 string `bson:"v5"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// CasbinAdapterParam Casbin adapter 參數
|
|
||||||
type CasbinAdapterParam struct {
|
|
||||||
Conf *mongo.Conf
|
|
||||||
CacheConf cache.CacheConf
|
|
||||||
DBOpts []mon.Option
|
|
||||||
CacheOpts []cache.Option
|
|
||||||
}
|
|
||||||
|
|
||||||
// CasbinAdapter MongoDB adapter for Casbin
|
|
||||||
type CasbinAdapter struct {
|
|
||||||
DB mongo.DocumentDBWithCacheUseCase
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCasbinAdapter 創建 Casbin adapter
|
|
||||||
func NewCasbinAdapter(param CasbinAdapterParam) persist.Adapter {
|
|
||||||
db, err := mongo.MustDocumentDBWithCache(
|
|
||||||
"casbin_rules",
|
|
||||||
param.Conf,
|
|
||||||
param.CacheConf,
|
|
||||||
param.CacheOpts,
|
|
||||||
param.DBOpts,
|
|
||||||
)
|
|
||||||
|
|
||||||
return &CasbinAdapter{
|
|
||||||
DB: db,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadPolicy loads all policy rules from the storage.
|
|
||||||
func (a *CasbinAdapter) LoadPolicy(model model.Model) error {
|
|
||||||
ctx := context.Background()
|
|
||||||
var rules []CasbinRule
|
|
||||||
|
|
||||||
err := a.DB.Find(ctx, bson.M{}, &rules)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DatabaseErr(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
a.loadPolicyLine(&rule, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SavePolicy saves all policy rules to the storage.
|
|
||||||
func (a *CasbinAdapter) SavePolicy(model model.Model) error {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// 清空現有規則
|
|
||||||
err := a.DB.DeleteMany(ctx, bson.M{})
|
|
||||||
if err != nil {
|
|
||||||
return errs.DatabaseErr(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
var rules []interface{}
|
|
||||||
|
|
||||||
for ptype, ast := range model["p"] {
|
|
||||||
for _, rule := range ast.Policy {
|
|
||||||
rules = append(rules, a.savePolicyLine(ptype, rule))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for ptype, ast := range model["g"] {
|
|
||||||
for _, rule := range ast.Policy {
|
|
||||||
rules = append(rules, a.savePolicyLine(ptype, rule))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rules) > 0 {
|
|
||||||
_, err = a.DB.InsertMany(ctx, rules)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DatabaseErr(err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPolicy adds a policy rule to the storage.
|
|
||||||
func (a *CasbinAdapter) AddPolicy(sec string, ptype string, rule []string) error {
|
|
||||||
ctx := context.Background()
|
|
||||||
casbinRule := a.savePolicyLine(ptype, rule)
|
|
||||||
|
|
||||||
_, err := a.DB.InsertOne(ctx, casbinRule)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DatabaseErr(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePolicy removes a policy rule from the storage.
|
|
||||||
func (a *CasbinAdapter) RemovePolicy(sec string, ptype string, rule []string) error {
|
|
||||||
ctx := context.Background()
|
|
||||||
filter := bson.M{"ptype": ptype}
|
|
||||||
|
|
||||||
for i, value := range rule {
|
|
||||||
filter[getFieldName(i)] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
err := a.DB.DeleteMany(ctx, filter)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DatabaseErr(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveFilteredPolicy removes policy rules that match the filter from the storage.
|
|
||||||
func (a *CasbinAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
|
|
||||||
ctx := context.Background()
|
|
||||||
filter := bson.M{"ptype": ptype}
|
|
||||||
|
|
||||||
for i, value := range fieldValues {
|
|
||||||
if fieldIndex+i <= 5 && value != "" {
|
|
||||||
filter[getFieldName(fieldIndex+i)] = value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := a.DB.DeleteMany(ctx, filter)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DatabaseErr(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadPolicyLine loads a line of policy from storage
|
|
||||||
func (a *CasbinAdapter) loadPolicyLine(rule *CasbinRule, model model.Model) {
|
|
||||||
lineText := rule.PType
|
|
||||||
if rule.V0 != "" {
|
|
||||||
lineText += ", " + rule.V0
|
|
||||||
}
|
|
||||||
if rule.V1 != "" {
|
|
||||||
lineText += ", " + rule.V1
|
|
||||||
}
|
|
||||||
if rule.V2 != "" {
|
|
||||||
lineText += ", " + rule.V2
|
|
||||||
}
|
|
||||||
if rule.V3 != "" {
|
|
||||||
lineText += ", " + rule.V3
|
|
||||||
}
|
|
||||||
if rule.V4 != "" {
|
|
||||||
lineText += ", " + rule.V4
|
|
||||||
}
|
|
||||||
if rule.V5 != "" {
|
|
||||||
lineText += ", " + rule.V5
|
|
||||||
}
|
|
||||||
|
|
||||||
persist.LoadPolicyLine(lineText, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// savePolicyLine saves a line of policy to storage
|
|
||||||
func (a *CasbinAdapter) savePolicyLine(ptype string, rule []string) *CasbinRule {
|
|
||||||
casbinRule := &CasbinRule{
|
|
||||||
PType: ptype,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rule) > 0 {
|
|
||||||
casbinRule.V0 = rule[0]
|
|
||||||
}
|
|
||||||
if len(rule) > 1 {
|
|
||||||
casbinRule.V1 = rule[1]
|
|
||||||
}
|
|
||||||
if len(rule) > 2 {
|
|
||||||
casbinRule.V2 = rule[2]
|
|
||||||
}
|
|
||||||
if len(rule) > 3 {
|
|
||||||
casbinRule.V3 = rule[3]
|
|
||||||
}
|
|
||||||
if len(rule) > 4 {
|
|
||||||
casbinRule.V4 = rule[4]
|
|
||||||
}
|
|
||||||
if len(rule) > 5 {
|
|
||||||
casbinRule.V5 = rule[5]
|
|
||||||
}
|
|
||||||
|
|
||||||
return casbinRule
|
|
||||||
}
|
|
||||||
|
|
||||||
// getFieldName returns the field name for the given index
|
|
||||||
func getFieldName(index int) string {
|
|
||||||
switch index {
|
|
||||||
case 0:
|
|
||||||
return "v0"
|
|
||||||
case 1:
|
|
||||||
return "v1"
|
|
||||||
case 2:
|
|
||||||
return "v2"
|
|
||||||
case 3:
|
|
||||||
return "v3"
|
|
||||||
case 4:
|
|
||||||
return "v4"
|
|
||||||
case 5:
|
|
||||||
return "v5"
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Index20241226001UP 創建索引
|
|
||||||
func (a *CasbinAdapter) Index20241226001UP(ctx context.Context) (bool, error) {
|
|
||||||
indexes := []mongodriver.IndexModel{
|
|
||||||
{
|
|
||||||
Keys: bson.D{
|
|
||||||
{Key: "ptype", Value: 1},
|
|
||||||
},
|
|
||||||
Options: &mongodriver.IndexOptions{
|
|
||||||
Name: &[]string{"idx_ptype"}[0],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Keys: bson.D{
|
|
||||||
{Key: "ptype", Value: 1},
|
|
||||||
{Key: "v0", Value: 1},
|
|
||||||
},
|
|
||||||
Options: &mongodriver.IndexOptions{
|
|
||||||
Name: &[]string{"idx_ptype_v0"}[0],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Keys: bson.D{
|
|
||||||
{Key: "ptype", Value: 1},
|
|
||||||
{Key: "v0", Value: 1},
|
|
||||||
{Key: "v1", Value: 1},
|
|
||||||
},
|
|
||||||
Options: &mongodriver.IndexOptions{
|
|
||||||
Name: &[]string{"idx_ptype_v0_v1"}[0],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// 需要轉換為 mongo.DocumentDBWithCacheUseCase 的 CreateIndexes 方法
|
|
||||||
// 這裡簡化處理,實際需要根據你的 mongo 包裝實現
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,196 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/library/errs/code"
|
|
||||||
"backend/pkg/permission/domain"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"backend/pkg/library/errs"
|
|
||||||
"backend/pkg/library/mongo"
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"backend/pkg/permission/domain/repository"
|
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/mon"
|
|
||||||
"go.mongodb.org/mongo-driver/v2/bson"
|
|
||||||
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ClientRepositoryParam struct {
|
|
||||||
Conf *mongo.Conf
|
|
||||||
CacheConf cache.CacheConf
|
|
||||||
DBOpts []mon.Option
|
|
||||||
CacheOpts []cache.Option
|
|
||||||
}
|
|
||||||
|
|
||||||
type ClientRepository struct {
|
|
||||||
DB mongo.DocumentDBWithCacheUseCase
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClientRepository 創建客戶端倉庫實例
|
|
||||||
func NewClientRepository(param ClientRepositoryParam) repository.ClientRepository {
|
|
||||||
e := entity.Client{}
|
|
||||||
documentDB, err := mongo.MustDocumentDBWithCache(
|
|
||||||
param.Conf,
|
|
||||||
e.CollectionName(),
|
|
||||||
param.CacheConf,
|
|
||||||
param.DBOpts,
|
|
||||||
param.CacheOpts,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &ClientRepository{
|
|
||||||
DB: documentDB,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *ClientRepository) Create(ctx context.Context, client *entity.Client) error {
|
|
||||||
now := time.Now()
|
|
||||||
client.CreateTime = now
|
|
||||||
client.UpdateTime = now
|
|
||||||
id := bson.NewObjectID()
|
|
||||||
client.ID = id
|
|
||||||
rk := domain.GeClientRedisKey(id.Hex())
|
|
||||||
_, err := repo.DB.InsertOne(ctx, rk, client)
|
|
||||||
if err != nil {
|
|
||||||
// 檢查是否為重複鍵錯誤
|
|
||||||
if mongodriver.IsDuplicateKeyError(err) {
|
|
||||||
return errs.ResourceAlreadyExist(client.ClientID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *ClientRepository) GetByID(ctx context.Context, id string) (*entity.Client, error) {
|
|
||||||
var client entity.Client
|
|
||||||
objID, err := bson.ObjectIDFromHex(id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
rk := domain.GeClientRedisKey(objID.Hex())
|
|
||||||
err = repo.DB.FindOne(ctx, rk, &client, bson.M{"_id": objID})
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, mongodriver.ErrNoDocuments) {
|
|
||||||
return nil, errs.ResourceNotFoundWithScope(
|
|
||||||
code.CloudEPPermission,
|
|
||||||
domain.FailedToGetByID,
|
|
||||||
"failed to get client by id")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return &client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *ClientRepository) GetByClientID(ctx context.Context, clientID string) (*entity.Client, error) {
|
|
||||||
var client entity.Client
|
|
||||||
rk := domain.GeClientRedisKey(clientID)
|
|
||||||
err := repo.DB.FindOne(ctx, rk, &client, bson.M{"client_id": clientID})
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, mongodriver.ErrNoDocuments) {
|
|
||||||
return nil, errs.ResourceNotFoundWithScope(
|
|
||||||
code.CloudEPPermission,
|
|
||||||
domain.FailedToGetByClientID,
|
|
||||||
"failed to get client by client id")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return &client, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *ClientRepository) Update(ctx context.Context, id string, client *entity.Client) error {
|
|
||||||
client.UpdateTime = time.Now()
|
|
||||||
objID, err := bson.ObjectIDFromHex(id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
update := bson.M{
|
|
||||||
"$set": bson.M{
|
|
||||||
"name": client.Name,
|
|
||||||
"secret": client.Secret,
|
|
||||||
"status": client.Status,
|
|
||||||
"update_time": client.UpdateTime,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
gc, err := repo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
rk := domain.GeClientRedisKey(objID.Hex())
|
|
||||||
_, err = repo.DB.UpdateOne(ctx, rk, bson.M{"_id": objID}, update)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
rk = domain.GeClientRedisKey(gc.ClientID)
|
|
||||||
err = repo.DB.DelCache(ctx, rk)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *ClientRepository) Delete(ctx context.Context, id string) error {
|
|
||||||
objID, err := bson.ObjectIDFromHex(id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
gc, err := repo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
rk := domain.GeClientRedisKey(gc.ClientID)
|
|
||||||
err = repo.DB.DelCache(ctx, rk)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
rk = domain.GeClientRedisKey(objID.Hex())
|
|
||||||
_, err = repo.DB.DeleteOne(ctx, rk, bson.M{"_id": objID})
|
|
||||||
if err != nil {
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *ClientRepository) List(ctx context.Context, filter repository.ClientFilter) ([]*entity.Client, error) {
|
|
||||||
query := bson.M{}
|
|
||||||
|
|
||||||
if filter.Status != nil {
|
|
||||||
query["status"] = *filter.Status
|
|
||||||
}
|
|
||||||
|
|
||||||
var clients []*entity.Client
|
|
||||||
err := repo.DB.GetClient().Find(ctx, query, &clients,
|
|
||||||
options.Find().SetLimit(int64(filter.Limit)),
|
|
||||||
options.Find().SetSkip(int64(filter.Skip)),
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return clients, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Index20241226001UP 創建索引
|
|
||||||
func (repo *ClientRepository) Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error) {
|
|
||||||
repo.DB.PopulateIndex(ctx, "client_id", 1, true)
|
|
||||||
repo.DB.PopulateIndex(ctx, "status", 1, false)
|
|
||||||
|
|
||||||
return repo.DB.GetClient().Indexes().List(ctx)
|
|
||||||
}
|
|
||||||
|
|
@ -1,209 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/library/errs/code"
|
|
||||||
"backend/pkg/permission/domain"
|
|
||||||
"backend/pkg/permission/domain/permission"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"backend/pkg/library/errs"
|
|
||||||
"backend/pkg/library/mongo"
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"backend/pkg/permission/domain/repository"
|
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/mon"
|
|
||||||
"go.mongodb.org/mongo-driver/v2/bson"
|
|
||||||
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PermissionRepositoryParam struct {
|
|
||||||
Conf *mongo.Conf
|
|
||||||
CacheConf cache.CacheConf
|
|
||||||
DBOpts []mon.Option
|
|
||||||
CacheOpts []cache.Option
|
|
||||||
}
|
|
||||||
|
|
||||||
type PermissionRepository struct {
|
|
||||||
DB mongo.DocumentDBWithCacheUseCase
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPermissionRepository 創建權限倉庫實例
|
|
||||||
func NewPermissionRepository(param PermissionRepositoryParam) repository.PermissionRepository {
|
|
||||||
e := entity.Permission{}
|
|
||||||
documentDB, err := mongo.MustDocumentDBWithCache(
|
|
||||||
param.Conf,
|
|
||||||
e.CollectionName(),
|
|
||||||
param.CacheConf,
|
|
||||||
param.DBOpts,
|
|
||||||
param.CacheOpts,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &PermissionRepository{
|
|
||||||
DB: documentDB,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *PermissionRepository) Create(ctx context.Context, permission *entity.Permission) error {
|
|
||||||
now := time.Now()
|
|
||||||
permission.CreateTime = now
|
|
||||||
permission.UpdateTime = now
|
|
||||||
|
|
||||||
id := bson.NewObjectID()
|
|
||||||
permission.ID = id
|
|
||||||
|
|
||||||
rk := domain.GetPermissionRedisKey(id.Hex())
|
|
||||||
_, err := repo.DB.InsertOne(ctx, rk, permission)
|
|
||||||
if err != nil {
|
|
||||||
// 檢查是否為重複鍵錯誤
|
|
||||||
if mongodriver.IsDuplicateKeyError(err) {
|
|
||||||
return errs.ResourceAlreadyExist(permission.ID.Hex())
|
|
||||||
}
|
|
||||||
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *PermissionRepository) GetByID(ctx context.Context, id string) (*entity.Permission, error) {
|
|
||||||
var p entity.Permission
|
|
||||||
objID, err := bson.ObjectIDFromHex(id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rk := domain.GetPermissionRedisKey(objID.Hex())
|
|
||||||
err = repo.DB.FindOne(ctx, rk, &p, bson.M{"_id": objID})
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, mongodriver.ErrNoDocuments) {
|
|
||||||
return nil, errs.ResourceNotFoundWithScope(
|
|
||||||
code.CloudEPPermission,
|
|
||||||
domain.FailedToGetPermission,
|
|
||||||
"failed to get permission by id")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return &p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *PermissionRepository) GetByKey(ctx context.Context, httpMethod, httpPath string) (*entity.Permission, error) {
|
|
||||||
filter := bson.M{
|
|
||||||
"http_method": httpMethod,
|
|
||||||
"http_path": httpPath,
|
|
||||||
"status": permission.StatusActive,
|
|
||||||
}
|
|
||||||
|
|
||||||
var p entity.Permission
|
|
||||||
err := repo.DB.GetClient().FindOne(ctx, &p, filter)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, mongodriver.ErrNoDocuments) {
|
|
||||||
return nil, errs.ResourceNotFoundWithScope(
|
|
||||||
code.CloudEPPermission, domain.FailedToGetPermissionByKey,
|
|
||||||
"failed to get permission by key")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
return &p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *PermissionRepository) Update(ctx context.Context, id string, permission *entity.Permission) error {
|
|
||||||
permission.UpdateTime = time.Now()
|
|
||||||
update := bson.M{
|
|
||||||
"$set": bson.M{
|
|
||||||
"parent_id": permission.ParentID,
|
|
||||||
"name": permission.Name,
|
|
||||||
"http_method": permission.HTTPMethod,
|
|
||||||
"http_path": permission.HTTPPath,
|
|
||||||
"status": permission.Status,
|
|
||||||
"type": permission.Type,
|
|
||||||
"update_time": permission.UpdateTime,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rk := domain.GetPermissionRedisKey(id)
|
|
||||||
objID, err := bson.ObjectIDFromHex(id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = repo.DB.UpdateOne(ctx, rk, bson.M{"_id": objID}, update)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *PermissionRepository) Delete(ctx context.Context, id string) error {
|
|
||||||
rk := domain.GetPermissionRedisKey(id)
|
|
||||||
objID, err := bson.ObjectIDFromHex(id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = repo.DB.DeleteOne(ctx, rk, bson.M{"_id": objID})
|
|
||||||
if err != nil {
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *PermissionRepository) List(ctx context.Context, filter repository.PermissionFilter) ([]*entity.Permission, error) {
|
|
||||||
query := bson.M{}
|
|
||||||
|
|
||||||
if filter.Status != nil {
|
|
||||||
query["status"] = *filter.Status
|
|
||||||
}
|
|
||||||
if filter.Type != nil {
|
|
||||||
query["type"] = *filter.Type
|
|
||||||
}
|
|
||||||
if filter.ParentID != nil {
|
|
||||||
query["parent_id"] = *filter.ParentID
|
|
||||||
}
|
|
||||||
|
|
||||||
var permissions []*entity.Permission
|
|
||||||
err := repo.DB.GetClient().Find(ctx,
|
|
||||||
&permissions, query,
|
|
||||||
options.Find().SetLimit(int64(filter.Limit)),
|
|
||||||
options.Find().SetSkip(int64(filter.Skip)))
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return permissions, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *PermissionRepository) GetActivePermissions(ctx context.Context) ([]*entity.Permission, error) {
|
|
||||||
status := permission.StatusActive
|
|
||||||
filter := repository.PermissionFilter{
|
|
||||||
Status: &status,
|
|
||||||
}
|
|
||||||
|
|
||||||
return repo.List(ctx, filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Index20241226001UP 創建索引
|
|
||||||
func (repo *PermissionRepository) Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error) {
|
|
||||||
// 等價於 db.account.createIndex({ "login_id": 1, "platform": 1}, {unique: true})
|
|
||||||
repo.DB.PopulateMultiIndex(ctx, []string{
|
|
||||||
"http_method",
|
|
||||||
"http_path",
|
|
||||||
}, []int32{1, 1}, true)
|
|
||||||
|
|
||||||
// 等價於 db.account.createIndex({"create_at": 1})
|
|
||||||
repo.DB.PopulateIndex(ctx, "name", 1, false)
|
|
||||||
repo.DB.PopulateIndex(ctx, "status", 1, false)
|
|
||||||
repo.DB.PopulateIndex(ctx, "type", 1, false)
|
|
||||||
|
|
||||||
return repo.DB.GetClient().Indexes().List(ctx)
|
|
||||||
}
|
|
||||||
|
|
@ -1,233 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/library/errs/code"
|
|
||||||
"backend/pkg/permission/domain"
|
|
||||||
"backend/pkg/permission/domain/permission"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"backend/pkg/library/errs"
|
|
||||||
"backend/pkg/library/mongo"
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"backend/pkg/permission/domain/repository"
|
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/mon"
|
|
||||||
"go.mongodb.org/mongo-driver/v2/bson"
|
|
||||||
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RoleRepositoryParam struct {
|
|
||||||
Conf *mongo.Conf
|
|
||||||
CacheConf cache.CacheConf
|
|
||||||
DBOpts []mon.Option
|
|
||||||
CacheOpts []cache.Option
|
|
||||||
}
|
|
||||||
|
|
||||||
type RoleRepository struct {
|
|
||||||
DB mongo.DocumentDBWithCacheUseCase
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRoleRepository 創建角色倉庫實例
|
|
||||||
func NewRoleRepository(param RoleRepositoryParam) repository.RoleRepository {
|
|
||||||
e := entity.Role{}
|
|
||||||
documentDB, err := mongo.MustDocumentDBWithCache(
|
|
||||||
param.Conf,
|
|
||||||
e.CollectionName(),
|
|
||||||
param.CacheConf,
|
|
||||||
param.DBOpts,
|
|
||||||
param.CacheOpts,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &RoleRepository{
|
|
||||||
DB: documentDB,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *RoleRepository) Create(ctx context.Context, role *entity.Role) error {
|
|
||||||
now := time.Now()
|
|
||||||
role.CreateTime = now
|
|
||||||
role.UpdateTime = now
|
|
||||||
id := bson.NewObjectID()
|
|
||||||
role.ID = id
|
|
||||||
|
|
||||||
rk := domain.GetRoleRedisKeyRedisKey(id.Hex())
|
|
||||||
_, err := repo.DB.InsertOne(ctx, rk, role)
|
|
||||||
if err != nil {
|
|
||||||
// 檢查是否為重複鍵錯誤
|
|
||||||
if mongodriver.IsDuplicateKeyError(err) {
|
|
||||||
return errs.ResourceAlreadyExist(role.ClientID)
|
|
||||||
}
|
|
||||||
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *RoleRepository) GetByID(ctx context.Context, id string) (*entity.Role, error) {
|
|
||||||
var role entity.Role
|
|
||||||
objID, err := bson.ObjectIDFromHex(id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rk := domain.GetRoleRedisKeyRedisKey(id)
|
|
||||||
err = repo.DB.FindOne(ctx, rk, &role, bson.M{"client_id": objID})
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, mongodriver.ErrNoDocuments) {
|
|
||||||
return nil, errs.ResourceNotFoundWithScope(
|
|
||||||
code.CloudEPPermission,
|
|
||||||
domain.FailedToGetRoleByID,
|
|
||||||
"failed to get role by id")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return &role, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *RoleRepository) GetByUID(ctx context.Context, uid string) (*entity.Role, error) {
|
|
||||||
var role entity.Role
|
|
||||||
rk := domain.GetRoleRedisKeyRedisKey(uid)
|
|
||||||
err := repo.DB.FindOne(ctx, rk, &role, bson.M{"uid": uid, "status": permission.StatusActive})
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, mongodriver.ErrNoDocuments) {
|
|
||||||
return nil, errs.ResourceNotFoundWithScope(
|
|
||||||
code.CloudEPPermission,
|
|
||||||
domain.FailedToGetByUID,
|
|
||||||
"failed to get role by uid")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return &role, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *RoleRepository) GetByClientAndName(ctx context.Context, clientID, name string) (*entity.Role, error) {
|
|
||||||
filter := bson.M{
|
|
||||||
"client_id": clientID,
|
|
||||||
"name": name,
|
|
||||||
"status": permission.StatusActive,
|
|
||||||
}
|
|
||||||
|
|
||||||
var role entity.Role
|
|
||||||
err := repo.DB.GetClient().FindOne(ctx, &role, filter)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, mongodriver.ErrNoDocuments) {
|
|
||||||
return nil, errs.ResourceNotFoundWithScope(
|
|
||||||
code.CloudEPPermission, domain.FailedToGetByClientAndName, "failed to get by client and name")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return &role, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *RoleRepository) Update(ctx context.Context, id string, role *entity.Role) error {
|
|
||||||
role.UpdateTime = time.Now()
|
|
||||||
objID, err := bson.ObjectIDFromHex(id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
update := bson.M{
|
|
||||||
"$set": bson.M{
|
|
||||||
"name": role.Name,
|
|
||||||
"status": role.Status,
|
|
||||||
"permissions": role.Permissions,
|
|
||||||
"update_time": role.UpdateTime,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rk := domain.GetRoleRedisKeyRedisKey(id)
|
|
||||||
|
|
||||||
_, err = repo.DB.UpdateOne(ctx, rk, bson.M{"_id": objID}, update)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *RoleRepository) Delete(ctx context.Context, id string) error {
|
|
||||||
rk := domain.GetRoleRedisKeyRedisKey(id)
|
|
||||||
objID, err := bson.ObjectIDFromHex(id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
gc, err := repo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
rk = domain.GetRoleRedisKeyRedisKey(gc.UID)
|
|
||||||
err = repo.DB.DelCache(ctx, rk)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = repo.DB.DeleteOne(ctx, rk, bson.M{"_id": objID})
|
|
||||||
if err != nil {
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *RoleRepository) List(ctx context.Context, filter repository.RoleFilter) ([]*entity.Role, error) {
|
|
||||||
query := bson.M{}
|
|
||||||
|
|
||||||
if filter.ClientID != "" {
|
|
||||||
query["client_id"] = filter.ClientID
|
|
||||||
}
|
|
||||||
if filter.Status != nil {
|
|
||||||
query["status"] = *filter.Status
|
|
||||||
}
|
|
||||||
|
|
||||||
var roles []*entity.Role
|
|
||||||
err := repo.DB.GetClient().Find(ctx, &roles, query,
|
|
||||||
options.Find().SetLimit(int64(filter.Limit)),
|
|
||||||
options.Find().SetSkip(int64(filter.Skip)),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return roles, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *RoleRepository) GetRolesByClientID(ctx context.Context, clientID string) ([]*entity.Role, error) {
|
|
||||||
status := permission.StatusActive
|
|
||||||
filter := repository.RoleFilter{
|
|
||||||
ClientID: clientID,
|
|
||||||
Status: &status,
|
|
||||||
}
|
|
||||||
|
|
||||||
return repo.List(ctx, filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Index20241226001UP 創建索引
|
|
||||||
func (repo *RoleRepository) Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error) {
|
|
||||||
// 等價於 db.account.createIndex({ "login_id": 1, "platform": 1}, {unique: true})
|
|
||||||
repo.DB.PopulateMultiIndex(ctx, []string{
|
|
||||||
"client_id",
|
|
||||||
"name",
|
|
||||||
}, []int32{1, 1}, true)
|
|
||||||
|
|
||||||
// 等價於 db.account.createIndex({"create_at": 1})
|
|
||||||
repo.DB.PopulateIndex(ctx, "uid", 1, true)
|
|
||||||
repo.DB.PopulateIndex(ctx, "status", 1, false)
|
|
||||||
|
|
||||||
return repo.DB.GetClient().Indexes().List(ctx)
|
|
||||||
}
|
|
||||||
|
|
@ -1,145 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/library/errs"
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"backend/pkg/permission/domain/repository"
|
|
||||||
"context"
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Token Repository Implementation
|
|
||||||
|
|
||||||
type TokenRepositoryParam struct {
|
|
||||||
Redis *redis.Redis
|
|
||||||
}
|
|
||||||
|
|
||||||
type TokenRepository struct {
|
|
||||||
Redis *redis.Redis
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTokenRepository 創建令牌倉庫實例
|
|
||||||
func NewTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
|
|
||||||
return &TokenRepository{
|
|
||||||
Redis: param.Redis,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *TokenRepository) Create(ctx context.Context, token *entity.Token) error {
|
|
||||||
// 驗證數據
|
|
||||||
if err := token.Validate(); err != nil {
|
|
||||||
return errs.InvalidFormat(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
token.CreateTime = time.Now()
|
|
||||||
token.UpdateTime = time.Now()
|
|
||||||
|
|
||||||
// 在 Redis 中存儲 access token
|
|
||||||
accessKey := "token:access:" + token.AccessToken
|
|
||||||
refreshKey := "token:refresh:" + token.RefreshToken
|
|
||||||
|
|
||||||
// 設置過期時間
|
|
||||||
expiry := int(time.Until(token.ExpiresAt).Seconds())
|
|
||||||
if expiry <= 0 {
|
|
||||||
return errs.InvalidFormat("token already expired")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 存儲 access token
|
|
||||||
err := r.Redis.SetexCtx(ctx, accessKey, token.UID+":"+token.ClientID+":"+token.DeviceID, expiry)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DatabaseErr(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 存儲 refresh token (較長的過期時間)
|
|
||||||
refreshExpiry := expiry * 7 // refresh token 過期時間是 access token 的 7 倍
|
|
||||||
err = r.Redis.SetexCtx(ctx, refreshKey, token.UID+":"+token.ClientID+":"+token.DeviceID, refreshExpiry)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DatabaseErr(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *TokenRepository) GetByAccessToken(ctx context.Context, accessToken string) (*entity.Token, error) {
|
|
||||||
key := "token:access:" + accessToken
|
|
||||||
value, err := r.Redis.GetCtx(ctx, key)
|
|
||||||
if err != nil {
|
|
||||||
if err == redis.Nil {
|
|
||||||
return nil, errs.NotFound("access_token")
|
|
||||||
}
|
|
||||||
return nil, errs.DatabaseErr(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析值
|
|
||||||
parts := strings.Split(value, ":")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
return nil, errs.InvalidFormat("invalid token format")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &entity.Token{
|
|
||||||
UID: parts[0],
|
|
||||||
ClientID: parts[1],
|
|
||||||
DeviceID: parts[2],
|
|
||||||
AccessToken: accessToken,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *TokenRepository) GetByRefreshToken(ctx context.Context, refreshToken string) (*entity.Token, error) {
|
|
||||||
key := "token:refresh:" + refreshToken
|
|
||||||
value, err := r.Redis.GetCtx(ctx, key)
|
|
||||||
if err != nil {
|
|
||||||
if err == redis.Nil {
|
|
||||||
return nil, errs.NotFound("refresh_token")
|
|
||||||
}
|
|
||||||
return nil, errs.DatabaseErr(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析值
|
|
||||||
parts := strings.Split(value, ":")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
return nil, errs.InvalidFormat("invalid token format")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &entity.Token{
|
|
||||||
UID: parts[0],
|
|
||||||
ClientID: parts[1],
|
|
||||||
DeviceID: parts[2],
|
|
||||||
RefreshToken: refreshToken,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *TokenRepository) Update(ctx context.Context, token *entity.Token) error {
|
|
||||||
// 驗證數據
|
|
||||||
if err := token.Validate(); err != nil {
|
|
||||||
return errs.InvalidFormat(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
token.UpdateTime = time.Now()
|
|
||||||
|
|
||||||
// 重新存儲 access token
|
|
||||||
accessKey := "token:access:" + token.AccessToken
|
|
||||||
expiry := int(time.Until(token.ExpiresAt).Seconds())
|
|
||||||
if expiry <= 0 {
|
|
||||||
return errs.InvalidFormat("token already expired")
|
|
||||||
}
|
|
||||||
|
|
||||||
err := r.Redis.SetexCtx(ctx, accessKey, token.UID+":"+token.ClientID+":"+token.DeviceID, expiry)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DatabaseErr(err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *TokenRepository) Delete(ctx context.Context, id bson.ObjectID) error {
|
|
||||||
// Redis 版本不需要 ObjectID,這裡留空實現
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *TokenRepository) DeleteByUserID(ctx context.Context, uid string) error {
|
|
||||||
// 可以實現刪除用戶所有 token 的邏輯
|
|
||||||
// 這裡簡化實現
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,382 @@
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"backend/pkg/permission/domain/entity"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupMiniRedis() (*miniredis.Miniredis, *redis.Redis) {
|
||||||
|
// 啟動 setupMiniRedis 作為模擬的 Redis 服務
|
||||||
|
mr, err := miniredis.Run()
|
||||||
|
if err != nil {
|
||||||
|
panic("failed to start miniRedis: " + err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 setupMiniRedis 的地址配置 go-zero Redis 客戶端
|
||||||
|
redisConf := redis.RedisConf{
|
||||||
|
Host: mr.Addr(),
|
||||||
|
Type: "node",
|
||||||
|
}
|
||||||
|
r := redis.MustNewRedis(redisConf)
|
||||||
|
|
||||||
|
return mr, r
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenRepository_Blacklist(t *testing.T) {
|
||||||
|
mr, r := setupMiniRedis()
|
||||||
|
defer mr.Close()
|
||||||
|
|
||||||
|
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("AddToBlacklist", func(t *testing.T) {
|
||||||
|
entry := &entity.BlacklistEntry{
|
||||||
|
JTI: "test-jti-123",
|
||||||
|
UID: "user123",
|
||||||
|
TokenID: "token123",
|
||||||
|
Reason: "user logout",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := repo.AddToBlacklist(ctx, entry, time.Hour)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify it was added
|
||||||
|
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, isBlacklisted)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("IsBlacklisted - not found", func(t *testing.T) {
|
||||||
|
isBlacklisted, err := repo.IsBlacklisted(ctx, "non-existent-jti")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, isBlacklisted)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("RemoveFromBlacklist", func(t *testing.T) {
|
||||||
|
// First add an entry
|
||||||
|
entry := &entity.BlacklistEntry{
|
||||||
|
JTI: "test-jti-456",
|
||||||
|
UID: "user456",
|
||||||
|
TokenID: "token456",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
err := repo.AddToBlacklist(ctx, entry, time.Hour)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify it exists
|
||||||
|
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, isBlacklisted)
|
||||||
|
|
||||||
|
// Remove it
|
||||||
|
err = repo.RemoveFromBlacklist(ctx, entry.JTI)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify it's gone
|
||||||
|
isBlacklisted, err = repo.IsBlacklisted(ctx, entry.JTI)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, isBlacklisted)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetBlacklistedTokensByUID", func(t *testing.T) {
|
||||||
|
uid := "user789"
|
||||||
|
|
||||||
|
// Add multiple entries for the same user
|
||||||
|
entries := []*entity.BlacklistEntry{
|
||||||
|
{
|
||||||
|
JTI: "jti-1",
|
||||||
|
UID: uid,
|
||||||
|
TokenID: "token-1",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
JTI: "jti-2",
|
||||||
|
UID: uid,
|
||||||
|
TokenID: "token-2",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
JTI: "jti-3",
|
||||||
|
UID: "different-user",
|
||||||
|
TokenID: "token-3",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
err := repo.AddToBlacklist(ctx, entry, time.Hour)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get blacklisted tokens for the user
|
||||||
|
userEntries, err := repo.GetBlacklistedTokensByUID(ctx, uid)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, userEntries, 2) // Should only get entries for the specific user
|
||||||
|
|
||||||
|
// Verify all returned entries belong to the correct user
|
||||||
|
for _, entry := range userEntries {
|
||||||
|
assert.Equal(t, uid, entry.UID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AddToBlacklist with zero TTL", func(t *testing.T) {
|
||||||
|
entry := &entity.BlacklistEntry{
|
||||||
|
JTI: "test-jti-zero-ttl",
|
||||||
|
UID: "user-zero-ttl",
|
||||||
|
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with zero TTL - should calculate from ExpiresAt
|
||||||
|
err := repo.AddToBlacklist(ctx, entry, 0)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify it was added
|
||||||
|
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, isBlacklisted)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("AddToBlacklist with expired token", func(t *testing.T) {
|
||||||
|
entry := &entity.BlacklistEntry{
|
||||||
|
JTI: "test-jti-expired",
|
||||||
|
UID: "user-expired",
|
||||||
|
ExpiresAt: time.Now().Add(-time.Hour).Unix(), // Already expired
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should not add expired token to blacklist
|
||||||
|
err := repo.AddToBlacklist(ctx, entry, 0)
|
||||||
|
assert.NoError(t, err) // No error, but token won't be added
|
||||||
|
|
||||||
|
// Verify it was not added
|
||||||
|
isBlacklisted, err := repo.IsBlacklisted(ctx, entry.JTI)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.False(t, isBlacklisted)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenRepository_CreateAndGet(t *testing.T) {
|
||||||
|
mr, r := setupMiniRedis()
|
||||||
|
defer mr.Close()
|
||||||
|
|
||||||
|
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("Create and GetAccessTokenByID", func(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
token := entity.Token{
|
||||||
|
ID: "test-token-123",
|
||||||
|
UID: "user123",
|
||||||
|
DeviceID: "device123",
|
||||||
|
AccessToken: "access-token-123",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
AccessCreateAt: now,
|
||||||
|
RefreshToken: "refresh-token-123",
|
||||||
|
RefreshCreateAt: now,
|
||||||
|
RefreshExpiresIn: 86400,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create token
|
||||||
|
err := repo.Create(ctx, token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Get token by ID
|
||||||
|
retrievedToken, err := repo.GetAccessTokenByID(ctx, token.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, token.ID, retrievedToken.ID)
|
||||||
|
assert.Equal(t, token.UID, retrievedToken.UID)
|
||||||
|
assert.Equal(t, token.AccessToken, retrievedToken.AccessToken)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetAccessTokensByUID", func(t *testing.T) {
|
||||||
|
uid := "user456"
|
||||||
|
now := time.Now()
|
||||||
|
tokens := []entity.Token{
|
||||||
|
{
|
||||||
|
ID: "token-1",
|
||||||
|
UID: uid,
|
||||||
|
DeviceID: "device1",
|
||||||
|
AccessToken: "access-1",
|
||||||
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
||||||
|
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "token-2",
|
||||||
|
UID: uid,
|
||||||
|
DeviceID: "device2",
|
||||||
|
AccessToken: "access-2",
|
||||||
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
||||||
|
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create tokens
|
||||||
|
for _, token := range tokens {
|
||||||
|
err := repo.Create(ctx, token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get tokens by UID
|
||||||
|
retrievedTokens, err := repo.GetAccessTokensByUID(ctx, uid)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, retrievedTokens, 2)
|
||||||
|
|
||||||
|
// Verify all tokens belong to the user
|
||||||
|
for _, token := range retrievedTokens {
|
||||||
|
assert.Equal(t, uid, token.UID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("GetAccessTokenCountByUID", func(t *testing.T) {
|
||||||
|
uid := "user789"
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Create multiple tokens for the user
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
token := entity.Token{
|
||||||
|
ID: "count-token-" + string(rune(i+'1')),
|
||||||
|
UID: uid,
|
||||||
|
DeviceID: "device" + string(rune(i+'1')),
|
||||||
|
AccessToken: "access-" + string(rune(i+'1')),
|
||||||
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
||||||
|
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
|
||||||
|
}
|
||||||
|
err := repo.Create(ctx, token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get count
|
||||||
|
count, err := repo.GetAccessTokenCountByUID(ctx, uid)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 3, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Delete", func(t *testing.T) {
|
||||||
|
token := entity.Token{
|
||||||
|
ID: "delete-token",
|
||||||
|
UID: "delete-user",
|
||||||
|
DeviceID: "delete-device",
|
||||||
|
AccessToken: "delete-access",
|
||||||
|
RefreshToken: "delete-refresh",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create token
|
||||||
|
err := repo.Create(ctx, token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify it exists
|
||||||
|
_, err = repo.GetAccessTokenByID(ctx, token.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Delete token
|
||||||
|
err = repo.Delete(ctx, token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify it's gone
|
||||||
|
_, err = repo.GetAccessTokenByID(ctx, token.ID)
|
||||||
|
assert.Error(t, err) // Should return error when not found
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("DeleteAccessTokensByUID", func(t *testing.T) {
|
||||||
|
uid := "delete-user-uid"
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Create multiple tokens for the user
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
token := entity.Token{
|
||||||
|
ID: "delete-uid-token-" + string(rune(i+'1')),
|
||||||
|
UID: uid,
|
||||||
|
DeviceID: "device" + string(rune(i+'1')),
|
||||||
|
AccessToken: "access-" + string(rune(i+'1')),
|
||||||
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
||||||
|
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
|
||||||
|
}
|
||||||
|
err := repo.Create(ctx, token)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify tokens exist
|
||||||
|
count, err := repo.GetAccessTokenCountByUID(ctx, uid)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, count)
|
||||||
|
|
||||||
|
// Delete all tokens for the user
|
||||||
|
err = repo.DeleteAccessTokensByUID(ctx, uid)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify tokens are gone
|
||||||
|
count, err = repo.GetAccessTokenCountByUID(ctx, uid)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, count)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenRepository_OneTimeToken(t *testing.T) {
|
||||||
|
mr, r := setupMiniRedis()
|
||||||
|
defer mr.Close()
|
||||||
|
|
||||||
|
repo := &TokenRepository{TokenRepositoryParam: TokenRepositoryParam{Redis: r}}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("CreateOneTimeToken", func(t *testing.T) {
|
||||||
|
now := time.Now()
|
||||||
|
// Create one-time token with ticket
|
||||||
|
token := entity.Token{
|
||||||
|
ID: "one-time-base-token",
|
||||||
|
UID: "user123",
|
||||||
|
AccessToken: "base-access-token",
|
||||||
|
ExpiresIn: int(now.Add(time.Hour).Unix()),
|
||||||
|
RefreshExpiresIn: int(now.Add(24 * time.Hour).Unix()),
|
||||||
|
}
|
||||||
|
|
||||||
|
oneTimeKey := "one-time-key-123"
|
||||||
|
ticket := entity.Ticket{
|
||||||
|
Data: map[string]string{"uid": "user123"},
|
||||||
|
Token: token,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := repo.CreateOneTimeToken(ctx, oneTimeKey, ticket, time.Minute)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
t.Run("DeleteOneTimeToken", func(t *testing.T) {
|
||||||
|
// Create one-time tokens
|
||||||
|
keys := []string{"delete-key-1", "delete-key-2"}
|
||||||
|
ticket := entity.Ticket{
|
||||||
|
Data: map[string]string{"test": "data"},
|
||||||
|
Token: entity.Token{ID: "test-token"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
err := repo.CreateOneTimeToken(ctx, key, ticket, time.Minute)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete one-time tokens
|
||||||
|
err := repo.DeleteOneTimeToken(ctx, keys, nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify they're gone
|
||||||
|
for _, key := range keys {
|
||||||
|
_, err := repo.GetAccessTokenByOneTimeToken(ctx, key)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,430 @@
|
||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"backend/pkg/permission/domain/entity"
|
||||||
|
"backend/pkg/permission/domain/repository"
|
||||||
|
"backend/pkg/permission/domain/token"
|
||||||
|
|
||||||
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TokenRepositoryParam token 需要的參數
|
||||||
|
type TokenRepositoryParam struct {
|
||||||
|
Redis *redis.Redis
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenRepository 通知
|
||||||
|
type TokenRepository struct {
|
||||||
|
TokenRepositoryParam
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustTokenRepository(param TokenRepositoryParam) repository.TokenRepository {
|
||||||
|
return &TokenRepository{
|
||||||
|
param,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create 創建一個新 Token,並將其存儲於 Redis
|
||||||
|
func (repo *TokenRepository) Create(ctx context.Context, token entity.Token) error {
|
||||||
|
body, err := json.Marshal(token)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshTTL := time.Duration(token.RedisRefreshExpiredSec()) * time.Second
|
||||||
|
|
||||||
|
return repo.runPipeline(ctx, func(tx redis.Pipeliner) error {
|
||||||
|
if err := repo.setToken(ctx, tx, token, body, refreshTTL); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := repo.setRefreshToken(ctx, tx, token, refreshTTL); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return repo.setRelation(ctx, tx, token.UID, token.DeviceID, token.ID, refreshTTL)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) CreateOneTimeToken(ctx context.Context, key string, ticket entity.Ticket, dt time.Duration) error {
|
||||||
|
body, err := json.Marshal(ticket)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = repo.Redis.SetnxExCtx(ctx, token.RefreshTokenRedisKey(key), string(body), int(dt.Seconds()))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) GetAccessTokenByOneTimeToken(ctx context.Context, oneTimeToken string) (entity.Token, error) {
|
||||||
|
id, err := repo.Redis.Get(token.RefreshTokenRedisKey(oneTimeToken))
|
||||||
|
if err != nil {
|
||||||
|
return entity.Token{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if id == "" {
|
||||||
|
return entity.Token{}, fmt.Errorf("token not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
return repo.GetAccessTokenByID(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) GetAccessTokenByID(ctx context.Context, id string) (entity.Token, error) {
|
||||||
|
return repo.get(ctx, token.GetAccessTokenRedisKey(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) GetAccessTokensByUID(ctx context.Context, uid string) ([]entity.Token, error) {
|
||||||
|
return repo.getTokensBySet(ctx, token.GetUIDTokenRedisKey(uid))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) GetAccessTokenCountByUID(ctx context.Context, uid string) (int, error) {
|
||||||
|
return repo.getCountBySet(ctx, token.UIDTokenRedisKey(uid))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) GetAccessTokensByDeviceID(ctx context.Context, deviceID string) ([]entity.Token, error) {
|
||||||
|
return repo.getTokensBySet(ctx, token.DeviceTokenRedisKey(deviceID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) GetAccessTokenCountByDeviceID(ctx context.Context, deviceID string) (int, error) {
|
||||||
|
return repo.getCountBySet(ctx, token.DeviceTokenRedisKey(deviceID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) Delete(ctx context.Context, tokenObj entity.Token) error {
|
||||||
|
// Delete 刪除指定的 Token
|
||||||
|
keys := []string{
|
||||||
|
token.GetAccessTokenRedisKey(tokenObj.ID),
|
||||||
|
token.RefreshTokenRedisKey(tokenObj.RefreshToken),
|
||||||
|
}
|
||||||
|
|
||||||
|
return repo.deleteKeysAndRelations(ctx, keys, tokenObj.UID, tokenObj.DeviceID, tokenObj.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) DeleteOneTimeToken(ctx context.Context, ids []string, tokens []entity.Token) error {
|
||||||
|
l := len(ids) + len(tokens)
|
||||||
|
keys := make([]string, 0, l)
|
||||||
|
|
||||||
|
for _, id := range ids {
|
||||||
|
keys = append(keys, token.RefreshTokenRedisKey(id))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tokenObj := range tokens {
|
||||||
|
keys = append(keys, token.RefreshTokenRedisKey(tokenObj.RefreshToken))
|
||||||
|
}
|
||||||
|
|
||||||
|
return repo.deleteKeys(ctx, keys...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) DeleteAccessTokenByID(ctx context.Context, ids []string) error {
|
||||||
|
for _, tokenID := range ids {
|
||||||
|
tokenObj, err := repo.GetAccessTokenByID(ctx, tokenID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
keys := []string{
|
||||||
|
token.GetAccessTokenRedisKey(tokenObj.ID),
|
||||||
|
token.RefreshTokenRedisKey(tokenObj.RefreshToken),
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = repo.deleteKeysAndRelations(ctx, keys, tokenObj.UID, tokenObj.DeviceID, tokenObj.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) DeleteAccessTokensByUID(ctx context.Context, uid string) error {
|
||||||
|
tokens, err := repo.GetAccessTokensByUID(ctx, uid)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, token := range tokens {
|
||||||
|
if err := repo.Delete(ctx, token); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) DeleteAccessTokensByDeviceID(ctx context.Context, deviceID string) error {
|
||||||
|
tokens, err := repo.GetAccessTokensByDeviceID(ctx, deviceID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
l := len(tokens) * 2
|
||||||
|
keys := make([]string, 0, l)
|
||||||
|
for _, tokenObj := range tokens {
|
||||||
|
keys = append(keys, token.GetAccessTokenRedisKey(tokenObj.ID))
|
||||||
|
keys = append(keys, token.RefreshTokenRedisKey(tokenObj.RefreshToken))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = repo.runPipeline(ctx, func(tx redis.Pipeliner) error {
|
||||||
|
for _, tokenObj := range tokens {
|
||||||
|
tx.SRem(ctx, token.UIDTokenRedisKey(tokenObj.UID), tokenObj.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := repo.deleteKeys(ctx, keys...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = repo.Redis.Del(token.DeviceTokenRedisKey(deviceID))
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================================================================
|
||||||
|
// deleteKeysAndRelations 刪除指定鍵並移除相關的關聯
|
||||||
|
func (repo *TokenRepository) deleteKeysAndRelations(ctx context.Context, keys []string, uid, deviceID, tokenID string) error {
|
||||||
|
err := repo.Redis.Pipelined(func(tx redis.Pipeliner) error {
|
||||||
|
// 刪除 UID 和 DeviceID 的關聯
|
||||||
|
_ = tx.SRem(ctx, token.UIDTokenRedisKey(uid), tokenID)
|
||||||
|
_ = tx.SRem(ctx, token.DeviceTokenRedisKey(deviceID), tokenID)
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
_ = tx.Del(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// runPipeline 執行 Redis 的 Pipeline 操作
|
||||||
|
func (repo *TokenRepository) runPipeline(ctx context.Context, fn func(tx redis.Pipeliner) error) error {
|
||||||
|
if err := repo.Redis.PipelinedCtx(ctx, fn); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// deleteKeys 批量刪除 Redis 鍵
|
||||||
|
func (repo *TokenRepository) deleteKeys(ctx context.Context, keys ...string) error {
|
||||||
|
return repo.Redis.Pipelined(func(tx redis.Pipeliner) error {
|
||||||
|
for _, key := range keys {
|
||||||
|
if err := tx.Del(ctx, key).Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) setToken(ctx context.Context, tx redis.Pipeliner, tokenObj entity.Token, body []byte, ttl time.Duration) error {
|
||||||
|
return tx.Set(ctx, token.GetAccessTokenRedisKey(tokenObj.ID), body, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) setRefreshToken(ctx context.Context, tx redis.Pipeliner, tokenObj entity.Token, ttl time.Duration) error {
|
||||||
|
if tokenObj.RefreshToken != "" {
|
||||||
|
return tx.Set(ctx, token.RefreshTokenRedisKey(tokenObj.RefreshToken), tokenObj.ID, ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (repo *TokenRepository) setRelation(ctx context.Context, tx redis.Pipeliner, uid, deviceID, tokenID string, ttl time.Duration) error {
|
||||||
|
if err := tx.SAdd(ctx, token.UIDTokenRedisKey(uid), tokenID).Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 設置 UID 鍵的過期時間
|
||||||
|
if err := tx.Expire(ctx, token.UIDTokenRedisKey(uid), ttl).Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.SAdd(ctx, token.DeviceTokenRedisKey(deviceID), tokenID).Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 設置 deviceID 鍵的過期時間
|
||||||
|
if err := tx.Expire(ctx, token.DeviceTokenRedisKey(deviceID), ttl).Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// get 根據鍵獲取 Token
|
||||||
|
func (repo *TokenRepository) get(ctx context.Context, key string) (entity.Token, error) {
|
||||||
|
body, err := repo.Redis.GetCtx(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return entity.Token{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if body == "" {
|
||||||
|
return entity.Token{}, fmt.Errorf("token not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
var token entity.Token
|
||||||
|
if err := json.Unmarshal([]byte(body), &token); err != nil {
|
||||||
|
return entity.Token{}, fmt.Errorf("json.Marshal token error")
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getTokensBySet 根據集合鍵獲取所有 Token
|
||||||
|
func (repo *TokenRepository) getTokensBySet(ctx context.Context, setKey string) ([]entity.Token, error) {
|
||||||
|
ids, err := repo.Redis.Smembers(setKey)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, redis.Nil) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := make([]entity.Token, 0, len(ids))
|
||||||
|
var deleteTokens []string
|
||||||
|
now := time.Now().Unix()
|
||||||
|
for _, id := range ids {
|
||||||
|
token, err := repo.get(ctx, token.GetAccessTokenRedisKey(id))
|
||||||
|
if err != nil {
|
||||||
|
deleteTokens = append(deleteTokens, id)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if int64(token.ExpiresIn) < now {
|
||||||
|
deleteTokens = append(deleteTokens, id)
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tokens = append(tokens, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(deleteTokens) > 0 {
|
||||||
|
_ = repo.DeleteAccessTokenByID(ctx, deleteTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getCountBySet 獲取集合中的元素數量
|
||||||
|
func (repo *TokenRepository) getCountBySet(ctx context.Context, setKey string) (int, error) {
|
||||||
|
count, err := repo.Redis.ScardCtx(ctx, setKey)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return int(count), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddToBlacklist 將 token 加入黑名單
|
||||||
|
func (repo *TokenRepository) AddToBlacklist(ctx context.Context, entry *entity.BlacklistEntry, ttl time.Duration) error {
|
||||||
|
key := token.GetBlacklistRedisKey(entry.JTI)
|
||||||
|
|
||||||
|
// 序列化黑名單條目
|
||||||
|
data, err := json.Marshal(entry)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal blacklist entry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用提供的 TTL,如果 TTL <= 0,則計算默認 TTL
|
||||||
|
if ttl <= 0 {
|
||||||
|
// 計算 TTL (token 過期時間 - 當前時間)
|
||||||
|
ttl = time.Unix(entry.ExpiresAt, 0).Sub(time.Now())
|
||||||
|
if ttl <= 0 {
|
||||||
|
// Token 已經過期,不需要加入黑名單
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 存儲到 Redis 並設置過期時間
|
||||||
|
err = repo.Redis.SetexCtx(ctx, key, string(data), int(ttl.Seconds()))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add token to blacklist: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsBlacklisted 檢查 token 是否在黑名單中
|
||||||
|
func (repo *TokenRepository) IsBlacklisted(ctx context.Context, jti string) (bool, error) {
|
||||||
|
key := token.GetBlacklistRedisKey(jti)
|
||||||
|
|
||||||
|
exists, err := repo.Redis.ExistsCtx(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("failed to check blacklist: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return exists, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveFromBlacklist 從黑名單中移除 token
|
||||||
|
func (repo *TokenRepository) RemoveFromBlacklist(ctx context.Context, jti string) error {
|
||||||
|
key := token.GetBlacklistRedisKey(jti)
|
||||||
|
|
||||||
|
_, err := repo.Redis.DelCtx(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove token from blacklist: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlacklistedTokensByUID 獲取用戶的所有黑名單 token
|
||||||
|
func (repo *TokenRepository) GetBlacklistedTokensByUID(ctx context.Context, uid string) ([]*entity.BlacklistEntry, error) {
|
||||||
|
// 使用 SCAN 來查找所有黑名單鍵
|
||||||
|
pattern := token.BlacklistKeyPrefix + "*"
|
||||||
|
|
||||||
|
var entries []*entity.BlacklistEntry
|
||||||
|
var cursor uint64 = 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
keys, nextCursor, err := repo.Redis.ScanCtx(ctx, cursor, pattern, 100)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan blacklist keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 獲取每個鍵的值並檢查 UID
|
||||||
|
for _, key := range keys {
|
||||||
|
data, err := repo.Redis.GetCtx(ctx, key)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, redis.Nil) {
|
||||||
|
continue // 鍵已過期或不存在
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("failed to get blacklist entry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var entry entity.BlacklistEntry
|
||||||
|
if err := json.Unmarshal([]byte(data), &entry); err != nil {
|
||||||
|
continue // 跳過無效的條目
|
||||||
|
}
|
||||||
|
|
||||||
|
// 檢查 UID 是否匹配
|
||||||
|
if entry.UID == uid {
|
||||||
|
entries = append(entries, &entry)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor = nextCursor
|
||||||
|
if cursor == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return entries, nil
|
||||||
|
}
|
||||||
|
|
@ -1,238 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/library/errs/code"
|
|
||||||
"backend/pkg/permission/domain"
|
|
||||||
"backend/pkg/permission/domain/permission"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"go.mongodb.org/mongo-driver/v2/mongo/options"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"backend/pkg/library/errs"
|
|
||||||
"backend/pkg/library/mongo"
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"backend/pkg/permission/domain/repository"
|
|
||||||
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/cache"
|
|
||||||
"github.com/zeromicro/go-zero/core/stores/mon"
|
|
||||||
"go.mongodb.org/mongo-driver/v2/bson"
|
|
||||||
mongodriver "go.mongodb.org/mongo-driver/v2/mongo"
|
|
||||||
)
|
|
||||||
|
|
||||||
type UserRoleRepositoryParam struct {
|
|
||||||
Conf *mongo.Conf
|
|
||||||
CacheConf cache.CacheConf
|
|
||||||
DBOpts []mon.Option
|
|
||||||
CacheOpts []cache.Option
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserRoleRepository struct {
|
|
||||||
DB mongo.DocumentDBWithCacheUseCase
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewUserRoleRepository 創建用戶角色倉庫實例
|
|
||||||
func NewUserRoleRepository(param UserRoleRepositoryParam) repository.UserRoleRepository {
|
|
||||||
e := entity.UserRole{}
|
|
||||||
documentDB, err := mongo.MustDocumentDBWithCache(
|
|
||||||
param.Conf,
|
|
||||||
e.CollectionName(),
|
|
||||||
param.CacheConf,
|
|
||||||
param.DBOpts,
|
|
||||||
param.CacheOpts,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UserRoleRepository{
|
|
||||||
DB: documentDB,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *UserRoleRepository) Create(ctx context.Context, userRole *entity.UserRole) error {
|
|
||||||
now := time.Now()
|
|
||||||
userRole.CreateTime = now
|
|
||||||
userRole.UpdateTime = now
|
|
||||||
id := bson.NewObjectID()
|
|
||||||
userRole.ID = id
|
|
||||||
|
|
||||||
rk := domain.GetUserRoleRedisKey(id.Hex())
|
|
||||||
userRole.CreateTime = time.Now()
|
|
||||||
userRole.UpdateTime = time.Now()
|
|
||||||
|
|
||||||
_, err := repo.DB.InsertOne(ctx, rk, userRole)
|
|
||||||
if err != nil {
|
|
||||||
// 檢查是否為重複鍵錯誤
|
|
||||||
if mongodriver.IsDuplicateKeyError(err) {
|
|
||||||
return errs.ResourceAlreadyExist("failed to insert user role")
|
|
||||||
}
|
|
||||||
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *UserRoleRepository) GetByID(ctx context.Context, id string) (*entity.UserRole, error) {
|
|
||||||
var userRole entity.UserRole
|
|
||||||
objID, err := bson.ObjectIDFromHex(id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rk := domain.GetUserRoleRedisKey(id)
|
|
||||||
err = repo.DB.FindOne(ctx, rk, &userRole, bson.M{"_id": objID})
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, mongodriver.ErrNoDocuments) {
|
|
||||||
return nil, errs.ResourceNotFoundWithScope(
|
|
||||||
code.CloudEPPermission,
|
|
||||||
domain.FailedToGetRoleByID,
|
|
||||||
"failed to get user role by id")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return &userRole, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *UserRoleRepository) GetByUserAndRole(ctx context.Context, uid, roleUID string) (*entity.UserRole, error) {
|
|
||||||
filter := bson.M{
|
|
||||||
"uid": uid,
|
|
||||||
"role_uid": roleUID,
|
|
||||||
"status": permission.StatusActive,
|
|
||||||
}
|
|
||||||
|
|
||||||
var userRole entity.UserRole
|
|
||||||
err := repo.DB.GetClient().Find(ctx, &userRole, filter)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, mongodriver.ErrNoDocuments) {
|
|
||||||
return nil, errs.ResourceNotFoundWithScope(code.CloudEPPermission, 0, "failed to get user and role")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, errs.DatabaseErrorWithScope(code.CloudEPPermission, 0, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return &userRole, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *UserRoleRepository) Update(ctx context.Context, id string, userRole *entity.UserRole) error {
|
|
||||||
userRole.UpdateTime = time.Now()
|
|
||||||
objID, err := bson.ObjectIDFromHex(id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
update := bson.M{
|
|
||||||
"$set": bson.M{
|
|
||||||
"status": userRole.Status,
|
|
||||||
"update_time": userRole.UpdateTime,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
rk := domain.GetUserRoleRedisKey(id)
|
|
||||||
|
|
||||||
_, err = repo.DB.UpdateOne(ctx, rk, bson.M{"_id": objID}, update)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *UserRoleRepository) Delete(ctx context.Context, id string) error {
|
|
||||||
objID, err := bson.ObjectIDFromHex(id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
rk := domain.GetUserRoleRedisKey(id)
|
|
||||||
_, err = repo.DB.DeleteOne(ctx, rk, bson.M{"_id": objID})
|
|
||||||
if err != nil {
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *UserRoleRepository) List(ctx context.Context, filter repository.UserRoleFilter) ([]*entity.UserRole, error) {
|
|
||||||
query := bson.M{}
|
|
||||||
if filter.Brand != "" {
|
|
||||||
query["brand"] = filter.Brand
|
|
||||||
}
|
|
||||||
if filter.UID != "" {
|
|
||||||
query["uid"] = filter.UID
|
|
||||||
}
|
|
||||||
if filter.RoleUID != "" {
|
|
||||||
query["role_uid"] = filter.RoleUID
|
|
||||||
}
|
|
||||||
if filter.Status != nil {
|
|
||||||
query["status"] = *filter.Status
|
|
||||||
}
|
|
||||||
|
|
||||||
var userRoles []*entity.UserRole
|
|
||||||
err := repo.DB.GetClient().Find(ctx, &userRoles, query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
err = repo.DB.GetClient().Find(ctx,
|
|
||||||
&userRoles, query,
|
|
||||||
options.Find().SetLimit(int64(filter.Limit)),
|
|
||||||
options.Find().SetSkip(int64(filter.Skip)))
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return userRoles, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *UserRoleRepository) GetUserRolesByUID(ctx context.Context, uid string) ([]*entity.UserRole, error) {
|
|
||||||
status := permission.StatusActive
|
|
||||||
filter := repository.UserRoleFilter{
|
|
||||||
UID: uid,
|
|
||||||
Status: &status,
|
|
||||||
}
|
|
||||||
|
|
||||||
return repo.List(ctx, filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (repo *UserRoleRepository) DeleteByUserAndRole(ctx context.Context, uid, roleUID string) error {
|
|
||||||
filter := repository.UserRoleFilter{
|
|
||||||
UID: uid,
|
|
||||||
RoleUID: roleUID,
|
|
||||||
}
|
|
||||||
list, err := repo.List(ctx, filter)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if len(list) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, item := range list {
|
|
||||||
_ = repo.DB.DelCache(ctx, domain.GetUserRoleRedisKey(item.ID.Hex()))
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = repo.DB.GetClient().DeleteMany(ctx, filter)
|
|
||||||
if err != nil {
|
|
||||||
return errs.DBErrorWithScope(code.CloudEPPermission, err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Index20241226001UP 創建索引
|
|
||||||
func (repo *UserRoleRepository) Index20241226001UP(ctx context.Context) (*mongodriver.Cursor, error) {
|
|
||||||
// 等價於 db.account.createIndex({ "login_id": 1, "platform": 1}, {unique: true})
|
|
||||||
repo.DB.PopulateMultiIndex(ctx, []string{
|
|
||||||
"uid",
|
|
||||||
"role_uid",
|
|
||||||
}, []int32{1, 1}, true)
|
|
||||||
|
|
||||||
// 等價於 db.account.createIndex({"create_at": 1})
|
|
||||||
repo.DB.PopulateIndex(ctx, "uid", 1, false)
|
|
||||||
repo.DB.PopulateIndex(ctx, "status", 1, false)
|
|
||||||
|
|
||||||
return repo.DB.GetClient().Indexes().List(ctx)
|
|
||||||
}
|
|
||||||
|
|
@ -1,207 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/library/errs/code"
|
|
||||||
"backend/pkg/permission/utils"
|
|
||||||
"context"
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/hex"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"backend/pkg/library/errs"
|
|
||||||
"backend/pkg/permission/domain/config"
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"backend/pkg/permission/domain/repository"
|
|
||||||
"backend/pkg/permission/domain/usecase"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AuthUseCaseParam struct {
|
|
||||||
ClientRepo repository.ClientRepository
|
|
||||||
TokenRepo repository.TokenRepository
|
|
||||||
JWTConfig config.JWTConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthUseCase struct {
|
|
||||||
clientRepo repository.ClientRepository
|
|
||||||
tokenRepo repository.TokenRepository
|
|
||||||
jwtConfig config.JWTConfig
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustAuthUseCase 創建認證用例實例
|
|
||||||
func MustAuthUseCase(param AuthUseCaseParam) usecase.AuthUseCase {
|
|
||||||
return &AuthUseCase{
|
|
||||||
clientRepo: param.ClientRepo,
|
|
||||||
tokenRepo: param.TokenRepo,
|
|
||||||
jwtConfig: param.JWTConfig,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *AuthUseCase) CreateToken(ctx context.Context, req usecase.CreateTokenRequest) (*usecase.TokenResponse, error) {
|
|
||||||
// 驗證客戶端
|
|
||||||
client, err := uc.clientRepo.GetByClientID(ctx, req.ClientID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !utils.IsActive(client.Status) {
|
|
||||||
return nil, errs.UserSuspended(code.CloudEPPermission, "failed to get token since user has been suspended")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 根據授權類型處理
|
|
||||||
var uid string
|
|
||||||
switch req.GrantType {
|
|
||||||
case "client_credentials":
|
|
||||||
uid = "client_" + req.ClientID
|
|
||||||
case "password":
|
|
||||||
if req.Username == "" || req.Password == "" {
|
|
||||||
return nil, errs.InvalidCredentials()
|
|
||||||
}
|
|
||||||
// 這裡應該驗證用戶名密碼,簡化處理
|
|
||||||
uid = req.Username
|
|
||||||
default:
|
|
||||||
return nil, errs.InvalidFormat("unsupported grant type: " + req.GrantType)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 生成令牌
|
|
||||||
accessToken, err := uc.generateAccessToken(uid, req.ClientID, req.DeviceID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.SystemInternal("failed to generate access token: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
refreshToken, err := uc.generateRefreshToken()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.SystemInternal("failed to generate refresh token: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存令牌
|
|
||||||
token := &entity.Token{
|
|
||||||
UID: uid,
|
|
||||||
ClientID: req.ClientID,
|
|
||||||
AccessToken: accessToken,
|
|
||||||
RefreshToken: refreshToken,
|
|
||||||
DeviceID: req.DeviceID,
|
|
||||||
ExpiresAt: time.Now().Add(uc.jwtConfig.AccessExpires),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := uc.tokenRepo.Create(ctx, token); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &usecase.TokenResponse{
|
|
||||||
AccessToken: accessToken,
|
|
||||||
RefreshToken: refreshToken,
|
|
||||||
TokenType: "Bearer",
|
|
||||||
ExpiresIn: int64(uc.jwtConfig.AccessExpires.Seconds()),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *AuthUseCase) RefreshToken(ctx context.Context, refreshToken string) (*usecase.TokenResponse, error) {
|
|
||||||
// 查找刷新令牌
|
|
||||||
token, err := uc.tokenRepo.GetByRefreshToken(ctx, refreshToken)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if token.IsExpired() {
|
|
||||||
return nil, errs.TokenExpired()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 生成新的訪問令牌
|
|
||||||
accessToken, err := uc.generateAccessToken(token.UID, token.ClientID, token.DeviceID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.SystemInternal("failed to generate access token: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新令牌
|
|
||||||
token.AccessToken = accessToken
|
|
||||||
token.ExpiresAt = time.Now().Add(uc.jwtConfig.AccessExpires)
|
|
||||||
|
|
||||||
if err := uc.tokenRepo.Update(ctx, token); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &usecase.TokenResponse{
|
|
||||||
AccessToken: accessToken,
|
|
||||||
RefreshToken: refreshToken,
|
|
||||||
TokenType: "Bearer",
|
|
||||||
ExpiresIn: int64(uc.jwtConfig.AccessExpires.Seconds()),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *AuthUseCase) ValidateToken(ctx context.Context, accessToken string) (*usecase.TokenClaims, error) {
|
|
||||||
// 解析JWT令牌
|
|
||||||
token, err := jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
|
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
||||||
return nil, errs.TokenInvalid()
|
|
||||||
}
|
|
||||||
return []byte(uc.jwtConfig.Secret), nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.TokenInvalid()
|
|
||||||
}
|
|
||||||
|
|
||||||
if !token.Valid {
|
|
||||||
return nil, errs.TokenInvalid()
|
|
||||||
}
|
|
||||||
|
|
||||||
claims, ok := token.Claims.(jwt.MapClaims)
|
|
||||||
if !ok {
|
|
||||||
return nil, errs.TokenInvalid()
|
|
||||||
}
|
|
||||||
|
|
||||||
uid, ok := claims["uid"].(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, errs.TokenInvalid()
|
|
||||||
}
|
|
||||||
|
|
||||||
clientID, ok := claims["client_id"].(string)
|
|
||||||
if !ok {
|
|
||||||
return nil, errs.TokenInvalid()
|
|
||||||
}
|
|
||||||
|
|
||||||
deviceID, _ := claims["device_id"].(string)
|
|
||||||
|
|
||||||
return &usecase.TokenClaims{
|
|
||||||
UID: uid,
|
|
||||||
ClientID: clientID,
|
|
||||||
DeviceID: deviceID,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *AuthUseCase) Logout(ctx context.Context, accessToken string) error {
|
|
||||||
// 查找並刪除令牌
|
|
||||||
token, err := uc.tokenRepo.GetByAccessToken(ctx, accessToken)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.tokenRepo.Delete(ctx, token.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *AuthUseCase) LogoutAllByUserID(ctx context.Context, uid string) error {
|
|
||||||
return uc.tokenRepo.DeleteByUserID(ctx, uid)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *AuthUseCase) generateAccessToken(uid, clientID, deviceID string) (string, error) {
|
|
||||||
claims := jwt.MapClaims{
|
|
||||||
"uid": uid,
|
|
||||||
"client_id": clientID,
|
|
||||||
"device_id": deviceID,
|
|
||||||
"exp": time.Now().Add(uc.jwtConfig.AccessExpires).Unix(),
|
|
||||||
"iat": time.Now().Unix(),
|
|
||||||
}
|
|
||||||
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
||||||
return token.SignedString([]byte(uc.jwtConfig.Secret))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *AuthUseCase) generateRefreshToken() (string, error) {
|
|
||||||
bytes := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(bytes); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return hex.EncodeToString(bytes), nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,348 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/library/errs/code"
|
|
||||||
"context"
|
|
||||||
"github.com/zeromicro/go-zero/core/logx"
|
|
||||||
"go.mongodb.org/mongo-driver/v2/bson"
|
|
||||||
|
|
||||||
"backend/pkg/library/errs"
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"backend/pkg/permission/domain/repository"
|
|
||||||
"backend/pkg/permission/domain/usecase"
|
|
||||||
|
|
||||||
"github.com/casbin/casbin/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PermissionUseCaseParam struct {
|
|
||||||
Enforcer *casbin.Enforcer
|
|
||||||
PermissionRepo repository.PermissionRepository
|
|
||||||
RoleRepo repository.RoleRepository
|
|
||||||
UserRoleRepo repository.UserRoleRepository
|
|
||||||
}
|
|
||||||
|
|
||||||
type PermissionUseCase struct {
|
|
||||||
enforcer *casbin.Enforcer
|
|
||||||
permissionRepo repository.PermissionRepository
|
|
||||||
roleRepo repository.RoleRepository
|
|
||||||
userRoleRepo repository.UserRoleRepository
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustPermissionUseCase 創建權限用例實例
|
|
||||||
func MustPermissionUseCase(param PermissionUseCaseParam) usecase.PermissionUseCase {
|
|
||||||
return &PermissionUseCase{
|
|
||||||
enforcer: param.Enforcer,
|
|
||||||
permissionRepo: param.PermissionRepo,
|
|
||||||
roleRepo: param.RoleRepo,
|
|
||||||
userRoleRepo: param.UserRoleRepo,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *PermissionUseCase) CreatePermission(ctx context.Context, req usecase.CreatePermissionRequest) (*entity.Permission, error) {
|
|
||||||
// 驗證請求
|
|
||||||
if req.Name == "" {
|
|
||||||
return nil, errs.InvalidFormat("permission name is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
permission := &entity.Permission{
|
|
||||||
Name: req.Name,
|
|
||||||
HTTPMethod: req.HTTPMethod,
|
|
||||||
HTTPPath: req.HTTPPath,
|
|
||||||
Status: req.Status,
|
|
||||||
Type: req.Type,
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.ParentID != nil {
|
|
||||||
objID, err := bson.ObjectIDFromHex(*req.ParentID)
|
|
||||||
if err != nil {
|
|
||||||
e := errs.InvalidFormat(err.Error())
|
|
||||||
return nil, e
|
|
||||||
}
|
|
||||||
permission.ID = objID
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := uc.permissionRepo.Create(ctx, permission); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return permission, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *PermissionUseCase) GetPermission(ctx context.Context, id string) (*entity.Permission, error) {
|
|
||||||
return uc.permissionRepo.GetByID(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *PermissionUseCase) UpdatePermission(ctx context.Context, req usecase.UpdatePermissionRequest) (*entity.Permission, error) {
|
|
||||||
// 獲取現有權限
|
|
||||||
permission, err := uc.permissionRepo.GetByID(ctx, req.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新字段
|
|
||||||
if req.Name != nil {
|
|
||||||
permission.Name = *req.Name
|
|
||||||
}
|
|
||||||
if req.HTTPMethod != nil {
|
|
||||||
permission.HTTPMethod = *req.HTTPMethod
|
|
||||||
}
|
|
||||||
if req.HTTPPath != nil {
|
|
||||||
permission.HTTPPath = *req.HTTPPath
|
|
||||||
}
|
|
||||||
if req.Status != nil {
|
|
||||||
permission.Status = *req.Status
|
|
||||||
}
|
|
||||||
if req.Type != nil {
|
|
||||||
permission.Type = *req.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := uc.permissionRepo.Update(ctx, req.ID, permission); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return permission, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *PermissionUseCase) DeletePermission(ctx context.Context, id string) error {
|
|
||||||
return uc.permissionRepo.Delete(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *PermissionUseCase) ListPermissions(ctx context.Context, req usecase.ListPermissionsRequest) ([]*entity.Permission, error) {
|
|
||||||
filter := repository.PermissionFilter{
|
|
||||||
Status: req.Status,
|
|
||||||
Type: req.Type,
|
|
||||||
ParentID: req.ParentID,
|
|
||||||
Limit: req.Limit,
|
|
||||||
Skip: req.Skip,
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.permissionRepo.List(ctx, filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckUserPermission 使用 Casbin 檢查用戶權限
|
|
||||||
func (uc *PermissionUseCase) CheckUserPermission(ctx context.Context, uid, httpMethod, httpPath string) (bool, error) {
|
|
||||||
// 使用 Casbin 進行權限檢查
|
|
||||||
// sub: 用戶ID, obj: 資源路徑, act: 行為
|
|
||||||
hasPermission, err := uc.enforcer.Enforce(uid, httpPath, httpMethod)
|
|
||||||
if err != nil {
|
|
||||||
return false, errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin enforce failed: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasPermission {
|
|
||||||
return false, errs.InsufficientPermission(httpMethod + ":" + httpPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckRolePermission 使用 Casbin 檢查角色權限
|
|
||||||
func (uc *PermissionUseCase) CheckRolePermission(ctx context.Context, roleUID, httpMethod, httpPath string) (bool, error) {
|
|
||||||
// 使用 Casbin 進行角色權限檢查
|
|
||||||
hasPermission, err := uc.enforcer.Enforce(roleUID, httpPath, httpMethod)
|
|
||||||
if err != nil {
|
|
||||||
return false, errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin enforce failed: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if !hasPermission {
|
|
||||||
return false, errs.InsufficientPermission(httpMethod + ":" + httpPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUserPermissions 獲取用戶的所有權限
|
|
||||||
func (uc *PermissionUseCase) GetUserPermissions(ctx context.Context, uid string) (map[string]int, error) {
|
|
||||||
// 獲取用戶的所有角色
|
|
||||||
roles, err := uc.enforcer.GetRolesForUser(uid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.SystemInternalErrorScope(code.CloudEPPermission, "failed to get user permissions: "+err.Error())
|
|
||||||
}
|
|
||||||
permissions := make(map[string]int)
|
|
||||||
|
|
||||||
// 獲取用戶直接擁有的權限
|
|
||||||
userPolicies, err := uc.enforcer.GetPermissionsForUser(uid)
|
|
||||||
if err != nil {
|
|
||||||
logx.Infof("failed to get user permissions: " + err.Error())
|
|
||||||
}
|
|
||||||
for _, policy := range userPolicies {
|
|
||||||
if len(policy) >= 3 {
|
|
||||||
key := policy[2] + ":" + policy[1] // method:path
|
|
||||||
permissions[key] = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 獲取通過角色繼承的權限
|
|
||||||
for _, role := range roles {
|
|
||||||
rolePolicies, err := uc.enforcer.GetPermissionsForUser(role)
|
|
||||||
if err != nil {
|
|
||||||
logx.Infof("failed to get permissions for user: " + err.Error())
|
|
||||||
}
|
|
||||||
for _, policy := range rolePolicies {
|
|
||||||
if len(policy) >= 3 {
|
|
||||||
key := policy[2] + ":" + policy[1] // method:path
|
|
||||||
permissions[key] = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return permissions, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchCheckPermissions 批量檢查權限
|
|
||||||
func (uc *PermissionUseCase) BatchCheckPermissions(ctx context.Context, uid string, permissions []usecase.PermissionCheck) (map[string]bool, error) {
|
|
||||||
results := make(map[string]bool)
|
|
||||||
|
|
||||||
for _, perm := range permissions {
|
|
||||||
key := perm.HTTPMethod + ":" + perm.HTTPPath
|
|
||||||
hasPermission, err := uc.enforcer.Enforce(uid, perm.HTTPPath, perm.HTTPMethod)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin enforce failed: "+err.Error())
|
|
||||||
}
|
|
||||||
results[key] = hasPermission
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPolicyForUser 為用戶添加權限策略
|
|
||||||
func (uc *PermissionUseCase) AddPolicyForUser(ctx context.Context, uid, httpPath, httpMethod string) error {
|
|
||||||
added, err := uc.enforcer.AddPolicy(uid, httpPath, httpMethod)
|
|
||||||
if err != nil {
|
|
||||||
return errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin add policy failed: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if !added {
|
|
||||||
return errs.ResourceAlreadyExistWithScope(code.CloudEPPermission, "policy already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePolicyForUser 移除用戶的權限策略
|
|
||||||
func (uc *PermissionUseCase) RemovePolicyForUser(ctx context.Context, uid, httpPath, httpMethod string) error {
|
|
||||||
removed, err := uc.enforcer.RemovePolicy(uid, httpPath, httpMethod)
|
|
||||||
if err != nil {
|
|
||||||
return errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin remove policy failed: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if !removed {
|
|
||||||
return errs.ResourceNotFoundWithScope(code.CloudEPPermission, 0, "policy not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddRoleForUser 為用戶分配角色
|
|
||||||
func (uc *PermissionUseCase) AddRoleForUser(ctx context.Context, uid, roleUID string) error {
|
|
||||||
added, err := uc.enforcer.AddRoleForUser(uid, roleUID)
|
|
||||||
if err != nil {
|
|
||||||
return errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin add role failed: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if !added {
|
|
||||||
return errs.ResourceAlreadyExistWithScope(code.CloudEPPermission, "role already assigned")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveRoleForUser 移除用戶的角色
|
|
||||||
func (uc *PermissionUseCase) RemoveRoleForUser(ctx context.Context, uid, roleUID string) error {
|
|
||||||
removed, err := uc.enforcer.DeleteRoleForUser(uid, roleUID)
|
|
||||||
if err != nil {
|
|
||||||
return errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin remove role failed: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if !removed {
|
|
||||||
return errs.ResourceNotFoundWithScope(code.CloudEPPermission, 0, "role assignment not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetUsersForRole 獲取角色下的所有用戶
|
|
||||||
func (uc *PermissionUseCase) GetUsersForRole(ctx context.Context, roleUID string) ([]string, error) {
|
|
||||||
return uc.enforcer.GetUsersForRole(roleUID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetRolesForUser 獲取用戶的所有角色
|
|
||||||
func (uc *PermissionUseCase) GetRolesForUser(ctx context.Context, uid string) ([]string, error) {
|
|
||||||
return uc.enforcer.GetRolesForUser(uid)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPermissionForRole 為角色添加權限
|
|
||||||
func (uc *PermissionUseCase) AddPermissionForRole(ctx context.Context, roleUID, httpPath, httpMethod string) error {
|
|
||||||
added, err := uc.enforcer.AddPolicy(roleUID, httpPath, httpMethod)
|
|
||||||
if err != nil {
|
|
||||||
return errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin add policy failed: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if !added {
|
|
||||||
return errs.ResourceAlreadyExistWithScope(code.CloudEPPermission, "policy already exists")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePermissionForRole 移除角色的權限
|
|
||||||
func (uc *PermissionUseCase) RemovePermissionForRole(ctx context.Context, roleUID, httpPath, httpMethod string) error {
|
|
||||||
removed, err := uc.enforcer.RemovePolicy(roleUID, httpPath, httpMethod)
|
|
||||||
if err != nil {
|
|
||||||
return errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin remove policy failed: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if !removed {
|
|
||||||
return errs.ResourceNotFoundWithScope(code.CloudEPPermission, 0, "policy not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPermissionsForRole 獲取角色的所有權限
|
|
||||||
func (uc *PermissionUseCase) GetPermissionsForRole(ctx context.Context, roleUID string) (map[string]int, error) {
|
|
||||||
policies, err := uc.enforcer.GetPermissionsForUser(roleUID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin get permissions failed: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
permissions := make(map[string]int)
|
|
||||||
|
|
||||||
for _, policy := range policies {
|
|
||||||
if len(policy) >= 3 {
|
|
||||||
key := policy[2] + ":" + policy[1] // method:path
|
|
||||||
permissions[key] = 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return permissions, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckPatternPermission 檢查模式權限 (支援通配符)
|
|
||||||
func (uc *PermissionUseCase) CheckPatternPermission(ctx context.Context, uid, pattern, action string) (bool, error) {
|
|
||||||
hasPermission, err := uc.enforcer.Enforce(uid, pattern, action)
|
|
||||||
if err != nil {
|
|
||||||
return false, errs.SystemInternalErrorScope(code.CloudEPPermission, "casbin enforce failed: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return hasPermission, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAllPolicies 獲取所有策略
|
|
||||||
func (uc *PermissionUseCase) GetAllPolicies(ctx context.Context) ([][]string, error) {
|
|
||||||
policies, err := uc.enforcer.GetPolicy()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.SystemInternalErrorScope(code.CloudEPPermission, "failed to get all policies: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return policies, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFilteredPolicies 獲取過濾後的策略
|
|
||||||
func (uc *PermissionUseCase) GetFilteredPolicies(ctx context.Context, fieldIndex int, fieldValues ...string) ([][]string, error) {
|
|
||||||
policies, err := uc.enforcer.GetFilteredPolicy(fieldIndex, fieldValues...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.SystemInternalErrorScope(code.CloudEPPermission, "failed to get filtered policies: "+err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return policies, nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,189 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/library/errs/code"
|
|
||||||
"backend/pkg/permission/domain/permission"
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"backend/pkg/library/errs"
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"backend/pkg/permission/domain/repository"
|
|
||||||
"backend/pkg/permission/domain/usecase"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RoleUseCaseParam struct {
|
|
||||||
RoleRepo repository.RoleRepository
|
|
||||||
UserRoleRepo repository.UserRoleRepository
|
|
||||||
}
|
|
||||||
|
|
||||||
type RoleUseCase struct {
|
|
||||||
roleRepo repository.RoleRepository
|
|
||||||
userRoleRepo repository.UserRoleRepository
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustRoleUseCase 創建角色用例實例
|
|
||||||
func MustRoleUseCase(param RoleUseCaseParam) usecase.RoleUseCase {
|
|
||||||
return &RoleUseCase{
|
|
||||||
roleRepo: param.RoleRepo,
|
|
||||||
userRoleRepo: param.UserRoleRepo,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *RoleUseCase) CreateRole(ctx context.Context, req usecase.CreateRoleRequest) (*entity.Role, error) {
|
|
||||||
// 驗證請求
|
|
||||||
if req.ClientID == "" {
|
|
||||||
return nil, errs.InvalidFormat("client_id is required")
|
|
||||||
}
|
|
||||||
if req.Name == "" {
|
|
||||||
return nil, errs.InvalidFormat("role name is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 檢查角色名稱是否已存在
|
|
||||||
existingRole, err := uc.roleRepo.GetByClientAndName(ctx, req.ClientID, req.Name)
|
|
||||||
if err == nil && existingRole != nil {
|
|
||||||
return nil, errs.ResourceAlreadyExistWithScope(code.CloudEPPermission, req.ClientID+":"+req.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
role := &entity.Role{
|
|
||||||
ClientID: req.ClientID,
|
|
||||||
UID: req.UID,
|
|
||||||
Name: req.Name,
|
|
||||||
Status: req.Status,
|
|
||||||
Permissions: req.Permissions,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := uc.roleRepo.Create(ctx, role); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return role, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *RoleUseCase) GetRole(ctx context.Context, id string) (*entity.Role, error) {
|
|
||||||
return uc.roleRepo.GetByID(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *RoleUseCase) GetRoleByUID(ctx context.Context, uid string) (*entity.Role, error) {
|
|
||||||
return uc.roleRepo.GetByUID(ctx, uid)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *RoleUseCase) UpdateRole(ctx context.Context, req usecase.UpdateRoleRequest) (*entity.Role, error) {
|
|
||||||
// 獲取現有角色
|
|
||||||
role, err := uc.roleRepo.GetByID(ctx, req.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新字段
|
|
||||||
if req.Name != nil {
|
|
||||||
// 檢查新名稱是否已存在
|
|
||||||
existingRole, err := uc.roleRepo.GetByClientAndName(ctx, role.ClientID, *req.Name)
|
|
||||||
if err == nil && existingRole != nil && existingRole.ID != role.ID {
|
|
||||||
return nil, errs.ResourceAlreadyExistWithScope(code.CloudEPPermission, role.ClientID+":"+*req.Name)
|
|
||||||
}
|
|
||||||
role.Name = *req.Name
|
|
||||||
}
|
|
||||||
if req.Status != nil {
|
|
||||||
role.Status = *req.Status
|
|
||||||
}
|
|
||||||
if req.Permissions != nil {
|
|
||||||
role.Permissions = *req.Permissions
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := uc.roleRepo.Update(ctx, req.ID, role); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return role, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *RoleUseCase) DeleteRole(ctx context.Context, id string) error {
|
|
||||||
// 獲取角色信息
|
|
||||||
role, err := uc.roleRepo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
status := permission.StatusActive
|
|
||||||
// 檢查是否有用戶使用此角色
|
|
||||||
userRoles, err := uc.userRoleRepo.List(ctx, repository.UserRoleFilter{
|
|
||||||
RoleUID: role.UID,
|
|
||||||
Status: &status,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(userRoles) > 0 {
|
|
||||||
return errs.InvalidFormat("cannot delete role that is assigned to users")
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.roleRepo.Delete(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *RoleUseCase) ListRoles(ctx context.Context, req usecase.ListRolesRequest) ([]*entity.Role, error) {
|
|
||||||
filter := repository.RoleFilter{
|
|
||||||
ClientID: req.ClientID,
|
|
||||||
Status: req.Status,
|
|
||||||
Limit: req.Limit,
|
|
||||||
Skip: req.Skip,
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.roleRepo.List(ctx, filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *RoleUseCase) AddPermissionToRole(ctx context.Context, roleID string, permissionKey string) error {
|
|
||||||
// 獲取角色
|
|
||||||
role, err := uc.roleRepo.GetByID(ctx, roleID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加權限
|
|
||||||
role.AddPermission(permissionKey)
|
|
||||||
return uc.roleRepo.Update(ctx, role.ID.Hex(), role)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *RoleUseCase) RemovePermissionFromRole(ctx context.Context, roleID string, permissionKey string) error {
|
|
||||||
// 獲取角色
|
|
||||||
role, err := uc.roleRepo.GetByID(ctx, roleID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 移除權限
|
|
||||||
role.RemovePermission(permissionKey)
|
|
||||||
|
|
||||||
return uc.roleRepo.Update(ctx, role.ID.Hex(), role)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *RoleUseCase) BatchUpdateRolePermissions(ctx context.Context, roleID string, permissions entity.Permissions) error {
|
|
||||||
// 獲取角色
|
|
||||||
role, err := uc.roleRepo.GetByID(ctx, roleID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 批量更新權限
|
|
||||||
role.Permissions = permissions
|
|
||||||
|
|
||||||
return uc.roleRepo.Update(ctx, role.ID.Hex(), role)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *RoleUseCase) GetRolesByClientID(ctx context.Context, clientID string) ([]*entity.Role, error) {
|
|
||||||
return uc.roleRepo.GetRolesByClientID(ctx, clientID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *RoleUseCase) CopyRole(ctx context.Context, sourceRoleID string, req usecase.CreateRoleRequest) (*entity.Role, error) {
|
|
||||||
// 獲取源角色
|
|
||||||
sourceRole, err := uc.roleRepo.GetByID(ctx, sourceRoleID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 創建新角色,複製權限
|
|
||||||
newReq := req
|
|
||||||
newReq.Permissions = sourceRole.Permissions
|
|
||||||
|
|
||||||
return uc.CreateRole(ctx, newReq)
|
|
||||||
}
|
|
||||||
|
|
@ -0,0 +1,708 @@
|
||||||
|
package usecase
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"backend/internal/config"
|
||||||
|
"backend/pkg/library/errs"
|
||||||
|
"backend/pkg/library/errs/code"
|
||||||
|
"backend/pkg/permission/domain/entity"
|
||||||
|
"backend/pkg/permission/domain/repository"
|
||||||
|
"backend/pkg/permission/domain/token"
|
||||||
|
"backend/pkg/permission/domain/usecase"
|
||||||
|
|
||||||
|
"github.com/segmentio/ksuid"
|
||||||
|
"github.com/zeromicro/go-zero/core/logx"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenUseCaseParam struct {
|
||||||
|
TokenRepo repository.TokenRepository
|
||||||
|
|
||||||
|
Config *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenUseCase struct {
|
||||||
|
TokenUseCaseParam
|
||||||
|
}
|
||||||
|
|
||||||
|
func (use *TokenUseCase) ReadTokenBasicData(ctx context.Context, token string) (map[string]string, error) {
|
||||||
|
claims, err := parseClaims(token, use.Config.Token.AccessSecret, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil,
|
||||||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "parseClaims",
|
||||||
|
req: token,
|
||||||
|
err: err,
|
||||||
|
message: "validate token claims error",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustTokenUseCase(param TokenUseCaseParam) usecase.TokenUseCase {
|
||||||
|
return &TokenUseCase{
|
||||||
|
param,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================ token ============================================
|
||||||
|
|
||||||
|
func (use *TokenUseCase) NewToken(ctx context.Context, req entity.AuthorizationReq) (entity.TokenResp, error) {
|
||||||
|
tokenObj, err := use.newToken(ctx, &req)
|
||||||
|
if err != nil {
|
||||||
|
return entity.TokenResp{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = use.TokenRepo.Create(ctx, *tokenObj)
|
||||||
|
if err != nil {
|
||||||
|
return entity.TokenResp{}, use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.Create",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to create token",
|
||||||
|
errorCode: code.TokenCreateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return entity.TokenResp{
|
||||||
|
AccessToken: tokenObj.AccessToken,
|
||||||
|
TokenType: token.TypeBearer.String(),
|
||||||
|
ExpiresIn: int64(tokenObj.ExpiresIn),
|
||||||
|
RefreshToken: tokenObj.RefreshToken,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (use *TokenUseCase) newToken(ctx context.Context, req *entity.AuthorizationReq) (*entity.Token, error) {
|
||||||
|
// 準備建立 Token 所需
|
||||||
|
now := time.Now().UTC()
|
||||||
|
expires := req.Expires
|
||||||
|
refreshExpires := req.Expires
|
||||||
|
|
||||||
|
if expires <= 0 {
|
||||||
|
// 將時間加上 n 秒
|
||||||
|
sec := use.Config.Token.AccessTokenExpiry
|
||||||
|
// 獲取 Unix 時間戳
|
||||||
|
expires = now.Add(sec).Unix()
|
||||||
|
refreshExpires = expires
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果這是一個 Refresh Token 過期時間要比普通的Token 長
|
||||||
|
if req.IsRefreshToken {
|
||||||
|
// 獲取 Unix 時間戳
|
||||||
|
refresh := use.Config.Token.RefreshTokenExpiry
|
||||||
|
refreshExpires = now.Add(refresh).Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
token := entity.Token{
|
||||||
|
ID: ksuid.New().String(),
|
||||||
|
DeviceID: req.DeviceID,
|
||||||
|
ExpiresIn: int(expires),
|
||||||
|
RefreshExpiresIn: int(refreshExpires),
|
||||||
|
AccessCreateAt: now,
|
||||||
|
RefreshCreateAt: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
tc := make(tokenClaims)
|
||||||
|
if req.Data != nil {
|
||||||
|
for k, v := range req.Data {
|
||||||
|
tc[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tc.SetRole(req.Role)
|
||||||
|
tc.SetID(token.ID)
|
||||||
|
tc.SetScope(req.Scope)
|
||||||
|
tc.SetAccount(req.Account)
|
||||||
|
|
||||||
|
token.UID = tc.UID()
|
||||||
|
|
||||||
|
if req.DeviceID != "" {
|
||||||
|
tc.SetDeviceID(req.DeviceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
token.AccessToken, err = accessTokenGenerator(token, tc, use.Config.Token.AccessSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "accessTokenGenerator",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to generator access token",
|
||||||
|
errorCode: code.TokenCreateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.IsRefreshToken {
|
||||||
|
token.RefreshToken = refreshTokenGenerator(token.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (use *TokenUseCase) RefreshToken(ctx context.Context, req entity.RefreshTokenReq) (entity.RefreshTokenResp, error) {
|
||||||
|
// Step 1: 檢查 refresh token
|
||||||
|
tokenObj, err := use.TokenRepo.GetAccessTokenByOneTimeToken(ctx, req.Token)
|
||||||
|
if err != nil {
|
||||||
|
return entity.RefreshTokenResp{},
|
||||||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.GetAccessTokenByOneTimeToken",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to get access token",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: 提取 Claims Data
|
||||||
|
claimsData, err := parseClaims(tokenObj.AccessToken, use.Config.Token.AccessSecret, false)
|
||||||
|
if err != nil {
|
||||||
|
return entity.RefreshTokenResp{},
|
||||||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "extractClaims",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to extract claims",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: 創建新 token
|
||||||
|
credentials := token.ClientCredentials
|
||||||
|
newToken, err := use.newToken(ctx, &entity.AuthorizationReq{
|
||||||
|
GrantType: credentials.ToString(),
|
||||||
|
Scope: req.Scope,
|
||||||
|
DeviceID: req.DeviceID,
|
||||||
|
Data: claimsData,
|
||||||
|
Expires: req.Expires,
|
||||||
|
IsRefreshToken: true,
|
||||||
|
Account: req.DeviceID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return entity.RefreshTokenResp{},
|
||||||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "use.newToken",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to create new token",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := use.TokenRepo.Create(ctx, *newToken); err != nil {
|
||||||
|
return entity.RefreshTokenResp{},
|
||||||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.Create",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to create new token",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: 刪除舊 token 並創建新 token
|
||||||
|
if err := use.TokenRepo.Delete(ctx, tokenObj); err != nil {
|
||||||
|
return entity.RefreshTokenResp{},
|
||||||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.Delete",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to delete old token",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回新的 Token 響應
|
||||||
|
return entity.RefreshTokenResp{
|
||||||
|
Token: newToken.AccessToken,
|
||||||
|
OneTimeToken: newToken.RefreshToken,
|
||||||
|
ExpiresIn: int64(newToken.ExpiresIn),
|
||||||
|
TokenType: token.TypeBearer.String(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (use *TokenUseCase) CancelToken(ctx context.Context, req entity.CancelTokenReq) error {
|
||||||
|
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false)
|
||||||
|
if err != nil {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "CancelToken extractClaims",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to get token claims",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
|
||||||
|
if err != nil {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo GetAccessTokenByID",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: fmt.Sprintf("failed to get token claims :%s", claims.ID()),
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
err = use.TokenRepo.Delete(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo Delete",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: fmt.Sprintf("failed to delete token :%s", token.ID),
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (use *TokenUseCase) ValidationToken(ctx context.Context, req entity.ValidationTokenReq) (entity.ValidationTokenResp, error) {
|
||||||
|
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, true)
|
||||||
|
if err != nil {
|
||||||
|
return entity.ValidationTokenResp{},
|
||||||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "parseClaims",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "validate token claims error",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
|
||||||
|
if err != nil {
|
||||||
|
return entity.ValidationTokenResp{},
|
||||||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.GetAccessTokenByID",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: fmt.Sprintf("failed to get token :%s", claims.ID()),
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return entity.ValidationTokenResp{
|
||||||
|
Token: entity.Token{
|
||||||
|
ID: token.ID,
|
||||||
|
UID: token.UID,
|
||||||
|
DeviceID: token.DeviceID,
|
||||||
|
AccessCreateAt: token.AccessCreateAt,
|
||||||
|
AccessToken: token.AccessToken,
|
||||||
|
ExpiresIn: token.ExpiresIn,
|
||||||
|
RefreshToken: token.RefreshToken,
|
||||||
|
RefreshExpiresIn: token.RefreshExpiresIn,
|
||||||
|
RefreshCreateAt: token.RefreshCreateAt,
|
||||||
|
},
|
||||||
|
Data: claims,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (use *TokenUseCase) CancelTokens(ctx context.Context, req entity.DoTokenByUIDReq) error {
|
||||||
|
if req.UID != "" {
|
||||||
|
err := use.TokenRepo.DeleteAccessTokensByUID(ctx, req.UID)
|
||||||
|
if err != nil {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.DeleteAccessTokensByUID",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to cancel tokens by uid",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(req.IDs) > 0 {
|
||||||
|
err := use.TokenRepo.DeleteAccessTokenByID(ctx, req.IDs)
|
||||||
|
if err != nil {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.DeleteAccessTokenByID",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to cancel tokens by token ids",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (use *TokenUseCase) CancelTokenByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) error {
|
||||||
|
err := use.TokenRepo.DeleteAccessTokensByDeviceID(ctx, req.DeviceID)
|
||||||
|
if err != nil {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.DeleteAccessTokensByDeviceID",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to cancel token by device id",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (use *TokenUseCase) GetUserTokensByDeviceID(ctx context.Context, req entity.DoTokenByDeviceIDReq) ([]*entity.TokenResp, error) {
|
||||||
|
uidTokens, err := use.TokenRepo.GetAccessTokensByDeviceID(ctx, req.DeviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.GetAccessTokensByDeviceID",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to get token by device id",
|
||||||
|
errorCode: code.TokenNotFound,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
|
||||||
|
for _, v := range uidTokens {
|
||||||
|
tokens = append(tokens, &entity.TokenResp{
|
||||||
|
AccessToken: v.AccessToken,
|
||||||
|
TokenType: token.TypeBearer.String(),
|
||||||
|
ExpiresIn: int64(v.ExpiresIn),
|
||||||
|
RefreshToken: v.RefreshToken,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (use *TokenUseCase) GetUserTokensByUID(ctx context.Context, req entity.QueryTokenByUIDReq) ([]*entity.TokenResp, error) {
|
||||||
|
uidTokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, req.UID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.GetAccessTokensByUID",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to get token by uid",
|
||||||
|
errorCode: code.TokenNotFound,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
tokens := make([]*entity.TokenResp, 0, len(uidTokens))
|
||||||
|
for _, v := range uidTokens {
|
||||||
|
tokens = append(tokens, &entity.TokenResp{
|
||||||
|
AccessToken: v.AccessToken,
|
||||||
|
TokenType: token.TypeBearer.String(),
|
||||||
|
ExpiresIn: int64(v.ExpiresIn),
|
||||||
|
RefreshToken: v.RefreshToken,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (use *TokenUseCase) NewOneTimeToken(ctx context.Context, req entity.CreateOneTimeTokenReq) (entity.CreateOneTimeTokenResp, error) {
|
||||||
|
// 驗證Token
|
||||||
|
claims, err := parseClaims(req.Token, use.Config.Token.AccessSecret, false)
|
||||||
|
if err != nil {
|
||||||
|
return entity.CreateOneTimeTokenResp{},
|
||||||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "parseClaims",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to get token claims",
|
||||||
|
errorCode: code.OneTimeTokenError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenObj, err := use.TokenRepo.GetAccessTokenByID(ctx, claims.ID())
|
||||||
|
if err != nil {
|
||||||
|
return entity.CreateOneTimeTokenResp{},
|
||||||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.GetAccessTokenByID",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to get token by id",
|
||||||
|
errorCode: code.OneTimeTokenError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
oneTimeToken := refreshTokenGenerator(ksuid.New().String())
|
||||||
|
key := token.TicketKeyPrefix + oneTimeToken
|
||||||
|
if err = use.TokenRepo.CreateOneTimeToken(ctx, key, entity.Ticket{
|
||||||
|
Data: claims,
|
||||||
|
Token: tokenObj,
|
||||||
|
}, time.Minute); err != nil {
|
||||||
|
return entity.CreateOneTimeTokenResp{},
|
||||||
|
use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.CreateOneTimeToken",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "create one time token error",
|
||||||
|
errorCode: code.OneTimeTokenError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return entity.CreateOneTimeTokenResp{
|
||||||
|
OneTimeToken: oneTimeToken,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (use *TokenUseCase) CancelOneTimeToken(ctx context.Context, req entity.CancelOneTimeTokenReq) error {
|
||||||
|
err := use.TokenRepo.DeleteOneTimeToken(ctx, req.Token, nil)
|
||||||
|
if err != nil {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "TokenRepo.DeleteOneTimeToken",
|
||||||
|
req: req,
|
||||||
|
err: err,
|
||||||
|
message: "failed to del one time token by token",
|
||||||
|
errorCode: code.OneTimeTokenError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type wrapTokenErrorReq struct {
|
||||||
|
funcName string
|
||||||
|
req any
|
||||||
|
err error
|
||||||
|
message string
|
||||||
|
errorCode uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
// wrapTokenError 將錯誤信息封裝到 errs.LibError 中
|
||||||
|
func (use *TokenUseCase) wrapTokenError(ctx context.Context, param wrapTokenErrorReq) error {
|
||||||
|
logFields := []logx.LogField{
|
||||||
|
{Key: "req", Value: param.req},
|
||||||
|
{Key: "func", Value: param.funcName},
|
||||||
|
{Key: "err", Value: param.err.Error()},
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.WithContext(ctx).Errorw(param.message, logFields...)
|
||||||
|
|
||||||
|
wrappedErr := errs.NewError(
|
||||||
|
code.CatToken,
|
||||||
|
code.CatToken,
|
||||||
|
param.errorCode,
|
||||||
|
param.message,
|
||||||
|
).Wrap(param.err)
|
||||||
|
|
||||||
|
return wrappedErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlacklistToken 將 JWT token 加入黑名單 (立即撤銷)
|
||||||
|
func (use *TokenUseCase) BlacklistToken(ctx context.Context, token string, reason string) error {
|
||||||
|
// 解析 JWT 獲取完整的 claims
|
||||||
|
claimMap, err := parseToken(token, use.Config.Token.AccessSecret, false)
|
||||||
|
if err != nil {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "BlacklistToken.parseToken",
|
||||||
|
req: token,
|
||||||
|
err: err,
|
||||||
|
message: "failed to parse token claims",
|
||||||
|
errorCode: code.InvalidJWT,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 獲取 JTI (JWT ID)
|
||||||
|
jti, exists := claimMap["jti"]
|
||||||
|
if !exists {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "BlacklistToken.getJTI",
|
||||||
|
req: token,
|
||||||
|
err: entity.ErrInvalidJTI,
|
||||||
|
message: "token missing JTI claim",
|
||||||
|
errorCode: code.InvalidJWT,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
jtiStr, ok := jti.(string)
|
||||||
|
if !ok {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "BlacklistToken.convertJTI",
|
||||||
|
req: token,
|
||||||
|
err: entity.ErrInvalidJTI,
|
||||||
|
message: "JTI claim is not a string",
|
||||||
|
errorCode: code.InvalidJWT,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 獲取 UID (可能在 data 中)
|
||||||
|
var uid string
|
||||||
|
if dataInterface, exists := claimMap["data"]; exists {
|
||||||
|
if dataMap, ok := dataInterface.(map[string]interface{}); ok {
|
||||||
|
if uidInterface, exists := dataMap["uid"]; exists {
|
||||||
|
uid, _ = uidInterface.(string)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 獲取過期時間
|
||||||
|
exp, exists := claimMap["exp"]
|
||||||
|
if !exists {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "BlacklistToken.getExp",
|
||||||
|
req: token,
|
||||||
|
err: entity.ErrTokenExpired,
|
||||||
|
message: "token missing exp claim",
|
||||||
|
errorCode: code.TokenExpired,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 將 exp 轉換為 int64 (JWT 中通常是 float64)
|
||||||
|
var expInt int64
|
||||||
|
switch v := exp.(type) {
|
||||||
|
case float64:
|
||||||
|
expInt = int64(v)
|
||||||
|
case int64:
|
||||||
|
expInt = v
|
||||||
|
case string:
|
||||||
|
parsedExp, err := strconv.ParseInt(v, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "BlacklistToken.parseExp",
|
||||||
|
req: token,
|
||||||
|
err: err,
|
||||||
|
message: "failed to parse exp claim",
|
||||||
|
errorCode: code.TokenExpired,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
expInt = parsedExp
|
||||||
|
default:
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "BlacklistToken.convertExp",
|
||||||
|
req: token,
|
||||||
|
err: fmt.Errorf("exp claim is not a valid type: %T", exp),
|
||||||
|
message: "exp claim type conversion failed",
|
||||||
|
errorCode: code.TokenExpired,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 創建黑名單條目
|
||||||
|
blacklistEntry := &entity.BlacklistEntry{
|
||||||
|
JTI: jtiStr,
|
||||||
|
UID: uid,
|
||||||
|
ExpiresAt: expInt,
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加到黑名單
|
||||||
|
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
|
||||||
|
if err != nil {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "BlacklistToken.AddToBlacklist",
|
||||||
|
req: jtiStr,
|
||||||
|
err: err,
|
||||||
|
message: "failed to add token to blacklist",
|
||||||
|
errorCode: code.TokenCreateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.WithContext(ctx).Infow("token blacklisted",
|
||||||
|
logx.Field("jti", jtiStr),
|
||||||
|
logx.Field("uid", uid),
|
||||||
|
logx.Field("reason", reason))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTokenBlacklisted 檢查 JWT token 是否在黑名單中
|
||||||
|
func (use *TokenUseCase) IsTokenBlacklisted(ctx context.Context, jti string) (bool, error) {
|
||||||
|
isBlacklisted, err := use.TokenRepo.IsBlacklisted(ctx, jti)
|
||||||
|
if err != nil {
|
||||||
|
return false, use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "IsTokenBlacklisted",
|
||||||
|
req: jti,
|
||||||
|
err: err,
|
||||||
|
message: "failed to check blacklist status",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return isBlacklisted, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlacklistAllUserTokens 將用戶的所有 token 加入黑名單 (全設備登出)
|
||||||
|
func (use *TokenUseCase) BlacklistAllUserTokens(ctx context.Context, uid string, reason string) error {
|
||||||
|
// 獲取用戶的所有 token
|
||||||
|
tokens, err := use.TokenRepo.GetAccessTokensByUID(ctx, uid)
|
||||||
|
if err != nil {
|
||||||
|
return use.wrapTokenError(ctx, wrapTokenErrorReq{
|
||||||
|
funcName: "BlacklistAllUserTokens.GetAccessTokensByUID",
|
||||||
|
req: uid,
|
||||||
|
err: err,
|
||||||
|
message: "failed to get user tokens",
|
||||||
|
errorCode: code.TokenValidateError,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 為每個 token 創建黑名單條目
|
||||||
|
for _, token := range tokens {
|
||||||
|
// 解析 token 獲取 JTI 和過期時間
|
||||||
|
claims, err := parseClaims(token.AccessToken, use.Config.Token.AccessSecret, false)
|
||||||
|
if err != nil {
|
||||||
|
logx.WithContext(ctx).Errorw("failed to parse token for blacklisting",
|
||||||
|
logx.Field("uid", uid),
|
||||||
|
logx.Field("tokenID", token.ID),
|
||||||
|
logx.Field("error", err))
|
||||||
|
continue // 跳過無效的 token,繼續處理其他 token
|
||||||
|
}
|
||||||
|
|
||||||
|
jti, exists := claims["jti"]
|
||||||
|
if !exists || jti == "" {
|
||||||
|
logx.WithContext(ctx).Errorw("token missing JTI claim",
|
||||||
|
logx.Field("uid", uid),
|
||||||
|
logx.Field("tokenID", token.ID))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
exp, exists := claims["exp"]
|
||||||
|
if !exists {
|
||||||
|
logx.WithContext(ctx).Errorw("token missing exp claim",
|
||||||
|
logx.Field("uid", uid),
|
||||||
|
logx.Field("tokenID", token.ID))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 將 exp 字符串轉換為 int64
|
||||||
|
expInt, err := strconv.ParseInt(exp, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
logx.WithContext(ctx).Errorw("failed to parse exp claim",
|
||||||
|
logx.Field("uid", uid),
|
||||||
|
logx.Field("tokenID", token.ID),
|
||||||
|
logx.Field("error", err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// 創建黑名單條目
|
||||||
|
blacklistEntry := &entity.BlacklistEntry{
|
||||||
|
JTI: jti,
|
||||||
|
UID: uid,
|
||||||
|
ExpiresAt: expInt,
|
||||||
|
CreatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加到黑名單
|
||||||
|
err = use.TokenRepo.AddToBlacklist(ctx, blacklistEntry, 0) // TTL=0 表示使用默認計算
|
||||||
|
if err != nil {
|
||||||
|
logx.WithContext(ctx).Errorw("failed to add token to blacklist",
|
||||||
|
logx.Field("uid", uid),
|
||||||
|
logx.Field("jti", jti),
|
||||||
|
logx.Field("error", err))
|
||||||
|
// 繼續處理其他 token,不要因為一個失敗就停止
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 刪除用戶的所有 token 記錄
|
||||||
|
err = use.TokenRepo.DeleteAccessTokensByUID(ctx, uid)
|
||||||
|
if err != nil {
|
||||||
|
logx.WithContext(ctx).Errorw("failed to delete user tokens",
|
||||||
|
logx.Field("uid", uid),
|
||||||
|
logx.Field("error", err))
|
||||||
|
// 這不是致命錯誤,因為 token 已經被加入黑名單
|
||||||
|
}
|
||||||
|
|
||||||
|
logx.WithContext(ctx).Infow("all user tokens blacklisted",
|
||||||
|
logx.Field("uid", uid),
|
||||||
|
logx.Field("tokenCount", len(tokens)),
|
||||||
|
logx.Field("reason", reason))
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
package usecase
|
||||||
|
|
||||||
|
type tokenClaims map[string]string
|
||||||
|
|
||||||
|
func (tc tokenClaims) SetID(id string) {
|
||||||
|
tc["id"] = id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc tokenClaims) SetRole(role string) {
|
||||||
|
tc["role"] = role
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc tokenClaims) SetDeviceID(deviceID string) {
|
||||||
|
tc["device_id"] = deviceID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc tokenClaims) SetScope(scope string) {
|
||||||
|
tc["scope"] = scope
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc tokenClaims) SetAccount(account string) {
|
||||||
|
tc["account"] = account
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc tokenClaims) Role() string {
|
||||||
|
role, ok := tc["role"]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return role
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc tokenClaims) ID() string {
|
||||||
|
id, ok := tc["id"]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc tokenClaims) DeviceID() string {
|
||||||
|
deviceID, ok := tc["device_id"]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return deviceID
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tc tokenClaims) UID() string {
|
||||||
|
uid, ok := tc["uid"]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return uid
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,325 @@
|
||||||
|
package usecase
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTokenClaims_SetAndGetID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
id string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal ID",
|
||||||
|
id: "token123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UUID ID",
|
||||||
|
id: "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty ID",
|
||||||
|
id: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tc := make(tokenClaims)
|
||||||
|
tc.SetID(tt.id)
|
||||||
|
|
||||||
|
result := tc.ID()
|
||||||
|
assert.Equal(t, tt.id, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenClaims_SetAndGetRole(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
role string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "admin role",
|
||||||
|
role: "admin",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user role",
|
||||||
|
role: "user",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "guest role",
|
||||||
|
role: "guest",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty role",
|
||||||
|
role: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tc := make(tokenClaims)
|
||||||
|
tc.SetRole(tt.role)
|
||||||
|
|
||||||
|
result := tc.Role()
|
||||||
|
assert.Equal(t, tt.role, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenClaims_SetAndGetDeviceID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
deviceID string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal device ID",
|
||||||
|
deviceID: "device123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UUID device ID",
|
||||||
|
deviceID: "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty device ID",
|
||||||
|
deviceID: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tc := make(tokenClaims)
|
||||||
|
tc.SetDeviceID(tt.deviceID)
|
||||||
|
|
||||||
|
result := tc.DeviceID()
|
||||||
|
assert.Equal(t, tt.deviceID, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenClaims_SetAndGetScope(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
scope string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "read write scope",
|
||||||
|
scope: "read write",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "read only scope",
|
||||||
|
scope: "read",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "admin scope",
|
||||||
|
scope: "admin",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty scope",
|
||||||
|
scope: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tc := make(tokenClaims)
|
||||||
|
tc.SetScope(tt.scope)
|
||||||
|
|
||||||
|
// Note: there's no GetScope method, so we just verify it's set
|
||||||
|
assert.Equal(t, tt.scope, tc["scope"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenClaims_SetAndGetAccount(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
account string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "email account",
|
||||||
|
account: "user@example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "username account",
|
||||||
|
account: "john_doe",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "phone account",
|
||||||
|
account: "+1234567890",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty account",
|
||||||
|
account: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tc := make(tokenClaims)
|
||||||
|
tc.SetAccount(tt.account)
|
||||||
|
|
||||||
|
// Note: there's no GetAccount method, so we just verify it's set
|
||||||
|
assert.Equal(t, tt.account, tc["account"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenClaims_SetAndGetUID(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
uid string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal UID",
|
||||||
|
uid: "user123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UUID UID",
|
||||||
|
uid: "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty UID",
|
||||||
|
uid: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tc := make(tokenClaims)
|
||||||
|
tc["uid"] = tt.uid
|
||||||
|
|
||||||
|
result := tc.UID()
|
||||||
|
assert.Equal(t, tt.uid, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenClaims_GetNonExistentField(t *testing.T) {
|
||||||
|
tc := make(tokenClaims)
|
||||||
|
|
||||||
|
t.Run("get non-existent ID", func(t *testing.T) {
|
||||||
|
result := tc.ID()
|
||||||
|
assert.Empty(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("get non-existent Role", func(t *testing.T) {
|
||||||
|
result := tc.Role()
|
||||||
|
assert.Empty(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("get non-existent DeviceID", func(t *testing.T) {
|
||||||
|
result := tc.DeviceID()
|
||||||
|
assert.Empty(t, result)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("get non-existent UID", func(t *testing.T) {
|
||||||
|
result := tc.UID()
|
||||||
|
assert.Empty(t, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenClaims_MultipleFields(t *testing.T) {
|
||||||
|
tc := make(tokenClaims)
|
||||||
|
|
||||||
|
tc.SetID("token123")
|
||||||
|
tc.SetRole("admin")
|
||||||
|
tc.SetDeviceID("device456")
|
||||||
|
tc.SetScope("read write")
|
||||||
|
tc.SetAccount("user@example.com")
|
||||||
|
tc["uid"] = "user789"
|
||||||
|
|
||||||
|
t.Run("verify all fields", func(t *testing.T) {
|
||||||
|
assert.Equal(t, "token123", tc.ID())
|
||||||
|
assert.Equal(t, "admin", tc.Role())
|
||||||
|
assert.Equal(t, "device456", tc.DeviceID())
|
||||||
|
assert.Equal(t, "read write", tc["scope"])
|
||||||
|
assert.Equal(t, "user@example.com", tc["account"])
|
||||||
|
assert.Equal(t, "user789", tc.UID())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenClaims_Overwrite(t *testing.T) {
|
||||||
|
tc := make(tokenClaims)
|
||||||
|
|
||||||
|
t.Run("overwrite ID", func(t *testing.T) {
|
||||||
|
tc.SetID("token123")
|
||||||
|
assert.Equal(t, "token123", tc.ID())
|
||||||
|
|
||||||
|
tc.SetID("token456")
|
||||||
|
assert.Equal(t, "token456", tc.ID())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("overwrite Role", func(t *testing.T) {
|
||||||
|
tc.SetRole("user")
|
||||||
|
assert.Equal(t, "user", tc.Role())
|
||||||
|
|
||||||
|
tc.SetRole("admin")
|
||||||
|
assert.Equal(t, "admin", tc.Role())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenClaims_MapBehavior(t *testing.T) {
|
||||||
|
tc := make(tokenClaims)
|
||||||
|
|
||||||
|
t.Run("can set custom fields", func(t *testing.T) {
|
||||||
|
tc["custom_field"] = "custom_value"
|
||||||
|
assert.Equal(t, "custom_value", tc["custom_field"])
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("can iterate over fields", func(t *testing.T) {
|
||||||
|
tc2 := make(tokenClaims)
|
||||||
|
tc2.SetID("token123")
|
||||||
|
tc2.SetRole("admin")
|
||||||
|
tc2["uid"] = "user123"
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
for range tc2 {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
assert.Equal(t, 3, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("can check field existence", func(t *testing.T) {
|
||||||
|
tc.SetID("token123")
|
||||||
|
|
||||||
|
_, exists := tc["id"]
|
||||||
|
assert.True(t, exists)
|
||||||
|
|
||||||
|
_, exists = tc["non_existent"]
|
||||||
|
assert.False(t, exists)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("can delete fields", func(t *testing.T) {
|
||||||
|
tc.SetRole("admin")
|
||||||
|
assert.Equal(t, "admin", tc.Role())
|
||||||
|
|
||||||
|
delete(tc, "role")
|
||||||
|
assert.Empty(t, tc.Role())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenClaims_EmptyMap(t *testing.T) {
|
||||||
|
tc := make(tokenClaims)
|
||||||
|
|
||||||
|
assert.Empty(t, tc.ID())
|
||||||
|
assert.Empty(t, tc.Role())
|
||||||
|
assert.Empty(t, tc.DeviceID())
|
||||||
|
assert.Empty(t, tc.UID())
|
||||||
|
assert.Equal(t, 0, len(tc))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenClaims_NilMap(t *testing.T) {
|
||||||
|
var tc tokenClaims
|
||||||
|
|
||||||
|
t.Run("get from nil map", func(t *testing.T) {
|
||||||
|
assert.Empty(t, tc.ID())
|
||||||
|
assert.Empty(t, tc.Role())
|
||||||
|
assert.Empty(t, tc.DeviceID())
|
||||||
|
assert.Empty(t, tc.UID())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -0,0 +1,107 @@
|
||||||
|
package usecase
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"backend/pkg/permission/domain/entity"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
)
|
||||||
|
|
||||||
|
var accessTokenGenerator = createAccessToken
|
||||||
|
var refreshTokenGenerator = createRefreshToken
|
||||||
|
|
||||||
|
// createAccessToken 生成訪問令牌(Access Token)
|
||||||
|
func createAccessToken(token entity.Token, data any, secretKey string) (string, error) {
|
||||||
|
claims := entity.Claims{
|
||||||
|
Data: data,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ID: token.ID,
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Unix(int64(token.ExpiresIn), 0)),
|
||||||
|
Issuer: "permission",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256, claims).
|
||||||
|
SignedString([]byte(secretKey))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return accessToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// createRefreshToken 基於訪問令牌生成刷新令牌(Refresh Token)
|
||||||
|
func createRefreshToken(accessToken string) string {
|
||||||
|
hash := sha256.New()
|
||||||
|
_, _ = hash.Write([]byte(accessToken))
|
||||||
|
|
||||||
|
return hex.EncodeToString(hash.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseToken(accessToken string, secret string, validate bool) (jwt.MapClaims, error) {
|
||||||
|
// 跳過驗證的解析
|
||||||
|
var token *jwt.Token
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if validate {
|
||||||
|
token, err = jwt.Parse(accessToken, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("token unexpected signing method: %v", token.Header["alg"])
|
||||||
|
}
|
||||||
|
|
||||||
|
return []byte(secret), nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return jwt.MapClaims{}, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
|
||||||
|
token, err = parser.Parse(accessToken, func(_ *jwt.Token) (any, error) {
|
||||||
|
return []byte(secret), nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return jwt.MapClaims{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(jwt.MapClaims)
|
||||||
|
if !ok && token.Valid {
|
||||||
|
return jwt.MapClaims{}, fmt.Errorf("token valid error")
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClaims(accessToken string, secret string, validate bool) (tokenClaims, error) {
|
||||||
|
claimMap, err := parseToken(accessToken, secret, validate)
|
||||||
|
if err != nil {
|
||||||
|
return tokenClaims{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
claimsData, ok := claimMap["data"].(map[string]any)
|
||||||
|
if ok {
|
||||||
|
return convertMap(claimsData), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenClaims{}, fmt.Errorf("get data from claim map error")
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertMap(input map[string]interface{}) map[string]string {
|
||||||
|
output := make(map[string]string)
|
||||||
|
for key, value := range input {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
output[key] = v
|
||||||
|
case fmt.Stringer:
|
||||||
|
output[key] = v.String()
|
||||||
|
default:
|
||||||
|
output[key] = fmt.Sprintf("%v", value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,329 @@
|
||||||
|
package usecase
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"backend/pkg/permission/domain/entity"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateAccessToken(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token entity.Token
|
||||||
|
data interface{}
|
||||||
|
secretKey string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "successful token creation",
|
||||||
|
token: entity.Token{
|
||||||
|
ID: "test-token-id",
|
||||||
|
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
|
||||||
|
},
|
||||||
|
data: map[string]string{
|
||||||
|
"uid": "user123",
|
||||||
|
"role": "admin",
|
||||||
|
},
|
||||||
|
secretKey: "test-secret-key",
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty secret key",
|
||||||
|
token: entity.Token{
|
||||||
|
ID: "test-token-id",
|
||||||
|
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
|
||||||
|
},
|
||||||
|
data: map[string]string{"uid": "user123"},
|
||||||
|
secretKey: "",
|
||||||
|
wantErr: false, // JWT library will still create token with empty key
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tokenStr, err := createAccessToken(tt.token, tt.data, tt.secretKey)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Empty(t, tokenStr)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, tokenStr)
|
||||||
|
|
||||||
|
// Verify the token can be parsed
|
||||||
|
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||||
|
return []byte(tt.secretKey), nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if tt.secretKey != "" {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, token.Valid)
|
||||||
|
|
||||||
|
// Check claims
|
||||||
|
if claims, ok := token.Claims.(jwt.MapClaims); ok {
|
||||||
|
assert.Equal(t, tt.token.ID, claims["jti"])
|
||||||
|
assert.Equal(t, "permission", claims["iss"])
|
||||||
|
assert.NotNil(t, claims["exp"])
|
||||||
|
assert.NotNil(t, claims["data"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreateRefreshToken(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accessToken string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "consistent hash generation",
|
||||||
|
accessToken: "test-access-token",
|
||||||
|
want: "9f86d081884c7d659a2feaa0c55ad015a3bf4f1b2b0b822cd15d6c15b0f00a08", // SHA256 of "test"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "different token different hash",
|
||||||
|
accessToken: "different-access-token",
|
||||||
|
want: "", // We'll check it's different
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty token",
|
||||||
|
accessToken: "",
|
||||||
|
want: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // SHA256 of empty string
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := createRefreshToken(tt.accessToken)
|
||||||
|
|
||||||
|
assert.NotEmpty(t, result)
|
||||||
|
assert.Len(t, result, 64) // SHA256 hex string length
|
||||||
|
|
||||||
|
if tt.want != "" {
|
||||||
|
if tt.name == "consistent hash generation" {
|
||||||
|
// For "test" input, we know the expected hash
|
||||||
|
testResult := createRefreshToken("test")
|
||||||
|
assert.Equal(t, tt.want, testResult)
|
||||||
|
} else if tt.name == "empty token" {
|
||||||
|
assert.Equal(t, tt.want, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test consistency - same input should produce same output
|
||||||
|
result2 := createRefreshToken(tt.accessToken)
|
||||||
|
assert.Equal(t, result, result2)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseToken(t *testing.T) {
|
||||||
|
secretKey := "test-secret-key"
|
||||||
|
|
||||||
|
// Create a valid token first
|
||||||
|
token := entity.Token{
|
||||||
|
ID: "test-id",
|
||||||
|
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
|
||||||
|
}
|
||||||
|
data := map[string]string{
|
||||||
|
"uid": "user123",
|
||||||
|
"role": "admin",
|
||||||
|
}
|
||||||
|
|
||||||
|
validTokenStr, err := createAccessToken(token, data, secretKey)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accessToken string
|
||||||
|
secret string
|
||||||
|
validate bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid token with validation",
|
||||||
|
accessToken: validTokenStr,
|
||||||
|
secret: secretKey,
|
||||||
|
validate: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "valid token without validation",
|
||||||
|
accessToken: validTokenStr,
|
||||||
|
secret: secretKey,
|
||||||
|
validate: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid token",
|
||||||
|
accessToken: "invalid.token.string",
|
||||||
|
secret: secretKey,
|
||||||
|
validate: true,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong secret",
|
||||||
|
accessToken: validTokenStr,
|
||||||
|
secret: "wrong-secret",
|
||||||
|
validate: true,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty token",
|
||||||
|
accessToken: "",
|
||||||
|
secret: secretKey,
|
||||||
|
validate: true,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
claims, err := parseToken(tt.accessToken, tt.secret, tt.validate)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, claims)
|
||||||
|
|
||||||
|
if tt.accessToken == validTokenStr {
|
||||||
|
assert.Equal(t, "test-id", claims["jti"])
|
||||||
|
assert.Equal(t, "permission", claims["iss"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseClaims(t *testing.T) {
|
||||||
|
secretKey := "test-secret-key"
|
||||||
|
|
||||||
|
// Create a valid token with data claims
|
||||||
|
token := entity.Token{
|
||||||
|
ID: "test-id",
|
||||||
|
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
|
||||||
|
}
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"uid": "user123",
|
||||||
|
"role": "admin",
|
||||||
|
"deviceId": "device456",
|
||||||
|
}
|
||||||
|
|
||||||
|
validTokenStr, err := createAccessToken(token, data, secretKey)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
accessToken string
|
||||||
|
secret string
|
||||||
|
validate bool
|
||||||
|
wantErr bool
|
||||||
|
expectUID string
|
||||||
|
expectRole string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid token with data claims",
|
||||||
|
accessToken: validTokenStr,
|
||||||
|
secret: secretKey,
|
||||||
|
validate: false,
|
||||||
|
wantErr: false,
|
||||||
|
expectUID: "user123",
|
||||||
|
expectRole: "admin",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid token",
|
||||||
|
accessToken: "invalid.token",
|
||||||
|
secret: secretKey,
|
||||||
|
validate: false,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
claims, err := parseClaims(tt.accessToken, tt.secret, tt.validate)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, claims)
|
||||||
|
|
||||||
|
if tt.expectUID != "" {
|
||||||
|
uid, exists := claims["uid"]
|
||||||
|
assert.True(t, exists)
|
||||||
|
assert.Equal(t, tt.expectUID, uid)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.expectRole != "" {
|
||||||
|
role, exists := claims["role"]
|
||||||
|
assert.True(t, exists)
|
||||||
|
assert.Equal(t, tt.expectRole, role)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertMap(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input map[string]interface{}
|
||||||
|
expect map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string values",
|
||||||
|
input: map[string]interface{}{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
},
|
||||||
|
expect: map[string]string{
|
||||||
|
"key1": "value1",
|
||||||
|
"key2": "value2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed types",
|
||||||
|
input: map[string]interface{}{
|
||||||
|
"string": "value",
|
||||||
|
"int": 123,
|
||||||
|
"float": 45.67,
|
||||||
|
"bool": true,
|
||||||
|
},
|
||||||
|
expect: map[string]string{
|
||||||
|
"string": "value",
|
||||||
|
"int": "123",
|
||||||
|
"float": "45.67",
|
||||||
|
"bool": "true",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty map",
|
||||||
|
input: map[string]interface{}{},
|
||||||
|
expect: map[string]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil values",
|
||||||
|
input: map[string]interface{}{
|
||||||
|
"nil": nil,
|
||||||
|
},
|
||||||
|
expect: map[string]string{
|
||||||
|
"nil": "<nil>",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := convertMap(tt.input)
|
||||||
|
assert.Equal(t, tt.expect, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,435 @@
|
||||||
|
package usecase
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"backend/internal/config"
|
||||||
|
"backend/pkg/permission/domain/entity"
|
||||||
|
"backend/pkg/permission/domain/token"
|
||||||
|
"backend/pkg/permission/mock/repository"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTokenUseCase_NewToken(t *testing.T) {
|
||||||
|
mockRepo := repository.NewMockTokenRepository(t)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Token: struct {
|
||||||
|
AccessSecret string
|
||||||
|
RefreshSecret string
|
||||||
|
AccessTokenExpiry time.Duration
|
||||||
|
RefreshTokenExpiry time.Duration
|
||||||
|
OneTimeTokenExpiry time.Duration
|
||||||
|
MaxTokensPerUser int
|
||||||
|
MaxTokensPerDevice int
|
||||||
|
}{
|
||||||
|
AccessSecret: "test-access-secret",
|
||||||
|
RefreshSecret: "test-refresh-secret",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 7 * 24 * time.Hour,
|
||||||
|
MaxTokensPerUser: 10,
|
||||||
|
MaxTokensPerDevice: 5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
useCase := &TokenUseCase{
|
||||||
|
TokenUseCaseParam: TokenUseCaseParam{
|
||||||
|
TokenRepo: mockRepo,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req entity.AuthorizationReq
|
||||||
|
setup func()
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "successful token creation",
|
||||||
|
req: entity.AuthorizationReq{
|
||||||
|
GrantType: token.PasswordCredentials.ToString(),
|
||||||
|
Scope: "read write",
|
||||||
|
DeviceID: "device123",
|
||||||
|
IsRefreshToken: true,
|
||||||
|
Data: map[string]string{
|
||||||
|
"uid": "user123",
|
||||||
|
"role": "user",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
|
||||||
|
Return(nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "repository error",
|
||||||
|
req: entity.AuthorizationReq{
|
||||||
|
GrantType: token.PasswordCredentials.ToString(),
|
||||||
|
Scope: "read",
|
||||||
|
DeviceID: "device123",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
|
||||||
|
Return(assert.AnError).Once()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setup()
|
||||||
|
|
||||||
|
resp, err := useCase.NewToken(context.Background(), tt.req)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Empty(t, resp.AccessToken)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, resp.AccessToken)
|
||||||
|
assert.Equal(t, token.TypeBearer.String(), resp.TokenType)
|
||||||
|
assert.Greater(t, resp.ExpiresIn, int64(0))
|
||||||
|
if tt.req.IsRefreshToken {
|
||||||
|
assert.NotEmpty(t, resp.RefreshToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUseCase_ValidationToken(t *testing.T) {
|
||||||
|
mockRepo := repository.NewMockTokenRepository(t)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Token: struct {
|
||||||
|
AccessSecret string
|
||||||
|
RefreshSecret string
|
||||||
|
AccessTokenExpiry time.Duration
|
||||||
|
RefreshTokenExpiry time.Duration
|
||||||
|
OneTimeTokenExpiry time.Duration
|
||||||
|
MaxTokensPerUser int
|
||||||
|
MaxTokensPerDevice int
|
||||||
|
}{
|
||||||
|
AccessSecret: "test-access-secret",
|
||||||
|
RefreshSecret: "test-refresh-secret",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 7 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
useCase := &TokenUseCase{
|
||||||
|
TokenUseCaseParam: TokenUseCaseParam{
|
||||||
|
TokenRepo: mockRepo,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先創建一個有效的 token 用於測試
|
||||||
|
tokenReq := entity.AuthorizationReq{
|
||||||
|
GrantType: token.PasswordCredentials.ToString(),
|
||||||
|
Data: map[string]string{
|
||||||
|
"uid": "user123",
|
||||||
|
"role": "user",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
|
||||||
|
Return(nil).Once()
|
||||||
|
|
||||||
|
tokenResp, err := useCase.NewToken(context.Background(), tokenReq)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, tokenResp.AccessToken)
|
||||||
|
|
||||||
|
// 測試驗證
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req entity.ValidationTokenReq
|
||||||
|
setup func()
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid token",
|
||||||
|
req: entity.ValidationTokenReq{
|
||||||
|
Token: tokenResp.AccessToken,
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("GetAccessTokenByID", mock.Anything, mock.AnythingOfType("string")).
|
||||||
|
Return(entity.Token{
|
||||||
|
ID: "test-id",
|
||||||
|
UID: "user123",
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
ExpiresIn: int(cfg.Token.AccessTokenExpiry.Seconds()),
|
||||||
|
}, nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid token",
|
||||||
|
req: entity.ValidationTokenReq{
|
||||||
|
Token: "invalid-token",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
// parseClaims will fail for invalid token
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setup()
|
||||||
|
|
||||||
|
resp, err := useCase.ValidationToken(context.Background(), tt.req)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, resp.Token.ID)
|
||||||
|
assert.Equal(t, "user123", resp.Token.UID)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUseCase_BlacklistToken(t *testing.T) {
|
||||||
|
mockRepo := repository.NewMockTokenRepository(t)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Token: struct {
|
||||||
|
AccessSecret string
|
||||||
|
RefreshSecret string
|
||||||
|
AccessTokenExpiry time.Duration
|
||||||
|
RefreshTokenExpiry time.Duration
|
||||||
|
OneTimeTokenExpiry time.Duration
|
||||||
|
MaxTokensPerUser int
|
||||||
|
MaxTokensPerDevice int
|
||||||
|
}{
|
||||||
|
AccessSecret: "test-access-secret",
|
||||||
|
RefreshSecret: "test-refresh-secret",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 7 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
useCase := &TokenUseCase{
|
||||||
|
TokenUseCaseParam: TokenUseCaseParam{
|
||||||
|
TokenRepo: mockRepo,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先創建一個有效的 token
|
||||||
|
tokenReq := entity.AuthorizationReq{
|
||||||
|
GrantType: token.PasswordCredentials.ToString(),
|
||||||
|
Data: map[string]string{
|
||||||
|
"uid": "user123",
|
||||||
|
"role": "user",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
|
||||||
|
Return(nil).Once()
|
||||||
|
|
||||||
|
tokenResp, err := useCase.NewToken(context.Background(), tokenReq)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
reason string
|
||||||
|
setup func()
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "successful blacklist",
|
||||||
|
token: tokenResp.AccessToken,
|
||||||
|
reason: "user logout",
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("AddToBlacklist", mock.Anything, mock.AnythingOfType("*entity.BlacklistEntry"), mock.AnythingOfType("time.Duration")).
|
||||||
|
Return(nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid token",
|
||||||
|
token: "invalid-token",
|
||||||
|
reason: "test",
|
||||||
|
setup: func() {},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setup()
|
||||||
|
|
||||||
|
err := useCase.BlacklistToken(context.Background(), tt.token, tt.reason)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUseCase_IsTokenBlacklisted(t *testing.T) {
|
||||||
|
mockRepo := repository.NewMockTokenRepository(t)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Token: struct {
|
||||||
|
AccessSecret string
|
||||||
|
RefreshSecret string
|
||||||
|
AccessTokenExpiry time.Duration
|
||||||
|
RefreshTokenExpiry time.Duration
|
||||||
|
OneTimeTokenExpiry time.Duration
|
||||||
|
MaxTokensPerUser int
|
||||||
|
MaxTokensPerDevice int
|
||||||
|
}{
|
||||||
|
AccessSecret: "test-secret",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
useCase := &TokenUseCase{
|
||||||
|
TokenUseCaseParam: TokenUseCaseParam{
|
||||||
|
TokenRepo: mockRepo,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
jti string
|
||||||
|
setup func()
|
||||||
|
wantResult bool
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "token is blacklisted",
|
||||||
|
jti: "test-jti-123",
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("IsBlacklisted", mock.Anything, "test-jti-123").
|
||||||
|
Return(true, nil).Once()
|
||||||
|
},
|
||||||
|
wantResult: true,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "token is not blacklisted",
|
||||||
|
jti: "test-jti-456",
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("IsBlacklisted", mock.Anything, "test-jti-456").
|
||||||
|
Return(false, nil).Once()
|
||||||
|
},
|
||||||
|
wantResult: false,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "repository error",
|
||||||
|
jti: "test-jti-error",
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("IsBlacklisted", mock.Anything, "test-jti-error").
|
||||||
|
Return(false, assert.AnError).Once()
|
||||||
|
},
|
||||||
|
wantResult: false,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setup()
|
||||||
|
|
||||||
|
result, err := useCase.IsTokenBlacklisted(context.Background(), tt.jti)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.wantResult, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUseCase_CancelTokens(t *testing.T) {
|
||||||
|
mockRepo := repository.NewMockTokenRepository(t)
|
||||||
|
cfg := &config.Config{}
|
||||||
|
|
||||||
|
useCase := &TokenUseCase{
|
||||||
|
TokenUseCaseParam: TokenUseCaseParam{
|
||||||
|
TokenRepo: mockRepo,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req entity.DoTokenByUIDReq
|
||||||
|
setup func()
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "cancel by UID",
|
||||||
|
req: entity.DoTokenByUIDReq{
|
||||||
|
UID: "user123",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("DeleteAccessTokensByUID", mock.Anything, "user123").
|
||||||
|
Return(nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cancel by token IDs",
|
||||||
|
req: entity.DoTokenByUIDReq{
|
||||||
|
IDs: []string{"token1", "token2"},
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("DeleteAccessTokenByID", mock.Anything, []string{"token1", "token2"}).
|
||||||
|
Return(nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "repository error",
|
||||||
|
req: entity.DoTokenByUIDReq{
|
||||||
|
UID: "user123",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("DeleteAccessTokensByUID", mock.Anything, "user123").
|
||||||
|
Return(assert.AnError).Once()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setup()
|
||||||
|
|
||||||
|
err := useCase.CancelTokens(context.Background(), tt.req)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,565 @@
|
||||||
|
package usecase
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"backend/internal/config"
|
||||||
|
"backend/pkg/permission/domain/entity"
|
||||||
|
"backend/pkg/permission/domain/token"
|
||||||
|
"backend/pkg/permission/mock/repository"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/mock"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTokenUseCase_RefreshToken(t *testing.T) {
|
||||||
|
mockRepo := repository.NewMockTokenRepository(t)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Token: struct {
|
||||||
|
AccessSecret string
|
||||||
|
RefreshSecret string
|
||||||
|
AccessTokenExpiry time.Duration
|
||||||
|
RefreshTokenExpiry time.Duration
|
||||||
|
OneTimeTokenExpiry time.Duration
|
||||||
|
MaxTokensPerUser int
|
||||||
|
MaxTokensPerDevice int
|
||||||
|
}{
|
||||||
|
AccessSecret: "test-access-secret",
|
||||||
|
RefreshSecret: "test-refresh-secret",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 7 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
useCase := &TokenUseCase{
|
||||||
|
TokenUseCaseParam: TokenUseCaseParam{
|
||||||
|
TokenRepo: mockRepo,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a base token first
|
||||||
|
tokenReq := entity.AuthorizationReq{
|
||||||
|
GrantType: token.PasswordCredentials.ToString(),
|
||||||
|
Data: map[string]string{
|
||||||
|
"uid": "user123",
|
||||||
|
"role": "user",
|
||||||
|
},
|
||||||
|
IsRefreshToken: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
|
||||||
|
Return(nil).Once()
|
||||||
|
|
||||||
|
tokenResp, err := useCase.NewToken(context.Background(), tokenReq)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req entity.RefreshTokenReq
|
||||||
|
setup func()
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "successful token refresh",
|
||||||
|
req: entity.RefreshTokenReq{
|
||||||
|
Token: tokenResp.RefreshToken,
|
||||||
|
Scope: "read write",
|
||||||
|
DeviceID: "device123",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
existingToken := entity.Token{
|
||||||
|
ID: "old-token-id",
|
||||||
|
UID: "user123",
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.On("GetAccessTokenByOneTimeToken", mock.Anything, tokenResp.RefreshToken).
|
||||||
|
Return(existingToken, nil).Once()
|
||||||
|
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
|
||||||
|
Return(nil).Once()
|
||||||
|
mockRepo.On("Delete", mock.Anything, mock.AnythingOfType("entity.Token")).
|
||||||
|
Return(nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid refresh token",
|
||||||
|
req: entity.RefreshTokenReq{
|
||||||
|
Token: "invalid-refresh-token",
|
||||||
|
Scope: "read",
|
||||||
|
DeviceID: "device123",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("GetAccessTokenByOneTimeToken", mock.Anything, "invalid-refresh-token").
|
||||||
|
Return(entity.Token{}, assert.AnError).Once()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setup()
|
||||||
|
|
||||||
|
resp, err := useCase.RefreshToken(context.Background(), tt.req)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, resp.Token)
|
||||||
|
assert.NotEmpty(t, resp.OneTimeToken)
|
||||||
|
assert.Equal(t, token.TypeBearer.String(), resp.TokenType)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUseCase_GetUserTokensByUID(t *testing.T) {
|
||||||
|
mockRepo := repository.NewMockTokenRepository(t)
|
||||||
|
cfg := &config.Config{}
|
||||||
|
|
||||||
|
useCase := &TokenUseCase{
|
||||||
|
TokenUseCaseParam: TokenUseCaseParam{
|
||||||
|
TokenRepo: mockRepo,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req entity.QueryTokenByUIDReq
|
||||||
|
setup func()
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "get tokens successfully",
|
||||||
|
req: entity.QueryTokenByUIDReq{
|
||||||
|
UID: "user123",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
tokens := []entity.Token{
|
||||||
|
{
|
||||||
|
ID: "token1",
|
||||||
|
UID: "user123",
|
||||||
|
AccessToken: "access1",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "token2",
|
||||||
|
UID: "user123",
|
||||||
|
AccessToken: "access2",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mockRepo.On("GetAccessTokensByUID", mock.Anything, "user123").
|
||||||
|
Return(tokens, nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "repository error",
|
||||||
|
req: entity.QueryTokenByUIDReq{
|
||||||
|
UID: "user456",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("GetAccessTokensByUID", mock.Anything, "user456").
|
||||||
|
Return([]entity.Token(nil), assert.AnError).Once()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setup()
|
||||||
|
|
||||||
|
tokens, err := useCase.GetUserTokensByUID(context.Background(), tt.req)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, tokens)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, tokens)
|
||||||
|
assert.Greater(t, len(tokens), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUseCase_GetUserTokensByDeviceID(t *testing.T) {
|
||||||
|
mockRepo := repository.NewMockTokenRepository(t)
|
||||||
|
cfg := &config.Config{}
|
||||||
|
|
||||||
|
useCase := &TokenUseCase{
|
||||||
|
TokenUseCaseParam: TokenUseCaseParam{
|
||||||
|
TokenRepo: mockRepo,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req entity.DoTokenByDeviceIDReq
|
||||||
|
setup func()
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "get tokens by device successfully",
|
||||||
|
req: entity.DoTokenByDeviceIDReq{
|
||||||
|
DeviceID: "device123",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
tokens := []entity.Token{
|
||||||
|
{
|
||||||
|
ID: "token1",
|
||||||
|
UID: "user123",
|
||||||
|
DeviceID: "device123",
|
||||||
|
AccessToken: "access1",
|
||||||
|
ExpiresIn: 3600,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
mockRepo.On("GetAccessTokensByDeviceID", mock.Anything, "device123").
|
||||||
|
Return(tokens, nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "repository error",
|
||||||
|
req: entity.DoTokenByDeviceIDReq{
|
||||||
|
DeviceID: "device456",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("GetAccessTokensByDeviceID", mock.Anything, "device456").
|
||||||
|
Return([]entity.Token(nil), assert.AnError).Once()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setup()
|
||||||
|
|
||||||
|
tokens, err := useCase.GetUserTokensByDeviceID(context.Background(), tt.req)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, tokens)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUseCase_CancelTokenByDeviceID(t *testing.T) {
|
||||||
|
mockRepo := repository.NewMockTokenRepository(t)
|
||||||
|
cfg := &config.Config{}
|
||||||
|
|
||||||
|
useCase := &TokenUseCase{
|
||||||
|
TokenUseCaseParam: TokenUseCaseParam{
|
||||||
|
TokenRepo: mockRepo,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req entity.DoTokenByDeviceIDReq
|
||||||
|
setup func()
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "cancel tokens successfully",
|
||||||
|
req: entity.DoTokenByDeviceIDReq{
|
||||||
|
DeviceID: "device123",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("DeleteAccessTokensByDeviceID", mock.Anything, "device123").
|
||||||
|
Return(nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "repository error",
|
||||||
|
req: entity.DoTokenByDeviceIDReq{
|
||||||
|
DeviceID: "device456",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("DeleteAccessTokensByDeviceID", mock.Anything, "device456").
|
||||||
|
Return(assert.AnError).Once()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setup()
|
||||||
|
|
||||||
|
err := useCase.CancelTokenByDeviceID(context.Background(), tt.req)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUseCase_NewOneTimeToken(t *testing.T) {
|
||||||
|
mockRepo := repository.NewMockTokenRepository(t)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Token: struct {
|
||||||
|
AccessSecret string
|
||||||
|
RefreshSecret string
|
||||||
|
AccessTokenExpiry time.Duration
|
||||||
|
RefreshTokenExpiry time.Duration
|
||||||
|
OneTimeTokenExpiry time.Duration
|
||||||
|
MaxTokensPerUser int
|
||||||
|
MaxTokensPerDevice int
|
||||||
|
}{
|
||||||
|
AccessSecret: "test-access-secret",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 7 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
useCase := &TokenUseCase{
|
||||||
|
TokenUseCaseParam: TokenUseCaseParam{
|
||||||
|
TokenRepo: mockRepo,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a base token first
|
||||||
|
tokenReq := entity.AuthorizationReq{
|
||||||
|
GrantType: token.PasswordCredentials.ToString(),
|
||||||
|
Data: map[string]string{
|
||||||
|
"uid": "user123",
|
||||||
|
"role": "user",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
|
||||||
|
Return(nil).Once()
|
||||||
|
|
||||||
|
tokenResp, err := useCase.NewToken(context.Background(), tokenReq)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req entity.CreateOneTimeTokenReq
|
||||||
|
setup func()
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "create one-time token successfully",
|
||||||
|
req: entity.CreateOneTimeTokenReq{
|
||||||
|
Token: tokenResp.AccessToken,
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
existingToken := entity.Token{
|
||||||
|
ID: "token-id",
|
||||||
|
UID: "user123",
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
ExpiresIn: int(time.Now().Add(time.Hour).Unix()),
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.On("GetAccessTokenByID", mock.Anything, mock.AnythingOfType("string")).
|
||||||
|
Return(existingToken, nil).Once()
|
||||||
|
mockRepo.On("CreateOneTimeToken", mock.Anything, mock.AnythingOfType("string"),
|
||||||
|
mock.AnythingOfType("entity.Ticket"), mock.AnythingOfType("time.Duration")).
|
||||||
|
Return(nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid token",
|
||||||
|
req: entity.CreateOneTimeTokenReq{
|
||||||
|
Token: "invalid-token",
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
// parseClaims will fail
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setup()
|
||||||
|
|
||||||
|
resp, err := useCase.NewOneTimeToken(context.Background(), tt.req)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, resp.OneTimeToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUseCase_CancelOneTimeToken(t *testing.T) {
|
||||||
|
mockRepo := repository.NewMockTokenRepository(t)
|
||||||
|
cfg := &config.Config{}
|
||||||
|
|
||||||
|
useCase := &TokenUseCase{
|
||||||
|
TokenUseCaseParam: TokenUseCaseParam{
|
||||||
|
TokenRepo: mockRepo,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req entity.CancelOneTimeTokenReq
|
||||||
|
setup func()
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "cancel one-time token successfully",
|
||||||
|
req: entity.CancelOneTimeTokenReq{
|
||||||
|
Token: []string{"token1", "token2"},
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("DeleteOneTimeToken", mock.Anything, []string{"token1", "token2"}, mock.Anything).
|
||||||
|
Return(nil).Once()
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "repository error",
|
||||||
|
req: entity.CancelOneTimeTokenReq{
|
||||||
|
Token: []string{"token3"},
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
mockRepo.On("DeleteOneTimeToken", mock.Anything, []string{"token3"}, mock.Anything).
|
||||||
|
Return(assert.AnError).Once()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setup()
|
||||||
|
|
||||||
|
err := useCase.CancelOneTimeToken(context.Background(), tt.req)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTokenUseCase_ReadTokenBasicData(t *testing.T) {
|
||||||
|
mockRepo := repository.NewMockTokenRepository(t)
|
||||||
|
cfg := &config.Config{
|
||||||
|
Token: struct {
|
||||||
|
AccessSecret string
|
||||||
|
RefreshSecret string
|
||||||
|
AccessTokenExpiry time.Duration
|
||||||
|
RefreshTokenExpiry time.Duration
|
||||||
|
OneTimeTokenExpiry time.Duration
|
||||||
|
MaxTokensPerUser int
|
||||||
|
MaxTokensPerDevice int
|
||||||
|
}{
|
||||||
|
AccessSecret: "test-access-secret",
|
||||||
|
AccessTokenExpiry: 15 * time.Minute,
|
||||||
|
RefreshTokenExpiry: 7 * 24 * time.Hour,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
useCase := &TokenUseCase{
|
||||||
|
TokenUseCaseParam: TokenUseCaseParam{
|
||||||
|
TokenRepo: mockRepo,
|
||||||
|
Config: cfg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a valid token first
|
||||||
|
tokenReq := entity.AuthorizationReq{
|
||||||
|
GrantType: token.PasswordCredentials.ToString(),
|
||||||
|
Data: map[string]string{
|
||||||
|
"uid": "user123",
|
||||||
|
"role": "admin",
|
||||||
|
},
|
||||||
|
Role: "admin",
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.On("Create", mock.Anything, mock.AnythingOfType("entity.Token")).
|
||||||
|
Return(nil).Once()
|
||||||
|
|
||||||
|
tokenResp, err := useCase.NewToken(context.Background(), tokenReq)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "read valid token",
|
||||||
|
token: tokenResp.AccessToken,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid token",
|
||||||
|
token: "invalid-token",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty token",
|
||||||
|
token: "",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
claims, err := useCase.ReadTokenBasicData(context.Background(), tt.token)
|
||||||
|
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, claims)
|
||||||
|
assert.Equal(t, "user123", claims["uid"])
|
||||||
|
assert.Equal(t, "admin", claims["role"])
|
||||||
|
}
|
||||||
|
|
||||||
|
mockRepo.AssertExpectations(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestTokenUseCase_BlacklistAllUserTokens is commented out due to complexity of mocking
|
||||||
|
// the JWT parsing within the loop. The functionality is tested through integration tests.
|
||||||
|
// func TestTokenUseCase_BlacklistAllUserTokens(t *testing.T) { ... }
|
||||||
|
|
||||||
|
|
@ -1,225 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/library/errs/code"
|
|
||||||
"backend/pkg/permission/domain/permission"
|
|
||||||
"backend/pkg/permission/utils"
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"backend/pkg/library/errs"
|
|
||||||
"backend/pkg/permission/domain/entity"
|
|
||||||
"backend/pkg/permission/domain/repository"
|
|
||||||
"backend/pkg/permission/domain/usecase"
|
|
||||||
)
|
|
||||||
|
|
||||||
type UserRoleUseCaseParam struct {
|
|
||||||
UserRoleRepo repository.UserRoleRepository
|
|
||||||
RoleRepo repository.RoleRepository
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserRoleUseCase struct {
|
|
||||||
userRoleRepo repository.UserRoleRepository
|
|
||||||
roleRepo repository.RoleRepository
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustUserRoleUseCase 創建用戶角色用例實例
|
|
||||||
func MustUserRoleUseCase(param UserRoleUseCaseParam) usecase.UserRoleUseCase {
|
|
||||||
return &UserRoleUseCase{
|
|
||||||
userRoleRepo: param.UserRoleRepo,
|
|
||||||
roleRepo: param.RoleRepo,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *UserRoleUseCase) AssignRole(ctx context.Context, req usecase.AssignRoleRequest) (*entity.UserRole, error) {
|
|
||||||
// 驗證請求
|
|
||||||
if req.UID == "" {
|
|
||||||
return nil, errs.InvalidFormat("uid is required")
|
|
||||||
}
|
|
||||||
if req.RoleUID == "" {
|
|
||||||
return nil, errs.InvalidFormat("role_uid is required")
|
|
||||||
}
|
|
||||||
if req.Brand == "" {
|
|
||||||
return nil, errs.InvalidFormat("brand is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 檢查角色是否存在
|
|
||||||
role, err := uc.roleRepo.GetByUID(ctx, req.RoleUID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !utils.IsActive(role.Status) {
|
|
||||||
return nil, errs.InvalidFormat("role is not active")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 檢查用戶是否已經有此角色
|
|
||||||
existingUserRole, err := uc.userRoleRepo.GetByUserAndRole(ctx, req.UID, req.RoleUID)
|
|
||||||
if err == nil && existingUserRole != nil && utils.IsActive(existingUserRole.Status) {
|
|
||||||
return nil, errs.ResourceAlreadyExistWithScope(code.CloudEPPermission, req.UID+":"+req.RoleUID)
|
|
||||||
}
|
|
||||||
|
|
||||||
userRole := &entity.UserRole{
|
|
||||||
Brand: req.Brand,
|
|
||||||
UID: req.UID,
|
|
||||||
RoleUID: req.RoleUID,
|
|
||||||
Status: permission.StatusActive,
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := uc.userRoleRepo.Create(ctx, userRole); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return userRole, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *UserRoleUseCase) RevokeRole(ctx context.Context, uid, roleUID string) error {
|
|
||||||
// 驗證參數
|
|
||||||
if uid == "" {
|
|
||||||
return errs.InvalidFormat("uid is required")
|
|
||||||
}
|
|
||||||
if roleUID == "" {
|
|
||||||
return errs.InvalidFormat("role_uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.userRoleRepo.DeleteByUserAndRole(ctx, uid, roleUID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *UserRoleUseCase) GetUserRole(ctx context.Context, id string) (*entity.UserRole, error) {
|
|
||||||
return uc.userRoleRepo.GetByID(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *UserRoleUseCase) UpdateUserRole(ctx context.Context, req usecase.UpdateUserRoleRequest) (*entity.UserRole, error) {
|
|
||||||
// 獲取現有用戶角色
|
|
||||||
userRole, err := uc.userRoleRepo.GetByID(ctx, req.ID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新狀態
|
|
||||||
if req.Status != nil {
|
|
||||||
userRole.Status = *req.Status
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := uc.userRoleRepo.Update(ctx, req.ID, userRole); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return userRole, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *UserRoleUseCase) ListUserRoles(ctx context.Context, req usecase.ListUserRolesRequest) ([]*entity.UserRole, error) {
|
|
||||||
filter := repository.UserRoleFilter{
|
|
||||||
Brand: req.Brand,
|
|
||||||
UID: req.UID,
|
|
||||||
RoleUID: req.RoleUID,
|
|
||||||
Status: req.Status,
|
|
||||||
Limit: req.Limit,
|
|
||||||
Skip: req.Skip,
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.userRoleRepo.List(ctx, filter)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *UserRoleUseCase) GetUserRoles(ctx context.Context, uid string) ([]*entity.UserRole, error) {
|
|
||||||
if uid == "" {
|
|
||||||
return nil, errs.InvalidFormat("uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.userRoleRepo.GetUserRolesByUID(ctx, uid)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *UserRoleUseCase) GetUserRoleDetails(ctx context.Context, uid string) ([]*usecase.UserRoleDetail, error) {
|
|
||||||
// 獲取用戶角色
|
|
||||||
userRoles, err := uc.GetUserRoles(ctx, uid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var details []*usecase.UserRoleDetail
|
|
||||||
for _, userRole := range userRoles {
|
|
||||||
if !utils.IsActive(userRole.Status) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 獲取角色詳情
|
|
||||||
role, err := uc.roleRepo.GetByUID(ctx, userRole.RoleUID)
|
|
||||||
if err != nil {
|
|
||||||
continue // 忽略獲取失敗的角色
|
|
||||||
}
|
|
||||||
|
|
||||||
detail := &usecase.UserRoleDetail{
|
|
||||||
UserRole: userRole,
|
|
||||||
Role: role,
|
|
||||||
}
|
|
||||||
details = append(details, detail)
|
|
||||||
}
|
|
||||||
|
|
||||||
return details, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *UserRoleUseCase) BatchAssignRoles(ctx context.Context, uid string, roleUIDs []string, brand string) error {
|
|
||||||
if uid == "" {
|
|
||||||
return errs.InvalidFormat("uid is required")
|
|
||||||
}
|
|
||||||
if brand == "" {
|
|
||||||
return errs.InvalidFormat("brand is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 逐個分配角色
|
|
||||||
for _, roleUID := range roleUIDs {
|
|
||||||
req := usecase.AssignRoleRequest{
|
|
||||||
Brand: brand,
|
|
||||||
UID: uid,
|
|
||||||
RoleUID: roleUID,
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := uc.AssignRole(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
// 如果是已存在錯誤,忽略繼續
|
|
||||||
e := errs.FromError(err)
|
|
||||||
if e.Is(errs.ResourceAlreadyExist()) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *UserRoleUseCase) BatchRevokeRoles(ctx context.Context, uid string, roleUIDs []string) error {
|
|
||||||
if uid == "" {
|
|
||||||
return errs.InvalidFormat("uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 逐個撤銷角色
|
|
||||||
for _, roleUID := range roleUIDs {
|
|
||||||
err := uc.RevokeRole(ctx, uid, roleUID)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *UserRoleUseCase) ReplaceUserRoles(ctx context.Context, uid string, roleUIDs []string, brand string) error {
|
|
||||||
// 獲取用戶當前的所有角色
|
|
||||||
currentUserRoles, err := uc.GetUserRoles(ctx, uid)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 撤銷所有現有角色
|
|
||||||
for _, userRole := range currentUserRoles {
|
|
||||||
if utils.IsActive(userRole.Status) {
|
|
||||||
if err := uc.RevokeRole(ctx, uid, userRole.RoleUID); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 分配新角色
|
|
||||||
return uc.BatchAssignRoles(ctx, uid, roleUIDs, brand)
|
|
||||||
}
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
package utils
|
|
||||||
|
|
||||||
import "backend/pkg/permission/domain/permission"
|
|
||||||
|
|
||||||
func IsActive(status int) bool {
|
|
||||||
return status == permission.StatusActive
|
|
||||||
}
|
|
||||||
Loading…
Reference in New Issue