blockchain/internal/lib/strategy/sma_test.go

121 lines
2.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package strategy
import (
"sync"
"testing"
"time"
"github.com/shopspring/decimal"
)
func d(i int64) decimal.Decimal { return decimal.NewFromInt(i) }
func TestSMA_WarmupAndSliding(t *testing.T) {
type step struct {
in int64
wantVal string // 用字串比對可避免浮點誤差decimal 本就精準)
ready bool
}
tests := []struct {
name string
window uint
steps []step
}{
{
name: "warmup_then_ready_and_slide",
window: 3,
steps: []step{
{in: 1, wantVal: "1", ready: false}, // [1] avg=1
{in: 2, wantVal: "1.5", ready: false}, // [1,2] avg=1.5
{in: 3, wantVal: "2", ready: true}, // [1,2,3] avg=2
{in: 4, wantVal: "3", ready: true}, // [2,3,4] avg=3
{in: 5, wantVal: "4", ready: true}, // [3,4,5] avg=4
{in: 6, wantVal: "5", ready: true}, // [4,5,6] avg=5
{in: 7, wantVal: "6", ready: true}, // [5,6,7] avg=6
},
},
{
name: "window_1_behaves_as_latest_value",
window: 1,
steps: []step{
{in: 10, wantVal: "10", ready: true},
{in: 11, wantVal: "11", ready: true},
{in: 12, wantVal: "12", ready: true},
},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
s := NewSMA(tc.window)
for i, st := range tc.steps {
got := s.Update(d(st.in))
if got.Value.String() != st.wantVal {
t.Fatalf("step %d: got value %s, want %s", i, got.Value, st.wantVal)
}
if got.Ready != st.ready {
t.Fatalf("step %d: ready mismatch, got %v, want %v", i, got.Ready, st.ready)
}
// Load() 應該等於最新快照
ld := s.Load()
if ld.Value.String() != st.wantVal || ld.Ready != st.ready {
t.Fatalf("step %d: Load() not latest snapshot", i)
}
}
})
}
}
func TestSMA_NewZeroWindowShouldPanic(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatalf("expected panic when window=0")
}
}()
_ = NewSMA(0)
}
func TestSMA_MultiReadersSingleWriter(t *testing.T) {
s := NewSMA(3)
var wg sync.WaitGroup
stop := make(chan struct{})
// 多 reader 併發讀取(零鎖)
reader := func() {
defer wg.Done()
for {
select {
case <-stop:
return
default:
_ = s.Load() // 不做判斷,重點是不得 panic
}
}
}
// 啟動多個讀者
for i := 0; i < 8; i++ {
wg.Add(1)
go reader()
}
// 單 writer 更新
prices := []int64{1, 2, 3, 4, 5, 6, 7}
for _, p := range prices {
s.Update(d(p))
time.Sleep(1 * time.Millisecond)
}
// 停 reader
close(stop)
wg.Wait()
// 最終應為 [5,6,7] 的平均=6且 ready=true
got := s.Load()
if got.Value.String() != "6" || !got.Ready {
t.Fatalf("final snapshot mismatch: got value=%s ready=%v", got.Value, got.Ready)
}
}