121 lines
2.7 KiB
Go
121 lines
2.7 KiB
Go
package strategy
|
||
|
||
import (
|
||
"sync"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/shopspring/decimal"
|
||
)
|
||
|
||
func d(i int64) decimal.Decimal { return decimal.NewFromInt(i) }
|
||
|
||
func TestSMA_WarmupAndSliding(t *testing.T) {
|
||
type step struct {
|
||
in int64
|
||
wantVal string // 用字串比對可避免浮點誤差(decimal 本就精準)
|
||
ready bool
|
||
}
|
||
tests := []struct {
|
||
name string
|
||
window uint
|
||
steps []step
|
||
}{
|
||
{
|
||
name: "warmup_then_ready_and_slide",
|
||
window: 3,
|
||
steps: []step{
|
||
{in: 1, wantVal: "1", ready: false}, // [1] avg=1
|
||
{in: 2, wantVal: "1.5", ready: false}, // [1,2] avg=1.5
|
||
{in: 3, wantVal: "2", ready: true}, // [1,2,3] avg=2
|
||
{in: 4, wantVal: "3", ready: true}, // [2,3,4] avg=3
|
||
{in: 5, wantVal: "4", ready: true}, // [3,4,5] avg=4
|
||
{in: 6, wantVal: "5", ready: true}, // [4,5,6] avg=5
|
||
{in: 7, wantVal: "6", ready: true}, // [5,6,7] avg=6
|
||
},
|
||
},
|
||
{
|
||
name: "window_1_behaves_as_latest_value",
|
||
window: 1,
|
||
steps: []step{
|
||
{in: 10, wantVal: "10", ready: true},
|
||
{in: 11, wantVal: "11", ready: true},
|
||
{in: 12, wantVal: "12", ready: true},
|
||
},
|
||
},
|
||
}
|
||
|
||
for _, tc := range tests {
|
||
tc := tc
|
||
t.Run(tc.name, func(t *testing.T) {
|
||
s := NewSMA(tc.window)
|
||
for i, st := range tc.steps {
|
||
got := s.Update(d(st.in))
|
||
if got.Value.String() != st.wantVal {
|
||
t.Fatalf("step %d: got value %s, want %s", i, got.Value, st.wantVal)
|
||
}
|
||
if got.Ready != st.ready {
|
||
t.Fatalf("step %d: ready mismatch, got %v, want %v", i, got.Ready, st.ready)
|
||
}
|
||
// Load() 應該等於最新快照
|
||
ld := s.Load()
|
||
if ld.Value.String() != st.wantVal || ld.Ready != st.ready {
|
||
t.Fatalf("step %d: Load() not latest snapshot", i)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestSMA_NewZeroWindowShouldPanic(t *testing.T) {
|
||
defer func() {
|
||
if r := recover(); r == nil {
|
||
t.Fatalf("expected panic when window=0")
|
||
}
|
||
}()
|
||
_ = NewSMA(0)
|
||
}
|
||
|
||
func TestSMA_MultiReadersSingleWriter(t *testing.T) {
|
||
s := NewSMA(3)
|
||
|
||
var wg sync.WaitGroup
|
||
stop := make(chan struct{})
|
||
|
||
// 多 reader 併發讀取(零鎖)
|
||
reader := func() {
|
||
defer wg.Done()
|
||
for {
|
||
select {
|
||
case <-stop:
|
||
return
|
||
default:
|
||
_ = s.Load() // 不做判斷,重點是不得 panic
|
||
}
|
||
}
|
||
}
|
||
|
||
// 啟動多個讀者
|
||
for i := 0; i < 8; i++ {
|
||
wg.Add(1)
|
||
go reader()
|
||
}
|
||
|
||
// 單 writer 更新
|
||
prices := []int64{1, 2, 3, 4, 5, 6, 7}
|
||
for _, p := range prices {
|
||
s.Update(d(p))
|
||
time.Sleep(1 * time.Millisecond)
|
||
}
|
||
|
||
// 停 reader
|
||
close(stop)
|
||
wg.Wait()
|
||
|
||
// 最終應為 [5,6,7] 的平均=6,且 ready=true
|
||
got := s.Load()
|
||
if got.Value.String() != "6" || !got.Ready {
|
||
t.Fatalf("final snapshot mismatch: got value=%s ready=%v", got.Value, got.Ready)
|
||
}
|
||
}
|