blockchain/internal/repository/data_source_binance.go

338 lines
8.4 KiB
Go

package repository
import (
"archive/zip"
"blockchain/internal/config"
"blockchain/internal/domain/blockchain"
"blockchain/internal/domain/entity"
"blockchain/internal/domain/repository"
"blockchain/internal/lib/cassandra"
"bytes"
"context"
"encoding/csv"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/panjf2000/ants/v2"
"github.com/adshao/go-binance/v2"
"github.com/jszwec/csvutil"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
"github.com/zeromicro/go-zero/core/syncx"
)
type BinanceRepositoryParam struct {
Conf *config.Binance
Redis *redis.Redis
DB *cassandra.CassandraDB
KeySpace string
}
type BinanceRepository struct {
Client *binance.Client
db *cassandra.CassandraDB
rds *redis.Redis
barrier syncx.SingleFlight
workers *ants.Pool
workerSize int64
KeySpace string
}
func MustBinanceRepository(param BinanceRepositoryParam) repository.DataSourceRepository {
apiKey := ""
secret := ""
if param.Conf.TestMode {
binance.UseTestnet = true
}
client := binance.NewClient(apiKey, secret)
workers, _ := ants.NewPool(int(param.Conf.WorkerSize))
return &BinanceRepository{
Client: client,
db: param.DB,
rds: param.Redis,
barrier: syncx.NewSingleFlight(),
workerSize: param.Conf.WorkerSize,
workers: workers,
KeySpace: param.KeySpace,
}
}
func (repo *BinanceRepository) GetSymbols(ctx context.Context) ([]*entity.Symbol, error) {
// 優先從 redis hash 拿
cached, err := repo.rds.Hgetall(blockchain.RedisKeySymbolList)
if err == nil && len(cached) > 0 {
symbols := make([]*entity.Symbol, 0, len(cached))
canUseCache := true
for _, v := range cached {
var symbol entity.Symbol
if err := json.Unmarshal([]byte(v), &symbol); err == nil {
symbols = append(symbols, &symbol)
} else {
// 如果任何一個反序列化失敗,代表快取可能已損壞,最好是回源重新拉取
canUseCache = false
break
}
}
if canUseCache {
return symbols, nil
}
}
// 用 SingleFlight 保證只有一個請求真的去 Binance
val, err := repo.barrier.Do(blockchain.RedisKeySymbolList, func() (any, error) {
// 拉 source
srcSymbols, err := repo.getSymbolsFromSource(ctx)
if err != nil {
return nil, err
}
result := make([]*entity.Symbol, 0, len(srcSymbols))
hashData := make(map[string]string, len(srcSymbols))
for _, s := range srcSymbols {
// 只挑目前需要的欄位
symbolEntity := &entity.Symbol{
Symbol: s.Symbol,
Status: s.Status,
BaseAsset: s.BaseAsset,
BaseAssetPrecision: s.BaseAssetPrecision,
QuoteAsset: s.QuoteAsset,
QuoteAssetPrecision: s.QuoteAssetPrecision,
}
result = append(result, symbolEntity)
// 將單一 symbol 序列化,準備寫入 hash
raw, err := json.Marshal(symbolEntity)
if err != nil {
logx.Error("failed to marshal symbol entity")
continue
}
hashData[symbolEntity.Symbol] = string(raw)
}
if len(hashData) > 0 {
// 使用 HMSET 一次寫入多個欄位到 hash
if err := repo.rds.Hmset(blockchain.RedisKeySymbolList, hashData); err == nil {
// 再對整個 key 設置過期時間
_ = repo.rds.Expire(blockchain.RedisKeySymbolList, blockchain.SymbolExpire)
}
}
return result, nil
})
if err != nil {
return nil, err
}
if symbols, ok := val.([]*entity.Symbol); ok {
return symbols, nil
}
return nil, fmt.Errorf("invalid symbol type: %T", val)
}
func (repo *BinanceRepository) FetchHistoryKline(ctx context.Context, param repository.QueryKline) ([]*entity.Kline, error) {
ch := make(chan []*entity.Kline, repo.workerSize)
var wg sync.WaitGroup
start := time.Unix(0, param.StartUnixNano)
end := time.Unix(0, param.EndUnixNano)
// 產生所有天的任務
for d := start; !d.After(end); d = d.AddDate(0, 0, 1) {
day := d
wg.Add(1)
_ = repo.workers.Submit(func() {
defer wg.Done()
klines, err := repo.fetchHistoryKline(ctx, param.Symbol, param.Interval, day.Format(time.DateOnly))
if err == nil && len(klines) > 0 {
ch <- klines // 只要拿到資料就丟進 channel
}
// 沒資料不用丟,避免 nil append
})
}
// 等全部任務完成再關閉 channel
go func() {
wg.Wait()
close(ch)
}()
// 收集所有 K 線
var allKlines []*entity.Kline
for klines := range ch {
allKlines = append(allKlines, klines...)
}
return allKlines, nil
}
func (repo *BinanceRepository) SaveHistoryKline(ctx context.Context, data []*entity.Kline) error {
ch := make(chan struct{}, repo.workerSize)
var wg sync.WaitGroup
var errList []error
var mu sync.Mutex
for _, item := range data {
wg.Add(1)
ch <- struct{}{} // block if max concurrency reached
go func(k *entity.Kline) {
defer wg.Done()
defer func() { <-ch }()
if err := repo.db.Insert(ctx, k, repo.KeySpace); err != nil {
mu.Lock()
errList = append(errList, err)
mu.Unlock()
logx.Errorf("failed to insert data: %v", err)
}
}(item)
}
wg.Wait()
if len(errList) > 0 {
return fmt.Errorf("insert errors: %v", errList)
}
return nil
}
// =============
func (repo *BinanceRepository) getSymbolsFromSource(ctx context.Context) ([]binance.Symbol, error) {
if repo.Client == nil {
return nil, fmt.Errorf("binance client not initialized")
}
// 取得幣安交易所資訊
info, err := repo.Client.NewExchangeInfoService().Do(ctx)
if err != nil {
return nil, err
}
return info.Symbols, nil
}
func (repo *BinanceRepository) fetchHistoryKline(ctx context.Context, symbol string, interval string, date string) ([]*entity.Kline, error) {
baseURL := fmt.Sprintf("%s%s", blockchain.BinanceHistoryDataBase, blockchain.BinanceHistoryDataKlines)
symbol = strings.ToUpper(symbol)
zipFile := fmt.Sprintf("%s-%s-%s.zip", symbol, interval, date)
url := fmt.Sprintf("%s/%s/%s/%s", baseURL, symbol, interval, zipFile)
if err := check(ctx, url); err != nil {
return nil, err
}
// 這個 URL 只可能指向 binance.vision 官方站,已限定字串組合,不可能被用戶控制。
// #nosec G107
// 下載 zip
// 這個 URL 只可能指向 binance.vision 官方站,已限定字串組合,不可能被用戶控制。
// #nosec G107
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil || resp.StatusCode != http.StatusOK {
if resp != nil {
resp.Body.Close()
}
return nil, fmt.Errorf("failed to fetch file %s", url)
}
defer resp.Body.Close()
tmpPath := filepath.Join(os.TempDir(), zipFile)
out, err := os.Create(tmpPath)
if err != nil {
return nil, err
}
_, _ = io.Copy(out, resp.Body)
out.Close()
resp.Body.Close()
// 解壓縮
r, err := zip.OpenReader(tmpPath)
if err != nil {
os.Remove(tmpPath)
return nil, err
}
defer r.Close()
defer os.Remove(tmpPath)
var result []*entity.Kline
header := []string{
"open_time", "open", "high", "low", "close", "volume", "close_time",
"quote_asset_volume", "number_of_trades", "taker_buy_base_asset_volume",
"taker_buy_quote_asset_volume", "ignore",
}
for _, f := range r.File {
rc, err := f.Open()
if err != nil {
continue
}
var buf bytes.Buffer
writer := csv.NewWriter(&buf)
_ = writer.Write(header)
reader := csv.NewReader(rc)
for {
record, err := reader.Read()
if err == io.EOF {
break
}
if err != nil || len(record) < 12 {
continue
}
_ = writer.Write(record)
}
writer.Flush()
rc.Close()
// csvutil parse
var klines []*entity.Kline
if err := csvutil.Unmarshal(buf.Bytes(), &klines); err != nil {
continue
}
// 可根據需要加上 symbol/interval
for _, k := range klines {
k.Symbol = symbol
k.Interval = interval
}
result = append(result, klines...)
}
return result, nil
}
func check(ctx context.Context, url string) error {
// 先 HEAD 確認檔案是否存在,節省流量
// 這個 URL 只可能指向 binance.vision 官方站,已限定字串組合,不可能被用戶控制。
// #nosec G107
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err != nil {
return err
}
client := &http.Client{}
respHead, err := client.Do(req)
if err != nil || respHead.StatusCode != http.StatusOK {
if respHead != nil {
respHead.Body.Close()
}
return fmt.Errorf("file not found: %s", url)
}
defer respHead.Body.Close()
return nil
}