gate_spot_ws.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. import aiohttp
  2. import time
  3. import asyncio
  4. import zlib
  5. import json, ujson
  6. import zlib
  7. import hmac, sys
  8. import base64, csv, random
  9. import traceback, hashlib
  10. import logging, logging.handlers
  11. import utils
  12. import model
  13. def inflate(data):
  14. '''
  15. 解压缩数据
  16. '''
  17. decompress = zlib.decompressobj(-zlib.MAX_WBITS)
  18. inflated = decompress.decompress(data)
  19. inflated += decompress.flush()
  20. return inflated
  21. def empty_call(msg):
  22. pass
  23. def get_sign(secret_key, message):
  24. h = (base64.b64encode(hmac.new(secret_key.encode('utf-8'), message.encode('utf-8'), hashlib.sha512).digest())).decode()
  25. return h
  26. class GateSpotWs:
  27. def __init__(self, params:model.ClientParams, colo=0, is_print=0):
  28. if colo:
  29. print('启用colo高速线路')
  30. self.URL = 'wss://spotws-private.gateapi.io/ws/v4/'
  31. else:
  32. self.URL = 'wss://api.gateio.ws/ws/v4/'
  33. self.params = params
  34. self.name = self.params.name
  35. self.base = self.params.pair.split('_')[0].upper()
  36. self.quote = self.params.pair.split('_')[1].upper()
  37. self.symbol = self.base + '_' + self.quote
  38. self.callback = {
  39. "onMarket":self.save_market,
  40. "onDepth":empty_call,
  41. "onPosition":empty_call,
  42. "onTicker":empty_call,
  43. "onDepth":empty_call,
  44. "onEquity":empty_call,
  45. "onOrder":empty_call,
  46. "onTrade":empty_call,
  47. "onExit":empty_call,
  48. }
  49. self.is_print = is_print
  50. self.proxy = None
  51. if 'win' in sys.platform:
  52. self.proxy = self.params.proxy
  53. self.logger = self.get_logger()
  54. self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
  55. self.stop_flag = 0
  56. self.max_buy = 0.0
  57. self.min_sell = 0.0
  58. self.buy_v = 0.0
  59. self.buy_q = 0.0
  60. self.sell_v = 0.0
  61. self.sell_q = 0.0
  62. self.update_t = 0.0
  63. self.depth = []
  64. #### 指定发包ip
  65. iplist = utils.get_local_ip_list()
  66. self.ip = iplist[int(self.params.ip)]
  67. def get_logger(self):
  68. logger = logging.getLogger(__name__)
  69. logger.setLevel(logging.DEBUG)
  70. # log to txt
  71. formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
  72. handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
  73. handler.setLevel(logging.DEBUG)
  74. handler.setFormatter(formatter)
  75. logger.addHandler(handler)
  76. return logger
  77. def save_market(self, msg):
  78. date = time.strftime('%Y-%m-%d',time.localtime())
  79. interval = self.params.interval
  80. if msg:
  81. exchange = msg['name']
  82. if len(msg['data']) > 1:
  83. with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
  84. 'a',
  85. newline='',
  86. encoding='utf-8') as f:
  87. writer = csv.writer(f, delimiter=',')
  88. writer.writerow(msg['data'])
  89. if self.is_print:print(f'写入行情 {self.symbol}')
  90. def _update_ticker(self, msg):
  91. self.ticker_info["bp"] = float(msg['highest_bid'])
  92. self.ticker_info["ap"] = float(msg['lowest_ask'])
  93. self.callback['onTicker'](self.ticker_info)
  94. def _update_depth(self, msg):
  95. if msg['t'] > self.update_t:
  96. self.update_t = msg['t']
  97. self.ticker_info["bp"] = float(msg['bids'][0][0])
  98. self.ticker_info["ap"] = float(msg['asks'][0][0])
  99. self.callback['onTicker'](self.ticker_info)
  100. ##### 标准化深度
  101. mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
  102. step = mp * utils.EFF_RANGE / utils.LEVEL
  103. bp = []
  104. ap = []
  105. bv = [0 for _ in range(utils.LEVEL)]
  106. av = [0 for _ in range(utils.LEVEL)]
  107. for i in range(utils.LEVEL):
  108. bp.append(self.ticker_info["bp"]-step*i)
  109. for i in range(utils.LEVEL):
  110. ap.append(self.ticker_info["ap"]+step*i)
  111. #
  112. price_thre = self.ticker_info["bp"] - step
  113. index = 0
  114. for bid in msg['bids']:
  115. price = float(bid[0])
  116. amount = float(bid[1])
  117. if price > price_thre:
  118. bv[index] += amount
  119. else:
  120. price_thre -= step
  121. index += 1
  122. if index == utils.LEVEL:
  123. break
  124. bv[index] += amount
  125. price_thre = self.ticker_info["ap"] + step
  126. index = 0
  127. for ask in msg['asks']:
  128. price = float(ask[0])
  129. amount = float(ask[1])
  130. if price < price_thre:
  131. av[index] += amount
  132. else:
  133. price_thre += step
  134. index += 1
  135. if index == utils.LEVEL:
  136. break
  137. av[index] += amount
  138. self.depth = bp + bv + ap + av
  139. self.callback['onDepth']({'name':self.name,'data':self.depth})
  140. else:
  141. print("收到过期depth")
  142. def _update_trade(self, msg):
  143. price = float(msg['price'])
  144. amount = float(msg['amount'])
  145. side = msg['side']
  146. if price > self.max_buy or self.max_buy == 0.0:
  147. self.max_buy = price
  148. if price < self.min_sell or self.min_sell == 0.0:
  149. self.min_sell = price
  150. if side == 'buy':
  151. self.buy_q += amount
  152. self.buy_v += amount*price
  153. elif side == 'sell':
  154. self.sell_q += amount
  155. self.sell_v += amount*price
  156. def _update_account(self, msg):
  157. for i in msg:
  158. if i['currency'].upper() == self.quote:
  159. cash = float(i['total'])
  160. self.callback['onEquity'] = {
  161. self.quote:cash
  162. }
  163. elif i['currency'].upper() == self.base:
  164. coin = float(i['total'])
  165. self.callback['onEquity'] = {
  166. self.base:coin
  167. }
  168. def _update_order(self, msg):
  169. self.logger.debug(f"ws订单推送 {msg}")
  170. for i in msg:
  171. if i['event'] == 'put':
  172. order_event = dict()
  173. order_event['filled'] = 0
  174. order_event['filled_price'] = 0
  175. order_event['client_id'] = i["text"]
  176. order_event['order_id'] = i['id']
  177. order_event['status'] = "NEW"
  178. self.callback['onOrder'](order_event)
  179. elif i['event'] == 'finish':
  180. order_event = dict()
  181. order_event['filled'] = float(i["amount"]) - float(i["left"])
  182. if order_event['filled'] > 0:
  183. order_event['filled_price'] = float(i["filled_total"])/order_event['filled']
  184. else:
  185. order_event['filled_price'] = 0
  186. order_event['client_id'] = i["text"]
  187. order_event['order_id'] = i['id']
  188. order_event['fee'] = float(i["fee"])
  189. order_event['status'] = "REMOVE"
  190. self.callback['onOrder'](order_event)
  191. # 根据成交信息更新仓位信息 因为账户信息推送有延迟
  192. # 但订单信息和账户信息到达先后时间可能有前有后 可能平仓 账户先置零仓位 然后sell成交达到 导致仓位变成负数
  193. def _update_usertrade(self, msg):
  194. '''暂时不用'''
  195. pass
  196. def _update_position(self, msg):
  197. long_pos, short_pos = 0, 0
  198. long_avg, short_avg = 0, 0
  199. for i in msg[0]['holding']:
  200. if i['side'] == 'long':
  201. long_pos += float(i['position'])
  202. long_avg = float(i['avg_cost'])
  203. if i['side'] == 'short':
  204. short_pos += float(i['position'])
  205. short_avg = float(i['avg_cost'])
  206. pos = model.Position()
  207. pos.longPos = long_pos
  208. pos.longAvg = long_avg
  209. pos.shortPos = short_pos
  210. pos.shortAvg = short_avg
  211. self.callback['onPosition'](pos)
  212. def _get_data(self):
  213. market_data = self.depth + [self.max_buy, self.min_sell]
  214. self.max_buy = 0.0
  215. self.min_sell = 0.0
  216. self.buy_v = 0.0
  217. self.buy_q = 0.0
  218. self.sell_v = 0.0
  219. self.sell_q = 0.0
  220. return {'name': self.name,'data':market_data}
  221. async def go(self):
  222. interval = float(self.params.interval)
  223. if self.is_print:print(f'Ws循环器启动 interval {interval}')
  224. ### onTrade
  225. while 1:
  226. try:
  227. # 更新市场信息
  228. market_data = self._get_data()
  229. self.callback['onMarket'](market_data)
  230. except:
  231. traceback.print_exc()
  232. await asyncio.sleep(interval)
  233. def get_sign(self, message):
  234. h = hmac.new(self.params.secret_key.encode("utf8"), message.encode("utf8"), hashlib.sha512)
  235. return h.hexdigest()
  236. async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
  237. while True:
  238. try:
  239. ping_time = time.time()
  240. # 尝试连接
  241. print(f'{self.name} 尝试连接ws')
  242. ws_url = self.URL
  243. async with aiohttp.ClientSession(
  244. connector = aiohttp.TCPConnector(
  245. limit=50,
  246. keepalive_timeout=120,
  247. verify_ssl=False,
  248. local_addr=(self.ip,0)
  249. )
  250. ).ws_connect(
  251. ws_url,
  252. proxy=self.proxy,
  253. timeout=30,
  254. receive_timeout=30,
  255. ) as _ws:
  256. self.is_print:print(f'{self.name} ws连接成功')
  257. # 登陆
  258. if is_auth:
  259. # userorders
  260. current_time = int(time.time())
  261. channel = "spot.orders"
  262. sub_str = {
  263. "time": current_time,
  264. "channel": channel,
  265. "event": "subscribe",
  266. "payload": [self.symbol]
  267. }
  268. message = 'channel=%s&event=%s&time=%d' % (channel, "subscribe", current_time)
  269. sub_str["auth"] = {
  270. "method": "api_key",
  271. "KEY": self.params.access_key,
  272. "SIGN": self.get_sign(message)}
  273. await _ws.send_str(ujson.dumps(sub_str))
  274. # usertrades
  275. current_time = int(time.time())
  276. channel = "spot.usertrades"
  277. sub_str = {
  278. "time": current_time,
  279. "channel": channel,
  280. "event": "subscribe",
  281. "payload": [self.symbol]
  282. }
  283. message = 'channel=%s&event=%s&time=%d' % (channel, "subscribe", current_time)
  284. sub_str["auth"] = {
  285. "method": "api_key",
  286. "KEY": self.params.access_key,
  287. "SIGN": self.get_sign(message)}
  288. await _ws.send_str(ujson.dumps(sub_str))
  289. # balance
  290. current_time = int(time.time())
  291. channel = "spot.balances"
  292. sub_str = {
  293. "time": current_time,
  294. "channel": channel,
  295. "event": "subscribe",
  296. "payload": [self.symbol]
  297. }
  298. message = 'channel=%s&event=%s&time=%d' % (channel, "subscribe", current_time)
  299. sub_str["auth"] = {
  300. "method": "api_key",
  301. "KEY": self.params.access_key,
  302. "SIGN": self.get_sign(message)}
  303. await _ws.send_str(ujson.dumps(sub_str))
  304. if sub_trade:
  305. # public trade
  306. current_time = int(time.time())
  307. channel = "spot.trades"
  308. sub_str = {
  309. "time": current_time,
  310. "channel": channel,
  311. "event": "subscribe",
  312. "payload": [self.symbol]
  313. }
  314. await _ws.send_str(ujson.dumps(sub_str))
  315. # 订阅public
  316. # tickers 太慢了
  317. # current_time = int(time.time())
  318. # channel = "spot.tickers"
  319. # sub_str = {
  320. # "time": current_time,
  321. # "channel": channel,
  322. # "event": "subscribe",
  323. # "payload": [self.symbol]
  324. # }
  325. # await _ws.send_str(ujson.dumps(sub_str))
  326. # depth
  327. current_time = int(time.time())
  328. channel = "spot.order_book"
  329. sub_str = {
  330. "time": current_time,
  331. "channel": channel,
  332. "event": "subscribe",
  333. "payload": [self.symbol,"20","100ms"]
  334. }
  335. await _ws.send_str(ujson.dumps(sub_str))
  336. while True:
  337. # 停机信号
  338. if self.stop_flag:
  339. await _ws.close()
  340. return
  341. # 接受消息
  342. try:
  343. msg = await _ws.receive(timeout=10)
  344. except:
  345. print(f'{self.name} ws长时间没有收到消息 准备重连...')
  346. self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
  347. break
  348. msg = msg[1]
  349. # 处理消息
  350. if 'update' in msg:
  351. msg = ujson.loads(msg)
  352. # if msg['channel'] == 'spot.tickers':self._update_ticker(msg['result'])
  353. if msg['channel'] == 'spot.order_book':self._update_depth(msg['result'])
  354. elif msg['channel'] == 'spot.balances':self._update_account(msg['result'])
  355. elif msg['channel'] == 'spot.orders':self._update_order(msg['result'])
  356. # if msg['channel'] == 'spot.usertrades':self._update_usertrade(msg['result'])
  357. elif msg['channel'] == 'spot.trades':self._update_trade(msg['result'])
  358. else:
  359. # print(msg)
  360. pass
  361. # pong
  362. if time.time() - ping_time > 5:
  363. await _ws.send_str('{"time": %d, "channel" : "spot.ping"}' % int(time.time()))
  364. ping_time = time.time()
  365. except:
  366. traceback.print_exc()
  367. print(f'{self.name} ws连接失败 开始重连...')
  368. self.logger.error(f'{self.name} ws连接失败 开始重连...')
  369. # await asyncio.sleep(1)