From 1bc26db607073643b50eeaa592fc39aa66f1eca9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=80=A7=E9=A9=8A?= Date: Wed, 20 Aug 2025 20:35:31 +0800 Subject: [PATCH] add sma ema macd --- go.mod | 3 +- go.sum | 2 + internal/lib/strategy/bool.go | 136 ++++++++++++++++++++ internal/lib/strategy/bool_test.go | 65 ++++++++++ internal/lib/strategy/ema.go | 1 + internal/lib/strategy/ema_sma_talib_test.go | 88 +++++++++++++ internal/lib/strategy/rsi.go | 62 +++++++++ internal/lib/strategy/rsi_test.go | 71 ++++++++++ 8 files changed, 427 insertions(+), 1 deletion(-) create mode 100644 internal/lib/strategy/bool.go create mode 100644 internal/lib/strategy/bool_test.go create mode 100644 internal/lib/strategy/ema_sma_talib_test.go create mode 100644 internal/lib/strategy/rsi.go create mode 100644 internal/lib/strategy/rsi_test.go diff --git a/go.mod b/go.mod index 4a948d8..8cbafb4 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,10 @@ require ( github.com/goccy/go-json v0.10.5 github.com/gocql/gocql v1.5.0 github.com/lxzan/gws v1.8.9 + github.com/markcheno/go-talib v0.0.0-20250114000313-ec55a20c902f github.com/panjf2000/ants/v2 v2.11.3 github.com/scylladb/gocqlx/v3 v3.0.1 + github.com/shopspring/decimal v1.4.0 github.com/stretchr/testify v1.10.0 github.com/testcontainers/testcontainers-go v0.38.0 github.com/zeromicro/go-zero v1.8.5 @@ -55,7 +57,6 @@ require ( github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/scylladb/go-reflectx v1.0.1 // indirect github.com/shirou/gopsutil/v4 v4.25.5 // indirect - github.com/shopspring/decimal v1.4.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect diff --git a/go.sum b/go.sum index 096efff..acbead6 100644 --- a/go.sum +++ b/go.sum @@ -156,6 +156,8 @@ github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8S github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= +github.com/markcheno/go-talib v0.0.0-20250114000313-ec55a20c902f h1:iKq//xEUUaeRoXNcAshpK4W8eSm7HtgI0aNznWtX7lk= +github.com/markcheno/go-talib v0.0.0-20250114000313-ec55a20c902f/go.mod h1:3YUtoVrKWu2ql+iAeRyepSz3fy6a+19hJzGS88+u4u0= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= diff --git a/internal/lib/strategy/bool.go b/internal/lib/strategy/bool.go new file mode 100644 index 0000000..51b9057 --- /dev/null +++ b/internal/lib/strategy/bool.go @@ -0,0 +1,136 @@ +package strategy + +import ( + "github.com/shopspring/decimal" + "math" +) + +/************** O(1) 滾動標準差:Welford + ring buffer **************/ + +// 滾動統計(固定窗寬):維持 mean 與 M2(平方離差和) +// 支援「移除最舊樣本」+「新增最新樣本」的 O(1) 更新 +type rollingWelford struct { + n int // 窗寬 + buf []decimal.Decimal // 環狀緩衝(固定長度 n) + head int // 下一個覆寫位置(最舊元素位置) + size int // 當前已填入數量(<= n) + mean decimal.Decimal // 目前視窗內的平均值 + m2 decimal.Decimal // 目前視窗內的平方離差和(∑(x-mean)^2) + ready bool // 是否已填滿 n +} + +func newRollingWelford(n int) *rollingWelford { + if n <= 0 { + panic("rollingWelford window must be > 0") + } + return &rollingWelford{ + n: n, + buf: make([]decimal.Decimal, n), + } +} + +// add 一個樣本(標準 Welford 加法) +func (rw *rollingWelford) add(x decimal.Decimal) { + // n_old = size, n_new = size+1 + nOld := decimal.NewFromInt(int64(rw.size)) + nNew := nOld.Add(decimal.NewFromInt(1)) + + delta := x.Sub(rw.mean) + meanNew := rw.mean.Add(delta.Div(nNew)) + delta2 := x.Sub(meanNew) + rw.m2 = rw.m2.Add(delta.Mul(delta2)) + rw.mean = meanNew + + if rw.size < rw.n { + rw.size++ + if rw.size == rw.n { + rw.ready = true + } + } +} + +// remove 一個樣本(Welford 反向移除公式) +func (rw *rollingWelford) remove(x decimal.Decimal) { + if rw.size <= 0 { + return + } + // n_old = size, n_new = size-1 + nOld := decimal.NewFromInt(int64(rw.size)) + nNew := nOld.Sub(decimal.NewFromInt(1)) + if nNew.LessThanOrEqual(decimal.Zero) { + // 視窗將清空 + rw.size = 0 + rw.mean = decimal.Zero + rw.m2 = decimal.Zero + rw.ready = false + return + } + delta := x.Sub(rw.mean) // (x - mean_old) + meanNew := rw.mean.Sub(delta.Div(nNew)) // mean_new = mean_old - delta / n_new + // M2_new = M2_old - delta*(x - mean_new) + rw.m2 = rw.m2.Sub(delta.Mul(x.Sub(meanNew))) + rw.mean = meanNew + rw.size-- + rw.ready = rw.size == rw.n // 正常來說 size==n-1,移除後會在 add 後再回到 n +} + +// push:O(1) 更新(移除最舊 + 新增最新) +func (rw *rollingWelford) push(x decimal.Decimal) { + if rw.size == rw.n { + old := rw.buf[rw.head] + rw.remove(old) + } + rw.buf[rw.head] = x + rw.head = (rw.head + 1) % rw.n + rw.add(x) +} + +// mid() = 平均;std() = sqrt(variance) +// 這裡採「母體方差」分母 = n(與很多圖表預設一致;如需樣本用 n-1) +func (rw *rollingWelford) mid() decimal.Decimal { return rw.mean } + +func (rw *rollingWelford) std() decimal.Decimal { + if !rw.ready || rw.size == 0 { + return decimal.Zero + } + nDec := decimal.NewFromInt(int64(rw.n)) + variance := rw.m2.Div(nDec) + f, _ := variance.Float64() + return decimal.NewFromFloat(math.Sqrt(f)) +} + +/************** Boll(Welford 版本) **************/ + +type BollW struct { + rw *rollingWelford +} + +type BollWOut struct { + Mid, Upper, Lower, Std decimal.Decimal + Ready bool +} + +func NewBollW(n int) *BollW { + return &BollW{rw: newRollingWelford(n)} +} + +// Push 輸入收盤價 close 與 k 倍標準差(典型 2.0) +// 回傳:上下軌、均線、標準差、是否就緒(前 n 根為 false) +func (b *BollW) Push(close, k decimal.Decimal) BollWOut { + b.rw.push(close) + if !b.rw.ready { + return BollWOut{Ready: false} + } + mid := b.rw.mid() + std := b.rw.std() + + upper := mid.Add(k.Mul(std)) + lower := mid.Sub(k.Mul(std)) + return BollWOut{ + Mid: mid, + Upper: upper, + Lower: lower, + Std: std, + Ready: true, + } +} diff --git a/internal/lib/strategy/bool_test.go b/internal/lib/strategy/bool_test.go new file mode 100644 index 0000000..5c32633 --- /dev/null +++ b/internal/lib/strategy/bool_test.go @@ -0,0 +1,65 @@ +package strategy + +import ( + "github.com/shopspring/decimal" + "testing" +) + +func TestBoll_WarmupAndSliding(t *testing.T) { + n := 5 + k := d(2.0) + b := NewBollW(n) + + // 前 n-1 根:not ready + for i, px := range []float64{10, 11, 12, 13} { + out := b.Push(d(int64(px)), k) + if out.Ready { + t.Fatalf("i=%d: should not be ready yet", i) + } + } + + // 第 n 根開始:ready + out := b.Push(d(14), k) + if !out.Ready { + t.Fatalf("should be ready at %d-th push", n) + } + // 中線應為 (10+11+12+13+14)/5 = 12 + if out.Mid.StringFixed(6) != d(12).StringFixed(6) { + t.Fatalf("mid expect 12, got %s", out.Mid) + } + // 標準差 > 0;上軌 > 中線 > 下軌 + if !out.Std.GreaterThan(decimal.Zero) || + !out.Upper.GreaterThan(out.Mid) || + !out.Mid.GreaterThan(out.Lower) { + t.Fatalf("band ordering violated: U=%s M=%s L=%s Std=%s", out.Upper, out.Mid, out.Lower, out.Std) + } + + // 再推兩根 -> 視窗滑到 [12,13,14,15,16],中線=14 + for _, px := range []float64{15, 16} { + out = b.Push(d(int64(px)), k) + } + if out.Mid.StringFixed(6) != d(14).StringFixed(6) { + t.Fatalf("sliding mid expect 14, got %s", out.Mid) + } + if !out.Std.GreaterThan(decimal.Zero) { + t.Fatalf("std should remain > 0 after slide, got %s", out.Std) + } +} + +func TestBoll_FlatSeries(t *testing.T) { + n := 4 + k := d(2.0) + b := NewBollW(n) + + // 全部相同價格 -> 標準差=0、三條線重合 + for i := 0; i < n; i++ { + _ = b.Push(d(10), k) + } + out := b.Push(d(10), k) // 視窗仍為相同值 + if !out.Std.Equal(decimal.Zero) { + t.Fatalf("std should be 0 on flat series, got %s", out.Std) + } + if !(out.Upper.Equal(out.Mid) && out.Mid.Equal(out.Lower) && out.Mid.Equal(d(10))) { + t.Fatalf("bands should coincide at 10: U=%s M=%s L=%s", out.Upper, out.Mid, out.Lower) + } +} diff --git a/internal/lib/strategy/ema.go b/internal/lib/strategy/ema.go index 7cdb0fa..682d683 100644 --- a/internal/lib/strategy/ema.go +++ b/internal/lib/strategy/ema.go @@ -32,6 +32,7 @@ func newEMACore(period uint) *emaCore { if period == 0 { panic("EMA period must be > 0") } + return &emaCore{ period: period, alpha: alphaFromPeriod(period), diff --git a/internal/lib/strategy/ema_sma_talib_test.go b/internal/lib/strategy/ema_sma_talib_test.go new file mode 100644 index 0000000..f62f074 --- /dev/null +++ b/internal/lib/strategy/ema_sma_talib_test.go @@ -0,0 +1,88 @@ +package strategy + +import ( + "github.com/markcheno/go-talib" + "github.com/shopspring/decimal" + "math" + "testing" +) + +func dv(v float64) decimal.Decimal { return decimal.NewFromFloat(v) } + +func genPrices(n int) []float64 { + out := make([]float64, n) + base := 10.0 + for i := 0; i < n; i++ { + out[i] = base + float64(i)*0.5 + math.Sin(float64(i)/3.0)*0.7 + } + return out +} + +func almostEqualDecFloat(dec decimal.Decimal, f float64, tol float64) bool { + df, _ := dec.Float64() + return math.Abs(df-f) <= tol +} + +// 注意:使用 NewEMAForTalib(period)(First-Price seed)來比對 TA-Lib +func TestEMA_MatchesGoTalib(t *testing.T) { + prices := genPrices(300) + tol := 1e-7 // decimal<->float64 轉換微誤差 + + periods := []uint{3, 5, 12, 26} + for _, p := range periods { + t.Run("EMA_p="+decimal.NewFromInt(int64(p)).String(), func(t *testing.T) { + ema := NewEMA(p) // 重點:用 TA-Lib 兼容 seed + + our := make([]decimal.Decimal, len(prices)) + ready := make([]bool, len(prices)) + for i, px := range prices { + out := ema.Update(dv(px)) + our[i] = out.Value + ready[i] = out.Ready + } + ref := talib.Ema(prices, int(p)) + + start := p // talib 第一個有效值在 index = period-1 + + for i := start; i < uint(len(prices)); i++ { + if !ready[i] { + t.Fatalf("i=%d: our not ready but talib has value", i) + } + if !almostEqualDecFloat(our[i], ref[i], tol) { + t.Fatalf("i=%d: EMA mismatch: our=%s ref=%f", i, our[i], ref[i]) + } + } + }) + } +} + +func TestSMA_MatchesGoTalib(t *testing.T) { + prices := genPrices(300) + tol := 1e-9 + + windows := []uint{3, 5, 20, 50} + for _, w := range windows { + t.Run("SMA_w="+decimal.NewFromInt(int64(w)).String(), func(t *testing.T) { + sma := NewSMA(w) + + our := make([]decimal.Decimal, len(prices)) + ready := make([]bool, len(prices)) + for i, px := range prices { + out := sma.Update(dv(px)) + our[i] = out.Value + ready[i] = out.Ready + } + + ref := talib.Sma(prices, int(w)) + start := w + for i := start; i < uint(len(prices)); i++ { + if !ready[i] { + t.Fatalf("i=%d: our not ready but talib has value", i) + } + if !almostEqualDecFloat(our[i], ref[i], tol) { + t.Fatalf("i=%d: SMA mismatch: our=%s ref=%f", i, our[i], ref[i]) + } + } + }) + } +} diff --git a/internal/lib/strategy/rsi.go b/internal/lib/strategy/rsi.go new file mode 100644 index 0000000..e532d46 --- /dev/null +++ b/internal/lib/strategy/rsi.go @@ -0,0 +1,62 @@ +package strategy + +import "github.com/shopspring/decimal" + +// RSI 使用 Wilder 的平均漲跌(非簡單平均),更貼近交易軟體常見計法 +type RSI struct { + n int + prevC decimal.Decimal // 前一根收盤價 + initCount int // 初始化用:先累積前 n 根的總漲/總跌 + avgGain decimal.Decimal // 平滑後的平均上漲 + avgLoss decimal.Decimal // 平滑後的平均下跌 + ok bool // 是否有 prevC +} + +func NewRSI(n int) *RSI { return &RSI{n: n} } + +// Push 餵入一根K線,回傳 (RSI值, 是否就緒) +// 注意:前 n 根會回傳就緒=false;之後才可信 +func (r *RSI) Push(c CandleForStrategy) (decimal.Decimal, bool) { + if !r.ok { + r.prevC = c.C + r.ok = true + return decimal.Zero, false + } + // 價差 + chg := c.C.Sub(r.prevC) + r.prevC = c.C + + // 區分上漲與下跌 + gain := decimal.Max(chg, decimal.Zero) + loss := decimal.Max(chg.Neg(), decimal.Zero) + + // 初始化階段:先把前 n 根的平均值建好 + if r.initCount < r.n { + r.avgGain = r.avgGain.Add(gain) + r.avgLoss = r.avgLoss.Add(loss) + r.initCount++ + if r.initCount == r.n { + r.avgGain = r.avgGain.Div(decimal.NewFromInt(int64(r.n))) + r.avgLoss = r.avgLoss.Div(decimal.NewFromInt(int64(r.n))) + return r.calc(), true + } + return decimal.Zero, false + } + + // Wilder 平滑:新的平均 = (舊平均*(n-1) + 當期值) / n + nDec := decimal.NewFromInt(int64(r.n)) + r.avgGain = (r.avgGain.Mul(nDec.Sub(decimal.NewFromInt(1))).Add(gain)).Div(nDec) + r.avgLoss = (r.avgLoss.Mul(nDec.Sub(decimal.NewFromInt(1))).Add(loss)).Div(nDec) + + return r.calc(), true +} + +func (r *RSI) calc() decimal.Decimal { + if r.avgLoss.IsZero() { + return decimal.NewFromInt(100) // 沒有下跌時,RSI=100 + } + rs := r.avgGain.Div(r.avgLoss) + one := decimal.NewFromInt(1) + hundred := decimal.NewFromInt(100) + return hundred.Sub(hundred.Div(one.Add(rs))) +} diff --git a/internal/lib/strategy/rsi_test.go b/internal/lib/strategy/rsi_test.go new file mode 100644 index 0000000..0d72cbf --- /dev/null +++ b/internal/lib/strategy/rsi_test.go @@ -0,0 +1,71 @@ +package strategy + +import ( + "github.com/shopspring/decimal" + "testing" +) + +func TestRSI_WarmupAndRange(t *testing.T) { + n := 2 + r := NewRSI(n) + + type step struct { + close float64 + wantReady bool + } + steps := []step{ + {10, false}, // 第一根,只建立 prevC + {11, false}, // 第二根,累積 + {12, true}, // 第三根,完成 seed -> ready + {13, true}, + {12, true}, + {11, true}, + {12, true}, + {13, true}, + } + + for i, st := range steps { + val, ready := r.Push(CandleForStrategy{C: dv(st.close)}) + if ready != st.wantReady { + t.Fatalf("step %d: ready got=%v want=%v", i, ready, st.wantReady) + } + if ready { + // RSI 應落在 0..100 + if val.LessThan(decimal.Zero) || val.GreaterThan(decimal.NewFromInt(100)) { + t.Fatalf("step %d: RSI out of [0,100], got %s", i, val) + } + } + } +} + +func TestRSI_PerfectUpAndDown(t *testing.T) { + n := 2 + r := NewRSI(n) + + // 先 seed + _, _ = r.Push(CandleForStrategy{C: d(10)}) + _, _ = r.Push(CandleForStrategy{C: d(11)}) + val, ready := r.Push(CandleForStrategy{C: d(12)}) + if !ready { + t.Fatalf("should be ready after %d candles", n) + } + + // 連續上漲:理想情況 avgLoss -> 0,RSI 逼近 100 + val, ready = r.Push(CandleForStrategy{C: d(13)}) + if !ready || !val.LessThanOrEqual(decimal.NewFromInt(100)) { + t.Fatalf("uptrend: ready=%v val=%s", ready, val) + } + + // 連續下跌:理想情況 avgGain -> 0,RSI 逼近 0 + r = NewRSI(n) + _, _ = r.Push(CandleForStrategy{C: d(13)}) + _, _ = r.Push(CandleForStrategy{C: d(12)}) + val, ready = r.Push(CandleForStrategy{C: d(11)}) + if !ready { + t.Fatalf("should be ready after %d candles", n) + } + val, ready = r.Push(CandleForStrategy{C: d(10)}) + if !ready || !val.GreaterThanOrEqual(decimal.Zero) { + t.Fatalf("downtrend: ready=%v val=%s", ready, val) + } +}