import aiohttp import time import asyncio import json, ujson from numpy import subtract import zlib import hashlib import hmac import base64 import traceback import random, csv, sys import logging, logging.handlers from datetime import datetime import urllib import utils import model def timeit(func): def wrapper(*args, **kwargs): nowTime = time.time() res = func(*args, **kwargs) spend_time = time.time() - nowTime spend_time = round(spend_time * 1e6, 3) print(f'{func.__name__} 耗时 {spend_time} us') return res return wrapper # okex 必须要订阅限价频道 防止无法下单 def empty_call(msg): pass # -- 增量维护本地深度工具 ---------------------- def update_bids(depth, bids_p): bids_u = depth['bids'] for i in bids_u: bid_price = i[0] for j in bids_p: if bid_price == j[0]: if i[1] == '0': bids_p.remove(j) break else: del j[1] j.insert(1, i[1]) break else: if i[1] != "0": bids_p.append(i) else: bids_p.sort(key=lambda price: sort_num(price[0]), reverse=True) return bids_p def update_asks(depth, asks_p): asks_u = depth['asks'] for i in asks_u: ask_price = i[0] for j in asks_p: if ask_price == j[0]: if i[1] == '0': asks_p.remove(j) break else: del j[1] j.insert(1, i[1]) break else: if i[1] != "0": asks_p.append(i) else: asks_p.sort(key=lambda price: sort_num(price[0])) return asks_p def sort_num(n): if n.isdigit(): return int(n) else: return float(n) # @timeit def check(bids, asks): # 获取bid档str bids_l = [] bid_l = [] count_bid = 1 while count_bid <= 25: if count_bid > len(bids): break bids_l.append(bids[count_bid-1]) count_bid += 1 for j in bids_l: str_bid = ':'.join(j[0 : 2]) bid_l.append(str_bid) # 获取ask档str asks_l = [] ask_l = [] count_ask = 1 while count_ask <= 25: if count_ask > len(asks): break asks_l.append(asks[count_ask-1]) count_ask += 1 for k in asks_l: str_ask = ':'.join(k[0 : 2]) ask_l.append(str_ask) # 拼接str num = '' if len(bid_l) == len(ask_l): for m in range(len(bid_l)): num += bid_l[m] + ':' + ask_l[m] + ':' elif len(bid_l) > len(ask_l): # bid档比ask档多 for n in range(len(ask_l)): num += bid_l[n] + ':' + ask_l[n] + ':' for l in range(len(ask_l), len(bid_l)): num += bid_l[l] + ':' elif len(bid_l) < len(ask_l): # ask档比bid档多 for n in range(len(bid_l)): num += bid_l[n] + ':' + ask_l[n] + ':' for l in range(len(bid_l), len(ask_l)): num += ask_l[l] + ':' new_num = num[:-1] int_checksum = zlib.crc32(new_num.encode()) fina = change(int_checksum) return fina def change(num_old): num = pow(2, 31) - 1 if num_old > num: out = num_old - num * 2 - 2 else: out = num_old return out #-------------------------------------------------- class OkexUsdtSwapWs: def __init__(self, params:model.ClientParams, colo=0, is_print=0): if colo: print('不支持colo高速线路 请修改hosts') #### hk self.URL_PUBLIC = 'wss://ws.okx.com:8443/ws/v5/public' self.URL_PRIVATE = 'wss://ws.okx.com:8443/ws/v5/private' self.REST = 'https://www.okx.com' #### aws # self.URL_PUBLIC = 'wss://wsaws.okx.com:8443/ws/v5/public' # self.URL_PRIVATE = 'wss://wsaws.okx.com:8443/ws/v5/private' # self.REST = 'https://aws.okx.com' else: #### hk self.URL_PUBLIC = 'wss://ws.okx.com:8443/ws/v5/public' self.URL_PRIVATE = 'wss://ws.okx.com:8443/ws/v5/private' self.REST = 'https://www.okx.com' #### aws # self.URL_PUBLIC = 'wss://wsaws.okx.com:8443/ws/v5/public' # self.URL_PRIVATE = 'wss://wsaws.okx.com:8443/ws/v5/private' # self.REST = 'https://aws.okx.com' self.params = params self.name = self.params.name self.base = params.pair.split('_')[0].upper() self.quote = params.pair.split('_')[1].upper() self.symbol = f"{self.base}-{self.quote}-SWAP" if len(self.params.pair.split('_')) > 2: self.delivery = self.params.pair.split('_')[2] # 210924 self.symbol += f"-{self.delivery}" self.data = dict() self.data['trade'] = [] self.data['force'] = [] self.callback = { "onMarket":self.save_market, "onPosition":empty_call, "onEquity":empty_call, "onOrder":empty_call, "onTicker":empty_call, "onDepth":empty_call, "onExit":empty_call, } self.depth_update = [] self.need_flash = 1 self.updata_u = None self.last_update_id = None self.depth = [] self.is_print = is_print self.proxy = None if 'win' in sys.platform: self.proxy = self.params.proxy self.logger = self.get_logger() self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0} self.stop_flag = 0 self.public_update_time = time.time() self.private_update_time = time.time() self.expired_time = 300 self.update_t = 0 self.max_buy = 0.0 self.min_sell = 0.0 self.buy_v = 0.0 self.buy_q = 0.0 self.sell_v = 0.0 self.sell_q = 0.0 self.getheader = self.make_header() self.postheader = self.make_header() self._detail_ob = {} # 用于增量更新depth self.stepSize = None self.tickSize = None self.ctVal = None # 合约乘数 self.ctMult = None # 合约面值 self.depth = [] self.sub_trade = 0 self.sub_fast = 0 #### 指定发包ip iplist = utils.get_local_ip_list() self.ip = iplist[int(self.params.ip)] def save_market(self, msg): date = time.strftime('%Y-%m-%d',time.localtime()) interval = self.params.interval if msg: exchange = msg['name'] if len(msg['data']) > 1: with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv', 'a', newline='', encoding='utf-8') as f: writer = csv.writer(f, delimiter=',') writer.writerow(msg['data']) if self.is_print:print(f'写入行情 {self.symbol}') # ------------------------------------------------------ # -- core ---------------------------------------------- async def get_instruments(self): """从rest获取合约信息""" params = {'instId': self.symbol, 'instType': 'SWAP'} session = aiohttp.ClientSession() response = await self.http_get_request_pub(session, '/api/v5/public/instruments', params) data = await response.text() await session.close() return ujson.loads(data) def update_instruments(self, data): """根据信息调整合约信息""" for info in data['data']: if info['instId'] == self.symbol: self.stepSize = sort_num(info['minSz']) self.tickSize = sort_num(info['tickSz']) self.ctVal = sort_num(info['ctVal']) self.ctMult = sort_num(info['ctMult']) return async def get_depth_flash(self): """rest获取深度信息""" params = {'instId': self.symbol, 'sz':20} session = aiohttp.ClientSession() response = await self.http_get_request(session, '/api/v5/market/books', params) depth_flash = await response.text() await session.close() return ujson.loads(depth_flash) def _get_data(self): market_data = self.depth + [self.max_buy, self.min_sell] self.max_buy = 0.0 self.min_sell = 0.0 self.buy_v = 0.0 self.buy_q = 0.0 self.sell_v = 0.0 self.sell_q = 0.0 return {'name': self.name,'data':market_data} async def go(self): interval = float(self.params.interval) if self.is_print:print(f'Ws循环器启动 interval {interval}') ### onTrade while 1: try: # 更新市场信息 market_data = self._get_data() self.callback['onMarket'](market_data) except: traceback.print_exc() await asyncio.sleep(interval) def subscribe_private(self): subs = [ {'channel':'balance_and_position'}, {'channel':'account'}, {'channel':'orders', 'instType':"SWAP", 'instId':self.symbol} ] return ujson.dumps({'op':'subscribe', 'args':subs}) def subscribe_public(self): channels = [] if self.sub_fast: channels.append("books50-l2-tbt") else: channels.append("books5") # "tickers", # 100ms 比book50慢 # "books", # "books5", # "books-l2-tbt", # "books50-l2-tbt", # "price-limit" if self.sub_trade: channels.append("trades") subs = [{'instId':self.symbol, 'channel':channel} for channel in channels] return ujson.dumps({'op':'subscribe', 'args':subs}) async def run_private(self): """""" while 1: try: self.private_update_time = time.time() print(f"{self.name} private 尝试连接ws") ws_url = self.URL_PRIVATE async with aiohttp.ClientSession( connector = aiohttp.TCPConnector( limit=50, keepalive_timeout=120, verify_ssl=False, local_addr=(self.ip,0) ) ).ws_connect( ws_url, proxy=self.proxy, timeout=30, receive_timeout=30, ) as _ws: print(f"{self.name} private ws连接成功") self.logger.debug(f"{self.name} private ws连接成功") await _ws.send_str(self.login_params()) msg = await _ws.receive(timeout=30) loggined = False if msg: msg = ujson.loads(msg.data) if msg['event'] == 'login' and msg['code'] == "0": loggined = True print(f"{self.name} private login success.") self.logger.debug(f"{self.name} private login success.") await _ws.send_str(self.subscribe_private()) # login成功就需要去订阅 if not loggined: print(f"{self.name} private login failed. --> {msg}") self.logger.debug(f"{self.name} private login failed. --> {msg}") await asyncio.sleep(3) while loggined: # 停机信号 if self.stop_flag: await _ws.close() return # 接受消息 try: msg = await _ws.receive(timeout=30) except: print(f'{self.name} private ws长时间没有收到消息 准备重连...') self.logger.error(f'{self.name} private ws长时间没有收到消息 准备重连...') break msg = msg.data await self.on_message_private(_ws, msg) except: traceback.print_exc() print(f'{self.name} ws public 连接失败 开始重连...') self.logger.error(f'{self.name} ws public 连接失败 开始重连...') self.logger.error(traceback.format_exc()) await asyncio.sleep(1) async def run_public(self): """""" while 1: try: self.public_update_time = time.time() print(f"{self.name} public 尝试连接ws") ws_url = self.URL_PUBLIC async with aiohttp.ClientSession( connector = aiohttp.TCPConnector( limit=50, keepalive_timeout=120, verify_ssl=False, local_addr=(self.ip,0) ) ).ws_connect( ws_url, proxy=self.proxy, timeout=30, receive_timeout=30, ) as _ws: print(f"{self.name} public ws连接成功") self.logger.debug(f"{self.name} public ws连接成功") await _ws.send_str(self.subscribe_public()) while True: # 停机信号 if self.stop_flag: await _ws.close() return # 接受消息 try: msg = await _ws.receive(timeout=30) except: print(f'{self.name} public ws长时间没有收到消息 准备重连...') self.logger.error(f'{self.name} public ws长时间没有收到消息 准备重连...') break msg = msg.data await self.on_message_public(_ws, msg) except: traceback.print_exc() print(f'{self.name} ws public 连接失败 开始重连...') self.logger.error(f'{self.name} ws public 连接失败 开始重连...') self.logger.error(traceback.format_exc()) await asyncio.sleep(1) def _update_ticker(self, msg): """""" self.public_update_time = time.time() msg = ujson.loads(msg) ticker = msg['data'][0] t = int(ticker['ts']) if t > self.update_t: self.update_t = t bp = float(ticker['bidPx']) ap = float(ticker['askPx']) bq = float(ticker['bidSz']) aq = float(ticker['askSz']) self.ticker_info["bp"] = bp self.ticker_info["ap"] = ap self.callback['onTicker'](self.ticker_info) #### self.depth = [bp, bq, ap, aq] self.callback['onDepth']({'name':self.name,'data':self.depth}) def _update_trade(self, msg): """""" self.public_update_time = time.time() msg = ujson.loads(msg) for i in msg['data']: side = i['side'] amount = float(i['sz'])*self.ctVal # paper trans to coin price = float(i['px']) if price > self.max_buy or self.max_buy == 0.0: self.max_buy = price if price < self.min_sell or self.min_sell == 0.0: self.min_sell = price if side == 'buy': self.buy_q += amount self.buy_v += amount*price elif side == 'sell': self.sell_q += amount self.sell_v += amount*price #### 修正ticker #### # if side == 'buy' and price > self.ticker_info['ap']: # self.ticker_info['ap'] = price # self.callback['onTicker'](self.ticker_info) # if side == 'sell' and price < self.ticker_info['bp']: # self.ticker_info['bp'] = price # self.callback['onTicker'](self.ticker_info) async def _update_depth(self, _ws, msg): """""" self.public_update_time = time.time() msg = ujson.loads(msg) if "action" not in msg: # books5 就没有action,但5档实在不够用,而且间隔200ms depth = msg['data'][0] bp = float(depth['bids'][0][0]) bv = float(depth['bids'][0][1]) ap = float(depth['asks'][0][0]) av = float(depth['asks'][0][1]) self.ticker_info["bp"] = bp self.ticker_info["ap"] = ap self.callback['onTicker'](self.ticker_info) self.depth = [bp,bv,ap,av] self.callback['onDepth']({'name':self.name,'data':self.depth}) return depth = msg['data'][0] action = msg['action'] if action == 'update': self._update_depth_update(depth) elif action == 'snapshot': self._update_depth_snapshot(depth) ob = self._detail_ob if self.compare_checksum(ob, depth): t = int(depth['ts']) if t > self.update_t: self.update_t = t self.ticker_info["bp"] = float(ob['bids'][0][0]) self.ticker_info["ap"] = float(ob['asks'][0][0]) self.callback['onTicker'](self.ticker_info) ##### 标准化深度 mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5 step = mp * utils.EFF_RANGE / utils.LEVEL bp = [] ap = [] bv = [0 for _ in range(utils.LEVEL)] av = [0 for _ in range(utils.LEVEL)] for i in range(utils.LEVEL): bp.append(self.ticker_info["bp"]-step*i) for i in range(utils.LEVEL): ap.append(self.ticker_info["ap"]+step*i) # price_thre = self.ticker_info["bp"] - step index = 0 for bid in ob['bids']: price = float(bid[0]) amount = float(bid[1]) if price > price_thre: bv[index] += amount else: price_thre -= step index += 1 if index == utils.LEVEL: break bv[index] += amount price_thre = self.ticker_info["ap"] + step index = 0 for ask in ob['asks']: price = float(ask[0]) amount = float(ask[1]) if price < price_thre: av[index] += amount else: price_thre += step index += 1 if index == utils.LEVEL: break av[index] += amount self.depth = bp + bv + ap + av self.callback['onDepth']({'name':self.name,'data':self.depth}) else: await self.resubscribe_depth(_ws) async def resubscribe_depth(self, _ws): info = f"{self.name} checksum not correct!" print(info) self.logger.info(info) args = [] if self.sub_fast: args.append({'channel':'books50-l2-tbt','instId':self.symbol}) else: args.append({'channel':'books5','instId':self.symbol},) # {'channel':'books','instId':self.symbol}, # {'channel':'books5','instId':self.symbol}, # {'channel':'books50-l2-tbt','instId':self.symbol}, # {'channel':'books-l2-tbt','instId':self.symbol}, sub_str = {'op':"unsubscribe",'args':args} await _ws.send_str(ujson.dumps(sub_str)) await asyncio.sleep(1) sub_str['op'] = 'subscribe' await _ws.send_str(ujson.dumps(sub_str)) def _update_depth_update(self, depth): ob = self._detail_ob ob['timestamp'] = depth['ts'] bids_p = ob['bids'] asks_p = ob['asks'] bids_p = update_bids(depth, bids_p) asks_p = update_asks(depth, asks_p) def _update_depth_snapshot(self, depth): self._detail_ob = depth def _update_order(self, msg): '''将ws收到的订单信息触发quant''' msg = ujson.loads(msg) self.logger.debug(f"ws订单推送 {msg}") order = msg['data'][0] if order['instId'] == self.symbol: order_event = dict() status = order['state'] if status in ["live", 'partially_filled']: local_status = 'NEW' elif status in ['canceled', 'filled']: local_status = 'REMOVE' else: print(f'未知订单状态 {order}') return order_event['status'] = local_status order_event['filled_price'] = float(order['avgPx']) order_event['filled'] = float(order['accFillSz'])*self.ctVal # usdt永续需要考虑每张的单位 order_event['client_id'] = order['clOrdId'] order_event['order_id'] = order['ordId'] if order['feeCcy'] == 'USDT': order_event['fee'] = -float(order['fee']) self.callback['onOrder'](order_event) # print(order_event) self.private_update_time = time.time() def _update_balance_position(self, msg): """""" msg = ujson.loads(msg) msg = msg['data'][0] # accounts = msg['balData'] # self._update_account(accounts) positions = msg['posData'] self._update_position(positions) def _update_position(self, positions): long_pos, short_pos = 0, 0 long_avg, short_avg = 0, 0 is_update = 0 for i in positions: if i['instId'] == self.symbol: is_update = 1 if i['posSide'] == 'long': long_pos += abs(float(i['pos'])*self.ctVal) long_avg = abs(float(i['avgPx'])) elif i['posSide'] == 'short': short_pos += abs(float(i['pos'])*self.ctVal) short_avg = abs(float(i['avgPx'])) if is_update: pos = model.Position() pos.longPos = long_pos pos.longAvg = long_avg pos.shortPos = short_pos pos.shortAvg = short_avg self.callback['onPosition'](pos) #print(f'{self.symbol} {long_pos} {long_avg} {short_pos} {short_avg}') self.private_update_time = time.time() def _update_account(self, accounts): accounts = ujson.loads(accounts) for data in accounts['data']: for i in data['details']: if i['ccy'] == self.quote: self.data['equity'] = float(i['eq']) self.callback['onEquity']({self.quote:self.data['equity']}) self.private_update_time = time.time() def _update_price_limit(self, accounts): accounts = ujson.loads(accounts) buy_limit = float(accounts['data'][0]['buyLmt']) sell_limit = float(accounts['data'][0]['sellLmt']) if self.ticker_info['bp'] > 0 and self.ticker_info['ap'] > 0: mp = (self.ticker_info['bp']+self.ticker_info['ap'])*0.5 upper = buy_limit * 0.99 lower = sell_limit * 1.01 if mp > upper or mp < lower: self.callback['onExit'](f"{self.name} 触发限价警告 准备停机") self.private_update_time = time.time() async def on_message_private(self, _ws, msg): """""" if "data" in msg: # 推送数据时,有data字段,优先级也最高 if "orders" in msg: self._update_order(msg) elif "balance_and_position" in msg: self._update_balance_position(msg) elif "account" in msg: self._update_account(msg) elif "event" in msg: # event常见于事件回报,一般都可以忽略,只需要看看是否有error if "error" in msg: info = f'{self.name} on_message error! --> {msg}' print(info) self.logger.error(info) elif 'ping' in msg: await _ws.send_str('pong') else: print(msg) async def on_message_public(self, _ws, msg): """""" #print(msg) if "data" in msg: # 推送数据时,有data字段,优先级也最高 if "tickers" in msg: self._update_ticker(msg) elif "trades" in msg: self._update_trade(msg) elif "books" in msg: await self._update_depth(_ws, msg) elif "price-limit" in msg: self._update_price_limit(msg) elif "event" in msg: # event常见于事件回报,一般都可以忽略,只需要看看是否有error if "error" in msg: info = f'{self.name} on_message error! --> {msg}' print(info) self.logger.error(info) elif 'ping' in msg: await _ws.send_str('pong') else: print(msg) async def run(self, is_auth=0, sub_trade=0, sub_fast=0): # update sub info self.sub_fast = sub_fast self.sub_trade = sub_trade # exchange info info = await self.get_instruments() self.update_instruments(info) print(f"{self.name} public 更新产品信息 ws {self.symbol} {self.ctVal} {self.ctMult}") # run asyncio.create_task(self.run_public()) if is_auth: asyncio.create_task(self.run_private()) while True: await asyncio.sleep(5) # ------------------------------------------------------ # -- utils --------------------------------------------- @staticmethod def compare_checksum(ob, depth): """计算深度的校验和""" #t1 = time.time() # 降低校验频率 if random.randint(0,10) == 0: cm = check(ob["bids"], ob['asks']) #t2 = time.time() #print(cm, depth['checksum'], (t2-t1)*1000) return cm==depth['checksum'] else: return 1 def get_logger(self): logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) # log to txt formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s') handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024) handler.setLevel(logging.DEBUG) handler.setFormatter(formatter) logger.addHandler(handler) return logger async def http_get_request(self, session, method, query=None, **args): """""" params = dict() params.update(**args) if query is not None: params.update(query) timestamp = self.timestamp() if params: path = '{method}?{params}'.format(method=method, params=urllib.parse.urlencode(params)) else: path = method sign = self.__sign(timestamp, 'GET', path, None) self.getheader['OK-ACCESS-SIGN'] = sign self.getheader['OK-ACCESS-TIMESTAMP'] = timestamp url = f"{self.REST}{path}" rst = await session.get( url, headers=self.getheader, timeout=5, proxy=self.proxy ) return rst async def http_get_request_pub(self, session, method, query=None, **args): """""" params = dict() params.update(**args) if query is not None: params.update(query) timestamp = self.timestamp() if params: path = '{method}?{params}'.format(method=method, params=urllib.parse.urlencode(params)) else: path = method headers = {} headers['Content-Type'] = 'application/json' url = f"{self.REST}{path}" rst = await session.get( url, headers=headers, timeout=5, proxy=self.proxy ) return rst def timestamp(self): return datetime.utcnow().isoformat("T")[:-3] + 'Z' def __sign(self, timestamp, method, path, data): if data: message = timestamp + method + path + data else: message = timestamp + method + path digest = hmac.new(bytes(self.params.secret_key.encode('utf8')), bytes(message.encode('utf8')), digestmod=hashlib.sha256).digest() return base64.b64encode(digest).decode('utf-8') def login_params(self): """生成login字符""" timestamp = str(time.time()) message = timestamp + 'GET' + '/users/self/verify' mac = hmac.new(bytes(self.params.secret_key.encode('utf8')), bytes(message.encode('utf8')), digestmod=hashlib.sha256).digest() sign = base64.b64encode(mac) login_dict = {} login_dict['apiKey'] = self.params.access_key login_dict['passphrase'] = self.params.pass_key login_dict['timestamp'] = timestamp login_dict['sign'] = sign.decode('utf-8') login_param = {'op': 'login', 'args': [login_dict]} login_str = ujson.dumps(login_param) return login_str def make_header(self): """""" headers = {} headers['Content-Type'] = 'application/json' headers['OK-ACCESS-KEY'] = self.params.access_key headers['OK-ACCESS-SIGN'] = None headers['OK-ACCESS-TIMESTAMP'] = None headers['OK-ACCESS-PASSPHRASE'] = self.params.pass_key return headers