From 777a5952b86ae2a510cb3f5deed822d82255ea65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=80=A7=E9=A9=8A?= Date: Fri, 15 Aug 2025 09:36:36 +0800 Subject: [PATCH] feat: add ema strategy --- internal/lib/strategy/ema.go | 57 ++++++++-- internal/lib/strategy/ema_test.go | 102 ++++++++++++++++++ internal/lib/strategy/ring_queue.go | 13 ++- internal/lib/strategy/ring_queue_test.go | 4 +- internal/lib/strategy/sma.go | 3 +- internal/lib/strategy/sma_test.go | 132 +++++++++++++++++++++++ 6 files changed, 295 insertions(+), 16 deletions(-) create mode 100644 internal/lib/strategy/ema_test.go create mode 100644 internal/lib/strategy/sma_test.go diff --git a/internal/lib/strategy/ema.go b/internal/lib/strategy/ema.go index e2c704c..62ed8be 100644 --- a/internal/lib/strategy/ema.go +++ b/internal/lib/strategy/ema.go @@ -2,26 +2,65 @@ package strategy import "github.com/shopspring/decimal" -/************** EMA 指數移動平均 **************/ +/* + EMA,全名為指數移動平均線(Exponential Moving Average),用於平滑價格波動,幫助識別市場趨勢。 + 它與簡單移動平均線(SMA)不同,EMA 更注重近期價格,因此對價格變動的反應更迅速,能更快地反映市場趨勢。 + +EMA 的主要特點和作用: + 更快速的反應: + EMA 比 SMA 更快地反映價格變動,因為它給予近期數據更高的權重。 + 識別趨勢: + 通過平滑價格波動,EMA 有助於識別市場的整體趨勢,判斷是上升趨勢還是下降趨勢。 + 輔助交易決策: + EMA 的使用可以幫助交易者判斷買入和賣出的時機,例如,當股價高於EMA 時,可能被視為買入信號;反之,則可能被視為賣出信號。 + 適合短線交易: + 由於EMA 對價格變動的敏感性,它更適合於短線交易者,能更快地捕捉市場的短期波動。 + EMA 的計算方法: + EMA 的計算涉及一個平滑因子和一個初始值,然後每天更新。 具體公式可以參考專業的金融網站或交易平台提供的資料。 + 總結: + EMA 是一種有用的技術分析工具,尤其適合於快速變動的市場,它可以幫助交易者更好地理解市場趨勢,並制定相應的交易策略 +*/ + type EMA struct { - n int + n uint alp decimal.Decimal // 平滑係數 α = 2 / (n + 1) val decimal.Decimal // 當前EMA值 - ok bool + ok bool // 內部旗標,用於判斷是否為第一筆資料 } -func NewEMA(n int) *EMA { - return &EMA{n: n, alp: decimal.NewFromFloat(2.0).Div(decimal.NewFromInt(int64(n + 1)))} +// NewEMA 建立EMA計算器 +func NewEMA(n uint) *EMA { + return &EMA{ + n: n, + alp: decimal.NewFromInt(2).Div(decimal.NewFromInt(int64(n + 1))), + ok: false, + } } +// Push 輸入收盤價,返回當前EMA值 func (e *EMA) Push(close decimal.Decimal) (decimal.Decimal, bool) { + // 如果 n 無效,永遠回傳無效狀態 + if e.n == 0 { + return decimal.Zero, false + } + if !e.ok { - // 第一筆資料直接當作EMA初始值 + // 第一筆資料直接當作EMA初始值,並將狀態設為 ok e.val = close e.ok = true - return e.val, false + } else { + // 後續資料使用 EMA 計算公式 + // EMA = α * close + (1 - α) * prev_EMA + e.val = e.alp.Mul(close).Add(decimal.NewFromInt(1).Sub(e.alp).Mul(e.val)) + } + // EMA 從第一筆資料開始就是有效的 + return e.val, true +} + +// GetEMA 取得目前 EMA 值 +func (e *EMA) GetEMA() (decimal.Decimal, bool) { + if !e.ok { + return decimal.Zero, false // 尚未初始化 } - // EMA計算公式 - e.val = e.alp.Mul(close).Add(decimal.NewFromInt(1).Sub(e.alp).Mul(e.val)) return e.val, true } diff --git a/internal/lib/strategy/ema_test.go b/internal/lib/strategy/ema_test.go new file mode 100644 index 0000000..8da532b --- /dev/null +++ b/internal/lib/strategy/ema_test.go @@ -0,0 +1,102 @@ +package strategy + +import ( + "github.com/shopspring/decimal" + "testing" +) + +// --- EMA 的表格式驅動測試 (新增) --- + +func TestEMA(t *testing.T) { + d10 := decimal.NewFromInt(10) + d11 := decimal.NewFromInt(11) + d12 := decimal.NewFromInt(12) + d13 := decimal.NewFromInt(13) + d20 := decimal.NewFromInt(20) + + type pushCheck struct { + wantEMA decimal.Decimal + wantOK bool + } + + testCases := []struct { + name string + n uint + inputs []decimal.Decimal + pushChecks []pushCheck + wantFinalEMA decimal.Decimal + wantFinalOK bool + }{ + { + name: "EMA-3 標準計算", + n: 3, // α = 2 / (3 + 1) = 0.5 + inputs: []decimal.Decimal{d10, d11, d12}, + pushChecks: []pushCheck{ + {d10, true}, // 第一次, EMA = 10 + {decimal.NewFromFloat(10.5), true}, // 第二次, 0.5*11 + (1-0.5)*10 = 5.5 + 5 = 10.5 + {decimal.NewFromFloat(11.25), true}, // 第三次, 0.5*12 + (1-0.5)*10.5 = 6 + 5.25 = 11.25 + }, + wantFinalEMA: decimal.NewFromFloat(11.25), + wantFinalOK: true, + }, + { + name: "EMA-1 邊界情況", + n: 1, // α = 2 / (1 + 1) = 1 + inputs: []decimal.Decimal{d10, d13, d11}, + pushChecks: []pushCheck{ + {d10, true}, // 第一次, EMA = 10 + {d13, true}, // 第二次, 1*13 + 0*10 = 13 + {d11, true}, // 第三次, 1*11 + 0*13 = 11 + }, + wantFinalEMA: d11, + wantFinalOK: true, + }, + { + name: "EMA-0 無效情況", + n: 0, + inputs: []decimal.Decimal{d10, d20}, + pushChecks: []pushCheck{ + {decimal.Zero, false}, + {decimal.Zero, false}, + }, + wantFinalEMA: decimal.Zero, + wantFinalOK: false, + }, + { + name: "在空實例上呼叫 GetEMA", + n: 5, + inputs: []decimal.Decimal{}, + pushChecks: []pushCheck{}, + wantFinalEMA: decimal.Zero, + wantFinalOK: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ema := NewEMA(tc.n) + + for i, input := range tc.inputs { + gotEMA, gotOK := ema.Push(input) + if i < len(tc.pushChecks) { + check := tc.pushChecks[i] + if gotOK != check.wantOK { + t.Errorf("Push #%d 的 OK 狀態錯誤: got %v, want %v", i+1, gotOK, check.wantOK) + } + // 使用 String() 進行比較,避免浮點數精度問題 + if gotEMA.String() != check.wantEMA.String() { + t.Errorf("Push #%d 的 EMA 值錯誤: got %s, want %s", i+1, gotEMA.String(), check.wantEMA.String()) + } + } + } + + finalEMA, finalOK := ema.GetEMA() + if finalOK != tc.wantFinalOK { + t.Errorf("最終 GetEMA 的 OK 狀態錯誤: got %v, want %v", finalOK, tc.wantFinalOK) + } + if finalEMA.String() != tc.wantFinalEMA.String() { + t.Errorf("最終 GetEMA 的 EMA 值錯誤: got %s, want %s", finalEMA.String(), tc.wantFinalEMA.String()) + } + }) + } +} diff --git a/internal/lib/strategy/ring_queue.go b/internal/lib/strategy/ring_queue.go index b7b8a02..72f2c8b 100644 --- a/internal/lib/strategy/ring_queue.go +++ b/internal/lib/strategy/ring_queue.go @@ -6,14 +6,16 @@ import ( ) /************** 基礎的固定長度隊列,用來計算移動平均 **************/ +// 請注意: 目前為非併發版本,使用場警還不需要 + type ringQD struct { - N int // 窗口大小(需要保留的資料數量) + N uint // 窗口大小(需要保留的資料數量)需要保留的數量沒有負數 l *list.List // 用於儲存資料的雙向鏈表 sum decimal.Decimal // 當前窗口的總和,方便快速計算平均值 } // 建立一個固定長度的隊列 -func newRingQD(n int) *ringQD { +func newRingQD(n uint) *ringQD { return &ringQD{N: n, l: list.New(), sum: decimal.Zero} } @@ -22,7 +24,8 @@ func (q *ringQD) push(x decimal.Decimal) { q.l.PushBack(x) q.sum = q.sum.Add(x) // 如果超出最大長度,移除最舊的數值 - if q.l.Len() > q.N { + + if uint(q.l.Len()) > q.N { f := q.l.Front() q.sum = q.sum.Sub(f.Value.(decimal.Decimal)) q.l.Remove(f) @@ -30,13 +33,14 @@ func (q *ringQD) push(x decimal.Decimal) { } // ready:判斷隊列是否已經填滿 -func (q *ringQD) ready() bool { return q.l.Len() == q.N } +func (q *ringQD) ready() bool { return q.N > 0 && uint(q.l.Len()) == q.N } // mean:計算平均值 func (q *ringQD) mean() decimal.Decimal { if q.l.Len() == 0 { return decimal.Zero } + return q.sum.Div(decimal.NewFromInt(int64(q.l.Len()))) } @@ -46,5 +50,6 @@ func (q *ringQD) values() []decimal.Decimal { for e := q.l.Front(); e != nil; e = e.Next() { out = append(out, e.Value.(decimal.Decimal)) } + return out } diff --git a/internal/lib/strategy/ring_queue_test.go b/internal/lib/strategy/ring_queue_test.go index f1f1094..17215bf 100644 --- a/internal/lib/strategy/ring_queue_test.go +++ b/internal/lib/strategy/ring_queue_test.go @@ -20,7 +20,7 @@ func TestRingQD(t *testing.T) { // 定義測試案例的結構 testCases := []struct { name string // 測試案例的名稱 - n int // ringQD 的大小 + n uint // ringQD 的大小 inputs []decimal.Decimal // 輸入的數值序列 wantSum decimal.Decimal // 預期的總和 wantMean decimal.Decimal // 預期的平均值 @@ -88,7 +88,7 @@ func TestRingQD(t *testing.T) { wantSum: decimal.Zero, wantMean: decimal.Zero, wantValues: []decimal.Decimal{}, - wantReady: true, + wantReady: false, }, } diff --git a/internal/lib/strategy/sma.go b/internal/lib/strategy/sma.go index 9aafdbb..bd25939 100644 --- a/internal/lib/strategy/sma.go +++ b/internal/lib/strategy/sma.go @@ -25,7 +25,7 @@ type SMA struct { } // NewSMA 建立SMA計算器 -func NewSMA(n int) *SMA { return &SMA{q: newRingQD(n)} } +func NewSMA(n uint) *SMA { return &SMA{q: newRingQD(n)} } // Push 輸入收盤價,返回當前SMA值 func (s *SMA) Push(close decimal.Decimal) (decimal.Decimal, bool) { @@ -41,5 +41,6 @@ func (s *SMA) GetSMA() (decimal.Decimal, bool) { if !s.q.ready() { return decimal.Zero, false // 尚未湊滿資料 } + return s.q.mean(), true } diff --git a/internal/lib/strategy/sma_test.go b/internal/lib/strategy/sma_test.go new file mode 100644 index 0000000..c144b05 --- /dev/null +++ b/internal/lib/strategy/sma_test.go @@ -0,0 +1,132 @@ +package strategy + +import ( + "github.com/shopspring/decimal" + "testing" +) + +// --- SMA 的表格式驅動測試 --- + +func TestSMA(t *testing.T) { + d10 := decimal.NewFromInt(10) + d20 := decimal.NewFromInt(20) + d30 := decimal.NewFromInt(30) + d40 := decimal.NewFromInt(40) + d50 := decimal.NewFromInt(50) + + // 定義 Push 過程中的檢查點結構 + type pushCheck struct { + wantSMA decimal.Decimal + wantOK bool + } + + testCases := []struct { + name string + n uint + inputs []decimal.Decimal + pushChecks []pushCheck // 驗證每一次 Push 的回傳值 + wantFinalSMA decimal.Decimal // 驗證最後 GetSMA 的回傳值 + wantFinalOK bool + }{ + { + name: "SMA-5 未滿載", + n: 5, + inputs: []decimal.Decimal{d10, d20, d30}, + pushChecks: []pushCheck{ + {decimal.Zero, false}, + {decimal.Zero, false}, + {decimal.Zero, false}, + }, + wantFinalSMA: decimal.Zero, + wantFinalOK: false, + }, + { + name: "SMA-3 剛好滿載", + n: 3, + inputs: []decimal.Decimal{d10, d20, d30}, + pushChecks: []pushCheck{ + {decimal.Zero, false}, + {decimal.Zero, false}, + {decimal.NewFromInt(20), true}, // (10+20+30)/3 + }, + wantFinalSMA: decimal.NewFromInt(20), + wantFinalOK: true, + }, + { + name: "SMA-3 滾動計算", + n: 3, + inputs: []decimal.Decimal{d10, d20, d30, d40, d50}, + pushChecks: []pushCheck{ + {decimal.Zero, false}, + {decimal.Zero, false}, + {decimal.NewFromInt(20), true}, // (10+20+30)/3 + {decimal.NewFromInt(30), true}, // (20+30+40)/3 + {decimal.NewFromInt(40), true}, // (30+40+50)/3 + }, + wantFinalSMA: decimal.NewFromInt(40), + wantFinalOK: true, + }, + { + name: "SMA-1 邊界情況", + n: 1, + inputs: []decimal.Decimal{d10, d20, d30}, + pushChecks: []pushCheck{ + {d10, true}, + {d20, true}, + {d30, true}, + }, + wantFinalSMA: d30, + wantFinalOK: true, + }, + { + name: "SMA-0 無效情況", + n: 0, + inputs: []decimal.Decimal{d10, d20, d30}, + pushChecks: []pushCheck{ + {decimal.Zero, false}, + {decimal.Zero, false}, + {decimal.Zero, false}, + }, + wantFinalSMA: decimal.Zero, + wantFinalOK: false, + }, + { + name: "在空實例上呼叫 GetSMA", + n: 5, + inputs: []decimal.Decimal{}, + pushChecks: []pushCheck{}, + wantFinalSMA: decimal.Zero, + wantFinalOK: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + sma := NewSMA(tc.n) + + // 驗證每一次 Push 的結果 + for i, input := range tc.inputs { + gotSMA, gotOK := sma.Push(input) + // 確保 pushChecks 陣列不會索引越界 + if i < len(tc.pushChecks) { + check := tc.pushChecks[i] + if gotOK != check.wantOK { + t.Errorf("Push #%d 的 OK 狀態錯誤: got %v, want %v", i+1, gotOK, check.wantOK) + } + if !gotSMA.Equals(check.wantSMA) { + t.Errorf("Push #%d 的 SMA 值錯誤: got %s, want %s", i+1, gotSMA.String(), check.wantSMA.String()) + } + } + } + + // 在所有 Push 操作完成後,驗證最終 GetSMA 的結果 + finalSMA, finalOK := sma.GetSMA() + if finalOK != tc.wantFinalOK { + t.Errorf("最終 GetSMA 的 OK 狀態錯誤: got %v, want %v", finalOK, tc.wantFinalOK) + } + if !finalSMA.Equals(tc.wantFinalSMA) { + t.Errorf("最終 GetSMA 的 SMA 值錯誤: got %s, want %s", finalSMA.String(), tc.wantFinalSMA.String()) + } + }) + } +}