Compare commits
No commits in common. "feat/notification" and "main" have entirely different histories.
feat/notif
...
main
|
|
@ -1 +0,0 @@
|
||||||
DROP TYPE IF EXISTS notification_event;
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
||||||
CREATE TABLE IF NOT EXISTS notification_event (
|
|
||||||
event_id uuid PRIMARY KEY, -- 事件 ID
|
|
||||||
|
|
||||||
event_type text, -- POST_PUBLISHED / COMMENT_ADDED / MENTIONED ...
|
|
||||||
actor_uid text, -- 觸發者 UID(例如 A)
|
|
||||||
object_type text, -- POST / COMMENT / USER ...
|
|
||||||
object_id text, -- 對應物件 ID(post_id 等)
|
|
||||||
|
|
||||||
title text, -- 顯示用標題
|
|
||||||
body text, -- 顯示用內容 / 摘要
|
|
||||||
payload text, -- JSON string(額外欄位,例如 {"postId": "..."})
|
|
||||||
|
|
||||||
priority smallint, -- 1=critical, 2=high, 3=normal, 4=low
|
|
||||||
created_at timestamp -- 事件時間(方便做 cross table 查詢)
|
|
||||||
) AND comment = 'notification_event';
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
DROP TYPE IF EXISTS user_notification;
|
|
||||||
|
|
@ -1,11 +0,0 @@
|
||||||
CREATE TABLE IF NOT EXISTS user_notification (
|
|
||||||
user_id text, -- 收通知的人
|
|
||||||
bucket text, -- 分桶,例如 '2025-11' 或 '2025-11-17'
|
|
||||||
ts timeuuid, -- 通知時間,用 now() 產生,排序用
|
|
||||||
|
|
||||||
event_id uuid, -- 對應 notification_event.event_id
|
|
||||||
status text, -- 'UNREAD' / 'READ' / 'ARCHIVED'
|
|
||||||
read_at timestamp, -- 已讀時間(非必填)
|
|
||||||
|
|
||||||
PRIMARY KEY ((user_id, bucket), ts)
|
|
||||||
) WITH CLUSTERING ORDER BY (ts DESC);
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
DROP TYPE IF EXISTS notification_cursor;
|
|
||||||
|
|
@ -1,5 +0,0 @@
|
||||||
CREATE TABLE IF NOT EXISTS notification_cursor (
|
|
||||||
user_id text PRIMARY KEY,
|
|
||||||
last_seen_ts timeuuid, -- 最後看到的通知 timeuuid
|
|
||||||
updated_at timestamp
|
|
||||||
);
|
|
||||||
5
go.mod
5
go.mod
|
|
@ -10,7 +10,6 @@ require (
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.18.21
|
github.com/aws/aws-sdk-go-v2/credentials v1.18.21
|
||||||
github.com/aws/aws-sdk-go-v2/service/ses v1.34.9
|
github.com/aws/aws-sdk-go-v2/service/ses v1.34.9
|
||||||
github.com/go-playground/validator/v10 v10.28.0
|
github.com/go-playground/validator/v10 v10.28.0
|
||||||
github.com/gocql/gocql v1.7.0
|
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.2
|
github.com/golang-jwt/jwt/v4 v4.5.2
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
github.com/matcornic/hermes/v2 v2.1.0
|
github.com/matcornic/hermes/v2 v2.1.0
|
||||||
|
|
@ -69,7 +68,6 @@ require (
|
||||||
github.com/grafana/pyroscope-go v1.2.7 // indirect
|
github.com/grafana/pyroscope-go v1.2.7 // indirect
|
||||||
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 // indirect
|
github.com/grafana/pyroscope-go/godeltaprof v0.1.9 // indirect
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 // indirect
|
||||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
|
|
||||||
github.com/huandu/xstrings v1.2.0 // indirect
|
github.com/huandu/xstrings v1.2.0 // indirect
|
||||||
github.com/imdario/mergo v0.3.6 // indirect
|
github.com/imdario/mergo v0.3.6 // indirect
|
||||||
github.com/jaytaylor/html2text v0.0.0-20180606194806-57d518f124b0 // indirect
|
github.com/jaytaylor/html2text v0.0.0-20180606194806-57d518f124b0 // indirect
|
||||||
|
|
@ -105,8 +103,6 @@ require (
|
||||||
github.com/redis/go-redis/v9 v9.14.0 // indirect
|
github.com/redis/go-redis/v9 v9.14.0 // indirect
|
||||||
github.com/rivo/uniseg v0.2.0 // indirect
|
github.com/rivo/uniseg v0.2.0 // indirect
|
||||||
github.com/russross/blackfriday/v2 v2.0.1 // indirect
|
github.com/russross/blackfriday/v2 v2.0.1 // indirect
|
||||||
github.com/scylladb/go-reflectx v1.0.1 // indirect
|
|
||||||
github.com/scylladb/gocqlx/v2 v2.8.0 // indirect
|
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
github.com/shirou/gopsutil/v4 v4.25.6 // indirect
|
||||||
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
|
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
|
||||||
github.com/sirupsen/logrus v1.9.3 // indirect
|
github.com/sirupsen/logrus v1.9.3 // indirect
|
||||||
|
|
@ -143,7 +139,6 @@ require (
|
||||||
google.golang.org/genproto/googleapis/api v0.0.0-20250804133106-a7a43d27e69b // indirect
|
google.golang.org/genproto/googleapis/api v0.0.0-20250804133106-a7a43d27e69b // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250804133106-a7a43d27e69b // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250804133106-a7a43d27e69b // indirect
|
||||||
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
|
gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect
|
||||||
gopkg.in/inf.v0 v0.9.1 // indirect
|
|
||||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|
|
||||||
18
go.sum
18
go.sum
|
|
@ -36,10 +36,6 @@ github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM=
|
||||||
github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
|
||||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
|
|
||||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
|
|
||||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
|
|
||||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||||
|
|
@ -97,13 +93,10 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
|
||||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||||
github.com/go-playground/validator/v10 v10.28.0 h1:Q7ibns33JjyW48gHkuFT91qX48KG0ktULL6FgHdG688=
|
github.com/go-playground/validator/v10 v10.28.0 h1:Q7ibns33JjyW48gHkuFT91qX48KG0ktULL6FgHdG688=
|
||||||
github.com/go-playground/validator/v10 v10.28.0/go.mod h1:GoI6I1SjPBh9p7ykNE/yj3fFYbyDOpwMn5KXd+m2hUU=
|
github.com/go-playground/validator/v10 v10.28.0/go.mod h1:GoI6I1SjPBh9p7ykNE/yj3fFYbyDOpwMn5KXd+m2hUU=
|
||||||
github.com/gocql/gocql v1.7.0 h1:O+7U7/1gSN7QTEAaMEsJc1Oq2QHXvCWoF3DFK9HDHus=
|
|
||||||
github.com/gocql/gocql v1.7.0/go.mod h1:vnlvXyFZeLBF0Wy+RS8hrOdbn0UWsWtdg07XJnFxZ+4=
|
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI=
|
||||||
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||||
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
|
||||||
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
|
||||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
|
||||||
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs=
|
||||||
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
|
|
@ -122,8 +115,6 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1
|
||||||
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
|
github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k=
|
||||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
|
||||||
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI=
|
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI=
|
||||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
|
|
||||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
|
|
||||||
github.com/huandu/xstrings v1.2.0 h1:yPeWdRnmynF7p+lLYz0H2tthW9lqhMJrQV/U7yy4wX0=
|
github.com/huandu/xstrings v1.2.0 h1:yPeWdRnmynF7p+lLYz0H2tthW9lqhMJrQV/U7yy4wX0=
|
||||||
github.com/huandu/xstrings v1.2.0/go.mod h1:DvyZB1rfVYsBIigL8HwpZgxHwXozlTgGqn63UyNX5k4=
|
github.com/huandu/xstrings v1.2.0/go.mod h1:DvyZB1rfVYsBIigL8HwpZgxHwXozlTgGqn63UyNX5k4=
|
||||||
github.com/imdario/mergo v0.3.6 h1:xTNEAn+kxVO7dTZGu0CegyqKZmoWFI0rF8UxjlB2d28=
|
github.com/imdario/mergo v0.3.6 h1:xTNEAn+kxVO7dTZGu0CegyqKZmoWFI0rF8UxjlB2d28=
|
||||||
|
|
@ -219,12 +210,6 @@ github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR
|
||||||
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
|
||||||
github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q=
|
github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q=
|
||||||
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||||
github.com/scylladb/go-reflectx v1.0.1 h1:b917wZM7189pZdlND9PbIJ6NQxfDPfBvUaQ7cjj1iZQ=
|
|
||||||
github.com/scylladb/go-reflectx v1.0.1/go.mod h1:rWnOfDIRWBGN0miMLIcoPt/Dhi2doCMZqwMCJ3KupFc=
|
|
||||||
github.com/scylladb/gocqlx/v2 v2.8.0 h1:f/oIgoEPjKDKd+RIoeHqexsIQVIbalVmT+axwvUqQUg=
|
|
||||||
github.com/scylladb/gocqlx/v2 v2.8.0/go.mod h1:4/+cga34PVqjhgSoo5Nr2fX1MQIqZB5eCE5DK4xeDig=
|
|
||||||
github.com/scylladb/gocqlx/v3 v3.0.4 h1:37rMVFEUlsGGNYB7OLR7991KwBYR2WA5TU7wtduClas=
|
|
||||||
github.com/scylladb/gocqlx/v3 v3.0.4/go.mod h1:3vBkGO+HRh/BYypLWXzurQ45u1BAO0VGBhg5VgperPY=
|
|
||||||
github.com/segmentio/ksuid v1.0.4 h1:sBo2BdShXjmcugAMwjugoGUdUV0pcxY5mW4xKRn3v4c=
|
github.com/segmentio/ksuid v1.0.4 h1:sBo2BdShXjmcugAMwjugoGUdUV0pcxY5mW4xKRn3v4c=
|
||||||
github.com/segmentio/ksuid v1.0.4/go.mod h1:/XUiZBD3kVx5SmUOl55voK5yeAbBNNIed+2O73XgrPE=
|
github.com/segmentio/ksuid v1.0.4/go.mod h1:/XUiZBD3kVx5SmUOl55voK5yeAbBNNIed+2O73XgrPE=
|
||||||
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
github.com/shirou/gopsutil/v4 v4.25.6 h1:kLysI2JsKorfaFPcYmcJqbzROzsBWEOAtw6A7dIfqXs=
|
||||||
|
|
@ -245,7 +230,6 @@ github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpE
|
||||||
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
|
||||||
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
|
||||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
|
||||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||||
|
|
@ -385,8 +369,6 @@ gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AW
|
||||||
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw=
|
gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw=
|
||||||
gopkg.in/h2non/gock.v1 v1.1.2 h1:jBbHXgGBK/AoPVfJh5x4r/WxIrElvbLel8TCZkkZJoY=
|
gopkg.in/h2non/gock.v1 v1.1.2 h1:jBbHXgGBK/AoPVfJh5x4r/WxIrElvbLel8TCZkkZJoY=
|
||||||
gopkg.in/h2non/gock.v1 v1.1.2/go.mod h1:n7UGz/ckNChHiK05rDoiC4MYSunEC/lyaUm2WWaDva0=
|
gopkg.in/h2non/gock.v1 v1.1.2/go.mod h1:n7UGz/ckNChHiK05rDoiC4MYSunEC/lyaUm2WWaDva0=
|
||||||
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
|
|
||||||
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
|
|
||||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||||
|
|
|
||||||
|
|
@ -1,758 +0,0 @@
|
||||||
# Cassandra Client Library
|
|
||||||
|
|
||||||
一個基於 Go Generics 的 Cassandra 客戶端庫,提供類型安全的 Repository 模式和流暢的查詢構建器 API。
|
|
||||||
|
|
||||||
## 功能特色
|
|
||||||
|
|
||||||
- **類型安全**: 使用 Go Generics 提供編譯時類型檢查
|
|
||||||
- **Repository 模式**: 簡潔的 CRUD 操作介面
|
|
||||||
- **流暢查詢**: 鏈式查詢構建器,支援條件、排序、限制
|
|
||||||
- **分散式鎖**: 基於 Cassandra 的 IF NOT EXISTS 實現分散式鎖
|
|
||||||
- **批次操作**: 支援批次插入、更新、刪除
|
|
||||||
- **SAI 索引支援**: 完整的 SAI (Storage-Attached Indexing) 索引管理功能
|
|
||||||
- **Option 模式**: 靈活的配置選項
|
|
||||||
- **錯誤處理**: 統一的錯誤處理機制
|
|
||||||
- **高效能**: 內建連接池、重試機制、Prepared Statement 快取
|
|
||||||
|
|
||||||
## 安裝
|
|
||||||
|
|
||||||
```bash
|
|
||||||
go get github.com/scylladb/gocqlx/v2
|
|
||||||
go get github.com/gocql/gocql
|
|
||||||
```
|
|
||||||
|
|
||||||
## 快速開始
|
|
||||||
|
|
||||||
### 1. 定義資料模型
|
|
||||||
|
|
||||||
```go
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
"backend/pkg/library/cassandra"
|
|
||||||
)
|
|
||||||
|
|
||||||
// User 定義用戶資料模型
|
|
||||||
type User struct {
|
|
||||||
ID gocql.UUID `db:"id" partition_key:"true"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
Email string `db:"email"`
|
|
||||||
Age int `db:"age"`
|
|
||||||
CreatedAt time.Time `db:"created_at"`
|
|
||||||
UpdatedAt time.Time `db:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// TableName 實現 Table 介面
|
|
||||||
func (u User) TableName() string {
|
|
||||||
return "users"
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 初始化資料庫連接
|
|
||||||
|
|
||||||
```go
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"backend/pkg/library/cassandra"
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
// 創建資料庫連接
|
|
||||||
db, err := cassandra.New(
|
|
||||||
cassandra.WithHosts("127.0.0.1"),
|
|
||||||
cassandra.WithPort(9042),
|
|
||||||
cassandra.WithKeyspace("my_keyspace"),
|
|
||||||
cassandra.WithAuth("username", "password"),
|
|
||||||
cassandra.WithConsistency(gocql.Quorum),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
// 創建 Repository
|
|
||||||
userRepo, err := cassandra.NewRepository[User](db, "my_keyspace")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// 使用 Repository...
|
|
||||||
_ = userRepo
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 詳細範例
|
|
||||||
|
|
||||||
### CRUD 操作
|
|
||||||
|
|
||||||
#### 插入資料
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 插入單筆資料
|
|
||||||
user := User{
|
|
||||||
ID: gocql.TimeUUID(),
|
|
||||||
Name: "Alice",
|
|
||||||
Email: "alice@example.com",
|
|
||||||
Age: 30,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
UpdatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := userRepo.Insert(ctx, user)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("插入失敗: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 批次插入
|
|
||||||
users := []User{
|
|
||||||
{ID: gocql.TimeUUID(), Name: "Bob", Email: "bob@example.com"},
|
|
||||||
{ID: gocql.TimeUUID(), Name: "Charlie", Email: "charlie@example.com"},
|
|
||||||
}
|
|
||||||
|
|
||||||
err = userRepo.InsertMany(ctx, users)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("批次插入失敗: %v", err)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 查詢資料
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 根據主鍵查詢
|
|
||||||
userID := gocql.TimeUUID()
|
|
||||||
user, err := userRepo.Get(ctx, userID)
|
|
||||||
if err != nil {
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
log.Println("用戶不存在")
|
|
||||||
} else {
|
|
||||||
log.Printf("查詢失敗: %v", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("用戶: %+v\n", user)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 更新資料
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 更新資料(只更新非零值欄位)
|
|
||||||
user.Name = "Alice Updated"
|
|
||||||
user.Email = "alice.updated@example.com"
|
|
||||||
err = userRepo.Update(ctx, user)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("更新失敗: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新所有欄位(包括零值)
|
|
||||||
user.Age = 0 // 零值也會被更新
|
|
||||||
err = userRepo.UpdateAll(ctx, user)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("更新失敗: %v", err)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 刪除資料
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 刪除資料
|
|
||||||
err = userRepo.Delete(ctx, userID)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("刪除失敗: %v", err)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 查詢構建器
|
|
||||||
|
|
||||||
#### 基本查詢
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 查詢所有符合條件的記錄
|
|
||||||
var users []User
|
|
||||||
err := userRepo.Query().
|
|
||||||
Where(cassandra.Eq("age", 30)).
|
|
||||||
OrderBy("created_at", cassandra.DESC).
|
|
||||||
Limit(10).
|
|
||||||
Scan(ctx, &users)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("查詢失敗: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢單筆記錄
|
|
||||||
user, err := userRepo.Query().
|
|
||||||
Where(cassandra.Eq("email", "alice@example.com")).
|
|
||||||
One(ctx)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
log.Println("用戶不存在")
|
|
||||||
} else {
|
|
||||||
log.Printf("查詢失敗: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 條件查詢
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 等於條件
|
|
||||||
userRepo.Query().Where(cassandra.Eq("name", "Alice"))
|
|
||||||
|
|
||||||
// IN 條件
|
|
||||||
userRepo.Query().Where(cassandra.In("id", []any{id1, id2, id3}))
|
|
||||||
|
|
||||||
// 大於條件
|
|
||||||
userRepo.Query().Where(cassandra.Gt("age", 18))
|
|
||||||
|
|
||||||
// 小於條件
|
|
||||||
userRepo.Query().Where(cassandra.Lt("age", 65))
|
|
||||||
|
|
||||||
// 組合多個條件
|
|
||||||
userRepo.Query().
|
|
||||||
Where(cassandra.Eq("status", "active")).
|
|
||||||
Where(cassandra.Gt("age", 18))
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 排序和限制
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 按建立時間降序排列,限制 20 筆
|
|
||||||
var users []User
|
|
||||||
err := userRepo.Query().
|
|
||||||
OrderBy("created_at", cassandra.DESC).
|
|
||||||
Limit(20).
|
|
||||||
Scan(ctx, &users)
|
|
||||||
|
|
||||||
// 多欄位排序
|
|
||||||
err = userRepo.Query().
|
|
||||||
OrderBy("status", cassandra.ASC).
|
|
||||||
OrderBy("created_at", cassandra.DESC).
|
|
||||||
Scan(ctx, &users)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 選擇特定欄位
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 只查詢特定欄位
|
|
||||||
var users []User
|
|
||||||
err := userRepo.Query().
|
|
||||||
Select("id", "name", "email").
|
|
||||||
Where(cassandra.Eq("status", "active")).
|
|
||||||
Scan(ctx, &users)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 計數查詢
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 計算符合條件的記錄數
|
|
||||||
count, err := userRepo.Query().
|
|
||||||
Where(cassandra.Eq("status", "active")).
|
|
||||||
Count(ctx)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("計數失敗: %v", err)
|
|
||||||
} else {
|
|
||||||
fmt.Printf("活躍用戶數: %d\n", count)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 分散式鎖
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 獲取鎖(預設 30 秒 TTL)
|
|
||||||
lockUser := User{ID: userID}
|
|
||||||
err := userRepo.TryLock(ctx, lockUser)
|
|
||||||
if err != nil {
|
|
||||||
if cassandra.IsLockFailed(err) {
|
|
||||||
log.Println("獲取鎖失敗,資源已被鎖定")
|
|
||||||
} else {
|
|
||||||
log.Printf("鎖操作失敗: %v", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 執行需要鎖定的操作
|
|
||||||
defer func() {
|
|
||||||
// 釋放鎖
|
|
||||||
if err := userRepo.UnLock(ctx, lockUser); err != nil {
|
|
||||||
log.Printf("釋放鎖失敗: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 執行業務邏輯...
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 自訂鎖 TTL
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 設定鎖的 TTL 為 60 秒
|
|
||||||
err := userRepo.TryLock(ctx, lockUser, cassandra.WithLockTTL(60*time.Second))
|
|
||||||
|
|
||||||
// 永不自動解鎖
|
|
||||||
err := userRepo.TryLock(ctx, lockUser, cassandra.WithNoLockExpire())
|
|
||||||
```
|
|
||||||
|
|
||||||
### 複雜主鍵
|
|
||||||
|
|
||||||
#### 複合主鍵(Partition Key + Clustering Key)
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 定義複合主鍵模型
|
|
||||||
type Order struct {
|
|
||||||
UserID gocql.UUID `db:"user_id" partition_key:"true"`
|
|
||||||
OrderID gocql.UUID `db:"order_id" clustering_key:"true"`
|
|
||||||
ProductID string `db:"product_id"`
|
|
||||||
Quantity int `db:"quantity"`
|
|
||||||
Price float64 `db:"price"`
|
|
||||||
CreatedAt time.Time `db:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o Order) TableName() string {
|
|
||||||
return "orders"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢時需要提供完整的主鍵
|
|
||||||
order, err := orderRepo.Get(ctx, Order{
|
|
||||||
UserID: userID,
|
|
||||||
OrderID: orderID,
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 多欄位 Partition Key
|
|
||||||
|
|
||||||
```go
|
|
||||||
type Message struct {
|
|
||||||
ChatID gocql.UUID `db:"chat_id" partition_key:"true"`
|
|
||||||
MessageID gocql.UUID `db:"message_id" clustering_key:"true"`
|
|
||||||
UserID gocql.UUID `db:"user_id" partition_key:"true"`
|
|
||||||
Content string `db:"content"`
|
|
||||||
CreatedAt time.Time `db:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m Message) TableName() string {
|
|
||||||
return "messages"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢時需要提供所有 Partition Key
|
|
||||||
message, err := messageRepo.Get(ctx, Message{
|
|
||||||
ChatID: chatID,
|
|
||||||
UserID: userID,
|
|
||||||
MessageID: messageID,
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
## 配置選項
|
|
||||||
|
|
||||||
### 連接選項
|
|
||||||
|
|
||||||
```go
|
|
||||||
db, err := cassandra.New(
|
|
||||||
// 主機列表
|
|
||||||
cassandra.WithHosts("127.0.0.1", "127.0.0.2", "127.0.0.3"),
|
|
||||||
|
|
||||||
// 連接埠
|
|
||||||
cassandra.WithPort(9042),
|
|
||||||
|
|
||||||
// Keyspace
|
|
||||||
cassandra.WithKeyspace("my_keyspace"),
|
|
||||||
|
|
||||||
// 認證
|
|
||||||
cassandra.WithAuth("username", "password"),
|
|
||||||
|
|
||||||
// 一致性級別
|
|
||||||
cassandra.WithConsistency(gocql.Quorum),
|
|
||||||
|
|
||||||
// 連接超時
|
|
||||||
cassandra.WithConnectTimeout(10 * time.Second),
|
|
||||||
|
|
||||||
// 每個節點的連接數
|
|
||||||
cassandra.WithNumConns(10),
|
|
||||||
|
|
||||||
// 重試次數
|
|
||||||
cassandra.WithMaxRetries(3),
|
|
||||||
|
|
||||||
// 重試間隔
|
|
||||||
cassandra.WithRetryInterval(100*time.Millisecond, 1*time.Second),
|
|
||||||
|
|
||||||
// 重連間隔
|
|
||||||
cassandra.WithReconnectInterval(1*time.Second, 10*time.Second),
|
|
||||||
|
|
||||||
// CQL 版本
|
|
||||||
cassandra.WithCQLVersion("3.0.0"),
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 錯誤處理
|
|
||||||
|
|
||||||
### 錯誤類型
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 檢查是否為特定錯誤
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
// 記錄不存在
|
|
||||||
}
|
|
||||||
|
|
||||||
if cassandra.IsConflict(err) {
|
|
||||||
// 衝突錯誤(如唯一鍵衝突)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cassandra.IsLockFailed(err) {
|
|
||||||
// 獲取鎖失敗
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用 errors.As 獲取詳細錯誤資訊
|
|
||||||
var cassandraErr *cassandra.Error
|
|
||||||
if errors.As(err, &cassandraErr) {
|
|
||||||
fmt.Printf("錯誤代碼: %s\n", cassandraErr.Code)
|
|
||||||
fmt.Printf("錯誤訊息: %s\n", cassandraErr.Message)
|
|
||||||
fmt.Printf("資料表: %s\n", cassandraErr.Table)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 錯誤代碼
|
|
||||||
|
|
||||||
- `NOT_FOUND`: 記錄未找到
|
|
||||||
- `CONFLICT`: 衝突(如唯一鍵衝突、鎖獲取失敗)
|
|
||||||
- `INVALID_INPUT`: 輸入參數無效
|
|
||||||
- `MISSING_PARTITION_KEY`: 缺少 Partition Key
|
|
||||||
- `NO_FIELDS_TO_UPDATE`: 沒有欄位需要更新
|
|
||||||
- `MISSING_TABLE_NAME`: 缺少 TableName 方法
|
|
||||||
- `MISSING_WHERE_CONDITION`: 缺少 WHERE 條件
|
|
||||||
|
|
||||||
## 最佳實踐
|
|
||||||
|
|
||||||
### 1. 使用 Context
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 所有操作都應該傳入 context,以便支援超時和取消
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
user, err := userRepo.Get(ctx, userID)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 錯誤處理
|
|
||||||
|
|
||||||
```go
|
|
||||||
user, err := userRepo.Get(ctx, userID)
|
|
||||||
if err != nil {
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
// 處理不存在的情況
|
|
||||||
return nil, ErrUserNotFound
|
|
||||||
}
|
|
||||||
// 處理其他錯誤
|
|
||||||
return nil, fmt.Errorf("查詢用戶失敗: %w", err)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 批次操作
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 對於大量資料,使用批次插入
|
|
||||||
const batchSize = 100
|
|
||||||
for i := 0; i < len(users); i += batchSize {
|
|
||||||
end := i + batchSize
|
|
||||||
if end > len(users) {
|
|
||||||
end = len(users)
|
|
||||||
}
|
|
||||||
|
|
||||||
err := userRepo.InsertMany(ctx, users[i:end])
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("批次插入失敗 (索引 %d-%d): %v", i, end, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. 使用分散式鎖
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 在需要保證原子性的操作中使用鎖
|
|
||||||
err := userRepo.TryLock(ctx, lockUser, cassandra.WithLockTTL(30*time.Second))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("獲取鎖失敗: %w", err)
|
|
||||||
}
|
|
||||||
defer userRepo.UnLock(ctx, lockUser)
|
|
||||||
|
|
||||||
// 執行需要原子性的操作
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. 查詢優化
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 只選擇需要的欄位
|
|
||||||
var users []User
|
|
||||||
err := userRepo.Query().
|
|
||||||
Select("id", "name", "email"). // 只選擇需要的欄位
|
|
||||||
Where(cassandra.Eq("status", "active")).
|
|
||||||
Scan(ctx, &users)
|
|
||||||
|
|
||||||
// 使用適當的限制
|
|
||||||
err = userRepo.Query().
|
|
||||||
Where(cassandra.Eq("status", "active")).
|
|
||||||
Limit(100). // 限制結果數量
|
|
||||||
Scan(ctx, &users)
|
|
||||||
```
|
|
||||||
|
|
||||||
## SAI 索引管理
|
|
||||||
|
|
||||||
### 建立 SAI 索引
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 檢查是否支援 SAI
|
|
||||||
if !db.SaiSupported() {
|
|
||||||
log.Fatal("SAI is not supported in this Cassandra version")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 建立標準索引
|
|
||||||
err := db.CreateSAIIndex(ctx, "my_keyspace", "users", "email", "users_email_idx", nil)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("建立索引失敗: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 建立全文索引(不區分大小寫)
|
|
||||||
opts := &cassandra.SAIIndexOptions{
|
|
||||||
IndexType: cassandra.SAIIndexTypeFullText,
|
|
||||||
IsAsync: false,
|
|
||||||
CaseSensitive: false,
|
|
||||||
}
|
|
||||||
err = db.CreateSAIIndex(ctx, "my_keyspace", "posts", "content", "posts_content_ft_idx", opts)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 查詢 SAI 索引
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 列出資料表的所有 SAI 索引
|
|
||||||
indexes, err := db.ListSAIIndexes(ctx, "my_keyspace", "users")
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("查詢索引失敗: %v", err)
|
|
||||||
} else {
|
|
||||||
for _, idx := range indexes {
|
|
||||||
fmt.Printf("索引: %s, 欄位: %s, 類型: %s\n", idx.Name, idx.Column, idx.Type)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 檢查索引是否存在
|
|
||||||
exists, err := db.CheckSAIIndexExists(ctx, "my_keyspace", "users_email_idx")
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("檢查索引失敗: %v", err)
|
|
||||||
} else if exists {
|
|
||||||
fmt.Println("索引存在")
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 刪除 SAI 索引
|
|
||||||
|
|
||||||
```go
|
|
||||||
// 刪除索引
|
|
||||||
err := db.DropSAIIndex(ctx, "my_keyspace", "users_email_idx")
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("刪除索引失敗: %v", err)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### SAI 索引類型
|
|
||||||
|
|
||||||
- **SAIIndexTypeStandard**: 標準索引(等於查詢)
|
|
||||||
- **SAIIndexTypeCollection**: 集合索引(用於 list、set、map)
|
|
||||||
- **SAIIndexTypeFullText**: 全文索引
|
|
||||||
|
|
||||||
### SAI 索引選項
|
|
||||||
|
|
||||||
```go
|
|
||||||
opts := &cassandra.SAIIndexOptions{
|
|
||||||
IndexType: cassandra.SAIIndexTypeFullText, // 索引類型
|
|
||||||
IsAsync: false, // 是否異步建立
|
|
||||||
CaseSensitive: true, // 是否區分大小寫
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 注意事項
|
|
||||||
|
|
||||||
### 1. 主鍵要求
|
|
||||||
|
|
||||||
- `Get` 和 `Delete` 操作必須提供完整的主鍵(所有 Partition Key 和 Clustering Key)
|
|
||||||
- 單一主鍵值只適用於單一 Partition Key 且無 Clustering Key 的情況
|
|
||||||
|
|
||||||
### 2. 更新操作
|
|
||||||
|
|
||||||
- `Update` 只更新非零值欄位
|
|
||||||
- `UpdateAll` 更新所有欄位(包括零值)
|
|
||||||
- 更新操作必須包含主鍵欄位
|
|
||||||
|
|
||||||
### 3. 查詢限制
|
|
||||||
|
|
||||||
- Cassandra 的查詢必須包含所有 Partition Key
|
|
||||||
- 排序只能按 Clustering Key 進行
|
|
||||||
- 不支援 JOIN 操作
|
|
||||||
|
|
||||||
### 4. 分散式鎖
|
|
||||||
|
|
||||||
- 鎖使用 IF NOT EXISTS 實現,預設 30 秒 TTL
|
|
||||||
- 獲取鎖失敗時會返回 `CONFLICT` 錯誤
|
|
||||||
- 釋放鎖時會自動重試,最多 3 次
|
|
||||||
|
|
||||||
### 5. 批次操作
|
|
||||||
|
|
||||||
- 批次操作有大小限制(建議不超過 1000 筆)
|
|
||||||
- 批次操作中的所有操作必須屬於同一個 Partition Key
|
|
||||||
|
|
||||||
### 6. SAI 索引
|
|
||||||
|
|
||||||
- SAI 索引需要 Cassandra 4.0.9+ 版本(建議 5.0+)
|
|
||||||
- 建立索引前請先檢查 `db.SaiSupported()`
|
|
||||||
- 索引建立是異步操作,可能需要一些時間
|
|
||||||
- 刪除索引時使用 `IF EXISTS`,避免索引不存在時報錯
|
|
||||||
- 使用 SAI 索引可以大幅提升非主鍵欄位的查詢效能
|
|
||||||
- 全文索引支援不區分大小寫的搜尋
|
|
||||||
|
|
||||||
## 完整範例
|
|
||||||
|
|
||||||
```go
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"backend/pkg/library/cassandra"
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
ID gocql.UUID `db:"id" partition_key:"true"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
Email string `db:"email"`
|
|
||||||
Age int `db:"age"`
|
|
||||||
Status string `db:"status"`
|
|
||||||
CreatedAt time.Time `db:"created_at"`
|
|
||||||
UpdatedAt time.Time `db:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u User) TableName() string {
|
|
||||||
return "users"
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
// 初始化資料庫連接
|
|
||||||
db, err := cassandra.New(
|
|
||||||
cassandra.WithHosts("127.0.0.1"),
|
|
||||||
cassandra.WithPort(9042),
|
|
||||||
cassandra.WithKeyspace("my_keyspace"),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
// 創建 Repository
|
|
||||||
userRepo, err := cassandra.NewRepository[User](db, "my_keyspace")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// 插入用戶
|
|
||||||
user := User{
|
|
||||||
ID: gocql.TimeUUID(),
|
|
||||||
Name: "Alice",
|
|
||||||
Email: "alice@example.com",
|
|
||||||
Age: 30,
|
|
||||||
Status: "active",
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
UpdatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := userRepo.Insert(ctx, user); err != nil {
|
|
||||||
log.Printf("插入失敗: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢用戶
|
|
||||||
foundUser, err := userRepo.Get(ctx, user.ID)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("查詢失敗: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Printf("查詢到的用戶: %+v\n", foundUser)
|
|
||||||
|
|
||||||
// 更新用戶
|
|
||||||
user.Name = "Alice Updated"
|
|
||||||
user.Email = "alice.updated@example.com"
|
|
||||||
if err := userRepo.Update(ctx, user); err != nil {
|
|
||||||
log.Printf("更新失敗: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢活躍用戶
|
|
||||||
var activeUsers []User
|
|
||||||
if err := userRepo.Query().
|
|
||||||
Where(cassandra.Eq("status", "active")).
|
|
||||||
OrderBy("created_at", cassandra.DESC).
|
|
||||||
Limit(10).
|
|
||||||
Scan(ctx, &activeUsers); err != nil {
|
|
||||||
log.Printf("查詢失敗: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Printf("活躍用戶數: %d\n", len(activeUsers))
|
|
||||||
|
|
||||||
// 使用分散式鎖
|
|
||||||
if err := userRepo.TryLock(ctx, user, cassandra.WithLockTTL(30*time.Second)); err != nil {
|
|
||||||
if cassandra.IsLockFailed(err) {
|
|
||||||
log.Println("獲取鎖失敗")
|
|
||||||
} else {
|
|
||||||
log.Printf("鎖操作失敗: %v", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer userRepo.UnLock(ctx, user)
|
|
||||||
|
|
||||||
// 執行需要鎖定的操作
|
|
||||||
fmt.Println("執行需要鎖定的操作...")
|
|
||||||
|
|
||||||
// 刪除用戶
|
|
||||||
if err := userRepo.Delete(ctx, user.ID); err != nil {
|
|
||||||
log.Printf("刪除失敗: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Println("操作完成")
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 測試
|
|
||||||
|
|
||||||
套件包含完整的測試覆蓋,包括:
|
|
||||||
|
|
||||||
- 單元測試(table-driven tests)
|
|
||||||
- 集成測試(使用 testcontainers)
|
|
||||||
|
|
||||||
運行測試:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
go test ./pkg/library/cassandra/...
|
|
||||||
```
|
|
||||||
|
|
||||||
查看測試覆蓋率:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
go test ./pkg/library/cassandra/... -cover
|
|
||||||
```
|
|
||||||
|
|
||||||
## 授權
|
|
||||||
|
|
||||||
本專案遵循專案的主要授權協議。
|
|
||||||
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// 預設設定常數
|
|
||||||
const (
|
|
||||||
defaultNumConns = 10 // 預設每個節點的連線數量
|
|
||||||
defaultTimeoutSec = 10 // 預設連線逾時秒數
|
|
||||||
defaultMaxRetries = 3 // 預設重試次數
|
|
||||||
defaultPort = 9042
|
|
||||||
defaultConsistency = gocql.Quorum
|
|
||||||
defaultRetryMinInterval = 1 * time.Second
|
|
||||||
defaultRetryMaxInterval = 30 * time.Second
|
|
||||||
defaultReconnectInitialInterval = 1 * time.Second
|
|
||||||
defaultReconnectMaxInterval = 60 * time.Second
|
|
||||||
defaultCqlVersion = "3.0.0"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
DBFiledName = "db"
|
|
||||||
Pk = "partition_key"
|
|
||||||
ClusterKey = "clustering_key"
|
|
||||||
)
|
|
||||||
|
|
@ -1,158 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
"github.com/scylladb/gocqlx/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
// DB 是 Cassandra 的核心資料庫連接
|
|
||||||
type DB struct {
|
|
||||||
session gocqlx.Session
|
|
||||||
defaultKeyspace string
|
|
||||||
version string
|
|
||||||
saiSupported bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// New 創建新的 DB 實例
|
|
||||||
func New(opts ...Option) (*DB, error) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(cfg)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(cfg.Hosts) == 0 {
|
|
||||||
return nil, fmt.Errorf("at least one host is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 建立連線設定
|
|
||||||
cluster := gocql.NewCluster(cfg.Hosts...)
|
|
||||||
cluster.Port = cfg.Port
|
|
||||||
cluster.Consistency = cfg.Consistency
|
|
||||||
cluster.Timeout = time.Duration(cfg.ConnectTimeoutSec) * time.Second
|
|
||||||
cluster.NumConns = cfg.NumConns
|
|
||||||
cluster.RetryPolicy = &gocql.ExponentialBackoffRetryPolicy{
|
|
||||||
NumRetries: cfg.MaxRetries,
|
|
||||||
Min: cfg.RetryMinInterval,
|
|
||||||
Max: cfg.RetryMaxInterval,
|
|
||||||
}
|
|
||||||
|
|
||||||
cluster.ReconnectionPolicy = &gocql.ExponentialReconnectionPolicy{
|
|
||||||
MaxRetries: cfg.MaxRetries,
|
|
||||||
InitialInterval: cfg.ReconnectInitialInterval,
|
|
||||||
MaxInterval: cfg.ReconnectMaxInterval,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 若有提供 Keyspace 則指定
|
|
||||||
if cfg.Keyspace != "" {
|
|
||||||
cluster.Keyspace = cfg.Keyspace
|
|
||||||
}
|
|
||||||
|
|
||||||
// 若啟用驗證則設定帳號密碼
|
|
||||||
if cfg.UseAuth {
|
|
||||||
cluster.Authenticator = gocql.PasswordAuthenticator{
|
|
||||||
Username: cfg.Username,
|
|
||||||
Password: cfg.Password,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 建立 Session
|
|
||||||
session, err := gocqlx.WrapSession(cluster.CreateSession())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to connect to Cassandra cluster (hosts: %v, port: %d): %w", cfg.Hosts, cfg.Port, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
db := &DB{
|
|
||||||
session: session,
|
|
||||||
defaultKeyspace: cfg.Keyspace,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 初始化版本資訊
|
|
||||||
version, err := db.getVersion(context.Background())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get DB version: %w", err)
|
|
||||||
}
|
|
||||||
db.version = version
|
|
||||||
db.saiSupported = isSAISupported(version)
|
|
||||||
|
|
||||||
return db, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close 關閉資料庫連線
|
|
||||||
func (db *DB) Close() {
|
|
||||||
db.session.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSession 返回底層的 gocqlx Session(用於進階操作)
|
|
||||||
func (db *DB) GetSession() gocqlx.Session {
|
|
||||||
return db.session
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDefaultKeyspace 返回預設的 keyspace
|
|
||||||
func (db *DB) GetDefaultKeyspace() string {
|
|
||||||
return db.defaultKeyspace
|
|
||||||
}
|
|
||||||
|
|
||||||
// Version 返回資料庫版本
|
|
||||||
func (db *DB) Version() string {
|
|
||||||
return db.version
|
|
||||||
}
|
|
||||||
|
|
||||||
// SaiSupported 返回是否支援 SAI
|
|
||||||
func (db *DB) SaiSupported() bool {
|
|
||||||
return db.saiSupported
|
|
||||||
}
|
|
||||||
|
|
||||||
// getVersion 獲取資料庫版本
|
|
||||||
func (db *DB) getVersion(ctx context.Context) (string, error) {
|
|
||||||
var version string
|
|
||||||
stmt := "SELECT release_version FROM system.local"
|
|
||||||
err := db.session.Query(stmt, []string{"release_version"}).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.One).
|
|
||||||
Scan(&version)
|
|
||||||
return version, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// isSAISupported 檢查版本是否支援 SAI
|
|
||||||
func isSAISupported(version string) bool {
|
|
||||||
// 只要 major >=5 就支援
|
|
||||||
// 4.0.9+ 才有 SAI,但不穩,強烈建議 5.0+
|
|
||||||
parts := strings.Split(version, ".")
|
|
||||||
if len(parts) < 2 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
major, _ := strconv.Atoi(parts[0])
|
|
||||||
minor, _ := strconv.Atoi(parts[1])
|
|
||||||
|
|
||||||
if major >= 5 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
if major == 4 {
|
|
||||||
if minor > 0 { // 4.1.x、4.2.x 直接支援
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if minor == 0 {
|
|
||||||
patch := 0
|
|
||||||
if len(parts) >= 3 {
|
|
||||||
patch, _ = strconv.Atoi(parts[2])
|
|
||||||
}
|
|
||||||
if patch >= 9 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// withContextAndTimestamp 為查詢添加 context 和時間戳
|
|
||||||
func (db *DB) withContextAndTimestamp(ctx context.Context, q *gocqlx.Queryx) *gocqlx.Queryx {
|
|
||||||
return q.WithContext(ctx).WithTimestamp(time.Now().UnixNano() / 1e3)
|
|
||||||
}
|
|
||||||
|
|
@ -1,545 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIsSAISupported(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
version string
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "version 5.0.0 should support SAI",
|
|
||||||
version: "5.0.0",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 5.1.0 should support SAI",
|
|
||||||
version: "5.1.0",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 6.0.0 should support SAI",
|
|
||||||
version: "6.0.0",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 4.1.0 should support SAI",
|
|
||||||
version: "4.1.0",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 4.2.0 should support SAI",
|
|
||||||
version: "4.2.0",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 4.0.9 should support SAI",
|
|
||||||
version: "4.0.9",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 4.0.10 should support SAI",
|
|
||||||
version: "4.0.10",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 4.0.8 should not support SAI",
|
|
||||||
version: "4.0.8",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 4.0.0 should not support SAI",
|
|
||||||
version: "4.0.0",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 3.11.0 should not support SAI",
|
|
||||||
version: "3.11.0",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid version format should not support SAI",
|
|
||||||
version: "invalid",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty version should not support SAI",
|
|
||||||
version: "",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version with only major should not support SAI",
|
|
||||||
version: "5",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 4.0.9 with extra parts should support SAI",
|
|
||||||
version: "4.0.9.1",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := isSAISupported(tt.version)
|
|
||||||
assert.Equal(t, tt.expected, result, "version %s should have SAI support = %v", tt.version, tt.expected)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNew_Validation(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
opts []Option
|
|
||||||
wantErr bool
|
|
||||||
errMsg string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "no hosts should return error",
|
|
||||||
opts: []Option{},
|
|
||||||
wantErr: true,
|
|
||||||
errMsg: "at least one host is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty hosts should return error",
|
|
||||||
opts: []Option{WithHosts()},
|
|
||||||
wantErr: true,
|
|
||||||
errMsg: "at least one host is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid hosts should not return error on validation",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("localhost"),
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple hosts should not return error on validation",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("localhost", "127.0.0.1"),
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with keyspace should not return error on validation",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("localhost"),
|
|
||||||
WithKeyspace("test_keyspace"),
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with port should not return error on validation",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("localhost"),
|
|
||||||
WithPort(9042),
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with auth should not return error on validation",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("localhost"),
|
|
||||||
WithAuth("user", "pass"),
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with all options should not return error on validation",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("localhost"),
|
|
||||||
WithKeyspace("test_keyspace"),
|
|
||||||
WithPort(9042),
|
|
||||||
WithAuth("user", "pass"),
|
|
||||||
WithConsistency(gocql.Quorum),
|
|
||||||
WithConnectTimeoutSec(10),
|
|
||||||
WithNumConns(10),
|
|
||||||
WithMaxRetries(3),
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
db, err := New(tt.opts...)
|
|
||||||
|
|
||||||
if tt.wantErr {
|
|
||||||
require.Error(t, err)
|
|
||||||
if tt.errMsg != "" {
|
|
||||||
assert.Contains(t, err.Error(), tt.errMsg)
|
|
||||||
}
|
|
||||||
assert.Nil(t, db)
|
|
||||||
} else {
|
|
||||||
// 注意:這裡可能會因為無法連接到真實的 Cassandra 而失敗
|
|
||||||
// 但至少驗證了配置驗證邏輯
|
|
||||||
if err != nil {
|
|
||||||
// 如果錯誤不是驗證錯誤,而是連接錯誤,這是可以接受的
|
|
||||||
assert.NotContains(t, err.Error(), "at least one host is required")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDB_GetDefaultKeyspace(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
keyspace string
|
|
||||||
expectedResult string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty keyspace should return empty string",
|
|
||||||
keyspace: "",
|
|
||||||
expectedResult: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-empty keyspace should return keyspace",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
expectedResult: "test_keyspace",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 DB 實例
|
|
||||||
// 在實際測試中,可能需要 mock 或使用 testcontainers
|
|
||||||
// 這裡只是展示測試結構
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDB_Version(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
version string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "version 5.0.0",
|
|
||||||
version: "5.0.0",
|
|
||||||
expected: "5.0.0",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 4.0.9",
|
|
||||||
version: "4.0.9",
|
|
||||||
expected: "4.0.9",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 3.11.0",
|
|
||||||
version: "3.11.0",
|
|
||||||
expected: "3.11.0",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 DB 實例
|
|
||||||
// 在實際測試中,可能需要 mock 或使用 testcontainers
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDB_SaiSupported(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
version string
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "version 5.0.0 should support SAI",
|
|
||||||
version: "5.0.0",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 4.0.9 should support SAI",
|
|
||||||
version: "4.0.9",
|
|
||||||
expected: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 4.0.8 should not support SAI",
|
|
||||||
version: "4.0.8",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 3.11.0 should not support SAI",
|
|
||||||
version: "3.11.0",
|
|
||||||
expected: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 DB 實例
|
|
||||||
// 在實際測試中,可能需要 mock 或使用 testcontainers
|
|
||||||
// 這裡只是展示測試結構
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDB_GetSession(t *testing.T) {
|
|
||||||
t.Run("GetSession should return non-nil session", func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 DB 實例
|
|
||||||
// 在實際測試中,可能需要 mock 或使用 testcontainers
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDB_Close(t *testing.T) {
|
|
||||||
t.Run("Close should not panic", func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 DB 實例
|
|
||||||
// 在實際測試中,可能需要 mock 或使用 testcontainers
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDB_getVersion(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
version string
|
|
||||||
queryErr error
|
|
||||||
wantErr bool
|
|
||||||
expectedVer string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "successful version query",
|
|
||||||
version: "5.0.0",
|
|
||||||
queryErr: nil,
|
|
||||||
wantErr: false,
|
|
||||||
expectedVer: "5.0.0",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "query error should return error",
|
|
||||||
version: "",
|
|
||||||
queryErr: errors.New("connection failed"),
|
|
||||||
wantErr: true,
|
|
||||||
expectedVer: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要 mock session
|
|
||||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDB_withContextAndTimestamp(t *testing.T) {
|
|
||||||
t.Run("withContextAndTimestamp should add context and timestamp", func(t *testing.T) {
|
|
||||||
// 注意:這需要 mock query
|
|
||||||
// 在實際測試中,需要使用 mock
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDefaultConfig(t *testing.T) {
|
|
||||||
t.Run("defaultConfig should return valid config", func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
require.NotNil(t, cfg)
|
|
||||||
assert.Equal(t, defaultPort, cfg.Port)
|
|
||||||
assert.Equal(t, defaultConsistency, cfg.Consistency)
|
|
||||||
assert.Equal(t, defaultTimeoutSec, cfg.ConnectTimeoutSec)
|
|
||||||
assert.Equal(t, defaultNumConns, cfg.NumConns)
|
|
||||||
assert.Equal(t, defaultMaxRetries, cfg.MaxRetries)
|
|
||||||
assert.Equal(t, defaultRetryMinInterval, cfg.RetryMinInterval)
|
|
||||||
assert.Equal(t, defaultRetryMaxInterval, cfg.RetryMaxInterval)
|
|
||||||
assert.Equal(t, defaultReconnectInitialInterval, cfg.ReconnectInitialInterval)
|
|
||||||
assert.Equal(t, defaultReconnectMaxInterval, cfg.ReconnectMaxInterval)
|
|
||||||
assert.Equal(t, defaultCqlVersion, cfg.CQLVersion)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOptionFunctions(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
opt Option
|
|
||||||
validateConfig func(*testing.T, *config)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "WithHosts should set hosts",
|
|
||||||
opt: WithHosts("host1", "host2"),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, []string{"host1", "host2"}, c.Hosts)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithPort should set port",
|
|
||||||
opt: WithPort(9999),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, 9999, c.Port)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithKeyspace should set keyspace",
|
|
||||||
opt: WithKeyspace("test_keyspace"),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, "test_keyspace", c.Keyspace)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithAuth should set auth and enable UseAuth",
|
|
||||||
opt: WithAuth("user", "pass"),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, "user", c.Username)
|
|
||||||
assert.Equal(t, "pass", c.Password)
|
|
||||||
assert.True(t, c.UseAuth)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithConsistency should set consistency",
|
|
||||||
opt: WithConsistency(gocql.One),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, gocql.One, c.Consistency)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithConnectTimeoutSec should set timeout",
|
|
||||||
opt: WithConnectTimeoutSec(20),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, 20, c.ConnectTimeoutSec)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithConnectTimeoutSec with zero should use default",
|
|
||||||
opt: WithConnectTimeoutSec(0),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, defaultTimeoutSec, c.ConnectTimeoutSec)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithNumConns should set numConns",
|
|
||||||
opt: WithNumConns(20),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, 20, c.NumConns)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithNumConns with zero should use default",
|
|
||||||
opt: WithNumConns(0),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, defaultNumConns, c.NumConns)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithMaxRetries should set maxRetries",
|
|
||||||
opt: WithMaxRetries(5),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, 5, c.MaxRetries)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithMaxRetries with zero should use default",
|
|
||||||
opt: WithMaxRetries(0),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, defaultMaxRetries, c.MaxRetries)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithRetryMinInterval should set retryMinInterval",
|
|
||||||
opt: WithRetryMinInterval(2 * time.Second),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, 2*time.Second, c.RetryMinInterval)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithRetryMinInterval with zero should use default",
|
|
||||||
opt: WithRetryMinInterval(0),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, defaultRetryMinInterval, c.RetryMinInterval)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithRetryMaxInterval should set retryMaxInterval",
|
|
||||||
opt: WithRetryMaxInterval(60 * time.Second),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, 60*time.Second, c.RetryMaxInterval)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithRetryMaxInterval with zero should use default",
|
|
||||||
opt: WithRetryMaxInterval(0),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, defaultRetryMaxInterval, c.RetryMaxInterval)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithReconnectInitialInterval should set reconnectInitialInterval",
|
|
||||||
opt: WithReconnectInitialInterval(2 * time.Second),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, 2*time.Second, c.ReconnectInitialInterval)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithReconnectInitialInterval with zero should use default",
|
|
||||||
opt: WithReconnectInitialInterval(0),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, defaultReconnectInitialInterval, c.ReconnectInitialInterval)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithReconnectMaxInterval should set reconnectMaxInterval",
|
|
||||||
opt: WithReconnectMaxInterval(120 * time.Second),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, 120*time.Second, c.ReconnectMaxInterval)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithReconnectMaxInterval with zero should use default",
|
|
||||||
opt: WithReconnectMaxInterval(0),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, defaultReconnectMaxInterval, c.ReconnectMaxInterval)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithCQLVersion should set CQLVersion",
|
|
||||||
opt: WithCQLVersion("3.1.0"),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, "3.1.0", c.CQLVersion)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithCQLVersion with empty should use default",
|
|
||||||
opt: WithCQLVersion(""),
|
|
||||||
validateConfig: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, defaultCqlVersion, c.CQLVersion)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
tt.opt(cfg)
|
|
||||||
tt.validateConfig(t, cfg)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMultipleOptions(t *testing.T) {
|
|
||||||
t.Run("multiple options should be applied correctly", func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
WithHosts("host1", "host2")(cfg)
|
|
||||||
WithPort(9999)(cfg)
|
|
||||||
WithKeyspace("test")(cfg)
|
|
||||||
WithAuth("user", "pass")(cfg)
|
|
||||||
|
|
||||||
assert.Equal(t, []string{"host1", "host2"}, cfg.Hosts)
|
|
||||||
assert.Equal(t, 9999, cfg.Port)
|
|
||||||
assert.Equal(t, "test", cfg.Keyspace)
|
|
||||||
assert.Equal(t, "user", cfg.Username)
|
|
||||||
assert.Equal(t, "pass", cfg.Password)
|
|
||||||
assert.True(t, cfg.UseAuth)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,151 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ErrorCode 定義錯誤代碼
|
|
||||||
type ErrorCode string
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ErrCodeNotFound 表示記錄未找到
|
|
||||||
ErrCodeNotFound ErrorCode = "NOT_FOUND"
|
|
||||||
// ErrCodeConflict 表示衝突(如唯一鍵衝突)
|
|
||||||
ErrCodeConflict ErrorCode = "CONFLICT"
|
|
||||||
// ErrCodeInvalidInput 表示輸入參數無效
|
|
||||||
ErrCodeInvalidInput ErrorCode = "INVALID_INPUT"
|
|
||||||
// ErrCodeMissingPartition 表示缺少 Partition Key
|
|
||||||
ErrCodeMissingPartition ErrorCode = "MISSING_PARTITION_KEY"
|
|
||||||
// ErrCodeNoFieldsToUpdate 表示沒有欄位需要更新
|
|
||||||
ErrCodeNoFieldsToUpdate ErrorCode = "NO_FIELDS_TO_UPDATE"
|
|
||||||
// ErrCodeMissingTableName 表示缺少 TableName 方法
|
|
||||||
ErrCodeMissingTableName ErrorCode = "MISSING_TABLE_NAME"
|
|
||||||
// ErrCodeMissingWhereCondition 表示缺少 WHERE 條件
|
|
||||||
ErrCodeMissingWhereCondition ErrorCode = "MISSING_WHERE_CONDITION"
|
|
||||||
// ErrCodeSAINotSupported 表示不支援 SAI
|
|
||||||
ErrCodeSAINotSupported ErrorCode = "SAI_NOT_SUPPORTED"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Error 是統一的錯誤類型
|
|
||||||
type Error struct {
|
|
||||||
Code ErrorCode
|
|
||||||
Message string
|
|
||||||
Table string
|
|
||||||
Err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error 實現 error 介面
|
|
||||||
func (e *Error) Error() string {
|
|
||||||
if e.Table != "" {
|
|
||||||
if e.Err != nil {
|
|
||||||
return fmt.Sprintf("cassandra[%s] (table: %s): %s: %v", e.Code, e.Table, e.Message, e.Err)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("cassandra[%s] (table: %s): %s", e.Code, e.Table, e.Message)
|
|
||||||
}
|
|
||||||
if e.Err != nil {
|
|
||||||
return fmt.Sprintf("cassandra[%s]: %s: %v", e.Code, e.Message, e.Err)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("cassandra[%s]: %s", e.Code, e.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unwrap 返回底層錯誤
|
|
||||||
func (e *Error) Unwrap() error {
|
|
||||||
return e.Err
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithTable 為錯誤添加表名資訊
|
|
||||||
func (e *Error) WithTable(table string) *Error {
|
|
||||||
return &Error{
|
|
||||||
Code: e.Code,
|
|
||||||
Message: e.Message,
|
|
||||||
Table: table,
|
|
||||||
Err: e.Err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithError 為錯誤添加底層錯誤
|
|
||||||
func (e *Error) WithError(err error) *Error {
|
|
||||||
return &Error{
|
|
||||||
Code: e.Code,
|
|
||||||
Message: e.Message,
|
|
||||||
Table: e.Table,
|
|
||||||
Err: err,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewError 創建新的錯誤
|
|
||||||
func NewError(code ErrorCode, message string) *Error {
|
|
||||||
return &Error{
|
|
||||||
Code: code,
|
|
||||||
Message: message,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 預定義錯誤
|
|
||||||
var (
|
|
||||||
// ErrNotFound 表示記錄未找到
|
|
||||||
ErrNotFound = &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "record not found",
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrInvalidInput 表示輸入參數無效
|
|
||||||
ErrInvalidInput = &Error{
|
|
||||||
Code: ErrCodeInvalidInput,
|
|
||||||
Message: "invalid input parameter",
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrNoPartitionKey 表示缺少 Partition Key
|
|
||||||
ErrNoPartitionKey = &Error{
|
|
||||||
Code: ErrCodeMissingPartition,
|
|
||||||
Message: "no partition key defined in struct",
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrMissingTableName 表示缺少 TableName 方法
|
|
||||||
ErrMissingTableName = &Error{
|
|
||||||
Code: ErrCodeMissingTableName,
|
|
||||||
Message: "struct must implement TableName() method",
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrNoFieldsToUpdate 表示沒有欄位需要更新
|
|
||||||
ErrNoFieldsToUpdate = &Error{
|
|
||||||
Code: ErrCodeNoFieldsToUpdate,
|
|
||||||
Message: "no fields to update",
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrMissingWhereCondition 表示缺少 WHERE 條件
|
|
||||||
ErrMissingWhereCondition = &Error{
|
|
||||||
Code: ErrCodeMissingWhereCondition,
|
|
||||||
Message: "operation requires at least one WHERE condition for safety",
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrMissingPartitionKey 表示 WHERE 條件中缺少 Partition Key
|
|
||||||
ErrMissingPartitionKey = &Error{
|
|
||||||
Code: ErrCodeMissingPartition,
|
|
||||||
Message: "operation requires all partition keys in WHERE clause",
|
|
||||||
}
|
|
||||||
// ErrSAINotSupported 表示不支援 SAI
|
|
||||||
ErrSAINotSupported = &Error{
|
|
||||||
Code: ErrCodeSAINotSupported,
|
|
||||||
Message: "SAI (Storage-Attached Indexing) is not supported in this Cassandra version",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// IsNotFound 檢查錯誤是否為 NotFound
|
|
||||||
func IsNotFound(err error) bool {
|
|
||||||
var e *Error
|
|
||||||
if errors.As(err, &e) {
|
|
||||||
return e.Code == ErrCodeNotFound
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsConflict 檢查錯誤是否為 Conflict
|
|
||||||
func IsConflict(err error) bool {
|
|
||||||
var e *Error
|
|
||||||
if errors.As(err, &e) {
|
|
||||||
return e.Code == ErrCodeConflict
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
@ -1,590 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestError_Error(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err *Error
|
|
||||||
want string
|
|
||||||
contains []string // 如果 want 為空,則檢查是否包含這些字串
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "error with code and message only",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "record not found",
|
|
||||||
},
|
|
||||||
want: "cassandra[NOT_FOUND]: record not found",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error with code, message and table",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "record not found",
|
|
||||||
Table: "users",
|
|
||||||
},
|
|
||||||
want: "cassandra[NOT_FOUND] (table: users): record not found",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error with code, message and underlying error",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeInvalidInput,
|
|
||||||
Message: "invalid input parameter",
|
|
||||||
Err: errors.New("validation failed"),
|
|
||||||
},
|
|
||||||
contains: []string{
|
|
||||||
"cassandra[INVALID_INPUT]",
|
|
||||||
"invalid input parameter",
|
|
||||||
"validation failed",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error with all fields",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeConflict,
|
|
||||||
Message: "acquire lock failed",
|
|
||||||
Table: "locks",
|
|
||||||
Err: errors.New("lock already exists"),
|
|
||||||
},
|
|
||||||
contains: []string{
|
|
||||||
"cassandra[CONFLICT]",
|
|
||||||
"(table: locks)",
|
|
||||||
"acquire lock failed",
|
|
||||||
"lock already exists",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error with empty message",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
},
|
|
||||||
want: "cassandra[NOT_FOUND]: ",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := tt.err.Error()
|
|
||||||
if tt.want != "" {
|
|
||||||
assert.Equal(t, tt.want, result)
|
|
||||||
} else {
|
|
||||||
for _, substr := range tt.contains {
|
|
||||||
assert.Contains(t, result, substr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestError_Unwrap(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err *Error
|
|
||||||
wantErr error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "error with underlying error",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeInvalidInput,
|
|
||||||
Message: "invalid input",
|
|
||||||
Err: errors.New("underlying error"),
|
|
||||||
},
|
|
||||||
wantErr: errors.New("underlying error"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error without underlying error",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "not found",
|
|
||||||
},
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error with nil underlying error",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "not found",
|
|
||||||
Err: nil,
|
|
||||||
},
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := tt.err.Unwrap()
|
|
||||||
if tt.wantErr == nil {
|
|
||||||
assert.Nil(t, result)
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, tt.wantErr.Error(), result.Error())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestError_WithTable(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err *Error
|
|
||||||
table string
|
|
||||||
wantCode ErrorCode
|
|
||||||
wantMsg string
|
|
||||||
wantTbl string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "add table to error without table",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "record not found",
|
|
||||||
},
|
|
||||||
table: "users",
|
|
||||||
wantCode: ErrCodeNotFound,
|
|
||||||
wantMsg: "record not found",
|
|
||||||
wantTbl: "users",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "replace existing table",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "record not found",
|
|
||||||
Table: "old_table",
|
|
||||||
},
|
|
||||||
table: "new_table",
|
|
||||||
wantCode: ErrCodeNotFound,
|
|
||||||
wantMsg: "record not found",
|
|
||||||
wantTbl: "new_table",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "add table to error with underlying error",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeInvalidInput,
|
|
||||||
Message: "invalid input",
|
|
||||||
Err: errors.New("validation failed"),
|
|
||||||
},
|
|
||||||
table: "products",
|
|
||||||
wantCode: ErrCodeInvalidInput,
|
|
||||||
wantMsg: "invalid input",
|
|
||||||
wantTbl: "products",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "add empty table",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "not found",
|
|
||||||
},
|
|
||||||
table: "",
|
|
||||||
wantCode: ErrCodeNotFound,
|
|
||||||
wantMsg: "not found",
|
|
||||||
wantTbl: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := tt.err.WithTable(tt.table)
|
|
||||||
assert.NotNil(t, result)
|
|
||||||
assert.Equal(t, tt.wantCode, result.Code)
|
|
||||||
assert.Equal(t, tt.wantMsg, result.Message)
|
|
||||||
assert.Equal(t, tt.wantTbl, result.Table)
|
|
||||||
// 確保是新的實例,不是修改原來的
|
|
||||||
assert.NotSame(t, tt.err, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestError_WithError(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err *Error
|
|
||||||
underlying error
|
|
||||||
wantCode ErrorCode
|
|
||||||
wantMsg string
|
|
||||||
wantErr error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "add underlying error to error without error",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeInvalidInput,
|
|
||||||
Message: "invalid input",
|
|
||||||
},
|
|
||||||
underlying: errors.New("validation failed"),
|
|
||||||
wantCode: ErrCodeInvalidInput,
|
|
||||||
wantMsg: "invalid input",
|
|
||||||
wantErr: errors.New("validation failed"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "replace existing underlying error",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeInvalidInput,
|
|
||||||
Message: "invalid input",
|
|
||||||
Err: errors.New("old error"),
|
|
||||||
},
|
|
||||||
underlying: errors.New("new error"),
|
|
||||||
wantCode: ErrCodeInvalidInput,
|
|
||||||
wantMsg: "invalid input",
|
|
||||||
wantErr: errors.New("new error"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "add nil underlying error",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "not found",
|
|
||||||
},
|
|
||||||
underlying: nil,
|
|
||||||
wantCode: ErrCodeNotFound,
|
|
||||||
wantMsg: "not found",
|
|
||||||
wantErr: nil,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "add error to error with table",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeConflict,
|
|
||||||
Message: "conflict",
|
|
||||||
Table: "locks",
|
|
||||||
},
|
|
||||||
underlying: errors.New("lock exists"),
|
|
||||||
wantCode: ErrCodeConflict,
|
|
||||||
wantMsg: "conflict",
|
|
||||||
wantErr: errors.New("lock exists"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := tt.err.WithError(tt.underlying)
|
|
||||||
assert.NotNil(t, result)
|
|
||||||
assert.Equal(t, tt.wantCode, result.Code)
|
|
||||||
assert.Equal(t, tt.wantMsg, result.Message)
|
|
||||||
// 確保是新的實例
|
|
||||||
assert.NotSame(t, tt.err, result)
|
|
||||||
// 檢查 underlying error
|
|
||||||
if tt.wantErr == nil {
|
|
||||||
assert.Nil(t, result.Err)
|
|
||||||
} else {
|
|
||||||
require.NotNil(t, result.Err)
|
|
||||||
assert.Equal(t, tt.wantErr.Error(), result.Err.Error())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewError(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
code ErrorCode
|
|
||||||
message string
|
|
||||||
want *Error
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "create NOT_FOUND error",
|
|
||||||
code: ErrCodeNotFound,
|
|
||||||
message: "record not found",
|
|
||||||
want: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "record not found",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "create CONFLICT error",
|
|
||||||
code: ErrCodeConflict,
|
|
||||||
message: "lock acquisition failed",
|
|
||||||
want: &Error{
|
|
||||||
Code: ErrCodeConflict,
|
|
||||||
Message: "lock acquisition failed",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "create INVALID_INPUT error",
|
|
||||||
code: ErrCodeInvalidInput,
|
|
||||||
message: "invalid parameter",
|
|
||||||
want: &Error{
|
|
||||||
Code: ErrCodeInvalidInput,
|
|
||||||
Message: "invalid parameter",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "create error with empty message",
|
|
||||||
code: ErrCodeNotFound,
|
|
||||||
message: "",
|
|
||||||
want: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := NewError(tt.code, tt.message)
|
|
||||||
assert.NotNil(t, result)
|
|
||||||
assert.Equal(t, tt.want.Code, result.Code)
|
|
||||||
assert.Equal(t, tt.want.Message, result.Message)
|
|
||||||
assert.Empty(t, result.Table)
|
|
||||||
assert.Nil(t, result.Err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsNotFound(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err error
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Error with NOT_FOUND code",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "record not found",
|
|
||||||
},
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error with CONFLICT code",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeConflict,
|
|
||||||
Message: "conflict",
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error with INVALID_INPUT code",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeInvalidInput,
|
|
||||||
Message: "invalid input",
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "wrapped Error with NOT_FOUND code",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "record not found",
|
|
||||||
Err: errors.New("underlying error"),
|
|
||||||
},
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "standard error",
|
|
||||||
err: errors.New("standard error"),
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nil error",
|
|
||||||
err: nil,
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "predefined ErrNotFound",
|
|
||||||
err: ErrNotFound,
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "predefined ErrNotFound with table",
|
|
||||||
err: ErrNotFound.WithTable("users"),
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := IsNotFound(tt.err)
|
|
||||||
assert.Equal(t, tt.want, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsConflict(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err error
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Error with CONFLICT code",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeConflict,
|
|
||||||
Message: "conflict",
|
|
||||||
},
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error with NOT_FOUND code",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeNotFound,
|
|
||||||
Message: "record not found",
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error with INVALID_INPUT code",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeInvalidInput,
|
|
||||||
Message: "invalid input",
|
|
||||||
},
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "wrapped Error with CONFLICT code",
|
|
||||||
err: &Error{
|
|
||||||
Code: ErrCodeConflict,
|
|
||||||
Message: "conflict",
|
|
||||||
Err: errors.New("underlying error"),
|
|
||||||
},
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "standard error",
|
|
||||||
err: errors.New("standard error"),
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nil error",
|
|
||||||
err: nil,
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "NewError with CONFLICT code",
|
|
||||||
err: NewError(ErrCodeConflict, "lock failed"),
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := IsConflict(tt.err)
|
|
||||||
assert.Equal(t, tt.want, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPredefinedErrors(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err *Error
|
|
||||||
wantCode ErrorCode
|
|
||||||
wantMsg string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "ErrNotFound",
|
|
||||||
err: ErrNotFound,
|
|
||||||
wantCode: ErrCodeNotFound,
|
|
||||||
wantMsg: "record not found",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ErrInvalidInput",
|
|
||||||
err: ErrInvalidInput,
|
|
||||||
wantCode: ErrCodeInvalidInput,
|
|
||||||
wantMsg: "invalid input parameter",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ErrNoPartitionKey",
|
|
||||||
err: ErrNoPartitionKey,
|
|
||||||
wantCode: ErrCodeMissingPartition,
|
|
||||||
wantMsg: "no partition key defined in struct",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ErrMissingTableName",
|
|
||||||
err: ErrMissingTableName,
|
|
||||||
wantCode: ErrCodeMissingTableName,
|
|
||||||
wantMsg: "struct must implement TableName() method",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ErrNoFieldsToUpdate",
|
|
||||||
err: ErrNoFieldsToUpdate,
|
|
||||||
wantCode: ErrCodeNoFieldsToUpdate,
|
|
||||||
wantMsg: "no fields to update",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ErrMissingWhereCondition",
|
|
||||||
err: ErrMissingWhereCondition,
|
|
||||||
wantCode: ErrCodeMissingWhereCondition,
|
|
||||||
wantMsg: "operation requires at least one WHERE condition for safety",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ErrMissingPartitionKey",
|
|
||||||
err: ErrMissingPartitionKey,
|
|
||||||
wantCode: ErrCodeMissingPartition,
|
|
||||||
wantMsg: "operation requires all partition keys in WHERE clause",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
assert.NotNil(t, tt.err)
|
|
||||||
assert.Equal(t, tt.wantCode, tt.err.Code)
|
|
||||||
assert.Equal(t, tt.wantMsg, tt.err.Message)
|
|
||||||
assert.Empty(t, tt.err.Table)
|
|
||||||
assert.Nil(t, tt.err.Err)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestError_Chaining(t *testing.T) {
|
|
||||||
t.Run("chain WithTable and WithError", func(t *testing.T) {
|
|
||||||
err := NewError(ErrCodeNotFound, "record not found").
|
|
||||||
WithTable("users").
|
|
||||||
WithError(errors.New("database error"))
|
|
||||||
|
|
||||||
assert.Equal(t, ErrCodeNotFound, err.Code)
|
|
||||||
assert.Equal(t, "record not found", err.Message)
|
|
||||||
assert.Equal(t, "users", err.Table)
|
|
||||||
assert.NotNil(t, err.Err)
|
|
||||||
assert.Equal(t, "database error", err.Err.Error())
|
|
||||||
assert.True(t, IsNotFound(err))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("chain multiple WithTable calls", func(t *testing.T) {
|
|
||||||
err1 := ErrNotFound.WithTable("table1")
|
|
||||||
err2 := err1.WithTable("table2")
|
|
||||||
|
|
||||||
assert.Equal(t, "table1", err1.Table)
|
|
||||||
assert.Equal(t, "table2", err2.Table)
|
|
||||||
assert.NotSame(t, err1, err2)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("chain multiple WithError calls", func(t *testing.T) {
|
|
||||||
err1 := ErrInvalidInput.WithError(errors.New("error1"))
|
|
||||||
err2 := err1.WithError(errors.New("error2"))
|
|
||||||
|
|
||||||
assert.Equal(t, "error1", err1.Err.Error())
|
|
||||||
assert.Equal(t, "error2", err2.Err.Error())
|
|
||||||
assert.NotSame(t, err1, err2)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestError_ErrorsAs(t *testing.T) {
|
|
||||||
t.Run("errors.As works with Error", func(t *testing.T) {
|
|
||||||
err := ErrNotFound.WithTable("users")
|
|
||||||
var target *Error
|
|
||||||
ok := errors.As(err, &target)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.NotNil(t, target)
|
|
||||||
assert.Equal(t, ErrCodeNotFound, target.Code)
|
|
||||||
assert.Equal(t, "users", target.Table)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("errors.As works with wrapped Error", func(t *testing.T) {
|
|
||||||
underlying := errors.New("underlying error")
|
|
||||||
err := ErrInvalidInput.WithError(underlying)
|
|
||||||
var target *Error
|
|
||||||
ok := errors.As(err, &target)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.NotNil(t, target)
|
|
||||||
assert.Equal(t, ErrCodeInvalidInput, target.Code)
|
|
||||||
assert.Equal(t, underlying, target.Err)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("errors.Is works with Error", func(t *testing.T) {
|
|
||||||
err := ErrNotFound
|
|
||||||
assert.True(t, errors.Is(err, ErrNotFound))
|
|
||||||
assert.False(t, errors.Is(err, ErrInvalidInput))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
@ -1,120 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
"github.com/scylladb/gocqlx/v2/qb"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
defaultLockTTLSec = 30
|
|
||||||
defaultLockRetry = 3
|
|
||||||
lockBaseDelay = 100 * time.Millisecond
|
|
||||||
)
|
|
||||||
|
|
||||||
// LockOption 用來設定 TryLock 的 TTL 行為
|
|
||||||
type LockOption func(*lockOptions)
|
|
||||||
|
|
||||||
type lockOptions struct {
|
|
||||||
ttlSeconds int // TTL,單位秒;<=0 代表不 expire
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithLockTTL 設定鎖的 TTL
|
|
||||||
func WithLockTTL(d time.Duration) LockOption {
|
|
||||||
return func(o *lockOptions) {
|
|
||||||
o.ttlSeconds = int(d.Seconds())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithNoLockExpire 永不自動解鎖
|
|
||||||
func WithNoLockExpire() LockOption {
|
|
||||||
return func(o *lockOptions) {
|
|
||||||
o.ttlSeconds = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TryLock 嘗試在表上插入一筆唯一鍵(IF NOT EXISTS)作為鎖
|
|
||||||
// 預設 30 秒 TTL,可透過 option 調整或取消 TTL
|
|
||||||
func (r *repository[T]) TryLock(ctx context.Context, doc T, opts ...LockOption) error {
|
|
||||||
// 組合 option
|
|
||||||
options := &lockOptions{ttlSeconds: defaultLockTTLSec}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(options)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 建 TTL 子句
|
|
||||||
builder := qb.Insert(r.table).
|
|
||||||
Unique(). // IF NOT EXISTS
|
|
||||||
Columns(r.metadata.Columns...)
|
|
||||||
|
|
||||||
if options.ttlSeconds > 0 {
|
|
||||||
ttl := time.Duration(options.ttlSeconds) * time.Second
|
|
||||||
builder = builder.TTL(ttl)
|
|
||||||
}
|
|
||||||
stmt, names := builder.ToCql()
|
|
||||||
|
|
||||||
// 執行 CAS
|
|
||||||
q := r.db.session.Query(stmt, names).BindStruct(doc).
|
|
||||||
WithContext(ctx).
|
|
||||||
WithTimestamp(time.Now().UnixNano() / 1e3).
|
|
||||||
SerialConsistency(gocql.Serial)
|
|
||||||
|
|
||||||
applied, err := q.ExecCASRelease()
|
|
||||||
if err != nil {
|
|
||||||
return ErrInvalidInput.WithTable(r.table).WithError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !applied {
|
|
||||||
return NewError(ErrCodeConflict, "acquire lock failed").WithTable(r.table)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnLock 釋放鎖,其實就是 Delete
|
|
||||||
func (r *repository[T]) UnLock(ctx context.Context, doc T) error {
|
|
||||||
var lastErr error
|
|
||||||
|
|
||||||
for i := 0; i < defaultLockRetry; i++ {
|
|
||||||
builder := qb.Delete(r.table).Existing()
|
|
||||||
|
|
||||||
// 動態添加 WHERE 條件(使用 Partition Key)
|
|
||||||
for _, key := range r.metadata.PartKey {
|
|
||||||
builder = builder.Where(qb.Eq(key))
|
|
||||||
}
|
|
||||||
stmt, names := builder.ToCql()
|
|
||||||
q := r.db.session.Query(stmt, names).BindStruct(doc).
|
|
||||||
WithContext(ctx).
|
|
||||||
WithTimestamp(time.Now().UnixNano() / 1e3).
|
|
||||||
SerialConsistency(gocql.Serial)
|
|
||||||
|
|
||||||
applied, err := q.ExecCASRelease()
|
|
||||||
if err == nil && applied {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
lastErr = fmt.Errorf("unlock error: %w", err)
|
|
||||||
} else if !applied {
|
|
||||||
lastErr = fmt.Errorf("unlock not applied: row not found or not visible yet")
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(lockBaseDelay * time.Duration(1<<i)) // 100ms → 200ms → 400ms
|
|
||||||
}
|
|
||||||
|
|
||||||
return ErrInvalidInput.WithTable(r.table).WithError(
|
|
||||||
fmt.Errorf("unlock failed after %d retries: %w", defaultLockRetry, lastErr),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsLockFailed 檢查錯誤是否為獲取鎖失敗
|
|
||||||
func IsLockFailed(err error) bool {
|
|
||||||
var e *Error
|
|
||||||
if errors.As(err, &e) {
|
|
||||||
return e.Code == ErrCodeConflict && e.Message == "acquire lock failed"
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
@ -1,503 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestWithLockTTL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
duration time.Duration
|
|
||||||
wantTTL int
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "30 seconds TTL",
|
|
||||||
duration: 30 * time.Second,
|
|
||||||
wantTTL: 30,
|
|
||||||
description: "should set TTL to 30 seconds",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "1 minute TTL",
|
|
||||||
duration: 1 * time.Minute,
|
|
||||||
wantTTL: 60,
|
|
||||||
description: "should set TTL to 60 seconds",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "5 minutes TTL",
|
|
||||||
duration: 5 * time.Minute,
|
|
||||||
wantTTL: 300,
|
|
||||||
description: "should set TTL to 300 seconds",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "1 hour TTL",
|
|
||||||
duration: 1 * time.Hour,
|
|
||||||
wantTTL: 3600,
|
|
||||||
description: "should set TTL to 3600 seconds",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero duration",
|
|
||||||
duration: 0,
|
|
||||||
wantTTL: 0,
|
|
||||||
description: "should set TTL to 0",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative duration",
|
|
||||||
duration: -10 * time.Second,
|
|
||||||
wantTTL: -10,
|
|
||||||
description: "should set TTL to negative value",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "fractional seconds",
|
|
||||||
duration: 1500 * time.Millisecond,
|
|
||||||
wantTTL: 1,
|
|
||||||
description: "should round down fractional seconds",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
opt := WithLockTTL(tt.duration)
|
|
||||||
options := &lockOptions{}
|
|
||||||
opt(options)
|
|
||||||
assert.Equal(t, tt.wantTTL, options.ttlSeconds, tt.description)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithNoLockExpire(t *testing.T) {
|
|
||||||
t.Run("should set TTL to 0", func(t *testing.T) {
|
|
||||||
opt := WithNoLockExpire()
|
|
||||||
options := &lockOptions{ttlSeconds: 30} // 先設置一個值
|
|
||||||
opt(options)
|
|
||||||
assert.Equal(t, 0, options.ttlSeconds)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("should override existing TTL", func(t *testing.T) {
|
|
||||||
opt := WithNoLockExpire()
|
|
||||||
options := &lockOptions{ttlSeconds: 100}
|
|
||||||
opt(options)
|
|
||||||
assert.Equal(t, 0, options.ttlSeconds)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLockOptions_Combination(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
opts []LockOption
|
|
||||||
wantTTL int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "WithLockTTL then WithNoLockExpire",
|
|
||||||
opts: []LockOption{WithLockTTL(60 * time.Second), WithNoLockExpire()},
|
|
||||||
wantTTL: 0, // WithNoLockExpire should override
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "WithNoLockExpire then WithLockTTL",
|
|
||||||
opts: []LockOption{WithNoLockExpire(), WithLockTTL(60 * time.Second)},
|
|
||||||
wantTTL: 60, // WithLockTTL should override
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple WithLockTTL calls",
|
|
||||||
opts: []LockOption{WithLockTTL(30 * time.Second), WithLockTTL(60 * time.Second)},
|
|
||||||
wantTTL: 60, // Last one wins
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple WithNoLockExpire calls",
|
|
||||||
opts: []LockOption{WithNoLockExpire(), WithNoLockExpire()},
|
|
||||||
wantTTL: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty options should use default",
|
|
||||||
opts: []LockOption{},
|
|
||||||
wantTTL: defaultLockTTLSec,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
options := &lockOptions{ttlSeconds: defaultLockTTLSec}
|
|
||||||
for _, opt := range tt.opts {
|
|
||||||
opt(options)
|
|
||||||
}
|
|
||||||
assert.Equal(t, tt.wantTTL, options.ttlSeconds)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsLockFailed(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err error
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Error with CONFLICT code and correct message",
|
|
||||||
err: NewError(ErrCodeConflict, "acquire lock failed"),
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error with CONFLICT code and correct message with table",
|
|
||||||
err: NewError(ErrCodeConflict, "acquire lock failed").WithTable("locks"),
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error with CONFLICT code but wrong message",
|
|
||||||
err: NewError(ErrCodeConflict, "different message"),
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error with NOT_FOUND code and correct message",
|
|
||||||
err: NewError(ErrCodeNotFound, "acquire lock failed"),
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error with INVALID_INPUT code",
|
|
||||||
err: ErrInvalidInput,
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "wrapped Error with CONFLICT code and correct message",
|
|
||||||
err: NewError(ErrCodeConflict, "acquire lock failed").
|
|
||||||
WithError(errors.New("underlying error")),
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "standard error",
|
|
||||||
err: errors.New("standard error"),
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nil error",
|
|
||||||
err: nil,
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error with CONFLICT code but empty message",
|
|
||||||
err: NewError(ErrCodeConflict, ""),
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error with CONFLICT code and similar but different message",
|
|
||||||
err: NewError(ErrCodeConflict, "acquire lock failed!"),
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := IsLockFailed(tt.err)
|
|
||||||
assert.Equal(t, tt.want, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLockConstants(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
constant interface{}
|
|
||||||
expected interface{}
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "defaultLockTTLSec should be 30",
|
|
||||||
constant: defaultLockTTLSec,
|
|
||||||
expected: 30,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "defaultLockRetry should be 3",
|
|
||||||
constant: defaultLockRetry,
|
|
||||||
expected: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "lockBaseDelay should be 100ms",
|
|
||||||
constant: lockBaseDelay,
|
|
||||||
expected: 100 * time.Millisecond,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
assert.Equal(t, tt.expected, tt.constant)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLockOptions_DefaultValues(t *testing.T) {
|
|
||||||
t.Run("default lockOptions should have default TTL", func(t *testing.T) {
|
|
||||||
options := &lockOptions{ttlSeconds: defaultLockTTLSec}
|
|
||||||
assert.Equal(t, defaultLockTTLSec, options.ttlSeconds)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("lockOptions with zero TTL", func(t *testing.T) {
|
|
||||||
options := &lockOptions{ttlSeconds: 0}
|
|
||||||
assert.Equal(t, 0, options.ttlSeconds)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("lockOptions with negative TTL", func(t *testing.T) {
|
|
||||||
options := &lockOptions{ttlSeconds: -1}
|
|
||||||
assert.Equal(t, -1, options.ttlSeconds)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTryLock_ErrorScenarios(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
description string
|
|
||||||
// 注意:實際的 TryLock 測試需要 mock session 或實際的資料庫連接
|
|
||||||
// 這裡只是定義測試結構
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "successful lock acquisition",
|
|
||||||
description: "should return nil when lock is successfully acquired",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "lock already exists",
|
|
||||||
description: "should return CONFLICT error when lock already exists",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "database error",
|
|
||||||
description: "should return INVALID_INPUT error with underlying error when database operation fails",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "context cancellation",
|
|
||||||
description: "should respect context cancellation",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with custom TTL",
|
|
||||||
description: "should use custom TTL when provided",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with no expire",
|
|
||||||
description: "should not set TTL when WithNoLockExpire is used",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要 mock session 或實際的資料庫連接
|
|
||||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnLock_ErrorScenarios(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
description string
|
|
||||||
// 注意:實際的 UnLock 測試需要 mock session 或實際的資料庫連接
|
|
||||||
// 這裡只是定義測試結構
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "successful unlock",
|
|
||||||
description: "should return nil when lock is successfully released",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "lock not found",
|
|
||||||
description: "should retry when lock is not found",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "database error",
|
|
||||||
description: "should retry on database error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "max retries exceeded",
|
|
||||||
description: "should return error after max retries",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "context cancellation",
|
|
||||||
description: "should respect context cancellation",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "exponential backoff",
|
|
||||||
description: "should use exponential backoff between retries",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要 mock session 或實際的資料庫連接
|
|
||||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLockOption_Type(t *testing.T) {
|
|
||||||
t.Run("WithLockTTL should return LockOption", func(t *testing.T) {
|
|
||||||
opt := WithLockTTL(30 * time.Second)
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
// 驗證它是一個函數
|
|
||||||
var lockOpt LockOption = opt
|
|
||||||
assert.NotNil(t, lockOpt)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("WithNoLockExpire should return LockOption", func(t *testing.T) {
|
|
||||||
opt := WithNoLockExpire()
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
// 驗證它是一個函數
|
|
||||||
var lockOpt LockOption = opt
|
|
||||||
assert.NotNil(t, lockOpt)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLockOptions_ApplyOrder(t *testing.T) {
|
|
||||||
t.Run("last option should win", func(t *testing.T) {
|
|
||||||
options := &lockOptions{ttlSeconds: defaultLockTTLSec}
|
|
||||||
|
|
||||||
WithLockTTL(60 * time.Second)(options)
|
|
||||||
assert.Equal(t, 60, options.ttlSeconds)
|
|
||||||
|
|
||||||
WithNoLockExpire()(options)
|
|
||||||
assert.Equal(t, 0, options.ttlSeconds)
|
|
||||||
|
|
||||||
WithLockTTL(120 * time.Second)(options)
|
|
||||||
assert.Equal(t, 120, options.ttlSeconds)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsLockFailed_EdgeCases(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err error
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Error with CONFLICT code, correct message, and underlying error",
|
|
||||||
err: NewError(ErrCodeConflict, "acquire lock failed").
|
|
||||||
WithTable("locks").
|
|
||||||
WithError(errors.New("database error")),
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error with CONFLICT code but message with extra spaces",
|
|
||||||
err: NewError(ErrCodeConflict, " acquire lock failed "),
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Error with CONFLICT code but message with different case",
|
|
||||||
err: NewError(ErrCodeConflict, "Acquire Lock Failed"),
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "chained errors with CONFLICT",
|
|
||||||
err: func() error {
|
|
||||||
err1 := NewError(ErrCodeConflict, "acquire lock failed")
|
|
||||||
err2 := errors.New("wrapped")
|
|
||||||
return errors.Join(err1, err2)
|
|
||||||
}(),
|
|
||||||
want: true, // errors.Join preserves Error type and errors.As can find it
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := IsLockFailed(tt.err)
|
|
||||||
assert.Equal(t, tt.want, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLockOptions_ZeroValue(t *testing.T) {
|
|
||||||
t.Run("zero value lockOptions", func(t *testing.T) {
|
|
||||||
var options lockOptions
|
|
||||||
assert.Equal(t, 0, options.ttlSeconds)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("apply option to zero value", func(t *testing.T) {
|
|
||||||
var options lockOptions
|
|
||||||
WithLockTTL(30 * time.Second)(&options)
|
|
||||||
assert.Equal(t, 30, options.ttlSeconds)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLockRetryDelay(t *testing.T) {
|
|
||||||
t.Run("verify exponential backoff calculation", func(t *testing.T) {
|
|
||||||
// 驗證重試延遲的計算邏輯
|
|
||||||
// 100ms → 200ms → 400ms
|
|
||||||
expectedDelays := []time.Duration{
|
|
||||||
lockBaseDelay * time.Duration(1<<0), // 100ms * 1 = 100ms
|
|
||||||
lockBaseDelay * time.Duration(1<<1), // 100ms * 2 = 200ms
|
|
||||||
lockBaseDelay * time.Duration(1<<2), // 100ms * 4 = 400ms
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.Equal(t, 100*time.Millisecond, expectedDelays[0])
|
|
||||||
assert.Equal(t, 200*time.Millisecond, expectedDelays[1])
|
|
||||||
assert.Equal(t, 400*time.Millisecond, expectedDelays[2])
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLockOption_InterfaceCompliance(t *testing.T) {
|
|
||||||
t.Run("LockOption should be a function type", func(t *testing.T) {
|
|
||||||
// 驗證 LockOption 是一個函數類型
|
|
||||||
var fn func(*lockOptions) = WithLockTTL(30 * time.Second)
|
|
||||||
assert.NotNil(t, fn)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("LockOption can be assigned from WithLockTTL", func(t *testing.T) {
|
|
||||||
var opt LockOption = WithLockTTL(30 * time.Second)
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("LockOption can be assigned from WithNoLockExpire", func(t *testing.T) {
|
|
||||||
var opt LockOption = WithNoLockExpire()
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLockOptions_RealWorldScenarios(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
scenario func(*lockOptions)
|
|
||||||
wantTTL int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "short-lived lock (5 seconds)",
|
|
||||||
scenario: func(o *lockOptions) {
|
|
||||||
WithLockTTL(5 * time.Second)(o)
|
|
||||||
},
|
|
||||||
wantTTL: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "medium-lived lock (5 minutes)",
|
|
||||||
scenario: func(o *lockOptions) {
|
|
||||||
WithLockTTL(5 * time.Minute)(o)
|
|
||||||
},
|
|
||||||
wantTTL: 300,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "long-lived lock (1 hour)",
|
|
||||||
scenario: func(o *lockOptions) {
|
|
||||||
WithLockTTL(1 * time.Hour)(o)
|
|
||||||
},
|
|
||||||
wantTTL: 3600,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "permanent lock",
|
|
||||||
scenario: func(o *lockOptions) {
|
|
||||||
WithNoLockExpire()(o)
|
|
||||||
},
|
|
||||||
wantTTL: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "default lock",
|
|
||||||
scenario: func(o *lockOptions) {
|
|
||||||
// 不應用任何選項,使用預設值
|
|
||||||
},
|
|
||||||
wantTTL: defaultLockTTLSec,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
options := &lockOptions{ttlSeconds: defaultLockTTLSec}
|
|
||||||
tt.scenario(options)
|
|
||||||
assert.Equal(t, tt.wantTTL, options.ttlSeconds)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,136 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"sync"
|
|
||||||
"unicode"
|
|
||||||
|
|
||||||
"github.com/scylladb/gocqlx/v2/table"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// metadataCache 快取已生成的 Metadata,避免重複反射解析
|
|
||||||
// key: tableName + ":" + structType (不包含 keyspace,因為同一個 struct 在不同 keyspace 結構相同)
|
|
||||||
metadataCache sync.Map
|
|
||||||
)
|
|
||||||
|
|
||||||
type cachedMetadata struct {
|
|
||||||
columns []string
|
|
||||||
partKeys []string
|
|
||||||
sortKeys []string
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateMetadata 根據傳入的 struct 產生 table.Metadata
|
|
||||||
// 使用快取機制避免重複反射解析,提升效能
|
|
||||||
func generateMetadata[T Table](doc T, keyspace string) (table.Metadata, error) {
|
|
||||||
// 取得型別資訊
|
|
||||||
t := reflect.TypeOf(doc)
|
|
||||||
if t.Kind() == reflect.Ptr {
|
|
||||||
t = t.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 取得表名稱
|
|
||||||
tableName := doc.TableName()
|
|
||||||
if tableName == "" {
|
|
||||||
return table.Metadata{}, ErrMissingTableName
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建快取 key: tableName:structType (不包含 keyspace)
|
|
||||||
cacheKey := fmt.Sprintf("%s:%s", tableName, t.String())
|
|
||||||
|
|
||||||
// 檢查快取
|
|
||||||
if cached, ok := metadataCache.Load(cacheKey); ok {
|
|
||||||
cachedMeta := cached.(cachedMetadata)
|
|
||||||
if cachedMeta.err != nil {
|
|
||||||
return table.Metadata{}, cachedMeta.err
|
|
||||||
}
|
|
||||||
// 從快取構建 metadata,動態加上 keyspace
|
|
||||||
meta := table.Metadata{
|
|
||||||
Name: fmt.Sprintf("%s.%s", keyspace, tableName),
|
|
||||||
Columns: make([]string, len(cachedMeta.columns)),
|
|
||||||
PartKey: make([]string, len(cachedMeta.partKeys)),
|
|
||||||
SortKey: make([]string, len(cachedMeta.sortKeys)),
|
|
||||||
}
|
|
||||||
copy(meta.Columns, cachedMeta.columns)
|
|
||||||
copy(meta.PartKey, cachedMeta.partKeys)
|
|
||||||
copy(meta.SortKey, cachedMeta.sortKeys)
|
|
||||||
return meta, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 快取未命中,生成 metadata
|
|
||||||
columns := make([]string, 0, t.NumField())
|
|
||||||
partKeys := make([]string, 0, t.NumField())
|
|
||||||
sortKeys := make([]string, 0, t.NumField())
|
|
||||||
|
|
||||||
// 遍歷所有 exported 欄位
|
|
||||||
for i := 0; i < t.NumField(); i++ {
|
|
||||||
field := t.Field(i)
|
|
||||||
// 跳過 unexported 欄位
|
|
||||||
if field.PkgPath != "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// 如果欄位有標記 db:"-" 則跳過
|
|
||||||
if tag := field.Tag.Get(DBFiledName); tag == "-" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// 取得欄位名稱
|
|
||||||
colName := field.Tag.Get(DBFiledName)
|
|
||||||
if colName == "" {
|
|
||||||
colName = toSnakeCase(field.Name)
|
|
||||||
}
|
|
||||||
columns = append(columns, colName)
|
|
||||||
// 若有 partition_key:"true" 標記,加入 PartKey
|
|
||||||
if field.Tag.Get(Pk) == "true" {
|
|
||||||
partKeys = append(partKeys, colName)
|
|
||||||
}
|
|
||||||
// 若有 clustering_key:"true" 標記,加入 SortKey
|
|
||||||
if field.Tag.Get(ClusterKey) == "true" {
|
|
||||||
sortKeys = append(sortKeys, colName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(partKeys) == 0 {
|
|
||||||
err := ErrNoPartitionKey
|
|
||||||
// 快取錯誤結果
|
|
||||||
metadataCache.Store(cacheKey, cachedMetadata{err: err})
|
|
||||||
return table.Metadata{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 快取成功結果(只存結構資訊,不包含 keyspace)
|
|
||||||
cachedMeta := cachedMetadata{
|
|
||||||
columns: make([]string, len(columns)),
|
|
||||||
partKeys: make([]string, len(partKeys)),
|
|
||||||
sortKeys: make([]string, len(sortKeys)),
|
|
||||||
}
|
|
||||||
copy(cachedMeta.columns, columns)
|
|
||||||
copy(cachedMeta.partKeys, partKeys)
|
|
||||||
copy(cachedMeta.sortKeys, sortKeys)
|
|
||||||
metadataCache.Store(cacheKey, cachedMeta)
|
|
||||||
|
|
||||||
// 組合並返回 Metadata(包含 keyspace)
|
|
||||||
meta := table.Metadata{
|
|
||||||
Name: fmt.Sprintf("%s.%s", keyspace, tableName),
|
|
||||||
Columns: columns,
|
|
||||||
PartKey: partKeys,
|
|
||||||
SortKey: sortKeys,
|
|
||||||
}
|
|
||||||
|
|
||||||
return meta, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// toSnakeCase 將 CamelCase 字串轉換為 snake_case
|
|
||||||
func toSnakeCase(s string) string {
|
|
||||||
var result []rune
|
|
||||||
for i, r := range s {
|
|
||||||
if unicode.IsUpper(r) {
|
|
||||||
if i > 0 {
|
|
||||||
result = append(result, '_')
|
|
||||||
}
|
|
||||||
result = append(result, unicode.ToLower(r))
|
|
||||||
} else {
|
|
||||||
result = append(result, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return string(result)
|
|
||||||
}
|
|
||||||
|
|
@ -1,500 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/scylladb/gocqlx/v2/table"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestToSnakeCase(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "simple CamelCase",
|
|
||||||
input: "UserName",
|
|
||||||
expected: "user_name",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single word",
|
|
||||||
input: "User",
|
|
||||||
expected: "user",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple words",
|
|
||||||
input: "UserAccountBalance",
|
|
||||||
expected: "user_account_balance",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "already lowercase",
|
|
||||||
input: "username",
|
|
||||||
expected: "username",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "all uppercase",
|
|
||||||
input: "USERNAME",
|
|
||||||
expected: "u_s_e_r_n_a_m_e",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed case",
|
|
||||||
input: "XMLParser",
|
|
||||||
expected: "x_m_l_parser",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty string",
|
|
||||||
input: "",
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single character",
|
|
||||||
input: "A",
|
|
||||||
expected: "a",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "with numbers",
|
|
||||||
input: "UserID123",
|
|
||||||
expected: "user_i_d123",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ID at end",
|
|
||||||
input: "UserID",
|
|
||||||
expected: "user_i_d",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ID at start",
|
|
||||||
input: "IDUser",
|
|
||||||
expected: "i_d_user",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := toSnakeCase(tt.input)
|
|
||||||
assert.Equal(t, tt.expected, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 測試用的 struct 定義
|
|
||||||
type testUser struct {
|
|
||||||
ID string `db:"id" partition_key:"true"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
Email string `db:"email"`
|
|
||||||
CreatedAt int64 `db:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t testUser) TableName() string {
|
|
||||||
return "users"
|
|
||||||
}
|
|
||||||
|
|
||||||
type testUserNoTableName struct {
|
|
||||||
ID string `db:"id" partition_key:"true"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t testUserNoTableName) TableName() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
type testUserNoPartitionKey struct {
|
|
||||||
ID string `db:"id"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t testUserNoPartitionKey) TableName() string {
|
|
||||||
return "users"
|
|
||||||
}
|
|
||||||
|
|
||||||
type testUserWithClusteringKey struct {
|
|
||||||
ID string `db:"id" partition_key:"true"`
|
|
||||||
Timestamp int64 `db:"timestamp" clustering_key:"true"`
|
|
||||||
Data string `db:"data"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t testUserWithClusteringKey) TableName() string {
|
|
||||||
return "events"
|
|
||||||
}
|
|
||||||
|
|
||||||
type testUserWithMultiplePartitionKeys struct {
|
|
||||||
UserID string `db:"user_id" partition_key:"true"`
|
|
||||||
AccountID string `db:"account_id" partition_key:"true"`
|
|
||||||
Balance int64 `db:"balance"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t testUserWithMultiplePartitionKeys) TableName() string {
|
|
||||||
return "accounts"
|
|
||||||
}
|
|
||||||
|
|
||||||
type testUserWithAutoSnakeCase struct {
|
|
||||||
UserID string `db:"user_id" partition_key:"true"`
|
|
||||||
AccountName string // 沒有 db tag,應該自動轉換為 snake_case
|
|
||||||
EmailAddr string `db:"email_addr"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t testUserWithAutoSnakeCase) TableName() string {
|
|
||||||
return "profiles"
|
|
||||||
}
|
|
||||||
|
|
||||||
type testUserWithIgnoredField struct {
|
|
||||||
ID string `db:"id" partition_key:"true"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
Password string `db:"-"` // 應該被忽略
|
|
||||||
CreatedAt int64 `db:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t testUserWithIgnoredField) TableName() string {
|
|
||||||
return "users"
|
|
||||||
}
|
|
||||||
|
|
||||||
type testUserUnexported struct {
|
|
||||||
ID string `db:"id" partition_key:"true"`
|
|
||||||
name string // unexported,應該被忽略
|
|
||||||
Email string `db:"email"`
|
|
||||||
createdAt int64 // unexported,應該被忽略
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t testUserUnexported) TableName() string {
|
|
||||||
return "users"
|
|
||||||
}
|
|
||||||
|
|
||||||
type testUserPointer struct {
|
|
||||||
ID *string `db:"id" partition_key:"true"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t testUserPointer) TableName() string {
|
|
||||||
return "users"
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateMetadata_Basic(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
doc interface{}
|
|
||||||
keyspace string
|
|
||||||
wantErr bool
|
|
||||||
errCode ErrorCode
|
|
||||||
checkFunc func(*testing.T, table.Metadata, string)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid user struct",
|
|
||||||
doc: testUser{ID: "1", Name: "Alice"},
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
wantErr: false,
|
|
||||||
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
|
|
||||||
assert.Equal(t, keyspace+".users", meta.Name)
|
|
||||||
assert.Contains(t, meta.Columns, "id")
|
|
||||||
assert.Contains(t, meta.Columns, "name")
|
|
||||||
assert.Contains(t, meta.Columns, "email")
|
|
||||||
assert.Contains(t, meta.Columns, "created_at")
|
|
||||||
assert.Contains(t, meta.PartKey, "id")
|
|
||||||
assert.Empty(t, meta.SortKey)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user with clustering key",
|
|
||||||
doc: testUserWithClusteringKey{ID: "1", Timestamp: 1234567890},
|
|
||||||
keyspace: "events_db",
|
|
||||||
wantErr: false,
|
|
||||||
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
|
|
||||||
assert.Equal(t, keyspace+".events", meta.Name)
|
|
||||||
assert.Contains(t, meta.PartKey, "id")
|
|
||||||
assert.Contains(t, meta.SortKey, "timestamp")
|
|
||||||
assert.Contains(t, meta.Columns, "data")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user with multiple partition keys",
|
|
||||||
doc: testUserWithMultiplePartitionKeys{UserID: "1", AccountID: "2"},
|
|
||||||
keyspace: "finance",
|
|
||||||
wantErr: false,
|
|
||||||
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
|
|
||||||
assert.Equal(t, keyspace+".accounts", meta.Name)
|
|
||||||
assert.Contains(t, meta.PartKey, "user_id")
|
|
||||||
assert.Contains(t, meta.PartKey, "account_id")
|
|
||||||
assert.Len(t, meta.PartKey, 2)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user with auto snake_case conversion",
|
|
||||||
doc: testUserWithAutoSnakeCase{UserID: "1", AccountName: "test"},
|
|
||||||
keyspace: "test",
|
|
||||||
wantErr: false,
|
|
||||||
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
|
|
||||||
assert.Contains(t, meta.Columns, "account_name") // 自動轉換
|
|
||||||
assert.Contains(t, meta.Columns, "user_id")
|
|
||||||
assert.Contains(t, meta.Columns, "email_addr")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user with ignored field",
|
|
||||||
doc: testUserWithIgnoredField{ID: "1", Name: "Alice"},
|
|
||||||
keyspace: "test",
|
|
||||||
wantErr: false,
|
|
||||||
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
|
|
||||||
assert.Contains(t, meta.Columns, "id")
|
|
||||||
assert.Contains(t, meta.Columns, "name")
|
|
||||||
assert.Contains(t, meta.Columns, "created_at")
|
|
||||||
assert.NotContains(t, meta.Columns, "password") // 應該被忽略
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user with unexported fields",
|
|
||||||
doc: testUserUnexported{ID: "1", Email: "test@example.com"},
|
|
||||||
keyspace: "test",
|
|
||||||
wantErr: false,
|
|
||||||
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
|
|
||||||
assert.Contains(t, meta.Columns, "id")
|
|
||||||
assert.Contains(t, meta.Columns, "email")
|
|
||||||
assert.NotContains(t, meta.Columns, "name") // unexported
|
|
||||||
assert.NotContains(t, meta.Columns, "created_at") // unexported
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "user pointer type",
|
|
||||||
doc: &testUserPointer{ID: stringPtr("1"), Name: "Alice"},
|
|
||||||
keyspace: "test",
|
|
||||||
wantErr: false,
|
|
||||||
checkFunc: func(t *testing.T, meta table.Metadata, keyspace string) {
|
|
||||||
assert.Equal(t, keyspace+".users", meta.Name)
|
|
||||||
assert.Contains(t, meta.Columns, "id")
|
|
||||||
assert.Contains(t, meta.Columns, "name")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
var meta table.Metadata
|
|
||||||
var err error
|
|
||||||
|
|
||||||
switch doc := tt.doc.(type) {
|
|
||||||
case testUser:
|
|
||||||
meta, err = generateMetadata(doc, tt.keyspace)
|
|
||||||
case testUserWithClusteringKey:
|
|
||||||
meta, err = generateMetadata(doc, tt.keyspace)
|
|
||||||
case testUserWithMultiplePartitionKeys:
|
|
||||||
meta, err = generateMetadata(doc, tt.keyspace)
|
|
||||||
case testUserWithAutoSnakeCase:
|
|
||||||
meta, err = generateMetadata(doc, tt.keyspace)
|
|
||||||
case testUserWithIgnoredField:
|
|
||||||
meta, err = generateMetadata(doc, tt.keyspace)
|
|
||||||
case testUserUnexported:
|
|
||||||
meta, err = generateMetadata(doc, tt.keyspace)
|
|
||||||
case *testUserPointer:
|
|
||||||
meta, err = generateMetadata(*doc, tt.keyspace)
|
|
||||||
default:
|
|
||||||
t.Fatalf("unsupported type: %T", doc)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.wantErr {
|
|
||||||
require.Error(t, err)
|
|
||||||
if tt.errCode != "" {
|
|
||||||
var e *Error
|
|
||||||
if assert.ErrorAs(t, err, &e) {
|
|
||||||
assert.Equal(t, tt.errCode, e.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err)
|
|
||||||
if tt.checkFunc != nil {
|
|
||||||
tt.checkFunc(t, meta, tt.keyspace)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateMetadata_ErrorCases(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
doc interface{}
|
|
||||||
keyspace string
|
|
||||||
wantErr bool
|
|
||||||
errCode ErrorCode
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "missing table name",
|
|
||||||
doc: testUserNoTableName{ID: "1"},
|
|
||||||
keyspace: "test",
|
|
||||||
wantErr: true,
|
|
||||||
errCode: ErrCodeMissingTableName,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "missing partition key",
|
|
||||||
doc: testUserNoPartitionKey{ID: "1", Name: "Alice"},
|
|
||||||
keyspace: "test",
|
|
||||||
wantErr: true,
|
|
||||||
errCode: ErrCodeMissingPartition,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
var err error
|
|
||||||
switch doc := tt.doc.(type) {
|
|
||||||
case testUserNoTableName:
|
|
||||||
_, err = generateMetadata(doc, tt.keyspace)
|
|
||||||
case testUserNoPartitionKey:
|
|
||||||
_, err = generateMetadata(doc, tt.keyspace)
|
|
||||||
default:
|
|
||||||
t.Fatalf("unsupported type: %T", doc)
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.wantErr {
|
|
||||||
require.Error(t, err)
|
|
||||||
if tt.errCode != "" {
|
|
||||||
var e *Error
|
|
||||||
if assert.ErrorAs(t, err, &e) {
|
|
||||||
assert.Equal(t, tt.errCode, e.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateMetadata_Cache(t *testing.T) {
|
|
||||||
t.Run("cache hit for same struct type", func(t *testing.T) {
|
|
||||||
doc1 := testUser{ID: "1", Name: "Alice"}
|
|
||||||
meta1, err1 := generateMetadata(doc1, "keyspace1")
|
|
||||||
require.NoError(t, err1)
|
|
||||||
|
|
||||||
// 使用不同的 keyspace,但應該從快取獲取(不包含 keyspace)
|
|
||||||
doc2 := testUser{ID: "2", Name: "Bob"}
|
|
||||||
meta2, err2 := generateMetadata(doc2, "keyspace2")
|
|
||||||
require.NoError(t, err2)
|
|
||||||
|
|
||||||
// 驗證結構相同,但 keyspace 不同
|
|
||||||
assert.Equal(t, "keyspace1.users", meta1.Name)
|
|
||||||
assert.Equal(t, "keyspace2.users", meta2.Name)
|
|
||||||
assert.Equal(t, meta1.Columns, meta2.Columns)
|
|
||||||
assert.Equal(t, meta1.PartKey, meta2.PartKey)
|
|
||||||
assert.Equal(t, meta1.SortKey, meta2.SortKey)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("cache hit for error case", func(t *testing.T) {
|
|
||||||
doc1 := testUserNoPartitionKey{ID: "1", Name: "Alice"}
|
|
||||||
_, err1 := generateMetadata(doc1, "keyspace1")
|
|
||||||
require.Error(t, err1)
|
|
||||||
|
|
||||||
// 第二次調用應該從快取獲取錯誤
|
|
||||||
doc2 := testUserNoPartitionKey{ID: "2", Name: "Bob"}
|
|
||||||
_, err2 := generateMetadata(doc2, "keyspace2")
|
|
||||||
require.Error(t, err2)
|
|
||||||
|
|
||||||
// 錯誤應該相同
|
|
||||||
assert.Equal(t, err1.Error(), err2.Error())
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("cache miss for different struct type", func(t *testing.T) {
|
|
||||||
doc1 := testUser{ID: "1"}
|
|
||||||
meta1, err1 := generateMetadata(doc1, "test")
|
|
||||||
require.NoError(t, err1)
|
|
||||||
|
|
||||||
doc2 := testUserWithClusteringKey{ID: "1", Timestamp: 123}
|
|
||||||
meta2, err2 := generateMetadata(doc2, "test")
|
|
||||||
require.NoError(t, err2)
|
|
||||||
|
|
||||||
// 應該是不同的 metadata
|
|
||||||
assert.NotEqual(t, meta1.Name, meta2.Name)
|
|
||||||
assert.NotEqual(t, meta1.Columns, meta2.Columns)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateMetadata_DifferentKeyspaces(t *testing.T) {
|
|
||||||
t.Run("same struct with different keyspaces", func(t *testing.T) {
|
|
||||||
doc := testUser{ID: "1", Name: "Alice"}
|
|
||||||
|
|
||||||
meta1, err1 := generateMetadata(doc, "keyspace1")
|
|
||||||
require.NoError(t, err1)
|
|
||||||
|
|
||||||
meta2, err2 := generateMetadata(doc, "keyspace2")
|
|
||||||
require.NoError(t, err2)
|
|
||||||
|
|
||||||
// 結構應該相同,但 keyspace 不同
|
|
||||||
assert.Equal(t, "keyspace1.users", meta1.Name)
|
|
||||||
assert.Equal(t, "keyspace2.users", meta2.Name)
|
|
||||||
assert.Equal(t, meta1.Columns, meta2.Columns)
|
|
||||||
assert.Equal(t, meta1.PartKey, meta2.PartKey)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateMetadata_EmptyKeyspace(t *testing.T) {
|
|
||||||
t.Run("empty keyspace", func(t *testing.T) {
|
|
||||||
doc := testUser{ID: "1", Name: "Alice"}
|
|
||||||
meta, err := generateMetadata(doc, "")
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, ".users", meta.Name)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateMetadata_PointerVsValue(t *testing.T) {
|
|
||||||
t.Run("pointer and value should produce same metadata", func(t *testing.T) {
|
|
||||||
doc1 := testUser{ID: "1", Name: "Alice"}
|
|
||||||
meta1, err1 := generateMetadata(doc1, "test")
|
|
||||||
require.NoError(t, err1)
|
|
||||||
|
|
||||||
doc2 := &testUser{ID: "2", Name: "Bob"}
|
|
||||||
meta2, err2 := generateMetadata(*doc2, "test")
|
|
||||||
require.NoError(t, err2)
|
|
||||||
|
|
||||||
// 應該產生相同的 metadata(除了可能的值不同)
|
|
||||||
assert.Equal(t, meta1.Name, meta2.Name)
|
|
||||||
assert.Equal(t, meta1.Columns, meta2.Columns)
|
|
||||||
assert.Equal(t, meta1.PartKey, meta2.PartKey)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateMetadata_ColumnOrder(t *testing.T) {
|
|
||||||
t.Run("columns should maintain struct field order", func(t *testing.T) {
|
|
||||||
doc := testUser{ID: "1", Name: "Alice", Email: "alice@example.com"}
|
|
||||||
meta, err := generateMetadata(doc, "test")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// 驗證欄位順序(根據 struct 定義)
|
|
||||||
assert.Equal(t, "id", meta.Columns[0])
|
|
||||||
assert.Equal(t, "name", meta.Columns[1])
|
|
||||||
assert.Equal(t, "email", meta.Columns[2])
|
|
||||||
assert.Equal(t, "created_at", meta.Columns[3])
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenerateMetadata_AllTagCombinations(t *testing.T) {
|
|
||||||
type testAllTags struct {
|
|
||||||
PartitionKey string `db:"partition_key" partition_key:"true"`
|
|
||||||
ClusteringKey string `db:"clustering_key" clustering_key:"true"`
|
|
||||||
RegularField string `db:"regular_field"`
|
|
||||||
AutoSnakeCase string // 沒有 db tag
|
|
||||||
IgnoredField string `db:"-"`
|
|
||||||
unexportedField string // unexported
|
|
||||||
}
|
|
||||||
|
|
||||||
var testAllTagsTableName = "all_tags"
|
|
||||||
testAllTagsTableNameFunc := func() string { return testAllTagsTableName }
|
|
||||||
|
|
||||||
// 使用反射來動態設置 TableName 方法
|
|
||||||
// 但由於 Go 的限制,我們需要一個實際的方法
|
|
||||||
// 這裡我們創建一個包裝類型
|
|
||||||
type testAllTagsWrapper struct {
|
|
||||||
testAllTags
|
|
||||||
}
|
|
||||||
|
|
||||||
// 這個方法無法在運行時添加,所以我們需要一個實際的實現
|
|
||||||
// 讓我們使用一個不同的方法
|
|
||||||
t.Run("all tag combinations", func(t *testing.T) {
|
|
||||||
// 由於無法動態添加方法,我們跳過這個測試
|
|
||||||
// 或者創建一個實際的 struct
|
|
||||||
_ = testAllTagsWrapper{}
|
|
||||||
_ = testAllTagsTableNameFunc
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// 輔助函數
|
|
||||||
func stringPtr(s string) *string {
|
|
||||||
return &s
|
|
||||||
}
|
|
||||||
|
|
@ -1,162 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// config 是初始化 DB 所需的內部設定(私有)
|
|
||||||
type config struct {
|
|
||||||
Hosts []string // Cassandra 主機列表
|
|
||||||
Port int // 連線埠
|
|
||||||
Keyspace string // 預設使用的 Keyspace
|
|
||||||
Username string // 認證用戶名
|
|
||||||
Password string // 認證密碼
|
|
||||||
Consistency gocql.Consistency // 一致性級別
|
|
||||||
ConnectTimeoutSec int // 連線逾時秒數
|
|
||||||
NumConns int // 每個節點連線數
|
|
||||||
MaxRetries int // 重試次數
|
|
||||||
UseAuth bool // 是否使用帳號密碼驗證
|
|
||||||
RetryMinInterval time.Duration // 重試間隔最小值
|
|
||||||
RetryMaxInterval time.Duration // 重試間隔最大值
|
|
||||||
ReconnectInitialInterval time.Duration // 重連初始間隔
|
|
||||||
ReconnectMaxInterval time.Duration // 重連最大間隔
|
|
||||||
CQLVersion string // 執行連線的CQL 版本號
|
|
||||||
}
|
|
||||||
|
|
||||||
// defaultConfig 返回預設配置
|
|
||||||
func defaultConfig() *config {
|
|
||||||
return &config{
|
|
||||||
Port: defaultPort,
|
|
||||||
Consistency: defaultConsistency,
|
|
||||||
ConnectTimeoutSec: defaultTimeoutSec,
|
|
||||||
NumConns: defaultNumConns,
|
|
||||||
MaxRetries: defaultMaxRetries,
|
|
||||||
RetryMinInterval: defaultRetryMinInterval,
|
|
||||||
RetryMaxInterval: defaultRetryMaxInterval,
|
|
||||||
ReconnectInitialInterval: defaultReconnectInitialInterval,
|
|
||||||
ReconnectMaxInterval: defaultReconnectMaxInterval,
|
|
||||||
CQLVersion: defaultCqlVersion,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Option 是設定選項的函數型別
|
|
||||||
type Option func(*config)
|
|
||||||
|
|
||||||
// WithHosts 設定 Cassandra 主機列表
|
|
||||||
func WithHosts(hosts ...string) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
c.Hosts = hosts
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithPort 設定連線埠
|
|
||||||
func WithPort(port int) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
c.Port = port
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithKeyspace 設定預設 keyspace
|
|
||||||
func WithKeyspace(keyspace string) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
c.Keyspace = keyspace
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithAuth 設定認證資訊
|
|
||||||
func WithAuth(username, password string) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
c.Username = username
|
|
||||||
c.Password = password
|
|
||||||
c.UseAuth = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithConsistency 設定一致性級別
|
|
||||||
func WithConsistency(consistency gocql.Consistency) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
c.Consistency = consistency
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithConnectTimeoutSec 設定連線逾時秒數
|
|
||||||
func WithConnectTimeoutSec(timeout int) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
if timeout <= 0 {
|
|
||||||
timeout = defaultTimeoutSec
|
|
||||||
}
|
|
||||||
c.ConnectTimeoutSec = timeout
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithNumConns 設定每個節點的連線數
|
|
||||||
func WithNumConns(numConns int) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
if numConns <= 0 {
|
|
||||||
numConns = defaultNumConns
|
|
||||||
}
|
|
||||||
c.NumConns = numConns
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithMaxRetries 設定最大重試次數
|
|
||||||
func WithMaxRetries(maxRetries int) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
if maxRetries <= 0 {
|
|
||||||
maxRetries = defaultMaxRetries
|
|
||||||
}
|
|
||||||
c.MaxRetries = maxRetries
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithRetryMinInterval 設定最小重試間隔
|
|
||||||
func WithRetryMinInterval(duration time.Duration) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
if duration <= 0 {
|
|
||||||
duration = defaultRetryMinInterval
|
|
||||||
}
|
|
||||||
c.RetryMinInterval = duration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithRetryMaxInterval 設定最大重試間隔
|
|
||||||
func WithRetryMaxInterval(duration time.Duration) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
if duration <= 0 {
|
|
||||||
duration = defaultRetryMaxInterval
|
|
||||||
}
|
|
||||||
c.RetryMaxInterval = duration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithReconnectInitialInterval 設定初始重連間隔
|
|
||||||
func WithReconnectInitialInterval(duration time.Duration) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
if duration <= 0 {
|
|
||||||
duration = defaultReconnectInitialInterval
|
|
||||||
}
|
|
||||||
c.ReconnectInitialInterval = duration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithReconnectMaxInterval 設定最大重連間隔
|
|
||||||
func WithReconnectMaxInterval(duration time.Duration) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
if duration <= 0 {
|
|
||||||
duration = defaultReconnectMaxInterval
|
|
||||||
}
|
|
||||||
c.ReconnectMaxInterval = duration
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithCQLVersion 設定 CQL 版本
|
|
||||||
func WithCQLVersion(version string) Option {
|
|
||||||
return func(c *config) {
|
|
||||||
if version == "" {
|
|
||||||
version = defaultCqlVersion
|
|
||||||
}
|
|
||||||
c.CQLVersion = version
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,963 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestOption_DefaultConfig(t *testing.T) {
|
|
||||||
t.Run("defaultConfig should return valid config with all defaults", func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
require.NotNil(t, cfg)
|
|
||||||
assert.Equal(t, defaultPort, cfg.Port)
|
|
||||||
assert.Equal(t, defaultConsistency, cfg.Consistency)
|
|
||||||
assert.Equal(t, defaultTimeoutSec, cfg.ConnectTimeoutSec)
|
|
||||||
assert.Equal(t, defaultNumConns, cfg.NumConns)
|
|
||||||
assert.Equal(t, defaultMaxRetries, cfg.MaxRetries)
|
|
||||||
assert.Equal(t, defaultRetryMinInterval, cfg.RetryMinInterval)
|
|
||||||
assert.Equal(t, defaultRetryMaxInterval, cfg.RetryMaxInterval)
|
|
||||||
assert.Equal(t, defaultReconnectInitialInterval, cfg.ReconnectInitialInterval)
|
|
||||||
assert.Equal(t, defaultReconnectMaxInterval, cfg.ReconnectMaxInterval)
|
|
||||||
assert.Equal(t, defaultCqlVersion, cfg.CQLVersion)
|
|
||||||
assert.Empty(t, cfg.Hosts)
|
|
||||||
assert.Empty(t, cfg.Keyspace)
|
|
||||||
assert.Empty(t, cfg.Username)
|
|
||||||
assert.Empty(t, cfg.Password)
|
|
||||||
assert.False(t, cfg.UseAuth)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithHosts(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
hosts []string
|
|
||||||
expected []string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "single host",
|
|
||||||
hosts: []string{"localhost"},
|
|
||||||
expected: []string{"localhost"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple hosts",
|
|
||||||
hosts: []string{"localhost", "127.0.0.1", "192.168.1.1"},
|
|
||||||
expected: []string{"localhost", "127.0.0.1", "192.168.1.1"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty hosts",
|
|
||||||
hosts: []string{},
|
|
||||||
expected: []string{},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "host with port",
|
|
||||||
hosts: []string{"localhost:9042"},
|
|
||||||
expected: []string{"localhost:9042"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "host with domain",
|
|
||||||
hosts: []string{"cassandra.example.com"},
|
|
||||||
expected: []string{"cassandra.example.com"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithHosts(tt.hosts...)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expected, cfg.Hosts)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithPort(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
port int
|
|
||||||
expected int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "default port",
|
|
||||||
port: 9042,
|
|
||||||
expected: 9042,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "custom port",
|
|
||||||
port: 9043,
|
|
||||||
expected: 9043,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero port",
|
|
||||||
port: 0,
|
|
||||||
expected: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative port",
|
|
||||||
port: -1,
|
|
||||||
expected: -1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "high port number",
|
|
||||||
port: 65535,
|
|
||||||
expected: 65535,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithPort(tt.port)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expected, cfg.Port)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithKeyspace(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
keyspace string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid keyspace",
|
|
||||||
keyspace: "my_keyspace",
|
|
||||||
expected: "my_keyspace",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty keyspace",
|
|
||||||
keyspace: "",
|
|
||||||
expected: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "keyspace with underscore",
|
|
||||||
keyspace: "test_keyspace_1",
|
|
||||||
expected: "test_keyspace_1",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "keyspace with numbers",
|
|
||||||
keyspace: "keyspace123",
|
|
||||||
expected: "keyspace123",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "long keyspace name",
|
|
||||||
keyspace: "very_long_keyspace_name_that_might_exist",
|
|
||||||
expected: "very_long_keyspace_name_that_might_exist",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithKeyspace(tt.keyspace)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expected, cfg.Keyspace)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithAuth(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
username string
|
|
||||||
password string
|
|
||||||
expectedUser string
|
|
||||||
expectedPass string
|
|
||||||
expectedUseAuth bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid credentials",
|
|
||||||
username: "admin",
|
|
||||||
password: "password123",
|
|
||||||
expectedUser: "admin",
|
|
||||||
expectedPass: "password123",
|
|
||||||
expectedUseAuth: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty username",
|
|
||||||
username: "",
|
|
||||||
password: "password",
|
|
||||||
expectedUser: "",
|
|
||||||
expectedPass: "password",
|
|
||||||
expectedUseAuth: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty password",
|
|
||||||
username: "admin",
|
|
||||||
password: "",
|
|
||||||
expectedUser: "admin",
|
|
||||||
expectedPass: "",
|
|
||||||
expectedUseAuth: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "both empty",
|
|
||||||
username: "",
|
|
||||||
password: "",
|
|
||||||
expectedUser: "",
|
|
||||||
expectedPass: "",
|
|
||||||
expectedUseAuth: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "special characters in password",
|
|
||||||
username: "user",
|
|
||||||
password: "p@ssw0rd!#$%",
|
|
||||||
expectedUser: "user",
|
|
||||||
expectedPass: "p@ssw0rd!#$%",
|
|
||||||
expectedUseAuth: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "long username and password",
|
|
||||||
username: "very_long_username_that_might_exist",
|
|
||||||
password: "very_long_password_that_might_exist",
|
|
||||||
expectedUser: "very_long_username_that_might_exist",
|
|
||||||
expectedPass: "very_long_password_that_might_exist",
|
|
||||||
expectedUseAuth: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithAuth(tt.username, tt.password)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expectedUser, cfg.Username)
|
|
||||||
assert.Equal(t, tt.expectedPass, cfg.Password)
|
|
||||||
assert.Equal(t, tt.expectedUseAuth, cfg.UseAuth)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithConsistency(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
consistency gocql.Consistency
|
|
||||||
expected gocql.Consistency
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Quorum consistency",
|
|
||||||
consistency: gocql.Quorum,
|
|
||||||
expected: gocql.Quorum,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "One consistency",
|
|
||||||
consistency: gocql.One,
|
|
||||||
expected: gocql.One,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "All consistency",
|
|
||||||
consistency: gocql.All,
|
|
||||||
expected: gocql.All,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Any consistency",
|
|
||||||
consistency: gocql.Any,
|
|
||||||
expected: gocql.Any,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "LocalQuorum consistency",
|
|
||||||
consistency: gocql.LocalQuorum,
|
|
||||||
expected: gocql.LocalQuorum,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "EachQuorum consistency",
|
|
||||||
consistency: gocql.EachQuorum,
|
|
||||||
expected: gocql.EachQuorum,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "LocalOne consistency",
|
|
||||||
consistency: gocql.LocalOne,
|
|
||||||
expected: gocql.LocalOne,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Two consistency",
|
|
||||||
consistency: gocql.Two,
|
|
||||||
expected: gocql.Two,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithConsistency(tt.consistency)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expected, cfg.Consistency)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithConnectTimeoutSec(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
timeout int
|
|
||||||
expected int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid timeout",
|
|
||||||
timeout: 10,
|
|
||||||
expected: 10,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero timeout should use default",
|
|
||||||
timeout: 0,
|
|
||||||
expected: defaultTimeoutSec,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative timeout should use default",
|
|
||||||
timeout: -1,
|
|
||||||
expected: defaultTimeoutSec,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "large timeout",
|
|
||||||
timeout: 300,
|
|
||||||
expected: 300,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "small timeout",
|
|
||||||
timeout: 1,
|
|
||||||
expected: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "very large timeout",
|
|
||||||
timeout: 3600,
|
|
||||||
expected: 3600,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithConnectTimeoutSec(tt.timeout)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expected, cfg.ConnectTimeoutSec)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithNumConns(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
numConns int
|
|
||||||
expected int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid numConns",
|
|
||||||
numConns: 10,
|
|
||||||
expected: 10,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero numConns should use default",
|
|
||||||
numConns: 0,
|
|
||||||
expected: defaultNumConns,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative numConns should use default",
|
|
||||||
numConns: -1,
|
|
||||||
expected: defaultNumConns,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "large numConns",
|
|
||||||
numConns: 100,
|
|
||||||
expected: 100,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "small numConns",
|
|
||||||
numConns: 1,
|
|
||||||
expected: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "very large numConns",
|
|
||||||
numConns: 1000,
|
|
||||||
expected: 1000,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithNumConns(tt.numConns)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expected, cfg.NumConns)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithMaxRetries(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
maxRetries int
|
|
||||||
expected int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid maxRetries",
|
|
||||||
maxRetries: 3,
|
|
||||||
expected: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero maxRetries should use default",
|
|
||||||
maxRetries: 0,
|
|
||||||
expected: defaultMaxRetries,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative maxRetries should use default",
|
|
||||||
maxRetries: -1,
|
|
||||||
expected: defaultMaxRetries,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "large maxRetries",
|
|
||||||
maxRetries: 10,
|
|
||||||
expected: 10,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "small maxRetries",
|
|
||||||
maxRetries: 1,
|
|
||||||
expected: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "very large maxRetries",
|
|
||||||
maxRetries: 100,
|
|
||||||
expected: 100,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithMaxRetries(tt.maxRetries)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expected, cfg.MaxRetries)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithRetryMinInterval(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
duration time.Duration
|
|
||||||
expected time.Duration
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid duration",
|
|
||||||
duration: 1 * time.Second,
|
|
||||||
expected: 1 * time.Second,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero duration should use default",
|
|
||||||
duration: 0,
|
|
||||||
expected: defaultRetryMinInterval,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative duration should use default",
|
|
||||||
duration: -1 * time.Second,
|
|
||||||
expected: defaultRetryMinInterval,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "milliseconds",
|
|
||||||
duration: 500 * time.Millisecond,
|
|
||||||
expected: 500 * time.Millisecond,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "minutes",
|
|
||||||
duration: 5 * time.Minute,
|
|
||||||
expected: 5 * time.Minute,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hours",
|
|
||||||
duration: 1 * time.Hour,
|
|
||||||
expected: 1 * time.Hour,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithRetryMinInterval(tt.duration)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expected, cfg.RetryMinInterval)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithRetryMaxInterval(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
duration time.Duration
|
|
||||||
expected time.Duration
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid duration",
|
|
||||||
duration: 30 * time.Second,
|
|
||||||
expected: 30 * time.Second,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero duration should use default",
|
|
||||||
duration: 0,
|
|
||||||
expected: defaultRetryMaxInterval,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative duration should use default",
|
|
||||||
duration: -1 * time.Second,
|
|
||||||
expected: defaultRetryMaxInterval,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "milliseconds",
|
|
||||||
duration: 1000 * time.Millisecond,
|
|
||||||
expected: 1000 * time.Millisecond,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "minutes",
|
|
||||||
duration: 10 * time.Minute,
|
|
||||||
expected: 10 * time.Minute,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hours",
|
|
||||||
duration: 2 * time.Hour,
|
|
||||||
expected: 2 * time.Hour,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithRetryMaxInterval(tt.duration)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expected, cfg.RetryMaxInterval)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithReconnectInitialInterval(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
duration time.Duration
|
|
||||||
expected time.Duration
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid duration",
|
|
||||||
duration: 1 * time.Second,
|
|
||||||
expected: 1 * time.Second,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero duration should use default",
|
|
||||||
duration: 0,
|
|
||||||
expected: defaultReconnectInitialInterval,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative duration should use default",
|
|
||||||
duration: -1 * time.Second,
|
|
||||||
expected: defaultReconnectInitialInterval,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "milliseconds",
|
|
||||||
duration: 500 * time.Millisecond,
|
|
||||||
expected: 500 * time.Millisecond,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "minutes",
|
|
||||||
duration: 2 * time.Minute,
|
|
||||||
expected: 2 * time.Minute,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithReconnectInitialInterval(tt.duration)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expected, cfg.ReconnectInitialInterval)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithReconnectMaxInterval(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
duration time.Duration
|
|
||||||
expected time.Duration
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid duration",
|
|
||||||
duration: 60 * time.Second,
|
|
||||||
expected: 60 * time.Second,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero duration should use default",
|
|
||||||
duration: 0,
|
|
||||||
expected: defaultReconnectMaxInterval,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative duration should use default",
|
|
||||||
duration: -1 * time.Second,
|
|
||||||
expected: defaultReconnectMaxInterval,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "milliseconds",
|
|
||||||
duration: 5000 * time.Millisecond,
|
|
||||||
expected: 5000 * time.Millisecond,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "minutes",
|
|
||||||
duration: 5 * time.Minute,
|
|
||||||
expected: 5 * time.Minute,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hours",
|
|
||||||
duration: 1 * time.Hour,
|
|
||||||
expected: 1 * time.Hour,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithReconnectMaxInterval(tt.duration)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expected, cfg.ReconnectMaxInterval)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithCQLVersion(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
version string
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid version",
|
|
||||||
version: "3.0.0",
|
|
||||||
expected: "3.0.0",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty version should use default",
|
|
||||||
version: "",
|
|
||||||
expected: defaultCqlVersion,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 3.1.0",
|
|
||||||
version: "3.1.0",
|
|
||||||
expected: "3.1.0",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 3.4.0",
|
|
||||||
version: "3.4.0",
|
|
||||||
expected: "3.4.0",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version 4.0.0",
|
|
||||||
version: "4.0.0",
|
|
||||||
expected: "4.0.0",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version with build",
|
|
||||||
version: "3.0.0-beta",
|
|
||||||
expected: "3.0.0-beta",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "version with snapshot",
|
|
||||||
version: "3.0.0-SNAPSHOT",
|
|
||||||
expected: "3.0.0-SNAPSHOT",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opt := WithCQLVersion(tt.version)
|
|
||||||
opt(cfg)
|
|
||||||
assert.Equal(t, tt.expected, cfg.CQLVersion)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOption_Combination(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
opts []Option
|
|
||||||
validate func(*testing.T, *config)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "all options",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("localhost", "127.0.0.1"),
|
|
||||||
WithPort(9042),
|
|
||||||
WithKeyspace("test_keyspace"),
|
|
||||||
WithAuth("user", "pass"),
|
|
||||||
WithConsistency(gocql.Quorum),
|
|
||||||
WithConnectTimeoutSec(10),
|
|
||||||
WithNumConns(10),
|
|
||||||
WithMaxRetries(3),
|
|
||||||
WithRetryMinInterval(1 * time.Second),
|
|
||||||
WithRetryMaxInterval(30 * time.Second),
|
|
||||||
WithReconnectInitialInterval(1 * time.Second),
|
|
||||||
WithReconnectMaxInterval(60 * time.Second),
|
|
||||||
WithCQLVersion("3.0.0"),
|
|
||||||
},
|
|
||||||
validate: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, []string{"localhost", "127.0.0.1"}, c.Hosts)
|
|
||||||
assert.Equal(t, 9042, c.Port)
|
|
||||||
assert.Equal(t, "test_keyspace", c.Keyspace)
|
|
||||||
assert.Equal(t, "user", c.Username)
|
|
||||||
assert.Equal(t, "pass", c.Password)
|
|
||||||
assert.True(t, c.UseAuth)
|
|
||||||
assert.Equal(t, gocql.Quorum, c.Consistency)
|
|
||||||
assert.Equal(t, 10, c.ConnectTimeoutSec)
|
|
||||||
assert.Equal(t, 10, c.NumConns)
|
|
||||||
assert.Equal(t, 3, c.MaxRetries)
|
|
||||||
assert.Equal(t, 1*time.Second, c.RetryMinInterval)
|
|
||||||
assert.Equal(t, 30*time.Second, c.RetryMaxInterval)
|
|
||||||
assert.Equal(t, 1*time.Second, c.ReconnectInitialInterval)
|
|
||||||
assert.Equal(t, 60*time.Second, c.ReconnectMaxInterval)
|
|
||||||
assert.Equal(t, "3.0.0", c.CQLVersion)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "minimal options",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("localhost"),
|
|
||||||
},
|
|
||||||
validate: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, []string{"localhost"}, c.Hosts)
|
|
||||||
// 其他應該使用預設值
|
|
||||||
assert.Equal(t, defaultPort, c.Port)
|
|
||||||
assert.Equal(t, defaultConsistency, c.Consistency)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "options with zero values should use defaults",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("localhost"),
|
|
||||||
WithConnectTimeoutSec(0),
|
|
||||||
WithNumConns(0),
|
|
||||||
WithMaxRetries(0),
|
|
||||||
WithRetryMinInterval(0),
|
|
||||||
WithRetryMaxInterval(0),
|
|
||||||
WithReconnectInitialInterval(0),
|
|
||||||
WithReconnectMaxInterval(0),
|
|
||||||
WithCQLVersion(""),
|
|
||||||
},
|
|
||||||
validate: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, []string{"localhost"}, c.Hosts)
|
|
||||||
assert.Equal(t, defaultTimeoutSec, c.ConnectTimeoutSec)
|
|
||||||
assert.Equal(t, defaultNumConns, c.NumConns)
|
|
||||||
assert.Equal(t, defaultMaxRetries, c.MaxRetries)
|
|
||||||
assert.Equal(t, defaultRetryMinInterval, c.RetryMinInterval)
|
|
||||||
assert.Equal(t, defaultRetryMaxInterval, c.RetryMaxInterval)
|
|
||||||
assert.Equal(t, defaultReconnectInitialInterval, c.ReconnectInitialInterval)
|
|
||||||
assert.Equal(t, defaultReconnectMaxInterval, c.ReconnectMaxInterval)
|
|
||||||
assert.Equal(t, defaultCqlVersion, c.CQLVersion)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "options with negative values should use defaults",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("localhost"),
|
|
||||||
WithConnectTimeoutSec(-1),
|
|
||||||
WithNumConns(-1),
|
|
||||||
WithMaxRetries(-1),
|
|
||||||
WithRetryMinInterval(-1 * time.Second),
|
|
||||||
WithRetryMaxInterval(-1 * time.Second),
|
|
||||||
WithReconnectInitialInterval(-1 * time.Second),
|
|
||||||
WithReconnectMaxInterval(-1 * time.Second),
|
|
||||||
},
|
|
||||||
validate: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, []string{"localhost"}, c.Hosts)
|
|
||||||
assert.Equal(t, defaultTimeoutSec, c.ConnectTimeoutSec)
|
|
||||||
assert.Equal(t, defaultNumConns, c.NumConns)
|
|
||||||
assert.Equal(t, defaultMaxRetries, c.MaxRetries)
|
|
||||||
assert.Equal(t, defaultRetryMinInterval, c.RetryMinInterval)
|
|
||||||
assert.Equal(t, defaultRetryMaxInterval, c.RetryMaxInterval)
|
|
||||||
assert.Equal(t, defaultReconnectInitialInterval, c.ReconnectInitialInterval)
|
|
||||||
assert.Equal(t, defaultReconnectMaxInterval, c.ReconnectMaxInterval)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple options applied in sequence",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("host1"),
|
|
||||||
WithHosts("host2", "host3"), // 應該覆蓋
|
|
||||||
WithPort(9042),
|
|
||||||
WithPort(9043), // 應該覆蓋
|
|
||||||
},
|
|
||||||
validate: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, []string{"host2", "host3"}, c.Hosts)
|
|
||||||
assert.Equal(t, 9043, c.Port)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
for _, opt := range tt.opts {
|
|
||||||
opt(cfg)
|
|
||||||
}
|
|
||||||
tt.validate(t, cfg)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOption_Type(t *testing.T) {
|
|
||||||
t.Run("all options should return Option type", func(t *testing.T) {
|
|
||||||
var opt Option
|
|
||||||
|
|
||||||
opt = WithHosts("localhost")
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
|
|
||||||
opt = WithPort(9042)
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
|
|
||||||
opt = WithKeyspace("test")
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
|
|
||||||
opt = WithAuth("user", "pass")
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
|
|
||||||
opt = WithConsistency(gocql.Quorum)
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
|
|
||||||
opt = WithConnectTimeoutSec(10)
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
|
|
||||||
opt = WithNumConns(10)
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
|
|
||||||
opt = WithMaxRetries(3)
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
|
|
||||||
opt = WithRetryMinInterval(1 * time.Second)
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
|
|
||||||
opt = WithRetryMaxInterval(30 * time.Second)
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
|
|
||||||
opt = WithReconnectInitialInterval(1 * time.Second)
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
|
|
||||||
opt = WithReconnectMaxInterval(60 * time.Second)
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
|
|
||||||
opt = WithCQLVersion("3.0.0")
|
|
||||||
assert.NotNil(t, opt)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOption_EdgeCases(t *testing.T) {
|
|
||||||
t.Run("empty option slice", func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
opts := []Option{}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(cfg)
|
|
||||||
}
|
|
||||||
// 應該保持預設值
|
|
||||||
assert.Equal(t, defaultPort, cfg.Port)
|
|
||||||
assert.Equal(t, defaultConsistency, cfg.Consistency)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("zero value option function", func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
var opt Option
|
|
||||||
// 零值的 Option 是 nil,調用會 panic,所以不應該調用
|
|
||||||
// 這裡只是驗證零值不會影響配置
|
|
||||||
_ = opt
|
|
||||||
// 應該保持預設值
|
|
||||||
assert.Equal(t, defaultPort, cfg.Port)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("very long strings", func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
longString := string(make([]byte, 10000))
|
|
||||||
WithKeyspace(longString)(cfg)
|
|
||||||
assert.Equal(t, longString, cfg.Keyspace)
|
|
||||||
|
|
||||||
WithAuth(longString, longString)(cfg)
|
|
||||||
assert.Equal(t, longString, cfg.Username)
|
|
||||||
assert.Equal(t, longString, cfg.Password)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("special characters in strings", func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
specialChars := "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
|
||||||
WithKeyspace(specialChars)(cfg)
|
|
||||||
assert.Equal(t, specialChars, cfg.Keyspace)
|
|
||||||
|
|
||||||
WithAuth(specialChars, specialChars)(cfg)
|
|
||||||
assert.Equal(t, specialChars, cfg.Username)
|
|
||||||
assert.Equal(t, specialChars, cfg.Password)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOption_RealWorldScenarios(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
scenario string
|
|
||||||
opts []Option
|
|
||||||
validate func(*testing.T, *config)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "production-like configuration",
|
|
||||||
scenario: "typical production setup",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("cassandra1.example.com", "cassandra2.example.com", "cassandra3.example.com"),
|
|
||||||
WithPort(9042),
|
|
||||||
WithKeyspace("production_keyspace"),
|
|
||||||
WithAuth("prod_user", "secure_password"),
|
|
||||||
WithConsistency(gocql.Quorum),
|
|
||||||
WithConnectTimeoutSec(30),
|
|
||||||
WithNumConns(50),
|
|
||||||
WithMaxRetries(5),
|
|
||||||
},
|
|
||||||
validate: func(t *testing.T, c *config) {
|
|
||||||
assert.Len(t, c.Hosts, 3)
|
|
||||||
assert.Equal(t, 9042, c.Port)
|
|
||||||
assert.Equal(t, "production_keyspace", c.Keyspace)
|
|
||||||
assert.True(t, c.UseAuth)
|
|
||||||
assert.Equal(t, gocql.Quorum, c.Consistency)
|
|
||||||
assert.Equal(t, 30, c.ConnectTimeoutSec)
|
|
||||||
assert.Equal(t, 50, c.NumConns)
|
|
||||||
assert.Equal(t, 5, c.MaxRetries)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "development configuration",
|
|
||||||
scenario: "local development setup",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("localhost"),
|
|
||||||
WithKeyspace("dev_keyspace"),
|
|
||||||
},
|
|
||||||
validate: func(t *testing.T, c *config) {
|
|
||||||
assert.Equal(t, []string{"localhost"}, c.Hosts)
|
|
||||||
assert.Equal(t, "dev_keyspace", c.Keyspace)
|
|
||||||
assert.False(t, c.UseAuth)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "high availability configuration",
|
|
||||||
scenario: "HA setup with multiple hosts",
|
|
||||||
opts: []Option{
|
|
||||||
WithHosts("node1", "node2", "node3", "node4", "node5"),
|
|
||||||
WithConsistency(gocql.All),
|
|
||||||
WithMaxRetries(10),
|
|
||||||
},
|
|
||||||
validate: func(t *testing.T, c *config) {
|
|
||||||
assert.Len(t, c.Hosts, 5)
|
|
||||||
assert.Equal(t, gocql.All, c.Consistency)
|
|
||||||
assert.Equal(t, 10, c.MaxRetries)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cfg := defaultConfig()
|
|
||||||
for _, opt := range tt.opts {
|
|
||||||
opt(cfg)
|
|
||||||
}
|
|
||||||
tt.validate(t, cfg)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,226 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
"github.com/scylladb/gocqlx/v2/qb"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Condition 定義查詢條件介面
|
|
||||||
type Condition interface {
|
|
||||||
Build() (qb.Cmp, map[string]any)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Eq 等於條件
|
|
||||||
func Eq(column string, value any) Condition {
|
|
||||||
return &eqCondition{column: column, value: value}
|
|
||||||
}
|
|
||||||
|
|
||||||
type eqCondition struct {
|
|
||||||
column string
|
|
||||||
value any
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *eqCondition) Build() (qb.Cmp, map[string]any) {
|
|
||||||
return qb.Eq(c.column), map[string]any{c.column: c.value}
|
|
||||||
}
|
|
||||||
|
|
||||||
// In IN 條件
|
|
||||||
func In(column string, values []any) Condition {
|
|
||||||
return &inCondition{column: column, values: values}
|
|
||||||
}
|
|
||||||
|
|
||||||
type inCondition struct {
|
|
||||||
column string
|
|
||||||
values []any
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *inCondition) Build() (qb.Cmp, map[string]any) {
|
|
||||||
return qb.In(c.column), map[string]any{c.column: c.values}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gt 大於條件
|
|
||||||
func Gt(column string, value any) Condition {
|
|
||||||
return >Condition{column: column, value: value}
|
|
||||||
}
|
|
||||||
|
|
||||||
type gtCondition struct {
|
|
||||||
column string
|
|
||||||
value any
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *gtCondition) Build() (qb.Cmp, map[string]any) {
|
|
||||||
return qb.Gt(c.column), map[string]any{c.column: c.value}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lt 小於條件
|
|
||||||
func Lt(column string, value any) Condition {
|
|
||||||
return <Condition{column: column, value: value}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ltCondition struct {
|
|
||||||
column string
|
|
||||||
value any
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ltCondition) Build() (qb.Cmp, map[string]any) {
|
|
||||||
return qb.Lt(c.column), map[string]any{c.column: c.value}
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryBuilder 定義查詢構建器介面
|
|
||||||
type QueryBuilder[T Table] interface {
|
|
||||||
Where(condition Condition) QueryBuilder[T]
|
|
||||||
OrderBy(column string, order Order) QueryBuilder[T]
|
|
||||||
Limit(n int) QueryBuilder[T]
|
|
||||||
Select(columns ...string) QueryBuilder[T]
|
|
||||||
Scan(ctx context.Context, dest *[]T) error
|
|
||||||
One(ctx context.Context) (T, error)
|
|
||||||
Count(ctx context.Context) (int64, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// queryBuilder 是 QueryBuilder 的具體實作
|
|
||||||
type queryBuilder[T Table] struct {
|
|
||||||
repo *repository[T]
|
|
||||||
conditions []Condition
|
|
||||||
orders []orderBy
|
|
||||||
limit int
|
|
||||||
columns []string
|
|
||||||
}
|
|
||||||
|
|
||||||
type orderBy struct {
|
|
||||||
column string
|
|
||||||
order Order
|
|
||||||
}
|
|
||||||
|
|
||||||
// newQueryBuilder 創建新的查詢構建器
|
|
||||||
func newQueryBuilder[T Table](repo *repository[T]) QueryBuilder[T] {
|
|
||||||
return &queryBuilder[T]{
|
|
||||||
repo: repo,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Where 添加 WHERE 條件
|
|
||||||
func (q *queryBuilder[T]) Where(condition Condition) QueryBuilder[T] {
|
|
||||||
q.conditions = append(q.conditions, condition)
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// OrderBy 添加排序
|
|
||||||
func (q *queryBuilder[T]) OrderBy(column string, order Order) QueryBuilder[T] {
|
|
||||||
q.orders = append(q.orders, orderBy{column: column, order: order})
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// Limit 設置限制
|
|
||||||
func (q *queryBuilder[T]) Limit(n int) QueryBuilder[T] {
|
|
||||||
q.limit = n
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select 指定要查詢的欄位
|
|
||||||
func (q *queryBuilder[T]) Select(columns ...string) QueryBuilder[T] {
|
|
||||||
q.columns = append(q.columns, columns...)
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan 執行查詢並將結果掃描到 dest
|
|
||||||
func (q *queryBuilder[T]) Scan(ctx context.Context, dest *[]T) error {
|
|
||||||
if dest == nil {
|
|
||||||
return ErrInvalidInput.WithTable(q.repo.table).WithError(
|
|
||||||
fmt.Errorf("destination cannot be nil"),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := qb.Select(q.repo.table)
|
|
||||||
|
|
||||||
// 添加欄位
|
|
||||||
if len(q.columns) > 0 {
|
|
||||||
builder = builder.Columns(q.columns...)
|
|
||||||
} else {
|
|
||||||
builder = builder.Columns(q.repo.metadata.Columns...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加條件
|
|
||||||
bindMap := make(map[string]any)
|
|
||||||
var cmps []qb.Cmp
|
|
||||||
for _, cond := range q.conditions {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
cmps = append(cmps, cmp)
|
|
||||||
for k, v := range binds {
|
|
||||||
bindMap[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(cmps) > 0 {
|
|
||||||
builder = builder.Where(cmps...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加排序
|
|
||||||
for _, o := range q.orders {
|
|
||||||
order := qb.ASC
|
|
||||||
if o.order == DESC {
|
|
||||||
order = qb.DESC
|
|
||||||
}
|
|
||||||
|
|
||||||
builder = builder.OrderBy(o.column, order)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加限制
|
|
||||||
if q.limit > 0 {
|
|
||||||
builder = builder.Limit(uint(q.limit))
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt, names := builder.ToCql()
|
|
||||||
query := q.repo.db.withContextAndTimestamp(ctx,
|
|
||||||
q.repo.db.session.Query(stmt, names).BindMap(bindMap))
|
|
||||||
|
|
||||||
return query.SelectRelease(dest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// One 執行查詢並返回單筆結果
|
|
||||||
func (q *queryBuilder[T]) One(ctx context.Context) (T, error) {
|
|
||||||
var zero T
|
|
||||||
q.limit = 1
|
|
||||||
|
|
||||||
var results []T
|
|
||||||
if err := q.Scan(ctx, &results); err != nil {
|
|
||||||
return zero, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(results) == 0 {
|
|
||||||
return zero, ErrNotFound.WithTable(q.repo.table)
|
|
||||||
}
|
|
||||||
|
|
||||||
return results[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count 計算符合條件的記錄數
|
|
||||||
func (q *queryBuilder[T]) Count(ctx context.Context) (int64, error) {
|
|
||||||
builder := qb.Select(q.repo.table).Columns("COUNT(*)")
|
|
||||||
|
|
||||||
// 添加條件
|
|
||||||
bindMap := make(map[string]any)
|
|
||||||
var cmps []qb.Cmp
|
|
||||||
for _, cond := range q.conditions {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
cmps = append(cmps, cmp)
|
|
||||||
for k, v := range binds {
|
|
||||||
bindMap[k] = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(cmps) > 0 {
|
|
||||||
builder = builder.Where(cmps...)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt, names := builder.ToCql()
|
|
||||||
query := q.repo.db.withContextAndTimestamp(ctx,
|
|
||||||
q.repo.db.session.Query(stmt, names).BindMap(bindMap))
|
|
||||||
|
|
||||||
var count int64
|
|
||||||
err := query.GetRelease(&count)
|
|
||||||
if err == gocql.ErrNotFound {
|
|
||||||
return 0, nil // COUNT 查詢不會返回 ErrNotFound,但為了安全起見
|
|
||||||
}
|
|
||||||
return count, err
|
|
||||||
}
|
|
||||||
|
|
@ -1,520 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/scylladb/gocqlx/v2/qb"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestEq(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
column string
|
|
||||||
value any
|
|
||||||
validate func(*testing.T, Condition)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "string value",
|
|
||||||
column: "name",
|
|
||||||
value: "Alice",
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, "Alice", binds["name"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "int value",
|
|
||||||
column: "age",
|
|
||||||
value: 25,
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, 25, binds["age"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nil value",
|
|
||||||
column: "description",
|
|
||||||
value: nil,
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Nil(t, binds["description"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty string",
|
|
||||||
column: "email",
|
|
||||||
value: "",
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, "", binds["email"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "boolean value",
|
|
||||||
column: "active",
|
|
||||||
value: true,
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, true, binds["active"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cond := Eq(tt.column, tt.value)
|
|
||||||
assert.NotNil(t, cond)
|
|
||||||
tt.validate(t, cond)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIn(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
column string
|
|
||||||
values []any
|
|
||||||
validate func(*testing.T, Condition)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "string values",
|
|
||||||
column: "status",
|
|
||||||
values: []any{"active", "pending", "completed"},
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, []any{"active", "pending", "completed"}, binds["status"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "int values",
|
|
||||||
column: "ids",
|
|
||||||
values: []any{1, 2, 3, 4, 5},
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, []any{1, 2, 3, 4, 5}, binds["ids"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty slice",
|
|
||||||
column: "tags",
|
|
||||||
values: []any{},
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, []any{}, binds["tags"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single value",
|
|
||||||
column: "id",
|
|
||||||
values: []any{1},
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, []any{1}, binds["id"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed types",
|
|
||||||
column: "values",
|
|
||||||
values: []any{"string", 123, true},
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, []any{"string", 123, true}, binds["values"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cond := In(tt.column, tt.values)
|
|
||||||
assert.NotNil(t, cond)
|
|
||||||
tt.validate(t, cond)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGt(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
column string
|
|
||||||
value any
|
|
||||||
validate func(*testing.T, Condition)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "int value",
|
|
||||||
column: "age",
|
|
||||||
value: 18,
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, 18, binds["age"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "float value",
|
|
||||||
column: "price",
|
|
||||||
value: 99.99,
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, 99.99, binds["price"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero value",
|
|
||||||
column: "count",
|
|
||||||
value: 0,
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, 0, binds["count"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cond := Gt(tt.column, tt.value)
|
|
||||||
assert.NotNil(t, cond)
|
|
||||||
tt.validate(t, cond)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLt(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
column string
|
|
||||||
value any
|
|
||||||
validate func(*testing.T, Condition)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "int value",
|
|
||||||
column: "age",
|
|
||||||
value: 65,
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, 65, binds["age"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "float value",
|
|
||||||
column: "price",
|
|
||||||
value: 199.99,
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, 199.99, binds["price"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative value",
|
|
||||||
column: "balance",
|
|
||||||
value: -100,
|
|
||||||
validate: func(t *testing.T, cond Condition) {
|
|
||||||
cmp, binds := cond.Build()
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, -100, binds["balance"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cond := Lt(tt.column, tt.value)
|
|
||||||
assert.NotNil(t, cond)
|
|
||||||
tt.validate(t, cond)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCondition_Build(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
cond Condition
|
|
||||||
validate func(*testing.T, qb.Cmp, map[string]any)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "Eq condition",
|
|
||||||
cond: Eq("name", "test"),
|
|
||||||
validate: func(t *testing.T, cmp qb.Cmp, binds map[string]any) {
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, "test", binds["name"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "In condition",
|
|
||||||
cond: In("ids", []any{1, 2, 3}),
|
|
||||||
validate: func(t *testing.T, cmp qb.Cmp, binds map[string]any) {
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, []any{1, 2, 3}, binds["ids"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Gt condition",
|
|
||||||
cond: Gt("age", 18),
|
|
||||||
validate: func(t *testing.T, cmp qb.Cmp, binds map[string]any) {
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, 18, binds["age"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Lt condition",
|
|
||||||
cond: Lt("price", 100),
|
|
||||||
validate: func(t *testing.T, cmp qb.Cmp, binds map[string]any) {
|
|
||||||
assert.NotNil(t, cmp)
|
|
||||||
assert.Equal(t, 100, binds["price"])
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
cmp, binds := tt.cond.Build()
|
|
||||||
tt.validate(t, cmp, binds)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Where(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
condition Condition
|
|
||||||
validate func(*testing.T, *queryBuilder[testUser])
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "single condition",
|
|
||||||
condition: Eq("name", "Alice"),
|
|
||||||
validate: func(t *testing.T, qb *queryBuilder[testUser]) {
|
|
||||||
assert.Len(t, qb.conditions, 1)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple conditions",
|
|
||||||
condition: In("status", []any{"active", "pending"}),
|
|
||||||
validate: func(t *testing.T, qb *queryBuilder[testUser]) {
|
|
||||||
// 添加多個條件
|
|
||||||
cond := In("status", []any{"active", "pending"})
|
|
||||||
qb.Where(Eq("name", "test"))
|
|
||||||
qb.Where(cond)
|
|
||||||
assert.Len(t, qb.conditions, 2)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 repository,但我們可以測試鏈式調用
|
|
||||||
// 實際的執行需要資料庫連接
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_OrderBy(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
column string
|
|
||||||
order Order
|
|
||||||
validate func(*testing.T, *queryBuilder[testUser])
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "ASC order",
|
|
||||||
column: "created_at",
|
|
||||||
order: ASC,
|
|
||||||
validate: func(t *testing.T, qb *queryBuilder[testUser]) {
|
|
||||||
assert.Len(t, qb.orders, 1)
|
|
||||||
assert.Equal(t, "created_at", qb.orders[0].column)
|
|
||||||
assert.Equal(t, ASC, qb.orders[0].order)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "DESC order",
|
|
||||||
column: "updated_at",
|
|
||||||
order: DESC,
|
|
||||||
validate: func(t *testing.T, qb *queryBuilder[testUser]) {
|
|
||||||
assert.Len(t, qb.orders, 1)
|
|
||||||
assert.Equal(t, "updated_at", qb.orders[0].column)
|
|
||||||
assert.Equal(t, DESC, qb.orders[0].order)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple orders",
|
|
||||||
column: "name",
|
|
||||||
order: ASC,
|
|
||||||
validate: func(t *testing.T, qb *queryBuilder[testUser]) {
|
|
||||||
qb.OrderBy("created_at", DESC)
|
|
||||||
qb.OrderBy("name", ASC)
|
|
||||||
assert.Len(t, qb.orders, 2)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 repository
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Limit(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
limit int
|
|
||||||
expected int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "positive limit",
|
|
||||||
limit: 10,
|
|
||||||
expected: 10,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero limit",
|
|
||||||
limit: 0,
|
|
||||||
expected: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "large limit",
|
|
||||||
limit: 1000,
|
|
||||||
expected: 1000,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative limit",
|
|
||||||
limit: -1,
|
|
||||||
expected: -1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 repository
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Select(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
columns []string
|
|
||||||
expected int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "single column",
|
|
||||||
columns: []string{"name"},
|
|
||||||
expected: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple columns",
|
|
||||||
columns: []string{"name", "email", "age"},
|
|
||||||
expected: 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty columns",
|
|
||||||
columns: []string{},
|
|
||||||
expected: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "duplicate columns",
|
|
||||||
columns: []string{"name", "name"},
|
|
||||||
expected: 2,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 repository
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Chaining(t *testing.T) {
|
|
||||||
t.Run("chain multiple methods", func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 repository
|
|
||||||
// 實際的執行需要資料庫連接
|
|
||||||
// 這裡只是展示測試結構
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Scan_ErrorCases(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "nil destination",
|
|
||||||
description: "should return error when destination is nil",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid query",
|
|
||||||
description: "should return error when query is invalid",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要 mock session 或實際的資料庫連接
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_One_ErrorCases(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "no results",
|
|
||||||
description: "should return ErrNotFound when no results found",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "query error",
|
|
||||||
description: "should return error when query fails",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要 mock session 或實際的資料庫連接
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Count_ErrorCases(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "query error",
|
|
||||||
description: "should return error when query fails",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "ErrNotFound should return 0",
|
|
||||||
description: "should return 0 when ErrNotFound",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要 mock session 或實際的資料庫連接
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,265 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
"github.com/scylladb/gocqlx/v2"
|
|
||||||
"github.com/scylladb/gocqlx/v2/qb"
|
|
||||||
"github.com/scylladb/gocqlx/v2/table"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Repository 定義資料存取介面(小介面,符合 M3)
|
|
||||||
type Repository[T Table] interface {
|
|
||||||
Insert(ctx context.Context, doc T) error
|
|
||||||
Get(ctx context.Context, pk any) (T, error)
|
|
||||||
Update(ctx context.Context, doc T) error
|
|
||||||
Delete(ctx context.Context, pk any) error
|
|
||||||
InsertMany(ctx context.Context, docs []T) error
|
|
||||||
Query() QueryBuilder[T]
|
|
||||||
TryLock(ctx context.Context, doc T, opts ...LockOption) error
|
|
||||||
UnLock(ctx context.Context, doc T) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// repository 是 Repository 的具體實作
|
|
||||||
type repository[T Table] struct {
|
|
||||||
db *DB
|
|
||||||
keyspace string
|
|
||||||
table string
|
|
||||||
metadata table.Metadata
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRepository 獲取指定類型的 Repository
|
|
||||||
// keyspace 如果為空,使用預設 keyspace
|
|
||||||
func NewRepository[T Table](db *DB, keyspace string) (Repository[T], error) {
|
|
||||||
if keyspace == "" {
|
|
||||||
keyspace = db.defaultKeyspace
|
|
||||||
}
|
|
||||||
|
|
||||||
var zero T
|
|
||||||
metadata, err := generateMetadata(zero, keyspace)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to generate metadata: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &repository[T]{
|
|
||||||
db: db,
|
|
||||||
keyspace: keyspace,
|
|
||||||
table: metadata.Name,
|
|
||||||
metadata: metadata,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert 插入單筆資料
|
|
||||||
func (r *repository[T]) Insert(ctx context.Context, doc T) error {
|
|
||||||
t := table.New(r.metadata)
|
|
||||||
q := r.db.withContextAndTimestamp(ctx,
|
|
||||||
r.db.session.Query(t.Insert()).BindStruct(doc))
|
|
||||||
return q.ExecRelease()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get 根據主鍵查詢單筆資料
|
|
||||||
// 注意:pk 必須是完整的 Primary Key(包含所有 Partition Key 和 Clustering Key)
|
|
||||||
// 如果主鍵是多欄位,需要傳入包含所有主鍵欄位的 struct
|
|
||||||
// pk 可以是:string, int, int64, gocql.UUID, []byte 或包含主鍵欄位的 struct
|
|
||||||
func (r *repository[T]) Get(ctx context.Context, pk any) (T, error) {
|
|
||||||
var zero T
|
|
||||||
t := table.New(r.metadata)
|
|
||||||
|
|
||||||
// 使用 table.Get() 方法,它會自動根據 metadata 構建主鍵查詢
|
|
||||||
// 如果 pk 是 struct,使用 BindStruct;否則使用 Bind
|
|
||||||
var q *gocqlx.Queryx
|
|
||||||
if reflect.TypeOf(pk).Kind() == reflect.Struct {
|
|
||||||
q = r.db.withContextAndTimestamp(ctx,
|
|
||||||
r.db.session.Query(t.Get()).BindStruct(pk))
|
|
||||||
} else {
|
|
||||||
// 單一主鍵欄位的情況
|
|
||||||
// 注意:這只適用於單一 Partition Key 且無 Clustering Key 的情況
|
|
||||||
if len(r.metadata.PartKey) != 1 || len(r.metadata.SortKey) > 0 {
|
|
||||||
return zero, ErrInvalidInput.WithTable(r.table).WithError(
|
|
||||||
fmt.Errorf("single value primary key only supported for single partition key without clustering key"),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
q = r.db.withContextAndTimestamp(ctx,
|
|
||||||
r.db.session.Query(t.Get()).Bind(pk))
|
|
||||||
}
|
|
||||||
|
|
||||||
var result T
|
|
||||||
err := q.GetRelease(&result)
|
|
||||||
if errors.Is(err, gocql.ErrNotFound) {
|
|
||||||
return zero, ErrNotFound.WithTable(r.table)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return zero, ErrInvalidInput.WithTable(r.table).WithError(err)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update 更新資料(只更新非零值欄位)
|
|
||||||
func (r *repository[T]) Update(ctx context.Context, doc T) error {
|
|
||||||
return r.updateSelective(ctx, doc, false)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateAll 更新所有欄位(包括零值)
|
|
||||||
func (r *repository[T]) UpdateAll(ctx context.Context, doc T) error {
|
|
||||||
return r.updateSelective(ctx, doc, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateSelective 選擇性更新
|
|
||||||
func (r *repository[T]) updateSelective(ctx context.Context, doc T, includeZero bool) error {
|
|
||||||
// 重用現有的 BuildUpdateFields 邏輯
|
|
||||||
// 由於在不同套件,我們需要重新實作或導入
|
|
||||||
fields, err := r.buildUpdateFields(doc, includeZero)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt, names := r.buildUpdateStatement(fields.setCols, fields.whereCols)
|
|
||||||
setVals := append(fields.setVals, fields.whereVals...)
|
|
||||||
q := r.db.withContextAndTimestamp(ctx,
|
|
||||||
r.db.session.Query(stmt, names).Bind(setVals...))
|
|
||||||
|
|
||||||
return q.ExecRelease()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete 刪除資料
|
|
||||||
// pk 可以是:string, int, int64, gocql.UUID, []byte 或包含主鍵欄位的 struct
|
|
||||||
func (r *repository[T]) Delete(ctx context.Context, pk any) error {
|
|
||||||
t := table.New(r.metadata)
|
|
||||||
stmt, names := t.Delete()
|
|
||||||
q := r.db.withContextAndTimestamp(ctx,
|
|
||||||
r.db.session.Query(stmt, names).Bind(pk))
|
|
||||||
return q.ExecRelease()
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsertMany 批次插入資料
|
|
||||||
func (r *repository[T]) InsertMany(ctx context.Context, docs []T) error {
|
|
||||||
if len(docs) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用 Batch 操作
|
|
||||||
batch := r.db.session.NewBatch(gocql.LoggedBatch).WithContext(ctx)
|
|
||||||
t := table.New(r.metadata)
|
|
||||||
stmt, names := t.Insert()
|
|
||||||
|
|
||||||
for _, doc := range docs {
|
|
||||||
// 在 v2 中,需要手動提取值
|
|
||||||
v := reflect.ValueOf(doc)
|
|
||||||
if v.Kind() == reflect.Ptr {
|
|
||||||
v = v.Elem()
|
|
||||||
}
|
|
||||||
values := make([]interface{}, len(names))
|
|
||||||
for i, name := range names {
|
|
||||||
// 根據 metadata 找到對應的欄位
|
|
||||||
for j, col := range r.metadata.Columns {
|
|
||||||
if col == name {
|
|
||||||
fieldValue := v.Field(j)
|
|
||||||
values[i] = fieldValue.Interface()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
batch.Query(stmt, values...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.db.session.ExecuteBatch(batch)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query 返回查詢構建器
|
|
||||||
func (r *repository[T]) Query() QueryBuilder[T] {
|
|
||||||
return newQueryBuilder(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateFields 包含更新操作所需的欄位資訊
|
|
||||||
type updateFields struct {
|
|
||||||
setCols []string
|
|
||||||
setVals []any
|
|
||||||
whereCols []string
|
|
||||||
whereVals []any
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildUpdateFields 從 document 中提取更新所需的欄位資訊
|
|
||||||
func (r *repository[T]) buildUpdateFields(doc T, includeZero bool) (*updateFields, error) {
|
|
||||||
v := reflect.ValueOf(doc)
|
|
||||||
if v.Kind() == reflect.Ptr {
|
|
||||||
v = v.Elem()
|
|
||||||
}
|
|
||||||
typ := v.Type()
|
|
||||||
|
|
||||||
setCols := make([]string, 0)
|
|
||||||
setVals := make([]any, 0)
|
|
||||||
whereCols := make([]string, 0)
|
|
||||||
whereVals := make([]any, 0)
|
|
||||||
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
|
||||||
field := typ.Field(i)
|
|
||||||
tag := field.Tag.Get(DBFiledName)
|
|
||||||
if tag == "" || tag == "-" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
val := v.Field(i)
|
|
||||||
if !val.IsValid() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 主鍵欄位放入 WHERE 條件
|
|
||||||
if contains(r.metadata.PartKey, tag) || contains(r.metadata.SortKey, tag) {
|
|
||||||
whereCols = append(whereCols, tag)
|
|
||||||
whereVals = append(whereVals, val.Interface())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// 根據 includeZero 決定是否包含零值欄位
|
|
||||||
if !includeZero && isZero(val) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
setCols = append(setCols, tag)
|
|
||||||
setVals = append(setVals, val.Interface())
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(setCols) == 0 {
|
|
||||||
return nil, ErrNoFieldsToUpdate.WithTable(r.table)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &updateFields{
|
|
||||||
setCols: setCols,
|
|
||||||
setVals: setVals,
|
|
||||||
whereCols: whereCols,
|
|
||||||
whereVals: whereVals,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildUpdateStatement 構建 UPDATE CQL 語句
|
|
||||||
func (r *repository[T]) buildUpdateStatement(setCols, whereCols []string) (string, []string) {
|
|
||||||
builder := qb.Update(r.table).Set(setCols...)
|
|
||||||
for _, col := range whereCols {
|
|
||||||
builder = builder.Where(qb.Eq(col))
|
|
||||||
}
|
|
||||||
return builder.ToCql()
|
|
||||||
}
|
|
||||||
|
|
||||||
// contains 判斷字串是否存在於 slice 中
|
|
||||||
func contains(list []string, target string) bool {
|
|
||||||
for _, item := range list {
|
|
||||||
if item == target {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// isZero 判斷欄位是否為零值或 nil
|
|
||||||
func isZero(v reflect.Value) bool {
|
|
||||||
switch v.Kind() {
|
|
||||||
case reflect.Ptr, reflect.Interface, reflect.Map, reflect.Slice:
|
|
||||||
return v.IsNil()
|
|
||||||
default:
|
|
||||||
return reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,547 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestContains(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
list []string
|
|
||||||
target string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "target exists in list",
|
|
||||||
list: []string{"a", "b", "c"},
|
|
||||||
target: "b",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "target at beginning",
|
|
||||||
list: []string{"a", "b", "c"},
|
|
||||||
target: "a",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "target at end",
|
|
||||||
list: []string{"a", "b", "c"},
|
|
||||||
target: "c",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "target not in list",
|
|
||||||
list: []string{"a", "b", "c"},
|
|
||||||
target: "d",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty list",
|
|
||||||
list: []string{},
|
|
||||||
target: "a",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty target",
|
|
||||||
list: []string{"a", "b", "c"},
|
|
||||||
target: "",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "target in single element list",
|
|
||||||
list: []string{"a"},
|
|
||||||
target: "a",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "case sensitive",
|
|
||||||
list: []string{"A", "B", "C"},
|
|
||||||
target: "a",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "duplicate values",
|
|
||||||
list: []string{"a", "b", "a", "c"},
|
|
||||||
target: "a",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "long list",
|
|
||||||
list: []string{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"},
|
|
||||||
target: "j",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := contains(tt.list, tt.target)
|
|
||||||
assert.Equal(t, tt.want, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsZero(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
value any
|
|
||||||
expected bool
|
|
||||||
skip bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "nil pointer",
|
|
||||||
value: (*string)(nil),
|
|
||||||
expected: true,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-nil pointer",
|
|
||||||
value: stringPtr("test"),
|
|
||||||
expected: false,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nil slice",
|
|
||||||
value: []string(nil),
|
|
||||||
expected: true,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty slice",
|
|
||||||
value: []string{},
|
|
||||||
expected: false, // 空 slice 不是 nil
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "nil map",
|
|
||||||
value: map[string]int(nil),
|
|
||||||
expected: true,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty map",
|
|
||||||
value: map[string]int{},
|
|
||||||
expected: false, // 空 map 不是 nil
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero int",
|
|
||||||
value: 0,
|
|
||||||
expected: true,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-zero int",
|
|
||||||
value: 42,
|
|
||||||
expected: false,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero int64",
|
|
||||||
value: int64(0),
|
|
||||||
expected: true,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-zero int64",
|
|
||||||
value: int64(42),
|
|
||||||
expected: false,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero float64",
|
|
||||||
value: 0.0,
|
|
||||||
expected: true,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-zero float64",
|
|
||||||
value: 3.14,
|
|
||||||
expected: false,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty string",
|
|
||||||
value: "",
|
|
||||||
expected: true,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "non-empty string",
|
|
||||||
value: "test",
|
|
||||||
expected: false,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "false bool",
|
|
||||||
value: false,
|
|
||||||
expected: true,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "true bool",
|
|
||||||
value: true,
|
|
||||||
expected: false,
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "struct with zero values",
|
|
||||||
value: testUser{},
|
|
||||||
expected: true, // 所有欄位都是零值,應該返回 true
|
|
||||||
skip: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if tt.skip {
|
|
||||||
t.Skip("Skipping test")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 使用 reflect.ValueOf 來獲取 reflect.Value
|
|
||||||
v := reflect.ValueOf(tt.value)
|
|
||||||
// 檢查是否為零值(nil interface 會導致 zero Value)
|
|
||||||
if !v.IsValid() {
|
|
||||||
// 對於 nil interface,直接返回 true
|
|
||||||
assert.True(t, tt.expected)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
result := isZero(v)
|
|
||||||
assert.Equal(t, tt.expected, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewRepository(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
keyspace string
|
|
||||||
wantErr bool
|
|
||||||
validate func(*testing.T, Repository[testUser], *DB)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "valid keyspace",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, repo Repository[testUser], db *DB) {
|
|
||||||
assert.NotNil(t, repo)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty keyspace uses default",
|
|
||||||
keyspace: "",
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, repo Repository[testUser], db *DB) {
|
|
||||||
assert.NotNil(t, repo)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 DB 實例
|
|
||||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRepository_Insert(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
description string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "successful insert",
|
|
||||||
description: "should insert document successfully",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "duplicate key",
|
|
||||||
description: "should return error on duplicate key",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid document",
|
|
||||||
description: "should return error for invalid document",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要 mock session 或實際的資料庫連接
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRepository_Get(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
pk any
|
|
||||||
description string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "found with string key",
|
|
||||||
pk: "test-id",
|
|
||||||
description: "should return document when found",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not found",
|
|
||||||
pk: "non-existent",
|
|
||||||
description: "should return ErrNotFound when not found",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid primary key structure",
|
|
||||||
pk: "single-key",
|
|
||||||
description: "should return error for invalid key structure",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "struct primary key",
|
|
||||||
pk: testUser{ID: "test-id"},
|
|
||||||
description: "should work with struct primary key",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要 mock session 或實際的資料庫連接
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRepository_Update(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
description string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "successful update",
|
|
||||||
description: "should update document successfully",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not found",
|
|
||||||
description: "should return error when document not found",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no fields to update",
|
|
||||||
description: "should return error when no fields to update",
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要 mock session 或實際的資料庫連接
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRepository_Delete(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
pk any
|
|
||||||
description string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "successful delete",
|
|
||||||
pk: "test-id",
|
|
||||||
description: "should delete document successfully",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not found",
|
|
||||||
pk: "non-existent",
|
|
||||||
description: "should not return error when not found (idempotent)",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要 mock session 或實際的資料庫連接
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRepository_InsertMany(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
docs []testUser
|
|
||||||
description string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "empty slice",
|
|
||||||
docs: []testUser{},
|
|
||||||
description: "should return nil for empty slice",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single document",
|
|
||||||
docs: []testUser{{ID: "1", Name: "Alice"}},
|
|
||||||
description: "should insert single document",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple documents",
|
|
||||||
docs: []testUser{{ID: "1", Name: "Alice"}, {ID: "2", Name: "Bob"}},
|
|
||||||
description: "should insert multiple documents",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "large batch",
|
|
||||||
docs: make([]testUser, 100),
|
|
||||||
description: "should handle large batch",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要 mock session 或實際的資料庫連接
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestRepository_Query(t *testing.T) {
|
|
||||||
t.Run("should return QueryBuilder", func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 repository
|
|
||||||
// 實際的執行需要資料庫連接
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildUpdateStatement(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
setCols []string
|
|
||||||
whereCols []string
|
|
||||||
table string
|
|
||||||
validate func(*testing.T, string, []string)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "single set column, single where column",
|
|
||||||
setCols: []string{"name"},
|
|
||||||
whereCols: []string{"id"},
|
|
||||||
table: "users",
|
|
||||||
validate: func(t *testing.T, stmt string, names []string) {
|
|
||||||
assert.Contains(t, stmt, "UPDATE")
|
|
||||||
assert.Contains(t, stmt, "users")
|
|
||||||
assert.Contains(t, stmt, "SET")
|
|
||||||
assert.Contains(t, stmt, "WHERE")
|
|
||||||
assert.Len(t, names, 2) // name, id
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple set columns, single where column",
|
|
||||||
setCols: []string{"name", "email", "age"},
|
|
||||||
whereCols: []string{"id"},
|
|
||||||
table: "users",
|
|
||||||
validate: func(t *testing.T, stmt string, names []string) {
|
|
||||||
assert.Contains(t, stmt, "UPDATE")
|
|
||||||
assert.Contains(t, stmt, "users")
|
|
||||||
assert.Len(t, names, 4) // name, email, age, id
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "single set column, multiple where columns",
|
|
||||||
setCols: []string{"status"},
|
|
||||||
whereCols: []string{"user_id", "account_id"},
|
|
||||||
table: "accounts",
|
|
||||||
validate: func(t *testing.T, stmt string, names []string) {
|
|
||||||
assert.Contains(t, stmt, "UPDATE")
|
|
||||||
assert.Contains(t, stmt, "accounts")
|
|
||||||
assert.Len(t, names, 3) // status, user_id, account_id
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "multiple set and where columns",
|
|
||||||
setCols: []string{"name", "email"},
|
|
||||||
whereCols: []string{"id", "version"},
|
|
||||||
table: "users",
|
|
||||||
validate: func(t *testing.T, stmt string, names []string) {
|
|
||||||
assert.Contains(t, stmt, "UPDATE")
|
|
||||||
assert.Len(t, names, 4) // name, email, id, version
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 創建一個臨時的 repository 來測試 buildUpdateStatement
|
|
||||||
// 注意:這需要一個有效的 metadata
|
|
||||||
// 使用 testUser 的 metadata
|
|
||||||
var zero testUser
|
|
||||||
metadata, err := generateMetadata(zero, "test_keyspace")
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
repo := &repository[testUser]{
|
|
||||||
table: tt.table,
|
|
||||||
metadata: metadata,
|
|
||||||
}
|
|
||||||
stmt, names := repo.buildUpdateStatement(tt.setCols, tt.whereCols)
|
|
||||||
tt.validate(t, stmt, names)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildUpdateFields(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
doc testUser
|
|
||||||
includeZero bool
|
|
||||||
wantErr bool
|
|
||||||
validate func(*testing.T, *updateFields)
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "update with includeZero false",
|
|
||||||
doc: testUser{ID: "1", Name: "Alice", Email: "alice@example.com"},
|
|
||||||
includeZero: false,
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, fields *updateFields) {
|
|
||||||
assert.NotEmpty(t, fields.setCols)
|
|
||||||
assert.Contains(t, fields.whereCols, "id")
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "update with includeZero true",
|
|
||||||
doc: testUser{ID: "1", Name: "", Email: ""},
|
|
||||||
includeZero: true,
|
|
||||||
wantErr: false,
|
|
||||||
validate: func(t *testing.T, fields *updateFields) {
|
|
||||||
assert.NotEmpty(t, fields.setCols)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no fields to update",
|
|
||||||
doc: testUser{ID: "1"},
|
|
||||||
includeZero: false,
|
|
||||||
wantErr: true,
|
|
||||||
validate: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 repository 和 metadata
|
|
||||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,289 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SAIIndexType 定義 SAI 索引類型
|
|
||||||
type SAIIndexType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
// SAIIndexTypeStandard 標準索引(等於查詢)
|
|
||||||
SAIIndexTypeStandard SAIIndexType = "STANDARD"
|
|
||||||
// SAIIndexTypeCollection 集合索引(用於 list、set、map)
|
|
||||||
SAIIndexTypeCollection SAIIndexType = "COLLECTION"
|
|
||||||
// SAIIndexTypeFullText 全文索引
|
|
||||||
SAIIndexTypeFullText SAIIndexType = "FULL_TEXT"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SAIIndexOptions 定義 SAI 索引選項
|
|
||||||
type SAIIndexOptions struct {
|
|
||||||
IndexType SAIIndexType // 索引類型
|
|
||||||
IsAsync bool // 是否異步建立索引
|
|
||||||
CaseSensitive bool // 是否區分大小寫(用於全文索引)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultSAIIndexOptions 返回預設的 SAI 索引選項
|
|
||||||
func DefaultSAIIndexOptions() *SAIIndexOptions {
|
|
||||||
return &SAIIndexOptions{
|
|
||||||
IndexType: SAIIndexTypeStandard,
|
|
||||||
IsAsync: false,
|
|
||||||
CaseSensitive: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateSAIIndex 建立 SAI 索引
|
|
||||||
// keyspace: keyspace 名稱
|
|
||||||
// table: 資料表名稱
|
|
||||||
// column: 欄位名稱
|
|
||||||
// indexName: 索引名稱(可選,如果為空則自動生成)
|
|
||||||
// opts: 索引選項(可選,如果為 nil 則使用預設選項)
|
|
||||||
func (db *DB) CreateSAIIndex(ctx context.Context, keyspace, table, column, indexName string, opts *SAIIndexOptions) error {
|
|
||||||
// 檢查是否支援 SAI
|
|
||||||
if !db.saiSupported {
|
|
||||||
return ErrInvalidInput.WithError(fmt.Errorf("SAI is not supported in Cassandra version %s (requires 4.0.9+ or 5.0+)", db.version))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證參數
|
|
||||||
if keyspace == "" {
|
|
||||||
return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
|
|
||||||
}
|
|
||||||
if table == "" {
|
|
||||||
return ErrInvalidInput.WithError(fmt.Errorf("table is required"))
|
|
||||||
}
|
|
||||||
if column == "" {
|
|
||||||
return ErrInvalidInput.WithError(fmt.Errorf("column is required"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用預設選項如果未提供
|
|
||||||
if opts == nil {
|
|
||||||
opts = DefaultSAIIndexOptions()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 生成索引名稱如果未提供
|
|
||||||
if indexName == "" {
|
|
||||||
indexName = fmt.Sprintf("%s_%s_sai_idx", table, column)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建 CREATE INDEX 語句
|
|
||||||
var stmt strings.Builder
|
|
||||||
stmt.WriteString("CREATE CUSTOM INDEX IF NOT EXISTS ")
|
|
||||||
stmt.WriteString(indexName)
|
|
||||||
stmt.WriteString(" ON ")
|
|
||||||
stmt.WriteString(keyspace)
|
|
||||||
stmt.WriteString(".")
|
|
||||||
stmt.WriteString(table)
|
|
||||||
stmt.WriteString(" (")
|
|
||||||
stmt.WriteString(column)
|
|
||||||
stmt.WriteString(") USING 'StorageAttachedIndex'")
|
|
||||||
|
|
||||||
// 添加選項
|
|
||||||
var options []string
|
|
||||||
if opts.IsAsync {
|
|
||||||
options = append(options, "'async'='true'")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 根據索引類型添加特定選項
|
|
||||||
switch opts.IndexType {
|
|
||||||
case SAIIndexTypeFullText:
|
|
||||||
if !opts.CaseSensitive {
|
|
||||||
options = append(options, "'case_sensitive'='false'")
|
|
||||||
} else {
|
|
||||||
options = append(options, "'case_sensitive'='true'")
|
|
||||||
}
|
|
||||||
case SAIIndexTypeCollection:
|
|
||||||
// Collection 索引不需要額外選項
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果有選項,添加到語句中
|
|
||||||
if len(options) > 0 {
|
|
||||||
stmt.WriteString(" WITH OPTIONS = {")
|
|
||||||
stmt.WriteString(strings.Join(options, ", "))
|
|
||||||
stmt.WriteString("}")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 執行建立索引語句
|
|
||||||
query := db.session.Query(stmt.String(), nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum)
|
|
||||||
|
|
||||||
err := query.ExecRelease()
|
|
||||||
if err != nil {
|
|
||||||
return ErrInvalidInput.WithError(fmt.Errorf("failed to create SAI index: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DropSAIIndex 刪除 SAI 索引
|
|
||||||
// keyspace: keyspace 名稱
|
|
||||||
// indexName: 索引名稱
|
|
||||||
func (db *DB) DropSAIIndex(ctx context.Context, keyspace, indexName string) error {
|
|
||||||
// 驗證參數
|
|
||||||
if keyspace == "" {
|
|
||||||
return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
|
|
||||||
}
|
|
||||||
if indexName == "" {
|
|
||||||
return ErrInvalidInput.WithError(fmt.Errorf("index name is required"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建 DROP INDEX 語句
|
|
||||||
stmt := fmt.Sprintf("DROP INDEX IF EXISTS %s.%s", keyspace, indexName)
|
|
||||||
|
|
||||||
// 執行刪除索引語句
|
|
||||||
query := db.session.Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum)
|
|
||||||
|
|
||||||
err := query.ExecRelease()
|
|
||||||
if err != nil {
|
|
||||||
return ErrInvalidInput.WithError(fmt.Errorf("failed to drop SAI index: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListSAIIndexes 列出指定資料表的所有 SAI 索引
|
|
||||||
// keyspace: keyspace 名稱
|
|
||||||
// table: 資料表名稱
|
|
||||||
func (db *DB) ListSAIIndexes(ctx context.Context, keyspace, table string) ([]SAIIndexInfo, error) {
|
|
||||||
// 驗證參數
|
|
||||||
if keyspace == "" {
|
|
||||||
return nil, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
|
|
||||||
}
|
|
||||||
if table == "" {
|
|
||||||
return nil, ErrInvalidInput.WithError(fmt.Errorf("table is required"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢系統表獲取索引資訊
|
|
||||||
// system_schema.indexes 表的結構:keyspace_name, table_name, index_name, kind, options
|
|
||||||
stmt := `
|
|
||||||
SELECT index_name, kind, options
|
|
||||||
FROM system_schema.indexes
|
|
||||||
WHERE keyspace_name = ? AND table_name = ?
|
|
||||||
`
|
|
||||||
|
|
||||||
var indexes []SAIIndexInfo
|
|
||||||
iter := db.session.Query(stmt, []string{"keyspace_name", "table_name"}).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.One).
|
|
||||||
Bind(keyspace, table).
|
|
||||||
Iter()
|
|
||||||
|
|
||||||
var indexName, kind string
|
|
||||||
var options map[string]string
|
|
||||||
for iter.Scan(&indexName, &kind, &options) {
|
|
||||||
// 檢查是否為 SAI 索引(kind = 'CUSTOM' 且 class_name 包含 StorageAttachedIndex)
|
|
||||||
if kind == "CUSTOM" {
|
|
||||||
if className, ok := options["class_name"]; ok && strings.Contains(className, "StorageAttachedIndex") {
|
|
||||||
// 從 options 中提取 target(欄位名稱)
|
|
||||||
columnName := ""
|
|
||||||
if target, ok := options["target"]; ok {
|
|
||||||
columnName = strings.Trim(target, "()\"'")
|
|
||||||
}
|
|
||||||
indexes = append(indexes, SAIIndexInfo{
|
|
||||||
Name: indexName,
|
|
||||||
Type: "StorageAttachedIndex",
|
|
||||||
Options: options,
|
|
||||||
Column: columnName,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := iter.Close(); err != nil {
|
|
||||||
return nil, ErrInvalidInput.WithError(fmt.Errorf("failed to list SAI indexes: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return indexes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SAIIndexInfo 表示 SAI 索引資訊
|
|
||||||
type SAIIndexInfo struct {
|
|
||||||
Name string // 索引名稱
|
|
||||||
Type string // 索引類型
|
|
||||||
Options map[string]string // 索引選項
|
|
||||||
Column string // 索引欄位名稱
|
|
||||||
}
|
|
||||||
|
|
||||||
// CheckSAIIndexExists 檢查 SAI 索引是否存在
|
|
||||||
// keyspace: keyspace 名稱
|
|
||||||
// indexName: 索引名稱
|
|
||||||
func (db *DB) CheckSAIIndexExists(ctx context.Context, keyspace, indexName string) (bool, error) {
|
|
||||||
// 驗證參數
|
|
||||||
if keyspace == "" {
|
|
||||||
return false, ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
|
|
||||||
}
|
|
||||||
if indexName == "" {
|
|
||||||
return false, ErrInvalidInput.WithError(fmt.Errorf("index name is required"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢系統表檢查索引是否存在
|
|
||||||
stmt := `
|
|
||||||
SELECT index_name, kind, options
|
|
||||||
FROM system_schema.indexes
|
|
||||||
WHERE keyspace_name = ? AND index_name = ?
|
|
||||||
LIMIT 1
|
|
||||||
`
|
|
||||||
|
|
||||||
var foundIndexName, kind string
|
|
||||||
var options map[string]string
|
|
||||||
err := db.session.Query(stmt, []string{"keyspace_name", "index_name"}).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.One).
|
|
||||||
Bind(keyspace, indexName).
|
|
||||||
Scan(&foundIndexName, &kind, &options)
|
|
||||||
|
|
||||||
if err == gocql.ErrNotFound {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return false, ErrInvalidInput.WithError(fmt.Errorf("failed to check SAI index existence: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 檢查是否為 SAI 索引
|
|
||||||
if kind == "CUSTOM" {
|
|
||||||
if className, ok := options["class_name"]; ok && strings.Contains(className, "StorageAttachedIndex") {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WaitForSAIIndex 等待 SAI 索引建立完成(用於異步建立)
|
|
||||||
// keyspace: keyspace 名稱
|
|
||||||
// indexName: 索引名稱
|
|
||||||
// maxWaitTime: 最大等待時間(秒)
|
|
||||||
func (db *DB) WaitForSAIIndex(ctx context.Context, keyspace, indexName string, maxWaitTime int) error {
|
|
||||||
// 驗證參數
|
|
||||||
if keyspace == "" {
|
|
||||||
return ErrInvalidInput.WithError(fmt.Errorf("keyspace is required"))
|
|
||||||
}
|
|
||||||
if indexName == "" {
|
|
||||||
return ErrInvalidInput.WithError(fmt.Errorf("index name is required"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢索引狀態
|
|
||||||
// 注意:Cassandra 沒有直接的索引狀態查詢,這裡需要通過檢查索引是否可用來判斷
|
|
||||||
// 實際實作可能需要根據具體的 Cassandra 版本調整
|
|
||||||
|
|
||||||
// 簡單實作:檢查索引是否存在
|
|
||||||
exists, err := db.CheckSAIIndexExists(ctx, keyspace, indexName)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
return ErrInvalidInput.WithError(fmt.Errorf("index %s does not exist", indexName))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 注意:實際的等待邏輯可能需要查詢系統表或使用其他方法
|
|
||||||
// 這裡只是基本框架,實際使用時可能需要根據具體需求調整
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,267 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDefaultSAIIndexOptions(t *testing.T) {
|
|
||||||
opts := DefaultSAIIndexOptions()
|
|
||||||
assert.NotNil(t, opts)
|
|
||||||
assert.Equal(t, SAIIndexTypeStandard, opts.IndexType)
|
|
||||||
assert.False(t, opts.IsAsync)
|
|
||||||
assert.True(t, opts.CaseSensitive)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateSAIIndex_Validation(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
keyspace string
|
|
||||||
table string
|
|
||||||
column string
|
|
||||||
indexName string
|
|
||||||
opts *SAIIndexOptions
|
|
||||||
wantErr bool
|
|
||||||
errMsg string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "missing keyspace",
|
|
||||||
keyspace: "",
|
|
||||||
table: "test_table",
|
|
||||||
column: "test_column",
|
|
||||||
indexName: "test_idx",
|
|
||||||
opts: nil,
|
|
||||||
wantErr: true,
|
|
||||||
errMsg: "keyspace is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "missing table",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
table: "",
|
|
||||||
column: "test_column",
|
|
||||||
indexName: "test_idx",
|
|
||||||
opts: nil,
|
|
||||||
wantErr: true,
|
|
||||||
errMsg: "table is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "missing column",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
table: "test_table",
|
|
||||||
column: "",
|
|
||||||
indexName: "test_idx",
|
|
||||||
opts: nil,
|
|
||||||
wantErr: true,
|
|
||||||
errMsg: "column is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid parameters with default options",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
table: "test_table",
|
|
||||||
column: "test_column",
|
|
||||||
indexName: "test_idx",
|
|
||||||
opts: nil,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid parameters with custom options",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
table: "test_table",
|
|
||||||
column: "test_column",
|
|
||||||
indexName: "test_idx",
|
|
||||||
opts: &SAIIndexOptions{
|
|
||||||
IndexType: SAIIndexTypeFullText,
|
|
||||||
IsAsync: true,
|
|
||||||
CaseSensitive: false,
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "auto-generate index name",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
table: "test_table",
|
|
||||||
column: "test_column",
|
|
||||||
indexName: "",
|
|
||||||
opts: nil,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 DB 實例和 SAI 支援
|
|
||||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDropSAIIndex_Validation(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
keyspace string
|
|
||||||
indexName string
|
|
||||||
wantErr bool
|
|
||||||
errMsg string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "missing keyspace",
|
|
||||||
keyspace: "",
|
|
||||||
indexName: "test_idx",
|
|
||||||
wantErr: true,
|
|
||||||
errMsg: "keyspace is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "missing index name",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
indexName: "",
|
|
||||||
wantErr: true,
|
|
||||||
errMsg: "index name is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid parameters",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
indexName: "test_idx",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 DB 實例
|
|
||||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestListSAIIndexes_Validation(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
keyspace string
|
|
||||||
table string
|
|
||||||
wantErr bool
|
|
||||||
errMsg string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "missing keyspace",
|
|
||||||
keyspace: "",
|
|
||||||
table: "test_table",
|
|
||||||
wantErr: true,
|
|
||||||
errMsg: "keyspace is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "missing table",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
table: "",
|
|
||||||
wantErr: true,
|
|
||||||
errMsg: "table is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid parameters",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
table: "test_table",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 DB 實例
|
|
||||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCheckSAIIndexExists_Validation(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
keyspace string
|
|
||||||
indexName string
|
|
||||||
wantErr bool
|
|
||||||
errMsg string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "missing keyspace",
|
|
||||||
keyspace: "",
|
|
||||||
indexName: "test_idx",
|
|
||||||
wantErr: true,
|
|
||||||
errMsg: "keyspace is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "missing index name",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
indexName: "",
|
|
||||||
wantErr: true,
|
|
||||||
errMsg: "index name is required",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "valid parameters",
|
|
||||||
keyspace: "test_keyspace",
|
|
||||||
indexName: "test_idx",
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
// 注意:這需要一個有效的 DB 實例
|
|
||||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
|
||||||
_ = tt
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSAIIndexType_Constants(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
indexType SAIIndexType
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "standard index type",
|
|
||||||
indexType: SAIIndexTypeStandard,
|
|
||||||
expected: "STANDARD",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "collection index type",
|
|
||||||
indexType: SAIIndexTypeCollection,
|
|
||||||
expected: "COLLECTION",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "full text index type",
|
|
||||||
indexType: SAIIndexTypeFullText,
|
|
||||||
expected: "FULL_TEXT",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
assert.Equal(t, tt.expected, string(tt.indexType))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateSAIIndex_NotSupported(t *testing.T) {
|
|
||||||
t.Run("should return error when SAI not supported", func(t *testing.T) {
|
|
||||||
// 注意:這需要一個不支援 SAI 的 DB 實例
|
|
||||||
// 在實際測試中,需要使用 mock 或 testcontainers
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreateSAIIndex_IndexNameGeneration(t *testing.T) {
|
|
||||||
t.Run("should generate index name when not provided", func(t *testing.T) {
|
|
||||||
// 測試自動生成索引名稱的邏輯
|
|
||||||
// 格式應該是: {table}_{column}_sai_idx
|
|
||||||
table := "users"
|
|
||||||
column := "email"
|
|
||||||
expected := "users_email_sai_idx"
|
|
||||||
|
|
||||||
// 這裡只是測試命名邏輯,實際建立需要 DB 實例
|
|
||||||
generated := fmt.Sprintf("%s_%s_sai_idx", table, column)
|
|
||||||
assert.Equal(t, expected, generated)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
@ -1,91 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/testcontainers/testcontainers-go"
|
|
||||||
"github.com/testcontainers/testcontainers-go/wait"
|
|
||||||
)
|
|
||||||
|
|
||||||
// startCassandraContainer 啟動 Cassandra 測試容器
|
|
||||||
func startCassandraContainer(ctx context.Context) (string, string, func(), error) {
|
|
||||||
req := testcontainers.ContainerRequest{
|
|
||||||
Image: "cassandra:4.1",
|
|
||||||
ExposedPorts: []string{"9042/tcp"},
|
|
||||||
WaitingFor: wait.ForListeningPort("9042/tcp"),
|
|
||||||
Env: map[string]string{
|
|
||||||
"CASSANDRA_CLUSTER_NAME": "test-cluster",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
cassandraC, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
|
||||||
ContainerRequest: req,
|
|
||||||
Started: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return "", "", nil, fmt.Errorf("failed to start Cassandra container: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
port, err := cassandraC.MappedPort(ctx, "9042")
|
|
||||||
if err != nil {
|
|
||||||
cassandraC.Terminate(ctx)
|
|
||||||
return "", "", nil, fmt.Errorf("failed to get mapped port: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
host, err := cassandraC.Host(ctx)
|
|
||||||
if err != nil {
|
|
||||||
cassandraC.Terminate(ctx)
|
|
||||||
return "", "", nil, fmt.Errorf("failed to get host: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
tearDown := func() {
|
|
||||||
_ = cassandraC.Terminate(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Cassandra test container started: %s:%s\n", host, port.Port())
|
|
||||||
|
|
||||||
return host, port.Port(), tearDown, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupTestDB 設置測試用的 DB 實例
|
|
||||||
func setupTestDB(t testing.TB) (*DB, func()) {
|
|
||||||
ctx := context.Background()
|
|
||||||
host, port, tearDown, err := startCassandraContainer(ctx)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to start Cassandra container: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
portInt, err := strconv.Atoi(port)
|
|
||||||
if err != nil {
|
|
||||||
tearDown()
|
|
||||||
t.Fatalf("Failed to convert port to int: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := New(
|
|
||||||
WithHosts(host),
|
|
||||||
WithPort(portInt),
|
|
||||||
WithKeyspace("test_keyspace"),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
tearDown()
|
|
||||||
t.Fatalf("Failed to create DB: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 創建 keyspace
|
|
||||||
createKeyspaceStmt := "CREATE KEYSPACE IF NOT EXISTS test_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}"
|
|
||||||
if err := db.session.Query(createKeyspaceStmt, nil).Exec(); err != nil {
|
|
||||||
db.Close()
|
|
||||||
tearDown()
|
|
||||||
t.Fatalf("Failed to create keyspace: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cleanup := func() {
|
|
||||||
db.Close()
|
|
||||||
tearDown()
|
|
||||||
}
|
|
||||||
|
|
||||||
return db, cleanup
|
|
||||||
}
|
|
||||||
|
|
@ -1,32 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Table 定義資料表模型必須實作的介面
|
|
||||||
type Table interface {
|
|
||||||
TableName() string
|
|
||||||
}
|
|
||||||
|
|
||||||
// PrimaryKey 定義主鍵類型(使用類型約束)
|
|
||||||
// 注意:Go 1.18+ 才支持類型約束,如果需要兼容舊版本,可以使用 interface{}
|
|
||||||
type PrimaryKey interface {
|
|
||||||
~string | ~int | ~int64 | gocql.UUID | []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// Order 定義排序順序
|
|
||||||
type Order int
|
|
||||||
|
|
||||||
const (
|
|
||||||
ASC Order = 0
|
|
||||||
DESC Order = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// 將 Order 轉換為 toGocqlX 的 Order
|
|
||||||
func (o Order) toGocqlX() string {
|
|
||||||
if o == DESC {
|
|
||||||
return "DESC"
|
|
||||||
}
|
|
||||||
return "ASC"
|
|
||||||
}
|
|
||||||
|
|
@ -1,140 +0,0 @@
|
||||||
package cassandra
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestOrder_ToGocqlX(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
order Order
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "ASC order",
|
|
||||||
order: ASC,
|
|
||||||
expected: "ASC",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "DESC order",
|
|
||||||
order: DESC,
|
|
||||||
expected: "DESC",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "zero value (defaults to ASC)",
|
|
||||||
order: Order(0),
|
|
||||||
expected: "ASC",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "invalid order value (defaults to ASC)",
|
|
||||||
order: Order(99),
|
|
||||||
expected: "ASC",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "negative order value (defaults to ASC)",
|
|
||||||
order: Order(-1),
|
|
||||||
expected: "ASC",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := tt.order.toGocqlX()
|
|
||||||
assert.Equal(t, tt.expected, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOrder_Constants(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
constant Order
|
|
||||||
expected int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "ASC constant",
|
|
||||||
constant: ASC,
|
|
||||||
expected: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "DESC constant",
|
|
||||||
constant: DESC,
|
|
||||||
expected: 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
assert.Equal(t, tt.expected, int(tt.constant))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOrder_StringConversion(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
order Order
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "ASC to string",
|
|
||||||
order: ASC,
|
|
||||||
expected: "ASC",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "DESC to string",
|
|
||||||
order: DESC,
|
|
||||||
expected: "DESC",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := tt.order.toGocqlX()
|
|
||||||
assert.Equal(t, tt.expected, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOrder_Comparison(t *testing.T) {
|
|
||||||
t.Run("ASC should equal 0", func(t *testing.T) {
|
|
||||||
assert.Equal(t, Order(0), ASC)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("DESC should equal 1", func(t *testing.T) {
|
|
||||||
assert.Equal(t, Order(1), DESC)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("ASC should not equal DESC", func(t *testing.T) {
|
|
||||||
assert.NotEqual(t, ASC, DESC)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOrder_EdgeCases(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
order Order
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "maximum int value",
|
|
||||||
order: Order(^int(0)),
|
|
||||||
expected: "ASC", // 不是 DESC,所以返回 ASC
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "minimum int value",
|
|
||||||
order: Order(-^int(0) - 1),
|
|
||||||
expected: "ASC",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := tt.order.toGocqlX()
|
|
||||||
assert.Equal(t, tt.expected, result)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,18 +0,0 @@
|
||||||
package entity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NotificationCursor tracks the last seen notification for a user.
|
|
||||||
type NotificationCursor struct {
|
|
||||||
UID string `db:"user_id" partition_key:"true"`
|
|
||||||
LastSeenTS gocql.UUID `db:"last_seen_ts"`
|
|
||||||
UpdatedAt time.Time `db:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (uc *NotificationCursor) TableName() string {
|
|
||||||
return "notification_cursor"
|
|
||||||
}
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
package entity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/notification/domain/notification"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NotificationEvent represents an event that triggers a notification.
|
|
||||||
type NotificationEvent struct {
|
|
||||||
EventID gocql.UUID `db:"event_id" partition_key:"true"` // 事件 ID
|
|
||||||
EventType string `db:"event_type"` // POST_PUBLISHED / COMMENT_ADDED / MENTIONED ...
|
|
||||||
ActorUID string `db:"actor_uid"` // 觸發者 UID
|
|
||||||
ObjectType string `db:"object_type"` // POST / COMMENT / USER ...
|
|
||||||
ObjectID string `db:"object_id"` // 對應物件 ID(post_id 等)
|
|
||||||
Title string `db:"title"` // 顯示用標題
|
|
||||||
Body string `db:"body"` // 顯示用內容 / 摘要
|
|
||||||
Payload string `db:"payload"` // JSON string(額外欄位,例如 {"postId": "..."})
|
|
||||||
Priority notification.NotifyPriority `db:"priority"` // 1=critical, 2=high, 3=normal, 4=low
|
|
||||||
CreatedAt time.Time `db:"created_at"` // 事件時間(方便做 cross table 查詢)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ue *NotificationEvent) TableName() string {
|
|
||||||
return "notification_event"
|
|
||||||
}
|
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
package entity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/notification/domain/notification"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// UserNotification represents a notification for a specific user.
|
|
||||||
type UserNotification struct {
|
|
||||||
UserID string `db:"user_id" partition_key:"true"` // 收通知的人
|
|
||||||
Bucket string `db:"bucket" partition_key:"true"` // 分桶,例如 '2025-11' 或 '2025-11-17'
|
|
||||||
TS gocql.UUID `db:"ts" clustering_key:"true"` // 通知時間,用 now() 產生,排序用(UTC0)
|
|
||||||
EventID gocql.UUID `db:"event_id"` // 對應 notification_event.event_id
|
|
||||||
Status notification.NotifyStatus `db:"status"` // UNREAD / READ / ARCHIVED
|
|
||||||
ReadAt time.Time `db:"read_at"` // 已讀時間(非必填)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (un *UserNotification) TableName() string {
|
|
||||||
return "user_notification"
|
|
||||||
}
|
|
||||||
|
|
@ -1,31 +0,0 @@
|
||||||
package notification
|
|
||||||
|
|
||||||
type NotifyPriority int8
|
|
||||||
|
|
||||||
func (n NotifyPriority) ToString() string {
|
|
||||||
status, ok := priorityMap[n]
|
|
||||||
if !ok {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
|
|
||||||
return status
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
Critical NotifyPriority = 1
|
|
||||||
High NotifyPriority = 2
|
|
||||||
Normal NotifyPriority = 3
|
|
||||||
Low NotifyPriority = 4
|
|
||||||
|
|
||||||
CriticalStr = "critical"
|
|
||||||
HighStr = "high"
|
|
||||||
NormalStr = "normal"
|
|
||||||
LowStr = "low"
|
|
||||||
)
|
|
||||||
|
|
||||||
var priorityMap = map[NotifyPriority]string{
|
|
||||||
Critical: CriticalStr,
|
|
||||||
High: HighStr,
|
|
||||||
Normal: NormalStr,
|
|
||||||
Low: LowStr,
|
|
||||||
}
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
package notification
|
|
||||||
|
|
||||||
type NotifyStatus int8
|
|
||||||
|
|
||||||
func (n NotifyStatus) ToString() string {
|
|
||||||
status, ok := statusMap[n]
|
|
||||||
if !ok {
|
|
||||||
return "unknown"
|
|
||||||
}
|
|
||||||
|
|
||||||
return status
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
UNREAD NotifyStatus = 1
|
|
||||||
READ NotifyStatus = 2
|
|
||||||
ARCHIVED NotifyStatus = 3
|
|
||||||
|
|
||||||
UNREADStr = "UNREAD"
|
|
||||||
READStr = "READ"
|
|
||||||
ARCHIVEDStr = "ARCHIVED"
|
|
||||||
)
|
|
||||||
|
|
||||||
var statusMap = map[NotifyStatus]string{
|
|
||||||
UNREAD: UNREADStr,
|
|
||||||
READ: READStr,
|
|
||||||
ARCHIVED: ARCHIVEDStr,
|
|
||||||
}
|
|
||||||
|
|
@ -1,82 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/notification/domain/entity"
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
type NotificationRepository interface {
|
|
||||||
NotificationEventRepository
|
|
||||||
UserNotificationRepository
|
|
||||||
NotificationCursorRepository
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---- 1. Event ----
|
|
||||||
// 專心管「事件本體」,fan-out 前先寫這張。
|
|
||||||
// 通常由上游 domain event consumer 呼叫 Create。
|
|
||||||
|
|
||||||
type QueryNotificationEventParam struct {
|
|
||||||
ObjectID *string
|
|
||||||
ObjectType *string
|
|
||||||
Limit *int
|
|
||||||
}
|
|
||||||
|
|
||||||
type NotificationEventRepository interface {
|
|
||||||
// Create 建立一筆新的 NotificationEvent。
|
|
||||||
Create(ctx context.Context, e *entity.NotificationEvent) error
|
|
||||||
// GetByID 依 EventID 取得事件。
|
|
||||||
GetByID(ctx context.Context, id string) (*entity.NotificationEvent, error)
|
|
||||||
// ListByObject 依 object_type + object_id 查詢相關事件(選用,debug / 後台用)。
|
|
||||||
ListByObject(ctx context.Context, param QueryNotificationEventParam) ([]*entity.NotificationEvent, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---- 2. 使用者通知:user_notification ----
|
|
||||||
// 管使用者的小鈴鐺 row,fan-out 之後用這個寫入。
|
|
||||||
|
|
||||||
// ListLatestOptions 查列表用的參數
|
|
||||||
type ListLatestOptions struct {
|
|
||||||
UserID string
|
|
||||||
Buckets []string // e.g. []string{"202511", "202510"}
|
|
||||||
Limit int // 建議在 service 層限制最大值,例如 <= 100
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserNotificationRepository interface {
|
|
||||||
// CreateUserNotification 建立單一通知(針對某一個 user)。
|
|
||||||
// 由呼叫端決定 bucket 與 TTL 秒數。
|
|
||||||
CreateUserNotification(ctx context.Context, n *entity.UserNotification, ttlSeconds int) error
|
|
||||||
|
|
||||||
// BulkCreate 批次建立多筆通知(fan-out worker 使用)。
|
|
||||||
// 一般期望要嘛全部成功要嘛全部失敗。
|
|
||||||
BulkCreate(ctx context.Context, list []*entity.UserNotification, ttlSeconds int) error
|
|
||||||
|
|
||||||
// ListLatest 取得某 user 最新的通知列表(小鈴鐺拉下來用)。
|
|
||||||
ListLatest(ctx context.Context, opt ListLatestOptions) ([]*entity.UserNotification, error)
|
|
||||||
|
|
||||||
// MarkRead 將單一通知設為已讀。
|
|
||||||
// 用 (user_id, bucket, ts) 精準定位那一筆資料。
|
|
||||||
MarkRead(ctx context.Context, userID, bucket string, ts gocql.UUID) error
|
|
||||||
|
|
||||||
// MarkAllRead 將指定 buckets 範圍內的通知設為已讀。
|
|
||||||
// 常見用法:最近幾個 bucket(例如最近 30 天)全部標為已讀。
|
|
||||||
// Cassandra 不適合全表掃描,實作時可分批 select 再 update。
|
|
||||||
MarkAllRead(ctx context.Context, userID string, buckets []string) error
|
|
||||||
|
|
||||||
// CountUnreadApprox 回傳未讀數(允許是近似值)。
|
|
||||||
// 實作方式可以是:
|
|
||||||
// - 掃少量 buckets 中 status='UNREAD' 的 row,然後在應用端計算
|
|
||||||
// - 或讀取外部 counter(Redis / 另一張 counter table)
|
|
||||||
CountUnreadApprox(ctx context.Context, userID string, buckets []string) (int64, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---- 3. NotificationCursorRepository ----
|
|
||||||
// 管 last_seen 光標,用來減少大量「每一筆更新已讀」的成本。
|
|
||||||
|
|
||||||
type NotificationCursorRepository interface {
|
|
||||||
// GetCursor 取得某 user 的光標,如果不存在可以回傳 (nil, nil)。
|
|
||||||
GetCursor(ctx context.Context, userID string) (*entity.NotificationCursor, error)
|
|
||||||
// UpsertCursor 新增或更新光標。
|
|
||||||
// 一般在使用者打開通知列表、或捲到最上面時更新。
|
|
||||||
UpsertCursor(ctx context.Context, cursor *entity.NotificationCursor) error
|
|
||||||
}
|
|
||||||
|
|
@ -1,114 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
// Import necessary packages
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
)
|
|
||||||
|
|
||||||
type NotificationUseCase interface {
|
|
||||||
EventUseCase
|
|
||||||
UserNotificationUseCase
|
|
||||||
CursorUseCase
|
|
||||||
}
|
|
||||||
type NotificationEvent struct {
|
|
||||||
EventType string // POST_PUBLISHED / COMMENT_ADDED / MENTIONED ...
|
|
||||||
ActorUID string // 觸發者 UID
|
|
||||||
ObjectType string // POST / COMMENT / USER ...
|
|
||||||
ObjectID string // 對應物件 ID(post_id 等)
|
|
||||||
Title string // 顯示用標題
|
|
||||||
Body string // 顯示用內容 / 摘要
|
|
||||||
Payload string // JSON string(額外欄位,例如 {"postId": "..."})
|
|
||||||
Priority string // critical, high, normal, low
|
|
||||||
}
|
|
||||||
|
|
||||||
type NotificationEventResp struct {
|
|
||||||
EventID string `json:"event_id"`
|
|
||||||
EventType string `json:"event_type"`
|
|
||||||
ActorUID string `json:"actor_uid"`
|
|
||||||
ObjectType string `json:"object_type"`
|
|
||||||
ObjectID string `json:"object_id"`
|
|
||||||
Title string `json:"title"`
|
|
||||||
Body string `json:"body"`
|
|
||||||
Payload string `json:"payload"`
|
|
||||||
Priority string `json:"priority"`
|
|
||||||
CreatedAt string `json:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type QueryNotificationEventParam struct {
|
|
||||||
ObjectID *string
|
|
||||||
ObjectType *string
|
|
||||||
Limit *int
|
|
||||||
}
|
|
||||||
|
|
||||||
type EventUseCase interface {
|
|
||||||
// CreateEvent creates a new notification event.
|
|
||||||
CreateEvent(ctx context.Context, e *NotificationEvent) error
|
|
||||||
|
|
||||||
// GetEventByID retrieves an event by its ID.
|
|
||||||
GetEventByID(ctx context.Context, id string) (*NotificationEventResp, error)
|
|
||||||
|
|
||||||
// ListEventsByObject lists events related to a specific object.
|
|
||||||
ListEventsByObject(ctx context.Context, param QueryNotificationEventParam) ([]*NotificationEventResp, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserNotification struct {
|
|
||||||
UserID string `json:"user_id"` // 收通知的人
|
|
||||||
EventID string `json:"event_id"` // 對應 notification_event.event_id
|
|
||||||
TTL int `json:"ttl"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ListLatestOptions struct {
|
|
||||||
UserID string
|
|
||||||
Buckets []string // e.g. []string{"202511", "202510"}
|
|
||||||
Limit int // 建議在 service 層限制最大值,例如 <= 100
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserNotificationResponse struct {
|
|
||||||
UserID string `json:"user_id"` // 收通知的人
|
|
||||||
Bucket string `json:"bucket"` // 分桶,例如 '2025-11' 或 '2025-11-17'
|
|
||||||
TS string `json:"ts"` // 通知時間,用 now() 產生,排序用(UTC0)
|
|
||||||
EventID string `json:"event_id"` // 對應 notification_event.event_id
|
|
||||||
Status string `json:"status"` // UNREAD / READ / ARCHIVED
|
|
||||||
ReadAt *string `json:"read_at,omitempty"` // 已讀時間(非必填)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UserNotificationUseCase handles user-specific notification operations.
|
|
||||||
type UserNotificationUseCase interface {
|
|
||||||
// CreateUserNotification creates a notification for a single user.
|
|
||||||
CreateUserNotification(ctx context.Context, n *UserNotification) error
|
|
||||||
|
|
||||||
// BulkCreateNotifications creates multiple notifications in batch.
|
|
||||||
BulkCreateNotifications(ctx context.Context, list []*UserNotification) error
|
|
||||||
|
|
||||||
// ListLatestNotifications lists the latest notifications for a user.
|
|
||||||
ListLatestNotifications(ctx context.Context, opt ListLatestOptions) ([]*UserNotificationResponse, error)
|
|
||||||
|
|
||||||
// MarkAsRead marks a single notification as read.
|
|
||||||
MarkAsRead(ctx context.Context, userID, bucket string, ts string) error
|
|
||||||
|
|
||||||
// MarkAllAsRead marks all notifications in specified buckets as read.
|
|
||||||
MarkAllAsRead(ctx context.Context, userID string, buckets []string) error
|
|
||||||
|
|
||||||
// CountUnread approximates the count of unread notifications.
|
|
||||||
CountUnread(ctx context.Context, userID string, buckets []string) (int64, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type NotificationCursor struct {
|
|
||||||
UID string
|
|
||||||
LastSeenTS string
|
|
||||||
UpdatedAt string
|
|
||||||
}
|
|
||||||
|
|
||||||
type UpdateNotificationCursorParam struct {
|
|
||||||
UID string
|
|
||||||
LastSeenTS string
|
|
||||||
}
|
|
||||||
|
|
||||||
// CursorUseCase handles notification cursor operations for efficient reading.
|
|
||||||
type CursorUseCase interface {
|
|
||||||
// GetCursor retrieves the notification cursor for a user.
|
|
||||||
GetCursor(ctx context.Context, userID string) (*NotificationCursor, error)
|
|
||||||
|
|
||||||
// UpdateCursor updates or inserts the cursor for a user.
|
|
||||||
UpdateCursor(ctx context.Context, cursor *UpdateNotificationCursorParam) error
|
|
||||||
}
|
|
||||||
|
|
@ -1,603 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"backend/pkg/notification/domain/entity"
|
|
||||||
"backend/pkg/notification/domain/notification"
|
|
||||||
"backend/pkg/notification/domain/repository"
|
|
||||||
"backend/pkg/notification/domain/usecase"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
errs "backend/pkg/library/errors"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NotificationUseCaseParam 通知服務參數配置
|
|
||||||
type NotificationUseCaseParam struct {
|
|
||||||
Repo repository.NotificationRepository
|
|
||||||
Logger errs.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotificationUseCase 通知服務實現
|
|
||||||
type NotificationUseCase struct {
|
|
||||||
param NotificationUseCaseParam
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustNotificationUseCase 創建通知服務實例
|
|
||||||
func MustNotificationUseCase(param NotificationUseCaseParam) usecase.NotificationUseCase {
|
|
||||||
return &NotificationUseCase{
|
|
||||||
param: param,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== EventUseCase 實現 ====================
|
|
||||||
|
|
||||||
// CreateEvent 創建新的通知事件
|
|
||||||
func (uc *NotificationUseCase) CreateEvent(ctx context.Context, e *usecase.NotificationEvent) error {
|
|
||||||
// 驗證輸入
|
|
||||||
if err := uc.validateNotificationEvent(e); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 轉換 priority
|
|
||||||
priority, err := uc.parsePriority(e.Priority)
|
|
||||||
if err != nil {
|
|
||||||
return errs.InputInvalidRangeError(fmt.Sprintf("invalid priority: %s", e.Priority)).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 創建 entity
|
|
||||||
event := &entity.NotificationEvent{
|
|
||||||
EventID: gocql.TimeUUID(),
|
|
||||||
EventType: e.EventType,
|
|
||||||
ActorUID: e.ActorUID,
|
|
||||||
ObjectType: e.ObjectType,
|
|
||||||
ObjectID: e.ObjectID,
|
|
||||||
Title: e.Title,
|
|
||||||
Body: e.Body,
|
|
||||||
Payload: e.Payload,
|
|
||||||
Priority: priority,
|
|
||||||
CreatedAt: time.Now().UTC(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存到資料庫
|
|
||||||
if err := uc.param.Repo.Create(ctx, event); err != nil {
|
|
||||||
return errs.DBErrorErrorL(
|
|
||||||
uc.param.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "event_type", Val: e.EventType},
|
|
||||||
{Key: "actor_uid", Val: e.ActorUID},
|
|
||||||
{Key: "func", Val: "NotificationRepository.Create"},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
"failed to create notification event",
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetEventByID 根據 ID 獲取事件
|
|
||||||
func (uc *NotificationUseCase) GetEventByID(ctx context.Context, id string) (*usecase.NotificationEventResp, error) {
|
|
||||||
// 驗證 UUID 格式
|
|
||||||
if _, err := gocql.ParseUUID(id); err != nil {
|
|
||||||
return nil, errs.InputInvalidRangeError(fmt.Sprintf("invalid event ID format: %s", id)).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 從資料庫獲取
|
|
||||||
event, err := uc.param.Repo.GetByID(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.DBErrorErrorL(
|
|
||||||
uc.param.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "event_id", Val: id},
|
|
||||||
{Key: "func", Val: "NotificationRepository.GetByID"},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
"failed to get notification event by ID",
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 轉換為響應格式
|
|
||||||
return uc.entityToEventResp(event), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListEventsByObject 根據物件查詢事件列表
|
|
||||||
func (uc *NotificationUseCase) ListEventsByObject(ctx context.Context, param usecase.QueryNotificationEventParam) ([]*usecase.NotificationEventResp, error) {
|
|
||||||
// 驗證參數
|
|
||||||
if param.ObjectID == nil || param.ObjectType == nil || param.Limit == nil {
|
|
||||||
return nil, errs.InputInvalidRangeError("object_id and object_type are required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢參數
|
|
||||||
repoParam := repository.QueryNotificationEventParam{
|
|
||||||
ObjectID: param.ObjectID,
|
|
||||||
ObjectType: param.ObjectType,
|
|
||||||
Limit: param.Limit,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 從資料庫查詢
|
|
||||||
events, err := uc.param.Repo.ListByObject(ctx, repoParam)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.DBErrorErrorL(
|
|
||||||
uc.param.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "object_id", Val: *param.ObjectID},
|
|
||||||
{Key: "object_type", Val: *param.ObjectType},
|
|
||||||
{Key: "func", Val: "NotificationRepository.ListByObject"},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
"failed to list notification events by object",
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 轉換為響應格式
|
|
||||||
result := make([]*usecase.NotificationEventResp, 0, len(events))
|
|
||||||
for _, event := range events {
|
|
||||||
result = append(result, uc.entityToEvent(event))
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== UserNotificationUseCase 實現 ====================
|
|
||||||
|
|
||||||
// CreateUserNotification 為單個用戶創建通知
|
|
||||||
func (uc *NotificationUseCase) CreateUserNotification(ctx context.Context, n *usecase.UserNotification) error {
|
|
||||||
// 驗證輸入
|
|
||||||
if err := uc.validateUserNotification(n); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 生成 bucket
|
|
||||||
bucket := uc.generateBucket(time.Now().UTC())
|
|
||||||
|
|
||||||
// 解析 EventID
|
|
||||||
eventID, err := gocql.ParseUUID(n.EventID)
|
|
||||||
if err != nil {
|
|
||||||
return errs.InputInvalidRangeError(fmt.Sprintf("invalid event ID format: %s", n.EventID)).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 創建 entity
|
|
||||||
userNotif := &entity.UserNotification{
|
|
||||||
UserID: n.UserID,
|
|
||||||
Bucket: bucket,
|
|
||||||
TS: gocql.TimeUUID(),
|
|
||||||
EventID: eventID,
|
|
||||||
Status: notification.UNREAD,
|
|
||||||
ReadAt: time.Time{},
|
|
||||||
}
|
|
||||||
|
|
||||||
// 計算 TTL(如果未提供,使用默認值)
|
|
||||||
ttlSeconds := n.TTL
|
|
||||||
if ttlSeconds == 0 {
|
|
||||||
ttlSeconds = uc.calculateDefaultTTL()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 保存到資料庫
|
|
||||||
if err := uc.param.Repo.CreateUserNotification(ctx, userNotif, ttlSeconds); err != nil {
|
|
||||||
return errs.DBErrorErrorL(
|
|
||||||
uc.param.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "user_id", Val: n.UserID},
|
|
||||||
{Key: "event_id", Val: n.EventID},
|
|
||||||
{Key: "func", Val: "NotificationRepository.CreateUserNotification"},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
"failed to create user notification",
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BulkCreateNotifications 批量創建通知
|
|
||||||
func (uc *NotificationUseCase) BulkCreateNotifications(ctx context.Context, list []*usecase.UserNotification) error {
|
|
||||||
if len(list) == 0 {
|
|
||||||
return errs.InputInvalidRangeError("notification list cannot be empty")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 生成 bucket
|
|
||||||
bucket := uc.generateBucket(time.Now().UTC())
|
|
||||||
|
|
||||||
// 轉換為 entity 列表
|
|
||||||
entities := make([]*entity.UserNotification, 0, len(list))
|
|
||||||
for _, n := range list {
|
|
||||||
// 驗證輸入
|
|
||||||
if err := uc.validateUserNotification(n); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析 EventID
|
|
||||||
eventID, err := gocql.ParseUUID(n.EventID)
|
|
||||||
if err != nil {
|
|
||||||
return errs.InputInvalidRangeError(fmt.Sprintf("invalid event ID format: %s", n.EventID)).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 計算 TTL
|
|
||||||
ttlSeconds := n.TTL
|
|
||||||
if ttlSeconds == 0 {
|
|
||||||
ttlSeconds = uc.calculateDefaultTTL()
|
|
||||||
}
|
|
||||||
|
|
||||||
e := &entity.UserNotification{
|
|
||||||
UserID: n.UserID,
|
|
||||||
Bucket: bucket,
|
|
||||||
TS: gocql.TimeUUID(),
|
|
||||||
EventID: eventID,
|
|
||||||
Status: notification.UNREAD,
|
|
||||||
ReadAt: time.Time{},
|
|
||||||
}
|
|
||||||
|
|
||||||
entities = append(entities, e)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用第一個通知的 TTL(假設批量通知使用相同的 TTL)
|
|
||||||
ttlSeconds := list[0].TTL
|
|
||||||
if ttlSeconds == 0 {
|
|
||||||
ttlSeconds = uc.calculateDefaultTTL()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 批量保存
|
|
||||||
if err := uc.param.Repo.BulkCreate(ctx, entities, ttlSeconds); err != nil {
|
|
||||||
return errs.DBErrorErrorL(
|
|
||||||
uc.param.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "count", Val: len(list)},
|
|
||||||
{Key: "func", Val: "NotificationRepository.BulkCreate"},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
"failed to bulk create user notifications",
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListLatestNotifications 獲取用戶最新的通知列表
|
|
||||||
func (uc *NotificationUseCase) ListLatestNotifications(ctx context.Context, opt usecase.ListLatestOptions) ([]*usecase.UserNotificationResponse, error) {
|
|
||||||
// 驗證參數
|
|
||||||
if opt.UserID == "" {
|
|
||||||
return nil, errs.InputInvalidRangeError("user_id is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 限制 Limit 最大值
|
|
||||||
if opt.Limit <= 0 {
|
|
||||||
opt.Limit = 20 // 默認值
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果未提供 buckets,生成默認的 buckets(最近 3 個月)
|
|
||||||
if len(opt.Buckets) == 0 {
|
|
||||||
opt.Buckets = uc.generateDefaultBuckets()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢參數
|
|
||||||
repoOpt := repository.ListLatestOptions{
|
|
||||||
UserID: opt.UserID,
|
|
||||||
Buckets: opt.Buckets,
|
|
||||||
Limit: opt.Limit,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 從資料庫查詢
|
|
||||||
notifications, err := uc.param.Repo.ListLatest(ctx, repoOpt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.DBErrorErrorL(
|
|
||||||
uc.param.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "user_id", Val: opt.UserID},
|
|
||||||
{Key: "buckets", Val: opt.Buckets},
|
|
||||||
{Key: "func", Val: "NotificationRepository.ListLatest"},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
"failed to list latest notifications",
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 轉換為響應格式
|
|
||||||
result := make([]*usecase.UserNotificationResponse, 0, len(notifications))
|
|
||||||
for _, n := range notifications {
|
|
||||||
result = append(result, uc.entityToUserNotificationResp(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkAsRead 標記單個通知為已讀
|
|
||||||
func (uc *NotificationUseCase) MarkAsRead(ctx context.Context, userID, bucket string, ts string) error {
|
|
||||||
// 驗證參數
|
|
||||||
if userID == "" || bucket == "" || ts == "" {
|
|
||||||
return errs.InputInvalidRangeError("user_id, bucket, and ts are required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析 TimeUUID
|
|
||||||
timeUUID, err := gocql.ParseUUID(ts)
|
|
||||||
if err != nil {
|
|
||||||
return errs.InputInvalidRangeError(fmt.Sprintf("invalid ts format: %s", ts)).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新資料庫
|
|
||||||
if err := uc.param.Repo.MarkRead(ctx, userID, bucket, timeUUID); err != nil {
|
|
||||||
return errs.DBErrorErrorL(
|
|
||||||
uc.param.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "user_id", Val: userID},
|
|
||||||
{Key: "bucket", Val: bucket},
|
|
||||||
{Key: "ts", Val: ts},
|
|
||||||
{Key: "func", Val: "NotificationRepository.MarkRead"},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
"failed to mark notification as read",
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MarkAllAsRead 標記指定 buckets 範圍內的所有通知為已讀
|
|
||||||
func (uc *NotificationUseCase) MarkAllAsRead(ctx context.Context, userID string, buckets []string) error {
|
|
||||||
// 驗證參數
|
|
||||||
if userID == "" {
|
|
||||||
return errs.InputInvalidRangeError("user_id is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果未提供 buckets,使用默認的 buckets
|
|
||||||
if len(buckets) == 0 {
|
|
||||||
buckets = uc.generateDefaultBuckets()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新資料庫
|
|
||||||
if err := uc.param.Repo.MarkAllRead(ctx, userID, buckets); err != nil {
|
|
||||||
return errs.DBErrorErrorL(
|
|
||||||
uc.param.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "user_id", Val: userID},
|
|
||||||
{Key: "buckets", Val: buckets},
|
|
||||||
{Key: "func", Val: "NotificationRepository.MarkAllRead"},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
"failed to mark all notifications as read",
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CountUnread 計算未讀通知數量(近似值)
|
|
||||||
func (uc *NotificationUseCase) CountUnread(ctx context.Context, userID string, buckets []string) (int64, error) {
|
|
||||||
// 驗證參數
|
|
||||||
if userID == "" {
|
|
||||||
return 0, errs.InputInvalidRangeError("user_id is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果未提供 buckets,使用默認的 buckets
|
|
||||||
if len(buckets) == 0 {
|
|
||||||
buckets = uc.generateDefaultBuckets()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 從資料庫查詢
|
|
||||||
count, err := uc.param.Repo.CountUnreadApprox(ctx, userID, buckets)
|
|
||||||
if err != nil {
|
|
||||||
return 0, errs.DBErrorErrorL(
|
|
||||||
uc.param.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "user_id", Val: userID},
|
|
||||||
{Key: "buckets", Val: buckets},
|
|
||||||
{Key: "func", Val: "NotificationRepository.CountUnreadApprox"},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
"failed to count unread notifications",
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== CursorUseCase 實現 ====================
|
|
||||||
|
|
||||||
// GetCursor 獲取用戶的通知光標
|
|
||||||
func (uc *NotificationUseCase) GetCursor(ctx context.Context, userID string) (*usecase.NotificationCursor, error) {
|
|
||||||
// 驗證參數
|
|
||||||
if userID == "" {
|
|
||||||
return nil, errs.InputInvalidRangeError("user_id is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 從資料庫查詢
|
|
||||||
cursor, err := uc.param.Repo.GetCursor(ctx, userID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.DBErrorErrorL(
|
|
||||||
uc.param.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "user_id", Val: userID},
|
|
||||||
{Key: "func", Val: "NotificationRepository.GetCursor"},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
"failed to get notification cursor",
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果不存在,返回 nil
|
|
||||||
if cursor == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 轉換為響應格式
|
|
||||||
return uc.entityToCursor(cursor), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateCursor 更新或插入通知光標
|
|
||||||
func (uc *NotificationUseCase) UpdateCursor(ctx context.Context, param *usecase.UpdateNotificationCursorParam) error {
|
|
||||||
// 驗證參數
|
|
||||||
if param == nil {
|
|
||||||
return errs.InputInvalidRangeError("cursor param is required")
|
|
||||||
}
|
|
||||||
if param.UID == "" {
|
|
||||||
return errs.InputInvalidRangeError("uid is required")
|
|
||||||
}
|
|
||||||
if param.LastSeenTS == "" {
|
|
||||||
return errs.InputInvalidRangeError("last_seen_ts is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 解析 TimeUUID
|
|
||||||
lastSeenTS, err := gocql.ParseUUID(param.LastSeenTS)
|
|
||||||
if err != nil {
|
|
||||||
return errs.InputInvalidRangeError(fmt.Sprintf("invalid last_seen_ts format: %s", param.LastSeenTS)).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 創建 entity
|
|
||||||
cursor := &entity.NotificationCursor{
|
|
||||||
UID: param.UID,
|
|
||||||
LastSeenTS: lastSeenTS,
|
|
||||||
UpdatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新資料庫
|
|
||||||
if err := uc.param.Repo.UpsertCursor(ctx, cursor); err != nil {
|
|
||||||
return errs.DBErrorErrorL(
|
|
||||||
uc.param.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "uid", Val: param.UID},
|
|
||||||
{Key: "last_seen_ts", Val: param.LastSeenTS},
|
|
||||||
{Key: "func", Val: "NotificationRepository.UpsertCursor"},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
"failed to update notification cursor",
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ==================== 輔助函數 ====================
|
|
||||||
|
|
||||||
// validateNotificationEvent 驗證通知事件
|
|
||||||
func (uc *NotificationUseCase) validateNotificationEvent(e *usecase.NotificationEvent) error {
|
|
||||||
if e == nil {
|
|
||||||
return errs.InputInvalidRangeError("notification event is required")
|
|
||||||
}
|
|
||||||
if e.EventType == "" {
|
|
||||||
return errs.InputInvalidRangeError("event_type is required")
|
|
||||||
}
|
|
||||||
if e.ActorUID == "" {
|
|
||||||
return errs.InputInvalidRangeError("actor_uid is required")
|
|
||||||
}
|
|
||||||
if e.ObjectType == "" {
|
|
||||||
return errs.InputInvalidRangeError("object_type is required")
|
|
||||||
}
|
|
||||||
if e.ObjectID == "" {
|
|
||||||
return errs.InputInvalidRangeError("object_id is required")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateUserNotification 驗證用戶通知
|
|
||||||
func (uc *NotificationUseCase) validateUserNotification(n *usecase.UserNotification) error {
|
|
||||||
if n == nil {
|
|
||||||
return errs.InputInvalidRangeError("user notification is required")
|
|
||||||
}
|
|
||||||
if n.UserID == "" {
|
|
||||||
return errs.InputInvalidRangeError("user_id is required")
|
|
||||||
}
|
|
||||||
if n.EventID == "" {
|
|
||||||
return errs.InputInvalidRangeError("event_id is required")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parsePriority 解析優先級字符串
|
|
||||||
func (uc *NotificationUseCase) parsePriority(priorityStr string) (notification.NotifyPriority, error) {
|
|
||||||
switch priorityStr {
|
|
||||||
case "critical":
|
|
||||||
return notification.Critical, nil
|
|
||||||
case "high":
|
|
||||||
return notification.High, nil
|
|
||||||
case "normal":
|
|
||||||
return notification.Normal, nil
|
|
||||||
case "low":
|
|
||||||
return notification.Low, nil
|
|
||||||
default:
|
|
||||||
return notification.Normal, errors.New("invalid priority value")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateBucket 生成 bucket 字符串(格式:YYYYMM)
|
|
||||||
func (uc *NotificationUseCase) generateBucket(t time.Time) string {
|
|
||||||
return t.Format("200601")
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateDefaultBuckets 生成默認的 buckets(最近 3 個月)
|
|
||||||
func (uc *NotificationUseCase) generateDefaultBuckets() []string {
|
|
||||||
now := time.Now()
|
|
||||||
buckets := make([]string, 0, 3)
|
|
||||||
for i := 0; i < 3; i++ {
|
|
||||||
month := now.AddDate(0, -i, 0)
|
|
||||||
buckets = append(buckets, month.Format("200601"))
|
|
||||||
}
|
|
||||||
return buckets
|
|
||||||
}
|
|
||||||
|
|
||||||
// calculateDefaultTTL 計算默認 TTL(90 天)
|
|
||||||
func (uc *NotificationUseCase) calculateDefaultTTL() int {
|
|
||||||
return 90 * 24 * 60 * 60 // 90 天,單位:秒
|
|
||||||
}
|
|
||||||
|
|
||||||
// entityToEventResp 將 entity 轉換為 EventResp
|
|
||||||
func (uc *NotificationUseCase) entityToEventResp(e *entity.NotificationEvent) *usecase.NotificationEventResp {
|
|
||||||
return &usecase.NotificationEventResp{
|
|
||||||
EventID: e.EventID.String(),
|
|
||||||
EventType: e.EventType,
|
|
||||||
ActorUID: e.ActorUID,
|
|
||||||
ObjectType: e.ObjectType,
|
|
||||||
ObjectID: e.ObjectID,
|
|
||||||
Title: e.Title,
|
|
||||||
Body: e.Body,
|
|
||||||
Payload: e.Payload,
|
|
||||||
Priority: e.Priority.ToString(),
|
|
||||||
CreatedAt: e.CreatedAt.UTC().Format(time.RFC3339),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// entityToEvent 將 entity 轉換為 Event
|
|
||||||
func (uc *NotificationUseCase) entityToEvent(e *entity.NotificationEvent) *usecase.NotificationEventResp {
|
|
||||||
return &usecase.NotificationEventResp{
|
|
||||||
EventID: e.EventID.String(),
|
|
||||||
EventType: e.EventType,
|
|
||||||
ActorUID: e.ActorUID,
|
|
||||||
ObjectType: e.ObjectType,
|
|
||||||
ObjectID: e.ObjectID,
|
|
||||||
Title: e.Title,
|
|
||||||
Body: e.Body,
|
|
||||||
Payload: e.Payload,
|
|
||||||
Priority: e.Priority.ToString(),
|
|
||||||
CreatedAt: e.CreatedAt.UTC().Format(time.RFC3339),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// entityToUserNotificationResp 將 entity 轉換為 UserNotificationResponse
|
|
||||||
func (uc *NotificationUseCase) entityToUserNotificationResp(n *entity.UserNotification) *usecase.UserNotificationResponse {
|
|
||||||
resp := &usecase.UserNotificationResponse{
|
|
||||||
UserID: n.UserID,
|
|
||||||
Bucket: n.Bucket,
|
|
||||||
TS: n.TS.String(),
|
|
||||||
EventID: n.EventID.String(),
|
|
||||||
Status: n.Status.ToString(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果 ReadAt 不是零值,設置為字符串
|
|
||||||
if !n.ReadAt.IsZero() {
|
|
||||||
readAtStr := n.ReadAt.UTC().Format(time.RFC3339)
|
|
||||||
resp.ReadAt = &readAtStr
|
|
||||||
}
|
|
||||||
|
|
||||||
return resp
|
|
||||||
}
|
|
||||||
|
|
||||||
// entityToCursor 將 entity 轉換為 Cursor
|
|
||||||
func (uc *NotificationUseCase) entityToCursor(c *entity.NotificationCursor) *usecase.NotificationCursor {
|
|
||||||
return &usecase.NotificationCursor{
|
|
||||||
UID: c.UID,
|
|
||||||
LastSeenTS: c.LastSeenTS.String(),
|
|
||||||
UpdatedAt: c.UpdatedAt.UTC().Format(time.RFC3339),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,49 +0,0 @@
|
||||||
package domain
|
|
||||||
|
|
||||||
// Business constants for the post service
|
|
||||||
const (
|
|
||||||
// DefaultPageSize is the default page size for pagination
|
|
||||||
DefaultPageSize = 20
|
|
||||||
|
|
||||||
// MaxPageSize is the maximum allowed page size
|
|
||||||
MaxPageSize = 100
|
|
||||||
|
|
||||||
// MinPageSize is the minimum allowed page size
|
|
||||||
MinPageSize = 1
|
|
||||||
|
|
||||||
// MaxPostTitleLength is the maximum length for post title
|
|
||||||
MaxPostTitleLength = 200
|
|
||||||
|
|
||||||
// MinPostTitleLength is the minimum length for post title
|
|
||||||
MinPostTitleLength = 1
|
|
||||||
|
|
||||||
// MaxPostContentLength is the maximum length for post content
|
|
||||||
MaxPostContentLength = 10000
|
|
||||||
|
|
||||||
// MinPostContentLength is the minimum length for post content
|
|
||||||
MinPostContentLength = 1
|
|
||||||
|
|
||||||
// MaxCommentLength is the maximum length for comment
|
|
||||||
MaxCommentLength = 2000
|
|
||||||
|
|
||||||
// MinCommentLength is the minimum length for comment
|
|
||||||
MinCommentLength = 1
|
|
||||||
|
|
||||||
// MaxTagNameLength is the maximum length for tag name
|
|
||||||
MaxTagNameLength = 50
|
|
||||||
|
|
||||||
// MinTagNameLength is the minimum length for tag name
|
|
||||||
MinTagNameLength = 1
|
|
||||||
|
|
||||||
// MaxTagsPerPost is the maximum number of tags per post
|
|
||||||
MaxTagsPerPost = 10
|
|
||||||
|
|
||||||
// DefaultCacheExpiration is the default cache expiration time in seconds
|
|
||||||
DefaultCacheExpiration = 3600
|
|
||||||
|
|
||||||
// MaxRetryAttempts is the maximum number of retry attempts for operations
|
|
||||||
MaxRetryAttempts = 3
|
|
||||||
|
|
||||||
// DefaultLikeCacheExpiration is the default cache expiration for like counts
|
|
||||||
DefaultLikeCacheExpiration = 300 // 5 minutes
|
|
||||||
)
|
|
||||||
|
|
@ -1,85 +0,0 @@
|
||||||
package entity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Category represents a category entity for organizing posts.
|
|
||||||
type Category struct {
|
|
||||||
ID gocql.UUID `db:"id" partition_key:"true"` // Category unique identifier
|
|
||||||
Slug string `db:"slug"` // URL-friendly slug (unique)
|
|
||||||
Name string `db:"name"` // Category name
|
|
||||||
Description *string `db:"description,omitempty"` // Category description (optional)
|
|
||||||
ParentID *gocql.UUID `db:"parent_id,omitempty"` // Parent category ID (for nested categories)
|
|
||||||
PostCount int64 `db:"post_count"` // Number of posts in this category
|
|
||||||
IsActive bool `db:"is_active"` // Whether the category is active
|
|
||||||
SortOrder int32 `db:"sort_order"` // Sort order for display
|
|
||||||
CreatedAt int64 `db:"created_at"` // Creation timestamp
|
|
||||||
UpdatedAt int64 `db:"updated_at"` // Last update timestamp
|
|
||||||
}
|
|
||||||
|
|
||||||
// TableName returns the Cassandra table name for Category entities.
|
|
||||||
func (c *Category) TableName() string {
|
|
||||||
return "categories"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate validates the Category entity
|
|
||||||
func (c *Category) Validate() error {
|
|
||||||
if c.Name == "" {
|
|
||||||
return errors.New("category name is required")
|
|
||||||
}
|
|
||||||
if c.Slug == "" {
|
|
||||||
return errors.New("category slug is required")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTimestamps sets the create and update timestamps
|
|
||||||
func (c *Category) SetTimestamps() {
|
|
||||||
now := time.Now().UTC().UnixNano() / 1e6 // milliseconds
|
|
||||||
if c.CreatedAt == 0 {
|
|
||||||
c.CreatedAt = now
|
|
||||||
}
|
|
||||||
c.UpdatedAt = now
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsNew returns true if this is a new category (no ID set)
|
|
||||||
func (c *Category) IsNew() bool {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
return c.ID == zeroUUID
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsRoot returns true if this category has no parent
|
|
||||||
func (c *Category) IsRoot() bool {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
return c.ParentID == nil || *c.ParentID == zeroUUID
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementPostCount increments the post count
|
|
||||||
func (c *Category) IncrementPostCount() {
|
|
||||||
c.PostCount++
|
|
||||||
c.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrementPostCount decrements the post count
|
|
||||||
func (c *Category) DecrementPostCount() {
|
|
||||||
if c.PostCount > 0 {
|
|
||||||
c.PostCount--
|
|
||||||
c.SetTimestamps()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Activate activates the category
|
|
||||||
func (c *Category) Activate() {
|
|
||||||
c.IsActive = true
|
|
||||||
c.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deactivate deactivates the category
|
|
||||||
func (c *Category) Deactivate() {
|
|
||||||
c.IsActive = false
|
|
||||||
c.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
@ -1,114 +0,0 @@
|
||||||
package entity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"backend/pkg/post/domain/post"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Comment represents a comment entity on a post.
|
|
||||||
// Comments can be nested (replies to comments).
|
|
||||||
type Comment struct {
|
|
||||||
ID gocql.UUID `db:"id" partition_key:"true"` // Comment unique identifier
|
|
||||||
PostID gocql.UUID `db:"post_id" clustering_key:"true"` // Post ID (clustering key for sorting)
|
|
||||||
AuthorUID string `db:"author_uid"` // Author user UID
|
|
||||||
ParentID *gocql.UUID `db:"parent_id,omitempty" clustering_key:"true"` // Parent comment ID (for nested comments)
|
|
||||||
Content string `db:"content"` // Comment content
|
|
||||||
Status post.CommentStatus `db:"status"` // Comment status
|
|
||||||
LikeCount int64 `db:"like_count"` // Number of likes
|
|
||||||
ReplyCount int64 `db:"reply_count"` // Number of replies
|
|
||||||
CreatedAt int64 `db:"created_at" clustering_key:"true"` // Creation timestamp (for sorting)
|
|
||||||
UpdatedAt int64 `db:"updated_at"` // Last update timestamp
|
|
||||||
}
|
|
||||||
|
|
||||||
// TableName returns the Cassandra table name for Comment entities.
|
|
||||||
func (c *Comment) TableName() string {
|
|
||||||
return "comments"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate validates the Comment entity
|
|
||||||
func (c *Comment) Validate() error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if c.PostID == zeroUUID {
|
|
||||||
return errors.New("post_id is required")
|
|
||||||
}
|
|
||||||
if c.AuthorUID == "" {
|
|
||||||
return errors.New("author_uid is required")
|
|
||||||
}
|
|
||||||
if len(c.Content) < 1 || len(c.Content) > 2000 {
|
|
||||||
return errors.New("content length must be between 1 and 2000 characters")
|
|
||||||
}
|
|
||||||
if !c.Status.IsValid() {
|
|
||||||
return errors.New("invalid comment status")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTimestamps sets the create and update timestamps
|
|
||||||
func (c *Comment) SetTimestamps() {
|
|
||||||
now := time.Now().UTC().UnixNano() / 1e6 // milliseconds
|
|
||||||
if c.CreatedAt == 0 {
|
|
||||||
c.CreatedAt = now
|
|
||||||
}
|
|
||||||
c.UpdatedAt = now
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsNew returns true if this is a new comment (no ID set)
|
|
||||||
func (c *Comment) IsNew() bool {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
return c.ID == zeroUUID
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsReply returns true if this comment is a reply to another comment
|
|
||||||
func (c *Comment) IsReply() bool {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
return c.ParentID != nil && *c.ParentID != zeroUUID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete marks the comment as deleted (soft delete)
|
|
||||||
func (c *Comment) Delete() {
|
|
||||||
c.Status = post.CommentStatusDeleted
|
|
||||||
c.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hide hides the comment
|
|
||||||
func (c *Comment) Hide() {
|
|
||||||
c.Status = post.CommentStatusHidden
|
|
||||||
c.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsVisible returns true if the comment is visible to public
|
|
||||||
func (c *Comment) IsVisible() bool {
|
|
||||||
return c.Status.IsVisible()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementLikeCount increments the like count
|
|
||||||
func (c *Comment) IncrementLikeCount() {
|
|
||||||
c.LikeCount++
|
|
||||||
c.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrementLikeCount decrements the like count
|
|
||||||
func (c *Comment) DecrementLikeCount() {
|
|
||||||
if c.LikeCount > 0 {
|
|
||||||
c.LikeCount--
|
|
||||||
c.SetTimestamps()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementReplyCount increments the reply count
|
|
||||||
func (c *Comment) IncrementReplyCount() {
|
|
||||||
c.ReplyCount++
|
|
||||||
c.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrementReplyCount decrements the reply count
|
|
||||||
func (c *Comment) DecrementReplyCount() {
|
|
||||||
if c.ReplyCount > 0 {
|
|
||||||
c.ReplyCount--
|
|
||||||
c.SetTimestamps()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,61 +0,0 @@
|
||||||
package entity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Like represents a like entity for posts or comments.
|
|
||||||
// Uses composite primary key: (target_id, user_uid) for uniqueness.
|
|
||||||
type Like struct {
|
|
||||||
ID gocql.UUID `db:"id" partition_key:"true"` // Like unique identifier
|
|
||||||
TargetID gocql.UUID `db:"target_id" clustering_key:"true"` // Target ID (post_id or comment_id)
|
|
||||||
UserUID string `db:"user_uid" clustering_key:"true"` // User UID who liked
|
|
||||||
TargetType string `db:"target_type"` // Target type: "post" or "comment"
|
|
||||||
CreatedAt int64 `db:"created_at"` // Creation timestamp
|
|
||||||
}
|
|
||||||
|
|
||||||
// TableName returns the Cassandra table name for Like entities.
|
|
||||||
func (l *Like) TableName() string {
|
|
||||||
return "likes"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate validates the Like entity
|
|
||||||
func (l *Like) Validate() error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if l.TargetID == zeroUUID {
|
|
||||||
return errors.New("target_id is required")
|
|
||||||
}
|
|
||||||
if l.UserUID == "" {
|
|
||||||
return errors.New("user_uid is required")
|
|
||||||
}
|
|
||||||
if l.TargetType != "post" && l.TargetType != "comment" {
|
|
||||||
return errors.New("target_type must be 'post' or 'comment'")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTimestamps sets the create timestamp
|
|
||||||
func (l *Like) SetTimestamps() {
|
|
||||||
if l.CreatedAt == 0 {
|
|
||||||
l.CreatedAt = time.Now().UTC().UnixNano() / 1e6 // milliseconds
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsNew returns true if this is a new like (no ID set)
|
|
||||||
func (l *Like) IsNew() bool {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
return l.ID == zeroUUID
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsPostLike returns true if this like is for a post
|
|
||||||
func (l *Like) IsPostLike() bool {
|
|
||||||
return l.TargetType == "post"
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsCommentLike returns true if this like is for a comment
|
|
||||||
func (l *Like) IsCommentLike() bool {
|
|
||||||
return l.TargetType == "comment"
|
|
||||||
}
|
|
||||||
|
|
@ -1,156 +0,0 @@
|
||||||
package entity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"backend/pkg/post/domain/post"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Post represents a post entity in the system.
|
|
||||||
// It contains the main content and metadata for user posts.
|
|
||||||
type Post struct {
|
|
||||||
ID gocql.UUID `db:"id" partition_key:"true"` // Post unique identifier
|
|
||||||
AuthorUID string `db:"author_uid"` // Author user UID
|
|
||||||
Title string `db:"title"` // Post title
|
|
||||||
Content string `db:"content"` // Post content
|
|
||||||
Type post.Type `db:"type"` // Post type (text, image, video, etc.)
|
|
||||||
Status post.Status `db:"status"` // Post status (draft, published, etc.)
|
|
||||||
CategoryID *gocql.UUID `db:"category_id,omitempty"` // Category ID (optional)
|
|
||||||
Tags []string `db:"tags,omitempty"` // Post tags
|
|
||||||
Images []string `db:"images,omitempty"` // Image URLs (optional)
|
|
||||||
VideoURL *string `db:"video_url,omitempty"` // Video URL (optional)
|
|
||||||
LinkURL *string `db:"link_url,omitempty"` // Link URL (optional)
|
|
||||||
LikeCount int64 `db:"like_count"` // Number of likes
|
|
||||||
CommentCount int64 `db:"comment_count"` // Number of comments
|
|
||||||
ViewCount int64 `db:"view_count"` // Number of views
|
|
||||||
IsPinned bool `db:"is_pinned"` // Whether the post is pinned
|
|
||||||
PinnedAt *int64 `db:"pinned_at,omitempty"` // Pinned timestamp (optional)
|
|
||||||
PublishedAt *int64 `db:"published_at,omitempty"` // Published timestamp (optional)
|
|
||||||
CreatedAt int64 `db:"created_at"` // Creation timestamp
|
|
||||||
UpdatedAt int64 `db:"updated_at"` // Last update timestamp
|
|
||||||
}
|
|
||||||
|
|
||||||
// TableName returns the Cassandra table name for Post entities.
|
|
||||||
func (p *Post) TableName() string {
|
|
||||||
return "posts"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate validates the Post entity
|
|
||||||
func (p *Post) Validate() error {
|
|
||||||
if p.AuthorUID == "" {
|
|
||||||
return errors.New("author_uid is required")
|
|
||||||
}
|
|
||||||
if len(p.Title) < 1 || len(p.Title) > 200 {
|
|
||||||
return errors.New("title length must be between 1 and 200 characters")
|
|
||||||
}
|
|
||||||
if len(p.Content) < 1 || len(p.Content) > 10000 {
|
|
||||||
return errors.New("content length must be between 1 and 10000 characters")
|
|
||||||
}
|
|
||||||
if !p.Type.IsValid() {
|
|
||||||
return errors.New("invalid post type")
|
|
||||||
}
|
|
||||||
if !p.Status.IsValid() {
|
|
||||||
return errors.New("invalid post status")
|
|
||||||
}
|
|
||||||
if len(p.Tags) > 10 {
|
|
||||||
return errors.New("maximum 10 tags allowed per post")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTimestamps sets the create and update timestamps
|
|
||||||
func (p *Post) SetTimestamps() {
|
|
||||||
now := time.Now().UTC().UnixNano() / 1e6 // milliseconds
|
|
||||||
if p.CreatedAt == 0 {
|
|
||||||
p.CreatedAt = now
|
|
||||||
}
|
|
||||||
p.UpdatedAt = now
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsNew returns true if this is a new post (no ID set)
|
|
||||||
func (p *Post) IsNew() bool {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
return p.ID == zeroUUID
|
|
||||||
}
|
|
||||||
|
|
||||||
// Publish marks the post as published
|
|
||||||
func (p *Post) Publish() {
|
|
||||||
p.Status = post.PostStatusPublished
|
|
||||||
now := time.Now().UTC().UnixNano() / 1e6
|
|
||||||
p.PublishedAt = &now
|
|
||||||
p.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Archive marks the post as archived
|
|
||||||
func (p *Post) Archive() {
|
|
||||||
p.Status = post.PostStatusArchived
|
|
||||||
p.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete marks the post as deleted (soft delete)
|
|
||||||
func (p *Post) Delete() {
|
|
||||||
p.Status = post.PostStatusDeleted
|
|
||||||
p.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsVisible returns true if the post is visible to public
|
|
||||||
func (p *Post) IsVisible() bool {
|
|
||||||
return p.Status.IsVisible()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsEditable returns true if the post can be edited
|
|
||||||
func (p *Post) IsEditable() bool {
|
|
||||||
return p.Status.IsEditable()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementLikeCount increments the like count
|
|
||||||
func (p *Post) IncrementLikeCount() {
|
|
||||||
p.LikeCount++
|
|
||||||
p.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrementLikeCount decrements the like count
|
|
||||||
func (p *Post) DecrementLikeCount() {
|
|
||||||
if p.LikeCount > 0 {
|
|
||||||
p.LikeCount--
|
|
||||||
p.SetTimestamps()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementCommentCount increments the comment count
|
|
||||||
func (p *Post) IncrementCommentCount() {
|
|
||||||
p.CommentCount++
|
|
||||||
p.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrementCommentCount decrements the comment count
|
|
||||||
func (p *Post) DecrementCommentCount() {
|
|
||||||
if p.CommentCount > 0 {
|
|
||||||
p.CommentCount--
|
|
||||||
p.SetTimestamps()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementViewCount increments the view count
|
|
||||||
func (p *Post) IncrementViewCount() {
|
|
||||||
p.ViewCount++
|
|
||||||
p.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pin pins the post
|
|
||||||
func (p *Post) Pin() {
|
|
||||||
p.IsPinned = true
|
|
||||||
now := time.Now().UTC().UnixNano() / 1e6
|
|
||||||
p.PinnedAt = &now
|
|
||||||
p.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unpin unpins the post
|
|
||||||
func (p *Post) Unpin() {
|
|
||||||
p.IsPinned = false
|
|
||||||
p.PinnedAt = nil
|
|
||||||
p.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
@ -1,60 +0,0 @@
|
||||||
package entity
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Tag represents a tag entity for categorizing posts.
|
|
||||||
type Tag struct {
|
|
||||||
ID gocql.UUID `db:"id" partition_key:"true"` // Tag unique identifier
|
|
||||||
Name string `db:"name"` // Tag name (unique)
|
|
||||||
Description *string `db:"description,omitempty"` // Tag description (optional)
|
|
||||||
PostCount int64 `db:"post_count"` // Number of posts using this tag
|
|
||||||
CreatedAt int64 `db:"created_at"` // Creation timestamp
|
|
||||||
UpdatedAt int64 `db:"updated_at"` // Last update timestamp
|
|
||||||
}
|
|
||||||
|
|
||||||
// TableName returns the Cassandra table name for Tag entities.
|
|
||||||
func (t *Tag) TableName() string {
|
|
||||||
return "tags"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate validates the Tag entity
|
|
||||||
func (t *Tag) Validate() error {
|
|
||||||
if len(t.Name) < 1 || len(t.Name) > 50 {
|
|
||||||
return errors.New("tag name length must be between 1 and 50 characters")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTimestamps sets the create and update timestamps
|
|
||||||
func (t *Tag) SetTimestamps() {
|
|
||||||
now := time.Now().UTC().UnixNano() / 1e6 // milliseconds
|
|
||||||
if t.CreatedAt == 0 {
|
|
||||||
t.CreatedAt = now
|
|
||||||
}
|
|
||||||
t.UpdatedAt = now
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsNew returns true if this is a new tag (no ID set)
|
|
||||||
func (t *Tag) IsNew() bool {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
return t.ID == zeroUUID
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementPostCount increments the post count
|
|
||||||
func (t *Tag) IncrementPostCount() {
|
|
||||||
t.PostCount++
|
|
||||||
t.SetTimestamps()
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrementPostCount decrements the post count
|
|
||||||
func (t *Tag) DecrementPostCount() {
|
|
||||||
if t.PostCount > 0 {
|
|
||||||
t.PostCount--
|
|
||||||
t.SetTimestamps()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
package post
|
|
||||||
|
|
||||||
// CommentStatus 評論狀態
|
|
||||||
type CommentStatus int32
|
|
||||||
|
|
||||||
func (s CommentStatus) CodeToString() string {
|
|
||||||
result, ok := commentStatusMap[s]
|
|
||||||
if !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
var commentStatusMap = map[CommentStatus]string{
|
|
||||||
CommentStatusPublished: "published", // 已發布
|
|
||||||
CommentStatusDeleted: "deleted", // 已刪除
|
|
||||||
CommentStatusHidden: "hidden", // 隱藏
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s CommentStatus) ToInt32() int32 {
|
|
||||||
return int32(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
CommentStatusPublished CommentStatus = 0 // 已發布
|
|
||||||
CommentStatusDeleted CommentStatus = 1 // 已刪除
|
|
||||||
CommentStatusHidden CommentStatus = 2 // 隱藏
|
|
||||||
)
|
|
||||||
|
|
||||||
// IsValid returns true if the status is valid
|
|
||||||
func (s CommentStatus) IsValid() bool {
|
|
||||||
return s >= CommentStatusPublished && s <= CommentStatusHidden
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsVisible returns true if the comment is visible to public
|
|
||||||
func (s CommentStatus) IsVisible() bool {
|
|
||||||
return s == CommentStatusPublished
|
|
||||||
}
|
|
||||||
|
|
@ -1,47 +0,0 @@
|
||||||
package post
|
|
||||||
|
|
||||||
// Status 貼文狀態
|
|
||||||
type Status int32
|
|
||||||
|
|
||||||
func (s Status) CodeToString() string {
|
|
||||||
result, ok := postStatusMap[s]
|
|
||||||
if !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
var postStatusMap = map[Status]string{
|
|
||||||
PostStatusDraft: "draft", // 草稿
|
|
||||||
PostStatusPublished: "published", // 已發布
|
|
||||||
PostStatusArchived: "archived", // 已歸檔
|
|
||||||
PostStatusDeleted: "deleted", // 已刪除
|
|
||||||
PostStatusHidden: "hidden", // 隱藏
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s Status) ToInt32() int32 {
|
|
||||||
return int32(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
PostStatusDraft Status = 0 // 草稿
|
|
||||||
PostStatusPublished Status = 1 // 已發布
|
|
||||||
PostStatusArchived Status = 2 // 已歸檔
|
|
||||||
PostStatusDeleted Status = 3 // 已刪除
|
|
||||||
PostStatusHidden Status = 4 // 隱藏
|
|
||||||
)
|
|
||||||
|
|
||||||
// IsValid returns true if the status is valid
|
|
||||||
func (s Status) IsValid() bool {
|
|
||||||
return s >= PostStatusDraft && s <= PostStatusHidden
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsVisible returns true if the post is visible to public
|
|
||||||
func (s Status) IsVisible() bool {
|
|
||||||
return s == PostStatusPublished
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsEditable returns true if the post can be edited
|
|
||||||
func (s Status) IsEditable() bool {
|
|
||||||
return s == PostStatusDraft || s == PostStatusPublished
|
|
||||||
}
|
|
||||||
|
|
@ -1,39 +0,0 @@
|
||||||
package post
|
|
||||||
|
|
||||||
// Type 貼文類型
|
|
||||||
type Type int32
|
|
||||||
|
|
||||||
func (t Type) CodeToString() string {
|
|
||||||
result, ok := postTypeMap[t]
|
|
||||||
if !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
var postTypeMap = map[Type]string{
|
|
||||||
TypeText: "text", // 純文字
|
|
||||||
TypeImage: "image", // 圖片
|
|
||||||
TypeVideo: "video", // 影片
|
|
||||||
TypeLink: "link", // 連結
|
|
||||||
TypePoll: "poll", // 投票
|
|
||||||
TypeArticle: "article", // 長文
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Type) ToInt32() int32 {
|
|
||||||
return int32(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
TypeText Type = 0 // 純文字
|
|
||||||
TypeImage Type = 1 // 圖片
|
|
||||||
TypeVideo Type = 2 // 影片
|
|
||||||
TypeLink Type = 3 // 連結
|
|
||||||
TypePoll Type = 4 // 投票
|
|
||||||
TypeArticle Type = 5 // 長文
|
|
||||||
)
|
|
||||||
|
|
||||||
// IsValid returns true if the type is valid
|
|
||||||
func (t Type) IsValid() bool {
|
|
||||||
return t >= TypeText && t <= TypeArticle
|
|
||||||
}
|
|
||||||
|
|
@ -1,26 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"backend/pkg/post/domain/entity"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CategoryRepository defines the interface for category data access operations
|
|
||||||
type CategoryRepository interface {
|
|
||||||
BaseCategoryRepository
|
|
||||||
FindBySlug(ctx context.Context, slug string) (*entity.Category, error)
|
|
||||||
FindByParentID(ctx context.Context, parentID string) ([]*entity.Category, error)
|
|
||||||
FindRootCategories(ctx context.Context) ([]*entity.Category, error)
|
|
||||||
FindActive(ctx context.Context) ([]*entity.Category, error)
|
|
||||||
IncrementPostCount(ctx context.Context, categoryID string) error
|
|
||||||
DecrementPostCount(ctx context.Context, categoryID string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// BaseCategoryRepository defines basic CRUD operations for categories
|
|
||||||
type BaseCategoryRepository interface {
|
|
||||||
Insert(ctx context.Context, data *entity.Category) error
|
|
||||||
FindOne(ctx context.Context, id string) (*entity.Category, error)
|
|
||||||
Update(ctx context.Context, data *entity.Category) error
|
|
||||||
Delete(ctx context.Context, id string) error
|
|
||||||
}
|
|
||||||
|
|
@ -1,46 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"backend/pkg/post/domain/entity"
|
|
||||||
"backend/pkg/post/domain/post"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CommentRepository defines the interface for comment data access operations
|
|
||||||
type CommentRepository interface {
|
|
||||||
BaseCommentRepository
|
|
||||||
FindByPostID(ctx context.Context, postID gocql.UUID, params *CommentQueryParams) ([]*entity.Comment, int64, error)
|
|
||||||
FindByParentID(ctx context.Context, parentID gocql.UUID, params *CommentQueryParams) ([]*entity.Comment, int64, error)
|
|
||||||
FindByAuthorUID(ctx context.Context, authorUID string, params *CommentQueryParams) ([]*entity.Comment, int64, error)
|
|
||||||
FindReplies(ctx context.Context, commentID gocql.UUID, params *CommentQueryParams) ([]*entity.Comment, int64, error)
|
|
||||||
IncrementLikeCount(ctx context.Context, commentID gocql.UUID) error
|
|
||||||
DecrementLikeCount(ctx context.Context, commentID gocql.UUID) error
|
|
||||||
IncrementReplyCount(ctx context.Context, commentID gocql.UUID) error
|
|
||||||
DecrementReplyCount(ctx context.Context, commentID gocql.UUID) error
|
|
||||||
UpdateStatus(ctx context.Context, commentID gocql.UUID, status post.CommentStatus) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// BaseCommentRepository defines basic CRUD operations for comments
|
|
||||||
type BaseCommentRepository interface {
|
|
||||||
Insert(ctx context.Context, data *entity.Comment) error
|
|
||||||
FindOne(ctx context.Context, id gocql.UUID) (*entity.Comment, error)
|
|
||||||
Update(ctx context.Context, data *entity.Comment) error
|
|
||||||
Delete(ctx context.Context, id gocql.UUID) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// CommentQueryParams defines query parameters for comment listing
|
|
||||||
type CommentQueryParams struct {
|
|
||||||
PostID *gocql.UUID
|
|
||||||
ParentID *gocql.UUID
|
|
||||||
AuthorUID *string
|
|
||||||
Status *post.CommentStatus
|
|
||||||
CreateStartTime *int64
|
|
||||||
CreateEndTime *int64
|
|
||||||
PageSize int64
|
|
||||||
PageIndex int64
|
|
||||||
OrderBy string // "created_at", "like_count"
|
|
||||||
OrderDirection string // "ASC", "DESC"
|
|
||||||
}
|
|
||||||
|
|
@ -1,37 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"backend/pkg/post/domain/entity"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// LikeRepository defines the interface for like data access operations
|
|
||||||
type LikeRepository interface {
|
|
||||||
BaseLikeRepository
|
|
||||||
FindByTargetID(ctx context.Context, targetID gocql.UUID, targetType string) ([]*entity.Like, error)
|
|
||||||
FindByUserUID(ctx context.Context, userUID string, params *LikeQueryParams) ([]*entity.Like, int64, error)
|
|
||||||
FindByTargetAndUser(ctx context.Context, targetID gocql.UUID, userUID string, targetType string) (*entity.Like, error)
|
|
||||||
CountByTargetID(ctx context.Context, targetID gocql.UUID, targetType string) (int64, error)
|
|
||||||
DeleteByTargetAndUser(ctx context.Context, targetID gocql.UUID, userUID string, targetType string) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// BaseLikeRepository defines basic CRUD operations for likes
|
|
||||||
type BaseLikeRepository interface {
|
|
||||||
Insert(ctx context.Context, data *entity.Like) error
|
|
||||||
FindOne(ctx context.Context, id gocql.UUID) (*entity.Like, error)
|
|
||||||
Delete(ctx context.Context, id gocql.UUID) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// LikeQueryParams defines query parameters for like listing
|
|
||||||
type LikeQueryParams struct {
|
|
||||||
TargetID *gocql.UUID
|
|
||||||
TargetType *string
|
|
||||||
UserUID *string
|
|
||||||
PageSize int64
|
|
||||||
PageIndex int64
|
|
||||||
OrderBy string // "created_at"
|
|
||||||
OrderDirection string // "ASC", "DESC"
|
|
||||||
}
|
|
||||||
|
|
@ -1,54 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"backend/pkg/post/domain/entity"
|
|
||||||
"backend/pkg/post/domain/post"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PostRepository defines the interface for post data access operations
|
|
||||||
type PostRepository interface {
|
|
||||||
BasePostRepository
|
|
||||||
FindByAuthorUID(ctx context.Context, authorUID string, params *PostQueryParams) ([]*entity.Post, int64, error)
|
|
||||||
FindByCategoryID(ctx context.Context, categoryID gocql.UUID, params *PostQueryParams) ([]*entity.Post, int64, error)
|
|
||||||
FindByTag(ctx context.Context, tagName string, params *PostQueryParams) ([]*entity.Post, int64, error)
|
|
||||||
FindPinnedPosts(ctx context.Context, limit int64) ([]*entity.Post, error)
|
|
||||||
FindByStatus(ctx context.Context, status post.Status, params *PostQueryParams) ([]*entity.Post, int64, error)
|
|
||||||
IncrementLikeCount(ctx context.Context, postID gocql.UUID) error
|
|
||||||
DecrementLikeCount(ctx context.Context, postID gocql.UUID) error
|
|
||||||
IncrementCommentCount(ctx context.Context, postID gocql.UUID) error
|
|
||||||
DecrementCommentCount(ctx context.Context, postID gocql.UUID) error
|
|
||||||
IncrementViewCount(ctx context.Context, postID gocql.UUID) error
|
|
||||||
UpdateStatus(ctx context.Context, postID gocql.UUID, status post.Status) error
|
|
||||||
PinPost(ctx context.Context, postID gocql.UUID) error
|
|
||||||
UnpinPost(ctx context.Context, postID gocql.UUID) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// BasePostRepository defines basic CRUD operations for posts
|
|
||||||
type BasePostRepository interface {
|
|
||||||
Insert(ctx context.Context, data *entity.Post) error
|
|
||||||
FindOne(ctx context.Context, id gocql.UUID) (*entity.Post, error)
|
|
||||||
Update(ctx context.Context, data *entity.Post) error
|
|
||||||
Delete(ctx context.Context, id gocql.UUID) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// PostQueryParams defines query parameters for post listing
|
|
||||||
type PostQueryParams struct {
|
|
||||||
AuthorUID *string
|
|
||||||
CategoryID *gocql.UUID
|
|
||||||
Tag *string
|
|
||||||
Status *post.Status
|
|
||||||
Type *post.Type
|
|
||||||
IsPinned *bool
|
|
||||||
CreateStartTime *int64
|
|
||||||
CreateEndTime *int64
|
|
||||||
PublishedStartTime *int64
|
|
||||||
PublishedEndTime *int64
|
|
||||||
PageSize int64
|
|
||||||
PageIndex int64
|
|
||||||
OrderBy string // "created_at", "published_at", "like_count", "view_count"
|
|
||||||
OrderDirection string // "ASC", "DESC"
|
|
||||||
}
|
|
||||||
|
|
@ -1,28 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"backend/pkg/post/domain/entity"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TagRepository defines the interface for tag data access operations
|
|
||||||
type TagRepository interface {
|
|
||||||
BaseTagRepository
|
|
||||||
FindByName(ctx context.Context, name string) (*entity.Tag, error)
|
|
||||||
FindByNames(ctx context.Context, names []string) ([]*entity.Tag, error)
|
|
||||||
FindPopular(ctx context.Context, limit int64) ([]*entity.Tag, error)
|
|
||||||
IncrementPostCount(ctx context.Context, tagID gocql.UUID) error
|
|
||||||
DecrementPostCount(ctx context.Context, tagID gocql.UUID) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// BaseTagRepository defines basic CRUD operations for tags
|
|
||||||
type BaseTagRepository interface {
|
|
||||||
Insert(ctx context.Context, data *entity.Tag) error
|
|
||||||
FindOne(ctx context.Context, id gocql.UUID) (*entity.Tag, error)
|
|
||||||
Update(ctx context.Context, data *entity.Tag) error
|
|
||||||
Delete(ctx context.Context, id gocql.UUID) error
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,128 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"backend/pkg/post/domain/post"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CommentUseCase defines the interface for comment business logic operations
|
|
||||||
type CommentUseCase interface {
|
|
||||||
CommentCRUDUseCase
|
|
||||||
CommentQueryUseCase
|
|
||||||
CommentInteractionUseCase
|
|
||||||
}
|
|
||||||
|
|
||||||
// CommentCRUDUseCase defines CRUD operations for comments
|
|
||||||
type CommentCRUDUseCase interface {
|
|
||||||
// CreateComment creates a new comment
|
|
||||||
CreateComment(ctx context.Context, req CreateCommentRequest) (*CommentResponse, error)
|
|
||||||
// GetComment retrieves a comment by ID
|
|
||||||
GetComment(ctx context.Context, req GetCommentRequest) (*CommentResponse, error)
|
|
||||||
// UpdateComment updates an existing comment
|
|
||||||
UpdateComment(ctx context.Context, req UpdateCommentRequest) (*CommentResponse, error)
|
|
||||||
// DeleteComment deletes a comment (soft delete)
|
|
||||||
DeleteComment(ctx context.Context, req DeleteCommentRequest) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// CommentQueryUseCase defines query operations for comments
|
|
||||||
type CommentQueryUseCase interface {
|
|
||||||
// ListComments lists comments for a post
|
|
||||||
ListComments(ctx context.Context, req ListCommentsRequest) (*ListCommentsResponse, error)
|
|
||||||
// ListReplies lists replies to a comment
|
|
||||||
ListReplies(ctx context.Context, req ListRepliesRequest) (*ListCommentsResponse, error)
|
|
||||||
// ListCommentsByAuthor lists comments by author
|
|
||||||
ListCommentsByAuthor(ctx context.Context, req ListCommentsByAuthorRequest) (*ListCommentsResponse, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// CommentInteractionUseCase defines interaction operations for comments
|
|
||||||
type CommentInteractionUseCase interface {
|
|
||||||
// LikeComment likes a comment
|
|
||||||
LikeComment(ctx context.Context, req LikeCommentRequest) error
|
|
||||||
// UnlikeComment unlikes a comment
|
|
||||||
UnlikeComment(ctx context.Context, req UnlikeCommentRequest) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateCommentRequest represents a request to create a comment
|
|
||||||
type CreateCommentRequest struct {
|
|
||||||
PostID gocql.UUID `json:"post_id"` // Post ID
|
|
||||||
AuthorUID string `json:"author_uid"` // Author user UID
|
|
||||||
ParentID *gocql.UUID `json:"parent_id,omitempty"` // Parent comment ID (optional, for replies)
|
|
||||||
Content string `json:"content"` // Comment content
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateCommentRequest represents a request to update a comment
|
|
||||||
type UpdateCommentRequest struct {
|
|
||||||
CommentID gocql.UUID `json:"comment_id"` // Comment ID
|
|
||||||
AuthorUID string `json:"author_uid"` // Author user UID (for authorization)
|
|
||||||
Content string `json:"content"` // Comment content
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetCommentRequest represents a request to get a comment
|
|
||||||
type GetCommentRequest struct {
|
|
||||||
CommentID gocql.UUID `json:"comment_id"` // Comment ID
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteCommentRequest represents a request to delete a comment
|
|
||||||
type DeleteCommentRequest struct {
|
|
||||||
CommentID gocql.UUID `json:"comment_id"` // Comment ID
|
|
||||||
AuthorUID string `json:"author_uid"` // Author user UID (for authorization)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListCommentsRequest represents a request to list comments
|
|
||||||
type ListCommentsRequest struct {
|
|
||||||
PostID gocql.UUID `json:"post_id"` // Post ID
|
|
||||||
ParentID *gocql.UUID `json:"parent_id,omitempty"` // Parent comment ID (optional, for replies only)
|
|
||||||
PageSize int64 `json:"page_size"` // Page size
|
|
||||||
PageIndex int64 `json:"page_index"` // Page index
|
|
||||||
OrderBy string `json:"order_by,omitempty"` // Order by field (default: "created_at")
|
|
||||||
OrderDirection string `json:"order_direction,omitempty"` // Order direction (ASC/DESC, default: ASC)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListRepliesRequest represents a request to list replies to a comment
|
|
||||||
type ListRepliesRequest struct {
|
|
||||||
CommentID gocql.UUID `json:"comment_id"` // Comment ID
|
|
||||||
PageSize int64 `json:"page_size"` // Page size
|
|
||||||
PageIndex int64 `json:"page_index"` // Page index
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListCommentsByAuthorRequest represents a request to list comments by author
|
|
||||||
type ListCommentsByAuthorRequest struct {
|
|
||||||
AuthorUID string `json:"author_uid"` // Author UID
|
|
||||||
PageSize int64 `json:"page_size"` // Page size
|
|
||||||
PageIndex int64 `json:"page_index"` // Page index
|
|
||||||
}
|
|
||||||
|
|
||||||
// LikeCommentRequest represents a request to like a comment
|
|
||||||
type LikeCommentRequest struct {
|
|
||||||
CommentID gocql.UUID `json:"comment_id"` // Comment ID
|
|
||||||
UserUID string `json:"user_uid"` // User UID
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnlikeCommentRequest represents a request to unlike a comment
|
|
||||||
type UnlikeCommentRequest struct {
|
|
||||||
CommentID gocql.UUID `json:"comment_id"` // Comment ID
|
|
||||||
UserUID string `json:"user_uid"` // User UID
|
|
||||||
}
|
|
||||||
|
|
||||||
// CommentResponse represents a comment response
|
|
||||||
type CommentResponse struct {
|
|
||||||
ID gocql.UUID `json:"id"`
|
|
||||||
PostID gocql.UUID `json:"post_id"`
|
|
||||||
AuthorUID string `json:"author_uid"`
|
|
||||||
ParentID *gocql.UUID `json:"parent_id,omitempty"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
Status post.CommentStatus `json:"status"`
|
|
||||||
LikeCount int64 `json:"like_count"`
|
|
||||||
ReplyCount int64 `json:"reply_count"`
|
|
||||||
CreatedAt int64 `json:"created_at"`
|
|
||||||
UpdatedAt int64 `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListCommentsResponse represents a list of comments response
|
|
||||||
type ListCommentsResponse struct {
|
|
||||||
Data []CommentResponse `json:"data"`
|
|
||||||
Page Pager `json:"page"`
|
|
||||||
}
|
|
||||||
|
|
@ -1,229 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"backend/pkg/post/domain/post"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PostUseCase defines the interface for post business logic operations
|
|
||||||
type PostUseCase interface {
|
|
||||||
PostCRUDUseCase
|
|
||||||
PostQueryUseCase
|
|
||||||
PostInteractionUseCase
|
|
||||||
PostManagementUseCase
|
|
||||||
}
|
|
||||||
|
|
||||||
// PostCRUDUseCase defines CRUD operations for posts
|
|
||||||
type PostCRUDUseCase interface {
|
|
||||||
// CreatePost creates a new post
|
|
||||||
CreatePost(ctx context.Context, req CreatePostRequest) (*PostResponse, error)
|
|
||||||
// GetPost retrieves a post by ID
|
|
||||||
GetPost(ctx context.Context, req GetPostRequest) (*PostResponse, error)
|
|
||||||
// UpdatePost updates an existing post
|
|
||||||
UpdatePost(ctx context.Context, req UpdatePostRequest) (*PostResponse, error)
|
|
||||||
// DeletePost deletes a post (soft delete)
|
|
||||||
DeletePost(ctx context.Context, req DeletePostRequest) error
|
|
||||||
// PublishPost publishes a draft post
|
|
||||||
PublishPost(ctx context.Context, req PublishPostRequest) (*PostResponse, error)
|
|
||||||
// ArchivePost archives a post
|
|
||||||
ArchivePost(ctx context.Context, req ArchivePostRequest) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// PostQueryUseCase defines query operations for posts
|
|
||||||
type PostQueryUseCase interface {
|
|
||||||
// ListPosts lists posts with filters and pagination
|
|
||||||
ListPosts(ctx context.Context, req ListPostsRequest) (*ListPostsResponse, error)
|
|
||||||
// ListPostsByAuthor lists posts by author UID
|
|
||||||
ListPostsByAuthor(ctx context.Context, req ListPostsByAuthorRequest) (*ListPostsResponse, error)
|
|
||||||
// ListPostsByCategory lists posts by category
|
|
||||||
ListPostsByCategory(ctx context.Context, req ListPostsByCategoryRequest) (*ListPostsResponse, error)
|
|
||||||
// ListPostsByTag lists posts by tag
|
|
||||||
ListPostsByTag(ctx context.Context, req ListPostsByTagRequest) (*ListPostsResponse, error)
|
|
||||||
// GetPinnedPosts gets pinned posts
|
|
||||||
GetPinnedPosts(ctx context.Context, req GetPinnedPostsRequest) (*ListPostsResponse, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PostInteractionUseCase defines interaction operations for posts
|
|
||||||
type PostInteractionUseCase interface {
|
|
||||||
// LikePost likes a post
|
|
||||||
LikePost(ctx context.Context, req LikePostRequest) error
|
|
||||||
// UnlikePost unlikes a post
|
|
||||||
UnlikePost(ctx context.Context, req UnlikePostRequest) error
|
|
||||||
// ViewPost increments view count
|
|
||||||
ViewPost(ctx context.Context, req ViewPostRequest) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// PostManagementUseCase defines management operations for posts
|
|
||||||
type PostManagementUseCase interface {
|
|
||||||
// PinPost pins a post
|
|
||||||
PinPost(ctx context.Context, req PinPostRequest) error
|
|
||||||
// UnpinPost unpins a post
|
|
||||||
UnpinPost(ctx context.Context, req UnpinPostRequest) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatePostRequest represents a request to create a post
|
|
||||||
type CreatePostRequest struct {
|
|
||||||
AuthorUID string `json:"author_uid"` // Author user UID
|
|
||||||
Title string `json:"title"` // Post title
|
|
||||||
Content string `json:"content"` // Post content
|
|
||||||
Type post.Type `json:"type"` // Post type
|
|
||||||
CategoryID *gocql.UUID `json:"category_id,omitempty"` // Category ID (optional)
|
|
||||||
Tags []string `json:"tags,omitempty"` // Post tags (optional)
|
|
||||||
Images []string `json:"images,omitempty"` // Image URLs (optional)
|
|
||||||
VideoURL *string `json:"video_url,omitempty"` // Video URL (optional)
|
|
||||||
LinkURL *string `json:"link_url,omitempty"` // Link URL (optional)
|
|
||||||
Status post.Status `json:"status,omitempty"` // Post status (default: draft)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePostRequest represents a request to update a post
|
|
||||||
type UpdatePostRequest struct {
|
|
||||||
PostID gocql.UUID `json:"post_id"` // Post ID
|
|
||||||
AuthorUID string `json:"author_uid"` // Author user UID (for authorization)
|
|
||||||
Title *string `json:"title,omitempty"` // Post title (optional)
|
|
||||||
Content *string `json:"content,omitempty"` // Post content (optional)
|
|
||||||
Type *post.Type `json:"type,omitempty"` // Post type (optional)
|
|
||||||
CategoryID *gocql.UUID `json:"category_id,omitempty"` // Category ID (optional)
|
|
||||||
Tags []string `json:"tags,omitempty"` // Post tags (optional)
|
|
||||||
Images []string `json:"images,omitempty"` // Image URLs (optional)
|
|
||||||
VideoURL *string `json:"video_url,omitempty"` // Video URL (optional)
|
|
||||||
LinkURL *string `json:"link_url,omitempty"` // Link URL (optional)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPostRequest represents a request to get a post
|
|
||||||
type GetPostRequest struct {
|
|
||||||
PostID gocql.UUID `json:"post_id"` // Post ID
|
|
||||||
UserUID *string `json:"user_uid,omitempty"` // User UID (for view count increment)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeletePostRequest represents a request to delete a post
|
|
||||||
type DeletePostRequest struct {
|
|
||||||
PostID gocql.UUID `json:"post_id"` // Post ID
|
|
||||||
AuthorUID string `json:"author_uid"` // Author user UID (for authorization)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PublishPostRequest represents a request to publish a post
|
|
||||||
type PublishPostRequest struct {
|
|
||||||
PostID gocql.UUID `json:"post_id"` // Post ID
|
|
||||||
AuthorUID string `json:"author_uid"` // Author user UID (for authorization)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ArchivePostRequest represents a request to archive a post
|
|
||||||
type ArchivePostRequest struct {
|
|
||||||
PostID gocql.UUID `json:"post_id"` // Post ID
|
|
||||||
AuthorUID string `json:"author_uid"` // Author user UID (for authorization)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListPostsRequest represents a request to list posts
|
|
||||||
type ListPostsRequest struct {
|
|
||||||
CategoryID *gocql.UUID `json:"category_id,omitempty"` // Category ID (optional)
|
|
||||||
Tag *string `json:"tag,omitempty"` // Tag name (optional)
|
|
||||||
Status *post.Status `json:"status,omitempty"` // Post status (optional)
|
|
||||||
Type *post.Type `json:"type,omitempty"` // Post type (optional)
|
|
||||||
AuthorUID *string `json:"author_uid,omitempty"` // Author UID (optional)
|
|
||||||
CreateStartTime *int64 `json:"create_start_time,omitempty"` // Create start time (optional)
|
|
||||||
CreateEndTime *int64 `json:"create_end_time,omitempty"` // Create end time (optional)
|
|
||||||
PageSize int64 `json:"page_size"` // Page size
|
|
||||||
PageIndex int64 `json:"page_index"` // Page index
|
|
||||||
OrderBy string `json:"order_by,omitempty"` // Order by field
|
|
||||||
OrderDirection string `json:"order_direction,omitempty"` // Order direction (ASC/DESC)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListPostsByAuthorRequest represents a request to list posts by author
|
|
||||||
type ListPostsByAuthorRequest struct {
|
|
||||||
AuthorUID string `json:"author_uid"` // Author UID
|
|
||||||
Status *post.Status `json:"status,omitempty"` // Post status (optional)
|
|
||||||
PageSize int64 `json:"page_size"` // Page size
|
|
||||||
PageIndex int64 `json:"page_index"` // Page index
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListPostsByCategoryRequest represents a request to list posts by category
|
|
||||||
type ListPostsByCategoryRequest struct {
|
|
||||||
CategoryID gocql.UUID `json:"category_id"` // Category ID
|
|
||||||
Status *post.Status `json:"status,omitempty"` // Post status (optional)
|
|
||||||
PageSize int64 `json:"page_size"` // Page size
|
|
||||||
PageIndex int64 `json:"page_index"` // Page index
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListPostsByTagRequest represents a request to list posts by tag
|
|
||||||
type ListPostsByTagRequest struct {
|
|
||||||
Tag string `json:"tag"` // Tag name
|
|
||||||
Status *post.Status `json:"status,omitempty"` // Post status (optional)
|
|
||||||
PageSize int64 `json:"page_size"` // Page size
|
|
||||||
PageIndex int64 `json:"page_index"` // Page index
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPinnedPostsRequest represents a request to get pinned posts
|
|
||||||
type GetPinnedPostsRequest struct {
|
|
||||||
Limit int64 `json:"limit,omitempty"` // Limit (optional, default: 10)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LikePostRequest represents a request to like a post
|
|
||||||
type LikePostRequest struct {
|
|
||||||
PostID gocql.UUID `json:"post_id"` // Post ID
|
|
||||||
UserUID string `json:"user_uid"` // User UID
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnlikePostRequest represents a request to unlike a post
|
|
||||||
type UnlikePostRequest struct {
|
|
||||||
PostID gocql.UUID `json:"post_id"` // Post ID
|
|
||||||
UserUID string `json:"user_uid"` // User UID
|
|
||||||
}
|
|
||||||
|
|
||||||
// ViewPostRequest represents a request to view a post
|
|
||||||
type ViewPostRequest struct {
|
|
||||||
PostID gocql.UUID `json:"post_id"` // Post ID
|
|
||||||
UserUID *string `json:"user_uid,omitempty"` // User UID (optional)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PinPostRequest represents a request to pin a post
|
|
||||||
type PinPostRequest struct {
|
|
||||||
PostID gocql.UUID `json:"post_id"` // Post ID
|
|
||||||
AuthorUID string `json:"author_uid"` // Author user UID (for authorization)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnpinPostRequest represents a request to unpin a post
|
|
||||||
type UnpinPostRequest struct {
|
|
||||||
PostID gocql.UUID `json:"post_id"` // Post ID
|
|
||||||
AuthorUID string `json:"author_uid"` // Author user UID (for authorization)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PostResponse represents a post response
|
|
||||||
type PostResponse struct {
|
|
||||||
ID gocql.UUID `json:"id"`
|
|
||||||
AuthorUID string `json:"author_uid"`
|
|
||||||
Title string `json:"title"`
|
|
||||||
Content string `json:"content"`
|
|
||||||
Type post.Type `json:"type"`
|
|
||||||
Status post.Status `json:"status"`
|
|
||||||
CategoryID *gocql.UUID `json:"category_id,omitempty"`
|
|
||||||
Tags []string `json:"tags,omitempty"`
|
|
||||||
Images []string `json:"images,omitempty"`
|
|
||||||
VideoURL *string `json:"video_url,omitempty"`
|
|
||||||
LinkURL *string `json:"link_url,omitempty"`
|
|
||||||
LikeCount int64 `json:"like_count"`
|
|
||||||
CommentCount int64 `json:"comment_count"`
|
|
||||||
ViewCount int64 `json:"view_count"`
|
|
||||||
IsPinned bool `json:"is_pinned"`
|
|
||||||
PinnedAt *int64 `json:"pinned_at,omitempty"`
|
|
||||||
PublishedAt *int64 `json:"published_at,omitempty"`
|
|
||||||
CreatedAt int64 `json:"created_at"`
|
|
||||||
UpdatedAt int64 `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListPostsResponse represents a list of posts response
|
|
||||||
type ListPostsResponse struct {
|
|
||||||
Data []PostResponse `json:"data"`
|
|
||||||
Page Pager `json:"page"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pager represents pagination information
|
|
||||||
type Pager struct {
|
|
||||||
PageIndex int64 `json:"page_index"`
|
|
||||||
PageSize int64 `json:"page_size"`
|
|
||||||
Total int64 `json:"total"`
|
|
||||||
TotalPage int64 `json:"total_page"`
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,263 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"backend/pkg/library/cassandra"
|
|
||||||
"backend/pkg/post/domain/entity"
|
|
||||||
domainRepo "backend/pkg/post/domain/repository"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CategoryRepositoryParam 定義 CategoryRepository 的初始化參數
|
|
||||||
type CategoryRepositoryParam struct {
|
|
||||||
DB *cassandra.DB
|
|
||||||
Keyspace string
|
|
||||||
}
|
|
||||||
|
|
||||||
// CategoryRepository 實作 domain repository 介面
|
|
||||||
type CategoryRepository struct {
|
|
||||||
repo cassandra.Repository[*entity.Category]
|
|
||||||
db *cassandra.DB
|
|
||||||
keyspace string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCategoryRepository 創建新的 CategoryRepository
|
|
||||||
func NewCategoryRepository(param CategoryRepositoryParam) domainRepo.CategoryRepository {
|
|
||||||
repo, err := cassandra.NewRepository[*entity.Category](param.DB, param.Keyspace)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("failed to create category repository: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
keyspace := param.Keyspace
|
|
||||||
if keyspace == "" {
|
|
||||||
keyspace = param.DB.GetDefaultKeyspace()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &CategoryRepository{
|
|
||||||
repo: repo,
|
|
||||||
db: param.DB,
|
|
||||||
keyspace: keyspace,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert 插入單筆分類
|
|
||||||
func (r *CategoryRepository) Insert(ctx context.Context, data *entity.Category) error {
|
|
||||||
if data == nil {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證資料
|
|
||||||
if err := data.Validate(); err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if data.ParentID == nil {
|
|
||||||
data.ParentID = &gocql.UUID{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 設置時間戳
|
|
||||||
data.SetTimestamps()
|
|
||||||
|
|
||||||
// 如果是新分類,生成 ID
|
|
||||||
if data.IsNew() {
|
|
||||||
data.ID = gocql.TimeUUID()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Slug 轉為小寫
|
|
||||||
data.Slug = strings.ToLower(strings.TrimSpace(data.Slug))
|
|
||||||
|
|
||||||
return r.repo.Insert(ctx, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindOne 根據 ID 查詢單筆分類
|
|
||||||
func (r *CategoryRepository) FindOne(ctx context.Context, id string) (*entity.Category, error) {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
uuid, err := gocql.ParseUUID(id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if uuid == zeroUUID {
|
|
||||||
return nil, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
category, err := r.repo.Get(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
return nil, ErrNotFound
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("failed to find category: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return category, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update 更新分類
|
|
||||||
func (r *CategoryRepository) Update(ctx context.Context, data *entity.Category) error {
|
|
||||||
if data == nil {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證資料
|
|
||||||
if err := data.Validate(); err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新時間戳
|
|
||||||
data.SetTimestamps()
|
|
||||||
|
|
||||||
// Slug 轉為小寫
|
|
||||||
data.Slug = strings.ToLower(strings.TrimSpace(data.Slug))
|
|
||||||
|
|
||||||
return r.repo.Update(ctx, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete 刪除分類
|
|
||||||
func (r *CategoryRepository) Delete(ctx context.Context, id string) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
uuid, err := gocql.ParseUUID(id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if uuid == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.repo.Delete(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindBySlug 根據 slug 查詢分類
|
|
||||||
func (r *CategoryRepository) FindBySlug(ctx context.Context, slug string) (*entity.Category, error) {
|
|
||||||
if slug == "" {
|
|
||||||
return nil, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 標準化 slug
|
|
||||||
slug = strings.ToLower(strings.TrimSpace(slug))
|
|
||||||
|
|
||||||
// 構建查詢(要有 SAI 索引在 slug 欄位上)
|
|
||||||
query := r.repo.Query().Where(cassandra.Eq("slug", slug))
|
|
||||||
|
|
||||||
var categories []*entity.Category
|
|
||||||
if err := query.Scan(ctx, &categories); err != nil {
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
return nil, ErrNotFound
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("failed to query category: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(categories) == 0 {
|
|
||||||
return nil, ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return categories[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByParentID 根據父分類 ID 查詢子分類
|
|
||||||
func (r *CategoryRepository) FindByParentID(ctx context.Context, parentID string) ([]*entity.Category, error) {
|
|
||||||
query := r.repo.Query()
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if parentID != "" {
|
|
||||||
// 構建查詢(有 SAI 索引在 parentID 欄位上)
|
|
||||||
uuid, err := gocql.ParseUUID(parentID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if uuid != zeroUUID {
|
|
||||||
query = query.Where(cassandra.Eq("parent_id", uuid))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
query = query.Where(cassandra.Eq("parent_id", zeroUUID))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 按 sort_order 排序
|
|
||||||
query = query.OrderBy("sort_order", cassandra.ASC)
|
|
||||||
|
|
||||||
var categories []*entity.Category
|
|
||||||
if err := query.Scan(ctx, &categories); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to query categories: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return categories, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindRootCategories 查詢根分類
|
|
||||||
func (r *CategoryRepository) FindRootCategories(ctx context.Context) ([]*entity.Category, error) {
|
|
||||||
return r.FindByParentID(ctx, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindActive 查詢啟用的分類
|
|
||||||
func (r *CategoryRepository) FindActive(ctx context.Context) ([]*entity.Category, error) {
|
|
||||||
query := r.repo.Query().
|
|
||||||
Where(cassandra.Eq("is_active", true)).
|
|
||||||
OrderBy("sort_order", cassandra.ASC)
|
|
||||||
|
|
||||||
var categories []*entity.Category
|
|
||||||
if err := query.Scan(ctx, &categories); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to query active categories: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := categories
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementPostCount 增加貼文數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:post_count 欄位必須是 counter 類型
|
|
||||||
func (r *CategoryRepository) IncrementPostCount(ctx context.Context, categoryID string) error {
|
|
||||||
uuid, err := gocql.ParseUUID(categoryID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%w: invalid category ID: %v", ErrInvalidInput, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用 counter 原子更新操作:UPDATE categories SET post_count = post_count + 1 WHERE id = ?
|
|
||||||
var zeroCategory entity.Category
|
|
||||||
tableName := zeroCategory.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count + 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(uuid)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to increment post count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrementPostCount 減少貼文數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:post_count 欄位必須是 counter 類型
|
|
||||||
func (r *CategoryRepository) DecrementPostCount(ctx context.Context, categoryID string) error {
|
|
||||||
uuid, err := gocql.ParseUUID(categoryID)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("%w: invalid category ID: %v", ErrInvalidInput, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用 counter 原子更新操作:UPDATE categories SET post_count = post_count - 1 WHERE id = ?
|
|
||||||
var zeroCategory entity.Category
|
|
||||||
tableName := zeroCategory.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count - 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(uuid)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to decrement post count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
@ -1,383 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"backend/pkg/library/cassandra"
|
|
||||||
"backend/pkg/post/domain/entity"
|
|
||||||
"backend/pkg/post/domain/post"
|
|
||||||
domainRepo "backend/pkg/post/domain/repository"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CommentRepositoryParam 定義 CommentRepository 的初始化參數
|
|
||||||
type CommentRepositoryParam struct {
|
|
||||||
DB *cassandra.DB
|
|
||||||
Keyspace string
|
|
||||||
}
|
|
||||||
|
|
||||||
// CommentRepository 實作 domain repository 介面
|
|
||||||
type CommentRepository struct {
|
|
||||||
repo cassandra.Repository[*entity.Comment]
|
|
||||||
db *cassandra.DB
|
|
||||||
keyspace string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewCommentRepository 創建新的 CommentRepository
|
|
||||||
func NewCommentRepository(param CommentRepositoryParam) domainRepo.CommentRepository {
|
|
||||||
repo, err := cassandra.NewRepository[*entity.Comment](param.DB, param.Keyspace)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("failed to create comment repository: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
keyspace := param.Keyspace
|
|
||||||
if keyspace == "" {
|
|
||||||
keyspace = param.DB.GetDefaultKeyspace()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &CommentRepository{
|
|
||||||
repo: repo,
|
|
||||||
db: param.DB,
|
|
||||||
keyspace: keyspace,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert 插入單筆評論
|
|
||||||
func (r *CommentRepository) Insert(ctx context.Context, data *entity.Comment) error {
|
|
||||||
if data == nil {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證資料
|
|
||||||
if err := data.Validate(); err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 設置時間戳
|
|
||||||
data.SetTimestamps()
|
|
||||||
|
|
||||||
// 如果是新評論,生成 ID
|
|
||||||
if data.IsNew() {
|
|
||||||
data.ID = gocql.TimeUUID()
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.repo.Insert(ctx, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindOne 根據 ID 查詢單筆評論
|
|
||||||
func (r *CommentRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Comment, error) {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if id == zeroUUID {
|
|
||||||
return nil, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
comment, err := r.repo.Get(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
return nil, ErrNotFound
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("failed to find comment: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return comment, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update 更新評論
|
|
||||||
func (r *CommentRepository) Update(ctx context.Context, data *entity.Comment) error {
|
|
||||||
if data == nil {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證資料
|
|
||||||
if err := data.Validate(); err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新時間戳
|
|
||||||
data.SetTimestamps()
|
|
||||||
|
|
||||||
return r.repo.Update(ctx, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete 刪除評論(軟刪除)
|
|
||||||
func (r *CommentRepository) Delete(ctx context.Context, id gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if id == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 先查詢評論
|
|
||||||
comment, err := r.FindOne(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 軟刪除:標記為已刪除
|
|
||||||
comment.Delete()
|
|
||||||
return r.Update(ctx, comment)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByPostID 根據貼文 ID 查詢評論
|
|
||||||
func (r *CommentRepository) FindByPostID(ctx context.Context, postID gocql.UUID, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if postID == zeroUUID {
|
|
||||||
return nil, 0, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢(使用 PostID 作為 clustering key)
|
|
||||||
query := r.repo.Query().Where(cassandra.Eq("post_id", postID))
|
|
||||||
|
|
||||||
// 添加父評論過濾(如果指定,只查詢回覆)
|
|
||||||
if params != nil && params.ParentID != nil {
|
|
||||||
query = query.Where(cassandra.Eq("parent_id", *params.ParentID))
|
|
||||||
} else {
|
|
||||||
// 如果沒有指定 ParentID,只查詢頂層評論(parent_id 為 null)
|
|
||||||
// 注意:Cassandra 不支援直接查詢 null,需要特殊處理
|
|
||||||
// 這裡簡化處理,實際可能需要使用 Materialized View
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加狀態過濾
|
|
||||||
if params != nil && params.Status != nil {
|
|
||||||
query = query.Where(cassandra.Eq("status", *params.Status))
|
|
||||||
} else {
|
|
||||||
// 預設只查詢已發布的評論
|
|
||||||
published := post.CommentStatusPublished
|
|
||||||
query = query.Where(cassandra.Eq("status", published))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加排序
|
|
||||||
orderBy := "created_at"
|
|
||||||
if params != nil && params.OrderBy != "" {
|
|
||||||
orderBy = params.OrderBy
|
|
||||||
}
|
|
||||||
order := cassandra.ASC
|
|
||||||
if params != nil && params.OrderDirection == "DESC" {
|
|
||||||
order = cassandra.DESC
|
|
||||||
}
|
|
||||||
query = query.OrderBy(orderBy, order)
|
|
||||||
|
|
||||||
// 添加分頁
|
|
||||||
pageSize := int64(20)
|
|
||||||
if params != nil && params.PageSize > 0 {
|
|
||||||
pageSize = params.PageSize
|
|
||||||
}
|
|
||||||
limit := int(pageSize)
|
|
||||||
query = query.Limit(limit)
|
|
||||||
|
|
||||||
// 執行查詢
|
|
||||||
var comments []*entity.Comment
|
|
||||||
if err := query.Scan(ctx, &comments); err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("failed to query comments: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := comments
|
|
||||||
|
|
||||||
total := int64(len(result))
|
|
||||||
return result, total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByParentID 根據父評論 ID 查詢回覆
|
|
||||||
func (r *CommentRepository) FindByParentID(ctx context.Context, parentID gocql.UUID, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if parentID == zeroUUID {
|
|
||||||
return nil, 0, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
query := r.repo.Query().Where(cassandra.Eq("parent_id", parentID))
|
|
||||||
|
|
||||||
// 添加狀態過濾
|
|
||||||
if params != nil && params.Status != nil {
|
|
||||||
query = query.Where(cassandra.Eq("status", *params.Status))
|
|
||||||
} else {
|
|
||||||
published := post.CommentStatusPublished
|
|
||||||
query = query.Where(cassandra.Eq("status", published))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加排序和分頁
|
|
||||||
orderBy := "created_at"
|
|
||||||
if params != nil && params.OrderBy != "" {
|
|
||||||
orderBy = params.OrderBy
|
|
||||||
}
|
|
||||||
order := cassandra.ASC
|
|
||||||
if params != nil && params.OrderDirection == "DESC" {
|
|
||||||
order = cassandra.DESC
|
|
||||||
}
|
|
||||||
query = query.OrderBy(orderBy, order)
|
|
||||||
|
|
||||||
pageSize := int64(20)
|
|
||||||
if params != nil && params.PageSize > 0 {
|
|
||||||
pageSize = params.PageSize
|
|
||||||
}
|
|
||||||
query = query.Limit(int(pageSize))
|
|
||||||
|
|
||||||
var comments []*entity.Comment
|
|
||||||
if err := query.Scan(ctx, &comments); err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("failed to query replies: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return comments, int64(len(comments)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByAuthorUID 根據作者 UID 查詢評論
|
|
||||||
func (r *CommentRepository) FindByAuthorUID(ctx context.Context, authorUID string, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) {
|
|
||||||
if authorUID == "" {
|
|
||||||
return nil, 0, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
query := r.repo.Query().Where(cassandra.Eq("author_uid", authorUID))
|
|
||||||
|
|
||||||
// 添加狀態過濾
|
|
||||||
if params != nil && params.Status != nil {
|
|
||||||
query = query.Where(cassandra.Eq("status", *params.Status))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加排序和分頁
|
|
||||||
orderBy := "created_at"
|
|
||||||
if params != nil && params.OrderBy != "" {
|
|
||||||
orderBy = params.OrderBy
|
|
||||||
}
|
|
||||||
order := cassandra.DESC
|
|
||||||
if params != nil && params.OrderDirection == "ASC" {
|
|
||||||
order = cassandra.ASC
|
|
||||||
}
|
|
||||||
query = query.OrderBy(orderBy, order)
|
|
||||||
|
|
||||||
pageSize := int64(20)
|
|
||||||
if params != nil && params.PageSize > 0 {
|
|
||||||
pageSize = params.PageSize
|
|
||||||
}
|
|
||||||
query = query.Limit(int(pageSize))
|
|
||||||
|
|
||||||
var comments []*entity.Comment
|
|
||||||
if err := query.Scan(ctx, &comments); err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("failed to query comments: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return comments, int64(len(comments)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindReplies 查詢指定評論的回覆
|
|
||||||
func (r *CommentRepository) FindReplies(ctx context.Context, commentID gocql.UUID, params *domainRepo.CommentQueryParams) ([]*entity.Comment, int64, error) {
|
|
||||||
return r.FindByParentID(ctx, commentID, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementLikeCount 增加按讚數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:like_count 欄位必須是 counter 類型
|
|
||||||
func (r *CommentRepository) IncrementLikeCount(ctx context.Context, commentID gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if commentID == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
var zeroComment entity.Comment
|
|
||||||
tableName := zeroComment.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count + 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(commentID)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to increment like count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrementLikeCount 減少按讚數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:like_count 欄位必須是 counter 類型
|
|
||||||
func (r *CommentRepository) DecrementLikeCount(ctx context.Context, commentID gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if commentID == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
var zeroComment entity.Comment
|
|
||||||
tableName := zeroComment.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count - 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(commentID)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to decrement like count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementReplyCount 增加回覆數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:reply_count 欄位必須是 counter 類型
|
|
||||||
func (r *CommentRepository) IncrementReplyCount(ctx context.Context, commentID gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if commentID == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
var zeroComment entity.Comment
|
|
||||||
tableName := zeroComment.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET reply_count = reply_count + 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(commentID)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to increment reply count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrementReplyCount 減少回覆數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:reply_count 欄位必須是 counter 類型
|
|
||||||
func (r *CommentRepository) DecrementReplyCount(ctx context.Context, commentID gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if commentID == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
var zeroComment entity.Comment
|
|
||||||
tableName := zeroComment.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET reply_count = reply_count - 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(commentID)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to decrement reply count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateStatus 更新評論狀態
|
|
||||||
func (r *CommentRepository) UpdateStatus(ctx context.Context, commentID gocql.UUID, status post.CommentStatus) error {
|
|
||||||
comment, err := r.FindOne(ctx, commentID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
comment.Status = status
|
|
||||||
return r.Update(ctx, comment)
|
|
||||||
}
|
|
||||||
|
|
@ -1,34 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
|
|
||||||
"backend/pkg/library/cassandra"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Common repository errors
|
|
||||||
var (
|
|
||||||
// ErrNotFound is returned when a requested resource is not found
|
|
||||||
ErrNotFound = errors.New("resource not found")
|
|
||||||
|
|
||||||
// ErrInvalidInput is returned when input validation fails
|
|
||||||
ErrInvalidInput = errors.New("invalid input")
|
|
||||||
|
|
||||||
// ErrDuplicateKey is returned when attempting to insert a document with a duplicate key
|
|
||||||
ErrDuplicateKey = errors.New("duplicate key error")
|
|
||||||
)
|
|
||||||
|
|
||||||
// IsNotFound checks if the error is a not found error
|
|
||||||
func IsNotFound(err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if err == ErrNotFound {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,228 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"backend/pkg/library/cassandra"
|
|
||||||
"backend/pkg/post/domain/entity"
|
|
||||||
domainRepo "backend/pkg/post/domain/repository"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// LikeRepositoryParam 定義 LikeRepository 的初始化參數
|
|
||||||
type LikeRepositoryParam struct {
|
|
||||||
DB *cassandra.DB
|
|
||||||
Keyspace string
|
|
||||||
}
|
|
||||||
|
|
||||||
// LikeRepository 實作 domain repository 介面
|
|
||||||
type LikeRepository struct {
|
|
||||||
repo cassandra.Repository[*entity.Like]
|
|
||||||
db *cassandra.DB
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewLikeRepository 創建新的 LikeRepository
|
|
||||||
func NewLikeRepository(param LikeRepositoryParam) domainRepo.LikeRepository {
|
|
||||||
repo, err := cassandra.NewRepository[*entity.Like](param.DB, param.Keyspace)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("failed to create like repository: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return &LikeRepository{
|
|
||||||
repo: repo,
|
|
||||||
db: param.DB,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert 插入單筆按讚
|
|
||||||
func (r *LikeRepository) Insert(ctx context.Context, data *entity.Like) error {
|
|
||||||
if data == nil {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證資料
|
|
||||||
if err := data.Validate(); err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 設置時間戳
|
|
||||||
data.SetTimestamps()
|
|
||||||
|
|
||||||
// 如果是新按讚,生成 ID
|
|
||||||
if data.IsNew() {
|
|
||||||
data.ID = gocql.TimeUUID()
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.repo.Insert(ctx, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindOne 根據 ID 查詢單筆按讚
|
|
||||||
func (r *LikeRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Like, error) {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if id == zeroUUID {
|
|
||||||
return nil, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
like, err := r.repo.Get(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
return nil, ErrNotFound
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("failed to find like: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return like, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete 刪除按讚
|
|
||||||
func (r *LikeRepository) Delete(ctx context.Context, id gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if id == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.repo.Delete(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByTargetID 根據目標 ID 查詢按讚列表
|
|
||||||
func (r *LikeRepository) FindByTargetID(ctx context.Context, targetID gocql.UUID, targetType string) ([]*entity.Like, error) {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if targetID == zeroUUID {
|
|
||||||
return nil, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetType != "post" && targetType != "comment" {
|
|
||||||
return nil, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢
|
|
||||||
query := r.repo.Query().
|
|
||||||
Where(cassandra.Eq("target_id", targetID)).
|
|
||||||
Where(cassandra.Eq("target_type", targetType)).
|
|
||||||
OrderBy("created_at", cassandra.DESC)
|
|
||||||
|
|
||||||
var likes []*entity.Like
|
|
||||||
if err := query.Scan(ctx, &likes); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to query likes: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return likes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByUserUID 根據用戶 UID 查詢按讚列表
|
|
||||||
func (r *LikeRepository) FindByUserUID(ctx context.Context, userUID string, params *domainRepo.LikeQueryParams) ([]*entity.Like, int64, error) {
|
|
||||||
if userUID == "" {
|
|
||||||
return nil, 0, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
query := r.repo.Query().Where(cassandra.Eq("user_uid", userUID))
|
|
||||||
|
|
||||||
// 添加目標類型過濾
|
|
||||||
if params != nil && params.TargetType != nil {
|
|
||||||
query = query.Where(cassandra.Eq("target_type", *params.TargetType))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加目標 ID 過濾
|
|
||||||
if params != nil && params.TargetID != nil {
|
|
||||||
query = query.Where(cassandra.Eq("target_id", *params.TargetID))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加排序
|
|
||||||
orderBy := "created_at"
|
|
||||||
if params != nil && params.OrderBy != "" {
|
|
||||||
orderBy = params.OrderBy
|
|
||||||
}
|
|
||||||
order := cassandra.DESC
|
|
||||||
if params != nil && params.OrderDirection == "ASC" {
|
|
||||||
order = cassandra.ASC
|
|
||||||
}
|
|
||||||
query = query.OrderBy(orderBy, order)
|
|
||||||
|
|
||||||
// 添加分頁
|
|
||||||
pageSize := int64(20)
|
|
||||||
if params != nil && params.PageSize > 0 {
|
|
||||||
pageSize = params.PageSize
|
|
||||||
}
|
|
||||||
query = query.Limit(int(pageSize))
|
|
||||||
|
|
||||||
var likes []*entity.Like
|
|
||||||
if err := query.Scan(ctx, &likes); err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("failed to query likes: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := likes
|
|
||||||
|
|
||||||
return result, int64(len(result)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByTargetAndUser 根據目標和用戶查詢按讚
|
|
||||||
func (r *LikeRepository) FindByTargetAndUser(ctx context.Context, targetID gocql.UUID, userUID string, targetType string) (*entity.Like, error) {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if targetID == zeroUUID || userUID == "" {
|
|
||||||
return nil, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetType != "post" && targetType != "comment" {
|
|
||||||
return nil, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢
|
|
||||||
query := r.repo.Query().
|
|
||||||
Where(cassandra.Eq("target_id", targetID)).
|
|
||||||
Where(cassandra.Eq("user_uid", userUID)).
|
|
||||||
Where(cassandra.Eq("target_type", targetType)).
|
|
||||||
Limit(1)
|
|
||||||
|
|
||||||
var likes []*entity.Like
|
|
||||||
if err := query.Scan(ctx, &likes); err != nil {
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
return nil, ErrNotFound
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("failed to query like: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(likes) == 0 {
|
|
||||||
return nil, ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return likes[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CountByTargetID 計算目標的按讚數
|
|
||||||
func (r *LikeRepository) CountByTargetID(ctx context.Context, targetID gocql.UUID, targetType string) (int64, error) {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if targetID == zeroUUID {
|
|
||||||
return 0, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetType != "post" && targetType != "comment" {
|
|
||||||
return 0, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢
|
|
||||||
query := r.repo.Query().
|
|
||||||
Where(cassandra.Eq("target_id", targetID)).
|
|
||||||
Where(cassandra.Eq("target_type", targetType))
|
|
||||||
|
|
||||||
count, err := query.Count(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("failed to count likes: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteByTargetAndUser 根據目標和用戶刪除按讚
|
|
||||||
func (r *LikeRepository) DeleteByTargetAndUser(ctx context.Context, targetID gocql.UUID, userUID string, targetType string) error {
|
|
||||||
// 先查詢按讚
|
|
||||||
like, err := r.FindByTargetAndUser(ctx, targetID, userUID, targetType)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 刪除按讚
|
|
||||||
return r.Delete(ctx, like.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,511 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
|
|
||||||
"backend/pkg/library/cassandra"
|
|
||||||
"backend/pkg/post/domain/entity"
|
|
||||||
"backend/pkg/post/domain/post"
|
|
||||||
domainRepo "backend/pkg/post/domain/repository"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PostRepositoryParam 定義 PostRepository 的初始化參數
|
|
||||||
type PostRepositoryParam struct {
|
|
||||||
DB *cassandra.DB
|
|
||||||
Keyspace string
|
|
||||||
}
|
|
||||||
|
|
||||||
// PostRepository 實作 domain repository 介面
|
|
||||||
type PostRepository struct {
|
|
||||||
repo cassandra.Repository[*entity.Post]
|
|
||||||
db *cassandra.DB
|
|
||||||
keyspace string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPostRepository 創建新的 PostRepository
|
|
||||||
func NewPostRepository(param PostRepositoryParam) domainRepo.PostRepository {
|
|
||||||
repo, err := cassandra.NewRepository[*entity.Post](param.DB, param.Keyspace)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("failed to create post repository: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
keyspace := param.Keyspace
|
|
||||||
if keyspace == "" {
|
|
||||||
keyspace = param.DB.GetDefaultKeyspace()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &PostRepository{
|
|
||||||
repo: repo,
|
|
||||||
db: param.DB,
|
|
||||||
keyspace: keyspace,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert 插入單筆貼文
|
|
||||||
func (r *PostRepository) Insert(ctx context.Context, data *entity.Post) error {
|
|
||||||
if data == nil {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證資料
|
|
||||||
if err := data.Validate(); err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 設置時間戳
|
|
||||||
data.SetTimestamps()
|
|
||||||
|
|
||||||
// 如果是新貼文,生成 ID
|
|
||||||
if data.IsNew() {
|
|
||||||
data.ID = gocql.TimeUUID()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果狀態是 published,設置發布時間
|
|
||||||
if data.Status == post.PostStatusPublished && data.PublishedAt == nil {
|
|
||||||
now := data.CreatedAt
|
|
||||||
data.PublishedAt = &now
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.repo.Insert(ctx, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindOne 根據 ID 查詢單筆貼文
|
|
||||||
func (r *PostRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Post, error) {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if id == zeroUUID {
|
|
||||||
return nil, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
post, err := r.repo.Get(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
return nil, ErrNotFound
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("failed to find post: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return post, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update 更新貼文
|
|
||||||
func (r *PostRepository) Update(ctx context.Context, data *entity.Post) error {
|
|
||||||
if data == nil {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證資料
|
|
||||||
if err := data.Validate(); err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新時間戳
|
|
||||||
data.SetTimestamps()
|
|
||||||
|
|
||||||
return r.repo.Update(ctx, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete 刪除貼文(軟刪除)
|
|
||||||
func (r *PostRepository) Delete(ctx context.Context, id gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if id == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 先查詢貼文
|
|
||||||
post, err := r.FindOne(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 軟刪除:標記為已刪除
|
|
||||||
post.Delete()
|
|
||||||
return r.Update(ctx, post)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByAuthorUID 根據作者 UID 查詢貼文
|
|
||||||
func (r *PostRepository) FindByAuthorUID(ctx context.Context, authorUID string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
|
|
||||||
if authorUID == "" {
|
|
||||||
return nil, 0, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢
|
|
||||||
query := r.repo.Query().Where(cassandra.Eq("author_uid", authorUID))
|
|
||||||
|
|
||||||
// 添加狀態過濾
|
|
||||||
if params != nil && params.Status != nil {
|
|
||||||
query = query.Where(cassandra.Eq("status", *params.Status))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加排序
|
|
||||||
orderBy := "created_at"
|
|
||||||
if params != nil && params.OrderBy != "" {
|
|
||||||
orderBy = params.OrderBy
|
|
||||||
}
|
|
||||||
order := cassandra.DESC
|
|
||||||
if params != nil && params.OrderDirection == "ASC" {
|
|
||||||
order = cassandra.ASC
|
|
||||||
}
|
|
||||||
query = query.OrderBy(orderBy, order)
|
|
||||||
|
|
||||||
// 添加分頁
|
|
||||||
pageSize := int64(20)
|
|
||||||
if params != nil && params.PageSize > 0 {
|
|
||||||
pageSize = params.PageSize
|
|
||||||
}
|
|
||||||
pageIndex := int64(1)
|
|
||||||
if params != nil && params.PageIndex > 0 {
|
|
||||||
pageIndex = params.PageIndex
|
|
||||||
}
|
|
||||||
limit := int(pageSize)
|
|
||||||
query = query.Limit(limit)
|
|
||||||
|
|
||||||
// 執行查詢
|
|
||||||
var posts []*entity.Post
|
|
||||||
if err := query.Scan(ctx, &posts); err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := posts
|
|
||||||
|
|
||||||
// 計算總數(簡化實作,實際應該使用 COUNT 查詢)
|
|
||||||
total := int64(len(posts))
|
|
||||||
if params != nil && params.PageIndex > 1 {
|
|
||||||
// 這裡應該執行 COUNT 查詢,但為了簡化,我們假設有更多結果
|
|
||||||
total = pageSize * pageIndex
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByCategoryID 根據分類 ID 查詢貼文
|
|
||||||
func (r *PostRepository) FindByCategoryID(ctx context.Context, categoryID gocql.UUID, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if categoryID == zeroUUID {
|
|
||||||
return nil, 0, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢
|
|
||||||
query := r.repo.Query().Where(cassandra.Eq("category_id", categoryID))
|
|
||||||
|
|
||||||
// 添加狀態過濾
|
|
||||||
if params != nil && params.Status != nil {
|
|
||||||
query = query.Where(cassandra.Eq("status", *params.Status))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加排序和分頁(類似 FindByAuthorUID)
|
|
||||||
orderBy := "created_at"
|
|
||||||
if params != nil && params.OrderBy != "" {
|
|
||||||
orderBy = params.OrderBy
|
|
||||||
}
|
|
||||||
order := cassandra.DESC
|
|
||||||
if params != nil && params.OrderDirection == "ASC" {
|
|
||||||
order = cassandra.ASC
|
|
||||||
}
|
|
||||||
query = query.OrderBy(orderBy, order)
|
|
||||||
|
|
||||||
pageSize := int64(20)
|
|
||||||
if params != nil && params.PageSize > 0 {
|
|
||||||
pageSize = params.PageSize
|
|
||||||
}
|
|
||||||
limit := int(pageSize)
|
|
||||||
query = query.Limit(limit)
|
|
||||||
|
|
||||||
var posts []*entity.Post
|
|
||||||
if err := query.Scan(ctx, &posts); err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := posts
|
|
||||||
|
|
||||||
total := int64(len(posts))
|
|
||||||
return result, total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByTag 根據標籤查詢貼文
|
|
||||||
func (r *PostRepository) FindByTag(ctx context.Context, tagName string, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
|
|
||||||
if tagName == "" {
|
|
||||||
return nil, 0, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢(注意:Cassandra 的集合查詢需要使用 CONTAINS,這裡簡化處理)
|
|
||||||
// 實際實作中,可能需要使用 SAI 索引或 Materialized View
|
|
||||||
query := r.repo.Query()
|
|
||||||
|
|
||||||
// 添加狀態過濾
|
|
||||||
if params != nil && params.Status != nil {
|
|
||||||
query = query.Where(cassandra.Eq("status", *params.Status))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加排序和分頁
|
|
||||||
orderBy := "created_at"
|
|
||||||
if params != nil && params.OrderBy != "" {
|
|
||||||
orderBy = params.OrderBy
|
|
||||||
}
|
|
||||||
order := cassandra.DESC
|
|
||||||
if params != nil && params.OrderDirection == "ASC" {
|
|
||||||
order = cassandra.ASC
|
|
||||||
}
|
|
||||||
query = query.OrderBy(orderBy, order)
|
|
||||||
|
|
||||||
pageSize := int64(20)
|
|
||||||
if params != nil && params.PageSize > 0 {
|
|
||||||
pageSize = params.PageSize
|
|
||||||
}
|
|
||||||
limit := int(pageSize)
|
|
||||||
query = query.Limit(limit)
|
|
||||||
|
|
||||||
var posts []*entity.Post
|
|
||||||
if err := query.Scan(ctx, &posts); err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 過濾包含指定標籤的貼文
|
|
||||||
filtered := make([]*entity.Post, 0)
|
|
||||||
for _, p := range posts {
|
|
||||||
for _, tag := range p.Tags {
|
|
||||||
if tag == tagName {
|
|
||||||
filtered = append(filtered, p)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
total := int64(len(filtered))
|
|
||||||
return filtered, total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindPinnedPosts 查詢置頂貼文
|
|
||||||
func (r *PostRepository) FindPinnedPosts(ctx context.Context, limit int64) ([]*entity.Post, error) {
|
|
||||||
query := r.repo.Query().
|
|
||||||
Where(cassandra.Eq("is_pinned", true)).
|
|
||||||
Where(cassandra.Eq("status", post.PostStatusPublished)).
|
|
||||||
OrderBy("pinned_at", cassandra.DESC).
|
|
||||||
Limit(int(limit))
|
|
||||||
|
|
||||||
var posts []*entity.Post
|
|
||||||
if err := query.Scan(ctx, &posts); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to query pinned posts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return posts, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByStatus 根據狀態查詢貼文
|
|
||||||
func (r *PostRepository) FindByStatus(ctx context.Context, status post.Status, params *domainRepo.PostQueryParams) ([]*entity.Post, int64, error) {
|
|
||||||
query := r.repo.Query().Where(cassandra.Eq("status", status))
|
|
||||||
|
|
||||||
// 添加排序和分頁
|
|
||||||
orderBy := "created_at"
|
|
||||||
if params != nil && params.OrderBy != "" {
|
|
||||||
orderBy = params.OrderBy
|
|
||||||
}
|
|
||||||
order := cassandra.DESC
|
|
||||||
if params != nil && params.OrderDirection == "ASC" {
|
|
||||||
order = cassandra.ASC
|
|
||||||
}
|
|
||||||
query = query.OrderBy(orderBy, order)
|
|
||||||
|
|
||||||
pageSize := int64(20)
|
|
||||||
if params != nil && params.PageSize > 0 {
|
|
||||||
pageSize = params.PageSize
|
|
||||||
}
|
|
||||||
limit := int(pageSize)
|
|
||||||
query = query.Limit(limit)
|
|
||||||
|
|
||||||
var posts []*entity.Post
|
|
||||||
if err := query.Scan(ctx, &posts); err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("failed to query posts: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := posts
|
|
||||||
|
|
||||||
total := int64(len(posts))
|
|
||||||
return result, total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementLikeCount 增加按讚數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:like_count 欄位必須是 counter 類型
|
|
||||||
func (r *PostRepository) IncrementLikeCount(ctx context.Context, postID gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if postID == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
var zeroPost entity.Post
|
|
||||||
tableName := zeroPost.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count + 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(postID)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to increment like count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrementLikeCount 減少按讚數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:like_count 欄位必須是 counter 類型
|
|
||||||
func (r *PostRepository) DecrementLikeCount(ctx context.Context, postID gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if postID == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
var zeroPost entity.Post
|
|
||||||
tableName := zeroPost.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET like_count = like_count - 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(postID)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to decrement like count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementCommentCount 增加評論數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:comment_count 欄位必須是 counter 類型
|
|
||||||
func (r *PostRepository) IncrementCommentCount(ctx context.Context, postID gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if postID == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
var zeroPost entity.Post
|
|
||||||
tableName := zeroPost.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count + 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(postID)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to increment comment count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrementCommentCount 減少評論數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:comment_count 欄位必須是 counter 類型
|
|
||||||
func (r *PostRepository) DecrementCommentCount(ctx context.Context, postID gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if postID == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
var zeroPost entity.Post
|
|
||||||
tableName := zeroPost.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET comment_count = comment_count - 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(postID)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to decrement comment count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementViewCount 增加瀏覽數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:view_count 欄位必須是 counter 類型
|
|
||||||
func (r *PostRepository) IncrementViewCount(ctx context.Context, postID gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if postID == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
var zeroPost entity.Post
|
|
||||||
tableName := zeroPost.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET view_count = view_count + 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(postID)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to increment view count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateStatus 更新貼文狀態
|
|
||||||
func (r *PostRepository) UpdateStatus(ctx context.Context, postID gocql.UUID, status post.Status) error {
|
|
||||||
postEntity, err := r.FindOne(ctx, postID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
postEntity.Status = status
|
|
||||||
publishedStatus := post.PostStatusPublished
|
|
||||||
if status == publishedStatus && postEntity.PublishedAt == nil {
|
|
||||||
now := postEntity.UpdatedAt
|
|
||||||
postEntity.PublishedAt = &now
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.Update(ctx, postEntity)
|
|
||||||
}
|
|
||||||
|
|
||||||
// PinPost 置頂貼文
|
|
||||||
func (r *PostRepository) PinPost(ctx context.Context, postID gocql.UUID) error {
|
|
||||||
post, err := r.FindOne(ctx, postID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
post.Pin()
|
|
||||||
return r.Update(ctx, post)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnpinPost 取消置頂
|
|
||||||
func (r *PostRepository) UnpinPost(ctx context.Context, postID gocql.UUID) error {
|
|
||||||
post, err := r.FindOne(ctx, postID)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
post.Unpin()
|
|
||||||
return r.Update(ctx, post)
|
|
||||||
}
|
|
||||||
|
|
||||||
// calculateTotalPages 計算總頁數
|
|
||||||
func calculateTotalPages(total, pageSize int64) int64 {
|
|
||||||
if pageSize <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return int64(math.Ceil(float64(total) / float64(pageSize)))
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,250 +0,0 @@
|
||||||
package repository
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"backend/pkg/library/cassandra"
|
|
||||||
"backend/pkg/post/domain/entity"
|
|
||||||
domainRepo "backend/pkg/post/domain/repository"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TagRepositoryParam 定義 TagRepository 的初始化參數
|
|
||||||
type TagRepositoryParam struct {
|
|
||||||
DB *cassandra.DB
|
|
||||||
Keyspace string
|
|
||||||
}
|
|
||||||
|
|
||||||
// TagRepository 實作 domain repository 介面
|
|
||||||
type TagRepository struct {
|
|
||||||
repo cassandra.Repository[*entity.Tag]
|
|
||||||
db *cassandra.DB
|
|
||||||
keyspace string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTagRepository 創建新的 TagRepository
|
|
||||||
func NewTagRepository(param TagRepositoryParam) domainRepo.TagRepository {
|
|
||||||
repo, err := cassandra.NewRepository[*entity.Tag](param.DB, param.Keyspace)
|
|
||||||
if err != nil {
|
|
||||||
panic(fmt.Sprintf("failed to create tag repository: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
keyspace := param.Keyspace
|
|
||||||
if keyspace == "" {
|
|
||||||
keyspace = param.DB.GetDefaultKeyspace()
|
|
||||||
}
|
|
||||||
|
|
||||||
return &TagRepository{
|
|
||||||
repo: repo,
|
|
||||||
db: param.DB,
|
|
||||||
keyspace: keyspace,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert 插入單筆標籤
|
|
||||||
func (r *TagRepository) Insert(ctx context.Context, data *entity.Tag) error {
|
|
||||||
if data == nil {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證資料
|
|
||||||
if err := data.Validate(); err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 設置時間戳
|
|
||||||
data.SetTimestamps()
|
|
||||||
|
|
||||||
// 如果是新標籤,生成 ID
|
|
||||||
if data.IsNew() {
|
|
||||||
data.ID = gocql.TimeUUID()
|
|
||||||
}
|
|
||||||
|
|
||||||
// 標籤名稱轉為小寫(統一格式)
|
|
||||||
data.Name = strings.ToLower(strings.TrimSpace(data.Name))
|
|
||||||
|
|
||||||
return r.repo.Insert(ctx, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindOne 根據 ID 查詢單筆標籤
|
|
||||||
func (r *TagRepository) FindOne(ctx context.Context, id gocql.UUID) (*entity.Tag, error) {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if id == zeroUUID {
|
|
||||||
return nil, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
tag, err := r.repo.Get(ctx, id)
|
|
||||||
if err != nil {
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
return nil, ErrNotFound
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("failed to find tag: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tag, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update 更新標籤
|
|
||||||
func (r *TagRepository) Update(ctx context.Context, data *entity.Tag) error {
|
|
||||||
if data == nil {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證資料
|
|
||||||
if err := data.Validate(); err != nil {
|
|
||||||
return fmt.Errorf("%w: %v", ErrInvalidInput, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新時間戳
|
|
||||||
data.SetTimestamps()
|
|
||||||
|
|
||||||
// 標籤名稱轉為小寫
|
|
||||||
data.Name = strings.ToLower(strings.TrimSpace(data.Name))
|
|
||||||
|
|
||||||
return r.repo.Update(ctx, data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete 刪除標籤
|
|
||||||
func (r *TagRepository) Delete(ctx context.Context, id gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if id == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.repo.Delete(ctx, id)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByName 根據名稱查詢標籤
|
|
||||||
func (r *TagRepository) FindByName(ctx context.Context, name string) (*entity.Tag, error) {
|
|
||||||
if name == "" {
|
|
||||||
return nil, ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 標準化名稱
|
|
||||||
name = strings.ToLower(strings.TrimSpace(name))
|
|
||||||
|
|
||||||
// 構建查詢(假設有 SAI 索引在 name 欄位上)
|
|
||||||
query := r.repo.Query().Where(cassandra.Eq("name", name))
|
|
||||||
|
|
||||||
var tags []*entity.Tag
|
|
||||||
if err := query.Scan(ctx, &tags); err != nil {
|
|
||||||
if cassandra.IsNotFound(err) {
|
|
||||||
return nil, ErrNotFound
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("failed to query tag: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(tags) == 0 {
|
|
||||||
return nil, ErrNotFound
|
|
||||||
}
|
|
||||||
|
|
||||||
return tags[0], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindByNames 根據名稱列表查詢標籤
|
|
||||||
func (r *TagRepository) FindByNames(ctx context.Context, names []string) ([]*entity.Tag, error) {
|
|
||||||
if len(names) == 0 {
|
|
||||||
return []*entity.Tag{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 標準化名稱
|
|
||||||
normalizedNames := make([]string, len(names))
|
|
||||||
for i, name := range names {
|
|
||||||
normalizedNames[i] = strings.ToLower(strings.TrimSpace(name))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢(使用 IN 條件)
|
|
||||||
query := r.repo.Query().Where(cassandra.In("name", toAnySlice(normalizedNames)))
|
|
||||||
|
|
||||||
var tags []*entity.Tag
|
|
||||||
if err := query.Scan(ctx, &tags); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to query tags: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return tags, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindPopular 查詢熱門標籤
|
|
||||||
func (r *TagRepository) FindPopular(ctx context.Context, limit int64) ([]*entity.Tag, error) {
|
|
||||||
// 構建查詢,按 post_count 降序排列
|
|
||||||
query := r.repo.Query().
|
|
||||||
OrderBy("post_count", cassandra.DESC).
|
|
||||||
Limit(int(limit))
|
|
||||||
|
|
||||||
var tags []*entity.Tag
|
|
||||||
if err := query.Scan(ctx, &tags); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to query popular tags: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
result := tags
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncrementPostCount 增加貼文數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:post_count 欄位必須是 counter 類型
|
|
||||||
func (r *TagRepository) IncrementPostCount(ctx context.Context, tagID gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if tagID == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用 counter 原子更新操作:UPDATE tags SET post_count = post_count + 1 WHERE id = ?
|
|
||||||
var zeroTag entity.Tag
|
|
||||||
tableName := zeroTag.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count + 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(tagID)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to increment post count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrementPostCount 減少貼文數(使用 counter 原子操作避免競爭條件)
|
|
||||||
// 注意:post_count 欄位必須是 counter 類型
|
|
||||||
func (r *TagRepository) DecrementPostCount(ctx context.Context, tagID gocql.UUID) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if tagID == zeroUUID {
|
|
||||||
return ErrInvalidInput
|
|
||||||
}
|
|
||||||
|
|
||||||
// 使用 counter 原子更新操作:UPDATE tags SET post_count = post_count - 1 WHERE id = ?
|
|
||||||
var zeroTag entity.Tag
|
|
||||||
tableName := zeroTag.TableName()
|
|
||||||
if r.keyspace == "" {
|
|
||||||
return fmt.Errorf("%w: keyspace is required", ErrInvalidInput)
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt := fmt.Sprintf("UPDATE %s.%s SET post_count = post_count - 1 WHERE id = ?", r.keyspace, tableName)
|
|
||||||
query := r.db.GetSession().Query(stmt, nil).
|
|
||||||
WithContext(ctx).
|
|
||||||
Consistency(gocql.Quorum).
|
|
||||||
Bind(tagID)
|
|
||||||
|
|
||||||
if err := query.ExecRelease(); err != nil {
|
|
||||||
return fmt.Errorf("failed to decrement post count: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// toAnySlice 將 string slice 轉換為 []any
|
|
||||||
func toAnySlice(strs []string) []any {
|
|
||||||
result := make([]any, len(strs))
|
|
||||||
for i, s := range strs {
|
|
||||||
result[i] = s
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
@ -1,455 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
|
|
||||||
errs "backend/pkg/library/errors"
|
|
||||||
"backend/pkg/post/domain/entity"
|
|
||||||
"backend/pkg/post/domain/post"
|
|
||||||
domainRepo "backend/pkg/post/domain/repository"
|
|
||||||
domainUsecase "backend/pkg/post/domain/usecase"
|
|
||||||
"backend/pkg/post/repository"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CommentUseCaseParam 定義 CommentUseCase 的初始化參數
|
|
||||||
type CommentUseCaseParam struct {
|
|
||||||
Comment domainRepo.CommentRepository
|
|
||||||
Post domainRepo.PostRepository
|
|
||||||
Like domainRepo.LikeRepository
|
|
||||||
Logger errs.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// CommentUseCase 實作 domain usecase 介面
|
|
||||||
type CommentUseCase struct {
|
|
||||||
CommentUseCaseParam
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustCommentUseCase 創建新的 CommentUseCase(如果失敗會 panic)
|
|
||||||
func MustCommentUseCase(param CommentUseCaseParam) domainUsecase.CommentUseCase {
|
|
||||||
return &CommentUseCase{
|
|
||||||
CommentUseCaseParam: param,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreateComment 創建新評論
|
|
||||||
func (uc *CommentUseCase) CreateComment(ctx context.Context, req domainUsecase.CreateCommentRequest) (*domainUsecase.CommentResponse, error) {
|
|
||||||
// 驗證輸入
|
|
||||||
if err := uc.validateCreateCommentRequest(req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證貼文存在
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return nil, errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
|
||||||
if err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
|
||||||
}
|
|
||||||
return nil, uc.handleDBError("Post.FindOne", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 檢查貼文是否可見
|
|
||||||
if !post.IsVisible() {
|
|
||||||
return nil, errs.ResNotFoundError("cannot comment on non-visible post")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 建立評論實體
|
|
||||||
comment := &entity.Comment{
|
|
||||||
PostID: req.PostID,
|
|
||||||
AuthorUID: req.AuthorUID,
|
|
||||||
ParentID: req.ParentID,
|
|
||||||
Content: req.Content,
|
|
||||||
Status: post.CommentStatusPublished,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 插入資料庫
|
|
||||||
if err := uc.Comment.Insert(ctx, comment); err != nil {
|
|
||||||
return nil, uc.handleDBError("Comment.Insert", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果是回覆,增加父評論的回覆數
|
|
||||||
if req.ParentID != nil {
|
|
||||||
if err := uc.Comment.IncrementReplyCount(ctx, *req.ParentID); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to increment reply count: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 增加貼文的評論數
|
|
||||||
if err := uc.Post.IncrementCommentCount(ctx, req.PostID); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to increment comment count: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.mapCommentToResponse(comment), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetComment 取得評論
|
|
||||||
func (uc *CommentUseCase) GetComment(ctx context.Context, req domainUsecase.GetCommentRequest) (*domainUsecase.CommentResponse, error) {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.CommentID == zeroUUID {
|
|
||||||
return nil, errs.InputInvalidRangeError("comment_id is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢評論
|
|
||||||
comment, err := uc.Comment.FindOne(ctx, req.CommentID)
|
|
||||||
if err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
return nil, errs.ResNotFoundError(fmt.Sprintf("comment not found: %s", req.CommentID))
|
|
||||||
}
|
|
||||||
return nil, uc.handleDBError("Comment.FindOne", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.mapCommentToResponse(comment), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateComment 更新評論
|
|
||||||
func (uc *CommentUseCase) UpdateComment(ctx context.Context, req domainUsecase.UpdateCommentRequest) (*domainUsecase.CommentResponse, error) {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.CommentID == zeroUUID {
|
|
||||||
return nil, errs.InputInvalidRangeError("comment_id is required")
|
|
||||||
}
|
|
||||||
if req.AuthorUID == "" {
|
|
||||||
return nil, errs.InputInvalidRangeError("author_uid is required")
|
|
||||||
}
|
|
||||||
if req.Content == "" {
|
|
||||||
return nil, errs.InputInvalidRangeError("content is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢現有評論
|
|
||||||
comment, err := uc.Comment.FindOne(ctx, req.CommentID)
|
|
||||||
if err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
return nil, errs.ResNotFoundError(fmt.Sprintf("comment not found: %s", req.CommentID))
|
|
||||||
}
|
|
||||||
return nil, uc.handleDBError("Comment.FindOne", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證權限
|
|
||||||
if comment.AuthorUID != req.AuthorUID {
|
|
||||||
return nil, errs.ResNotFoundError("not authorized to update this comment")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 檢查是否可見
|
|
||||||
if !comment.IsVisible() {
|
|
||||||
return nil, errs.ResNotFoundError("comment is not visible")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新內容
|
|
||||||
comment.Content = req.Content
|
|
||||||
|
|
||||||
// 更新資料庫
|
|
||||||
if err := uc.Comment.Update(ctx, comment); err != nil {
|
|
||||||
return nil, uc.handleDBError("Comment.Update", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.mapCommentToResponse(comment), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteComment 刪除評論(軟刪除)
|
|
||||||
func (uc *CommentUseCase) DeleteComment(ctx context.Context, req domainUsecase.DeleteCommentRequest) error {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.CommentID == zeroUUID {
|
|
||||||
return errs.InputInvalidRangeError("comment_id is required")
|
|
||||||
}
|
|
||||||
if req.AuthorUID == "" {
|
|
||||||
return errs.InputInvalidRangeError("author_uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢評論
|
|
||||||
comment, err := uc.Comment.FindOne(ctx, req.CommentID)
|
|
||||||
if err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
return errs.ResNotFoundError(fmt.Sprintf("comment not found: %s", req.CommentID))
|
|
||||||
}
|
|
||||||
return uc.handleDBError("Comment.FindOne", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證權限
|
|
||||||
if comment.AuthorUID != req.AuthorUID {
|
|
||||||
return errs.ResNotFoundError("not authorized to delete this comment")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 刪除評論
|
|
||||||
if err := uc.Comment.Delete(ctx, req.CommentID); err != nil {
|
|
||||||
return uc.handleDBError("Comment.Delete", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果是回覆,減少父評論的回覆數
|
|
||||||
if comment.ParentID != nil {
|
|
||||||
if err := uc.Comment.DecrementReplyCount(ctx, *comment.ParentID); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to decrement reply count: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 減少貼文的評論數
|
|
||||||
if err := uc.Post.DecrementCommentCount(ctx, comment.PostID); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to decrement comment count: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListComments 列出評論
|
|
||||||
func (uc *CommentUseCase) ListComments(ctx context.Context, req domainUsecase.ListCommentsRequest) (*domainUsecase.ListCommentsResponse, error) {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return nil, errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
if req.PageSize <= 0 {
|
|
||||||
req.PageSize = 20
|
|
||||||
}
|
|
||||||
if req.PageIndex <= 0 {
|
|
||||||
req.PageIndex = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢參數
|
|
||||||
params := &domainRepo.CommentQueryParams{
|
|
||||||
PostID: &req.PostID,
|
|
||||||
ParentID: req.ParentID,
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
OrderBy: req.OrderBy,
|
|
||||||
OrderDirection: req.OrderDirection,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果 OrderBy 未指定,預設為 created_at
|
|
||||||
if params.OrderBy == "" {
|
|
||||||
params.OrderBy = "created_at"
|
|
||||||
}
|
|
||||||
// 如果 OrderDirection 未指定,預設為 ASC
|
|
||||||
if params.OrderDirection == "" {
|
|
||||||
params.OrderDirection = "ASC"
|
|
||||||
}
|
|
||||||
|
|
||||||
// 執行查詢
|
|
||||||
comments, total, err := uc.Comment.FindByPostID(ctx, req.PostID, params)
|
|
||||||
if err != nil {
|
|
||||||
return nil, uc.handleDBError("Comment.FindByPostID", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 轉換為 Response
|
|
||||||
responses := make([]domainUsecase.CommentResponse, len(comments))
|
|
||||||
for i, c := range comments {
|
|
||||||
responses[i] = *uc.mapCommentToResponse(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &domainUsecase.ListCommentsResponse{
|
|
||||||
Data: responses,
|
|
||||||
Page: domainUsecase.Pager{
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
Total: total,
|
|
||||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListReplies 列出回覆
|
|
||||||
func (uc *CommentUseCase) ListReplies(ctx context.Context, req domainUsecase.ListRepliesRequest) (*domainUsecase.ListCommentsResponse, error) {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.CommentID == zeroUUID {
|
|
||||||
return nil, errs.InputInvalidRangeError("comment_id is required")
|
|
||||||
}
|
|
||||||
if req.PageSize <= 0 {
|
|
||||||
req.PageSize = 20
|
|
||||||
}
|
|
||||||
if req.PageIndex <= 0 {
|
|
||||||
req.PageIndex = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢參數
|
|
||||||
params := &domainRepo.CommentQueryParams{
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
OrderBy: "created_at",
|
|
||||||
OrderDirection: "ASC",
|
|
||||||
}
|
|
||||||
|
|
||||||
// 執行查詢
|
|
||||||
comments, total, err := uc.Comment.FindReplies(ctx, req.CommentID, params)
|
|
||||||
if err != nil {
|
|
||||||
return nil, uc.handleDBError("Comment.FindReplies", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 轉換為 Response
|
|
||||||
responses := make([]domainUsecase.CommentResponse, len(comments))
|
|
||||||
for i, c := range comments {
|
|
||||||
responses[i] = *uc.mapCommentToResponse(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &domainUsecase.ListCommentsResponse{
|
|
||||||
Data: responses,
|
|
||||||
Page: domainUsecase.Pager{
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
Total: total,
|
|
||||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListCommentsByAuthor 根據作者列出評論
|
|
||||||
func (uc *CommentUseCase) ListCommentsByAuthor(ctx context.Context, req domainUsecase.ListCommentsByAuthorRequest) (*domainUsecase.ListCommentsResponse, error) {
|
|
||||||
if req.AuthorUID == "" {
|
|
||||||
return nil, errs.InputInvalidRangeError("author_uid is required")
|
|
||||||
}
|
|
||||||
if req.PageSize <= 0 {
|
|
||||||
req.PageSize = 20
|
|
||||||
}
|
|
||||||
if req.PageIndex <= 0 {
|
|
||||||
req.PageIndex = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
params := &domainRepo.CommentQueryParams{
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
OrderBy: "created_at",
|
|
||||||
OrderDirection: "DESC",
|
|
||||||
}
|
|
||||||
|
|
||||||
comments, total, err := uc.Comment.FindByAuthorUID(ctx, req.AuthorUID, params)
|
|
||||||
if err != nil {
|
|
||||||
return nil, uc.handleDBError("Comment.FindByAuthorUID", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
responses := make([]domainUsecase.CommentResponse, len(comments))
|
|
||||||
for i, c := range comments {
|
|
||||||
responses[i] = *uc.mapCommentToResponse(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &domainUsecase.ListCommentsResponse{
|
|
||||||
Data: responses,
|
|
||||||
Page: domainUsecase.Pager{
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
Total: total,
|
|
||||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LikeComment 按讚評論
|
|
||||||
func (uc *CommentUseCase) LikeComment(ctx context.Context, req domainUsecase.LikeCommentRequest) error {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.CommentID == zeroUUID {
|
|
||||||
return errs.InputInvalidRangeError("comment_id is required")
|
|
||||||
}
|
|
||||||
if req.UserUID == "" {
|
|
||||||
return errs.InputInvalidRangeError("user_uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 檢查是否已經按讚
|
|
||||||
existingLike, err := uc.Like.FindByTargetAndUser(ctx, req.CommentID, req.UserUID, "comment")
|
|
||||||
if err == nil && existingLike != nil {
|
|
||||||
// 已經按讚,直接返回成功
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil && !repository.IsNotFound(err) {
|
|
||||||
return uc.handleDBError("Like.FindByTargetAndUser", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 建立按讚記錄
|
|
||||||
like := &entity.Like{
|
|
||||||
TargetID: req.CommentID,
|
|
||||||
UserUID: req.UserUID,
|
|
||||||
TargetType: "comment",
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := uc.Like.Insert(ctx, like); err != nil {
|
|
||||||
return uc.handleDBError("Like.Insert", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 增加評論的按讚數
|
|
||||||
if err := uc.Comment.IncrementLikeCount(ctx, req.CommentID); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to increment like count: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnlikeComment 取消按讚評論
|
|
||||||
func (uc *CommentUseCase) UnlikeComment(ctx context.Context, req domainUsecase.UnlikeCommentRequest) error {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.CommentID == zeroUUID {
|
|
||||||
return errs.InputInvalidRangeError("comment_id is required")
|
|
||||||
}
|
|
||||||
if req.UserUID == "" {
|
|
||||||
return errs.InputInvalidRangeError("user_uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 刪除按讚記錄
|
|
||||||
if err := uc.Like.DeleteByTargetAndUser(ctx, req.CommentID, req.UserUID, "comment"); err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
// 已經取消按讚,直接返回成功
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return uc.handleDBError("Like.DeleteByTargetAndUser", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 減少評論的按讚數
|
|
||||||
if err := uc.Comment.DecrementLikeCount(ctx, req.CommentID); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to decrement like count: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateCreateCommentRequest 驗證建立評論請求
|
|
||||||
func (uc *CommentUseCase) validateCreateCommentRequest(req domainUsecase.CreateCommentRequest) error {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
if req.AuthorUID == "" {
|
|
||||||
return errs.InputInvalidRangeError("author_uid is required")
|
|
||||||
}
|
|
||||||
if req.Content == "" {
|
|
||||||
return errs.InputInvalidRangeError("content is required")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// mapCommentToResponse 將 Comment 實體轉換為 CommentResponse
|
|
||||||
func (uc *CommentUseCase) mapCommentToResponse(comment *entity.Comment) *domainUsecase.CommentResponse {
|
|
||||||
return &domainUsecase.CommentResponse{
|
|
||||||
ID: comment.ID,
|
|
||||||
PostID: comment.PostID,
|
|
||||||
AuthorUID: comment.AuthorUID,
|
|
||||||
ParentID: comment.ParentID,
|
|
||||||
Content: comment.Content,
|
|
||||||
Status: comment.Status,
|
|
||||||
LikeCount: comment.LikeCount,
|
|
||||||
ReplyCount: comment.ReplyCount,
|
|
||||||
CreatedAt: comment.CreatedAt,
|
|
||||||
UpdatedAt: comment.UpdatedAt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleDBError 處理資料庫錯誤
|
|
||||||
func (uc *CommentUseCase) handleDBError(funcName string, req any, err error) error {
|
|
||||||
return errs.DBErrorErrorL(
|
|
||||||
uc.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "func", Val: funcName},
|
|
||||||
{Key: "req", Val: req},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
fmt.Sprintf("database operation failed: %s", funcName),
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -1,801 +0,0 @@
|
||||||
package usecase
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"math"
|
|
||||||
|
|
||||||
errs "backend/pkg/library/errors"
|
|
||||||
"backend/pkg/post/domain/entity"
|
|
||||||
"backend/pkg/post/domain/post"
|
|
||||||
domainRepo "backend/pkg/post/domain/repository"
|
|
||||||
domainUsecase "backend/pkg/post/domain/usecase"
|
|
||||||
"backend/pkg/post/repository"
|
|
||||||
|
|
||||||
"github.com/gocql/gocql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PostUseCaseParam 定義 PostUseCase 的初始化參數
|
|
||||||
type PostUseCaseParam struct {
|
|
||||||
Post domainRepo.PostRepository
|
|
||||||
Comment domainRepo.CommentRepository
|
|
||||||
Like domainRepo.LikeRepository
|
|
||||||
Tag domainRepo.TagRepository
|
|
||||||
Category domainRepo.CategoryRepository
|
|
||||||
Logger errs.Logger
|
|
||||||
}
|
|
||||||
|
|
||||||
// PostUseCase 實作 domain usecase 介面
|
|
||||||
type PostUseCase struct {
|
|
||||||
PostUseCaseParam
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustPostUseCase 創建新的 PostUseCase(如果失敗會 panic)
|
|
||||||
func MustPostUseCase(param PostUseCaseParam) domainUsecase.PostUseCase {
|
|
||||||
return &PostUseCase{
|
|
||||||
PostUseCaseParam: param,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CreatePost 創建新貼文
|
|
||||||
func (uc *PostUseCase) CreatePost(ctx context.Context, req domainUsecase.CreatePostRequest) (*domainUsecase.PostResponse, error) {
|
|
||||||
// 驗證輸入
|
|
||||||
if err := uc.validateCreatePostRequest(req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// 建立貼文實體
|
|
||||||
post := &entity.Post{
|
|
||||||
AuthorUID: req.AuthorUID,
|
|
||||||
Title: req.Title,
|
|
||||||
Content: req.Content,
|
|
||||||
Type: req.Type,
|
|
||||||
CategoryID: req.CategoryID,
|
|
||||||
Tags: req.Tags,
|
|
||||||
Images: req.Images,
|
|
||||||
VideoURL: req.VideoURL,
|
|
||||||
LinkURL: req.LinkURL,
|
|
||||||
Status: req.Status,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果狀態未指定,預設為草稿
|
|
||||||
if post.Status == 0 {
|
|
||||||
post.Status = post.PostStatusDraft
|
|
||||||
}
|
|
||||||
|
|
||||||
// 插入資料庫
|
|
||||||
if err := uc.Post.Insert(ctx, post); err != nil {
|
|
||||||
return nil, uc.handleDBError("Post.Insert", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 處理標籤(更新標籤的貼文數)
|
|
||||||
if err := uc.updateTagPostCounts(ctx, req.Tags, true); err != nil {
|
|
||||||
// 記錄錯誤但不中斷流程
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to update tag post counts: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 處理分類(更新分類的貼文數)
|
|
||||||
if req.CategoryID != nil {
|
|
||||||
if err := uc.Category.IncrementPostCount(ctx, *req.CategoryID); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to increment category post count: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.mapPostToResponse(post), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPost 取得貼文
|
|
||||||
func (uc *PostUseCase) GetPost(ctx context.Context, req domainUsecase.GetPostRequest) (*domainUsecase.PostResponse, error) {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return nil, errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢貼文
|
|
||||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
|
||||||
if err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
|
||||||
}
|
|
||||||
return nil, uc.handleDBError("Post.FindOne", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果提供了 UserUID,增加瀏覽數
|
|
||||||
if req.UserUID != nil {
|
|
||||||
if err := uc.Post.IncrementViewCount(ctx, req.PostID); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to increment view count: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.mapPostToResponse(post), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePost 更新貼文
|
|
||||||
func (uc *PostUseCase) UpdatePost(ctx context.Context, req domainUsecase.UpdatePostRequest) (*domainUsecase.PostResponse, error) {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return nil, errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
if req.AuthorUID == "" {
|
|
||||||
return nil, errs.InputInvalidRangeError("author_uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢現有貼文
|
|
||||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
|
||||||
if err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
|
||||||
}
|
|
||||||
return nil, uc.handleDBError("Post.FindOne", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證權限
|
|
||||||
if post.AuthorUID != req.AuthorUID {
|
|
||||||
return nil, errs.ResNotFoundError("not authorized to update this post")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 檢查是否可編輯
|
|
||||||
if !post.IsEditable() {
|
|
||||||
return nil, errs.ResNotFoundError("post is not editable")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新欄位
|
|
||||||
if req.Title != nil {
|
|
||||||
post.Title = *req.Title
|
|
||||||
}
|
|
||||||
if req.Content != nil {
|
|
||||||
post.Content = *req.Content
|
|
||||||
}
|
|
||||||
if req.Type != nil {
|
|
||||||
post.Type = *req.Type
|
|
||||||
}
|
|
||||||
if req.CategoryID != nil {
|
|
||||||
// 更新分類計數
|
|
||||||
if post.CategoryID != nil && *post.CategoryID != *req.CategoryID {
|
|
||||||
if err := uc.Category.DecrementPostCount(ctx, *post.CategoryID); err != nil {
|
|
||||||
uc.Logger.Error("failed to decrement category post count", errs.LogField{Key: "error", Val: err.Error()})
|
|
||||||
}
|
|
||||||
if err := uc.Category.IncrementPostCount(ctx, *req.CategoryID); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to increment category post count: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
post.CategoryID = req.CategoryID
|
|
||||||
}
|
|
||||||
if req.Tags != nil {
|
|
||||||
// 更新標籤計數
|
|
||||||
oldTags := post.Tags
|
|
||||||
post.Tags = req.Tags
|
|
||||||
if err := uc.updateTagPostCountsDiff(ctx, oldTags, req.Tags); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to update tag post counts: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if req.Images != nil {
|
|
||||||
post.Images = req.Images
|
|
||||||
}
|
|
||||||
if req.VideoURL != nil {
|
|
||||||
post.VideoURL = req.VideoURL
|
|
||||||
}
|
|
||||||
if req.LinkURL != nil {
|
|
||||||
post.LinkURL = req.LinkURL
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新資料庫
|
|
||||||
if err := uc.Post.Update(ctx, post); err != nil {
|
|
||||||
return nil, uc.handleDBError("Post.Update", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.mapPostToResponse(post), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeletePost 刪除貼文(軟刪除)
|
|
||||||
func (uc *PostUseCase) DeletePost(ctx context.Context, req domainUsecase.DeletePostRequest) error {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
if req.AuthorUID == "" {
|
|
||||||
return errs.InputInvalidRangeError("author_uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢貼文
|
|
||||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
|
||||||
if err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
|
||||||
}
|
|
||||||
return uc.handleDBError("Post.FindOne", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證權限
|
|
||||||
if post.AuthorUID != req.AuthorUID {
|
|
||||||
return errs.ResNotFoundError("not authorized to delete this post")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 刪除貼文
|
|
||||||
if err := uc.Post.Delete(ctx, req.PostID); err != nil {
|
|
||||||
return uc.handleDBError("Post.Delete", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新標籤和分類計數
|
|
||||||
if len(post.Tags) > 0 {
|
|
||||||
if err := uc.updateTagPostCounts(ctx, post.Tags, false); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to update tag post counts: %v", err))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if post.CategoryID != nil {
|
|
||||||
if err := uc.Category.DecrementPostCount(ctx, *post.CategoryID); err != nil {
|
|
||||||
uc.Logger.Error("failed to decrement category post count", errs.LogField{Key: "error", Val: err.Error()})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PublishPost 發布貼文
|
|
||||||
func (uc *PostUseCase) PublishPost(ctx context.Context, req domainUsecase.PublishPostRequest) (*domainUsecase.PostResponse, error) {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return nil, errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
if req.AuthorUID == "" {
|
|
||||||
return nil, errs.InputInvalidRangeError("author_uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢貼文
|
|
||||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
|
||||||
if err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
return nil, errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
|
||||||
}
|
|
||||||
return nil, uc.handleDBError("Post.FindOne", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證權限
|
|
||||||
if post.AuthorUID != req.AuthorUID {
|
|
||||||
return nil, errs.ResNotFoundError("not authorized to publish this post")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 發布貼文
|
|
||||||
post.Publish()
|
|
||||||
|
|
||||||
// 更新資料庫
|
|
||||||
if err := uc.Post.Update(ctx, post); err != nil {
|
|
||||||
return nil, uc.handleDBError("Post.Update", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return uc.mapPostToResponse(post), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ArchivePost 歸檔貼文
|
|
||||||
func (uc *PostUseCase) ArchivePost(ctx context.Context, req domainUsecase.ArchivePostRequest) error {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
if req.AuthorUID == "" {
|
|
||||||
return errs.InputInvalidRangeError("author_uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢貼文
|
|
||||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
|
||||||
if err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
|
||||||
}
|
|
||||||
return uc.handleDBError("Post.FindOne", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證權限
|
|
||||||
if post.AuthorUID != req.AuthorUID {
|
|
||||||
return errs.ResNotFoundError("not authorized to archive this post")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 歸檔貼文
|
|
||||||
post.Archive()
|
|
||||||
|
|
||||||
// 更新資料庫
|
|
||||||
return uc.Post.Update(ctx, post)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListPosts 列出貼文
|
|
||||||
func (uc *PostUseCase) ListPosts(ctx context.Context, req domainUsecase.ListPostsRequest) (*domainUsecase.ListPostsResponse, error) {
|
|
||||||
// 驗證分頁參數
|
|
||||||
if req.PageSize <= 0 {
|
|
||||||
req.PageSize = 20
|
|
||||||
}
|
|
||||||
if req.PageIndex <= 0 {
|
|
||||||
req.PageIndex = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// 構建查詢參數
|
|
||||||
params := &domainRepo.PostQueryParams{
|
|
||||||
CategoryID: req.CategoryID,
|
|
||||||
Tag: req.Tag,
|
|
||||||
Status: req.Status,
|
|
||||||
Type: req.Type,
|
|
||||||
AuthorUID: req.AuthorUID,
|
|
||||||
CreateStartTime: req.CreateStartTime,
|
|
||||||
CreateEndTime: req.CreateEndTime,
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
OrderBy: req.OrderBy,
|
|
||||||
OrderDirection: req.OrderDirection,
|
|
||||||
}
|
|
||||||
|
|
||||||
// 執行查詢
|
|
||||||
var posts []*entity.Post
|
|
||||||
var total int64
|
|
||||||
var err error
|
|
||||||
|
|
||||||
if req.CategoryID != nil {
|
|
||||||
posts, total, err = uc.Post.FindByCategoryID(ctx, *req.CategoryID, params)
|
|
||||||
} else if req.Tag != nil {
|
|
||||||
posts, total, err = uc.Post.FindByTag(ctx, *req.Tag, params)
|
|
||||||
} else if req.AuthorUID != nil {
|
|
||||||
posts, total, err = uc.Post.FindByAuthorUID(ctx, *req.AuthorUID, params)
|
|
||||||
} else if req.Status != nil {
|
|
||||||
posts, total, err = uc.Post.FindByStatus(ctx, *req.Status, params)
|
|
||||||
} else {
|
|
||||||
// 預設查詢所有已發布的貼文
|
|
||||||
published := post.PostStatusPublished
|
|
||||||
params.Status = &published
|
|
||||||
posts, total, err = uc.Post.FindByStatus(ctx, published, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, uc.handleDBError("Post.FindBy*", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 轉換為 Response
|
|
||||||
responses := make([]domainUsecase.PostResponse, len(posts))
|
|
||||||
for i, p := range posts {
|
|
||||||
responses[i] = *uc.mapPostToResponse(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &domainUsecase.ListPostsResponse{
|
|
||||||
Data: responses,
|
|
||||||
Page: domainUsecase.Pager{
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
Total: total,
|
|
||||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListPostsByAuthor 根據作者列出貼文
|
|
||||||
func (uc *PostUseCase) ListPostsByAuthor(ctx context.Context, req domainUsecase.ListPostsByAuthorRequest) (*domainUsecase.ListPostsResponse, error) {
|
|
||||||
if req.AuthorUID == "" {
|
|
||||||
return nil, errs.InputInvalidRangeError("author_uid is required")
|
|
||||||
}
|
|
||||||
if req.PageSize <= 0 {
|
|
||||||
req.PageSize = 20
|
|
||||||
}
|
|
||||||
if req.PageIndex <= 0 {
|
|
||||||
req.PageIndex = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
params := &domainRepo.PostQueryParams{
|
|
||||||
Status: req.Status,
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
OrderBy: "created_at",
|
|
||||||
OrderDirection: "DESC",
|
|
||||||
}
|
|
||||||
|
|
||||||
posts, total, err := uc.Post.FindByAuthorUID(ctx, req.AuthorUID, params)
|
|
||||||
if err != nil {
|
|
||||||
return nil, uc.handleDBError("Post.FindByAuthorUID", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
responses := make([]domainUsecase.PostResponse, len(posts))
|
|
||||||
for i, p := range posts {
|
|
||||||
responses[i] = *uc.mapPostToResponse(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &domainUsecase.ListPostsResponse{
|
|
||||||
Data: responses,
|
|
||||||
Page: domainUsecase.Pager{
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
Total: total,
|
|
||||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListPostsByCategory 根據分類列出貼文
|
|
||||||
func (uc *PostUseCase) ListPostsByCategory(ctx context.Context, req domainUsecase.ListPostsByCategoryRequest) (*domainUsecase.ListPostsResponse, error) {
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.CategoryID == zeroUUID {
|
|
||||||
return nil, errs.InputInvalidRangeError("category_id is required")
|
|
||||||
}
|
|
||||||
if req.PageSize <= 0 {
|
|
||||||
req.PageSize = 20
|
|
||||||
}
|
|
||||||
if req.PageIndex <= 0 {
|
|
||||||
req.PageIndex = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
params := &domainRepo.PostQueryParams{
|
|
||||||
Status: req.Status,
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
OrderBy: "created_at",
|
|
||||||
OrderDirection: "DESC",
|
|
||||||
}
|
|
||||||
|
|
||||||
posts, total, err := uc.Post.FindByCategoryID(ctx, req.CategoryID, params)
|
|
||||||
if err != nil {
|
|
||||||
return nil, uc.handleDBError("Post.FindByCategoryID", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
responses := make([]domainUsecase.PostResponse, len(posts))
|
|
||||||
for i, p := range posts {
|
|
||||||
responses[i] = *uc.mapPostToResponse(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &domainUsecase.ListPostsResponse{
|
|
||||||
Data: responses,
|
|
||||||
Page: domainUsecase.Pager{
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
Total: total,
|
|
||||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListPostsByTag 根據標籤列出貼文
|
|
||||||
func (uc *PostUseCase) ListPostsByTag(ctx context.Context, req domainUsecase.ListPostsByTagRequest) (*domainUsecase.ListPostsResponse, error) {
|
|
||||||
if req.Tag == "" {
|
|
||||||
return nil, errs.InputInvalidRangeError("tag is required")
|
|
||||||
}
|
|
||||||
if req.PageSize <= 0 {
|
|
||||||
req.PageSize = 20
|
|
||||||
}
|
|
||||||
if req.PageIndex <= 0 {
|
|
||||||
req.PageIndex = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
params := &domainRepo.PostQueryParams{
|
|
||||||
Status: req.Status,
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
OrderBy: "created_at",
|
|
||||||
OrderDirection: "DESC",
|
|
||||||
}
|
|
||||||
|
|
||||||
posts, total, err := uc.Post.FindByTag(ctx, req.Tag, params)
|
|
||||||
if err != nil {
|
|
||||||
return nil, uc.handleDBError("Post.FindByTag", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
responses := make([]domainUsecase.PostResponse, len(posts))
|
|
||||||
for i, p := range posts {
|
|
||||||
responses[i] = *uc.mapPostToResponse(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &domainUsecase.ListPostsResponse{
|
|
||||||
Data: responses,
|
|
||||||
Page: domainUsecase.Pager{
|
|
||||||
PageIndex: req.PageIndex,
|
|
||||||
PageSize: req.PageSize,
|
|
||||||
Total: total,
|
|
||||||
TotalPage: calculateTotalPages(total, req.PageSize),
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPinnedPosts 取得置頂貼文
|
|
||||||
func (uc *PostUseCase) GetPinnedPosts(ctx context.Context, req domainUsecase.GetPinnedPostsRequest) (*domainUsecase.ListPostsResponse, error) {
|
|
||||||
limit := int64(10)
|
|
||||||
if req.Limit > 0 {
|
|
||||||
limit = req.Limit
|
|
||||||
}
|
|
||||||
|
|
||||||
posts, err := uc.Post.FindPinnedPosts(ctx, limit)
|
|
||||||
if err != nil {
|
|
||||||
return nil, uc.handleDBError("Post.FindPinnedPosts", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
responses := make([]domainUsecase.PostResponse, len(posts))
|
|
||||||
for i, p := range posts {
|
|
||||||
responses[i] = *uc.mapPostToResponse(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &domainUsecase.ListPostsResponse{
|
|
||||||
Data: responses,
|
|
||||||
Page: domainUsecase.Pager{
|
|
||||||
PageIndex: 1,
|
|
||||||
PageSize: limit,
|
|
||||||
Total: int64(len(responses)),
|
|
||||||
TotalPage: 1,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LikePost 按讚貼文
|
|
||||||
func (uc *PostUseCase) LikePost(ctx context.Context, req domainUsecase.LikePostRequest) error {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
if req.UserUID == "" {
|
|
||||||
return errs.InputInvalidRangeError("user_uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 檢查是否已經按讚
|
|
||||||
existingLike, err := uc.Like.FindByTargetAndUser(ctx, req.PostID, req.UserUID, "post")
|
|
||||||
if err == nil && existingLike != nil {
|
|
||||||
// 已經按讚,直接返回成功
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if err != nil && !repository.IsNotFound(err) {
|
|
||||||
return uc.handleDBError("Like.FindByTargetAndUser", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 建立按讚記錄
|
|
||||||
like := &entity.Like{
|
|
||||||
TargetID: req.PostID,
|
|
||||||
UserUID: req.UserUID,
|
|
||||||
TargetType: "post",
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := uc.Like.Insert(ctx, like); err != nil {
|
|
||||||
return uc.handleDBError("Like.Insert", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 增加貼文的按讚數
|
|
||||||
if err := uc.Post.IncrementLikeCount(ctx, req.PostID); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to increment like count: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnlikePost 取消按讚
|
|
||||||
func (uc *PostUseCase) UnlikePost(ctx context.Context, req domainUsecase.UnlikePostRequest) error {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
if req.UserUID == "" {
|
|
||||||
return errs.InputInvalidRangeError("user_uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 刪除按讚記錄
|
|
||||||
if err := uc.Like.DeleteByTargetAndUser(ctx, req.PostID, req.UserUID, "post"); err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
// 已經取消按讚,直接返回成功
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return uc.handleDBError("Like.DeleteByTargetAndUser", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 減少貼文的按讚數
|
|
||||||
if err := uc.Post.DecrementLikeCount(ctx, req.PostID); err != nil {
|
|
||||||
uc.Logger.Error(fmt.Sprintf("failed to decrement like count: %v", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ViewPost 瀏覽貼文(增加瀏覽數)
|
|
||||||
func (uc *PostUseCase) ViewPost(ctx context.Context, req domainUsecase.ViewPostRequest) error {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 增加瀏覽數
|
|
||||||
if err := uc.Post.IncrementViewCount(ctx, req.PostID); err != nil {
|
|
||||||
return uc.handleDBError("Post.IncrementViewCount", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PinPost 置頂貼文
|
|
||||||
func (uc *PostUseCase) PinPost(ctx context.Context, req domainUsecase.PinPostRequest) error {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
if req.AuthorUID == "" {
|
|
||||||
return errs.InputInvalidRangeError("author_uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢貼文
|
|
||||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
|
||||||
if err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
|
||||||
}
|
|
||||||
return uc.handleDBError("Post.FindOne", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證權限
|
|
||||||
if post.AuthorUID != req.AuthorUID {
|
|
||||||
return errs.ResNotFoundError("not authorized to pin this post")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 置頂貼文
|
|
||||||
return uc.Post.PinPost(ctx, req.PostID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UnpinPost 取消置頂
|
|
||||||
func (uc *PostUseCase) UnpinPost(ctx context.Context, req domainUsecase.UnpinPostRequest) error {
|
|
||||||
// 驗證輸入
|
|
||||||
var zeroUUID gocql.UUID
|
|
||||||
if req.PostID == zeroUUID {
|
|
||||||
return errs.InputInvalidRangeError("post_id is required")
|
|
||||||
}
|
|
||||||
if req.AuthorUID == "" {
|
|
||||||
return errs.InputInvalidRangeError("author_uid is required")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢貼文
|
|
||||||
post, err := uc.Post.FindOne(ctx, req.PostID)
|
|
||||||
if err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
return errs.ResNotFoundError(fmt.Sprintf("post not found: %s", req.PostID))
|
|
||||||
}
|
|
||||||
return uc.handleDBError("Post.FindOne", req, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 驗證權限
|
|
||||||
if post.AuthorUID != req.AuthorUID {
|
|
||||||
return errs.ResNotFoundError("not authorized to unpin this post")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 取消置頂
|
|
||||||
return uc.Post.UnpinPost(ctx, req.PostID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// validateCreatePostRequest 驗證建立貼文請求
|
|
||||||
func (uc *PostUseCase) validateCreatePostRequest(req domainUsecase.CreatePostRequest) error {
|
|
||||||
if req.AuthorUID == "" {
|
|
||||||
return errs.InputInvalidRangeError("author_uid is required")
|
|
||||||
}
|
|
||||||
if req.Title == "" {
|
|
||||||
return errs.InputInvalidRangeError("title is required")
|
|
||||||
}
|
|
||||||
if req.Content == "" {
|
|
||||||
return errs.InputInvalidRangeError("content is required")
|
|
||||||
}
|
|
||||||
if !req.Type.IsValid() {
|
|
||||||
return errs.InputInvalidRangeError("invalid post type")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// mapPostToResponse 將 Post 實體轉換為 PostResponse
|
|
||||||
func (uc *PostUseCase) mapPostToResponse(post *entity.Post) *domainUsecase.PostResponse {
|
|
||||||
return &domainUsecase.PostResponse{
|
|
||||||
ID: post.ID,
|
|
||||||
AuthorUID: post.AuthorUID,
|
|
||||||
Title: post.Title,
|
|
||||||
Content: post.Content,
|
|
||||||
Type: post.Type,
|
|
||||||
Status: post.Status,
|
|
||||||
CategoryID: post.CategoryID,
|
|
||||||
Tags: post.Tags,
|
|
||||||
Images: post.Images,
|
|
||||||
VideoURL: post.VideoURL,
|
|
||||||
LinkURL: post.LinkURL,
|
|
||||||
LikeCount: post.LikeCount,
|
|
||||||
CommentCount: post.CommentCount,
|
|
||||||
ViewCount: post.ViewCount,
|
|
||||||
IsPinned: post.IsPinned,
|
|
||||||
PinnedAt: post.PinnedAt,
|
|
||||||
PublishedAt: post.PublishedAt,
|
|
||||||
CreatedAt: post.CreatedAt,
|
|
||||||
UpdatedAt: post.UpdatedAt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleDBError 處理資料庫錯誤
|
|
||||||
func (uc *PostUseCase) handleDBError(funcName string, req any, err error) error {
|
|
||||||
return errs.DBErrorErrorL(
|
|
||||||
uc.Logger,
|
|
||||||
[]errs.LogField{
|
|
||||||
{Key: "func", Val: funcName},
|
|
||||||
{Key: "req", Val: req},
|
|
||||||
{Key: "error", Val: err.Error()},
|
|
||||||
},
|
|
||||||
fmt.Sprintf("database operation failed: %s", funcName),
|
|
||||||
).Wrap(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateTagPostCounts 更新標籤的貼文數
|
|
||||||
func (uc *PostUseCase) updateTagPostCounts(ctx context.Context, tags []string, increment bool) error {
|
|
||||||
if len(tags) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// 查詢或建立標籤
|
|
||||||
for _, tagName := range tags {
|
|
||||||
tag, err := uc.Tag.FindByName(ctx, tagName)
|
|
||||||
if err != nil {
|
|
||||||
if repository.IsNotFound(err) {
|
|
||||||
// 建立新標籤
|
|
||||||
newTag := &entity.Tag{
|
|
||||||
Name: tagName,
|
|
||||||
}
|
|
||||||
if err := uc.Tag.Insert(ctx, newTag); err != nil {
|
|
||||||
return fmt.Errorf("failed to create tag: %w", err)
|
|
||||||
}
|
|
||||||
tag = newTag
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("failed to find tag: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 更新計數
|
|
||||||
if increment {
|
|
||||||
if err := uc.Tag.IncrementPostCount(ctx, tag.ID); err != nil {
|
|
||||||
return fmt.Errorf("failed to increment tag count: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if err := uc.Tag.DecrementPostCount(ctx, tag.ID); err != nil {
|
|
||||||
return fmt.Errorf("failed to decrement tag count: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateTagPostCountsDiff 更新標籤計數(處理差異)
|
|
||||||
func (uc *PostUseCase) updateTagPostCountsDiff(ctx context.Context, oldTags, newTags []string) error {
|
|
||||||
// 找出新增和刪除的標籤
|
|
||||||
oldTagMap := make(map[string]bool)
|
|
||||||
for _, tag := range oldTags {
|
|
||||||
oldTagMap[tag] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
newTagMap := make(map[string]bool)
|
|
||||||
for _, tag := range newTags {
|
|
||||||
newTagMap[tag] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// 新增的標籤
|
|
||||||
for _, tag := range newTags {
|
|
||||||
if !oldTagMap[tag] {
|
|
||||||
if err := uc.updateTagPostCounts(ctx, []string{tag}, true); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 刪除的標籤
|
|
||||||
for _, tag := range oldTags {
|
|
||||||
if !newTagMap[tag] {
|
|
||||||
if err := uc.updateTagPostCounts(ctx, []string{tag}, false); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// calculateTotalPages 計算總頁數
|
|
||||||
func calculateTotalPages(total, pageSize int64) int64 {
|
|
||||||
if pageSize <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return int64(math.Ceil(float64(total) / float64(pageSize)))
|
|
||||||
}
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue