data_processing.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import json
  2. import time
  3. import numpy as np
  4. import pandas as pd
  5. import queue
  6. import threading
  7. from collections import deque
  8. from logger_config import logger
  9. # 假设我们有一个数据流,订单簿和成交数据
  10. order_book_snapshots = deque(maxlen=100) # 存储过去100ms的订单簿快照
  11. trade_data = deque(maxlen=100) # 存储过去100ms的成交数据
  12. prediction_window = deque(maxlen=10) # 用于平滑预测结果的滑动窗口
  13. # 数据积累的阈值
  14. DATA_THRESHOLD = 20
  15. messages = queue.Queue() # 创建一个线程安全队列
  16. stop_event = threading.Event()
  17. def on_message_trade(_ws, message):
  18. global trade_data
  19. json_message = json.loads(message)
  20. trade = {
  21. 'price': float(json_message['data']['p']),
  22. 'qty': float(json_message['data']['q']),
  23. 'timestamp': pd.to_datetime(json_message['data']['T'], unit='ms'),
  24. 'side': 'sell' if json_message['data']['m'] else 'buy'
  25. }
  26. trade_data.append(trade)
  27. predict_market_direction()
  28. def on_message_depth(_ws, message):
  29. global order_book_snapshots
  30. json_message = json.loads(message)
  31. bids = json_message['data']['b'][:10] # Top 10 bids
  32. asks = json_message['data']['a'][:10] # Top 10 asks
  33. timestamp = pd.to_datetime(json_message['data']['E'], unit='ms')
  34. order_book_snapshots.append({
  35. 'bids': bids,
  36. 'asks': asks,
  37. 'timestamp': timestamp
  38. })
  39. predict_market_direction()
  40. def predict_market_direction():
  41. global prediction_window
  42. if len(trade_data) == 0:
  43. return
  44. # 统计过去100ms内的买卖交易数量
  45. buy_count = sum(trade['qty'] for trade in trade_data if trade['side'] == 'buy')
  46. sell_count = sum(trade['qty'] for trade in trade_data if trade['side'] == 'sell')
  47. # 简单的预测逻辑:买单多则预测上涨,卖单多则预测下跌
  48. prediction = 1 if buy_count > sell_count else 0
  49. # 将预测结果添加到滑动窗口中
  50. prediction_window.append(prediction)
  51. # 根据滑动窗口中的多数结果确定当前的预测方向
  52. if len(prediction_window) == prediction_window.maxlen:
  53. smoothed_prediction = 1 if sum(prediction_window) > len(prediction_window) / 2 else 0
  54. logger.info("Predicted Market Direction (smoothed): %s", "Up" if smoothed_prediction == 1 else "Down")
  55. show_message(smoothed_prediction)
  56. def show_message(market_direction):
  57. global order_book_snapshots, trade_data
  58. if len(order_book_snapshots) > 0 and len(trade_data) > 0:
  59. # 获取最新的订单簿数据和成交数据
  60. latest_order_book = order_book_snapshots[-1]
  61. latest_trade = trade_data[-1]
  62. # 提取asks和bids数据
  63. asks = [[float(price), 0] for price, qty in latest_order_book['asks']]
  64. bids = [[float(price), 0] for price, qty in latest_order_book['bids']]
  65. # 排序asks和bids数据
  66. asks_sorted = sorted(asks, key=lambda x: x[0])
  67. asks_sorted[-1] = [asks_sorted[-1][0], 1 if market_direction == 1 else 0]
  68. bids_sorted = sorted(bids, key=lambda x: x[0], reverse=True)
  69. bids_sorted[-1] = [bids_sorted[-1][0], 0 if market_direction == 1 else 1]
  70. last_price = latest_trade['price']
  71. # last_qty = latest_trade['qty']
  72. last_qty = 0
  73. side = latest_trade['side']
  74. # 生成数据字典
  75. data = {
  76. "asks": asks_sorted,
  77. "bids": bids_sorted,
  78. "last_price": last_price,
  79. "last_qty": last_qty,
  80. "side": side,
  81. "time": int(time.time() * 1000)
  82. }
  83. # 将数据放入消息队列
  84. messages.put(data)