Ver Fonte

初始化

skyfffire há 2 anos atrás
commit
27c97201a4
49 ficheiros alterados com 20975 adições e 0 exclusões
  1. 2 0
      .gitignore
  2. 379 0
      backtest.py
  3. 636 0
      bak/center.py
  4. 487 0
      bak/dummy.py
  5. 138 0
      bak/final.py
  6. 118 0
      bak/seele.py
  7. 108 0
      broker.py
  8. 26 0
      config.toml
  9. 0 0
      exchange/__init__.py
  10. 540 0
      exchange/binance_coin_swap_rest.py
  11. 345 0
      exchange/binance_coin_swap_ws.py
  12. 573 0
      exchange/binance_spot_rest.py
  13. 521 0
      exchange/binance_spot_ws.py
  14. 592 0
      exchange/binance_usdt_swap_rest.py
  15. 562 0
      exchange/binance_usdt_swap_ws.py
  16. 588 0
      exchange/bitget_usdt_swap_rest.py
  17. 422 0
      exchange/bitget_usdt_swap_ws.py
  18. 658 0
      exchange/bybit_usdt_swap_rest.py
  19. 544 0
      exchange/bybit_usdt_swap_ws.py
  20. 602 0
      exchange/coinex_spot_rest.py
  21. 350 0
      exchange/coinex_spot_ws.py
  22. 542 0
      exchange/coinex_usdt_swap_rest.py
  23. 331 0
      exchange/coinex_usdt_swap_ws.py
  24. 369 0
      exchange/ftx_spot_ws.py
  25. 364 0
      exchange/ftx_usdt_swap_ws.py
  26. 528 0
      exchange/gate_spot_rest.py
  27. 389 0
      exchange/gate_spot_ws.py
  28. 525 0
      exchange/gate_usdt_swap_rest.py
  29. 464 0
      exchange/gate_usdt_swap_ws.py
  30. 309 0
      exchange/huobi_spot_ws.py
  31. 393 0
      exchange/huobi_usdt_swap_rest.py
  32. 407 0
      exchange/huobi_usdt_swap_ws.py
  33. 531 0
      exchange/kucoin_spot_rest.py
  34. 390 0
      exchange/kucoin_spot_ws.py
  35. 537 0
      exchange/kucoin_usdt_swap_rest.py
  36. 386 0
      exchange/kucoin_usdt_swap_ws.py
  37. 471 0
      exchange/mexc_spot_rest.py
  38. 340 0
      exchange/mexc_spot_ws.py
  39. 98 0
      exchange/model.py
  40. 712 0
      exchange/okex_usdt_swap_rest.py
  41. 811 0
      exchange/okex_usdt_swap_ws.py
  42. 18 0
      exchange/readme.md
  43. 520 0
      exchange/utils.py
  44. 99 0
      model.py
  45. 131 0
      predictor.py
  46. 1473 0
      quant.py
  47. 38 0
      readme.txt
  48. 1080 0
      strategy.py
  49. 528 0
      utils.py

+ 2 - 0
.gitignore

@@ -0,0 +1,2 @@
+.idea
+/venv

+ 379 - 0
backtest.py

@@ -0,0 +1,379 @@
+'''
+    基于深度信息的异步事件触发回测框架
+    作者 千千量化
+'''
+import os
+import sys
+import asyncio
+import json, ujson
+import requests
+import hmac
+import base64
+import zlib
+import datetime
+import time
+import strategy as strategy
+import traceback
+import joblib
+import traceback
+import configparser
+import signal
+import random
+import utils
+import model
+import time
+import copy
+import random
+
+def timeit(func):
+    def wrapper(*args, **kwargs):
+        nowTime = time.time()
+        func(*args, **kwargs)
+        spend_time = time.time() - nowTime
+        spend_time = round(spend_time * 100, 5)
+        print(f'{func.__name__} 耗时 {spend_time} ms')
+
+    return wrapper
+
+class Backtest:
+    '''
+        回测框架
+    '''
+
+    def __init__(self, strategy:strategy.Strategy, data=None, is_plot=0):
+        '''
+            初始化实盘类
+        '''
+        # 获取策略对象
+        self.strategy = strategy
+        # 历史数据
+        self.data = data
+        # 交易品种
+        self.symbol = 'backtest_symbol'
+        # 虚拟挂单
+        self.localOrders = dict()
+        # 虚拟仓位
+        self.pos = model.Position()
+        # 中间价
+        self.mp = None
+        # 行情数据
+        self.reference_price = []
+        self.last_marketdata = []
+        # 费率 maker taker
+        self.fee = model.BacktestFee("v9")
+        # 是否画图
+        self._is_plot = is_plot
+        self.long_hold = []
+        self.short_hold = []
+        self.index = []
+        # 虚拟成交信息
+        self.trade_num = 0
+        # 记录持仓时间
+        self.kd_time = None
+        self.kk_time = None
+        self.hold_times = [0.0]
+        # 权益
+        self.start_cash = 1000000.0
+        self.equity = self.start_cash
+        self.equity_real = self.start_cash
+        self.equity_high = self.start_cash
+        self.balance = [self.start_cash]
+        self.strategy.start_cash = self.start_cash
+        self.strategy.local_start_time = 0.0
+        # 成交时间轴
+        self.priceMotion = []
+        self.kd_trade = []
+        self.kk_trade = []
+        self.pd_trade = []
+        self.pk_trade = []
+        # 挂单时间轴
+        self.localOrders_timeseries = []
+        # min max fill
+        self.fill_price = []
+        # 上次信号
+        self.signal_series = dict()
+        # 回测延迟 默认值
+        self.backtest_delay = utils.BACKTEST_DELAY
+        # 回测时间
+        self.backtest_time = time.time()
+        # 平均持仓时间
+        self.avg_hold_time = 0.0
+        # 最近一次处理的信号时间戳
+        self.last_signal_time = 0.0
+        
+    def get_available_close_amt(self):
+        a_pd, a_pk = self.pos.longPos, self.pos.shortPos
+        for i in self.localOrders:
+            if self.localOrders[i]["side"] == 'pd':
+                a_pd -= self.localOrders[i]["amount"]
+            if self.localOrders[i]["side"] == 'pk':
+                a_pk -= self.localOrders[i]["amount"]
+        return a_pd, a_pk
+
+    # @timeit
+    # 处理策略信号
+    def handle_signal(self, signals):
+        '''
+            当信号>1tick 且时间超过n ms 才允许进行处理
+        '''
+        keys = list(signals.keys())
+        for signal_time in keys:
+            # 只允许处理上一tick时间戳之前的信号
+            if signal_time < self.last_signal_time:
+                pass
+            else:
+                continue
+            if self.backtest_time - signal_time >= self.backtest_delay:
+                signal = signals[signal_time]
+                if signal == None:
+                    return
+                for i in signal:
+                    # 撤销虚拟订单
+                    if 'Cancel' in i:
+                        if signal[i][0] in self.localOrders:
+                            del(self.localOrders[signal[i][0]])
+                    # 执行虚拟下单
+                    elif 'Limits_open' in i:
+                        for j in signal[i]:
+                            order_event = dict()
+                            order_event['symbol'] = self.symbol
+                            order_event['status'] = "NEW"
+                            order_event['amount'] = float(j[0])
+                            order_event['side'] = j[1]
+                            order_event['price'] = float(j[2])
+                            order_event['filled_price'] = 0
+                            order_event['filled'] = 0
+                            order_event['client_id'] = str(random.randint(1,99999))
+                            order_event['order_id'] = str(random.randint(1,99999))
+                            order_event['localtime'] = signal_time
+                            order_event['createtime'] = signal_time
+                            self.localOrders[order_event['client_id']] = order_event
+                    elif 'Limits_close' in i:
+                        a_pd, a_pk = self.get_available_close_amt()
+                        for j in signal[i]:
+                            if j[1]  == "pd":
+                                if j[0] > a_pd:
+                                    break
+                            elif j[1]  == "pk":
+                                if j[0] > a_pk:
+                                    break
+                            order_event = dict()
+                            order_event['symbol'] = self.symbol
+                            order_event['status'] = "NEW"
+                            order_event['amount'] = float(j[0])
+                            order_event['side'] = j[1]
+                            order_event['price'] = float(j[2])
+                            order_event['filled_price'] = 0
+                            order_event['filled'] = 0
+                            order_event['client_id'] = str(random.randint(1,99999))
+                            order_event['order_id'] = str(random.randint(1,99999))
+                            order_event['localtime'] = signal_time
+                            order_event['createtime'] = signal_time
+                            self.localOrders[order_event['client_id']] = order_event
+                del(signals[signal_time])
+    
+    # @timeit
+    def matching(self, data):
+        max_fill = 0 if data[utils.MAX_FILL_INDEX] == 0 else data[utils.MAX_FILL_INDEX]            # 最高成交价
+        min_fill = 9999999999 if data[utils.MIN_FILL_INDEX] == 0 else data[utils.MIN_FILL_INDEX]  # 最低成交价
+        # 本tick产生的pnl
+        pnl = 0.0
+        # 保存成交时间轴信息
+        if self._is_plot:
+            self.kd_trade.append(0)
+            self.kk_trade.append(0)
+            self.pd_trade.append(0)
+            self.pk_trade.append(0)
+        # 检查订单是否符合撮合条件
+        max_kd = 0
+        max_pk = 0
+        min_kk = 9999999999
+        min_pd = 9999999999
+        # 本ticker
+        now_bp = data[utils.BP_INDEX]
+        now_ap = data[utils.AP_INDEX]
+        self.mp = (now_ap + now_bp)*0.5
+        match_bp = now_bp
+        match_ap = now_ap
+        if self.last_marketdata != []:
+            # 前ticker
+            bp = self.last_marketdata[utils.BP_INDEX]
+            ap = self.last_marketdata[utils.AP_INDEX]
+            localOrdersCid = list(self.localOrders)
+            # 初始化成交标记
+            filled_flag = 0
+            filled_fee = 0.0
+            filled_price = 0.0
+            # 检查所有挂单是否被成交
+            for cid in localOrdersCid:
+                # 为每个订单重置成交标记
+                filled_flag = 0
+                filled_fee = 0.0
+                filled_price = 0.0
+                # 获取订单
+                i = self.localOrders[cid]
+                # 买单
+                if i["side"] == "kd":
+                    if self._is_plot: max_kd = max(max_kd,i['price'])
+                    # 判断吃单成交还是挂单成交
+                    if i['price'] > ap: # 吃单成交
+                        filled_flag = 1
+                        filled_fee = self.fee.taker
+                        filled_price = ap
+                    elif i["price"] > match_ap: # 挂单成交
+                        filled_flag = 1
+                        filled_fee = self.fee.maker
+                        filled_price = i["price"]
+                    elif i["price"] > min_fill: # 一定概率被挂单成交
+                        filled_flag = 1
+                        filled_fee = self.fee.maker
+                        filled_price = i["price"]
+                    if filled_flag:
+                        del(self.localOrders[cid])
+                        self.trade_num += 1
+                        value = i['amount'] * filled_price
+                        self.pos.longAvg = (self.pos.longPos * self.pos.longAvg + value)/(self.pos.longPos + i['amount'])
+                        self.pos.longPos += i['amount']
+                        pnl -= filled_fee * value
+                        self.kd_time = self.strategy.local_time
+                        if self._is_plot:
+                            self.kd_trade[-1] = filled_price
+                elif i["side"] == "pk":
+                    if self._is_plot: max_pk = max(max_pk,i['price'])
+                    # 判断吃单成交还是挂单成交
+                    if i["price"] > ap: # 吃单成交
+                        filled_flag = 1
+                        filled_fee = self.fee.taker
+                        filled_price = ap
+                    elif i["price"] > match_ap: # 挂单成交
+                        filled_flag = 1
+                        filled_fee = self.fee.maker
+                        filled_price = i["price"]
+                    elif i["price"] > min_fill: # 一定概率被挂单成交
+                        filled_flag = 1
+                        filled_fee = self.fee.maker
+                        filled_price = i["price"]
+                    if filled_flag:
+                        del(self.localOrders[cid])
+                        self.trade_num += 1
+                        pct = (self.pos.shortAvg - filled_price)/self.pos.shortAvg
+                        value = i['amount'] * filled_price
+                        pnl += pct * value - filled_fee * value
+                        self.pos.shortPos -= i['amount']
+                        if self._is_plot:
+                            self.pk_trade[-1] = filled_price
+                            self.hold_times.append(self.strategy.local_time - self.kk_time)
+                        else:
+                            if self.avg_hold_time == 0.0:
+                                self.avg_hold_time = self.strategy.local_time - self.kk_time
+                            else:
+                                self.avg_hold_time = self.avg_hold_time * 0.9 + (self.strategy.local_time - self.kk_time) * 0.1
+                # 卖单
+                elif i["side"] == "kk":
+                    if self._is_plot: min_kk = min(min_kk,i['price'])
+                    # 判断吃单成交还是挂单成交
+                    if i["price"] < bp: # 吃单成交
+                        filled_flag = 1
+                        filled_fee = self.fee.taker
+                        filled_price = i["price"]
+                    elif i["price"] < match_bp: # 挂单成交
+                        filled_flag = 1
+                        filled_fee = self.fee.maker
+                        filled_price = i["price"]
+                    elif i["price"] < max_fill: # 一定概率挂单成交
+                        filled_flag = 1
+                        filled_fee = self.fee.maker
+                        filled_price = i["price"]
+                    if filled_flag:
+                        del(self.localOrders[cid])
+                        self.trade_num += 1
+                        value = filled_price * i['amount']
+                        self.pos.shortAvg = (self.pos.shortAvg*self.pos.shortPos + value)/(self.pos.shortPos + i['amount'])
+                        self.pos.shortPos += i['amount']
+                        pnl -= filled_fee * value
+                        self.kk_time = self.strategy.local_time
+                        if self._is_plot:
+                            self.kk_trade[-1] = filled_price
+                elif i["side"] == "pd":
+                    if self._is_plot: min_pd = min(min_pd,i['price'])
+                    # 判断吃单成交还是挂单成交
+                    if i["price"] < bp: # 吃单成交
+                        filled_flag = 1
+                        filled_fee = self.fee.taker
+                        filled_price = i["price"]
+                    elif i["price"] < match_bp: # 挂单成交
+                        filled_flag = 1
+                        filled_fee = self.fee.maker
+                        filled_price = i["price"]
+                    elif i["price"] < max_fill: # 一定概率挂单成交
+                        filled_flag = 1
+                        filled_fee = self.fee.maker
+                        filled_price = i["price"]
+                    if filled_flag:
+                        del(self.localOrders[cid])
+                        self.trade_num += 1
+                        pct = (filled_price - self.pos.longAvg)/self.pos.longAvg
+                        value = i['amount'] * i['price']
+                        pnl += pct * value - filled_fee * value
+                        self.pos.longPos -= i['amount']
+                        if self._is_plot:
+                            self.pd_trade[-1] = filled_price
+                            self.hold_times.append(self.strategy.local_time - self.kd_time)
+                        else:
+                            if self.avg_hold_time == 0.0:
+                                self.avg_hold_time = self.strategy.local_time - self.kd_time
+                            else:
+                                self.avg_hold_time = self.avg_hold_time * 0.9 + (self.strategy.local_time - self.kd_time) * 0.1
+            # 当前浮动盈亏
+            unrealized_pnl = 0 
+            if self.pos.longPos > 0:
+                unrealized_pnl += (self.mp - self.pos.longAvg)/self.pos.longAvg * self.pos.longPos * self.mp
+            if self.pos.shortPos > 0:
+                unrealized_pnl += (self.pos.shortAvg - self.mp)/self.pos.shortAvg * self.pos.shortPos * self.mp
+            # 如果需要画图
+            if self._is_plot:
+                # 保存挂单信息时间轴
+                max_kd = 0 if max_kd == 0 else max_kd
+                max_pk = 0 if max_pk == 0 else max_pk
+                min_kk = 0 if min_kk == 9999999999 else min_kk
+                min_pd = 0 if min_pd == 9999999999 else min_pd
+                self.localOrders_timeseries.append([max_kd,max_pk,min_kk,min_pd])
+                self.priceMotion.append(self.strategy.mp)
+                # 参考价格
+                self.reference_price.append(self.strategy.ref_price)
+                # 成交情况
+                self.fill_price.append([max_fill,min_fill])
+                # 持仓情况
+                self.long_hold.append(self.pos.longPos)
+                self.short_hold.append(self.pos.shortPos)
+                # 净值
+                self.index.append(self.strategy.equity)
+                # 记录浮动净值
+                self.balance.append(self.equity + pnl + unrealized_pnl)
+            # 记录净值 包含未实现盈亏
+            self.equity_real = self.equity + pnl + unrealized_pnl
+            # 记录净值 已实现盈亏
+            self.equity = self.equity + pnl
+            # 记录最高净值记录
+            self.equity_high = max(self.equity_high, self.equity)
+        # 结束本次回测
+        # 更新上一次的行情记录
+        self.last_marketdata = data
+
+    # @timeit
+    def run_by_tick(self, tradeMsg:model.TraderMsg):
+        # 更新策略本地时间
+        self.strategy.local_time = self.backtest_time
+        # 执行信号
+        self.handle_signal(self.signal_series)
+        # 用最新行情进行撮合
+        self.matching(tradeMsg.market)
+        # 更新账户信息
+        tradeMsg.cash = self.equity
+        tradeMsg.orders = self.localOrders
+        tradeMsg.position = self.pos
+        # 产生本tick信号
+        signal = self.strategy.onTime(tradeMsg)
+        self.signal_series[self.backtest_time] = signal
+        self.last_signal_time = self.backtest_time
+        

+ 636 - 0
bak/center.py

@@ -0,0 +1,636 @@
+import asyncio
+from http.client import NON_AUTHORITATIVE_INFORMATION
+from aiohttp import web
+import traceback
+import time
+import utils
+import logging, logging.handlers
+import signal
+import broker
+import os, json, sys, random
+
+def timeit(func):
+    def wrapper(*args, **kwargs):
+        nowTime = time.time()
+        res = func(*args, **kwargs)
+        spend_time = time.time() - nowTime
+        spend_time = round(spend_time * 100, 2)
+        print(f'{func.__name__} 耗时 {spend_time} ms')
+        return res
+    return wrapper
+
+def takeSecond(elem):
+    return elem[1]
+
+import ccxt.async_support as ccxt
+
+ex_list = [
+        ccxt.binanceusdm(),
+        ccxt.binance(),
+        # ccxt.okex5(),
+        # ccxt.kucoin(),
+        # ccxt.gateio(),
+        # ccxt.coinex(),
+    ]
+
+class Center:
+
+    def __init__(self, fname, logname=None):
+        print('###############   参数中心   ################')
+        print(f'>> {utils.VERSION} <<<')
+        print('*** 当前配置 ***')
+        self.fname = fname
+        self.params = utils.get_params(fname)
+        for p in self.params.__dict__:
+            print('***', p, ' => ', getattr(self.params, p))
+        print('##################################################')
+        pid = os.getpid()
+        print(f'交易程序正在启动 进程号{pid}...')
+        self.logger = self.get_logger(logname)
+        self.params_base = dict()
+        for i in broker.exchange_lists:self.params_base[i] = dict()
+        self.params_dummy = dict()
+        for i in broker.exchange_lists:self.params_dummy[i] = dict()
+        self.params_real = dict()
+        for i in broker.exchange_lists:self.params_real[i] = dict()
+        self.win_dict = dict()
+        for i in broker.exchange_lists:self.win_dict[i] = dict()
+        self.loss_dict = dict()
+        for i in broker.exchange_lists:self.loss_dict[i] = dict()
+        self.choose_dict = dict()
+        for i in broker.exchange_lists:self.choose_dict[i] = dict()
+        self.dummy_choose_dict = dict()
+        for i in broker.exchange_lists:self.dummy_choose_dict[i] = dict()
+        try:
+            with open('params_real.json','r') as f:
+                self.params_real = json.load(f)
+        except:
+            pass
+        try:
+            with open('choose_dict.json','r') as f:
+                self.choose_dict = json.load(f)
+        except:
+            pass
+        try:
+            with open('dummy_choose_dict.json','r') as f:
+                self.dummy_choose_dict = json.load(f)
+        except:
+            pass
+        try:
+            with open('params_dummy.json','r') as f:
+                self.params_dummy = json.load(f)
+        except:
+            pass
+        self.loop = asyncio.get_event_loop()
+        ###
+        self.market = dict()
+        self.score = dict()
+        self.info_msg = "加载中..."
+
+    def get_logger(self, logname):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        if logname == None: logname = "log"
+        handler = logging.handlers.RotatingFileHandler(f"{logname}.log",maxBytes=1024*1024*50,backupCount=10,encoding='utf-8')
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        # log to console
+        console = logging.StreamHandler()
+        console.setLevel(logging.INFO)
+        logger.addHandler(handler)
+        logger.addHandler(console)
+        logger.info('开启日志记录')
+        return logger
+
+    async def exit(self, delay=0):
+        '''退出操作'''
+        print(f"开始退出操作 delay{delay}")
+        if delay > 0:
+            await asyncio.sleep(delay)
+        self.logger.info(f'停机退出')
+        await asyncio.sleep(1)
+        print("停机...")
+        # self.loop.create_task(utils.ding(f"参数中心程序停止", 1, self.params.webhook, self.params.proxy))
+        self.loop.stop()
+        os._exit(0) 
+
+    async def read_market(self):
+
+        market = dict()
+        
+        print("开始查询行情")
+
+        for exchange in ex_list:
+        
+            if (exchange.has['fetchTickers']):
+
+                tickers = await exchange.fetch_tickers()
+
+                for i in tickers:
+
+                    symbol = tickers[i]['symbol']
+                    if '/USDT' not in symbol:continue
+                    pair = symbol.split('/')[0].lower() + '_' + symbol.split('/')[1].lower()
+                    last = tickers[i]['last']
+
+                    if last > 0:
+                        if pair not in market:
+                            market[pair] = last
+        return market
+
+
+    async def update_score(self):
+        '''
+            按最近涨跌幅打分
+        '''
+        await asyncio.sleep(5)
+        while 1:
+            try:
+                market = await self.read_market()
+                self.score = dict()
+                for i in self.market:
+                    if i in market:
+                        score = abs(market[i] - self.market[i])/market[i]
+                        if score > 0:
+                            self.score[i] = score
+                score_list = []
+                for i in self.score:
+                    score_list.append(self.score[i])
+                if len(score_list) > 0:
+                    max_score = max(score_list)
+                    for i in self.score:
+                        self.score[i] = round(self.score[i]/max_score,3)
+                self.market = market
+                print("当前打分", self.score)
+                await asyncio.sleep(60)
+            except Exception as e:
+                print("定时循环系统出错"+str(e))
+                self.logger.error(traceback.print_exc())
+                await asyncio.sleep(10)
+
+    async def _on_timer(self):
+        '''定期触发系统逻辑'''
+        await asyncio.sleep(1)
+        while 1:
+            try:
+                print('整理参数池...')
+                # 重置info_msg
+                self.info_msg = ""
+                # 加载配置文件群
+                p_list = []
+                for root, dirs, files in os.walk(os.getcwd()):  
+                    for file_name in files:
+                        if 'config_dummy' in file_name:
+                            p_list.append(utils.get_params(os.path.join(root, file_name)))
+                # 打印本地参数数据
+                self.params_base = dict()
+                for i in broker.exchange_lists:self.params_base[i] = dict()
+                for p in p_list:
+                    name = p.exchange + '@' + p.pair
+                    self.params_base[p.exchange][name] = dict()
+                    self.params_base[p.exchange][name]['pair'] = p.pair
+                    self.params_base[p.exchange][name]['exchange'] = p.exchange
+                    self.params_base[p.exchange][name]['open'] = p.open
+                    self.params_base[p.exchange][name]['close'] = p.close
+                    self.params_base[p.exchange][name]['refexchange'] = p.refexchange
+                    self.params_base[p.exchange][name]['refpair'] = p.refpair
+                    self.params_base[p.exchange][name]['profit'] = 0.0
+                #### 整理params_real
+                # 移除太旧的信息
+                params_real = dict()
+                for ex in broker.exchange_lists:
+                    params_real[ex] = dict()
+                    for name in self.params_real[ex].keys():
+                        if time.time() - self.params_real[ex][name]['time'] > utils.EARLY_STOP_SECOND:
+                            print(f"params real 记录时间太久 移除{name}")
+                        else:
+                            params_real[ex][name] = self.params_real[ex][name]
+                self.params_real = params_real
+                # 移除异常信息
+                params_real = dict()
+                for ex in broker.exchange_lists:
+                    params_real[ex] = dict()
+                    for name in self.params_real[ex].keys():
+                        if abs(float(self.params_real[ex][name]['profit'])) > 10000.0:
+                            print(f"params real 收益率异常 移除{name}")
+                        else:
+                            params_real[ex][name] = self.params_real[ex][name]
+                self.params_real = params_real
+                #### 整理params_dummy
+                # 移除太旧的信息
+                params_dummy = dict()
+                for ex in broker.exchange_lists:
+                    params_dummy[ex] = dict()
+                    for name in self.params_dummy[ex].keys():
+                        if time.time() - self.params_dummy[ex][name]['time'] > utils.DUMMY_EARLY_STOP_SECOND:
+                            print(f"params dummy 记录时间太久 移除{name}")
+                        else:
+                            params_dummy[ex][name] = self.params_dummy[ex][name]
+                self.params_dummy = params_dummy
+                # 移除异常信息
+                params_dummy = dict()
+                for ex in broker.exchange_lists:
+                    params_dummy[ex] = dict()
+                    for name in self.params_dummy[ex].keys():
+                        if abs(float(self.params_dummy[ex][name]['profit'])) > 10000.0:
+                            print(f"params dummy 收益率异常 移除{name}")
+                        else:
+                            params_dummy[ex][name] = self.params_dummy[ex][name]
+                self.params_dummy = params_dummy
+                # 从params_real 提取 win_dict loss_dict
+                self.logger.info('='*10)
+                profit_thre = 0.001
+                for ex in broker.exchange_lists:
+                    self.logger.info('='*10)
+                    self.info_msg += '='*10 + '\r'
+                    # 更新windict lossdict
+                    self.win_dict[ex] = dict()
+                    self.loss_dict[ex] = dict()
+                    profits = []
+                    pairs = []
+                    for name in self.params_real[ex]:
+                        profit = self.params_real[ex][name]['profit']
+                        profits.append(profit)
+                        pairs.append(self.params_real[ex][name]['pair'])
+                        if profit > profit_thre:
+                            self.win_dict[ex][name] = self.params_real[ex][name]
+                            self.win_dict[ex][name]['leverrate'] = 2.0
+                        else:
+                            self.loss_dict[ex][name] = self.params_real[ex][name]
+                            self.loss_dict[ex][name]['leverrate'] = 0.0
+                    if len(profits) > 0:
+                        max_profit = max(profits)
+                        max_profit_name = pairs[profits.index(max_profit)]
+                        max_profit - round(max_profit, 5)
+                        min_profit = min(profits)
+                        min_profit_name = pairs[profits.index(min_profit)]
+                        min_profit - round(min_profit, 5)
+                        avg_profit = round(sum(profits)/len(profits),5)
+                        win_num = 0
+                        for i in profits:
+                            if i > profit_thre:win_num+=1
+                        total_num = len(profits)
+                        base_num = len(self.params_base[ex])
+                        msg = f"Real情况 盘口{ex} 最大{max_profit_name} {max_profit} 最小{min_profit_name} {min_profit} 平均{avg_profit} 盈利{win_num} 记录{total_num} 总数{base_num}"
+                        self.logger.info(msg)
+                        self.info_msg += msg + "\r"
+                        msg = ""
+                        for i in range(total_num):
+                            msg += pairs[i] + " "
+                            if (i+1)%10 == 0:
+                                self.logger.info(msg)
+                                self.info_msg += msg + "\r"
+                                msg = ""
+                        self.logger.info(msg)
+                        self.info_msg += msg + "\r"
+                    self.logger.info('-'*10)
+                    self.info_msg += '-'*10 + '\r'
+                    # 打印dummy_params情况
+                    profits = []
+                    pairs = []
+                    for name in self.params_dummy[ex]:
+                        profit = self.params_dummy[ex][name]['profit']
+                        profits.append(profit)
+                        pairs.append(self.params_dummy[ex][name]['pair'])
+                    if len(profits) > 0:
+                        max_profit = max(profits)
+                        max_profit_name = pairs[profits.index(max_profit)]
+                        max_profit - round(max_profit, 5)
+                        min_profit = min(profits)
+                        min_profit_name = pairs[profits.index(min_profit)]
+                        min_profit - round(min_profit, 5)
+                        avg_profit = round(sum(profits)/len(profits),5)
+                        win_num = 0
+                        for i in profits:
+                            if i > profit_thre:win_num+=1
+                        total_num = len(profits)
+                        base_num = len(self.params_base[ex])
+                        msg = f"Dummy情况 盘口{ex} 最大{max_profit_name} {max_profit} 最小{min_profit_name} {min_profit} 平均{avg_profit} 盈利{win_num} 记录{total_num} 总数{base_num}"
+                        self.logger.info(msg)
+                        self.info_msg += msg + "\r"
+                        msg = ""
+                        for i in range(total_num):
+                            msg += pairs[i] + " "
+                            if (i+1)%10 == 0:
+                                self.logger.info(msg)
+                                self.info_msg += msg + "\r"
+                                msg = ""
+                        self.logger.info(msg)
+                        self.info_msg += msg + "\r"
+                # 保存信息
+                with open('params_real.json', 'w') as f:
+                    json.dump(self.params_real,f)
+                with open('choose_dict.json', 'w') as f:
+                    json.dump(self.choose_dict,f)
+                with open('dummy_choose_dict.json', 'w') as f:
+                    json.dump(self.dummy_choose_dict,f)
+                with open('params_dummy.json', 'w') as f:
+                    json.dump(self.params_dummy,f)
+                # 休眠
+                await asyncio.sleep(60)
+            except Exception as e:
+                print("定时循环系统出错"+str(e))
+                self.logger.error(traceback.print_exc())
+                await asyncio.sleep(10)
+
+    def stop(self):
+        print(f'进入停机流程...')
+        self.loop.create_task(self.exit(delay=1))
+
+    async def get_info(self, request):
+        print(request.remote)
+        return web.Response(text=self.info_msg)
+
+    async def get_dummy_params(self, request):
+        '''
+            中控数据接口
+            从base参数池随机选参数
+        '''
+        data = await request.json()
+        if isinstance(data, str):
+            data = json.loads(data)
+        exchange = data['exchange']
+        res = dict()
+        # 全部dummy实例从base参数池随机选参数 尝试新品种
+        for _ in range(len(self.params_base[exchange])):
+            name = random.choice(list(self.params_base[exchange].keys()))
+            if name not in self.dummy_choose_dict[exchange]:
+                self.dummy_choose_dict[exchange][name] = 0 
+            if name not in self.loss_dict[exchange] and time.time() - self.dummy_choose_dict[exchange][name] > utils.DUMMY_EARLY_STOP_SECOND:
+                res[name] = self.params_base[exchange][name]
+                res[name]['leverrate'] = 0.5
+                self.dummy_choose_dict[exchange][name] = int(time.time())
+                break
+        ###### 如果没有找到满足条件的参数 随机一组参数
+        if res == dict():
+            name = random.choice(list(self.params_base[exchange].keys()))
+            res[name] = self.params_base[exchange][name]
+            res[name]['leverrate'] = 0.5
+            self.dummy_choose_dict[exchange][name] = int(time.time())
+        return web.Response(body=json.dumps(res))
+
+    async def post_dummy_params(self, request):
+        '''
+            中控数据接口
+            更新params_dummy参数池
+        '''
+        data = await request.json()
+        # ip = request.remote
+        # print(f'从{ip}更新dummy参数',data)
+        if isinstance(data, str):
+            data = json.loads(data)
+        exchange = data['exchange']
+        pair = data['pair']
+        name = exchange+"@"+pair
+        profit = round(float(data['profit']),4)
+        ####### 
+        if exchange in self.params_dummy:
+            # 如果已有记录
+            if name in self.params_dummy[exchange]:
+                self.params_dummy[exchange][name]['exchange'] = data['exchange']
+                self.params_dummy[exchange][name]['pair'] = data['pair']
+                self.params_dummy[exchange][name]['open'] = data['open']
+                self.params_dummy[exchange][name]['close'] = data['close']
+                self.params_dummy[exchange][name]['refexchange'] = data['refexchange']
+                self.params_dummy[exchange][name]['refpair'] = data['refpair']
+                self.params_dummy[exchange][name]['profit'] = \
+                    round( profit * 0.3 + self.params_dummy[exchange][name]['profit'] * 0.7, 4)
+                self.params_dummy[exchange][name]['time'] = int(time.time())
+            else:
+                # 如果没有记录
+                self.params_dummy[exchange][name] = dict()
+                self.params_dummy[exchange][name]['exchange'] = data['exchange']
+                self.params_dummy[exchange][name]['pair'] = data['pair']
+                self.params_dummy[exchange][name]['open'] = data['open']
+                self.params_dummy[exchange][name]['close'] = data['close']
+                self.params_dummy[exchange][name]['refexchange'] = data['refexchange']
+                self.params_dummy[exchange][name]['refpair'] = data['refpair']
+                self.params_dummy[exchange][name]['profit'] = profit
+                self.params_dummy[exchange][name]['time'] = int(time.time())
+        return web.Response(body=json.dumps({}))
+
+    async def get_params(self, request):
+        '''
+            中控数据接口
+            从base参数池随机选参数
+            从dummy参数池随机选参数
+        '''
+        data = await request.json()
+        if isinstance(data, str):
+            data = json.loads(data)
+        exchange = data['exchange']
+        res = dict()
+        # 本盘口盈利品种数量
+        win_num = len(self.win_dict[exchange])
+        # 计算盈利参数占比
+        if win_num == 0:
+            # 没有发现盈利品种的时候 全部实例尝试新品种
+            if random.randint(0,100) < 10:
+                # 小概率 从base参数池选参数
+                for _ in range(len(self.params_base[exchange])):
+                    name = random.choice(list(self.params_base[exchange].keys()))
+                    if name not in self.choose_dict[exchange]:
+                        self.choose_dict[exchange][name] = 0 
+                    if name not in self.loss_dict[exchange] and time.time() - self.choose_dict[exchange][name] > utils.EARLY_STOP_SECOND:
+                        res[name] = self.params_base[exchange][name]
+                        res[name]['leverrate'] = 0.5
+                        self.choose_dict[exchange][name] = int(time.time())
+                        break
+            else:
+                # 大概率 从dummy参数池选参数
+                # 计算概率数组
+                p_list = []
+                name_list = []
+                for name in self.params_dummy[exchange]:
+                    p = self.params_dummy[exchange][name]['profit']
+                    if p > 0.0:
+                        p_list.append(p)
+                        name_list.append(name)
+                if len(p_list) > 1:
+                    for _ in range(len(p_list)):
+                        name = random.choices(
+                            name_list,
+                            p_list,
+                            k=1,
+                        )[0]
+                        if name not in self.choose_dict[exchange]:
+                            self.choose_dict[exchange][name] = 0 
+                        if name not in self.loss_dict[exchange] and time.time() - self.choose_dict[exchange][name] > utils.EARLY_STOP_SECOND:
+                            res[name] = self.params_dummy[exchange][name]
+                            res[name]['leverrate'] = 0.5
+                            self.choose_dict[exchange][name] = int(time.time())
+                            break
+        else:
+            # 允许10%的实例去搜索新品种
+            if random.randint(0,100) < 90:
+                # 跑实盘盈利品种
+                res = self.win_dict[exchange]
+            else:
+                # 没有发现盈利品种的时候 全部实例尝试新品种
+                if random.randint(0,100) < 10:
+                    # 小概率 从base参数池选参数
+                    for _ in range(len(self.params_base[exchange])):
+                        name = random.choice(list(self.params_base[exchange].keys()))
+                        if name not in self.choose_dict[exchange]:
+                            self.choose_dict[exchange][name] = 0 
+                        if name not in self.loss_dict[exchange] and time.time() - self.choose_dict[exchange][name] > utils.EARLY_STOP_SECOND:
+                            res[name] = self.params_base[exchange][name]
+                            res[name]['leverrate'] = 0.5
+                            self.choose_dict[exchange][name] = int(time.time())
+                            break
+                else:
+                    # 大概率 从dummy参数池选参数
+                    # 计算概率数组
+                    p_list = []
+                    name_list = []
+                    for name in self.params_dummy[exchange]:
+                        p = self.params_dummy[exchange][name]['profit']
+                        if p > 0.0:
+                            p_list.append(p)
+                            name_list.append(name)
+                    if len(p_list) > 1:
+                        for _ in range(len(p_list)):
+                            name = random.choices(
+                                name_list,
+                                p_list,
+                                k=1,
+                            )[0]
+                            if name not in self.choose_dict[exchange]:
+                                self.choose_dict[exchange][name] = 0 
+                            if name not in self.loss_dict[exchange] and time.time() - self.choose_dict[exchange][name] > utils.EARLY_STOP_SECOND:
+                                res[name] = self.params_dummy[exchange][name]
+                                res[name]['leverrate'] = 0.5
+                                self.choose_dict[exchange][name] = int(time.time())
+                                break
+        ###### 如果没有找到满足条件的参数 随机一组参数
+        if res == dict():
+            name = random.choice(list(self.params_base[exchange].keys()))
+            res[name] = self.params_base[exchange][name]
+            res[name]['leverrate'] = 0.5
+            self.choose_dict[exchange][name] = int(time.time())
+        return web.Response(body=json.dumps(res))
+
+    async def post_params(self, request):
+        '''
+            中控数据接口
+            更新real参数池
+        '''
+        data = await request.json()
+        ip = request.remote
+        # print(f'从{ip}更新参数',data)
+        if isinstance(data, str):
+            data = json.loads(data)
+        exchange = data['exchange']
+        pair = data['pair']
+        name = exchange+"@"+pair
+        profit = round(float(data['profit']),4)
+        if exchange in self.params_real:
+            # 如果已有记录
+            if name in self.params_real[exchange]:
+                self.params_real[exchange][name]['exchange'] = data['exchange']
+                self.params_real[exchange][name]['pair'] = data['pair']
+                self.params_real[exchange][name]['open'] = data['open']
+                self.params_real[exchange][name]['close'] = data['close']
+                self.params_real[exchange][name]['refexchange'] = data['refexchange']
+                self.params_real[exchange][name]['refpair'] = data['refpair']
+                self.params_real[exchange][name]['profit'] = \
+                    round( profit * 0.3 + self.params_real[exchange][name]['profit'] * 0.7, 4)
+                self.params_real[exchange][name]['time'] = int(time.time())
+            else:
+                # 如果没有记录
+                self.params_real[exchange][name] = dict()
+                self.params_real[exchange][name]['exchange'] = data['exchange']
+                self.params_real[exchange][name]['pair'] = data['pair']
+                self.params_real[exchange][name]['open'] = data['open']
+                self.params_real[exchange][name]['close'] = data['close']
+                self.params_real[exchange][name]['refexchange'] = data['refexchange']
+                self.params_real[exchange][name]['refpair'] = data['refpair']
+                self.params_real[exchange][name]['profit'] = profit
+                self.params_real[exchange][name]['time'] = int(time.time())
+        return web.Response(body=json.dumps({}))
+
+    async def _run_server(self):
+        print('server正在启动...')
+        await asyncio.sleep(10)
+        app = web.Application()
+        app.router.add_route('*', f'/get_info', self.get_info)
+        app.router.add_route('POST', f'/get_params', self.get_params)
+        app.router.add_route('POST', f'/get_dummy_params', self.get_dummy_params)
+        app.router.add_route('POST', f'/post_params', self.post_params)
+        app.router.add_route('POST', f'/post_dummy_params', self.post_dummy_params)
+        try:
+            self.loop.create_task(web._run_app(app, host='0.0.0.0', port=self.params.server_port, handle_signals=False))
+        except:
+            self.logger.error(traceback.format_exc())
+
+    def run(self):
+        '''启动ws行情获取'''
+        
+        tasks = []
+        # 策略
+        for i in [
+                asyncio.ensure_future(self._run_server()),
+                asyncio.ensure_future(self._on_timer()),
+                # asyncio.ensure_future(self.update_score()),
+            ]:
+            tasks.append(i)
+
+        def keyboard_interrupt(s, f):
+            print("收到退出信号 准备关机")
+            self.logger.info("收到退出信号 准备关机")
+            self.stop()
+        try:
+            signal.signal(signal.SIGINT, keyboard_interrupt)
+            signal.signal(signal.SIGTERM, keyboard_interrupt)
+            if 'win' not in sys.platform:
+                signal.signal(signal.SIGKILL, keyboard_interrupt)
+                signal.signal(signal.SIGQUIT, keyboard_interrupt)
+        except:
+            pass
+
+        self.loop.run_until_complete(asyncio.wait(tasks))
+
+if __name__ == "__main__":
+
+    if 0:
+        utils.check_auth()
+
+    pnum = len(sys.argv)
+
+    if pnum > 0:
+        fname = None
+        log_file = None
+        pidnum = None
+        for i in range(pnum):
+            print(f"第{i}个参数为:{sys.argv[i]}")
+            if sys.argv[i] == '-c' or sys.argv[i] == '--c': 
+                fname = sys.argv[i+1]
+            elif sys.argv[i] == '-h': 
+                print("帮助文档")
+            elif sys.argv[i] == '-log_file' or sys.argv[i] == '--log_file':
+                log_file = sys.argv[i+1]
+            elif sys.argv[i] == '-num' or sys.argv[i] == '--num':
+                pidnum = sys.argv[i+1]
+            elif sys.argv[i] == '-v' or sys.argv[i] == '--v':
+                print("当前版本为 V4.1")
+        if fname and log_file and pidnum:
+            print(f"指定的配置为 fname:{fname} log_file:{log_file} pidnum:{pidnum}")
+            date = time.strftime("%Y%m%d", time.localtime()) 
+            logname = f"{log_file}-{date}"
+            quant = Center(fname, logname)
+            quant.run()
+        elif fname:
+            print(f"运行指定配置文件{fname}")
+            quant = Center(fname)
+            quant.run()
+        else:
+            print("缺少指定参数 运行默认配置文件")
+            fname = 'config.toml'
+            quant = Center(fname)
+            quant.run()
+    else:
+        fname = 'config.toml'
+        quant = Center(fname)
+        quant.run()

+ 487 - 0
bak/dummy.py

@@ -0,0 +1,487 @@
+import asyncio
+from aiohttp import web
+import traceback
+import time
+import strategy
+import backtest
+import utils
+import model
+import logging, logging.handlers
+import signal
+import os, json, sys
+import csv
+import predictor
+import subprocess
+from decimal import Decimal
+import gc
+import broker
+
+VERSION = utils.VERSION
+
+def timeit(func):
+    def wrapper(*args, **kwargs):
+        nowTime = time.time()
+        res = func(*args, **kwargs)
+        spend_time = time.time() - nowTime
+        spend_time = round(spend_time * 100, 2)
+        print(f'{func.__name__} 耗时 {spend_time} ms')
+        return res
+    return wrapper
+
+class Dummy:
+
+    def __init__(self, params:model.Config, logname=None):
+        print('###############   Dummy System   ################')
+        print(f'>>> 版本号v{VERSION} <<<')
+        print('*** 当前配置')
+        self.params = params
+        for p in self.params.__dict__:
+            print('***', p, ' => ', getattr(self.params, p))
+        print('##################################################')
+        pid = os.getpid()
+        print(f'Dummpy System 正在启动 进程号{pid}...')
+        self.pid_start_time = time.time()
+        self.logger = self.get_logger(logname)
+        self.acct_name = self.params.account_name
+        self.symbol = self.params.pair
+        self.loop = asyncio.get_event_loop()
+        self.interval = float(self.params.interval)
+        self.exchange = self.params.exchange
+        self.tradeMsg = model.TraderMsg()
+        self.exit_msg = "正常退出"
+        # 现货特殊变量
+        self.is_first = 1
+        # 参考盘口名称列表
+        self.ref_names = []
+        self.tickers = dict()
+        self.tickers_update_time = dict()
+        for i in range(len(self.params.refexchange)):
+            refex = self.params.refexchange[i]
+            pair = self.params.refpair[i]
+            name = refex + '@' + pair
+            self.ref_names.append(name)
+            self.tickers[name] = dict()
+            self.tickers_update_time[name] = time.time()
+        # 参考盘口tick更新时间
+        # 创建ws实例
+        self.wss = dict()
+        name = self.exchange+'@'+self.params.pair
+        self.trade_name = name
+        cp = model.ClientParams()
+        cp.name = name
+        cp.pair = self.params.pair
+        cp.access_key = self.params.access_key
+        cp.secret_key = self.params.secret_key
+        cp.pass_key = self.params.pass_key
+        cp.interval = self.params.interval
+        cp.broker_id = self.params.broker_id
+        cp.debug = self.params.debug
+        cp.proxy = self.params.proxy
+        cp.interval = self.params.interval
+        self.ws = broker.newWs(self.exchange)(cp)
+        self.ws.logger = self.logger
+        self.ready = 0
+        # 参考盘口
+        for i,name in enumerate(self.ref_names):
+            cp = model.ClientParams()
+            cp.name = name
+            cp.pair = self.params.refpair[i]
+            cp.proxy = self.params.proxy
+            cp.interval = self.params.interval
+            self.wss[name] = broker.newWs(self.params.refexchange[i])(cp)
+            self.wss[name].callback = {
+                'onTicker':self.update_ticker,
+                'onDepth':self.update_depth,
+                }
+            self.wss[name].logger = self.logger
+        # 添加回调
+        self.ws.callback = {
+            'onTicker':self.update_ticker,
+            'onDepth':self.update_depth,
+            'onPosition':self.update_position,
+            'onAccount':self.update_account,
+            'onEquity':self.update_equity,
+            'onFreeEquity':self.update_free_equity,
+            'onOrder':self.update_order,
+            }
+        # 配置定价模型
+        self.Predictor = predictor.Predictor(ref_name=self.ref_names)
+        # 配置实时回测
+        # 基础参数 当找不到盈利参数时使用
+        self.base_open = float(self.params.open)
+        self.base_close = float(self.params.close)
+        self.base_index = 0
+        self.base_profit = 0.0
+        self.backtest_tasks = list()
+        self.backtest_start_equity = 1000000.0
+        for _open in [0.001,0.002,0.003]:
+            for _G in [0.2]:
+                for _index in range(len(self.ref_names)):
+                    # 采用虚拟合约交易策略进行实时回测
+                    _close = round(_open * _G, 5)
+                    task = dict()
+                    st = strategy.Strategy(self.params, is_print=0)
+                    st.leverrate = 1.0
+                    st.trade_open_dist = _open
+                    st.trade_close_dist = _close
+                    st.ref_index = _index
+                    st.exchange = 'dummy_usdt_swap'
+                    st.local_start_time = 0.0
+                    bt = backtest.Backtest(st, is_plot=0)
+                    bt.start_cash = self.backtest_start_equity
+                    task["backtest_engine"] = bt
+                    task["open"] = _open
+                    task["close"] = _close
+                    task["index"] = _index
+                    self.backtest_tasks.append(task)
+
+    def get_logger(self, logname):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        if logname == None: logname = "log"
+        handler = logging.handlers.RotatingFileHandler(f"{logname}.log",maxBytes=1024*1024*50,encoding='utf-8')
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        # log to console
+        console = logging.StreamHandler()
+        console.setLevel(logging.INFO)
+        logger.addHandler(handler)
+        logger.addHandler(console)
+        logger.info('开启日志记录')
+        return logger
+
+    def update_order(self, data):
+        pass
+
+    def update_position(self, data):
+        pass
+
+    def update_ticker(self, data):
+        '''更新ticker信息'''
+        name = data['name']
+        # 记录深度更新时间
+        self.tickers_update_time[name] = time.time()
+        self.tickers[name] = data
+    
+    def update_depth(self, data):
+        '''更新depth信息'''
+        name = data['name']
+        # 记录深度更新时间
+        self.tickers_update_time[name] = time.time()
+
+    def update_equity(self, data):
+        pass
+
+    def update_free_equity(self, data):
+        pass
+
+    def update_account(self, data):
+        '''更新账户信息'''
+        pass
+
+    def update_trade_msg(self):
+        pass
+
+    def get_all_tickers(self):
+        '''
+            组合最新价格信息
+            有depth用mp
+            两depth之间用lp
+        '''
+        ref_tickers = []
+        for i in self.ref_names:
+            ref_tickers.append([self.tickers[i]['bp'], self.tickers[i]['ap']])
+        return ref_tickers
+
+    def real_time_back_test(self, data):
+        ''' 
+            按照长短期回测利润选择参数
+            优先按长期回测利润选参数 如果找不到就
+            再按短期回测利润选参数 如果还找不到就
+            使用默认参数 如果默认参数亏损就触发冷静期
+        '''
+        now_time = time.time()
+        for i in self.backtest_tasks:
+            i["backtest_engine"].backtest_time = now_time
+            i["backtest_engine"].run_by_tick(data)
+
+    def choose_params(self):
+        ''' 
+        按照长短期回测利润选择参数
+        优先按长期回测利润选参数 如果找不到就
+        再按短期回测利润选参数 如果还找不到就
+        使用默认参数 如果默认参数亏损就触发冷静期
+        '''       
+        profits = []
+        for i in self.backtest_tasks:
+            equity = i["backtest_engine"].equity
+            # 直接按收益排序 简单粗暴有效
+            # 转换为预估日化收益    
+            profit = (equity-self.backtest_start_equity)/self.backtest_start_equity/(time.time()-self.pid_start_time)*86400.0
+            profits.append(profit) # 利润      
+        # 按回测结果调整参数
+        # 查找利润
+        max_index = profits.index(max(profits))
+        max_profit = max(profits)
+        self.base_close = self.backtest_tasks[max_index]["close"]
+        self.base_open = self.backtest_tasks[max_index]["open"]
+        self.base_index = self.backtest_tasks[max_index]["index"]
+        self.base_profit = max_profit
+        return
+
+    def check_risk(self):
+        '''检查风控'''
+        ###### 行情更新异常风控 ######
+        for name in self.ref_names:
+            delay = round((time.time() - self.tickers_update_time[name]) * 1000, 3)
+            if delay > 60000: # 60s
+                msg = f"{self.acct_name} ref_name:{name} delay:{delay}ms 行情更新延迟过高 退出"
+                self.logger.error(msg)
+                # self.loop.create_task(utils.ding(msg, 1, self.params.webhook, self.params.proxy))
+                self.stop()
+
+    def print_backtest_results(self):
+        if self.base_profit > 0.0:
+            self.logger.info(f"exchange:{self.params.exchange} pair:{self.params.pair} open:{self.base_open} close:{self.base_close} index:{self.base_index} profit:{self.base_profit}")
+        else:
+            self.logger.info(f'无盈利结果 {self.base_profit}')
+
+    async def exit(self, delay=0):
+        '''退出操作'''
+        self.logger.info(f"开始退出操作 delay{delay}")
+        if delay > 0:
+            await asyncio.sleep(delay)
+        self.logger.info(f'停机退出 {self.exit_msg}')
+        await asyncio.sleep(1)
+        print("停机...")
+        # self.loop.create_task(utils.ding(f"{self.acct_name} Dummy System 停止", 1, self.params.webhook, self.params.proxy))
+        self.loop.stop()
+        os._exit(0) 
+
+    async def on_timer(self):
+        '''定期触发系统逻辑'''
+        # self.loop.create_task(utils.ding(f"{self.acct_name} Dummy System 启动", 1, self.params.webhook, self.params.proxy))
+        await asyncio.sleep(5)
+        push_time = utils.DUMMY_EARLY_STOP_SECOND * 0.5
+        start_time = time.time()
+        while 1:
+            try:
+                ####
+                await asyncio.sleep(60)
+                # 检查风控
+                self.check_risk()
+                # 打印回测结果
+                self.print_backtest_results()
+                # 参数调优
+                self.choose_params()
+                # 发送钉钉
+                if time.time() - start_time > push_time:
+                    await utils._post_params(
+                        "http://wwww.khods.com:8888/post_dummy_params", 
+                        self.params.proxy,
+                        json.dumps({
+                            "exchange":self.params.exchange,
+                            "pair":self.params.pair,
+                            "open":self.base_open,
+                            "close":self.base_close,
+                            "refexchange":self.params.refexchange,
+                            "refpair":self.params.refpair,
+                            "profit":self.base_profit
+                        })
+                    )
+            except Exception as e:
+                print("定时循环系统出错"+str(e))
+                self.logger.error(traceback.print_exc())
+                await asyncio.sleep(10)
+
+    async def early_stop_loop(self):
+        '''判断是否需要早停'''
+        while 1:
+            try:
+                # 1
+                await asyncio.sleep(utils.DUMMY_EARLY_STOP_SECOND)
+                # 2
+                if self.base_profit <= 0.0:
+                    self.exit_msg = "触发早停条件"
+                    self.stop()
+            except:
+                self.logger.error(traceback.format_exc())
+
+    def get_all_market_data(self):
+        '''
+            只能定时触发
+            组合市场信息=交易盘口+参考盘口1+参考盘口2...
+        '''
+        market = []
+        data = self.ws._get_data()["data"]
+        if data == []:return None
+        market += data
+        for name in self.ref_names:
+            data = self.wss[name]._get_data()["data"]
+            if data == []:return None
+            market += data
+        return market
+
+    async def run_stratey(self):
+        '''定期触发策略'''
+        print('定时触发器启动')
+        # 准备交易
+        try:
+            await asyncio.sleep(10)
+            while 1:
+                await asyncio.sleep(self.interval)
+                ### 是否准备充分
+                if self.ready:
+                    ### 更新市场数据
+                    all_market = self.get_all_market_data()
+                    ### 更新预测值
+                    self.Predictor.onTime(all_market)
+                    self.tradeMsg.market = all_market
+                    ### 更新交易数据
+                    self.update_trade_msg()
+                    ### 更新参考价格
+                    self.tradeMsg.ref_price = self.Predictor.Get_ref(self.get_all_tickers())
+                    self.real_time_back_test(self.tradeMsg)
+                else:
+                    self.check_ready()
+        except Exception as e:
+            print(e)
+            self.logger.error(e)
+            traceback.print_exc()
+            await asyncio.sleep(10)
+    
+    def check_ready(self):
+        '''
+            判断初始数据是否齐全
+        '''
+        ### 检查 ticker 行情
+        # for m in self.ref_names:
+        #     if m not in self.tickers:
+        #         return
+        #     else:
+        #         if self.tickers[m]['bp'] == 0 or self.tickers[m]['ap'] == 0:
+        #             return
+        #         else:
+        #             print('ref ticker 未准备好')
+        # if self.trade_name not in self.tickers:
+        #     return
+        # else:
+        #     if self.tickers[self.trade_name]['bp'] == 0 or self.tickers[self.trade_name]['ap'] == 0:
+        #         return
+        #     else:
+        #         print('trade ticker 未准备好')
+        ### 检查 market 行情
+        all_market = self.get_all_market_data()
+        if len(all_market) != utils.LEN*(1+len(self.ref_names)):
+            self.logger.error("聚合行情未准备好")
+            return
+        else:
+            # 如果行情已经就绪 预热trademsg和predictor
+            self.tradeMsg.market = all_market
+            self.Predictor.onTime(all_market)
+        self.ready = 1
+
+    async def server_handle(self, request):
+        '''中控数据接口'''
+        return web.Response(body=json.dumps({
+            "wallet_balance":1+self.base_profit,
+            "cross_wallet_balance":0,
+            "unrealized_pn_l":0,
+            "position_amount":0,
+            "entry_price":0,        
+            "accumulated_realized":0,
+            "now_price":(self.tickers[self.trade_name]['bp']+self.tickers[self.trade_name]['ap'])*0.5,
+        }))
+
+    async def _run_server(self):
+        print('server正在启动...')
+        app = web.Application()
+        app.router.add_route('GET', '/account', self.server_handle)
+        try:
+            self.loop.create_task(web._run_app(app, host='0.0.0.0', port=self.params.server_port, handle_signals=False))
+        except:
+            self.logger.error(f"Server启动失败")
+            self.logger.error(traceback.format_exc())
+            self.exit_msg = "服务启动失败 停机退出"
+            self.stop()
+
+    def stop(self):
+        self.logger.info(f'进入停机流程...')
+        self.loop.create_task(self.exit(delay=1))
+
+    def run(self):
+        '''启动ws行情获取'''
+        
+        tasks = []
+        # 使用全市场行情
+        for i in self.wss:
+            tasks.append(asyncio.ensure_future(self.wss[i].run()))
+        # 策略
+        for i in [
+                asyncio.ensure_future(self.ws.run(is_auth=0, sub_trade=1)),
+                asyncio.ensure_future(self.run_stratey()),
+                asyncio.ensure_future(self.on_timer()),
+                asyncio.ensure_future(self.early_stop_loop()),
+                asyncio.ensure_future(self._run_server()),
+            ]:
+            tasks.append(i)
+
+        def keyboard_interrupt(s, f):
+            self.logger.info("收到退出信号 准备关机")
+            self.stop()
+        try:
+            signal.signal(signal.SIGINT, keyboard_interrupt)
+            signal.signal(signal.SIGTERM, keyboard_interrupt)
+            if 'win' not in sys.platform:
+                signal.signal(signal.SIGKILL, keyboard_interrupt)
+                signal.signal(signal.SIGQUIT, keyboard_interrupt)
+        except:
+            pass
+
+        self.loop.run_until_complete(asyncio.wait(tasks))
+
+if __name__ == "__main__":
+
+    if 0:
+        utils.check_auth()
+
+    pnum = len(sys.argv)
+
+    if pnum > 0:
+        fname = None
+        log_file = None
+        pidnum = None
+        for i in range(pnum):
+            print(f"第{i}个参数为:{sys.argv[i]}")
+            if sys.argv[i] == '-c' or sys.argv[i] == '--c': 
+                fname = sys.argv[i+1]
+            elif sys.argv[i] == '-h': 
+                print("帮助文档")
+            elif sys.argv[i] == '-log_file' or sys.argv[i] == '--log_file':
+                log_file = sys.argv[i+1]
+            elif sys.argv[i] == '-num' or sys.argv[i] == '--num':
+                pidnum = sys.argv[i+1]
+            elif sys.argv[i] == '-v' or sys.argv[i] == '--v':
+                print(f"当前版本为 V{VERSION}")
+        if fname and log_file and pidnum:
+            print(f"指定的配置为 fname:{fname} log_file:{log_file} pidnum:{pidnum}")
+            date = time.strftime("%Y%m%d", time.localtime()) 
+            logname = f"{log_file}-{date}"
+            quant = Dummy(utils.get_params(fname), logname)
+            quant.run()
+        elif fname:
+            print(f"运行指定配置文件{fname}")
+            quant = Dummy(utils.get_params(fname))
+            quant.run()
+        else:
+            print("缺少指定参数 运行默认配置文件")
+            fname = 'config_dummy.toml'
+            quant = Dummy(utils.get_params(fname))
+            quant.run()
+    else:
+        fname = 'config_dummy.toml'
+        quant = Dummy(utils.get_params(fname))
+        quant.run()
+

+ 138 - 0
bak/final.py

@@ -0,0 +1,138 @@
+import utils
+import time
+import subprocess
+import signal
+import os, sys
+import toml, json
+import random
+
+if __name__ == "__main__":
+
+    exe = 0
+
+    for root, dirs, files in os.walk(os.getcwd()):  
+
+        for file_name in files:
+    
+            if 'final_quant' in file_name:
+    
+                exe = 1
+
+    quant = None
+
+    def keyboard_interrupt(s, f):
+        for _ in range(9):
+            try:
+                quant.terminate()
+                time.sleep(10)
+                break
+            except:
+                pass
+        os._exit(0)
+
+    try:
+        signal.signal(signal.SIGINT, keyboard_interrupt)
+        signal.signal(signal.SIGTERM, keyboard_interrupt)
+        if 'win' not in sys.platform:
+            signal.signal(signal.SIGKILL, keyboard_interrupt)
+            signal.signal(signal.SIGQUIT, keyboard_interrupt)
+    except:
+        pass
+
+    if 0:
+        utils.check_auth()
+
+    pnum = len(sys.argv)
+
+    fname = 'config.toml'
+    log_file = ""
+    pidnum = ""
+
+    if pnum > 0:
+        for i in range(pnum):
+            if sys.argv[i] == '-c' or sys.argv[i] == '--c': 
+                fname = sys.argv[i+1]
+            elif sys.argv[i] == '-log_file' or sys.argv[i] == '--log_file':
+                log_file = sys.argv[i+1]
+            elif sys.argv[i] == '-num' or sys.argv[i] == '--num':
+                pidnum = sys.argv[i+1]
+
+    while 1:
+        ### 加点意义不明的计算
+        num = 0
+        for i in range(99999):
+            num += i
+        ### 避免supervisor管理错误
+    
+        params = utils.get_params(fname)
+
+        params_pool = utils._get_params("http://wwww.khods.com:8888/get_params", params.proxy, {"exchange":params.exchange})
+
+        # 选出本交易所的数据
+        profits = []
+
+        for p_name in params_pool:
+            if params.exchange in p_name:
+                profit = float(params_pool[p_name]['profit'])
+                profits.append(max(profit,0.00001))
+
+        if len(params_pool) > 0:
+
+            random_params = params_pool[random.choices(list(params_pool.keys()),weights=profits,k=1)[0]]
+
+            # 切换参数
+            params.pair = random_params['pair']
+            params.refpair = str(random_params['refpair'])
+            params.refexchange = str(random_params['refexchange'])
+            params.open = random_params['open'] if 'open' in random_params else "0.003"
+            params.close = random_params['close'] if 'close' in random_params else "0.0001"
+            leverrate = random_params['leverrate'] if 'leverrate' in random_params else "1.0"
+            if float(leverrate) == 0.5:
+                # 随机分配的参数
+                params.leverrate = leverrate
+            elif float(leverrate) == 2.0:
+                # 来自盈利组的参数
+                min_profit = min(profits)
+                max_profit = max(profits)
+                if min_profit == max_profit:
+                    leverrate = 2.0
+                else:
+                    leverrate = round(0.5 + 1.5*(random_params['profit'] - min_profit)/(max_profit-min_profit),1)
+                params.leverrate = str(leverrate)
+            else:
+                # 其他参数
+                params.leverrate = leverrate
+
+
+            with open(fname, 'w+') as fp:
+                toml.dump(params.__dict__, fp)
+
+            if exe:
+
+                quant = subprocess.Popen(['./final_quant', '-c', fname, '-log_file', log_file, '-num', pidnum, '-child'])
+
+            else:
+
+                quant = subprocess.Popen(['python3', 'quant.py', '-c', fname, '-log_file', log_file, '-num', pidnum, '-child'])
+
+            wait_time = random.randint(utils.CHILD_RUN_SECOND*0.8,utils.CHILD_RUN_SECOND)
+
+            check_num = 10000
+
+            wait_time_per_loop = wait_time / check_num
+
+            for _ in range(check_num):
+
+                time.sleep(wait_time_per_loop)
+
+                if quant.poll() is not None:
+
+                    break
+
+            quant.terminate()
+
+        # 休息n秒进入下一轮
+        time.sleep(random.randint(10,15))
+
+            
+

+ 118 - 0
bak/seele.py

@@ -0,0 +1,118 @@
+
+import utils
+import time
+import subprocess
+import signal
+import os, sys
+import toml, json
+import random
+
+if __name__ == "__main__":
+
+    exe = 0
+
+    for root, dirs, files in os.walk(os.getcwd()):  
+
+        for file_name in files:
+    
+            if 'dummy_quant' in file_name:
+    
+                exe = 1
+
+    quant = None
+
+    def keyboard_interrupt(s, f):
+        for _ in range(9):
+            try:
+                quant.terminate()
+                time.sleep(10)
+                break
+            except:
+                pass
+        os._exit(0) 
+
+    try:
+        signal.signal(signal.SIGINT, keyboard_interrupt)
+        signal.signal(signal.SIGTERM, keyboard_interrupt)
+        if 'win' not in sys.platform:
+            signal.signal(signal.SIGKILL, keyboard_interrupt)
+            signal.signal(signal.SIGQUIT, keyboard_interrupt)
+    except:
+        pass
+
+    if 0:
+        utils.check_auth()
+
+    pnum = len(sys.argv)
+
+    fname = 'config.toml'
+    log_file = ""
+    pidnum = ""
+
+    if pnum > 0:
+        for i in range(pnum):
+            if sys.argv[i] == '-c' or sys.argv[i] == '--c': 
+                fname = sys.argv[i+1]
+            elif sys.argv[i] == '-log_file' or sys.argv[i] == '--log_file':
+                log_file = sys.argv[i+1]
+            elif sys.argv[i] == '-num' or sys.argv[i] == '--num':
+                pidnum = sys.argv[i+1]
+
+    while 1:
+    
+        params = utils.get_params(fname)
+
+        params_pool = utils._get_params("http://wwww.khods.com:8888/get_dummy_params", params.proxy, {"exchange":params.exchange})
+
+        # 选出本交易所的数据
+        params_temp = []
+
+        for p_name in params_pool:
+            if params.exchange in p_name:
+                profit = float(params_pool[p_name]['profit'])
+                for _ in range(1+int(profit//0.01)):
+                    params_temp.append(params_pool[p_name])
+
+        if len(params_temp) > 0:
+
+            max_params = random.choice(params_temp)
+
+            # 切换参数
+            params.pair = max_params['pair']
+            params.refpair = str(max_params['refpair'])
+            params.refexchange = str(max_params['refexchange'])
+            params.open = max_params['open'] if 'open' in max_params else "0.003"
+            params.close = max_params['close'] if 'close' in max_params else "0.0001"
+            params.leverrate = max_params['leverrate'] if 'leverrate' in max_params else "1.0"
+
+
+            with open(fname, 'w+') as fp:
+                toml.dump(params.__dict__, fp)
+
+            if exe:
+
+                quant = subprocess.Popen(['./dummy_quant', '-c', fname, '-log_file', log_file, '-num', pidnum, '-child'])
+
+            else:
+
+                quant = subprocess.Popen(['python3', 'dummy.py', '-c', fname, '-log_file', log_file, '-num', pidnum, '-child'])
+
+            wait_time = random.randint(utils.DUMMY_RUN_SECOND+60,utils.DUMMY_RUN_SECOND+120)
+
+            check_num = 10000
+
+            wait_time_per_loop = wait_time / check_num
+
+            for _ in range(check_num):
+
+                time.sleep(wait_time_per_loop)
+
+                if quant.poll() is not None:
+
+                    break
+
+            quant.terminate()
+
+        # 休息n秒进入下一轮
+        time.sleep(random.randint(10,15))
+

+ 108 - 0
broker.py

@@ -0,0 +1,108 @@
+
+############ WS ############
+import exchange.huobi_usdt_swap_ws as huobiusdtswapws
+import exchange.huobi_spot_ws as huobispotws
+########################  
+import exchange.okex_usdt_swap_ws as okexusdtswapws
+########################  
+import exchange.binance_coin_swap_ws as binancecoinswapws
+import exchange.binance_usdt_swap_ws as binanceusdtswapws
+import exchange.binance_spot_ws as binancespotws
+########################  
+import exchange.gate_spot_ws as gatespotws
+import exchange.gate_usdt_swap_ws as gateusdtswapws
+########################  
+import exchange.kucoin_spot_ws as kucoinspotws
+import exchange.kucoin_usdt_swap_ws as kucoinusdtswapws
+########################  
+import exchange.coinex_spot_ws as coinexspotws
+import exchange.coinex_usdt_swap_ws as coinexusdtswapws
+####################################
+import exchange.ftx_spot_ws as ftxspotws
+import exchange.ftx_usdt_swap_ws as ftxswapws
+####################################
+import exchange.bitget_usdt_swap_ws as bitgetusdtswapws
+####################################
+import exchange.bybit_usdt_swap_ws as bybitusdtswapws
+####################################
+####################################
+import exchange.mexc_spot_ws as mexcspotws
+####################################
+
+############ REST ############
+import exchange.gate_spot_rest as gate_spot_rest
+import exchange.gate_usdt_swap_rest as gate_usdt_swap_rest
+########################
+import exchange.binance_spot_rest as binance_spot_rest
+import exchange.binance_usdt_swap_rest as binance_usdt_swap_rest
+########################
+import exchange.kucoin_spot_rest as kucoin_spot_rest
+import exchange.kucoin_usdt_swap_rest as kucoin_usdt_swap_rest
+########################
+import exchange.coinex_spot_rest as coinex_spot_rest
+import exchange.coinex_usdt_swap_rest as coinex_usdt_swap_rest
+########################
+import exchange.okex_usdt_swap_rest as okex_usdt_swap_rest
+####################################
+import exchange.bitget_usdt_swap_rest as bitget_usdt_swap_rest
+####################################
+import exchange.bybit_usdt_swap_rest as bybit_usdt_swap_rest
+####################################
+import exchange.mexc_spot_rest as mexcspotrest
+####################################
+
+
+exchange_lists = [
+    "binance_spot","binance_usdt_swap","binance_coin_swap",
+    "huobi_spot","huobi_usdt_swap","huobi_coin_swap",
+    "okex_spot","okex_usdt_swap","okex_coin_swap",
+    "gate_spot","gate_usdt_swap",
+    "kucoin_spot","kucoin_usdt_swap",
+    "coinex_spot","coinex_usdt_swap",
+    "bitget_usdt_swap",
+    "bybit_usdt_swap",
+    "mexc_spot",
+]
+
+exchange_ws_clients = dict()
+exchange_ws_clients['binance_usdt_swap']= binanceusdtswapws.BinanceUsdtSwapWs
+exchange_ws_clients['binance_coin_swap']= binancecoinswapws.BinanceCoinSwapWs
+exchange_ws_clients['binance_spot']= binancespotws.BinanceSpotWs
+exchange_ws_clients['huobi_usdt_swap']=huobiusdtswapws.HuobiUsdtSwapWs
+exchange_ws_clients['huobi_spot']= huobispotws.HuobiSpotWs
+exchange_ws_clients['okex_usdt_swap']= okexusdtswapws.OkexUsdtSwapWs
+exchange_ws_clients['gate_spot']= gatespotws.GateSpotWs
+exchange_ws_clients['gate_usdt_swap']= gateusdtswapws.GateUsdtSwapWs
+exchange_ws_clients['kucoin_spot']= kucoinspotws.KucoinSpotWs
+exchange_ws_clients['kucoin_usdt_swap']= kucoinusdtswapws.KucoinUsdtSwapWs
+exchange_ws_clients['coinex_spot']= coinexspotws.CoinExSpotWs
+exchange_ws_clients['coinex_usdt_swap']= coinexusdtswapws.CoinExUsdtSwapWs
+exchange_ws_clients['ftx_spot'] = ftxspotws.FtxSpotWs
+exchange_ws_clients['ftx_usdt_swap'] = ftxswapws.FtxUsdtSwapWs
+exchange_ws_clients['bitget_usdt_swap'] = bitgetusdtswapws.BitgetUsdtSwapWs
+exchange_ws_clients['bybit_usdt_swap'] = bybitusdtswapws.BybitUsdtSwapWs
+exchange_ws_clients['mexc_spot'] = mexcspotws.MexcSpotWs
+
+
+exchange_rest_clients = dict()
+exchange_rest_clients['binance_usdt_swap'] = binance_usdt_swap_rest.BinanceUsdtSwapRest
+exchange_rest_clients['binance_spot'] = binance_spot_rest.BinanceSpotRest
+exchange_rest_clients['gate_spot'] = gate_spot_rest.GateSpotRest
+exchange_rest_clients['gate_usdt_swap'] = gate_usdt_swap_rest.GateUsdtSwapRest
+exchange_rest_clients['kucoin_spot'] = kucoin_spot_rest.KucoinSpotRest
+exchange_rest_clients['kucoin_usdt_swap'] = kucoin_usdt_swap_rest.KucoinUsdtSwapRest
+exchange_rest_clients['coinex_spot'] = coinex_spot_rest.CoinExSpotRest
+exchange_rest_clients['coinex_usdt_swap'] = coinex_usdt_swap_rest.CoinExUsdtSwapRest
+exchange_rest_clients['okex_usdt_swap'] = okex_usdt_swap_rest.OkexUsdtSwapRest
+exchange_rest_clients['bitget_usdt_swap'] = bitget_usdt_swap_rest.BitGetUsdtSwapRest
+exchange_rest_clients['bybit_usdt_swap'] = bybit_usdt_swap_rest.BybitUsdtSwapRest
+exchange_rest_clients['mexc_spot'] = mexcspotrest.MexcSpotRest
+
+def newWs(exchange):
+    if exchange in exchange_ws_clients:
+        return exchange_ws_clients[exchange]
+    else:
+        return None
+
+def newRest(exchange):
+    return exchange_rest_clients[exchange]

+ 26 - 0
config.toml

@@ -0,0 +1,26 @@
+broker_id = ""
+account_name = "test_account_001"
+access_key = "B5261D7FB987478FA47E14F00C4A653B"
+secret_key = "CAF69F5A23DE0EA32B7B15777AF47609C61B1EB17E5AC203"
+pass_key = ""
+exchange = "coinex_usdt_swap"
+pair = "rose_usdt"
+debug = "False"
+open = 0.001
+close = 0.0002
+server_port = 8003
+leverrate = "1.0"
+interval = 0.1
+refexchange = "['binance_usdt_swap']"
+refpair = "['rose_usdt']"
+webhook = ""
+used_pct = "0.9"
+index = 0
+save = 0
+hold_coin = 0.0
+log = 1
+stoploss = "0.02"
+gamma = 0.999
+grid = 1
+ip = 0
+backtest = 0

+ 0 - 0
exchange/__init__.py


+ 540 - 0
exchange/binance_coin_swap_rest.py

@@ -0,0 +1,540 @@
+import aiohttp
+import time
+import asyncio
+import zlib
+import json
+import hmac
+import base64
+import hashlib
+import traceback
+import random, sys
+from urllib.parse import urlparse
+import logging, logging.handlers
+import model, utils
+from decimal import Decimal
+
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+
+
+class BinanceCoinSwapRest:
+
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.HOST = 'https://dapi.binance.com'
+        else:
+            self.HOST = 'https://dapi.binance.com'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        if len(self.params.pair.split('_')) > 2:
+            self.delivery = self.params.pair.split('_')[2] # 210924
+            self.symbol += f"_{self.delivery}"
+        else:
+            self.symbol += '_PERP'
+        self.data = {}
+        self._SESSIONS = dict()
+        self.callback = {
+            "onMarket":empty_call,
+            "onPosition":empty_call,
+            "onOrder":empty_call,
+            "onEquity":empty_call,
+            "onExit":empty_call,
+            }
+        self.exchange_info = dict()
+        self.stepSize = None
+        self.tickSize = None
+        self.delays = []
+        self.avg_delay = 0
+        self.max_delay = 0
+        self.proxy = None
+        self.broker_id = self.params.broker_id
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.multiplier = None
+        self.mp = 0.0 # 初始mp 在check_position中更新
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+
+    async def _request(self, method, uri, body=None, params=None, HOST=None):
+
+        headers = {}
+        headers["Content-Type"] = "application/json"
+        headers['X-MBX-APIKEY'] = self.params.access_key
+
+        params['timestamp']=int(time.time())*1000
+        query_string = "&".join(["{}={}".format(k, params[k]) for k in params.keys()])
+        signature = hmac.new(self.params.secret_key.encode(), msg=query_string.encode(), digestmod=hashlib.sha256).hexdigest()
+        params['signature']=signature
+        if HOST == None:
+            url = self.HOST + uri
+        else:
+            url = HOST + uri
+        # 发起请求
+        timeout = aiohttp.ClientTimeout(10)
+        session = self._get_session(url)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json()
+            if code not in (200, 201, 202, 203, 204, 205, 206):
+                print(f'URL:{url} PARAMS:{params} body:{body} ERROR:{res}')
+                self.logger.error(res)
+                return None, str(res)
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            return res, None
+        except Exception as e:
+            print('网络请求错误')
+            print(f'URL:{url} PARAMS:{params} ERROR:{e}')
+            self.logger.error(e)
+            self.logger.error(traceback.format_exc())
+            return None, str(e)
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='LIMIT'):
+        '''
+            下单接口
+            input amount 单位为 币 需要转换为 张
+        '''
+        if symbol not in self.exchange_info:
+            await self.before_trade()
+        # price = float(Decimal(str(price//self.tickSize))*Decimal(str(self.tickSize)))
+        # amount = amount * price / self.multiplier
+        # amount = float(Decimal(str(amount//self.stepSize))*Decimal(str(self.stepSize)))
+        # 转换为张
+        if origin_side =='kd':
+            side = 'BUY'
+            positionSide = 'LONG'
+        elif origin_side =='pd':
+            side = 'SELL'
+            positionSide = 'LONG'
+        elif origin_side =='kk':
+            side = 'SELL'
+            positionSide = 'SHORT'
+        elif origin_side =='pk':
+            side = 'BUY'
+            positionSide = 'SHORT'
+        else:
+            raise Exception(f'下单参数错误 side:{origin_side}')
+        if amount <= 0: 
+            self.logger.error(f'下单参数错误 amount:{amount}')
+            return None
+        if price <= 0:
+            self.logger.error(f'下单参数错误 price:{price}')
+            return None
+        params = {
+            'symbol': symbol, 
+            'quantity': amount,
+            'side': side, 
+            'positionSide': positionSide, 
+            'type':order_type,
+            'newClientOrderId':cid,
+        }
+        if order_type in ['LIMIT','STOP','TAKE_PROFIT']:
+            params['price'] = price
+            params['timeInForce'] = 'GTC'
+        if self.params.debug == 'True':
+            await asyncio.sleep(0.1)
+            return None, None
+        else:
+            # 再报单
+            response, error = await self._request('POST', '/dapi/v1/order', params=params)
+            # 再更新
+            if response is not None:
+                if 'orderId' in response:
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['client_id'] = params["newClientOrderId"]
+                    order_event['order_id'] = response['orderId']
+                    self.callback["onOrder"](order_event)
+            if error:
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['filled_price'] = 0
+                order_event['filled'] = 0
+                order_event['client_id'] = params["newClientOrderId"]
+                self.callback["onOrder"](order_event)
+        return response, error
+    
+    async def cancel_order(self, order_id=None, client_id=None, symbol=None):
+        params = {
+            "symbol": self.symbol if symbol==None else symbol,
+        }
+        if order_id:
+            params["orderId"] = order_id
+        if client_id:
+            params["origClientOrderId"] = client_id
+        if self.params.debug == 'True':
+            await asyncio.sleep(0.1)
+            return None
+        else:
+            response, error = await self._request('DELETE', f'/dapi/v1/order', params=params)
+            if error:
+                print("撤单失败",error)
+                # if 'Unknown order sent.' in error: 
+                # 撤单失败 可能已经撤单 是否发生成交需要rest查
+                if client_id:await self.check_order(client_id=client_id)
+                if order_id:await self.check_order(order_id=order_id)
+                return error
+            if response:
+                if 'status' in response:
+                    if response['status'] in ['CANCELED','EXPIRED']:  # 已撤销 删除本地订单表
+                        order_event = dict()
+                        order_event['status'] = "REMOVE"
+                        order_event['client_id'] = response["clientOrderId"]
+                        order_event['order_id'] = response["orderId"]
+                        order_event['filled'] = float(response["executedQty"]) * self.multiplier / float(response["price"])
+                        order_event['filled_price'] = float(response["price"])
+                        self.callback['onOrder'](order_event)
+                return response
+
+    async def check_order(self, order_id=None, client_id=None, symbol=None):
+        params = {
+            "symbol": self.symbol if symbol==None else symbol,
+        }
+        if order_id:
+            params["orderId"] = order_id
+        if client_id:
+            params["origClientOrderId"] = client_id
+        if self.params.debug == 'True':
+            await asyncio.sleep(0.1)
+            return None
+        else:
+            response, error = await self._request('GET', f'/dapi/v1/order', params=params)
+            if error:
+                print("查单失败", error)
+                if 'Order does not exist' in error: 
+                    # 这种情况也可能还会有成交
+                    # 在订单从引擎到数据库的间隙查单会提示不存在 但实际有成交
+                    pass
+                return error
+            if response:
+                if 'status' in response:
+                    # 需要删除本地订单表的情况
+                    if response['status'] in ['CANCELED','EXPIRED','FILLED']:  
+                        order_event = dict()
+                        order_event['status'] = "REMOVE"
+                        order_event['client_id'] = response["clientOrderId"]
+                        order_event['order_id'] = response["orderId"]
+                        order_event['filled'] = float(response["executedQty"]) * self.multiplier / float(response["price"])
+                        order_event['filled_price'] = float(response["price"])
+                        self.callback['onOrder'](order_event)
+                    elif response['status'] in ["NEW"]: # 需要更新本地表的情况
+                        order_event = dict()
+                        order_event['status'] = "NEW"
+                        order_event['client_id'] = response["clientOrderId"]
+                        order_event['order_id'] = response['orderId']
+                        self.callback['onOrder'](order_event)
+                return response
+    
+    async def get_order_list(self):
+        '''
+            获取挂单表
+        '''
+        response, error = await self._request('GET', '/dapi/v1/openOrders', params={'symbol':self.symbol})
+        orders = [] # 查询当前挂单 只可能出现 new 和 partfill 默认成交为0 只有 done状态的订单才考虑是否有成交
+        if response:
+            for i in response:
+                order_event = dict()
+                order_event['status'] = "NEW"
+                order_event['filled'] = 0
+                order_event['filled_price'] = 0
+                order_event['client_id'] = i["clientOrderId"]
+                order_event['order_id'] = i['orderId']
+                self.callback["onOrder"](order_event)
+                orders.append(order_event)
+        if error:
+            print('查询列表出错',error)
+        return orders
+    
+    async def get_server_time(self):
+        params = {}
+        response = await self._request('GET', '/dapi/v1/time', params=params)
+        return response
+
+    async def change_pos_side(self, dual='true'):
+        '''
+            获取仓位模式
+        '''
+        params = {'dualSidePosition':dual}
+        response = await self._request('POST', '/dapi/v1/positionSide/dual', params=params)
+        return response
+
+    async def get_ticker(self):
+        params = {'symbol': self.symbol}
+        response = await self._request('GET', '/dapi/v1/ticker/bookTicker', params=params)
+        ap = float(response[0][0]["askPrice"])
+        bp = float(response[0][0]["bidPrice"])
+        mp = (ap+bp)*0.5
+        d = {"name":self.name,'mp': mp, 'bp':bp, 'ap':ap}
+        return d
+
+    async def before_trade(self):
+        params = {}
+        response, error = await self._request('GET', '/dapi/v1/exchangeInfo', params=params)
+        if response:
+            for i in response['symbols']:
+                if self.symbol in i['symbol'].upper():
+                    self.tickSize = float(i['filters'][0]['tickSize'])
+                    self.stepSize = float(i['filters'][1]['stepSize'])
+                    self.multiplier = float(i['contractSize'])
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i['symbol'].upper()
+                exchange_info.multiplier = float(i['contractSize'])
+                exchange_info.tickSize = float(i['filters'][0]['tickSize'])
+                exchange_info.stepSize = float(i['filters'][1]['stepSize'])
+                self.exchange_info[exchange_info.symbol] = exchange_info
+        if error:
+            print('获取市场信息错误',error)
+        ###
+        ticker = await self.get_ticker()
+        ###
+        res = await self.get_account()
+        if res:
+            for i in res:
+                if self.base == i['asset'].upper():
+                    self.data['equity'] = float(i['balance']) * ticker["mp"]
+                    self.callback['onEquity'](self.data['equity'])
+        if error:
+            print('获取账户信息错误',error)
+    
+    async def universalTransfer(self, _type='UMFUTURE_MAIN', asset='USDT', amount=0):
+        params = {}
+        params['type'] = _type
+        params['asset'] = asset
+        params['amount'] = amount
+        print('发起提现')
+        response = await self._request('POST', '/sapi/v1/asset/transfer', params=params, HOST='https://api.binance.com')
+        print(f'提现结果 {response}')
+        return response
+
+    async def futuresTransfer(self, _type='2', asset='USDT', amount=0):
+        '''
+            1: 现货账户向USDT合约账户划转
+            2: USDT合约账户向现货账户划转
+            3: 现货账户向币本位合约账户划转
+            4: 币本位合约账户向现货账户划转
+        '''
+        params = {}
+        params['type'] = _type
+        params['asset'] = asset
+        params['amount'] = amount
+        print('发起转账')
+        response = await self._request('POST', '/sapi/v1/futures/transfer', params=params, HOST='https://api.binance.com')
+        print(f'转账结果 {response}')
+        return response
+
+    async def get_account(self):
+        response, error = await self._request('GET','/dapi/v1/balance', params={})
+        return response
+
+    async def buy_token(self):
+        '''买入平台币'''
+        pass
+
+
+    async def check_position(self, hold_coin=0):
+        '''检查是否存在非运行币种的仓位并take平仓'''
+        self.logger.info('检查遗漏订单')
+        response, error = await self._request('GET', '/dapi/v1/openOrders', params={'symbol':self.symbol})
+        self.logger.info(response)
+        self.logger.info(error)
+        if error:self.logger.error(error)
+        if response:
+            for i in response:
+                symbol = i['symbol']
+                order_id = i['orderId']
+                if symbol == self.symbol: 
+                    res = await self.cancel_order(order_id=i['orderId'], symbol=i['symbol'])
+                    await asyncio.sleep(0.1)
+                    self.logger.info(res)
+        self.logger.info('检查遗漏仓位')
+        response, error = await self._request('GET','/dapi/v2/positionRisk', params={})
+        if error is not None:self.logger.error(error)
+        if response:
+            for i in response:
+                symbol = i['symbol']
+                if symbol == self.symbol:
+                    if i['positionSide'] == 'LONG':
+                        longPos = abs(float(i['positionAmt']))
+                        longAvg = abs(float(i['entryPrice']))
+                        if longPos > 0:
+                            self.logger.info('发现多头遗留仓位 进行平仓')
+                            res, error = await self.take_order(
+                                symbol,
+                                longPos,
+                                'pd',
+                                1,
+                                utils.get_cid(),
+                                order_type='MARKET',
+                            )
+                            self.logger.info(res)
+                            self.logger.info(error)
+                    if i['positionSide'] == 'SHORT':
+                        shortPos = abs(float(i['positionAmt']))
+                        shortAvg = abs(float(i['entryPrice']))
+                        if shortPos > 0:
+                            self.logger.info('发现空头遗留仓位 进行平仓')
+                            res, error = await self.take_order(
+                                symbol,
+                                shortPos,
+                                'pk',
+                                1,
+                                utils.get_cid(),
+                                order_type='MARKET',
+                            )
+                            self.logger.info(res)
+                            self.logger.info(error)
+        self.logger.info('遗留仓位检测完毕')
+        return
+
+    async def get_position(self):
+        '''
+            获取仓位信息
+        '''
+        response, error = await self._request('GET','/dapi/v2/positionRisk', params={'symbol':self.symbol})
+        longPos, shortPos = 0, 0
+        longAvg, shortAvg = 0, 0
+        position = {
+            "longPos":abs(longPos),
+            "shortPos":abs(shortPos), 
+            "longAvg":abs(longAvg), 
+            "shortAvg":abs(shortAvg)}
+        if response:
+            for i in response:
+                if i['symbol'] == self.symbol:
+                    if i['positionSide'] == 'LONG':
+                        longPos = float(i['positionAmt'])
+                        longAvg = float(i['entryPrice'])
+                    if i['positionSide'] == 'SHORT':
+                        shortPos = float(i['positionAmt'])
+                        shortAvg = float(i['entryPrice'])
+            position = model.Position()
+            position.longPos = abs(longPos)
+            position.longAvg = abs(longAvg)
+            position.shortPos = abs(shortPos)
+            position.shortAvg = abs(shortAvg)
+            self.callback['onPosition'](position)
+        return position
+
+    async def go(self):
+        '''
+            盘前
+            获取市场信息
+            获取账户信息
+            更改仓位模式(期货)
+            清空仓位和挂单
+            盘中
+            更新账户信息
+            更新挂单列表
+            更新仓位信息
+            更新延迟信息
+        '''
+        print('Rest循环器启动')
+        interval = 60  # 不能太快防止占用限频
+        ### beforeTrade
+        await self.before_trade()
+        await asyncio.sleep(1)
+        await self.change_pos_side()
+        await asyncio.sleep(1)
+        ### onTrade
+        loop = 0
+        while 1:
+            loop += 1
+            try:
+                # 更新账户
+                res = await self.get_account()
+                if res:
+                    for i in res:
+                        if self.quote == i['asset'].upper():
+                            self.data['equity'] = float(i['balance'])
+                            self.callback['onEquity'](self.data['equity'])
+                # 更新仓位
+                position = await self.get_position()
+                await asyncio.sleep(interval)
+                # 打印延迟
+                self.get_delay_info()
+                self.logger.debug(f'报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(30)
+
+    def get_data(self):
+        return self.data
+    
+    async def handle_signals(self, orders):
+        '''
+            执行策略指令
+            撤销订单
+            检查订单
+            下达订单
+        '''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    if cid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(client_id=cid))
+                    elif oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3]
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    cid = orders[order_name][0]
+                    # oid = orders[order_name][1]
+                    asyncio.get_event_loop().create_task(self.check_order(client_id=cid))
+        except Exception as e:
+            traceback.print_exc()
+            self.logger.error("执行信号出错"+str(e))
+            await asyncio.sleep(0.1)
+

+ 345 - 0
exchange/binance_coin_swap_ws.py

@@ -0,0 +1,345 @@
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random, csv, sys
+import logging, logging.handlers
+import utils
+import model
+
+def empty_call(msg):
+    pass
+
+
+def inflate(data):
+    '''
+        解压缩数据
+    '''
+    decompress = zlib.decompressobj(-zlib.MAX_WBITS)
+    inflated = decompress.decompress(data)
+    inflated += decompress.flush()
+    return inflated
+
+class BinanceCoinSwapWs:
+
+    def __init__(self, params: model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.URL = 'wss://dstream.binance.com/ws/'
+        else:
+            self.URL = 'wss://dstream.binance.com/ws/'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        if len(self.params.pair.split('_')) > 2:
+            self.delivery = self.params.pair.split('_')[2] # 210924
+            self.symbol += f"_{self.delivery}"
+        else:
+            self.symbol += '_PERP'
+        self.callback = {
+            "onMarket":self.save_market,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.multiplier = None
+        self.stop_flag = 0
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.depth = []
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = self.params.interval
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+    async def get_sign(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        params = {
+            'timestamp':int(time.time())*1000,
+            'recvWindow':5000,
+        }
+        query_string = "&".join(["{}={}".format(k, params[k]) for k in sorted(params.keys())])
+        signature = hmac.new(self.params.secret_key.encode(), msg=query_string.encode(), digestmod=hashlib.sha256).hexdigest()
+        params['signature']=signature
+        url = 'https://dapi.binance.com/dapi/v1/listenKey'
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            url, 
+            params=params,
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        login_str = await response.text()
+        await session.close()
+        return ujson.loads(login_str)['listenKey']
+
+    async def get_depth_flash(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        url = f'https://dapi.binance.com/dapi/v1/depth?symbol={self.symbol}&limit=1000'
+        session = aiohttp.ClientSession()
+        response = await session.get(
+            url, 
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        depth_flash = await response.text()
+        await session.close()
+        return ujson.loads(depth_flash)
+
+    def _update_ticker(self, msg):
+        msg = ujson.loads(msg)
+        bp = float(msg['b'])
+        bq = float(msg['B'])
+        ap = float(msg['a'])
+        aq = float(msg['A'])
+        self.ticker_info['bp'] = bp
+        self.ticker_info['ap'] = ap
+        self.callback['onTicker'](self.ticker_info)
+        #### 标准化深度
+        self.depth = [bp,bq,ap,aq]
+        self.callback['onDepth']({'name':self.name,'data':self.depth})
+        
+    def _update_depth(self, msg):
+        msg = ujson.loads(msg)
+        ##### on ticker event
+        self.ticker_info['bp'] = float(msg['b'][0][0])
+        self.ticker_info['ap'] = float(msg['a'][0][0])
+        self.callback['onTicker'](self.ticker_info)
+        ##### 标准化深度
+        mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+        step = mp * utils.EFF_RANGE / utils.LEVEL
+        bp = []
+        ap = []
+        bv = [0 for _ in range(utils.LEVEL)]
+        av = [0 for _ in range(utils.LEVEL)]
+        for i in range(utils.LEVEL):
+            bp.append(self.ticker_info["bp"]-step*i)
+        for i in range(utils.LEVEL):
+            ap.append(self.ticker_info["ap"]+step*i)
+        # 
+        price_thre = self.ticker_info["bp"] - step
+        index = 0
+        for bid in msg['b']:
+            price = float(bid[0])
+            amount = float(bid[1])
+            if price > price_thre:
+                bv[index] += amount
+            else:
+                price_thre -= step
+                index += 1
+                if index == utils.LEVEL:
+                    break
+                bv[index] += amount
+        price_thre = self.ticker_info["ap"] + step
+        index = 0
+        for ask in msg['a']:
+            price = float(ask[0])
+            amount = float(ask[1])
+            if price < price_thre:
+                av[index] += amount
+            else:
+                price_thre += step
+                index += 1
+                if index == utils.LEVEL:
+                    break
+                av[index] += amount
+        self.depth = bp + bv + ap + av
+        self.callback['onDepth']({'name':self.name,'data':self.depth})
+
+    def _update_trade(self, msg):
+        msg = ujson.loads(msg)
+        side = 'sell' if msg['m'] else 'buy'
+        price = float(msg['p'])
+        amount = float(msg['q'])
+        if price > self.max_buy or self.max_buy == 0.0:
+            self.max_buy = price
+        if price < self.min_sell or self.min_sell == 0.0:
+            self.min_sell = price
+        if side == 'buy':
+            self.buy_q += amount
+            self.buy_v += amount*price
+        elif side == 'sell':
+            self.sell_q += amount
+            self.sell_v += amount*price
+
+    def _update_account(self, msg):
+        msg = ujson.loads(msg)
+        for i in msg['a']['B']:
+            if i['a'].lower() == self.base.lower(): 
+                self.callback['onEquity']({self.quote:float(i['wb'])})
+    
+    def _update_order(self, msg):
+        msg = ujson.loads(msg)
+        i = msg['o']
+        if i['s'] == self.symbol:
+            if i['X'] == 'NEW':  # 新增订单
+                pass
+            if i['X'] == 'FILLED':  # 删除订单
+                self.callback['onOrder']({"deleteOrder":i['i']})
+            if i['X'] == 'CANCELED':  # 删除订单
+                self.callback['onOrder']({"deleteOrder":i['i']})
+
+    def _update_position(self, msg):
+        long_pos, short_pos = 0, 0
+        long_avg, short_avg = 0, 0 
+        msg = ujson.loads(msg)
+        for i in msg['a']['P']:
+            if i['s'] == self.symbol:
+                if i['ps'] == 'LONG':
+                    long_pos += abs(float(i['pa']))
+                    long_avg = abs(float(i['ep']))
+                if i['ps'] == 'SHORT':
+                    short_pos += abs(float(i['pa']))
+                    short_avg = abs(float(i['ep']))
+        pos = model.Position()
+        pos.longPos = long_pos
+        pos.shortPos = short_pos
+        pos.longAvg = long_avg
+        pos.shortAvg = short_avg
+        self.callback['onPosition'](pos)
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0
+        self.min_sell = 0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    async def before_trade(self):
+        session = aiohttp.ClientSession()
+        response = await session.get(
+            "https://dapi.binance.com/dapi/v1/exchangeInfo", 
+            proxy=self.proxy
+            )
+        response = await response.json()
+        for i in response['symbols']:
+            if self.symbol in i['symbol'].upper():
+                self.multiplier = float(i['contractSize'])
+        
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        session = aiohttp.ClientSession()
+        while True:
+            try:
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                # 登陆
+                if is_auth:
+                    listenKey = await self.get_sign()
+                else:
+                    listenKey = 'qqlh'
+                # 更新
+                await self.before_trade()
+                async with session.ws_connect(
+                        self.URL+listenKey,
+                        proxy=self.proxy,
+                        timeout=30,
+                        receive_timeout=30,
+                        ) as _ws:
+                    print(f'{self.name} ws连接成功')
+                    self.logger.info(f'{self.name} ws连接成功')
+                    # 订阅
+                    symbol = self.symbol.lower()
+                    if sub_fast:
+                        channels=[f"{symbol}@bookTicker",]
+                    else:
+                        channels=[f"{symbol}@depth20@100ms",]
+                    if sub_trade:
+                        channels.append(f"{symbol}@aggTrade")                        
+                    sub_str = ujson.dumps({"method": "SUBSCRIBE", "params": channels, "id":random.randint(1,1000)})
+                    await _ws.send_str(sub_str)
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except:
+                            print(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            break
+                        msg = msg.data
+                        # 处理消息
+                        if 'depthUpdate' in msg:self._update_depth(msg)
+                        elif 'aggTrade' in msg:self._update_trade(msg)
+                        elif 'bookTicker' in msg:self._update_ticker(msg)
+                        elif 'ACCOUNT_UPDATE' in msg:self._update_position(msg)
+                        elif 'ACCOUNT_UPDATE' in msg:self._update_account(msg)
+                        elif 'ORDER_TRADE_UPDATE' in msg:self._update_order(msg)
+                        elif 'ping' in msg:await _ws.send_str('pong')
+                        elif 'listenKeyExpired' in msg:raise Exception('key过期重连')
+            except:
+                _ws = None
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+
+
+

+ 573 - 0
exchange/binance_spot_rest.py

@@ -0,0 +1,573 @@
+
+import aiohttp
+import time
+import asyncio
+import zlib
+import json
+import hmac
+import base64
+import hashlib
+import traceback
+import random, sys
+from urllib.parse import urlparse
+import logging, logging.handlers
+import utils
+import model
+from decimal import Decimal
+from decimal import ROUND_HALF_UP, ROUND_FLOOR
+
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+
+
+class BinanceSpotRest:
+
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.HOST = 'https://api.binance.com'
+        else:
+            self.HOST = 'https://api.binance.com'
+        self.params = params
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        if len(self.params.pair.split('_')) > 2:
+            self.delivery = self.params.pair.split('_')[2] # 210924
+            self.symbol += f"_{self.delivery}"
+        self.name = self.params.name
+        self._SESSIONS = dict()
+        self.callback = {
+            "onMarket":empty_call,
+            "onPosition":empty_call,
+            "onOrder":empty_call,
+            "onEquity":empty_call,
+            "onTicker":empty_call,
+            "onExit":empty_call,
+            }
+        self.exchange_info = dict()
+        self.stepSize = None
+        self.tickSize = None
+        self.delays = []
+        self.avg_delay = 0
+        self.max_delay = 0
+        self.proxy = None
+        self.broker_id = self.params.broker_id
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.stop_flag = 0
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+
+    async def _request(self, method, uri, body=None, params=None, HOST=None):
+
+        headers = {}
+        headers["Content-Type"] = "application/json"
+        headers['X-MBX-APIKEY'] = self.params.access_key
+
+        if params != None:
+            params['timestamp']=int(time.time())*1000
+            query_string = "&".join(["{}={}".format(k, params[k]) for k in params.keys()])
+            signature = hmac.new(self.params.secret_key.encode(), msg=query_string.encode(), digestmod=hashlib.sha256).hexdigest()
+            params['signature']=signature
+        if HOST == None:
+            url = self.HOST + uri
+        else:
+            url = HOST + uri
+        # 发起请求
+        timeout = aiohttp.ClientTimeout(10)
+        session = self._get_session(url)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json() 
+            res_msg = msg + f' 回报 {res}'
+            self.logger.debug(res_msg)
+            if code not in (200, 201, 202, 203, 204, 205, 206):
+                print(f'URL:{url} PARAMS:{params} body:{body} ERROR:{res}')
+                if code == 429 or code == 418:
+                    self.callback['onExit'](f"{self.name} 即将触发限频封禁 紧急退出")
+                return None, str(res)
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            return res, None
+        except Exception as e:
+            print('网络请求错误')
+            print(f'URL:{url} PARAMS:{params} ERROR:{e}')
+            self.logger.error(e)
+            self.logger.error(traceback.format_exc())
+            return None, str(e)
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='LIMIT'):
+        '''
+            下单接口
+        '''
+        if symbol not in self.exchange_info:
+            await self.before_trade()
+            # amount = float(Decimal(str(amount//self.exchange_info[symbol].stepSize))*Decimal(str(self.exchange_info[symbol].stepSize)))
+            # price = float(Decimal(str(price//self.exchange_info[symbol].tickSize))*Decimal(str(self.exchange_info[symbol].tickSize)))
+            amount = utils.fix_amount(amount, self.exchange_info[symbol].stepSize)
+            price = utils.fix_price(price, self.exchange_info[symbol].tickSize)
+        if origin_side =='kd':
+            side = 'BUY'
+        elif origin_side =='pd':
+            side = 'SELL'
+        elif origin_side =='kk':
+            side = 'SELL'
+        elif origin_side =='pk':
+            side = 'BUY'
+        else:
+            raise Exception(f'下单参数错误 side:{origin_side}')
+        if float(amount) <= 0.0:
+            self.logger.error(f'下单参数错误 amount:{amount}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None, 'amount error'
+        if float(price) <= 0.0:
+            self.logger.error(f'下单参数错误 price:{price}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None, 'price error'
+        params = {
+            'symbol': symbol, 
+            'quantity': utils.num_to_str(amount, self.exchange_info[symbol].stepSize),
+            'side': side, 
+            'type':order_type,
+            'newClientOrderId':cid,
+        }
+        if order_type in ['LIMIT','STOP','TAKE_PROFIT']:
+            params['price'] = utils.num_to_str(price, self.exchange_info[symbol].tickSize)
+            params['timeInForce'] = 'GTC'
+        if self.params.debug == 'True':
+            await asyncio.sleep(0.1)
+            return None, None
+        else:
+            # 再报单
+            response, error = await self._request('POST', '/api/v3/order', params=params)
+            # 再更新
+            if response is not None:
+                if 'orderId' in response:
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['client_id'] = params["newClientOrderId"]
+                    order_event['order_id'] = response['orderId']
+                    self.callback["onOrder"](order_event)
+            if error:
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['filled_price'] = 0
+                order_event['filled'] = 0
+                order_event['client_id'] = params["newClientOrderId"]
+                self.callback["onOrder"](order_event)
+        return response, error
+    
+    async def cancel_order(self, order_id=None, client_id=None, symbol=None):
+        params = {
+            "symbol": self.symbol if symbol==None else symbol,
+        }
+        if order_id:
+            params["orderId"] = order_id
+        if client_id:
+            params["origClientOrderId"] = client_id
+        if self.params.debug == 'True':
+            await asyncio.sleep(0.1)
+            return None
+        else:
+            response, error = await self._request('DELETE', f'/api/v3/order', params=params)
+            if error:
+                print("撤单失败",error)
+                # if 'Unknown order sent.' in error: 
+                # 撤单失败 可能已经撤单 是否发生成交需要rest查
+                # if client_id:await self.check_order(client_id=client_id)
+                # if order_id:await self.check_order(order_id=order_id)
+                return error
+            if response:
+                pass
+                # if 'status' in response:
+                #     if response['status'] in ['CANCELED','EXPIRED']:  # 已撤销 删除本地订单表
+                #         order_event = dict()
+                #         order_event['status'] = "REMOVE"
+                #         order_event['client_id'] = response["origClientOrderId"]
+                #         order_event['order_id'] = response["orderId"]
+                #         order_event['filled'] = float(response["executedQty"])
+                #         order_event['filled_price'] = float(response["cummulativeQuoteQty"])/float(response["executedQty"]) if float(response["executedQty"]) > 0 else 0
+                #         self.callback['onOrder'](order_event)
+                return response
+
+    async def check_order(self, order_id=None, client_id=None, symbol=None):
+        params = {
+            "symbol": self.symbol if symbol==None else symbol,
+        }
+        if order_id:
+            params["orderId"] = order_id
+        if client_id:
+            params["origClientOrderId"] = client_id
+        if self.params.debug == 'True':
+            await asyncio.sleep(0.1)
+            return None
+        else:
+            response, error = await self._request('GET', f'/api/v3/order', params=params)
+            if error:
+                print("查单失败", error)
+                if 'Order does not exist' in error: 
+                    # 这种情况也可能还会有成交
+                    # 在订单从引擎到数据库的间隙查单会提示不存在 但实际有成交
+                    pass
+                return error
+            if response:
+                if 'status' in response:
+                    # 需要删除本地订单表的情况
+                    if response['status'] in ['CANCELED','EXPIRED','FILLED']:  
+                        order_event = dict()
+                        order_event['status'] = "REMOVE"
+                        order_event['client_id'] = response["clientOrderId"]
+                        order_event['order_id'] = response["orderId"]
+                        order_event['fee'] = 0.0   # 查询订单信息中没有手续费信息
+                        order_event['filled'] = float(response["executedQty"])
+                        order_event['filled_price'] = float(response["cummulativeQuoteQty"])/float(response["executedQty"]) if float(response["executedQty"]) > 0 else 0
+                        self.callback['onOrder'](order_event)
+                    elif response['status'] in ["NEW"]: # 需要更新本地表的情况
+                        order_event = dict()
+                        order_event['status'] = "NEW"
+                        order_event['client_id'] = response["clientOrderId"]
+                        order_event['order_id'] = response['orderId']
+                        self.callback['onOrder'](order_event)
+                return response
+    
+    async def get_order_list(self):
+        '''
+            获取挂单表
+        '''
+        response, error = await self._request('GET', '/api/v3/openOrders', params={'symbol':self.symbol})
+        orders = [] # 查询当前挂单 只可能出现 new 和 partfill 默认成交为0 只有 done状态的订单才考虑是否有成交
+        if response:
+            for i in response:
+                order_event = dict()
+                order_event['status'] = "NEW"
+                order_event['filled'] = 0
+                order_event['filled_price'] = 0
+                order_event['client_id'] = i["clientOrderId"]
+                order_event['order_id'] = i['orderId']
+                self.callback["onOrder"](order_event)
+                orders.append(order_event)
+        if error:
+            print('查询列表出错',error)
+        return orders
+
+    async def get_history_order(self):
+        params = {
+            "symbol":self.symbol,
+            "limit":1000,
+            "startTime":1635815135000,
+            "endTime":1635901535000,
+        }
+        response, error = await self._request('GET', '/api/v3/allOrders', params=params)
+        import json
+        fp = open("123.csv", "w")
+        json.dump(response,fp)
+        return response
+    
+    async def get_server_time(self):
+        params = {}
+        response = await self._request('GET', '/api/v3/time', params=params)
+        return response
+
+    async def before_trade(self):
+        response, error = await self._request('GET', '/api/v3/exchangeInfo', params=None)
+        if response:
+            for i in response['symbols']:
+                if self.symbol in i['symbol'].upper():
+                    self.tickSize = float(i['filters'][0]['tickSize'])
+                    self.stepSize = float(i['filters'][1]['stepSize'])
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i['symbol'].upper()
+                exchange_info.multiplier = 1
+                exchange_info.tickSize = float(i['filters'][0]['tickSize'])
+                exchange_info.stepSize = float(i['filters'][1]['stepSize'])
+                self.exchange_info[exchange_info.symbol] = exchange_info
+        if error:
+            print('获取市场信息错误',error)
+
+    async def get_equity(self):
+        res, err = await self.get_account()
+        if res:
+            for i in res['balances']:
+                if self.quote == i['asset'].upper():
+                    cash = float(i['free']) + float(i['locked'])
+                    self.callback['onEquity']({
+                        self.quote:cash
+                    })
+                    self.cash_value = cash
+                if i['asset'].upper() == self.base:
+                    coin = float(i['free']) + float(i['locked'])
+                    self.callback['onEquity']({
+                        self.base:coin
+                    })
+                    self.coin_value = coin
+        if err:
+            print('获取账户信息错误',err)
+        await asyncio.sleep(1)
+    
+    async def universalTransfer(self, _type='UMFUTURE_MAIN', asset='USDT', amount=0):
+        params = {}
+        params['type'] = _type
+        params['asset'] = asset
+        params['amount'] = amount
+        print('发起提现')
+        response = await self._request('POST', '/sapi/v3/asset/transfer', params=params, HOST='https://api.binance.com')
+        print(f'提现结果 {response}')
+        return response
+
+    async def futuresTransfer(self, _type='2', asset='USDT', amount=0):
+        '''
+            1: 现货账户向USDT合约账户划转
+            2: USDT合约账户向现货账户划转
+            3: 现货账户向币本位合约账户划转
+            4: 币本位合约账户向现货账户划转
+        '''
+        params = {}
+        params['type'] = _type
+        params['asset'] = asset
+        params['amount'] = amount
+        print('发起转账')
+        response = await self._request('POST', '/sapi/v3/futures/transfer', params=params, HOST='https://api.binance.com')
+        print(f'转账结果 {response}')
+        return response
+
+    async def get_account(self):
+        return await self._request('GET','/api/v3/account', params={})
+
+    async def get_ticker(self):
+        res ,err = await self._request('GET', '/api/v3/ticker/bookTicker', params=None)
+        if res:
+            for i in res:
+                if i['symbol'] == self.symbol:
+                    ap = float(i['bidPrice'])
+                    bp = float(i['askPrice'])
+                    mp = (ap+bp)*0.5
+                    d = {"name":self.name,'mp': mp, 'bp':bp, 'ap':ap}
+                    self.callback['onTicker'](d)
+                    return d
+        if err:
+            self.logger.error(err)
+            return None 
+
+    async def buy_token(self):
+        '''买入平台币'''
+        pass
+
+    async def check_position(self, hold_coin=0.0):
+        '''
+            重置账户挂单和仓位 已支持全品种
+        '''
+        try:
+            self.logger.info('检查遗漏订单')
+            response, error = await self._request('GET', '/api/v3/openOrders', params={})
+            self.logger.info(response)
+            self.logger.info(error)
+            if error:self.logger.error(error)
+            if response:
+                for i in response:
+                    res = await self.cancel_order(order_id=i['orderId'], symbol=i['symbol'])
+                    await asyncio.sleep(0.1)
+                    self.logger.info(res)
+            self.logger.info('检查遗漏仓位')
+            ###########
+            res ,err = await self._request('GET', '/api/v3/ticker/bookTicker', params=None)
+            tickers_mp = dict()
+            if res:
+                for i in res:
+                    ap = float(i['bidPrice'])
+                    bp = float(i['askPrice'])
+                    mp = (ap+bp)*0.5
+                    tickers_mp[i['symbol']] = mp 
+            if err:
+                self.logger.error(err)
+            ###########
+            if self.exchange_info == dict():
+                await self.before_trade()
+            ###########
+            response, error = await self._request('GET','/api/v3/account', params={})
+            if error is not None:self.logger.error(error)
+            if response:
+                for i in response['balances']:
+                    asset = i['asset']
+                    if asset in ['BNB','USDT', 'TUSD']:
+                        continue
+                    symbol = asset + 'USDT'
+                    if symbol not in tickers_mp:
+                        continue
+                    coin = abs(float(i['free']))+abs(float(i['locked']))
+                    if coin == 0.0:
+                        continue
+                    mp = tickers_mp[symbol]
+                    coin_value = coin * mp
+                    if symbol == self.symbol:
+                        _hold_coin = hold_coin
+                    else:
+                        _hold_coin = 0                
+                    diff = _hold_coin - coin_value
+                    diff *= 0.99 # 避免无法下单
+                    self.logger.info(f'需要调整现货仓位{diff}usd')
+                    if diff > 20.0:
+                        self.logger.info('买入现货')
+                        res, err = await self.take_order(
+                            symbol,
+                            diff/mp,
+                            'kd',
+                            1,
+                            utils.get_cid(),
+                            'MARKET'
+                        )
+                        self.logger.info(res)
+                        self.logger.info(err)
+                    elif diff < -20.0:
+                        self.logger.info('卖出现货')
+                        res, err = await self.take_order(
+                            symbol,
+                            -diff/mp,
+                            'kk',
+                            1,
+                            utils.get_cid(),
+                            'MARKET'
+                        )
+                        self.logger.info(res)
+                        self.logger.info(err)
+            self.logger.info('遗留仓位检测完毕')
+        except:
+            self.logger.error("清仓程序执行出错")
+            self.logger.error(traceback.format_exc())
+        return
+
+    async def go(self):
+        '''
+            盘前
+            获取市场信息
+            获取账户信息
+            更改仓位模式(期货)
+            清空仓位和挂单
+            盘中
+            更新账户信息
+            更新挂单列表
+            更新仓位信息
+            更新延迟信息
+        '''
+        print('Rest循环器启动')
+        interval = 60  # 不能太快防止占用限频
+        ### beforeTrade
+        await self.before_trade()
+        await asyncio.sleep(1)
+        ### onTrade
+        loop = 0
+        while 1:
+            loop += 1
+            try:
+                # 停机信号
+                if self.stop_flag:
+                    return
+                # 更新账户
+                res, err = await self.get_account()
+                if res:
+                    for i in res['balances']:
+                        if self.quote == i['asset'].upper():
+                            self.callback['onEquity']({
+                                self.quote: float(i['free']) + float(i['locked'])
+                            })
+                        if i['asset'].upper() == self.base:
+                            self.callback['onEquity']({
+                                self.base: float(i['free']) + float(i['locked'])
+                            })
+                await asyncio.sleep(interval)
+                # 打印延迟
+                self.get_delay_info()
+                self.logger.debug(f'报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+            except asyncio.CancelledError:
+                return
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(30)
+    
+    async def handle_signals(self, orders):
+        '''
+            执行策略指令
+            撤销订单
+            检查订单
+            下达订单
+        '''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    if cid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(client_id=cid))
+                    elif oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3]
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    cid = orders[order_name][0]
+                    # oid = orders[order_name][1]
+                    asyncio.get_event_loop().create_task(self.check_order(client_id=cid))
+        except Exception as e:
+            traceback.print_exc()
+            self.logger.error("执行信号出错"+str(e))
+            await asyncio.sleep(0.1)
+
+

+ 521 - 0
exchange/binance_spot_ws.py

@@ -0,0 +1,521 @@
+import aiohttp
+import time
+import asyncio
+import zlib
+import json
+import ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random, csv, sys, utils
+import logging, logging.handlers
+import model
+
+def empty_call(msg):
+    pass
+
+
+def inflate(data):
+    '''
+        解压缩数据
+    '''
+    decompress = zlib.decompressobj(-zlib.MAX_WBITS)
+    inflated = decompress.decompress(data)
+    inflated += decompress.flush()
+    return inflated
+
+
+class BinanceSpotWs:
+
+    def __init__(self, params:model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.URL = 'wss://stream.binance.com:9443/ws'
+        else:
+            self.URL = 'wss://stream.binance.com:9443/ws'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        self.callback = {
+            "onMarket":self.save_market,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.stop_flag = 0
+        self.public_update_time = time.time()
+        self.private_update_time = time.time()
+        self.expired_time = 300
+        ### 更新id
+        self.update_flag_e = 0
+        self.update_flag_u = 0
+        ### 
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.depth = []
+        ####
+        self.depth_update = []
+        self.need_flash = 1
+        self.lastUpdateId = None # 就是小写u
+        self.depth_full = dict()
+        self.depth_full['bids'] = dict()
+        self.depth_full['asks'] = dict()
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = self.params.interval
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+    async def get_sign(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        url = 'https://api.binance.com/api/v3/userDataStream'
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            url, 
+            params=None,
+            headers=headers, 
+            timeout=10, 
+            proxy=self.proxy
+            )
+        self.logger.debug("申请key")
+        login_str = await response.text()
+        self.logger.debug(login_str)
+        await session.close()
+        return ujson.loads(login_str)['listenKey']
+
+    async def long_key(self,listenKey):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        params = {
+            'listenKey':listenKey,
+        }
+        url = 'https://api.binance.com/api/v3/userDataStream'
+        session = aiohttp.ClientSession()
+        response = await session.put(
+            url, 
+            params=params,
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        self.logger.debug("续期key")
+        login_str = await response.text()
+        self.logger.debug(login_str)
+        await session.close()
+        return ujson.loads(login_str)
+
+    def _check_update_e(self, id):
+        if id > self.update_flag_e:
+            self.update_flag_e = id
+            return 0
+        else:
+            return 1
+
+    def _check_update_u(self, id):
+        if id > self.update_flag_u:
+            self.update_flag_u = id
+            return 0
+        else:
+            return 1
+
+    # @timeit
+    def _update_depth20(self, msg):
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        if self._check_update_u(msg['lastUpdateId']):
+            return
+        else:
+            # 更新ticker信息但不触发
+            self.ticker_info["bp"] = float(msg['bids'][0][0])
+            self.ticker_info["ap"] = float(msg['asks'][0][0])
+            self.callback['onTicker'](self.ticker_info)
+            ##### 标准化深度
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for i in msg['bids']:
+                price = float(i[0])
+                amount = float(i[1])
+                if price > price_thre:
+                    bv[index] += amount
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += amount
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for i in msg['asks']:
+                price = float(i[0])
+                amount = float(i[1])
+                if price < price_thre:
+                    av[index] += amount
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += amount
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+
+    def _update_depth(self, msg):
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        self.depth_update.append(msg)
+        if self.need_flash == 0: # 可以更新深度
+            for i in self.depth_update[:]:
+                u = i['u']
+                U = i['U']
+                # print(f'处理 {u}')
+                if u < self.lastUpdateId: # 丢弃过旧的信息
+                    self.depth_update.remove(i)
+                else:
+                    if u >= self.lastUpdateId+1 and U <= self.lastUpdateId+1: # 后续更新本地副本
+                        if U != self.lastUpdateId + 1:
+                            self.need_flash = 1
+                            self.logger.error('发现遗漏增量深度推送 重置绝对深度')
+                            return
+                        # print(f'符合要求 {u}')
+                        # 开始更新深度
+                        for j in i['b']:
+                            price = float(j[0])
+                            amount = float(j[1])
+                            if amount > 0:
+                                self.depth_full['bids'][price] = amount 
+                            else:
+                                if price in self.depth_full['bids']:del(self.depth_full['bids'][price])
+                        for j in i['a']:
+                            price = float(j[0])
+                            amount = float(j[1])
+                            if amount > 0:
+                                self.depth_full['asks'][price] = amount 
+                            else:
+                                if price in self.depth_full['asks']:del(self.depth_full['asks'][price])
+                        self.depth_update.remove(i)
+                        self.lastUpdateId = u
+                    else:
+                        self.logger.error('增量深度不满足文档要求的条件')
+            buyP = list(self.depth_full['bids'].keys())
+            buyP.sort(reverse=True) # 从大到小
+            sellP = list(self.depth_full['asks'].keys())
+            sellP.sort(reverse=False) # 从小到大
+            # update ticker
+            self.ticker_info["bp"] = float(buyP[0])
+            self.ticker_info["ap"] = float(sellP[0])
+            self.callback['onTicker'](self.ticker_info)
+            if self.ticker_info["bp"] > self.ticker_info["ap"]:
+                self.need_flash = 1
+            ##### normalized depth
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for price in buyP:
+                if price > price_thre:
+                    bv[index] += self.depth_full['bids'][price]
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += self.depth_full['bids'][price]
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for price in sellP:
+                if price < price_thre:
+                    av[index] += self.depth_full['asks'][price]
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += self.depth_full['asks'][price]
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+
+    def _update_ticker(self, msg):
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        if self._check_update_u(msg['u']):
+            return
+        else:
+            bp = float(msg['b'])
+            bq = float(msg['B'])
+            ap = float(msg['a'])
+            aq = float(msg['A'])
+            self.ticker_info['bp'] = bp
+            self.ticker_info['ap'] = ap
+            self.callback['onTicker'](self.ticker_info)
+            #### 标准化深度
+            self.depth = [bp,bq,ap,aq]
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+    
+    def _update_trade(self, msg):
+        '''
+            binance spot 无法和depth比对时间戳 放弃修正depth
+        '''
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        price = float(msg['p'])
+        amount = float(msg['q'])
+        side = 'sell' if msg['m'] else 'buy'
+        if price > self.max_buy or self.max_buy == 0.0:
+            self.max_buy = price
+        if price < self.min_sell or self.min_sell == 0.0:
+            self.min_sell = price
+        if side == 'buy':
+            self.buy_q += amount
+            self.buy_v += amount*price
+        elif side == 'sell':
+            self.sell_q += amount
+            self.sell_v += amount*price
+        #### 修正ticker ####
+        # side = 'sell' if msg['m'] else 'buy'
+        # if side == 'buy' and price > self.ticker_info['ap']:
+        #     self.ticker_info['ap'] = price
+        #     self.callback['onTicker'](self.ticker_info)
+        # if side == 'sell' and price < self.ticker_info['bp']:
+        #     self.ticker_info['bp'] = price
+        #     self.callback['onTicker'](self.ticker_info)
+
+    def _update_account(self, msg):
+        msg = ujson.loads(msg)
+        for i in msg['B']:
+            if i['a'] == self.base:
+                coin = float(i['f'])+float(i['l'])
+                self.callback['onEquity'] = {
+                    self.base:coin
+                }
+            if i['a'] == self.quote:
+                cash = float(i['f'])+float(i['l'])
+                self.callback['onEquity'] = {
+                    self.quote:cash
+                }
+        self.private_update_time = time.time()
+    
+    def _update_order(self, msg):
+        '''将ws收到的订单信息触发quant'''
+        msg = ujson.loads(msg)
+        self.logger.debug(f"ws订单推送 {msg}")
+        data = msg
+        if self.symbol in data['s']:
+            order_event = dict()
+            status = data['X']
+            if status == "NEW": # 新增
+                local_status = "NEW"
+            elif status in ["CANCELED", "FILLED", "EXPIRED"]: # 删除 
+                local_status = "REMOVE"  
+            elif status in ["PARTIALLY_FILLED"]: # 忽略
+                return 
+            else:
+                print("未知订单状态",data)
+                return
+            order_event['status'] = local_status 
+            order_event['filled_price'] = float(data['Z'])/float(data['z']) if float(data['z']) > 0.0 else 0.0
+            order_event['filled'] = float(data['z'])
+            if data['C'] == '':
+                cid = data['c']
+            else:
+                cid = data['C']
+            order_event['client_id'] = cid
+            order_event['order_id'] = data['i']
+            order_event['fee'] = float(data['n'])
+            self.callback['onOrder'](order_event)
+        self.private_update_time = time.time()
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+    
+    async def get_depth_flash(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        url = f'https://api.binance.com/api/v3/depth?symbol={self.symbol}&limit=1000'
+        session = aiohttp.ClientSession()
+        response = await session.get(
+            url, 
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        depth_flash = await response.text()
+        await session.close()
+        return ujson.loads(depth_flash)
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                if self.stop_flag == 1:
+                    return
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        while True:
+            try:
+                # 重置更新时间
+                self.public_update_time = time.time()
+                self.private_update_time = time.time()
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                # 登陆
+                ws_url = self.URL
+                if is_auth:
+                    listenKey = await self.get_sign()
+                    listenKeyTime = time.time()
+                    ws_url += '/'+listenKey
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws连接成功')
+                    self.logger.info(f'{self.name} ws连接成功')
+                    # 订阅 币安 现货 bbo没有事件标记 无法区分
+                    symbol = self.symbol.lower()
+                    if sub_fast:
+                        channels=[f"{symbol}@bookTicker",]
+                    else:
+                        channels=[
+                            # f"{symbol}@depth@100ms",
+                            f"{symbol}@depth20@100ms",
+                            ]
+                    if sub_trade:
+                        channels.append(f"{symbol}@aggTrade")
+                    sub_str = ujson.dumps({"method": "SUBSCRIBE", "params": channels, "id":random.randint(1,1000)})
+                    await _ws.send_str(sub_str)
+                    self.need_flash = 1
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:
+                            await _ws.close()
+                            return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=10)
+                        except:
+                            print(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            break
+                        msg = msg.data
+                        # 处理消息
+                        # if 'depthUpdate' in msg:self._update_depth(msg)
+                        if 'lastUpdateId' in msg:self._update_depth20(msg)
+                        elif 'aggTrade' in msg:self._update_trade(msg)
+                        elif 'A' in msg and 'B' in msg and 'e' not in msg:self._update_ticker(msg)
+                        elif 'outboundAccountPosition' in msg:self._update_account(msg)
+                        elif 'executionReport' in msg:self._update_order(msg)
+                        elif 'ping' in msg:await _ws.send_str('pong')
+                        elif 'listenKeyExpired' in msg or 'expired' in str(msg).lower():
+                            raise Exception('key过期重连')
+                        if is_auth:
+                            if time.time() - listenKeyTime > 60*15: # 每15分钟续一次
+                                print('续期listenKey')
+                                await self.long_key(listenKey)
+                                listenKeyTime = time.time()
+                            if time.time() - self.private_update_time > self.expired_time*5:
+                                raise Exception('长期未更新私有信息重连')
+                        if time.time() - self.public_update_time > self.expired_time:
+                            raise Exception('长期未更新公有信息重连')
+                        # if self.need_flash:
+                        #     print('rest获取绝对深度')
+                        #     depth_flash = await self.get_depth_flash()
+                        #     self.lastUpdateId = depth_flash['lastUpdateId']
+                        #     # 检查已有更新中是否包含
+                        #     self.depth_full['bids'] = dict()
+                        #     self.depth_full['asks'] = dict()
+                        #     for i in depth_flash['bids']:self.depth_full['bids'][float(i[0])] = float(i[1])
+                        #     for i in depth_flash['asks']:self.depth_full['asks'][float(i[0])] = float(i[1])
+                        #     self.need_flash = 0
+            except:
+                _ws = None
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                # await asyncio.sleep(1)
+

+ 592 - 0
exchange/binance_usdt_swap_rest.py

@@ -0,0 +1,592 @@
+
+import aiohttp
+import time
+import asyncio
+import zlib
+import json
+import hmac
+import base64
+import hashlib
+import traceback
+import random, sys
+from urllib.parse import urlparse
+import logging, logging.handlers
+import utils
+import model
+from decimal import Decimal
+
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+
+
+class BinanceUsdtSwapRest:
+
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.HOST = 'https://fapi.binance.com'
+        else:
+            self.HOST = 'https://fapi.binance.com'
+        self.params = params
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        if self.base in ['shib','xec']:
+            print("请输入1000shib/1000xec")
+        if len(self.params.pair.split('_')) > 2:
+            self.delivery = self.params.pair.split('_')[2] # 210924
+            self.symbol += f"_{self.delivery}"  
+        self.name = self.params.name
+        self._SESSIONS = dict()
+        self.callback = {
+            "onMarket":empty_call,
+            "onPosition":empty_call,
+            "onOrder":empty_call,
+            "onEquity":empty_call,
+            "onExit":empty_call,
+            "onTicker":empty_call,
+            }
+        self.exchange_info = dict()
+        self.stepSize = None
+        self.tickSize = None
+        self.delays = []
+        self.avg_delay = 0
+        self.max_delay = 0
+        self.proxy = None
+        self.broker_id = self.params.broker_id
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.stop_flag = 0
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+
+    async def _request(self, method, uri, body=None, params=None, HOST=None):
+
+        headers = {}
+        headers["Content-Type"] = "application/json"
+        headers['X-MBX-APIKEY'] = self.params.access_key
+
+        params['timestamp']=int(time.time())*1000
+        query_string = "&".join(["{}={}".format(k, params[k]) for k in params.keys()])
+        signature = hmac.new(self.params.secret_key.encode(), msg=query_string.encode(), digestmod=hashlib.sha256).hexdigest()
+        params['signature']=signature
+
+        if HOST == None:
+            url = self.HOST + uri
+        else:
+            url = HOST + uri
+
+        # 发起请求
+        timeout = aiohttp.ClientTimeout(10)
+        session = self._get_session(url)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json() 
+            res_msg = msg + f' 回报 {res}'
+            self.logger.debug(res_msg)
+            if code not in (200, 201, 202, 203, 204, 205, 206):
+                print(f'URL:{url} PARAMS:{params} body:{body} ERROR:{res}')
+                if code == 429:
+                    self.callback['onExit'](f"{self.name} 即将触发限频封禁 紧急退出")
+                return None, str(res)
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            self.get_delay_info()
+            return res, None
+        except Exception as e:
+            print('网络请求错误')
+            print(f'URL:{url} PARAMS:{params} ERROR:{e}')
+            self.logger.error(e)
+            self.logger.error(traceback.format_exc())
+            return None, str(e)
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='LIMIT'):
+        '''
+            下单接口
+        '''
+        if symbol not in self.exchange_info:
+            await self.before_trade()
+            # amount = float(Decimal(str(amount//self.exchange_info[symbol].stepSize))*Decimal(str(self.exchange_info[symbol].stepSize)))
+            # price = float(Decimal(str(price//self.exchange_info[symbol].tickSize))*Decimal(str(self.exchange_info[symbol].tickSize)))
+            amount = utils.fix_amount(amount, self.exchange_info[symbol].stepSize)
+            price = utils.fix_price(price, self.exchange_info[symbol].tickSize)
+        if origin_side =='kd':
+            side = 'BUY'
+            positionSide = 'LONG'
+        elif origin_side =='pd':
+            side = 'SELL'
+            positionSide = 'LONG'
+        elif origin_side =='kk':
+            side = 'SELL'
+            positionSide = 'SHORT'
+        elif origin_side =='pk':
+            side = 'BUY'
+            positionSide = 'SHORT'
+        else:
+            raise Exception(f'下单参数错误 side:{origin_side}')
+        if amount <= 0.0:
+            self.logger.error(f'下单参数错误 amount:{amount}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        if price <= 0.0:
+            self.logger.error(f'下单参数错误 price:{price}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        params = {
+            'symbol': symbol, 
+            'quantity': utils.num_to_str(amount, self.exchange_info[symbol].stepSize),
+            'side': side, 
+            'positionSide': positionSide, 
+            'type':order_type,
+            'newClientOrderId':cid,
+        }
+        if order_type in ['LIMIT','STOP','TAKE_PROFIT']:
+            params['price'] = utils.num_to_str(price, self.exchange_info[symbol].tickSize)
+
+            # if origin_side in ['kd', 'kk']:
+            params['timeInForce'] = 'GTC'
+            # else:    
+            #     params['timeInForce'] = 'GTC'
+
+        if self.params.debug == 'True':
+            await asyncio.sleep(0.1)
+            return None, None
+        else:
+            # 再报单
+            response, error = await self._request('POST', '/fapi/v1/order', params=params)
+            # 再更新
+            if response is not None:
+                if 'orderId' in response:
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['client_id'] = params["newClientOrderId"]
+                    order_event['order_id'] = response['orderId']
+                    self.callback["onOrder"](order_event)
+            if error:
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['filled_price'] = 0
+                order_event['filled'] = 0
+                order_event['client_id'] = params["newClientOrderId"]
+                self.callback["onOrder"](order_event)
+        return response, error
+    
+    async def cancel_order(self, order_id=None, client_id=None, symbol=None):
+        params = {
+            "symbol": self.symbol if symbol==None else symbol,
+        }
+        if order_id:
+            params["orderId"] = order_id
+        if client_id:
+            params["origClientOrderId"] = client_id
+        if self.params.debug == 'True':
+            await asyncio.sleep(0.1)
+            return None
+        else:
+            response, error = await self._request('DELETE', f'/fapi/v1/order', params=params)
+            if error:
+                print("撤单失败",error)
+                # if 'Unknown order sent.' in error: 
+                # 撤单失败 可能已经撤单 是否发生成交需要rest查
+                # if client_id:await self.check_order(client_id=client_id)
+                # if order_id:await self.check_order(order_id=order_id)
+                return error
+            if response:
+                # if 'status' in response:
+                #     if response['status'] in ['CANCELED','EXPIRED']:  # 已撤销 删除本地订单表
+                #         order_event = dict()
+                #         order_event['status'] = "REMOVE"
+                #         order_event['client_id'] = response["clientOrderId"]
+                #         order_event['order_id'] = response["orderId"]
+                #         order_event['filled'] = float(response["executedQty"])
+                #         order_event['filled_price'] = float(response["cumQuote"])/float(response["executedQty"]) if float(response["executedQty"]) > 0 else 0
+                #         self.callback['onOrder'](order_event)
+                return response
+
+    async def check_order(self, order_id=None, client_id=None, symbol=None):
+        params = {
+            "symbol": self.symbol if symbol==None else symbol,
+        }
+        if order_id:
+            params["orderId"] = order_id
+        if client_id:
+            params["origClientOrderId"] = client_id
+        if self.params.debug == 'True':
+            await asyncio.sleep(0.1)
+            return None
+        else:
+            response, error = await self._request('GET', f'/fapi/v1/order', params=params)
+            if error:
+                print("查单失败", error)
+                if 'Order does not exist' in error: 
+                    # 这种情况也可能还会有成交
+                    # 在订单从引擎到数据库的间隙查单会提示不存在 但实际有成交
+                    pass
+                return error
+            if response:
+                if 'status' in response:
+                    # 需要删除本地订单表的情况
+                    if response['status'] in ['CANCELED','EXPIRED','FILLED']:  
+                        order_event = dict()
+                        order_event['status'] = "REMOVE"
+                        order_event['client_id'] = response["clientOrderId"]
+                        order_event['order_id'] = response["orderId"]
+                        order_event['filled'] = float(response["executedQty"])
+                        order_event['filled_price'] = float(response["cumQuote"])/float(response["executedQty"]) if float(response["executedQty"]) > 0 else 0
+                        self.callback['onOrder'](order_event)
+                    elif response['status'] in ["NEW"]: # 需要更新本地表的情况
+                        order_event = dict()
+                        order_event['status'] = "NEW"
+                        order_event['client_id'] = response["clientOrderId"]
+                        order_event['order_id'] = response['orderId']
+                        self.callback['onOrder'](order_event)
+                return response
+    
+    async def get_order_list(self):
+        '''
+            获取挂单表
+        '''
+        response, error = await self._request('GET', '/fapi/v1/openOrders', params={'symbol':self.symbol})
+        orders = [] # 查询当前挂单 只可能出现 new 和 partfill 默认成交为0 只有 done状态的订单才考虑是否有成交
+        if response:
+            for i in response:
+                order_event = dict()
+                order_event['status'] = "NEW"
+                order_event['filled'] = 0
+                order_event['filled_price'] = 0
+                order_event['client_id'] = i["clientOrderId"]
+                order_event['order_id'] = i['orderId']
+                self.callback["onOrder"](order_event)
+                orders.append(order_event)
+        if error:
+            print('查询列表出错',error)
+        return orders
+
+    async def get_history_order(self):
+        params = {
+            "limit":1000,
+            "startTime":1631260140000,
+            "endTime":1631260200000,
+        }
+        response, error = await self._request('GET', '/fapi/v1/allOrders', params=params)
+        import json
+        fp = open("123.csv", "w")
+        json.dump(response,fp)
+        return response
+    
+    async def get_server_time(self):
+        params = {}
+        response = await self._request('GET', '/fapi/v1/time', params=params)
+        return response
+
+    async def change_pos_side(self, dual='true'):
+        '''
+            获取仓位模式
+        '''
+        params = {'dualSidePosition':dual}
+        response = await self._request('POST', '/fapi/v1/positionSide/dual', params=params)
+        return response
+
+    async def get_exchange_info(self):
+        response, error = await self._request('GET', '/fapi/v1/exchangeInfo', params={})
+        return response, error
+        
+    async def before_trade(self):
+        ##########
+        response, error = await self.get_exchange_info()
+        if response:
+            for i in response['symbols']:
+                if self.symbol in i['symbol'].upper():
+                    self.stepSize = float(i['filters'][1]['stepSize'])
+                    self.tickSize = float(i['filters'][0]['tickSize'])
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i['symbol'].upper()
+                exchange_info.multiplier = 1
+                exchange_info.stepSize = float(i['filters'][1]['stepSize'])
+                exchange_info.tickSize = float(i['filters'][0]['tickSize'])
+                self.exchange_info[exchange_info.symbol] = exchange_info
+        if error:
+            print('获取市场信息错误',error)
+        await self.change_pos_side()
+        await asyncio.sleep(1)
+        await self.get_position()
+
+    async def get_equity(self):
+        ##########
+        res, err = await self.get_account()
+        if res:
+            for i in res:
+                if self.quote == i['asset'].upper():
+                    self.callback['onEquity']({self.quote:float(i['balance'])})
+                    self.cash_value = float(i['balance'])
+        if err:
+            print('获取账户信息错误', err)
+        ##########
+    
+    async def universalTransfer(self, _type='UMFUTURE_MAIN', asset='USDT', amount=0):
+        params = {}
+        params['type'] = _type
+        params['asset'] = asset
+        params['amount'] = amount
+        print('发起提现')
+        response = await self._request('POST', '/sapi/v1/asset/transfer', params=params, HOST='https://api.binance.com')
+        print(f'提现结果 {response}')
+        return response
+
+    async def futuresTransfer(self, _type='2', asset='USDT', amount=0):
+        '''
+            1: 现货账户向USDT合约账户划转
+            2: USDT合约账户向现货账户划转
+            3: 现货账户向币本位合约账户划转
+            4: 币本位合约账户向现货账户划转
+        '''
+        params = {}
+        params['type'] = _type
+        params['asset'] = asset
+        params['amount'] = amount
+        print('发起转账')
+        response = await self._request('POST', '/sapi/v1/futures/transfer', params=params, HOST='https://api.binance.com')
+        print(f'转账结果 {response}')
+        return response
+
+    async def get_account(self):
+        response, error = await self._request('GET','/fapi/v2/balance', params={})
+        return response, error
+
+    async def get_ticker(self):
+        response, error = await self._request('GET','/fapi/v1/ticker/bookTicker', params={"symbol":self.symbol})
+        if response:
+            if response['symbol'] == self.symbol:
+                ap = float(response['bidPrice'])
+                bp = float(response['askPrice'])
+                mp = (ap+bp)*0.5
+                d = {"name":self.name,'mp': mp, 'bp':bp, 'ap':ap}
+                self.callback['onTicker'](d)
+                return d
+        if error:
+            print(error)
+            return None
+
+    async def buy_token(self):
+        '''买入平台币'''
+        pass
+
+    async def check_position(self, hold_coin=0.0):
+        '''
+            检查是否存在非运行币种的仓位并take平仓  已支持全品种
+        '''
+        try:
+            self.logger.info('检查遗漏订单')
+            response, error = await self._request('GET', '/fapi/v1/openOrders', params={})
+            self.logger.info(response)
+            self.logger.info(error)
+            if error:self.logger.error(error)
+            if response:
+                for i in response:
+                    res = await self.cancel_order(order_id=i['orderId'], symbol=i['symbol'])
+                    await asyncio.sleep(0.1)
+                    self.logger.info(res)
+            # 清空仓位
+            self.logger.info('检查遗漏仓位')
+            response, error = await self._request('GET','/fapi/v2/positionRisk', params={})
+            if error is not None:self.logger.error(error)
+            if response:
+                for i in response:
+                    symbol = i['symbol']
+                    if i['positionSide'] == 'LONG':
+                        longPos = abs(float(i['positionAmt']))
+                        longAvg = abs(float(i['entryPrice']))
+                        if longPos > 0:
+                            self.logger.info('发现多头遗留仓位 进行平仓')
+                            res, err = await self.take_order(
+                                symbol,
+                                longPos,
+                                'pd',
+                                longAvg*0.95,
+                                utils.get_cid(),
+                                order_type='MARKET',
+                            )
+                            self.logger.info(res)
+                            self.logger.info(err)
+                    if i['positionSide'] == 'SHORT':
+                        shortPos = abs(float(i['positionAmt']))
+                        shortAvg = abs(float(i['entryPrice']))
+                        if shortPos > 0:
+                            self.logger.info('发现空头遗留仓位 进行平仓')
+                            res, err = await self.take_order(
+                                symbol,
+                                shortPos,
+                                'pk',
+                                shortAvg*1.05,
+                                utils.get_cid(),
+                                order_type='MARKET',
+                            )
+                            self.logger.info(res)
+                            self.logger.info(err)
+            self.logger.info('遗留仓位检测完毕')
+        except:
+            self.logger.error("清仓程序执行出错")
+            self.logger.error(traceback.format_exc())
+        return
+
+    async def get_position(self):
+        '''
+            获取仓位信息
+        '''
+        response, error = await self._request('GET','/fapi/v2/positionRisk', params={'symbol':self.symbol})
+        longPos, shortPos = 0, 0
+        longAvg, shortAvg = 0, 0
+        position = {
+            "longPos":abs(longPos),
+            "shortPos":abs(shortPos), 
+            "longAvg":abs(longAvg), 
+            "shortAvg":abs(shortAvg)}
+        if response:
+            for i in response:
+                if i['symbol'] == self.symbol:
+                    if i['positionSide'] == 'LONG':
+                        longPos = float(i['positionAmt'])
+                        longAvg = float(i['entryPrice'])
+                    if i['positionSide'] == 'SHORT':
+                        shortPos = float(i['positionAmt'])
+                        shortAvg = float(i['entryPrice'])
+            position = model.Position()
+            position.longPos = abs(longPos)
+            position.longAvg = abs(longAvg)
+            position.shortPos = abs(shortPos)
+            position.shortAvg = abs(shortAvg)
+            self.callback['onPosition'](position)
+        return position
+
+    async def go(self):
+        '''
+            盘前
+            获取市场信息
+            获取账户信息
+            更改仓位模式(期货)
+            清空仓位和挂单
+            盘中
+            更新账户信息
+            更新挂单列表
+            更新仓位信息
+            更新延迟信息
+        '''
+        print('Rest循环器启动')
+        interval = 60  # 不能太快防止占用限频
+        ### beforeTrade
+        await self.before_trade()
+        await asyncio.sleep(1)
+        ### onTrade
+        loop = 0
+        while 1:
+            loop += 1
+            try:
+                # 停机信号
+                if self.stop_flag:
+                    return
+                # 更新账户
+                res, err = await self.get_account()
+                if err:
+                    print(err)
+                if res:
+                    for i in res:
+                        if self.quote == i['asset'].upper():
+                            self.callback['onEquity']({self.quote:float(i['balance'])})
+                # 更新仓位
+                position = await self.get_position()
+                await asyncio.sleep(interval)
+                # 打印延迟
+                self.logger.debug(f'报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+            except asyncio.CancelledError:
+                return
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(30)
+    
+    async def handle_signals(self, orders):
+        '''
+            执行策略指令
+            撤销订单
+            检查订单
+            下达订单
+        '''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    if cid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(client_id=cid))
+                    elif oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3]
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    cid = orders[order_name][0]
+                    # oid = orders[order_name][1]
+                    asyncio.get_event_loop().create_task(self.check_order(client_id=cid))
+        except Exception as e:
+            traceback.print_exc()
+            self.logger.error("执行信号出错"+str(e))
+            await asyncio.sleep(0.1)
+
+

+ 562 - 0
exchange/binance_usdt_swap_ws.py

@@ -0,0 +1,562 @@
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random, csv, sys, utils
+import logging, logging.handlers
+import model
+
+def empty_call(msg):
+    pass
+
+def timeit(func):
+    def wrapper(*args, **kwargs):
+        nowTime = time.time()
+        res = func(*args, **kwargs)
+        spend_time = time.time() - nowTime
+        spend_time = round(spend_time * 1000, 5)
+        print(f'{func.__name__} 耗时 {spend_time} ms')
+        return res
+    return wrapper
+
+
+def inflate(data):
+    '''
+        解压缩数据
+    '''
+    decompress = zlib.decompressobj(-zlib.MAX_WBITS)
+    inflated = decompress.decompress(data)
+    inflated += decompress.flush()
+    return inflated
+
+class BinanceUsdtSwapWs:
+
+    def __init__(self, params:model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.URL = 'wss://fstream.binance.com/ws/'
+        else:
+            self.URL = 'wss://fstream.binance.com/ws/'
+        self.params = params
+        self.name = self.params.name
+        self.base = params.pair.split('_')[0].upper()
+        self.quote = params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        if len(self.params.pair.split('_')) > 2:
+            self.delivery = self.params.pair.split('_')[2] # 210924
+            self.symbol += f"_{self.delivery}" 
+        self.callback = {
+            "onMarket":self.save_market,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.stop_flag = 0
+        self.public_update_time = time.time()
+        self.private_update_time = time.time()
+        self.expired_time = 300
+        ### 更新id
+        self.update_flag_u = 0
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.depth = []
+        ####
+        self.depth_update = []
+        self.need_flash = 1
+        self.lastUpdateId = None # 就是小写u
+        self.depth_full = dict()
+        self.depth_full['bids'] = dict()
+        self.depth_full['asks'] = dict()
+        self.decimal = 99
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = self.params.interval
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+    async def get_sign(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        params = {
+            'timestamp':int(time.time())*1000,
+            'recvWindow':5000,
+        }
+        query_string = "&".join(["{}={}".format(k, params[k]) for k in sorted(params.keys())])
+        signature = hmac.new(self.params.secret_key.encode(), msg=query_string.encode(), digestmod=hashlib.sha256).hexdigest()
+        params['signature']=signature
+        url = 'https://fapi.binance.com/fapi/v1/listenKey'
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            url, 
+            params=params,
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        self.logger.debug("申请key")
+        login_str = await response.text()
+        print(login_str)
+        self.logger.debug(login_str)
+        await session.close()
+        try:
+            return ujson.loads(login_str)['listenKey']
+        except:
+            self.logger.error('登录失败')
+            return 'qqlh'
+
+    async def long_key(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        params = {
+            'timestamp':int(time.time())*1000,
+            'recvWindow':5000,
+        }
+        query_string = "&".join(["{}={}".format(k, params[k]) for k in sorted(params.keys())])
+        signature = hmac.new(self.params.secret_key.encode(), msg=query_string.encode(), digestmod=hashlib.sha256).hexdigest()
+        params['signature']=signature
+        url = 'https://fapi.binance.com/fapi/v1/listenKey'
+        session = aiohttp.ClientSession()
+        response = await session.put(
+            url, 
+            params=params,
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        self.logger.debug("续期key")
+        login_str = await response.text()
+        self.logger.debug(login_str)
+        await session.close()
+        return ujson.loads(login_str)
+
+    def _check_update_u(self, id):
+        if id > self.update_flag_u:
+            self.update_flag_u = id
+            return 0
+        else:
+            return 1
+
+    def _update_ticker(self, msg):
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        if self._check_update_u(msg['u']):
+            return
+        else:
+            bp = float(msg['b'])
+            bq = float(msg['B'])
+            ap = float(msg['a'])
+            aq = float(msg['A'])
+            self.ticker_info["bp"] = bp
+            self.ticker_info["ap"] = ap
+            self.callback['onTicker'](self.ticker_info)
+            ### 标准化深度
+            self.depth = [bp,bq,ap,aq]
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+
+    # @timeit
+    def _update_depth20(self, msg):
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        if self._check_update_u(msg['u']):
+            return
+        else:
+            # 更新ticker信息但不触发
+            self.ticker_info["bp"] = float(msg['b'][0][0])
+            self.ticker_info["ap"] = float(msg['a'][0][0])
+            self.callback['onTicker'](self.ticker_info)
+            if self.decimal == 99:self.decimal = utils.num_to_decimal(msg['b'][0][0])
+            ##### 标准化深度
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = round(mp * utils.EFF_RANGE / utils.LEVEL, self.decimal)
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(round(self.ticker_info["bp"]-step*i, self.decimal))
+            for i in range(utils.LEVEL):
+                ap.append(round(self.ticker_info["ap"]+step*i, self.decimal))
+            ###############################################
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for i in msg['b']:
+                price = float(i[0])
+                amount = float(i[1])
+                if price > price_thre:
+                    bv[index] += amount
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += amount
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for i in msg['a']:
+                price = float(i[0])
+                amount = float(i[1])
+                if price < price_thre:
+                    av[index] += amount
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += amount
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+
+    # @timeit
+    def _update_depth(self, msg):
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        self.depth_update.append(msg)
+        ### 检查是否有遗漏
+        for i in range(1,len(self.depth_update)):
+            if self.depth_update[i]['pu'] != self.depth_update[i]['u']:
+                self.need_flash = 1
+                self.logger.error('发现遗漏增量深度推送 重置绝对深度')
+                return
+        # print(len(self.depth_update))
+        if self.need_flash == 0: # 可以更新深度
+            for i in self.depth_update[:]:
+                u = i['u']
+                U = i['U']
+                pu = i['pu']
+                # print(f'处理 {u}')
+                if u < self.lastUpdateId: # 丢弃过旧的信息
+                    self.depth_update.remove(i)
+                else:
+                    if u >= self.lastUpdateId: # 后续更新本地副本
+                        # print(f'符合要求 {u}')
+                        # 开始更新深度
+                        for j in i['b']:
+                            price = float(j[0])
+                            amount = float(j[1])
+                            if amount > 0:
+                                self.depth_full['bids'][price] = amount 
+                            else:
+                                if price in self.depth_full['bids']:del(self.depth_full['bids'][price])
+                        for j in i['a']:
+                            price = float(j[0])
+                            amount = float(j[1])
+                            if amount > 0:
+                                self.depth_full['asks'][price] = amount 
+                            else:
+                                if price in self.depth_full['asks']:del(self.depth_full['asks'][price])
+                        self.depth_update.remove(i)
+                        self.lastUpdateId = u
+                    else:
+                        self.logger.error('增量深度不满足文档要求的条件')
+            buyP = list(self.depth_full['bids'].keys())
+            buyP.sort(reverse=True) # 从大到小
+            sellP = list(self.depth_full['asks'].keys())
+            sellP.sort(reverse=False) # 从小到大
+            # update ticker
+            self.ticker_info["bp"] = float(buyP[0])
+            self.ticker_info["ap"] = float(sellP[0])
+            self.callback['onTicker'](self.ticker_info)
+            if self.ticker_info["bp"] > self.ticker_info["ap"]:
+                self.need_flash = 1
+            ##### 标准化深度
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for price in buyP:
+                if price > price_thre:
+                    bv[index] += self.depth_full['bids'][price]
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += self.depth_full['bids'][price]
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for price in sellP:
+                if price < price_thre:
+                    av[index] += self.depth_full['asks'][price]
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += self.depth_full['asks'][price]
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+        
+    
+    def _update_trade(self, msg):
+        '''
+            根据trade修正depth对性能消耗很大
+        '''
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        price = float(msg['p'])
+        if self.decimal == 99:self.decimal=utils.num_to_decimal(price)
+        amount = float(msg['q'])
+        side = 'sell' if msg['m'] else 'buy'
+        if price > self.max_buy or self.max_buy == 0.0:
+            self.max_buy = price
+        if price < self.min_sell or self.min_sell == 0.0:
+            self.min_sell = price
+        if side == 'buy':
+            self.buy_q += amount
+            self.buy_v += amount*price
+        elif side == 'sell':
+            self.sell_q += amount
+            self.sell_v += amount*price
+        #### 修正ticker ####
+        # side = 'sell' if msg['m'] else 'buy'
+        # if side == 'buy' and price > self.ticker_info['ap']:
+        #     self.ticker_info['ap'] = price
+        #     self.callback['onTicker'](self.ticker_info)
+        # if side == 'sell' and price < self.ticker_info['bp']:
+        #     self.ticker_info['bp'] = price
+        #     self.callback['onTicker'](self.ticker_info)
+
+    def _update_account(self, msg):
+        self.private_update_time = time.time()
+        msg = ujson.loads(msg)
+        for i in msg['a']['B']:
+            if i['a'] == self.quote:
+                self.callback['onEquity']({self.quote:float(i['wb'])})
+    
+    def _update_order(self, msg):
+        '''将ws收到的订单信息触发quant'''
+        msg = ujson.loads(msg)
+        self.logger.debug(f"ws订单推送 {msg}")
+        data = msg['o']
+        if self.symbol in data['s']:
+            order_event = dict()
+            status = data['X']
+            if status == "NEW": # 新增
+                local_status = "NEW"
+            elif status in ["CANCELED", "FILLED", "EXPIRED"]: # 删除 
+                local_status = "REMOVE"  
+            elif status in ["PARTIALLY_FILLED"]: # 忽略
+                return 
+            else:
+                print("未知订单状态",data)
+                return 
+            order_event['status'] = local_status 
+            order_event['filled_price'] = float(data['ap'])
+            order_event['filled'] = float(data['z'])
+            order_event['client_id'] = data['c']
+            order_event['order_id'] = data['i']
+            self.callback['onOrder'](order_event)
+        self.private_update_time = time.time()
+
+    def _update_position(self, msg):
+        long_pos, short_pos = 0, 0
+        long_avg, short_avg = 0, 0 
+        msg = ujson.loads(msg)
+        is_update = 0
+        for i in msg['a']['P']:
+            if i['s'] == self.symbol:
+                is_update = 1
+                if i['ps'] == 'LONG':
+                    long_pos += abs(float(i['pa']))
+                    long_avg = abs(float(i['ep']))
+                if i['ps'] == 'SHORT':
+                    short_pos += abs(float(i['pa']))
+                    short_avg = abs(float(i['ep']))
+        if is_update:
+            pos = model.Position()
+            pos.longPos = long_pos
+            pos.longAvg = long_avg
+            pos.shortPos = short_pos
+            pos.shortAvg = short_avg
+            self.callback['onPosition'](pos)
+        self.private_update_time = time.time()
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def get_depth_flash(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        url = f'https://fapi.binance.com/fapi/v1/depth?symbol={self.symbol}&limit=1000'
+        session = aiohttp.ClientSession()
+        response = await session.get(
+            url, 
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        depth_flash = await response.text()
+        await session.close()
+        return ujson.loads(depth_flash)
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        while True:
+            try:
+                # 重置更新时间
+                self.public_update_time = time.time()
+                self.private_update_time = time.time()
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                # 登陆
+                if is_auth:
+                    listenKey = await self.get_sign()
+                    listenKeyTime = time.time()
+                else:
+                    listenKey = 'qqlh'
+                ws_url = self.URL+listenKey
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws连接成功')
+                    self.logger.debug(f'{self.name} ws连接成功')
+                    # 订阅
+                    symbol = self.symbol.lower()
+                    if sub_fast:
+                        channels=[
+                            f"{symbol}@bookTicker",
+                            ]
+                    else:
+                        channels=[
+                            # f"{symbol}@depth@100ms",
+                            f"{symbol}@depth20@100ms",
+                            ]
+                    if sub_trade:
+                        channels.append(f"{symbol}@aggTrade")
+                    sub_str = ujson.dumps({"method": "SUBSCRIBE", "params": channels, "id":random.randint(1,1000)})
+                    await _ws.send_str(sub_str)
+                    self.need_flash = 1
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:
+                            await _ws.close()
+                            return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except:
+                            print(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            break
+                        msg = msg.data
+                        # 处理消息
+                        # if 'depthUpdate' in msg:self._update_depth(msg)
+                        if 'depthUpdate' in msg:self._update_depth20(msg)
+                        elif 'bookTicker' in msg:self._update_ticker(msg)
+                        elif 'aggTrade' in msg:self._update_trade(msg)
+                        elif 'ACCOUNT_UPDATE' in msg:
+                            self._update_position(msg)
+                            self._update_account(msg)
+                        elif 'ORDER_TRADE_UPDATE' in msg:self._update_order(msg)
+                        elif 'ping' in msg:await _ws.send_str('pong')
+                        elif 'listenKeyExpired' in msg:raise Exception('key过期重连')
+                        # 续期listenkey
+                        if is_auth:
+                            if time.time() - listenKeyTime > 60*15: # 每15分钟续一次
+                                print('续期listenKey')
+                                listenKeyTime = time.time()
+                                await self.long_key()
+                            if time.time() - self.private_update_time > self.expired_time*5:
+                                raise Exception('长期未更新私有信息重连')
+                        if time.time() - self.public_update_time > self.expired_time:
+                            raise Exception('长期未更新公有信息重连')
+                        # if self.need_flash:
+                        #     depth_flash = await self.get_depth_flash()
+                        #     self.lastUpdateId = depth_flash['lastUpdateId']
+                        #     print(f'更新绝对深度 {self.lastUpdateId}')
+                        #     # 检查已有更新中是否包含
+                        #     self.depth_full['bids'] = dict()
+                        #     self.depth_full['asks'] = dict()
+                        #     for i in depth_flash['bids']:self.depth_full['bids'][float(i[0])] = float(i[1])
+                        #     for i in depth_flash['asks']:self.depth_full['asks'][float(i[0])] = float(i[1])
+                        #     self.need_flash = 0
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(traceback.format_exc())
+                await asyncio.sleep(1)
+
+

+ 588 - 0
exchange/bitget_usdt_swap_rest.py

@@ -0,0 +1,588 @@
+import random
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import hmac
+import base64
+import hashlib
+import traceback
+import urllib
+from urllib import parse
+from urllib.parse import urljoin
+import datetime, sys
+from urllib.parse import urlparse
+import logging, logging.handlers
+import utils
+import logging, logging.handlers
+import model
+from decimal import Decimal
+
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+
+def parse_params_to_str(params):
+    url = '?'
+    for key, value in params.items():
+        url = url + str(key) + '=' + str(value) + '&'
+    return url[0:-1]
+
+class BitGetUsdtSwapRest:
+
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.HOST = 'https://capi.bitget.com'
+        else:
+            self.HOST = 'https://capi.bitget.com'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote + '_UMCBL'
+        self._SESSIONS = dict()
+        self.logger = self.get_logger()
+        self.callback = {
+            "onMarket":empty_call,
+            "onPosition":empty_call,
+            "onOrder":empty_call,
+            "onEquity":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.exchange_info = dict()
+        self.tickSize = None
+        self.stepSize = None
+        self.delays = []
+        self.max_delay = 0
+        self.avg_delay = 0
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.mp_from_rest = None
+        self.stop_flag = 0
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+        self.min_trade_amount = 0.0
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %( message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def get_logger(self):
+
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler("log.log",maxBytes=1024*1024,encoding='utf-8')
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        # log to console
+        console = logging.StreamHandler()
+        console.setLevel(logging.WARNING)
+        logger.addHandler(handler)
+        logger.addHandler(console)
+        return logger
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+
+    def get_sign(self, timestamp, method, request_path, params, body, secret_key):
+        if body == None:
+            body_str = ""
+        else:
+            body_str = ujson.dumps(body)
+        if params == None:
+            params_str = ""
+        else:
+            params_str = parse_params_to_str(params)
+        message = str(timestamp) + str.upper(method) + request_path + params_str + body_str
+        mac = hmac.new(bytes(secret_key, encoding='utf-8'), bytes(message, encoding='utf-8'), digestmod='sha256').digest()
+        return str(base64.b64encode(mac), 'utf-8')
+
+    async def _request(self, method, uri, body=None, params=None, auth=False):
+        url = urljoin(self.HOST, uri)
+        headers = {}
+        if auth:
+            if method in ['GET','DELETE']:
+                timestamp = str(int(time.time()*1000))
+                headers = {
+                    'ACCESS-SIGN':self.get_sign(timestamp, method, uri, params, body, self.params.secret_key),
+                    'ACCESS-KEY':self.params.access_key,
+                    'ACCESS-PASSPHRASE':self.params.pass_key,
+                    'Content-Type': 'application/json',
+                    "ACCESS-TIMESTAMP":timestamp,
+                    # 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/60.0.3112.90 Safari/537.36',
+                }
+            elif method == 'POST':
+                timestamp = str(int(time.time()*1000))
+                headers = {
+                    'ACCESS-SIGN':self.get_sign(timestamp, method, uri, params, body, self.params.secret_key),
+                    'ACCESS-KEY':self.params.access_key,
+                    'ACCESS-PASSPHRASE':self.params.pass_key,
+                    'Content-Type': 'application/json',
+                    "ACCESS-TIMESTAMP":timestamp,
+                    # 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/60.0.3112.90 Safari/537.36',
+                }
+        # 发起请求
+        session = self._get_session(url)
+        timeout = aiohttp.ClientTimeout(10)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, data=ujson.dumps(body), headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json(content_type=None) 
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            res_msg = msg + f' 回报 {res}'
+            self.logger.debug(res_msg)
+            if code not in (200, 201, 202, 203, 204, 205, 206) or res['code'] not in ('00000'):
+                self.logger.error(f'URL:{url} PARAMS:{params} body:{body} ERROR:{res}')
+                return None, str(res)
+            return res, None
+        except Exception as e:
+            print(f'{self.name} rest 请求出错', str(e))
+            self.logger.error(f'请求错误 {msg}'+str(e))
+            self.logger.error(traceback.format_exc())
+            return None, e
+
+    async def get_position(self):
+        response, error = await self._request('GET', '/api/mix/v1/position/singlePosition', params={"symbol":self.symbol,'marginCoin':"USDT"}, auth=1)
+        if error:
+            print(error)
+        if response:
+            position = model.Position()
+            for i in response['data']:
+                if i['symbol'] == self.symbol:
+                    side = i['holdSide']
+                    pos = float(i['total'])
+                    price = float(i['averageOpenPrice'])
+                    if side == 'long':
+                        position.longPos = pos
+                        position.longAvg = price
+                    elif side == 'short':
+                        position.shortPos = pos
+                        position.shortAvg = price
+            return position
+
+    async def buy_token(self):
+        '''买入平台币'''
+        pass
+
+    async def check_position(self, hold_coin=0.0):
+        '''
+            现货交易  已支持全品种
+        '''
+        try:
+            #######################
+            self.logger.info("清空挂单")
+            params = {
+                'symbol':self.symbol,
+            }
+            response, error = await self._request('GET', '/api/mix/v1/order/current', params=params, auth=1)
+            if error:
+                self.logger.info(error)
+            if response:
+                for i in response['data']:
+                    oid = i['orderId']
+                    res = await self.cancel_order(order_id=oid)
+                    print(f'取消挂单{oid}',res)
+            #############################
+            self.logger.info("清空仓位")
+            res, err = await self._request('GET', '/api/mix/v1/position/allPosition', params={"productType":"umcbl"}, auth=1)
+            if err:
+                self.logger.info(err)
+            if res:
+                for i in res['data']:
+                    if float(i['total']) > 0:
+                        ####
+                        response ,error = await self._request('GET',f'/api/mix/v1/market/ticker', params={"symbol":i['symbol']}, auth=0)
+                        if response:
+                            ap = float(response["data"]['bestAsk'])
+                            bp = float(response["data"]['bestBid'])
+                            mp = (ap+bp)*0.5
+                        if error:
+                            print(error)
+                        ####
+                        if self.exchange_info == dict():
+                            await self.before_trade()
+                        ####
+                        if i['holdSide'] == 'long':
+                            _side = "pd"
+                            _price = utils.fix_price(mp * 0.999, self.exchange_info[i['symbol']].tickSize)
+                        elif i['holdSide'] == 'short':
+                            _side = "pk"
+                            _price = utils.fix_price(mp * 1.001, self.exchange_info[i['symbol']].tickSize)
+                        ####
+                        response = await self.take_order(
+                            i['symbol'],
+                            float(i['total']),
+                            _side,
+                            _price,
+                            utils.get_cid()
+                        )
+                        if response:
+                            print(response)
+            #######################        
+        except:
+            self.logger.error("清仓程序执行出错")
+            self.logger.error(traceback.format_exc())
+        return
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='limit'):
+        ''' 
+            下单
+            price: 限价时 为 必填  
+            marginCoin: 保证金币种
+            size: 限价时 为 数量  市价买 为额度 卖为数量
+            side: open_long open_short close_long close_short
+            orderType: limit(限价)  market(市价)
+            timeInForceValue:normal(普通限价订单)   postOnly(只做maker,市价不允许使用这个)  ioc(立即成交并取消剩余)  fok(全部成交或立即取消)
+            presetTakeProfitPrice: 预设止盈价格
+            presetStopLossPrice: 预设止损价格
+        '''
+        if origin_side =='kd':
+            side = "open_long"
+        elif origin_side =='pd':
+            side = "close_long"
+        elif origin_side =='kk':
+            side = "open_short"
+        elif origin_side =='pk':
+            side = "close_short"
+        else:
+            print("合约不允许此交易方向")
+            return None
+        if symbol not in self.exchange_info:
+            await self.before_trade()
+            # amount = utils.fix_amount(amount, self.exchange_info[symbol].stepSize)
+            # price = utils.fix_price(price, self.exchange_info[symbol].tickSize)
+            # amount = float(Decimal(str(amount))//Decimal(str(self.exchange_info[symbol].stepSize))*Decimal(str(self.exchange_info[symbol].stepSize)))
+            # price = float(int(Decimal(str(price))/Decimal(str(self.exchange_info[symbol].tickSize)))*Decimal(str(self.exchange_info[symbol].tickSize)))
+        if float(price) <= 0.0:
+            self.logger.error(f'下单参数错误 price:{price}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0                
+            order_event['fee'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        params = {
+            "clientOid":cid,
+            "marginCoin":self.quote,
+            "symbol": symbol,
+            "size":utils.num_to_str(amount, self.exchange_info[symbol].stepSize),
+            "side": side,
+        }
+        if order_type == 'limit':
+            params["orderType"] = "limit"
+            params["price"] = utils.num_to_str(price, self.exchange_info[symbol].tickSize)
+        elif order_type == 'market':
+            params["orderType"] = "market"
+        # params["timeInForceValue"] = "normal"
+
+        if self.params.debug == 'True':
+            return await asyncio.sleep(0.1)
+        else:
+            # 发单
+            if order_type == 'limit':
+                response, error = await self._request('POST', '/api/mix/v1/order/placeOrder', body=params, auth=1)
+            elif order_type == 'market':
+                response, error = await self._request('POST', '/api/mix/v1/order/placeOrder', body=params, auth=1)
+            # 再更新
+            if response:
+                # 增加新的
+                if 'data' in response:
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['client_id'] = cid
+                    order_event['order_id'] = response['data']["orderId"]
+                    self.callback["onOrder"](order_event)
+            if error:
+                # coinex swap 有时候返回错误也下单成功 很危险
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['filled_price'] = 0.0                
+                order_event['fee'] = 0.0
+                order_event['filled'] = 0.0
+                order_event['client_id'] = cid
+                self.callback["onOrder"](order_event)
+        return response
+    
+    async def cancel_order(self, order_id=None, client_id=None):
+        '''
+            symbol	String	是	产品ID 必须大写
+            marginCoin	String	是	保证金币种 必须大写
+            orderId	String	是	订单号
+        '''
+        if order_id:
+            response, error = await self._request('POST', f'/api/mix/v1/order/cancel-order', body={'symbol':self.symbol,'orderId':order_id, 'marginCoin':self.quote}, auth=1)
+        elif client_id:
+            return 'not support cid cancel'
+        else:
+            raise Exception("撤单出错 没指定订单号")
+        if response:
+            # await self.check_order(order_id=response['data']['orderId'])
+            self.logger.debug(f'撤单回报 {response}')
+            # order_event = dict()
+            # order_event['status'] = "REMOVE"
+            # order_event['filled_price'] = float(response['data']['price'])            
+            # order_event['fee'] = float(response['data']["deal_fee"])
+            # order_event['filled'] = float(response['data']['amount']) - float(response['data']['left'])
+            # order_event['client_id'] = response['data']["client_id"]
+            # self.callback["onOrder"](order_event)
+        if error:
+            print("撤单失败",error)
+            self.logger.error(error)
+        return response
+    
+    async def check_order(self, order_id=None, client_id=None):
+        '''
+            symbol	String	是	产品ID 必须大写
+            orderId	String	是	订单号
+        '''
+        if order_id:
+            response, error = await self._request('GET', f'/api/mix/v1/order/detail', params={'symbol':self.symbol, 'orderId':order_id}, auth=1)
+        elif client_id:
+            return 'not support cid check order'
+        else:
+            return
+        if response:
+            self.logger.debug(f'查单回报 {response}')
+            if response['data']:
+                i = response['data']
+                if i["state"] == 'new':  # 新增订单
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['filled'] = 0
+                    order_event['filled_price'] = 0
+                    order_event['client_id'] = i["clientOid"] if "clientOid" in i else ""
+                    order_event['order_id'] = i['orderId']
+                    order_event['fee'] = 0.0
+                    self.callback["onOrder"](order_event)
+                    # print(order_event)
+                elif i["state"] in ['filled','canceled']:  # 删除订单
+                    # fee 负数是扣手续费 bitget没有返佣
+                    order_event = dict() 
+                    order_event['status'] = "REMOVE"
+                    order_event['client_id'] = i["clientOid"] if "clientOid" in i else ""
+                    order_event['order_id'] = i['orderId']
+                    order_event['filled'] = float(i["filledQty"])
+                    order_event['filled_price'] = float(i["priceAvg"]) if "priceAvg" in i else float(i['price'])
+                    order_event['fee'] = -float(i['fee'])
+                    self.callback["onOrder"](order_event)
+                    # print(order_event)
+        if error:
+            print("查单失败",error)
+            self.logger.error(error)
+        return response
+
+    async def get_order_list(self):
+        params = {
+            'market':self.symbol,
+            'offset':100,
+            "side":0,
+            'limit':100,
+        }
+        response, error = await self._request('GET', '/perpetual/v1/order/pending', params=params, auth=1)
+        if response is not None:
+            for i in response['data']['records']:
+                order_event = dict()
+                order_event['symbol'] = self.symbol
+                order_event['price'] = float(i["price"])
+                order_event['amount'] = float(i["amount"])
+                order_event['filled'] = float(i["amount"])-float(i["left"])
+                order_event['filled_price'] = float(i["avg_price"])
+                order_event['client_id'] = i["clientOid"]
+                order_event['order_id'] = i['id']
+                asset_fee = float(response['data']["asset_fee"])
+                money_fee = float(response['data']["money_fee"])
+                stock_fee = float(response['data']["stock_fee"])
+                if asset_fee > 0.0: # 非amm品种
+                    order_event['fee'] = asset_fee
+                else: # amm品种
+                    order_event['fee'] = money_fee if money_fee > 0.0 else stock_fee
+                if response["data"]['status'] == 'not_deal':
+                    order_event['status'] = "NEW"
+                elif response["data"]['status'] in ['cancel','done']:
+                    order_event['status'] = "REMOVE"
+                else:
+                    s = response["data"]['status']
+                    self.logger.error(f"错误的订单状态 {s}")
+                self.callback["onOrder"](order_event)
+        if error:
+            print(error)
+        return response
+    
+    async def get_server_time(self):
+        params = {}
+        response = await self._request('GET', '/perpetual/api/v1/timestamp', params=params)
+        return response
+
+    async def get_account(self):
+        '''
+            symbol	String	是	产品ID 必须大写
+            marginCoin	String	是	保证金币种
+        '''
+        return await self._request('GET','/api/mix/v1/account/account', params={'symbol':self.symbol,'marginCoin':'USDT'}, auth=1)
+
+    async def get_market_details(self):
+        return await self._request('GET',f'/api/mix/v1/market/contracts', params={'productType':'umcbl'}, auth=0)
+
+    async def get_ticker(self):
+        res ,err = await self._request('GET',f'/api/mix/v1/market/ticker', params={"symbol":self.symbol}, auth=0)
+        if res:
+            ap = float(res["data"]['bestAsk'])
+            bp = float(res["data"]['bestBid'])
+            mp = (ap+bp)*0.5
+            d = {"name":self.name,'mp': mp, 'bp':bp, 'ap':ap}
+            self.callback['onTicker'](d)
+            return d
+        if err:
+            self.logger.error(err)
+            return None 
+
+    async def before_trade(self):
+        # 切换杠杆
+        await self.change_position_side()
+        # 获取市场最新价格
+        res = await self.get_ticker()
+        ticker_price = res["mp"]
+        if isinstance(ticker_price, float):
+            self.mp_from_rest = ticker_price
+        # 获取市场基本情况
+        res, error = await self.get_market_details()
+        if error:
+            pass 
+        else:
+            for i in res['data']:
+                if i['symbol'] == self.symbol:
+                    self.stepSize = float(Decimal("0.1")**Decimal(i["volumePlace"]))
+                    self.tickSize = float(Decimal("0.1")**Decimal(i["pricePlace"])*Decimal(i["priceEndStep"]))
+                    self.min_trade_amount = 0.0
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i['symbol']
+                exchange_info.multiplier = 1
+                exchange_info.stepSize = float(Decimal("0.1")**Decimal(i["volumePlace"]))
+                exchange_info.tickSize = float(Decimal("0.1")**Decimal(i["pricePlace"])*Decimal(i["priceEndStep"]))
+                self.exchange_info[exchange_info.symbol] = exchange_info
+
+    async def get_equity(self):
+        # 更新账户
+        res, err = await self.get_account()
+        if err:print(err)
+        if res:
+            if self.quote == res['data']['marginCoin']:
+                cash = float(res['data']['equity'])
+                self.callback['onEquity']({
+                    self.quote:cash
+                })
+                self.cash_value = cash
+
+    async def change_position_side(self):
+        res ,err = await self._request(
+            'POST',
+            '/api/mix/v1/account/setMarginMode',
+            body={"symbol":self.symbol,'marginCoin':"USDT",'marginMode':'crossed'},
+            auth=1
+            )
+        if err:print(err)
+        if res:print(res)
+
+    async def go(self):
+        interval = 60  # 不能太快防止占用限频
+        await self.before_trade()
+        await asyncio.sleep(1)
+        while 1:
+            try:
+                # 停机信号
+                if self.stop_flag:return
+                # 更新账户
+                res, err = await self.get_account()
+                if err:print(err)
+                if res:
+                    if self.quote == res['data']['marginCoin']:
+                        cash = float(res['data']['equity'])
+                        self.callback['onEquity']({
+                            self.quote:cash
+                        })
+                        self.cash_value = cash
+                # 更新仓位
+                p = await self.get_position()
+                self.callback['onPosition'](p)
+                await asyncio.sleep(interval)
+                # 打印延迟
+                self.get_delay_info()
+                self.logger.debug(f'报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(10)
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+    
+    async def handle_signals(self, orders):
+        '''执行策略指令'''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    # cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    # 只能用oid撤单
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3]
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    # cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    # 只能用oid查单
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.check_order(order_id=oid))
+        except Exception as e:
+            traceback.print_exc()
+            self.logger.error("执行信号出错"+str(e))
+            await asyncio.sleep(0.1)
+
+

+ 422 - 0
exchange/bitget_usdt_swap_ws.py

@@ -0,0 +1,422 @@
+from os import access
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random
+import gzip, csv, sys
+from uuid import uuid4
+import logging, logging.handlers
+
+from yarl import URL
+import utils
+import model
+from decimal import Decimal
+
+def empty_call(msg):
+    # print(msg)
+    pass
+
+def sign(message, secret_key):
+    mac = hmac.new(bytes(secret_key, encoding='utf8'), bytes(message, encoding='utf-8'), digestmod='sha256')
+    d = mac.digest()
+    return str(base64.b64encode(d), 'utf8')
+
+def pre_hash(timestamp, method, request_path):
+    return str(timestamp) + str.upper(method) + str(request_path)
+
+class BitgetUsdtSwapWs:
+
+    def __init__(self, params: model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.BaseURL = "wss://ws.bitget.com/mix/v1/stream"  
+        else:
+            self.BaseURL = "wss://ws.bitget.com/mix/v1/stream"  
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        self.instId = self.symbol + '_UMCBL'
+        self.callback = {
+            "onMarket":self.save_market,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.multiplier = None
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.update_t = 0.0
+        self.depth = []
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = float(self.params.interval)
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+
+    async def get_sign(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        params = {
+            'timestamp':int(time.time())*1000,
+            'recvWindow':5000,
+        }
+        query_string = "&".join(["{}={}".format(k, params[k]) for k in sorted(params.keys())])
+        signature = hmac.new(self.params.secret_key.encode(), msg=query_string.encode(), digestmod=hashlib.sha256).hexdigest()
+        params['signature']=signature
+        url = 'https://fapi.binance.com/fapi/v1/listenKey'
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            url, 
+            params=params,
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        login_str = await response.text()
+        await session.close()
+        return ujson.loads(login_str)['listenKey']
+
+    def _update_depth(self, msg):
+        t = int(msg['data'][0]['ts'])
+        if t > self.update_t:
+            self.update_t = t
+            self.ticker_info['bp'] = float(msg['data'][0]['bids'][0][0])
+            self.ticker_info['ap'] = float(msg['data'][0]['asks'][0][0])
+            self.callback['onTicker'](self.ticker_info)
+            ##### 标准化深度
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for bid in msg['data'][0]['bids']:
+                price = float(bid[0])
+                amount = float(bid[1])
+                if price > price_thre:
+                    bv[index] += amount
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += amount
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for ask in msg['data'][0]['asks']:
+                price = float(ask[0])
+                amount = float(ask[1])
+                if price < price_thre:
+                    av[index] += amount
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += amount
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+
+    # def _update_ticker(self, msg):
+    #     if msg['data']['sequence'] > self.update_t:
+    #         self.update_t = msg['data']['sequence']
+    #         self.ticker_info['bp'] = float(msg['data']['bestBidPrice'])
+    #         self.ticker_info['ap'] = float(msg['data']['bestAskPrice'])
+    #         self.callback['onTicker'](self.ticker_info)
+    
+    def _update_trade(self, msg):
+        for i in msg['data']:
+            price = float(i[1])
+            side = i[3]
+            amount = float(i[2])
+            if price > self.max_buy or self.max_buy == 0.0:
+                self.max_buy = price
+            if price < self.min_sell or self.min_sell == 0.0:
+                self.min_sell = price
+            if side == 'buy':
+                self.buy_q += amount
+                self.buy_v += amount*price
+            elif side == 'sell':
+                self.sell_q += amount
+                self.sell_v += amount*price
+
+    def _update_position(self, msg):
+        pos = model.Position()
+        for i in msg['data']:
+            symbol = i['instName']
+            if symbol == self.symbol:
+                amt = float(i["total"])
+                side = i['holdSide']
+                ep = float(i["averageOpenPrice"])
+                if side == 'long':
+                    pos.longPos = amt
+                    pos.longAvg = ep
+                elif side == 'short':
+                    pos.shortPos = amt
+                    pos.shortAvg = ep
+                else:
+                    pass
+        self.callback["onPosition"](pos)
+
+    def _update_account(self, msg):
+        for i in msg['data']:
+            if i['marginCoin'] == 'USDT':
+                self.callback['onEquity'] = {self.quote:float(i['equity'])}
+    
+    def _update_order(self, msg):
+        self.logger.debug(f"ws订单推送 {msg}")
+        # print(msg)
+        for i in msg['data']:
+            if self.instId == i['instId']:
+                if i["status"] == 'new':  # 新增订单
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['filled'] = 0
+                    order_event['filled_price'] = 0
+                    order_event['client_id'] = i["clOrdId"] if "clOrdId" in i else ""
+                    order_event['order_id'] = i['ordId']
+                    order_event['fee'] = 0.0
+                    self.callback["onOrder"](order_event)
+                    # print('新建',order_event['client_id'])
+                elif i["status"] in ['full-fill','cancelled']:  # 删除订单
+                    # fee 负数是扣手续费 bitget没有返佣
+                    order_event = dict() 
+                    order_event['status'] = "REMOVE"
+                    order_event['client_id'] = i["clOrdId"] if "clOrdId" in i else ""
+                    order_event['order_id'] = i['ordId']
+                    order_event['filled'] = float(i["accFillSz"])
+                    order_event['filled_price'] = float(i["fillPx"]) if 'fillPx' in i else float(i['px'])
+                    for j in i['orderFee']:
+                        if j['feeCcy'] == 'USDT':
+                            order_event['fee'] = -float(j['fee'])
+                    self.callback["onOrder"](order_event)
+                    # print('移除',order_event['client_id'])
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket']({'name': self.name,'data':market_data})
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    async def get_token(self, is_auth):
+        # 获取 token
+        if is_auth:
+            uri = "/api/v1/bullet-private"
+        else:
+            uri = "/api/v1/bullet-public"
+        headers = {}
+        if is_auth:
+            now_time = int(time.time()) * 1000
+            str_to_sign = str(now_time) + "POST" + uri
+            sign = base64.b64encode(hmac.new(self.params.secret_key.encode('utf-8'), str_to_sign.encode('utf-8'), hashlib.sha256).digest())
+            passphrase = base64.b64encode(hmac.new(self.params.secret_key.encode('utf-8'), self.params.pass_key.encode('utf-8'), hashlib.sha256).digest())
+            headers = {
+                "KC-API-SIGN": sign.decode(),
+                "KC-API-TIMESTAMP": str(now_time),
+                "KC-API-KEY": self.params.access_key,
+                "KC-API-PASSPHRASE": passphrase.decode(),
+                "Content-Type": "application/json",
+                "KC-API-KEY-VERSION": "2"
+            }
+        headers["User-Agent"] = "kucoin-python-sdk/v1.0"
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            self.BaseURL+uri, 
+            timeout=5, 
+            headers=headers,
+            proxy=self.proxy
+            )
+        res = await response.text()
+        res = ujson.loads(res)
+        await session.close()
+        if res["code"] == "200000":
+            token = res["data"]["token"]
+            ws_connect_id = str(uuid4()).replace('-', '')
+            endpoint = res["data"]['instanceServers'][0]['endpoint']
+            ws_endpoint = f"{endpoint}?token={token}&connectId={ws_connect_id}"
+            encrypt = res["data"]['instanceServers'][0]['encrypt']
+            if is_auth:
+                ws_endpoint += '&acceptUserMessage=true'
+            return ws_endpoint, encrypt
+        else:
+            raise Exception("kucoin usdt swap 获取token错误")
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        while True:
+            try:
+                ping_time = time.time()
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                # 登陆
+                ws_url = self.BaseURL
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws连接成功')
+                    self.logger.info(f'{self.name} ws连接成功')
+                    # 订阅
+                    channels=[
+                        {
+                            "instType":"mc",
+                            "channel":"books5",
+                            "instId":self.symbol
+                        }
+                        ]
+                    if sub_trade:
+                        channels += [
+                            {
+                                "instType":"MC",
+                                "channel":"trade",
+                                "instId":self.symbol
+                            }
+                        ]
+                    if is_auth:
+                        # 先登录
+                        timestamp = int(time.time())
+                        sign_str = sign(pre_hash(timestamp, "GET", '/user/verify'), self.params.secret_key)
+                        await _ws.send_str(ujson.dumps({
+                            "op":"login",
+                                "args":[
+                                    {
+                                        "apiKey":self.params.access_key,
+                                        "passphrase":self.params.pass_key,
+                                        "timestamp":timestamp,
+                                        "sign":sign_str
+                                    }
+                                ]
+                        }))
+                        # 先登录
+                        channels += [
+                            {
+                                "instType": "UMCBL",
+                                "channel": "account",
+                                "instId": "default"
+                            },
+                            {
+                                "instType": "UMCBL",
+                                "channel": "positions",
+                                "instId": "default"
+                            },
+                            {
+                                "channel": "orders",
+                                "instType": "UMCBL",
+                                "instId": "default"
+                            }
+                        ]
+                    for i in channels:
+                        sub_str = ujson.dumps({"args": [i], "op":"subscribe"})
+                        await _ws.send_str(sub_str)
+                    while True:
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except:
+                            print(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            break
+                        # self.logger.debug(msg)
+                        try:
+                            msg = ujson.loads(msg.data)
+                        except:
+                            # self.logger.warning(f'非json格式string:{msg}')
+                            pass
+                        # 处理消息
+                        if 'data' in msg:
+                            if 'books5' in msg['arg']['channel']:self._update_depth(msg)
+                            # elif "tickerV2" in msg["subject"]:self._update_ticker(msg)
+                            elif 'trade' in msg['arg']['channel']:self._update_trade(msg)
+                            elif 'account' in msg['arg']['channel']:self._update_account(msg)
+                            elif 'orders' in msg['arg']['channel']:self._update_order(msg)
+                            elif 'positions' in msg['arg']['channel']:self._update_position(msg)
+                        # heartbeat
+                        if time.time() - ping_time > 15:
+                            await _ws.send_str("ping")
+                            ping_time = time.time()
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(traceback.format_exc())
+                # await asyncio.sleep(1)
+

+ 658 - 0
exchange/bybit_usdt_swap_rest.py

@@ -0,0 +1,658 @@
+import imp
+import random
+import re
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import hmac
+import base64
+import hashlib
+import traceback
+import urllib
+from urllib import parse
+from urllib.parse import urljoin
+import datetime, sys
+from urllib.parse import urlparse
+import logging, logging.handlers
+import utils
+import logging, logging.handlers
+import model
+from decimal import Decimal
+
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+
+def parse_params_to_str(params):
+    url = ''
+    for key in sorted(params.keys()):
+        value = params[key]
+        if isinstance(value, bool):
+            if params[key]:
+                value = "true"
+            else:
+                value = "false"
+        url = url + str(key) + '=' + str(value) + '&'
+    return url[0:-1]
+
+class BybitUsdtSwapRest:
+
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.HOST = 'https://api.bybit.com'
+        else:
+            self.HOST = 'https://api.bybit.com'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        self._SESSIONS = dict()
+        self.logger = self.get_logger()
+        self.callback = {
+            "onMarket":empty_call,
+            "onPosition":empty_call,
+            "onOrder":empty_call,
+            "onEquity":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.exchange_info = dict()
+        self.tickSize = None
+        self.stepSize = None
+        self.delays = []
+        self.max_delay = 0
+        self.avg_delay = 0
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.mp_from_rest = None
+        self.stop_flag = 0
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+        self.min_trade_amount = 0.0
+        self.rate_limit_left = 120
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %( message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def get_logger(self):
+
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler("log.log",maxBytes=1024*1024,encoding='utf-8')
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        # log to console
+        console = logging.StreamHandler()
+        console.setLevel(logging.WARNING)
+        logger.addHandler(handler)
+        logger.addHandler(console)
+        return logger
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+
+    def get_sign(self, params, secret_key):
+        message = parse_params_to_str(params)
+        # mac = hmac.new(bytes(secret_key, encoding='utf-8'), bytes(message, encoding='utf-8'), digestmod='sha256').digest()
+
+        hash = hmac.new(bytes(secret_key, encoding='utf-8'), bytes(message, encoding='utf-8'), hashlib.sha256)
+        signature = hash.hexdigest()
+        return signature
+
+    async def _request(self, method, uri, body=None, params=None, auth=False):
+        url = urljoin(self.HOST, uri)
+        headers = {}
+        if auth:
+            if method in ['GET']:
+                timestamp = str(int(time.time()*1000))
+                params['timestamp'] = timestamp
+                params['api_key'] = self.params.access_key
+                params['recv_window'] = "15000"
+                params['sign'] = self.get_sign(params, self.params.secret_key)
+                headers = {}
+            elif method == 'POST':
+                timestamp = str(int(time.time()*1000))
+                params['timestamp'] = timestamp
+                params['api_key'] = self.params.access_key
+                params['recv_window'] = "15000"
+                params['sign'] = self.get_sign(params, self.params.secret_key)
+                headers = {}
+        url = url + '?' + parse_params_to_str(params)
+        # 发起请求
+        session = self._get_session(url)
+        timeout = aiohttp.ClientTimeout(10)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, data=ujson.dumps(params), headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json(content_type=None) 
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            res_msg = msg + f' 回报 {res}'
+            self.logger.debug(res_msg)
+            #### 检查频率限制
+            if 'rate_limit_status' in res:
+                self.rate_limit_left = res['rate_limit_status']
+                if self.rate_limit_left < 20:
+                    self.callback['onExit'](f"{self.name} 即将触发限频封禁 紧急退出")
+            ####
+            if code not in (200, 201, 202, 203, 204, 205, 206) or res['ret_code'] not in [0]:
+                self.logger.error(f'URL:{url} PARAMS:{params} body:{body} ERROR:{res}')
+                return None, str(res)
+            return res, None
+        except Exception as e:
+            print(f'{self.name} rest 请求出错', str(e))
+            self.logger.error(f'请求错误 {msg}'+str(e))
+            self.logger.error(traceback.format_exc())
+            return None, e
+
+    async def get_position(self):
+        response, error = await self._request('GET', '/private/linear/position/list', params={}, auth=1)
+        if error:
+            print(error)
+        if response:
+            position = model.Position()
+            for j in response['result']:
+                i = j['data']
+                if i['symbol'] == self.symbol:
+                    side = i['side'].lower()
+                    pos = float(i['size'])
+                    price = float(i['entry_price'])
+                    if side == 'buy':
+                        position.longPos = pos
+                        position.longAvg = price
+                    elif side == 'sell':
+                        position.shortPos = pos
+                        position.shortAvg = price
+            return position
+
+    async def buy_token(self):
+        '''买入平台币'''
+        pass
+
+    async def check_position(self, hold_coin=0.0):
+        '''
+            现货交易  已支持全品种
+        '''
+        try:
+            #######################
+            self.logger.info("清空挂单")
+            params = {
+                'symbol':self.symbol,
+            }
+            response, error = await self._request('POST', '/private/linear/order/cancel-all', params=params, auth=1)
+            if error:
+                self.logger.info("全部撤单失败")
+            if response:
+                self.logger.info("全部撤单成功")
+            #############################
+            self.logger.info("清空仓位")
+            res, err = await self._request('GET', '/private/linear/position/list', params={}, auth=1)
+            if err:
+                self.logger.info(err)
+            if res:
+                for i in res['result']:
+                    if float(i['data']['size']) > 0:
+                        ####
+                        response ,error = await self._request('GET',f'/v2/public/orderBook/L2', params={"symbol":i['data']['symbol']}, auth=0)
+                        if response:
+                            bids = []
+                            asks = []
+                            for j in response['result']:
+                                if j['side'].lower() == 'buy':
+                                    bids.append(float(j['price']))
+                                if j['side'].lower() == 'sell':
+                                    asks.append(float(j['price'])) 
+                            ap = min(asks)
+                            bp = max(bids)
+                            mp = (ap+bp)*0.5
+                        if error:
+                            print(error)
+                        ####
+                        if self.exchange_info == dict():
+                            await self.before_trade()
+                        ####
+                        if i['data']['side'].lower() == 'buy':
+                            _side = "pd"
+                            _price = utils.fix_price(mp * 0.999, self.exchange_info[i['data']['symbol']].tickSize)
+                        elif i['data']['side'].lower() == 'sell':
+                            _side = "pk"
+                            _price = utils.fix_price(mp * 1.001, self.exchange_info[i['data']['symbol']].tickSize)
+                        ####
+                        response = await self.take_order(
+                            i['data']['symbol'],
+                            float(i['data']['size']),
+                            _side,
+                            _price,
+                            utils.get_cid()
+                        )
+                        if response:
+                            print(response)
+            #######################        
+        except:
+            self.logger.error("清仓程序执行出错")
+            self.logger.error(traceback.format_exc())
+        return
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='limit'):
+        ''' 
+            下单
+            price: 限价时 为 必填  
+            marginCoin: 保证金币种
+            size: 限价时 为 数量  市价买 为额度 卖为数量
+            side: open_long open_short close_long close_short
+            orderType: limit(限价)  market(市价)
+            timeInForceValue:normal(普通限价订单)   postOnly(只做maker,市价不允许使用这个)  ioc(立即成交并取消剩余)  fok(全部成交或立即取消)
+            presetTakeProfitPrice: 预设止盈价格
+            presetStopLossPrice: 预设止损价格
+        '''
+        if origin_side =='kd':
+            side = "Buy"
+            reduce_only = "false"
+        elif origin_side =='pd':
+            side = "Sell"
+            reduce_only = "true"
+        elif origin_side =='kk':
+            side = "Sell"
+            reduce_only = "false"
+        elif origin_side =='pk':
+            side = "Buy"
+            reduce_only = "true"
+        else:
+            print("合约不允许此交易方向")
+            return None
+        if self.rate_limit_left < 40 and origin_side in ['kd','kk']:
+            print("即将触发限频 停止开仓单发单")
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0                
+            order_event['fee'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return
+        if symbol not in self.exchange_info:
+            await self.before_trade()
+            # amount = utils.fix_amount(amount, self.exchange_info[symbol].stepSize)
+            # price = utils.fix_price(price, self.exchange_info[symbol].tickSize)
+            # amount = float(Decimal(str(amount))//Decimal(str(self.exchange_info[symbol].stepSize))*Decimal(str(self.exchange_info[symbol].stepSize)))
+            # price = float(int(Decimal(str(price))/Decimal(str(self.exchange_info[symbol].tickSize)))*Decimal(str(self.exchange_info[symbol].tickSize)))
+        if float(price) <= 0.0:
+            self.logger.error(f'下单参数错误 price:{price}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0                
+            order_event['fee'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        if float(amount) <=0:
+            self.logger.error(f'下单参数错误 amount:{amount}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0                
+            order_event['fee'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        params = {
+            "order_link_id":cid,
+            "symbol": symbol,
+            "qty":utils.num_to_str(amount, self.exchange_info[symbol].stepSize),
+            "side": side,
+            "reduce_only":reduce_only,
+            "close_on_trigger":False,
+            "time_in_force":"GoodTillCancel",
+        }
+        if order_type == 'limit':
+            params["order_type"] = "Limit"
+            params["price"] = utils.num_to_str(price, self.exchange_info[symbol].tickSize)
+        elif order_type == 'market':
+            params["order_type"] = "Market"
+        if self.params.debug == 'True':
+            return await asyncio.sleep(0.1)
+        else:
+            # 发单
+            if order_type == 'limit':
+                response, error = await self._request('POST', '/private/linear/order/create', params=params, auth=1)
+            elif order_type == 'market':
+                response, error = await self._request('POST', '/private/linear/order/create', params=params, auth=1)
+            # 再更新
+            if response:
+                # 增加新的
+                if 'result' in response:
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['client_id'] = cid
+                    order_event['order_id'] = response['result']["order_id"]
+                    self.callback["onOrder"](order_event)
+            if error:
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['filled_price'] = 0.0                
+                order_event['fee'] = 0.0
+                order_event['filled'] = 0.0
+                order_event['client_id'] = cid
+                self.callback["onOrder"](order_event)
+        return response
+    
+    async def cancel_order(self, order_id=None, client_id=None):
+        '''
+            symbol	String	是	产品ID 必须大写
+            marginCoin	String	是	保证金币种 必须大写
+            orderId	String	是	订单号
+        '''
+        if order_id:
+            response, error = await self._request('POST', f'/private/linear/order/cancel', params={'symbol':self.symbol,'order_id':order_id}, auth=1)
+        elif client_id:
+            response, error = await self._request('POST', f'/private/linear/order/cancel', params={'symbol':self.symbol,'order_link_id':client_id}, auth=1)
+        else:
+            raise Exception("撤单出错 没指定订单号")
+        if response:
+            pass
+            # await self.check_order(order_id=response['data']['orderId'])
+            # self.logger.debug(f'撤单回报 {response}')
+            # order_event = dict()
+            # order_event['status'] = "REMOVE"
+            # order_event['filled_price'] = float(response['data']['price'])            
+            # order_event['fee'] = float(response['data']["deal_fee"])
+            # order_event['filled'] = float(response['data']['amount']) - float(response['data']['left'])
+            # order_event['client_id'] = response['data']["client_id"]
+            # self.callback["onOrder"](order_event)
+        if error:
+            pass
+            # print("撤单失败",error)
+            # self.logger.error(error)
+        return response
+    
+    async def check_order(self, order_id=None, client_id=None):
+        '''
+            symbol	String	是	产品ID 必须大写
+            orderId	String	是	订单号
+        '''
+        if order_id:
+            response, error = await self._request('GET', f'/private/linear/order/search', params={'symbol':self.symbol, 'order_id':order_id}, auth=1)
+        elif client_id:
+            response, error = await self._request('GET', f'/private/linear/order/search', params={'symbol':self.symbol, 'order_link_id':client_id}, auth=1)
+        else:
+            return
+        if response:
+            self.logger.debug(f'查单回报 {response}')
+            if response['result']:
+                if response['result']["order_status"] == 'New':  # 新增订单
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['filled'] = 0
+                    order_event['filled_price'] = 0
+                    order_event['client_id'] = response['result']["order_link_id"] if "order_link_id" in response['result'] else ""
+                    order_event['order_id'] = response['result']['order_id']
+                    order_event['fee'] = 0.0
+                    self.callback["onOrder"](order_event)
+                    # print(order_event)
+                elif response['result']["order_status"] in ['Filled','Cancelled']:  # 删除订单
+                    # fee 负数是扣手续费 bitget没有返佣
+                    order_event = dict() 
+                    order_event['status'] = "REMOVE"
+                    order_event['client_id'] = response['result']["order_link_id"] if "order_link_id" in response['result'] else ""
+                    order_event['order_id'] = response['result']['order_id']
+                    order_event['filled'] = float(response['result']["cum_exec_qty"])
+                    order_event['filled_price'] = float(response['result']["last_exec_price"]) \
+                        if 'last_exec_price' in response['result'] else float(response['result']['price'])
+                    order_event['fee'] = float(response['result']['cum_exec_fee'])
+                    self.callback["onOrder"](order_event)
+                    # print(order_event)
+        if error:
+            print("查单失败",error)
+            self.logger.error(error)
+        return response
+
+    async def get_order_list(self):
+        params = {
+            'market':self.symbol,
+            'offset':100,
+            "side":0,
+            'limit':100,
+        }
+        response, error = await self._request('GET', '/perpetual/v1/order/pending', params=params, auth=1)
+        if response is not None:
+            for i in response['data']['records']:
+                order_event = dict()
+                order_event['symbol'] = self.symbol
+                order_event['price'] = float(i["price"])
+                order_event['amount'] = float(i["amount"])
+                order_event['filled'] = float(i["amount"])-float(i["left"])
+                order_event['filled_price'] = float(i["avg_price"])
+                order_event['client_id'] = i["clientOid"]
+                order_event['order_id'] = i['id']
+                asset_fee = float(response['data']["asset_fee"])
+                money_fee = float(response['data']["money_fee"])
+                stock_fee = float(response['data']["stock_fee"])
+                if asset_fee > 0.0: # 非amm品种
+                    order_event['fee'] = asset_fee
+                else: # amm品种
+                    order_event['fee'] = money_fee if money_fee > 0.0 else stock_fee
+                if response["data"]['status'] == 'not_deal':
+                    order_event['status'] = "NEW"
+                elif response["data"]['status'] in ['cancel','done']:
+                    order_event['status'] = "REMOVE"
+                else:
+                    s = response["data"]['status']
+                    self.logger.error(f"错误的订单状态 {s}")
+                self.callback["onOrder"](order_event)
+        if error:
+            print(error)
+        return response
+    
+    async def get_server_time(self):
+        params = {}
+        response = await self._request('GET', '/perpetual/api/v1/timestamp', params=params)
+        return response
+
+    async def get_account(self):
+        '''
+            symbol	String	是	产品ID 必须大写
+            marginCoin	String	是	保证金币种
+        '''
+        return await self._request('GET','/v2/private/wallet/balance', params={}, auth=1)
+
+    async def transfer(self):
+        '''
+            
+        '''
+        from uuid import uuid4
+        params = {
+            "transfer_id":str(uuid4()),
+            "coin":"USDT",
+            "amount":"200",
+            "from_account_type":"SPOT",
+            "to_account_type":"CONTRACT",
+        }
+        return await self._request('POST','/asset/v1/private/transfer', params=params, auth=1)
+
+    async def get_all_sub_account(self):
+        '''
+            
+        '''
+        params = {
+        }
+        return await self._request('GET','/asset/v1/private/sub-member/member-ids', params=params, auth=1)
+
+    async def get_market_details(self):
+        return await self._request('GET',f'/v2/public/symbols', params={}, auth=0)
+
+    async def get_ticker(self):
+        ####
+        response ,error = await self._request('GET',f'/v2/public/orderBook/L2', params={"symbol":self.symbol}, auth=0)
+        if response:
+            bids = []
+            asks = []
+            for j in response['result']:
+                if j['side'].lower() == 'buy':
+                    bids.append(float(j['price']))
+                if j['side'].lower() == 'sell':
+                    asks.append(float(j['price'])) 
+            ap = min(asks)
+            bp = max(bids)
+            mp = (ap+bp)*0.5
+            d = {"name":self.name,'mp': mp, 'bp':bp, 'ap':ap}
+            self.callback['onTicker'](d)
+            return d
+        if error:
+            self.logger.error(error)
+            return None 
+
+    async def before_trade(self):
+        # 切换杠杆
+        await self.change_position_side()
+        # 获取市场最新价格
+        res = await self.get_ticker()
+        ticker_price = res["mp"]
+        if isinstance(ticker_price, float):
+            self.mp_from_rest = ticker_price
+        # 获取市场基本情况
+        res, error = await self.get_market_details()
+        if error:
+            pass 
+        else:
+            for i in res['result']:
+                if i['name'] == self.symbol:
+                    self.stepSize = float(i['lot_size_filter']['qty_step'])
+                    self.tickSize = float(i['price_filter']['tick_size'])
+                    self.min_trade_amount = 0.0
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i['name']
+                exchange_info.multiplier = 1
+                exchange_info.stepSize = float(i['lot_size_filter']['qty_step'])
+                exchange_info.tickSize = float(i['price_filter']['tick_size'])
+                self.exchange_info[exchange_info.symbol] = exchange_info
+
+    async def get_equity(self):
+        # 更新账户
+        res, err = await self.get_account()
+        if err:print(err)
+        if res:
+            for i in res['result']:
+                if self.quote == i:
+                    cash = float(res['result'][i]['equity'])
+                    self.callback['onEquity']({
+                        self.quote:cash
+                    })
+                    self.cash_value = cash
+
+    async def change_position_side(self):
+        '''切换到全仓'''
+        res ,err = await self._request(
+            'POST',
+            '/private/linear/position/switch-mode',
+            params={
+                'symbol':self.symbol,
+                'mode':"BothSide",},
+            auth=1
+            )
+        if err:print(err)
+        if res:print(res)
+
+    async def go(self):
+        interval = 60  # 不能太快防止占用限频
+        await self.before_trade()
+        await asyncio.sleep(1)
+        while 1:
+            try:
+                # 停机信号
+                if self.stop_flag:return
+                # 更新账户
+                res, err = await self.get_account()
+                if err:print(err)
+                if res:
+                    for i in res['result']:
+                        if self.quote == i:
+                            cash = float(res['result'][i]['equity'])
+                            self.callback['onEquity']({
+                                self.quote:cash
+                            })
+                            self.cash_value = cash
+                # 更新仓位
+                p = await self.get_position()
+                self.callback['onPosition'](p)
+                await asyncio.sleep(interval)
+                # 打印延迟
+                self.get_delay_info()
+                self.logger.debug(f'报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(10)
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+    
+    async def handle_signals(self, orders):
+        '''执行策略指令'''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    # cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    # 只能用oid撤单
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3]
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    # cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    # 只能用oid查单
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.check_order(order_id=oid))
+        except Exception as e:
+            traceback.print_exc()
+            self.logger.error("执行信号出错"+str(e))
+            await asyncio.sleep(0.1)
+
+

+ 544 - 0
exchange/bybit_usdt_swap_ws.py

@@ -0,0 +1,544 @@
+from os import access
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random
+import gzip, csv, sys
+from uuid import uuid4
+import logging, logging.handlers
+
+from yarl import URL
+import utils
+import model
+from decimal import Decimal
+
+def empty_call(msg):
+    # print(msg)
+    pass
+
+
+class BybitUsdtSwapWs:
+
+    def __init__(self, params: model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.BaseURL_public = "wss://stream.bybit.com/realtime_public"
+            self.BaseURL_private = "wss://stream.bybit.com/realtime_private"
+        else:
+            self.BaseURL_public = "wss://stream.bybit.com/realtime_public"  
+            self.BaseURL_private = "wss://stream.bybit.com/realtime_private"  
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        self.callback = {
+            "onMarket":self.save_market,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.multiplier = None
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.update_t = 0.0
+        self.depth = []
+        self.orderbook = dict()
+        self.orderbook['bid'] = dict()
+        self.orderbook['ask'] = dict()
+        self.last_on_depth_time = time.time()
+        self.sub_fast = 0
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = float(self.params.interval)
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+
+    async def get_sign(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        params = {
+            'timestamp':int(time.time())*1000,
+            'recvWindow':5000,
+        }
+        query_string = "&".join(["{}={}".format(k, params[k]) for k in sorted(params.keys())])
+        signature = hmac.new(self.params.secret_key.encode(), msg=query_string.encode(), digestmod=hashlib.sha256).hexdigest()
+        params['signature']=signature
+        url = 'https://fapi.binance.com/fapi/v1/listenKey'
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            url, 
+            params=params,
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        login_str = await response.text()
+        await session.close()
+        return ujson.loads(login_str)['listenKey']
+
+    def _update_depth(self, msg):
+        t = int(msg['timestamp_e6'])
+        if t > self.update_t:
+            self.update_t = t
+            ###### 维护orderbook
+            if msg['type'] == 'snapshot':
+                self.orderbook = dict()
+                self.orderbook['bid'] = dict()
+                self.orderbook['ask'] = dict()
+                for i in msg['data']['order_book']:
+                    if i['side'] == 'Buy':
+                        self.orderbook['bid'][float(i['price'])] = float(i['size'])
+                    elif i['side'] == 'Sell':
+                        self.orderbook['ask'][float(i['price'])] = float(i['size'])
+                    else:
+                        print('错误类型')
+            elif msg['type'] == 'delta':
+                for _type in msg['data']:
+                    for i in msg['data'][_type]:
+                        if _type == 'delete':
+                            if i['side'] == 'Buy':
+                                if float(i['price']) in self.orderbook['bid']:
+                                    del(self.orderbook['bid'][float(i['price'])])
+                            elif i['side'] == 'Sell':
+                                if float(i['price']) in self.orderbook['ask']:
+                                    del(self.orderbook['ask'][float(i['price'])])
+                            else:
+                                print('错误类型')
+                        elif _type == 'update':
+                            if i['side'] == 'Buy':
+                                if float(i['price']) in self.orderbook['bid']:
+                                    self.orderbook['bid'][float(i['price'])] = float(i['size'])
+                            elif i['side'] == 'Sell':
+                                if float(i['price']) in self.orderbook['ask']:
+                                    self.orderbook['ask'][float(i['price'])] = float(i['size'])
+                            else:
+                                print('错误类型')
+                        elif _type == 'insert':
+                            if i['side'] == 'Buy':
+                                self.orderbook['bid'][float(i['price'])] = float(i['size'])
+                            elif i['side'] == 'Sell':
+                                self.orderbook['ask'][float(i['price'])] = float(i['size'])
+                            else:
+                                print('错误类型')
+                        else:
+                            print('错误类型')
+            else:
+                print('未知depth类型')
+            ###### 限制回调频率
+            now_time = time.time()
+            if now_time - self.last_on_depth_time >= 0.2 or self.sub_fast:
+                self.last_on_depth_time = time.time()
+                ######
+                self.ticker_info['bp'] = max(self.orderbook['bid'].keys())
+                self.ticker_info['ap'] = min(self.orderbook['ask'].keys())
+                ######
+                if self.ticker_info['bp'] > self.ticker_info['ap']:
+                    raise Exception("增量深度出现错误")
+                ######
+                self.callback['onTicker'](self.ticker_info)
+                ##### 标准化深度
+                mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+                step = mp * utils.EFF_RANGE / utils.LEVEL
+                bp = []
+                ap = []
+                bv = [0 for _ in range(utils.LEVEL)]
+                av = [0 for _ in range(utils.LEVEL)]
+                for i in range(utils.LEVEL):
+                    bp.append(self.ticker_info["bp"]-step*i)
+                for i in range(utils.LEVEL):
+                    ap.append(self.ticker_info["ap"]+step*i)
+                # 
+                price_thre = self.ticker_info["bp"] - step
+                index = 0
+                for bid_price in self.orderbook['bid'].keys():
+                    price = bid_price
+                    amount = self.orderbook['bid'][bid_price]
+                    if price > price_thre:
+                        bv[index] += amount
+                    else:
+                        price_thre -= step
+                        index += 1
+                        if index == utils.LEVEL:
+                            break
+                        bv[index] += amount
+                price_thre = self.ticker_info["ap"] + step
+                index = 0
+                for ask_price in self.orderbook['ask'].keys():
+                    price = ask_price
+                    amount = self.orderbook['ask'][ask_price]
+                    if price < price_thre:
+                        av[index] += amount
+                    else:
+                        price_thre += step
+                        index += 1
+                        if index == utils.LEVEL:
+                            break
+                        av[index] += amount
+                self.depth = bp + bv + ap + av
+                self.callback['onDepth']({'name':self.name,'data':self.depth})
+                # print('更新深度', time.time(),self.depth)
+
+    # def _update_ticker(self, msg):
+    #     if msg['data']['sequence'] > self.update_t:
+    #         self.update_t = msg['data']['sequence']
+    #         self.ticker_info['bp'] = float(msg['data']['bestBidPrice'])
+    #         self.ticker_info['ap'] = float(msg['data']['bestAskPrice'])
+    #         self.callback['onTicker'](self.ticker_info)
+    
+    def _update_trade(self, msg):
+        for i in msg['data']:
+            price = float(i['price'])
+            side = i['side'].lower()
+            amount = float(i['size'])
+            if price > self.max_buy or self.max_buy == 0.0:
+                self.max_buy = price
+            if price < self.min_sell or self.min_sell == 0.0:
+                self.min_sell = price
+            if side == 'buy':
+                self.buy_q += amount
+                self.buy_v += amount*price
+            elif side == 'sell':
+                self.sell_q += amount
+                self.sell_v += amount*price
+
+    def _update_position(self, msg):
+        pos = model.Position()
+        for i in msg['data']:
+            symbol = i['symbol']
+            if symbol == self.symbol:
+                amt = float(i["size"])
+                side = i['side'].lower()
+                ep = float(i["entry_price"])
+                if side == 'buy':
+                    pos.longPos = amt
+                    pos.longAvg = ep
+                elif side == 'sell':
+                    pos.shortPos = amt
+                    pos.shortAvg = ep
+                else:
+                    pass
+        self.callback["onPosition"](pos)
+
+    def _update_account(self, msg):
+        for i in msg['data']:
+            self.callback['onEquity'] = {self.quote:float(i['wallet_balance'])}
+    
+    def _update_order(self, msg):
+        self.logger.debug(f"ws订单推送 {msg}")
+        # print(msg)
+        for i in msg['data']:
+            if self.symbol == i['symbol']:
+                if i["order_status"] == 'New':  # 新增订单
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['filled'] = 0
+                    order_event['filled_price'] = 0
+                    order_event['client_id'] = i["order_link_id"] if "order_link_id" in i else ""
+                    order_event['order_id'] = i['order_id']
+                    order_event['fee'] = 0.0
+                    self.callback["onOrder"](order_event)
+                    # print('新建',order_event['client_id'])
+                elif i["order_status"] in ['Filled','Cancelled']:  # 删除订单
+                    # fee 负数是扣手续费 bitget没有返佣
+                    order_event = dict() 
+                    order_event['status'] = "REMOVE"
+                    order_event['client_id'] = i["order_link_id"] if "order_link_id" in i else ""
+                    order_event['order_id'] = i['order_id']
+                    order_event['filled'] = float(i["cum_exec_qty"])
+                    order_event['filled_price'] = float(i["last_exec_price"]) if 'last_exec_price' in i else float(i['price'])
+                    order_event['fee'] = float(i['cum_exec_fee'])
+                    self.callback["onOrder"](order_event)
+                    # print('移除',order_event['client_id'])
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket']({'name': self.name,'data':market_data})
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    async def get_token(self, is_auth):
+        # 获取 token
+        if is_auth:
+            uri = "/api/v1/bullet-private"
+        else:
+            uri = "/api/v1/bullet-public"
+        headers = {}
+        if is_auth:
+            now_time = int(time.time()) * 1000
+            str_to_sign = str(now_time) + "POST" + uri
+            sign = base64.b64encode(hmac.new(self.params.secret_key.encode('utf-8'), str_to_sign.encode('utf-8'), hashlib.sha256).digest())
+            passphrase = base64.b64encode(hmac.new(self.params.secret_key.encode('utf-8'), self.params.pass_key.encode('utf-8'), hashlib.sha256).digest())
+            headers = {
+                "KC-API-SIGN": sign.decode(),
+                "KC-API-TIMESTAMP": str(now_time),
+                "KC-API-KEY": self.params.access_key,
+                "KC-API-PASSPHRASE": passphrase.decode(),
+                "Content-Type": "application/json",
+                "KC-API-KEY-VERSION": "2"
+            }
+        headers["User-Agent"] = "kucoin-python-sdk/v1.0"
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            self.BaseURL+uri, 
+            timeout=5, 
+            headers=headers,
+            proxy=self.proxy
+            )
+        res = await response.text()
+        res = ujson.loads(res)
+        await session.close()
+        if res["code"] == "200000":
+            token = res["data"]["token"]
+            ws_connect_id = str(uuid4()).replace('-', '')
+            endpoint = res["data"]['instanceServers'][0]['endpoint']
+            ws_endpoint = f"{endpoint}?token={token}&connectId={ws_connect_id}"
+            encrypt = res["data"]['instanceServers'][0]['encrypt']
+            if is_auth:
+                ws_endpoint += '&acceptUserMessage=true'
+            return ws_endpoint, encrypt
+        else:
+            raise Exception("kucoin usdt swap 获取token错误")
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        ''''''
+        asyncio.create_task(self.run_public(sub_trade, sub_fast))
+        if is_auth:
+            asyncio.create_task(self.run_private())
+        while True:
+            await asyncio.sleep(5)
+
+    async def run_private(self):
+        '''
+            订阅private频道
+        '''
+        while True:
+            try:
+                ping_time = time.time()
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                # 登陆
+                ws_url = self.BaseURL_private
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws private 连接成功')
+                    self.logger.info(f'{self.name} ws private 连接成功')
+                    # 先鉴权
+                    # Generate expires.
+                    expires = int((time.time() + 10000) * 1000)
+                    # Generate signature.
+                    signature = str(hmac.new(
+                        bytes(self.params.secret_key, "utf-8"),
+                        bytes(f"GET/realtime{expires}", "utf-8"), digestmod="sha256"
+                    ).hexdigest())
+                    await _ws.send_str(ujson.dumps({
+                        "op":"auth",
+                            "args":[
+                                self.params.access_key, expires, signature
+                            ]
+                    }))
+                    # 订阅              
+                    channels = [
+                        "position",
+                        "wallet",
+                        "order",
+                    ]
+                    for i in channels:
+                        sub_str = ujson.dumps({"args": [i], "op":"subscribe"})
+                        await _ws.send_str(sub_str)
+                    while True:
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=300)
+                        except:
+                            print(f'{self.name} ws长时间没有收到消息 private 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 private 准备重连...')
+                            break
+                        # self.logger.debug(msg)
+                        try:
+                            msg = ujson.loads(msg.data)
+                        except:
+                            # self.logger.warning(f'非json格式string:{msg}')
+                            pass
+                        # print(msg)
+                        # 处理消息
+                        if 'topic' in msg:
+                            if 'wallet' in msg['topic']:self._update_account(msg)
+                            elif 'order' in msg['topic']:self._update_order(msg)
+                            elif 'position' in msg['topic']:self._update_position(msg)
+                        # heartbeat
+                        if time.time() - ping_time > 15:
+                            await _ws.send_str('{"op": "ping"}')
+                            ping_time = time.time()
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(traceback.format_exc())
+                # await asyncio.sleep(1)
+
+    async def run_public(self, sub_trade=0, sub_fast=0):
+        '''
+            订阅public频道
+        '''
+        self.sub_fast = sub_fast
+        while True:
+            try:
+                ping_time = time.time()
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                # 登陆
+                url = self.BaseURL_public
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws public 连接成功')
+                    self.logger.info(f'{self.name} ws public 连接成功')
+                    # 订阅
+                    channels=[
+                        f"orderBookL2_25.{self.symbol}" # 推送频率20ms
+                        ]
+                    if sub_trade:
+                        channels += [
+                            f"trade.{self.symbol}"
+                        ]
+                    for i in channels:
+                        sub_str = ujson.dumps({"args": [i], "op":"subscribe"})
+                        await _ws.send_str(sub_str)
+                    while True:
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except:
+                            print(f'{self.name} ws长时间没有收到消息 public 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 public 准备重连...')
+                            break
+                        # self.logger.debug(msg)
+                        try:
+                            msg = ujson.loads(msg.data)
+                        except:
+                            # self.logger.warning(f'非json格式string:{msg}')
+                            pass
+                        # print(msg)
+                        # 处理消息
+                        if 'data' in msg:
+                            if f'orderBookL2_25.{self.symbol}' == msg['topic']:self._update_depth(msg)
+                            elif 'trade' in msg['topic']:self._update_trade(msg)
+                        # heartbeat
+                        if time.time() - ping_time > 15:
+                            await _ws.send_str('{"op": "ping"}')
+                            ping_time = time.time()
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(traceback.format_exc())
+                # await asyncio.sleep(1)
+
+if __name__ == "__main__":
+
+    p = model.ClientParams()
+
+    p.name = ""
+    p.pair = "matic_usdt"
+    p.proxy = "http://127.0.0.1:7890"  
+    p.access_key = "nVrNVv0HQ9a1IgaDeC"
+    p.secret_key = "7zJpfh8rImdrtNO2GnKnGMscKdAJkVMnt6Jl"
+    p.pass_key = "qwer1234"
+    p.interval = "0.1"
+    p.broker_id = "x-nXtHr5jj"
+    p.debug = "False"
+
+    ws = BybitUsdtSwapWs(p, is_print=1)
+
+    loop = asyncio.get_event_loop()
+
+    tasks = [
+        asyncio.ensure_future(ws.run(is_auth=1, sub_trade=1)),
+        # asyncio.ensure_future(ws.go()),
+    ]
+    loop.run_until_complete(asyncio.wait(tasks))
+

+ 602 - 0
exchange/coinex_spot_rest.py

@@ -0,0 +1,602 @@
+import random
+import aiohttp
+import time
+import asyncio
+import zlib
+import json
+import hmac
+import base64
+import hashlib
+import traceback
+import urllib
+from urllib import parse
+from urllib.parse import urljoin
+import datetime
+import sys
+from urllib.parse import urlparse
+import logging
+import logging.handlers
+import utils
+import logging
+import logging.handlers
+import model
+from decimal import Decimal
+
+
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+
+class CoinExSpotRest:
+
+    def __init__(self, params: model.ClientParams, colo=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.HOST = 'https://api.coinex.com'
+        else:
+            self.HOST = 'https://api.coinex.com'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        self.data = {}
+        self._SESSIONS = dict()
+        self.logger = self.get_logger()
+        self.data['account'] = {}
+        self.callback = {
+            "onMarket": empty_call,
+            "onPosition": empty_call,
+            "onOrder": empty_call,
+            "onEquity": empty_call,
+            "onTicker": empty_call,
+            "onExit": empty_call,
+        }
+        self.exchange_info = dict()
+        self.tickSize = None
+        self.stepSize = None
+        self.delays = []
+        self.max_delay = 0
+        self.avg_delay = 0
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.stop_flag = 0
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter(
+            '[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(
+            f"log.log", maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def get_logger(self):
+
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter(
+            '[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(
+            "log.log", maxBytes=1024*1024, encoding='utf-8')
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        # log to console
+        console = logging.StreamHandler()
+        console.setLevel(logging.WARNING)
+        logger.addHandler(handler)
+        logger.addHandler(console)
+        return logger
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+
+    def get_sign(self, params, secret_key):
+        sort_params = sorted(params)
+        data = []
+        for item in sort_params:
+            data.append(item + '=' + str(params[item]))
+        str_params = "{0}&secret_key={1}".format('&'.join(data), secret_key)
+        token = hashlib.md5(str_params.encode("utf8")).hexdigest().upper()
+        return token
+
+    async def _request(self, method, uri, body=None, params=None, auth=False):
+        url = urljoin(self.HOST, uri)
+        headers = {}
+        if auth:
+            if method in ['GET', 'DELETE']:
+                params['access_id'] = self.params.access_key
+                params['tonce'] = int(time.time()*1000)
+                headers = {
+                    'AUTHORIZATION': self.get_sign(params, self.params.secret_key),
+                    'Content-Type': 'application/json; charset=utf-8',
+                    'Accept': 'application/json',
+                    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/60.0.3112.90 Safari/537.36'
+                }
+            elif method == 'POST':
+                body['access_id'] = self.params.access_key
+                body['tonce'] = int(time.time()*1000)
+                headers = {
+                    'AUTHORIZATION': self.get_sign(body, self.params.secret_key),
+                    'Content-Type': 'application/json; charset=utf-8',
+                    'Accept': 'application/json',
+                    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/60.0.3112.90 Safari/537.36'
+                }
+        # 发起请求
+        session = self._get_session(url)
+        timeout = aiohttp.ClientTimeout(10)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, params=None, data=json.dumps(body), headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json()
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            res_msg = msg + f' 回报 {res}'
+            self.logger.debug(res_msg)
+            if code not in (200, 201, 202, 203, 204, 205, 206) or res['code'] not in (0, 200):
+                print(f'URL:{url} PARAMS:{params} body:{body} ERROR:{res}')
+                return None, res
+            return res, None
+        except Exception as e:
+            print(f'{self.name} rest 请求出错', str(e))
+            self.logger.error('请求错误'+str(e))
+            self.logger.error(traceback.format_exc())
+            return None, e
+
+    async def buy_token(self):
+        '''买入平台币'''
+        pass
+
+    async def check_position(self, hold_coin=0.0):
+        '''
+            现货交易 已支持全品种
+        '''
+        try:
+            #######################
+            self.logger.info("清空挂单")
+            params = {
+                'access_id': self.params.access_key,
+                'market': self.symbol,
+            }
+            response, error = await self._request('DELETE', '/v1/order/pending', params=params, auth=1)
+            if error:
+                self.logger.info(error)
+            if response:
+                self.logger.info(error)
+            #############################
+            res, err = await self.get_account()
+            if err:
+                self.logger.info(err)
+            if res:
+                for i in res['data']:
+                    if i in ['CET','USDT']:
+                        continue
+                    symbol = i + 'USDT'
+                    #######################
+                    ticker, error = await self._request('GET', f'/v1/market/depth', params={"market": symbol, 'merge': '0.00000001'}, auth=0)
+                    if ticker:
+                        ap = float(ticker["data"]['asks'][0][0])
+                        bp = float(ticker["data"]['bids'][0][0])
+                        mp = (ap+bp)*0.5
+                    if error:
+                        self.logger.error('ger ticker failed!')
+                        continue
+                    coin = float(res['data'][i]['available']) + \
+                        float(res['data'][i]['frozen'])
+                    coin_value = coin * mp
+                    if i == self.base:
+                        _hold_coin = hold_coin
+                    else:
+                        _hold_coin = 0
+                    diff = _hold_coin - coin_value
+                    diff *= 0.99  # 避免无法下单
+                    self.logger.info(f'{symbol}需要调整现货仓位{diff}usd')
+                    if diff > 20.0:
+                        #######################
+                        self.logger.info("清空挂单")
+                        params = {
+                            'access_id': self.params.access_key,
+                            'market': symbol,
+                        }
+                        response, error = await self._request('DELETE', '/v1/order/pending', params=params, auth=1)
+                        if error:
+                            self.logger.info(error)
+                        if response:
+                            self.logger.info(error)
+                        #############################
+                        info = await self.take_order(symbol, diff/mp, "kd", mp*1.001, "123", "limit")
+                        self.logger.info(info)
+                    elif diff < -20.0:
+                        #######################
+                        self.logger.info("清空挂单")
+                        params = {
+                            'access_id': self.params.access_key,
+                            'market': symbol,
+                        }
+                        response, error = await self._request('DELETE', '/v1/order/pending', params=params, auth=1)
+                        if error:
+                            self.logger.info(error)
+                        if response:
+                            self.logger.info(error)
+                        #############################
+                        info = await self.take_order(symbol, -diff/mp, "kk", mp*0.999, "123", "limit")
+                        self.logger.info(info)
+            #######################        
+        except:
+            self.logger.error("清仓程序执行出错")
+            self.logger.error(traceback.format_exc())
+        return
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='limit'):
+        if origin_side == 'kd':
+            side = 'buy'
+        elif origin_side == 'pd':
+            side = 'sell'
+        elif origin_side == 'kk':
+            side = 'sell'
+        elif origin_side == 'pk':
+            side = 'buy'
+        else:
+            print("现货不允许此交易方向")
+            return None
+        if symbol not in self.exchange_info:
+            await self.before_trade()
+            amount = float(Decimal(str(amount//self.exchange_info[symbol].stepSize))
+                           * Decimal(str(self.exchange_info[symbol].stepSize)))
+            price = float(Decimal(str(price//self.exchange_info[symbol].tickSize))
+                          * Decimal(str(self.exchange_info[symbol].tickSize)))
+        if amount <= 0:
+            self.logger.error(f'下单参数错误 amount:{amount}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0
+            order_event['fee'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        if price <= 0:
+            self.logger.error(f'下单参数错误 price:{price}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0
+            order_event['fee'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        params = {
+            'access_id': self.params.access_key,
+            'client_id': cid,
+            'market': symbol,
+            'amount': utils.num_to_str(amount, self.exchange_info[symbol].stepSize),
+            'type': side,
+            'price': utils.num_to_str(price, self.exchange_info[symbol].tickSize),
+        }
+        # logger.info(f'下单指令 {params}')
+        if self.params.debug == 'True':
+            return await asyncio.sleep(0.1)
+        else:
+            # 发单
+            response, error = await self._request('POST', '/v1/order/limit', body=params, auth=1)
+            # 再更新
+            if response:
+                # logger.info(f'下单回报 {response}')
+                # 增加新的
+                if 'data' in response:
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['client_id'] = params["client_id"]
+                    order_event['order_id'] = response['data']["id"]
+                    self.callback["onOrder"](order_event)
+            if error:
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['filled_price'] = 0.0
+                order_event['fee'] = 0.0
+                order_event['filled'] = 0.0
+                order_event['client_id'] = params["client_id"]
+                self.callback["onOrder"](order_event)
+                return error
+        return response
+
+    async def cancel_order(self, order_id=None, client_id=None):
+        if order_id:
+            response, error = await self._request('DELETE', f'/v1/order/pending', params={'market': self.symbol, 'id': order_id}, auth=1)
+        elif client_id:
+            response, error = await self._request('DELETE', f'/v1/order/pending', params={'market': self.symbol, 'id': client_id}, auth=1)
+        else:
+            raise Exception("撤单出错 没指定订单号")
+        if response:
+            self.logger.debug(f'撤单回报 {response}')
+            # if response["data"]['status'] in ['cancel','done']:
+            #     order_event = dict()
+            #     order_event['status'] = "REMOVE"
+            #     order_event['filled_price'] = float(response['data']['avg_price'])
+            #     asset_fee = float(response['data']["asset_fee"])
+            #     money_fee = float(response['data']["money_fee"])
+            #     stock_fee = float(response['data']["stock_fee"])
+            #     # 非amm品种 优先扣cet 其次u 再次b
+            #     # amm品种 买入收b 卖出收u
+            #     if response['data']['type'] == "sell":
+            #         # 卖出
+            #         order_event['fee'] = money_fee
+            #     elif response['data']['type'] == "buy":
+            #         # 买入
+            #         order_event['fee'] = stock_fee
+            #     order_event['filled'] = float(response['data']['amount']) - float(response['data']['left'])
+            #     order_event['client_id'] = response['data']["client_id"]
+            #     self.callback["onOrder"](order_event)
+        if error:
+            print("撤单失败", error)
+            self.logger.error(error)
+            # if client_id:await self.check_order(client_id=client_id)
+            # if order_id:await self.check_order(order_id=order_id)
+        return response
+
+    async def check_order(self, order_id=None, client_id=None):
+        if order_id:
+            response, error = await self._request('GET', f'/v1/order/status', params={'market': self.symbol, 'id': order_id}, auth=1)
+        elif client_id:
+            response, error = await self._request('GET', f'/v1/order/status', params={'market': self.symbol, 'id': client_id}, auth=1)
+        else:
+            return
+        if response:
+            self.logger.debug(f'查单回报 {response}')
+            order_event = dict()
+            if response["data"]['status'] in ['not_deal', 'part_deal']:
+                order_event['status'] = "NEW"
+            elif response["data"]['status'] in ['cancel', 'done']:
+                order_event['status'] = "REMOVE"
+            else:
+                self.logger.error("错误的订单状态")
+            order_event['price'] = float(response["data"]["price"])
+            order_event['amount'] = float(response["data"]["amount"])
+            order_event['filled'] = float(
+                response["data"]["amount"])-float(response["data"]["left"])
+            order_event['filled_price'] = float(response["data"]["avg_price"])
+            order_event['client_id'] = response["data"]["client_id"]
+            order_event['order_id'] = response["data"]['id']
+            asset_fee = float(response['data']["asset_fee"])
+            money_fee = float(response['data']["money_fee"])
+            stock_fee = float(response['data']["stock_fee"])
+            # 非amm品种 优先扣cet 其次u 再次b
+            # amm品种 买入收b 卖出收u
+            if response['data']['type'] == "sell":
+                # 卖出
+                order_event['fee'] = money_fee
+            elif response['data']['type'] == "buy":
+                # 买入
+                order_event['fee'] = stock_fee
+            self.callback["onOrder"](order_event)
+        if error:
+            print("查单失败", error)
+            self.logger.error(error)
+        return response
+
+    async def get_order_list(self):
+        params = {
+            'market': self.symbol,
+            'page': 1,
+            'limit': 100,
+        }
+        response, error = await self._request('GET', '/v1/order/pending', params=params, auth=1)
+        orders = []  # 重置本地订单列表
+        if response is not None:
+            for i in response['data']['data']:
+                order_event = dict()
+                order_event['symbol'] = self.symbol
+                order_event['price'] = float(i["price"])
+                order_event['amount'] = float(i["amount"])
+                order_event['filled'] = float(i["amount"])-float(i["left"])
+                order_event['filled_price'] = float(i["avg_price"])
+                order_event['client_id'] = i["client_id"] if 'client_id' in i else ""
+                order_event['order_id'] = i['id']
+                asset_fee = float(i["asset_fee"])
+                money_fee = float(i["money_fee"])
+                stock_fee = float(i["stock_fee"])
+                # 非amm品种 优先扣cet 其次u 再次b
+                # amm品种 买入收b 卖出收u
+                if i['type'] == "sell":
+                    # 卖出
+                    order_event['fee'] = money_fee
+                elif i['type'] == "buy":
+                    # 买入
+                    order_event['fee'] = stock_fee
+                if i['status'] in ['not_deal', 'part_deal']:
+                    order_event['status'] = "NEW"
+                elif i['status'] in ['cancel', 'done']:
+                    order_event['status'] = "REMOVE"
+                else:
+                    self.logger.error("错误的订单状态")
+                self.callback["onOrder"](order_event)
+        if error:
+            print(error)
+        return response
+
+    async def get_server_time(self):
+        params = {}
+        response = await self._request('GET', '/api/v1/timestamp', params=params)
+        return response
+
+    async def get_account(self):
+        return await self._request('GET', '/v1/balance/info', params={"access_id": self.params.access_key}, auth=1)
+
+    async def get_market_details(self):
+        return await self._request('GET', f'/v1/market/info', params={}, auth=0)
+
+    async def get_ticker(self):
+        res, err = await self._request('GET', f'/v1/market/depth', params={"market": self.symbol, 'merge': '0.00000001'}, auth=0)
+        if res:
+            ap = float(res["data"]['asks'][0][0])
+            bp = float(res["data"]['bids'][0][0])
+            mp = (ap+bp)*0.5
+            d = {"name": self.name, 'mp': mp, 'bp': bp, 'ap': ap}
+            self.callback['onTicker'](d)
+            return d
+        if err:
+            self.logger.error(err)
+            return None
+
+    async def before_trade(self):
+        # 获取市场基本情况
+        res, error = await self.get_market_details()
+        if error:
+            pass
+        else:
+            for i in res['data']:
+                if res['data'][i]['name'] == self.symbol:
+                    self.stepSize = float(Decimal("0.1")**Decimal(res['data'][i]["trading_decimal"]))
+                    self.tickSize = float(Decimal("0.1")**Decimal(res['data'][i]["pricing_decimal"]))
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i
+                exchange_info.multiplier = 1
+                exchange_info.stepSize = float(Decimal("0.1")**Decimal(res['data'][i]["trading_decimal"]))
+                exchange_info.tickSize = float(Decimal("0.1")**Decimal(res['data'][i]["pricing_decimal"]))
+                self.exchange_info[exchange_info.symbol] = exchange_info
+
+    async def get_equity(self):
+        # 更新账户
+        res, err = await self.get_account()
+        if err:
+            print(err)
+        if res:
+            for i in res["data"]:
+                if self.quote == i:
+                    self.data['equity'] = float(
+                        res['data'][i]['available'])+float(res['data'][i]['frozen'])
+                    self.callback['onEquity']({
+                        self.quote: self.data['equity']
+                    })
+                    self.cash_value = self.data['equity']
+                elif self.base == i:
+                    coin = float(res['data'][i]['available']) + \
+                        float(res['data'][i]['frozen'])
+                    self.callback['onEquity']({
+                        self.base: coin
+                    })
+                    self.coin_value = coin
+
+    async def go(self):
+        await self.before_trade()
+        await asyncio.sleep(1)
+        ### 检查是否为AMMM品种
+        try:
+            async with aiohttp.ClientSession(connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )) as session:
+                response = await session.get(
+                    "https://api.coinex.com/v1/amm/market",
+                    proxy=self.proxy
+                )
+                res = await response.json()
+                amm_list = res['data']
+                print(f'AMM列表{amm_list}')
+                if self.symbol in amm_list:
+                    self.callback['onExit'](f"{self.name} coinex spot 禁止跑AMM品种")
+                else:
+                    print(f'不是AMM品种 正常运行')
+        except:
+            self.logger.error(traceback.format_exc())
+            self.callback['onExit'](f"{self.name} coinex spot AMM列表获取失败")
+        while 1:
+            try:
+                # 停机信号
+                if self.stop_flag:
+                    return
+                # 更新账户
+                res, err = await self.get_account()
+                if err:
+                    print(err)
+                if res:
+                    for i in res["data"]:
+                        if self.quote == i:
+                            self.data['equity'] = float(
+                                res['data'][i]['available']) + float(res['data'][i]['frozen'])
+                            self.callback['onEquity']({
+                                self.quote: self.data['equity']
+                            })
+                        elif self.base == i:
+                            coin = float(res['data'][i]['available']) + \
+                                float(res['data'][i]['frozen'])
+                            self.callback['onEquity']({
+                                self.base: coin
+                            })
+                # 更新订单
+                # res = await self.get_order_list()
+                await asyncio.sleep(60)
+                # 打印延迟
+                self.get_delay_info()
+                self.logger.debug(f'报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(10)
+
+    def get_data(self):
+        return self.data
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+
+    async def handle_signals(self, orders):
+        '''执行策略指令'''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    # 只能用oid撤单
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3]
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    # cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    asyncio.get_event_loop().create_task(self.check_order(order_id=oid))
+        except Exception as e:
+            traceback.print_exc()
+            self.logger.error("执行信号出错"+str(e))
+            await asyncio.sleep(0.1)
+

+ 350 - 0
exchange/coinex_spot_ws.py

@@ -0,0 +1,350 @@
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random, csv, sys, utils
+import logging, logging.handlers
+import model
+
+def empty_call(msg):
+    pass
+
+
+class CoinExSpotWs:
+
+    def __init__(self, params:model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.URL = 'wss://socket.coinex.com/'
+        else:
+            self.URL = 'wss://socket.coinex.com/'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        self.callback = {
+            "onMarket":self.save_market,
+            "onDepth":empty_call,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.stop_flag = 0
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.update_t = 0
+        self.depth = []
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = self.params.interval
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+    async def get_depth_flash(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        url = f'https://api.binance.com/api/v1/depth?symbol={self.symbol}&limit=1000'
+        session = aiohttp.ClientSession()
+        response = await session.get(
+            url, 
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        depth_flash = await response.text()
+        await session.close()
+        return ujson.loads(depth_flash)
+
+    def _update_depth(self, msg):
+        msg = ujson.loads(msg)
+        t = float(msg['params'][1]['time'])
+        if t > self.update_t:
+            self.update_t = t
+            self.ticker_info["bp"] = float(msg['params'][1]['bids'][0][0])
+            self.ticker_info["ap"] = float(msg['params'][1]['asks'][0][0])
+            self.callback['onTicker'](self.ticker_info)
+            ##### normalize depth
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for bid in msg['params'][1]['bids']:
+                price = float(bid[0])
+                amount = float(bid[1])
+                if price > price_thre:
+                    bv[index] += amount
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += amount
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for ask in msg['params'][1]['asks']:
+                price = float(ask[0])
+                amount = float(ask[1])
+                if price < price_thre:
+                    av[index] += amount
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += amount
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+        else:
+            self.logger.error("coienx ws推送过期信息")
+
+    def _update_trade(self, msg):
+        msg = json.loads(msg)
+        for i in msg['params'][1]:
+            side = i["type"]
+            price = float(i["price"])
+            amount = float(i['amount'])
+            if price > self.max_buy or self.max_buy == 0.0:
+                self.max_buy = price
+            if price < self.min_sell or self.min_sell == 0.0:
+                self.min_sell = price
+            if side == 'buy':
+                self.buy_q += amount
+                self.buy_v += amount*price
+            elif side == 'sell':
+                self.sell_q += amount
+                self.sell_v += amount*price
+
+    def _update_account(self, msg):
+        msg = json.loads(msg)
+        for i in msg['params'][0]:
+            if self.quote == i:
+                cash = float(msg['params'][0][self.quote]['available'])+float(msg['params'][0][self.quote]['frozen'])
+                self.callback['onEquity'] = {
+                    self.quote:cash
+                }
+            elif self.base == i:
+                coin = float(msg['params'][0][self.base]['available'])+float(msg['params'][0][self.base]['frozen'])
+                self.callback['onEquity'] = {
+                    self.base:coin
+                }
+    
+    def _update_order(self, msg):
+        self.logger.debug("ws推送订单"+msg)
+        msg = json.loads(msg)
+        event_type = msg['params'][0]
+        event = msg['params'][1]
+        if event_type == 1:  # 新增订单
+            order_event = dict()
+            order_event['filled'] = 0
+            order_event['filled_price'] = 0
+            order_event['client_id'] = event["client_id"]
+            order_event['order_id'] = event['id']
+            order_event['status'] = "NEW"
+            self.callback["onOrder"](order_event)
+        elif event_type == 3:  # 删除订单
+            order_event = dict()
+            order_event['filled'] = float(event["amount"]) - float(event["left"])
+            order_event['filled_price'] = float(event["price"]) 
+            # asset_fee = float(event["asset_fee"])
+            money_fee = float(event["money_fee"])
+            stock_fee = float(event["stock_fee"])
+            # 非amm品种 优先扣cet 其次u 再次b
+            # amm品种 买入收b 卖出收u
+            if event['side'] == 1:
+                # 卖出
+                order_event['fee'] = money_fee
+            elif event['side'] == 2:
+                # 买入
+                order_event['fee'] = stock_fee
+            order_event['client_id'] = event["client_id"]
+            order_event['order_id'] = event['id']
+            order_event['status'] = "REMOVE"
+            self.callback["onOrder"](order_event)
+
+    def _update_position(self, msg):
+        long_pos, short_pos = 0, 0
+        long_avg, short_avg = 0, 0 
+        msg = ujson.loads(msg)
+        for i in msg['a']['P']:
+            if i['s'] == self.symbol:
+                if i['ps'] == 'LONG':
+                    long_pos += abs(float(i['pa']))
+                    long_avg = abs(float(i['ep']))
+                if i['ps'] == 'SHORT':
+                    short_pos += abs(float(i['pa']))
+                    short_avg = abs(float(i['ep']))
+        pos = model.Position()
+        pos.longPos = long_pos
+        pos.longAvg = long_avg
+        pos.shortPos = short_pos
+        pos.shortAvg = short_avg
+        self.callback['onPosition'](pos)
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        ping_time = time.time()
+        while True:
+            try:
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                # 登陆
+                ws_url = self.URL
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws连接成功')
+                    self.logger.info(f'{self.name} ws连接成功')
+                    # 订阅 coinex 现货
+                    symbol = self.symbol.upper()
+                    # 鉴权
+                    if is_auth:
+                        current_time = int(time.time()*1000)
+                        sign_str = f"access_id={self.params.access_key}&tonce={current_time}&secret_key={self.params.secret_key}"
+                        md5 = hashlib.md5(sign_str.encode())
+                        param = {
+                            "id": 1,
+                            "method": "server.sign",
+                            "params": [self.params.access_key, md5.hexdigest().upper(), current_time]
+                        }
+                        await _ws.send_str(ujson.dumps(param))   
+                        res = await _ws.receive(timeout=30)
+                        # 订阅资产
+                        sub_str = ujson.dumps({"id": 1, "method": "asset.subscribe","params": [self.base,self.quote]})
+                        await _ws.send_str(sub_str)
+                        # 订阅私有订单
+                        sub_str = ujson.dumps({"id": 1, "method": "order.subscribe","params": [symbol]})
+                        await _ws.send_str(sub_str)
+                    if sub_trade:
+                        # 订阅公开成交
+                        sub_str = ujson.dumps({"id": 1, "method": "deals.subscribe","params": [symbol]})
+                        await _ws.send_str(sub_str)
+                    # 订阅深度
+                    sub_str = ujson.dumps({"id": 1, "method": "depth.subscribe","params": [symbol, 50, "0.000000001", False]})
+                    await _ws.send_str(sub_str)
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:
+                            await _ws.close()
+                            return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except asyncio.CancelledError:
+                            print('ws取消')
+                            return
+                        except asyncio.TimeoutError:
+                            print(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            break
+                        except:
+                            print(f'{self.name} ws出现错误 准备重连...')
+                            self.logger.error(f'{self.name} ws出现错误 准备重连...')
+                            self.logger.error(traceback.format_exc())
+                            break
+                        msg = msg.data
+                        # 处理消息
+                        if 'depth.update' in msg:self._update_depth(msg)
+                        elif 'deals.update' in msg:self._update_trade(msg)
+                        elif 'asset.update' in msg:self._update_account(msg)
+                        elif 'order.update' in msg:self._update_order(msg)
+                        else:
+                            print(msg)
+                            pass
+                        if ping_time - time.time() > 60:
+                            ping_time = time.time()
+                            sub_str = ujson.dumps({"id": 1, "method": "server.ping","params": []})
+                            await _ws.send_str(sub_str)
+            except:
+                _ws = None
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                # await asyncio.sleep(1)
+
+
+

+ 542 - 0
exchange/coinex_usdt_swap_rest.py

@@ -0,0 +1,542 @@
+import random
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import hmac
+import base64
+import hashlib
+import traceback
+import urllib
+from urllib import parse
+from urllib.parse import urljoin
+import datetime, sys
+from urllib.parse import urlparse
+import logging, logging.handlers
+import utils
+import logging, logging.handlers
+import model
+from decimal import Decimal
+
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+
+
+class CoinExUsdtSwapRest:
+
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.HOST = 'https://api.coinex.com'
+        else:
+            self.HOST = 'https://api.coinex.com'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        self._SESSIONS = dict()
+        self.logger = self.get_logger()
+        self.callback = {
+            "onMarket":empty_call,
+            "onPosition":empty_call,
+            "onOrder":empty_call,
+            "onEquity":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.exchange_info = dict()
+        self.tickSize = None
+        self.stepSize = None
+        self.delays = []
+        self.max_delay = 0
+        self.avg_delay = 0
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.mp_from_rest = None
+        self.stop_flag = 0
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+        self.min_trade_amount = 0.0
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %( message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def get_logger(self):
+
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler("log.log",maxBytes=1024*1024,encoding='utf-8')
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        # log to console
+        console = logging.StreamHandler()
+        console.setLevel(logging.WARNING)
+        logger.addHandler(handler)
+        logger.addHandler(console)
+        return logger
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+
+    def get_sign(self, params, secret_key):
+        sort_params = params
+        data = []
+        for item in sort_params:
+            data.append(item + '=' + str(params[item]))
+        str_params = "{0}&secret_key={1}".format('&'.join(data), secret_key)
+        token = hashlib.sha256(str_params.encode("utf8")).hexdigest()
+        return token
+
+    async def _request(self, method, uri, body=None, params=None, auth=False):
+        url = urljoin(self.HOST, uri)
+        headers = {}
+        if auth:
+            if method in ['GET','DELETE']:
+                params['timestamp'] = int(time.time()*1000)
+                headers = {
+                    'Authorization':self.get_sign(params, self.params.secret_key),
+                    'Content-Type': 'application/json; charset=utf-8',
+                    'Accept': 'application/json',
+                    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/60.0.3112.90 Safari/537.36',
+                    "AccessId":self.params.access_key
+                }
+            elif method == 'POST':
+                body['timestamp'] = int(time.time()*1000)
+                headers = {
+                    'Authorization':self.get_sign(body, self.params.secret_key),
+                    'Content-Type': 'application/json; charset=utf-8',
+                    'Accept': 'application/json',
+                    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/60.0.3112.90 Safari/537.36',
+                    "AccessId":self.params.access_key
+                }
+        # 发起请求
+        session = self._get_session(url)
+        timeout = aiohttp.ClientTimeout(10)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json(content_type=None) 
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            self.get_delay_info()
+            res_msg = msg + f' 回报 {res}'
+            self.logger.debug(res_msg)
+            if code not in (200, 201, 202, 203, 204, 205, 206) or res['code'] not in (0,200):
+                print(f'URL:{url} PARAMS:{params} body:{body} ERROR:{res}')
+                return None, res
+            return res, None
+        except Exception as e:
+            print(f'{self.name} rest 请求出错', str(e))
+            self.logger.error(f'请求错误 {msg}'+str(e))
+            self.logger.error(traceback.format_exc())
+            return None, e
+
+    async def get_position(self):
+        response, error = await self._request('GET', '/perpetual/v1/position/pending', params={"market":self.symbol}, auth=1)
+        if error:
+            print(error)
+        if response:
+            position = model.Position()
+            for i in response['data']:
+                side = i['side']
+                pos = float(i['amount'])
+                price = float(i['open_price'])
+                if side == 1:
+                    position.shortPos = pos
+                    position.shortAvg = price
+                elif side == 2:
+                    position.longPos = pos
+                    position.longAvg = price
+            return position
+
+    async def buy_token(self):
+        '''买入平台币'''
+        pass
+
+    async def check_position(self, hold_coin=0.0):
+        '''
+            现货交易  已支持全品种
+        '''
+        try:
+            #######################
+            self.logger.info("清空挂单")
+            params = {
+                'market':self.symbol,
+                'side':0,
+            }
+            response, error = await self._request('POST', '/perpetual/v1/order/cancel_all', body=params, auth=1)
+            if error:
+                self.logger.info(error)
+            #############################
+            self.logger.info("清空仓位")
+            res, err = await self._request('GET', '/perpetual/v1/position/pending', params={}, auth=1)
+            if err:
+                self.logger.info(err)
+            if res:
+                for i in res['data']:
+                    if abs(float(i['close_left'])) > 0.0:
+                        params = {
+                            "market":i['market'],
+                            "position_id":i['position_id'],
+                            # "amount":i["amount"]
+                        }
+                        response, error = await self._request('POST', '/perpetual/v1/order/close_market', body=params, auth=1)
+                        if response:
+                            self.logger.info(response)
+                        if error:
+                            self.logger.info(error)
+                            await asyncio.sleep(1)
+                            params = {
+                                "market":i['market'],
+                                "position_id":i['position_id'],
+                            }
+                            response, error = await self._request('POST', '/perpetual/v1/position/market_close', body=params, auth=1)
+                            self.logger.info(response)
+                            self.logger.info(error)
+            #######################        
+        except:
+            self.logger.error("清仓程序执行出错")
+            self.logger.error(traceback.format_exc())
+        return
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='limit'):
+        ''' 
+            coinex swap 平仓需考虑最小下单量 只能通过close_position和position_id来平仓
+        '''
+        if origin_side =='kd':
+            side = "2"
+        elif origin_side =='pd':
+            side = "1"
+        elif origin_side =='kk':
+            side = "1"
+        elif origin_side =='pk':
+            side = "2"
+        else:
+            print("合约不允许此交易方向")
+            return None
+        if symbol not in self.exchange_info:
+            await self.before_trade()
+            # amount = float(Decimal(str(amount//self.exchange_info[symbol].stepSize))*Decimal(str(self.exchange_info[symbol].stepSize)))
+            # price = float(Decimal(str(price//self.exchange_info[symbol].tickSize))*Decimal(str(self.exchange_info[symbol].tickSize)))
+            amount = utils.fix_amount(amount, self.exchange_info[symbol].stepSize)
+            price = utils.fix_price(price, self.exchange_info[symbol].tickSize)
+        if float(amount) <= self.min_trade_amount:
+            # self.logger.error(f'下单参数错误 amount:{amount}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0                
+            order_event['fee'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        if float(price) <= 0.0:
+            self.logger.error(f'下单参数错误 price:{price}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0                
+            order_event['fee'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        params = {
+            'client_id':cid,
+            'market': symbol,
+            'amount':utils.num_to_str(amount, self.exchange_info[symbol].stepSize),
+            'side': side,
+            'price':utils.num_to_str(price, self.exchange_info[symbol].tickSize), 
+        }
+        if self.params.debug == 'True':
+            return await asyncio.sleep(0.1)
+        else:
+            # 发单
+            if order_type == 'limit':
+                response, error = await self._request('POST', '/perpetual/v1/order/put_limit', body=params, auth=1)
+            elif order_type == 'market':
+                response, error = await self._request('POST', '/perpetual/v1/order/put_market', body=params, auth=1)
+            # 再更新
+            if response:
+                # 增加新的
+                if 'data' in response:
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['client_id'] = params["client_id"]
+                    order_event['order_id'] = response['data']["order_id"]
+                    self.callback["onOrder"](order_event)
+            if error:
+                # coinex swap 有时候返回错误也下单成功 很危险
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['filled_price'] = 0.0                
+                order_event['fee'] = 0.0
+                order_event['filled'] = 0.0
+                order_event['client_id'] = params["client_id"]
+                self.callback["onOrder"](order_event)
+        return response
+    
+    async def cancel_order(self, order_id=None, client_id=None):
+        if order_id:
+            response, error = await self._request('POST', f'/perpetual/v1/order/cancel', body={'market':self.symbol,'order_id':order_id}, auth=1)
+        elif client_id:
+            response, error = await self._request('POST', f'/perpetual/v1/order/cancel', body={'market':self.symbol,'order_id':client_id}, auth=1)
+        else:
+            raise Exception("撤单出错 没指定订单号")
+        if response:
+            self.logger.debug(f'撤单回报 {response}')
+            # order_event = dict()
+            # order_event['status'] = "REMOVE"
+            # order_event['filled_price'] = float(response['data']['price'])            
+            # order_event['fee'] = float(response['data']["deal_fee"])
+            # order_event['filled'] = float(response['data']['amount']) - float(response['data']['left'])
+            # order_event['client_id'] = response['data']["client_id"]
+            # self.callback["onOrder"](order_event)
+        if error:
+            print("撤单失败",error)
+            self.logger.error(error)
+        return response
+    
+    async def check_order(self, order_id=None, client_id=None):
+        if order_id:
+            response, error = await self._request('GET', f'/perpetual/v1/order/status', params={'market':self.symbol, 'order_id':order_id}, auth=1)
+        elif client_id:
+            response, error = await self._request('GET', f'/perpetual/v1/order/status', params={'market':self.symbol, 'order_id':client_id}, auth=1)
+        else:
+            return
+        if response:
+            self.logger.debug(f'查单回报 {response}')
+            if response['data']:
+                order_event = dict()
+                if response["data"]['status'] in ['cancel','done']:
+                    order_event['status'] = "REMOVE"
+                else:
+                    order_event['status'] = "NEW"
+                order_event['price'] = float(response["data"]["price"])
+                order_event['amount'] = float(response["data"]["amount"])
+                order_event['filled'] = float(response["data"]["amount"])-float(response["data"]["left"])
+                order_event['filled_price'] = float(response["data"]["price"])
+                order_event['client_id'] = response["data"]["client_id"]
+                order_event['order_id'] = response["data"]['order_id']
+                order_event['fee'] = float(response["data"]['deal_fee'])
+                self.callback["onOrder"](order_event)
+        if error:
+            print("查单失败",error)
+            self.logger.error(error)
+        return response
+
+    async def get_order_list(self):
+        params = {
+            'market':self.symbol,
+            'offset':100,
+            "side":0,
+            'limit':100,
+        }
+        response, error = await self._request('GET', '/perpetual/v1/order/pending', params=params, auth=1)
+        if response is not None:
+            for i in response['data']['records']:
+                order_event = dict()
+                order_event['symbol'] = self.symbol
+                order_event['price'] = float(i["price"])
+                order_event['amount'] = float(i["amount"])
+                order_event['filled'] = float(i["amount"])-float(i["left"])
+                order_event['filled_price'] = float(i["avg_price"])
+                order_event['client_id'] = i["clientOid"]
+                order_event['order_id'] = i['id']
+                asset_fee = float(response['data']["asset_fee"])
+                money_fee = float(response['data']["money_fee"])
+                stock_fee = float(response['data']["stock_fee"])
+                if asset_fee > 0.0: # 非amm品种
+                    order_event['fee'] = asset_fee
+                else: # amm品种
+                    order_event['fee'] = money_fee if money_fee > 0.0 else stock_fee
+                if response["data"]['status'] == 'not_deal':
+                    order_event['status'] = "NEW"
+                elif response["data"]['status'] in ['cancel','done']:
+                    order_event['status'] = "REMOVE"
+                else:
+                    s = response["data"]['status']
+                    self.logger.error(f"错误的订单状态 {s}")
+                self.callback["onOrder"](order_event)
+        if error:
+            print(error)
+        return response
+    
+    async def get_server_time(self):
+        params = {}
+        response = await self._request('GET', '/perpetual/api/v1/timestamp', params=params)
+        return response
+
+    async def get_account(self):
+        return await self._request('GET','/perpetual/v1/asset/query', params={}, auth=1)
+
+    async def get_market_details(self):
+        return await self._request('GET',f'/perpetual/v1/market/list', params={}, auth=0)
+
+    async def get_ticker(self):
+        res ,err = await self._request('GET',f'/perpetual/v1/market/depth', params={"market":self.symbol,'merge':'0.00000001'}, auth=0)
+        if res:
+            ap = float(res["data"]['asks'][0][0])
+            bp = float(res["data"]['bids'][0][0])
+            mp = (ap+bp)*0.5
+            d = {"name":self.name,'mp': mp, 'bp':bp, 'ap':ap}
+            self.callback['onTicker'](d)
+            return d
+        if err:
+            self.logger.error(err)
+            return None 
+
+    async def before_trade(self):
+        # 切换杠杆
+        await self.change_position_side()
+        # 获取市场最新价格
+        res = await self.get_ticker()
+        ticker_price = res["mp"]
+        if isinstance(ticker_price, float):
+            self.mp_from_rest = ticker_price
+        # 获取市场基本情况
+        res, error = await self.get_market_details()
+        if error:
+            pass 
+        else:
+            for i in res['data']:
+                if i['name'] == self.symbol:
+                    self.stepSize = float(Decimal("0.1")**Decimal(i["amount_prec"]))
+                    self.tickSize = float(Decimal("0.1")**Decimal(i["money_prec"]))
+                    self.min_trade_amount = float(i['amount_min'])
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i['name']
+                exchange_info.multiplier = 1
+                exchange_info.tickSize = float(Decimal("0.1")**Decimal(i["money_prec"]))
+                exchange_info.stepSize = float(Decimal("0.1")**Decimal(i["amount_prec"]))
+                self.exchange_info[exchange_info.symbol] = exchange_info
+
+    async def get_equity(self):
+        # 更新账户
+        res, err = await self.get_account()
+        if err:print(err)
+        if res:
+            for i in res["data"]:
+                if self.quote == i:
+                    cash = float(res['data'][i]['available'])+ float(res['data'][i]['frozen'])+ float(res['data'][i]['margin'])+float(res['data'][i]['profit_unreal'])
+                    self.callback['onEquity']({
+                        self.quote:cash
+                    })
+                    self.cash_value = cash
+
+    async def change_position_side(self):
+        res ,err = await self._request('POST',f'/perpetual/v1/market/adjust_leverage', body={"market":self.symbol,'leverage':10,"position_type":2}, auth=1)
+        if err:print(err)
+        if res:print(res)
+
+    async def go(self):
+        interval = 60  # 不能太快防止占用限频
+        await self.before_trade()
+        await asyncio.sleep(1)
+        while 1:
+            try:
+                # 停机信号
+                if self.stop_flag:return
+                # 更新账户
+                res, err = await self.get_account()
+                if err:print(err)
+                if res:
+                    for i in res["data"]:
+                        if self.quote == i:
+                            cash = float(res['data'][i]['available'])+float(res['data'][i]['frozen'])+float(res['data'][i]['margin'])+ float(res['data'][i]['profit_unreal'])
+                            self.callback['onEquity']({
+                                self.quote:cash
+                            })
+                            self.cash_value = cash
+                # 更新仓位
+                res, err = await self._request('GET', '/perpetual/v1/position/pending', params={"market":self.symbol}, auth=1)
+                if err:
+                    self.logger.info(err)
+                if res:
+                    p = model.Position()
+                    for i in res['data']:
+                        if i['side'] == 1:# sell
+                            p.shortPos = float(i['amount'])
+                            p.shortAvg = float(i['open_price'])
+                        if i['side'] == 2:# buy
+                            p.longPos = float(i['amount'])
+                            p.longAvg = float(i['open_price'])
+                    self.callback['onPosition'](p)
+                await asyncio.sleep(interval)
+                # 打印延迟
+                self.logger.debug(f'报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(10)
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+    
+    async def handle_signals(self, orders):
+        '''执行策略指令'''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    # cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    # 只能用oid撤单
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3]
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    # cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    # 只能用oid查单
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.check_order(order_id=oid))
+        except Exception as e:
+            traceback.print_exc()
+            self.logger.error("执行信号出错"+str(e))
+            await asyncio.sleep(0.1)
+
+

+ 331 - 0
exchange/coinex_usdt_swap_ws.py

@@ -0,0 +1,331 @@
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random, csv, sys, utils
+import logging, logging.handlers
+import model
+
+def empty_call(msg):
+    pass
+
+
+class CoinExUsdtSwapWs:
+
+    def __init__(self, params:model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.URL = 'wss://perpetual.coinex.com/'
+        else:
+            self.URL = 'wss://perpetual.coinex.com/'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        self.data = dict()
+        self.data['trade'] = []
+        self.on_trade_ratio = 5.0
+        self.on_trade_value = 1000000.0
+        self.trade_mean = 0.0
+        self.data['force'] = []
+        self.callback = {
+            "onMarket":self.save_market,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.stop_flag = 0
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.depth = []
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = self.params.interval
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+    def _update_depth(self, msg):
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        self.ticker_info["bp"] = float(msg['params'][1]['bids'][0][0])
+        self.ticker_info["ap"] = float(msg['params'][1]['asks'][0][0])
+        self.callback['onTicker'](self.ticker_info)
+        ##### 标准化深度
+        mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+        step = mp * utils.EFF_RANGE / utils.LEVEL
+        bp = []
+        ap = []
+        bv = [0 for _ in range(utils.LEVEL)]
+        av = [0 for _ in range(utils.LEVEL)]
+        for i in range(utils.LEVEL):
+            bp.append(self.ticker_info["bp"]-step*i)
+        for i in range(utils.LEVEL):
+            ap.append(self.ticker_info["ap"]+step*i)
+        # 
+        price_thre = self.ticker_info["bp"] - step
+        index = 0
+        for bid in msg['params'][1]['bids']:
+            price = float(bid[0])
+            amount = float(bid[1])
+            if price > price_thre:
+                bv[index] += amount
+            else:
+                price_thre -= step
+                index += 1
+                if index == utils.LEVEL:
+                    break
+                bv[index] += amount
+        price_thre = self.ticker_info["ap"] + step
+        index = 0
+        for ask in msg['params'][1]['asks']:
+            price = float(ask[0])
+            amount = float(ask[1])
+            if price < price_thre:
+                av[index] += amount
+            else:
+                price_thre += step
+                index += 1
+                if index == utils.LEVEL:
+                    break
+                av[index] += amount
+        self.depth = bp + bv + ap + av
+        self.callback['onDepth']({'name':self.name,'data':self.depth})
+
+    
+    def _update_trade(self, msg):
+        self.public_update_time = time.time()
+        msg = json.loads(msg)
+        for i in msg['params'][1]:
+            side = i["type"]
+            price = float(i["price"])
+            amount = float(i['amount'])
+            if price > self.max_buy or self.max_buy == 0.0:
+                self.max_buy = price
+            if price < self.min_sell or self.min_sell == 0.0:
+                self.min_sell = price
+            if side == 'buy':
+                self.buy_q += amount
+                self.buy_v += amount*price
+            elif side == 'sell':
+                self.sell_q += amount
+                self.sell_v += amount*price
+
+    def _update_account(self, msg):
+        msg = ujson.loads(msg)
+        for i in msg['params'][0]:
+            if self.quote == i:
+                cash = float(msg['params'][0][self.quote]['available'])+float(msg['params'][0][self.quote]['frozen'])+\
+                    float(msg['params'][0][self.quote]['margin'])+float(msg['params'][0][self.quote]['profit_unreal'])
+                self.callback['onEquity'] = {
+                    self.quote:cash
+                }
+    
+    def _update_order(self, msg):
+        self.logger.debug("ws推送订单"+msg)
+        msg = ujson.loads(msg)
+        event_type = msg['params'][0]
+        event = msg['params'][1]
+        if event_type == 1:  # 新增订单
+            order_event = dict()
+            order_event['filled'] = 0.0
+            order_event['filled_price'] = 0.0
+            order_event['client_id'] = event["client_id"]
+            order_event['order_id'] = event['order_id']
+            order_event['status'] = "NEW"
+            self.callback["onOrder"](order_event)
+        elif event_type == 3:  # 删除订单
+            order_event = dict()
+            order_event['filled'] = float(event["amount"]) - float(event["left"])
+            order_event['filled_price'] = float(event["price"]) 
+            order_event['fee'] = float(event["deal_fee"])
+            order_event['client_id'] = event["client_id"]
+            order_event['order_id'] = event['order_id']
+            order_event['status'] = "REMOVE"
+            self.callback["onOrder"](order_event)
+
+    def _update_position(self, msg):
+        long_pos, short_pos = 0, 0
+        long_avg, short_avg = 0, 0 
+        msg = ujson.loads(msg)
+        msg = msg['params'][1]
+        if msg['market'] == self.symbol:
+            side = msg['side']
+            pos = float(msg['amount'])
+            price = float(msg['open_price'])
+            if side == 1:
+                short_pos = pos
+                short_avg = price
+            elif side == 2:
+                long_pos = pos
+                long_avg = price
+        pos = model.Position()
+        pos.longPos = long_pos
+        pos.longAvg = long_avg
+        pos.shortPos = short_pos
+        pos.shortAvg = short_avg
+        self.callback['onPosition'](pos)
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        ping_time = time.time()
+        while True:
+            try:
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                # 登陆
+                ws_url = self.URL
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws连接成功')
+                    self.logger.info(f'{self.name} ws连接成功')
+                    # 订阅 coinex 合约行情
+                    symbol = self.symbol.upper()
+                    # 鉴权
+                    if is_auth:
+                        current_time = int(time.time()*1000)
+                        sign_str = f"access_id={self.params.access_key}&timestamp={current_time}&secret_key={self.params.secret_key}"
+                        md5 = hashlib.sha256(sign_str.encode())
+                        param = {
+                            "id": 1,
+                            "method": "server.sign",
+                            "params": [self.params.access_key, md5.hexdigest().lower(), current_time]
+                        }
+                        await _ws.send_str(ujson.dumps(param))
+                        res = await _ws.receive(timeout=30)
+                        self.logger.info(res)
+                        # 订阅资产
+                        sub_str = ujson.dumps({"id": 1, "method": "asset.subscribe","params": [self.quote]})
+                        await _ws.send_str(sub_str)
+                        # 订阅订单
+                        sub_str = ujson.dumps({"id": 1, "method": "order.subscribe","params": [symbol]})
+                        await _ws.send_str(sub_str)
+                        # 订阅仓位
+                        sub_str = ujson.dumps({"id": 1, "method": "position.subscribe","params": [symbol]})
+                        await _ws.send_str(sub_str)
+                    if sub_trade:
+                        # 订阅公开成交
+                        sub_str = ujson.dumps({"id": 1, "method": "deals.subscribe","params": [symbol]})
+                        await _ws.send_str(sub_str)
+                    # 订阅深度
+                    sub_str = ujson.dumps({"id": 1, "method": "depth.subscribe","params": [symbol, 50, "0.000000001", False]})
+                    await _ws.send_str(sub_str)
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:
+                            await _ws.close()
+                            return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except asyncio.CancelledError:
+                            print('ws取消')
+                            return
+                        except asyncio.TimeoutError:
+                            print(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            break
+                        except:
+                            print(f'{self.name} ws出现错误 准备重连...')
+                            self.logger.error(f'{self.name} ws出现错误 准备重连...')
+                            self.logger.error(traceback.format_exc())
+                            break
+                        msg = msg.data
+                        # print(msg)
+                        # 处理消息
+                        if 'depth.update' in msg:self._update_depth(msg)
+                        elif 'deals.update' in msg:self._update_trade(msg)
+                        elif 'asset.update' in msg:self._update_account(msg)
+                        elif 'order.update' in msg:self._update_order(msg)
+                        elif 'position.update' in msg:self._update_position(msg)
+                        else:
+                            print(msg)
+                            pass
+                        if ping_time - time.time() > 60:
+                            ping_time = time.time()
+                            sub_str = ujson.dumps({"id": 1, "method": "server.ping","params": []})
+                            await _ws.send_str(sub_str)
+            except:
+                _ws = None
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                # await asyncio.sleep(1)
+
+
+

+ 369 - 0
exchange/ftx_spot_ws.py

@@ -0,0 +1,369 @@
+
+import aiohttp
+import time
+import asyncio
+import json, ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random, csv, sys
+import logging, logging.handlers
+from itertools import zip_longest
+from datetime import datetime
+import urllib
+import utils
+import model
+from collections import defaultdict, deque
+
+
+
+
+def empty_call(msg):
+    pass
+
+ZERO = 1e-8
+
+
+class FtxSpotWs:
+    """"""
+    def __init__(self, params:model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.URL = 'wss://ftx.com/ws/'
+        else:
+            self.URL = 'wss://ftx.com/ws/'
+        self.params = params
+        self.name = self.params.name
+        #
+        self.base = params.pair.split('_')[0].upper()
+        self.quote = params.pair.split('_')[1].upper()
+        if self.quote == "USDT":
+            self.quote = "USD"
+        self.symbol = f"{self.base}/{self.quote}"
+        #print(self.symbol)
+        #
+        self.data = dict()
+        self.data['trade'] = []
+        self.callback = {
+            "onMarket":self.save_market,
+            "onDepth":empty_call,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.depth_update = []
+        self.need_flash = 1
+        self.updata_u = None
+        self.last_update_id = None
+        self.depth = dict()
+        self.depth['bids'] = dict()
+        self.depth['asks'] = dict()
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.stop_flag = 0
+        self.public_update_time = time.time()
+        self.private_update_time = time.time()
+        self.expired_time = 300
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.stepSize = None
+        self.tickSize = None
+        self.ctVal    = None  # 合约乘数
+        self.ctMult   = None  # 合约面值
+        
+        self.depth = []
+
+        self._reset_orderbook()
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+
+    def _reset_orderbook(self) -> None:
+        self._orderbook_timestamp = 0
+        self._orderbook = {side: defaultdict(float) for side in ['bids','asks']}
+
+    def save_market(self, msg):
+        print(msg)
+        #pass
+        #date = time.strftime('%Y-%m-%d',time.localtime())
+        #interval = self.params.interval
+        #if msg:
+        #    exchange = msg['name']
+        #    if len(msg['data']) > 1:
+        #        with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+        #                    'a',
+        #                    newline='',
+        #                    encoding='utf-8') as f:
+        #            writer = csv.writer(f, delimiter=',')
+        #            writer.writerow(msg['data'])
+        #        if self.is_print:print(f'写入行情 {self.symbol}')
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    def subscribe_public(self, sub_trade=0):
+        channels = [
+            "orderbook",
+            # "ticker"
+        ]
+        if sub_trade:
+            channels.append("trades")
+        subs = [ujson.dumps({'op':'subscribe','market':self.symbol, 'channel':channel}) for channel in channels]
+        return subs
+
+    async def run_public(self, sub_trade=0):
+        """"""
+        while 1:
+            try:
+                self.public_update_time = time.time()
+                print(f"{self.name} public 尝试连接ws")
+                ws_url = self.URL
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f"{self.name} public ws连接成功")
+                    self.logger.debug(f"{self.name} public ws连接成功")
+                    for sub in self.subscribe_public(sub_trade):
+                        await _ws.send_str(sub)
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:
+                            await _ws.close()
+                            return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive()
+                        except:
+                            print(f'{self.name} public ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} public ws长时间没有收到消息 准备重连...')
+                            break
+                        msg = msg.data
+                        await self.on_message_public(_ws, msg)
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws public 连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws public 连接失败 开始重连...')
+                self.logger.error(traceback.format_exc())
+                await asyncio.sleep(1)
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        asyncio.create_task(self.run_public(sub_trade))
+        while True:
+            await asyncio.sleep(5)
+
+    async def on_message_public(self, _ws, msg):
+        """"""
+        #print(msg)
+        if "data" in msg:
+            # 推送数据时,有data字段,优先级也最高
+            if "ticker" in msg:
+                self._update_ticker(msg)
+            elif "trades" in msg:
+                self._update_trade(msg)
+            elif "orderbook" in msg:
+                await self._update_depth(_ws, msg)
+        elif "type" in msg:
+            # event常见于事件回报,一般都可以忽略,只需要看看是否有error
+            if "error" in msg:
+                info = f'{self.name} on_message error! --> {msg}'
+                print(info)
+                self.logger.error(info)
+        elif 'ping' in msg:
+            await _ws.send_str('pong')
+        else:
+            print(msg)
+
+    def _update_ticker(self, msg):
+        """"""
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        ticker = msg['data']
+        bp = float(ticker['bid']) if ticker['bid'] != 'null' else 0
+        ap = float(ticker['ask']) if ticker['ask'] != 'null' else 0
+        lp = float(ticker['last']) if ticker['last'] != 'null' else 0
+        self.ticker_info["bp"] = bp
+        self.ticker_info["ap"] = ap
+        self.callback['onTicker'](self.ticker_info)
+
+    def _update_trade(self, msg):
+        """"""
+        msg = ujson.loads(msg)
+        for trade in msg['data']:
+            price = float(trade['price'])
+            amount = float(trade['size'])
+            side = trade['side']
+            if side == 'buy':
+                self.buy_q += amount
+                self.buy_v += amount*price
+            elif side == 'sell':
+                self.sell_q += amount
+                self.sell_v += amount*price
+            self.data['trade'].append([
+                side,
+                amount,
+                price,
+                ])
+            self.public_update_time = time.time()
+            #print(msg)
+
+    async def _update_depth(self, _ws, msg):
+        """"""
+        msg = ujson.loads(msg)
+        if msg['market']!=self.symbol:
+            return
+        depth = msg['data']
+        action = msg['type']
+        if action == 'partial':
+            self._reset_orderbook()
+        for side in {'bids', 'asks'}:
+            book = self._orderbook[side]
+            for price, size in depth[side]:
+                if size:
+                    book[price] = size
+                else:
+                    del book[price]
+            self._orderbook_timestamp = depth['time']
+
+        ob = self.get_orderbook()
+        if self.compare_checksum(ob, depth):
+            self.public_update_time = time.time()
+            bp = self.depth[0]
+            ap = self.depth[40]
+            self.ticker_info["bp"] = bp
+            self.ticker_info["ap"] = ap
+            ##### 标准化深度
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for bid in ob['bids']:
+                price = float(bid[0])
+                amount = float(bid[1])
+                if price > price_thre:
+                    bv[index] += amount
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += amount
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for ask in ob['asks']:
+                price = float(ask[0])
+                amount = float(ask[1])
+                if price < price_thre:
+                    av[index] += amount
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += amount
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+        else:
+            self._reset_orderbook()
+            await self.resubscribe_depth(_ws)
+
+    async def resubscribe_depth(self, _ws):
+        info = f"{self.name} checksum not correct!"
+        print(info)
+        self.logger.info(info)
+        sub_str = {'op':"unsubscribe",'market':self.symbol, 'channel':'orderbook'}
+        await _ws.send_str(ujson.dumps(sub_str))
+        await asyncio.sleep(1)
+        sub_str['op'] = 'subscribe'
+        await _ws.send_str(ujson.dumps(sub_str))
+
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def get_orderbook(self):
+        return {
+            side: sorted(
+                [(price, quantity) for price, quantity in list(self._orderbook[side].items())
+                 if quantity],
+                key=lambda order: order[0] * (-1 if side == 'bids' else 1)
+            )
+            for side in {'bids', 'asks'}
+        }
+
+    @staticmethod
+    def compare_checksum(ob, depth):
+        """计算深度的校验和"""
+        #t1 = time.time()
+        checksum_data = [
+            ':'.join([f'{float(order[0])}:{float(order[1])}' for order in (bid, offer) if order])
+            for (bid, offer) in zip_longest(ob['bids'][:100], ob['asks'][:100])
+        ]
+        cm = int(zlib.crc32(':'.join(checksum_data).encode()))
+        #t2 = time.time()
+        #print(cm, depth['checksum'], (t2-t1)*1000)
+        return cm==depth['checksum']
+
+
+
+
+
+

+ 364 - 0
exchange/ftx_usdt_swap_ws.py

@@ -0,0 +1,364 @@
+
+import aiohttp
+import time
+import asyncio
+import json, ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random, csv, sys
+import logging, logging.handlers
+from itertools import zip_longest
+from datetime import datetime
+import urllib
+import utils
+import model
+from collections import defaultdict, deque
+
+
+
+
+def empty_call(msg):
+    pass
+
+ZERO = 1e-8
+
+
+class FtxUsdtSwapWs:
+    """"""
+    def __init__(self, params:model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.URL = 'wss://ftx.com/ws/'
+        else:
+            self.URL = 'wss://ftx.com/ws/'
+        self.params = params
+        self.name = self.params.name
+        #
+        self.base = params.pair.split('_')[0].upper()
+        self.quote = params.pair.split('_')[1].upper()
+        self.symbol = f"{self.base}-PERP"
+        #
+        self.data = dict()
+        self.data['trade'] = []
+        self.data['force'] = []
+        self.callback = {
+            "onMarket":self.save_market,
+            "onDepth":empty_call,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.depth_update = []
+        self.need_flash = 1
+        self.updata_u = None
+        self.last_update_id = None
+        self.depth = dict()
+        self.depth['bids'] = dict()
+        self.depth['asks'] = dict()
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.stop_flag = 0
+        self.public_update_time = time.time()
+        self.private_update_time = time.time()
+        self.expired_time = 300
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.stepSize = None
+        self.tickSize = None
+        self.ctVal    = None  # 合约乘数
+        self.ctMult   = None  # 合约面值
+        self.depth = []
+
+        self._reset_orderbook()
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+
+    def _reset_orderbook(self) -> None:
+        self._orderbook_timestamp = 0
+        self._orderbook = {side: defaultdict(float) for side in ['bids','asks']}
+
+    def save_market(self, msg):
+        print(msg)
+        #pass
+        #date = time.strftime('%Y-%m-%d',time.localtime())
+        #interval = self.params.interval
+        #if msg:
+        #    exchange = msg['name']
+        #    if len(msg['data']) > 1:
+        #        with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+        #                    'a',
+        #                    newline='',
+        #                    encoding='utf-8') as f:
+        #            writer = csv.writer(f, delimiter=',')
+        #            writer.writerow(msg['data'])
+        #        if self.is_print:print(f'写入行情 {self.symbol}')
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    def subscribe_public(self, sub_trade = 0):
+        channels = [
+            "orderbook",
+            # "ticker"
+        ]
+        if sub_trade:
+            channels.append("trades")
+        subs = [ujson.dumps({'op':'subscribe','market':self.symbol, 'channel':channel}) for channel in channels]
+        return subs
+
+    async def run_public(self, sub_trade=0):
+        """"""
+        while 1:
+            try:
+                self.public_update_time = time.time()
+                print(f"{self.name} public 尝试连接ws")
+                ws_url = self.URL
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f"{self.name} public ws连接成功")
+                    self.logger.debug(f"{self.name} public ws连接成功")
+                    for sub in self.subscribe_public(sub_trade):
+                        await _ws.send_str(sub)
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:
+                            await _ws.close()
+                            return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive()
+                        except:
+                            print(f'{self.name} public ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} public ws长时间没有收到消息 准备重连...')
+                            break
+                        msg = msg.data
+                        await self.on_message_public(_ws, msg)
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws public 连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws public 连接失败 开始重连...')
+                self.logger.error(traceback.format_exc())
+                await asyncio.sleep(1)
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        asyncio.create_task(self.run_public(sub_trade))
+        while True:
+            await asyncio.sleep(5)
+
+    async def on_message_public(self, _ws, msg):
+        """"""
+        #print(msg)
+        if "data" in msg:
+            # 推送数据时,有data字段,优先级也最高
+            if "ticker" in msg:
+                self._update_ticker(msg)
+            elif "trades" in msg:
+                self._update_trade(msg)
+            elif "orderbook" in msg:
+                await self._update_depth(_ws, msg)
+        elif "type" in msg:
+            # event常见于事件回报,一般都可以忽略,只需要看看是否有error
+            if "error" in msg:
+                info = f'{self.name} on_message error! --> {msg}'
+                print(info)
+                self.logger.error(info)
+        elif 'ping' in msg:
+            await _ws.send_str('pong')
+        else:
+            print(msg)
+
+    def _update_ticker(self, msg):
+        """"""
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        ticker = msg['data']
+        bp = float(ticker['bid']) if ticker['bid'] != 'null' else 0
+        ap = float(ticker['ask']) if ticker['ask'] != 'null' else 0
+        self.ticker_info["bp"] = bp
+        self.ticker_info["ap"] = ap
+        self.callback['onTicker'](self.ticker_info)
+
+    def _update_trade(self, msg):
+        """"""
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        for trade in msg['data']:
+            price = float(trade['price'])
+            amount = float(trade['size'])
+            side = trade['side']
+            if price > self.max_buy or self.max_buy == 0.0:
+                self.max_buy = price
+            if price < self.min_sell or self.min_sell == 0.0:
+                self.min_sell = price
+            if side == 'buy':
+                self.buy_q += amount
+                self.buy_v += amount*price
+            elif side == 'sell':
+                self.sell_q += amount
+                self.sell_v += amount*price
+
+    async def _update_depth(self, _ws, msg):
+        """"""
+        msg = ujson.loads(msg)
+        if msg['market']!=self.symbol:
+            return
+        depth = msg['data']
+        action = msg['type']
+        if action == 'partial':
+            self._reset_orderbook()
+        for side in {'bids', 'asks'}:
+            book = self._orderbook[side]
+            for price, size in depth[side]:
+                if size:
+                    book[price] = size
+                else:
+                    del book[price]
+            self._orderbook_timestamp = depth['time']
+
+        ob = self.get_orderbook()
+        if self.compare_checksum(ob, depth):
+            self.public_update_time = time.time()
+            bp = float(ob['bids'][0][0])
+            ap = float(ob['asks'][0][0])
+            self.ticker_info["bp"] = bp
+            self.ticker_info["ap"] = ap
+            self.callback['onTicker'](self.ticker_info)
+            ##### 标准化深度
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for bid in ob['bids']:
+                price = float(bid[0])
+                amount = float(bid[1])
+                if price > price_thre:
+                    bv[index] += amount
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += amount
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for ask in ob['asks']:
+                price = float(ask[0])
+                amount = float(ask[1])
+                if price < price_thre:
+                    av[index] += amount
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += amount
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+        else:
+            self._reset_orderbook()
+            await self.resubscribe_depth(_ws)
+
+    async def resubscribe_depth(self, _ws):
+        info = f"{self.name} checksum not correct!"
+        print(info)
+        self.logger.info(info)
+        sub_str = {'op':"unsubscribe",'market':self.symbol, 'channel':'orderbook'}
+        await _ws.send_str(ujson.dumps(sub_str))
+        await asyncio.sleep(1)
+        sub_str['op'] = 'subscribe'
+        await _ws.send_str(ujson.dumps(sub_str))
+
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def get_orderbook(self):
+        return {
+            side: sorted(
+                [(price, quantity) for price, quantity in list(self._orderbook[side].items())
+                 if quantity],
+                key=lambda order: order[0] * (-1 if side == 'bids' else 1)
+            )
+            for side in {'bids', 'asks'}
+        }
+
+    @staticmethod
+    def compare_checksum(ob, depth):
+        """计算深度的校验和"""
+        #t1 = time.time()
+        checksum_data = [
+            ':'.join([f'{float(order[0])}:{float(order[1])}' for order in (bid, offer) if order])
+            for (bid, offer) in zip_longest(ob['bids'][:100], ob['asks'][:100])
+        ]
+        cm = int(zlib.crc32(':'.join(checksum_data).encode()))
+        #t2 = time.time()
+        #print(cm, depth['checksum'], (t2-t1)*1000)
+        return cm==depth['checksum']
+
+
+
+
+
+

+ 528 - 0
exchange/gate_spot_rest.py

@@ -0,0 +1,528 @@
+import random
+import aiohttp
+import time
+import asyncio
+import zlib
+import json
+import hmac
+import base64
+import hashlib
+import traceback
+import urllib
+from urllib import parse
+from urllib.parse import urljoin
+import datetime, sys, utils
+from urllib.parse import urlparse
+import logging, logging.handlers
+import model
+from decimal import Decimal
+
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+
+
+class GateSpotRest:
+
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('使用colo高速线路')
+            self.HOST = 'https://apiv4-private.gateapi.io'
+        else:
+            self.HOST = 'https://api.gateio.ws'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + '_' + self.quote
+        self._SESSIONS = dict()
+        self.callback = {
+            "onMarket":empty_call,
+            "onPosition":empty_call,
+            "onOrder":empty_call,
+            "onEquity":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.exchange_info = dict()
+        self.tickSize = None
+        self.stepSize = None
+        self.delays = []
+        self.max_delay = 0.0
+        self.avg_delay = 0.0
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.stop_flag = 0
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+            
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+
+    def generate_signature(self, method, uri, query_param=None, body=None):
+        t = time.time()
+        m = hashlib.sha512()
+        m.update((body or "").encode('utf-8'))
+        hashed_payload = m.hexdigest()
+        s = '%s\n%s\n%s\n%s\n%s' % (method, uri, query_param or "", hashed_payload, t)
+        sign = hmac.new(self.params.secret_key.encode('utf-8'), s.encode('utf-8'), hashlib.sha512).hexdigest()
+        return {'KEY': self.params.access_key, 'Timestamp': str(t), 'SIGN': sign}
+
+    async def _request(self, method, uri, body=None, params=None, auth=False):
+        url = urljoin(self.HOST, uri)
+        if method == "GET":
+            headers = {
+                "Content-type": "application/x-www-form-urlencoded",
+                "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) "
+                    "Chrome/39.0.2171.71 Safari/537.36"
+            }
+        else:
+            headers = {
+                "Accept": "application/json",
+                "Content-type": "application/json"
+            }
+        if auth:
+            if method == "POST":
+                body = json.dumps(params)
+                query_param = None
+                sign_headers = self.generate_signature(method, uri, query_param, body)
+                headers.update(sign_headers)
+            if method == "GET" or method == "DELETE":
+                query_param = ''
+                for i in params:
+                    query_param += f'{i}={params[i]}&'
+                query_param = query_param[:-1]
+                sign_headers = self.generate_signature(method, uri, query_param)
+                headers.update(sign_headers)
+        # 发起请求
+        session = self._get_session(url)
+        timeout = aiohttp.ClientTimeout(10)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, params=None, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json() 
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            res_msg = msg + f' 回报 {res}'
+            self.logger.debug(res_msg)
+            if code not in (200, 201, 202, 203, 204, 205, 206):
+                print(f'URL:{url} PARAMS:{params} body:{body} ERROR:{res}')
+                return None, res
+            return res, None
+        except Exception as e:
+            print(f'{self.name} 请求出错', e)
+            self.logger.error('请求错误'+str(e))
+            self.logger.error(traceback.format_exc())
+            return None, e
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='limit'):
+        if origin_side =='kd':
+            side = 'buy'
+        elif origin_side =='pd':
+            side = 'sell'
+        elif origin_side =='kk':
+            side = 'sell'
+        elif origin_side =='pk':
+            side = 'buy'
+        else:
+            return None
+        if symbol not in self.exchange_info:
+            await self.before_trade()
+            # amount = float(Decimal(str(amount//self.exchange_info[symbol].stepSize))*Decimal(str(self.exchange_info[symbol].stepSize)))
+            # price = float(Decimal(str(price//self.exchange_info[symbol].tickSize))*Decimal(str(self.exchange_info[symbol].tickSize)))
+            amount = utils.fix_amount(amount, self.exchange_info[symbol].stepSize)
+            price = utils.fix_price(price, self.exchange_info[symbol].tickSize)
+        if amount <= 0.0 or price <= 0.0:
+            self.logger.error(f"下单参数错误 amount:{amount} price:{price}")
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['client_id'] = cid
+            order_event['filled_price'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['fee'] = 0.0
+            self.callback["onOrder"](order_event)
+        params = {
+            'text':cid,
+            'currency_pair': symbol, 
+            'amount': utils.num_to_str(amount, self.exchange_info[symbol].stepSize),
+            'side': side, 
+            'account':"spot",
+            'price': utils.num_to_str(price, self.exchange_info[symbol].tickSize), 
+            'type':order_type,
+        }
+        # logger.info(f'下单指令 {params}')
+        if self.params.debug == 'True':
+            return await asyncio.sleep(0.1)
+        else:
+            # 发单
+            response, error = await self._request('POST', '/api/v4/spot/orders', params=params, auth=1)
+            if response:
+                # 增加新的
+                order_event = dict()
+                order_event['status'] = "NEW"
+                order_event['client_id'] = response["text"]
+                order_event['order_id'] = response["id"]
+                self.callback["onOrder"](order_event)
+            if error:
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['client_id'] = params["text"]
+                order_event['filled_price'] = 0.0
+                order_event['filled'] = 0.0
+                order_event['fee'] = 0.0
+                self.callback["onOrder"](order_event)
+                return error
+        return response
+
+    async def check_order(self, order_id=None, client_id=None):
+        params = {
+            "currency_pair": self.symbol
+        }
+        if order_id:
+            response, error = await self._request('GET', f'/api/v4/spot/orders/{order_id}', params=params, auth=1)
+        elif client_id:
+            response, error = await self._request('GET', f'/api/v4/spot/orders/{client_id}', params=params, auth=1)
+        if response:
+            if response['status'] in ['cancelled','closed']:  # 已撤销 或 全部成交
+                order_event = dict()
+                order_event['client_id'] = response["text"]
+                order_event['order_id'] = response['id']
+                order_event['filled'] = float(response["amount"]) - float(response["left"])
+                order_event['filled_price'] = float(response["price"])
+                order_event['fee'] = float(response["fee"])
+                order_event['status'] = "REMOVE"
+                self.callback['onOrder'](order_event)
+            else: # 还在挂单中
+                order_event = dict()
+                order_event['client_id'] = response["text"]
+                order_event['order_id'] = response['id']
+                order_event['status'] = "NEW"
+                self.callback['onOrder'](order_event)
+        if error:
+            pass
+        return response
+
+    async def cancel_order(self, order_id=None, client_id=None):
+        params = {
+            "currency_pair": self.symbol
+        }
+        if order_id:
+            response, error = await self._request('DELETE', f'/api/v4/spot/orders/{order_id}', params=params, auth=1)
+        elif client_id:
+            response, error = await self._request('DELETE', f'/api/v4/spot/orders/{client_id}', params=params, auth=1)
+
+        if response:
+            # rest cancel 如果有回报 可能会和ws回报 产生重复处理
+            self.logger.debug(f'撤单回报 {response}')
+            # if response['status'] == 'cancelled':  # 已撤销
+            #     order_event = dict()
+            #     order_event['price'] = float(response["price"])
+            #     order_event['amount'] = float(response["amount"])
+            #     order_event['client_id'] = response["text"]
+            #     order_event['order_id'] = response['id']
+            #     order_event['filled'] = float(response["amount"]) - float(response["left"])
+            #     order_event['filled_price'] = float(response["price"])
+            #     order_event['fee'] = float(response["fee"])
+            #     order_event['status'] = "REMOVE"
+            #     self.callback['onOrder'](order_event)
+        if error:
+            return error
+        return response
+    
+    async def get_order_list(self):
+        params = {
+            'currency_pair':self.symbol,
+            'status':"open",
+        }
+        response, error = await self._request('GET', '/api/v4/spot/orders', params=params, auth=1)
+        orders = [] # 重置本地订单列表
+        if response is not None:
+            for i in response:
+                if i['side'] == 'buy':
+                    side = 'kd'
+                elif i['side'] == 'sell':
+                    side = 'pd'
+                else:
+                    raise Exception(f"{self.name} wrong side")
+                order_event = dict()
+                order_event['price'] = float(i["price"])
+                order_event['amount'] = float(i["amount"])
+                order_event['client_id'] = i["text"]
+                order_event['order_id'] = i['id']
+                order_event['status'] = "NEW"
+                self.callback['onOrder'](order_event)
+        return response
+    
+    async def get_server_time(self):
+        params = {}
+        response = await self._request('GET', '/api/v1/timestamp', params=params)
+        return response
+
+    async def get_account(self):
+        return await self._request('GET','/api/v4/spot/accounts', params={}, auth=1)
+
+    async def get_position(self):
+        '''获取持仓 symbol: BTC-USDT'''
+        return await self._request('POST','/linear-swap-api/v1/swap_position_info', params={'contract_code':self.symbol}, auth=1)
+
+    async def get_market_details(self):
+        return await self._request('GET',f'/api/v4/spot/currency_pairs', params={}, auth=1)
+
+    async def get_ticker(self):
+        res, err = await self._request('GET',f'/api/v4/spot/tickers', params={"currency_pair":self.symbol}, auth=1)
+        if res:
+            ap = float(res[0]["lowest_ask"])
+            bp = float(res[0]["highest_bid"])
+            mp = (ap+bp)*0.5
+            d = {"name":self.name,'mp': mp, 'bp':bp, 'ap':ap}
+            self.callback['onTicker'](d)
+            return d
+        if err:
+            self.logger.error(err)
+            return None 
+
+    async def before_trade(self):
+        res, error = await self.get_market_details()
+        if error:
+            pass
+        else:
+            for i in res:
+                if i['id'] == self.symbol:
+                    self.tickSize = float(Decimal("0.1")**Decimal(i['precision']))
+                    self.stepSize = float(Decimal("0.1")**Decimal(i['amount_precision']))
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i['id']
+                exchange_info.multiplier = 1
+                exchange_info.tickSize = float(Decimal("0.1")**Decimal(i['precision']))
+                exchange_info.stepSize = float(Decimal("0.1")**Decimal(i['amount_precision']))
+                self.exchange_info[exchange_info.symbol] = exchange_info
+
+    async def get_equity(self):
+        # 更新账户
+        res, err = await self.get_account()
+        if err:print(err)
+        if res:
+            for i in res:
+                if self.quote == i['currency'].upper():
+                    cash = float(i['available']) + float(i['locked'])
+                    self.callback['onEquity']({
+                        self.quote:cash
+                        })
+                    self.cash_value = cash
+                if i['currency'].upper() == self.base:
+                    coin = float(i['available']) + float(i['locked'])
+                    self.callback['onEquity']({
+                        self.base:coin
+                        })
+                    self.coin_value = coin
+
+    async def buy_token(self):
+        '''买入平台币'''
+        # 获取u数量 平台币数量
+        # 更新账户
+        cash, token = 0.0, 0.0
+        res, err = await self.get_account()
+        if err:self.logger.error(err)
+        if res:
+            for i in res:
+                if 'USDT' == i['currency'].upper():
+                    cash = float(i['available']) + float(i['locked'])
+                if 'GT' == i['currency'].upper():
+                    token = float(i['available']) + float(i['locked'])
+        self.logger.info(f"持u{cash} 持GT{token}")
+        # 获取平台币价格
+        res, err = await self._request('GET',f'/api/v4/spot/tickers', params={"currency_pair":'GT_USDT'}, auth=1)
+        if err:print(err)
+        if res:
+            ap = float(res[0]["lowest_ask"])
+            bp = float(res[0]["highest_bid"])
+            mp = (ap+bp)*0.5
+        # 判断是否需要买入
+        token_value = token * mp
+        if token_value < 30:
+            self.logger.info(f"GT数量过少")
+            if cash > 200:
+                self.logger.info(f"准备买入GT")
+                # 下单买入50uGT
+                res = await self.take_order("GT_USDT", 50/mp, "kd", mp*1.001, "t-888", "limit")
+                self.logger.info(res)
+            else:
+                self.logger.warning(f"现金不足 无法买入GT")
+        else:
+            self.logger.info(f"GT数量充足")
+
+    async def check_position(self, hold_coin=0.0):
+        '''
+            清空挂单清空仓位
+        '''
+        try:
+            #############################
+            self.logger.info("获取挂单")
+            params = {
+            }
+            response, error = await self._request('GET', '/api/v4/spot/open_orders', params=params, auth=1)
+            if response is not None:
+                for i in response:
+                    params = {
+                        "currency_pair": i['currency_pair']
+                    }
+                    for j in i['orders']:
+                        oid = j['id']
+                        r, e = await self._request('DELETE', f'/api/v4/spot/orders/{oid}', params=params, auth=1)
+                        print(r,e)
+            if error:
+                self.logger.info(error)
+            #############################
+            res, err = await self.get_account()
+            if err:self.logger.info(err)
+            if res:
+                coin = 0.0
+                for i in res:
+                    coin_name = i['currency'].upper()
+                    if coin_name in ['GT','USDT', 'POINT']:
+                        continue
+                    symbol = coin_name + '_USDT'
+                    if coin_name == self.base:
+                        _hold_coin = hold_coin
+                    else:
+                        _hold_coin = 0
+                    coin = float(i['available']) + float(i['locked'])
+                    #################
+                    ticker, _ = await self._request('GET',f'/api/v4/spot/tickers', params={"currency_pair":symbol}, auth=1)
+                    if ticker:
+                        ap = float(ticker[0]["lowest_ask"])
+                        bp = float(ticker[0]["highest_bid"])
+                        mp = (ap+bp)*0.5
+                    #################
+                    coin_value = coin * mp
+                    diff = _hold_coin - coin_value
+                    diff *= 0.99 # 避免无法下单
+                    self.logger.info(f'{symbol}需要调整现货仓位{diff}usd')
+                    if diff > 20.0:
+                        res = await self.take_order(symbol, diff/mp, "kd", mp*1.001, "t-123", "limit")
+                    elif diff < -20.0:
+                        res = await self.take_order(symbol, -diff/mp, "kk", mp*0.999, "t-123", "limit")
+            #############################
+            params = {
+                'currency_pair':self.symbol,
+                'status':"open",
+            }
+            response, error = await self._request('GET', '/api/v4/spot/orders', params=params, auth=1)
+            if response is not None:
+                for i in response:
+                    await self.cancel_order(order_id=i['id'])
+                    await self.cancel_order(client_id=i['text'])
+            if error:
+                self.logger.info(error)        
+        except:
+            self.logger.error("清仓程序执行出错")
+            self.logger.error(traceback.format_exc())
+        return 
+
+    async def go(self):
+        await self.before_trade()
+        await asyncio.sleep(1)
+        while 1:
+            try:
+                # 停机信号
+                if self.stop_flag:return
+                # 更新账户
+                res, err = await self.get_account()
+                if err:print(err)
+                if res:
+                    for i in res:
+                        if self.quote == i['currency'].upper():
+                            cash = float(i['available']) + float(i['locked'])
+                            self.callback['onEquity']({
+                                self.quote:cash
+                            })
+                            self.cash_value = cash
+                        if i['currency'].upper() == self.base:
+                            coin = float(i['available']) + float(i['locked'])
+                            self.callback['onEquity']({
+                                self.base:coin
+                            })
+                            self.coin_value = coin
+                # 更新订单
+                # res = await self.get_order_list()
+                await asyncio.sleep(60)
+                # 打印延迟
+                self.get_delay_info()
+                self.logger.debug(f'{self.name} rest报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(10)
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+    
+    async def handle_signals(self, orders):
+        '''执行策略指令'''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    # gate优先用oid撤单
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+                    elif cid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(client_id=cid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3],
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    # gate优先用oid查单
+                    cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.check_order(order_id=oid))
+                    elif cid:
+                        asyncio.get_event_loop().create_task(self.check_order(client_id=cid))
+        except:
+            # traceback.print_exc()
+            await asyncio.sleep(0.1)
+
+

+ 389 - 0
exchange/gate_spot_ws.py

@@ -0,0 +1,389 @@
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import zlib
+import hmac, sys
+import base64, csv, random
+import traceback, hashlib
+import logging, logging.handlers
+import utils
+import model
+
+
+def inflate(data):
+    '''
+        解压缩数据
+    '''
+    decompress = zlib.decompressobj(-zlib.MAX_WBITS)
+    inflated = decompress.decompress(data)
+    inflated += decompress.flush()
+    return inflated
+
+def empty_call(msg):
+    pass
+
+def get_sign(secret_key, message):
+    h = (base64.b64encode(hmac.new(secret_key.encode('utf-8'), message.encode('utf-8'), hashlib.sha512).digest())).decode()
+    return h
+
+class GateSpotWs:
+
+    def __init__(self, params:model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('启用colo高速线路')
+            self.URL = 'wss://spotws-private.gateapi.io/ws/v4/'
+        else:
+            self.URL = 'wss://api.gateio.ws/ws/v4/'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + '_' + self.quote
+        self.callback = {
+            "onMarket":self.save_market,
+            "onDepth":empty_call,
+            "onPosition":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTrade":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.stop_flag = 0
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.update_t = 0.0
+        self.depth = []
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = self.params.interval
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+    def _update_ticker(self, msg):
+        self.ticker_info["bp"] = float(msg['highest_bid'])
+        self.ticker_info["ap"] = float(msg['lowest_ask'])
+        self.callback['onTicker'](self.ticker_info)
+
+    def _update_depth(self, msg):
+        if msg['t'] > self.update_t:
+            self.update_t = msg['t']
+            self.ticker_info["bp"] = float(msg['bids'][0][0])
+            self.ticker_info["ap"] = float(msg['asks'][0][0])
+            self.callback['onTicker'](self.ticker_info)
+            ##### 标准化深度
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for bid in msg['bids']:
+                price = float(bid[0])
+                amount = float(bid[1])
+                if price > price_thre:
+                    bv[index] += amount
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += amount
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for ask in msg['asks']:
+                price = float(ask[0])
+                amount = float(ask[1])
+                if price < price_thre:
+                    av[index] += amount
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += amount
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+        else:
+            print("收到过期depth")
+    
+    def _update_trade(self, msg):
+        price = float(msg['price'])
+        amount = float(msg['amount'])
+        side = msg['side']
+        if price > self.max_buy or self.max_buy == 0.0:
+            self.max_buy = price
+        if price < self.min_sell or self.min_sell == 0.0:
+            self.min_sell = price
+        if side == 'buy':
+            self.buy_q += amount
+            self.buy_v += amount*price
+        elif side == 'sell':
+            self.sell_q += amount
+            self.sell_v += amount*price
+
+    def _update_account(self, msg):
+        for i in msg:
+            if i['currency'].upper() == self.quote:
+                cash = float(i['total'])
+                self.callback['onEquity'] = {
+                    self.quote:cash
+                }
+            elif i['currency'].upper() == self.base:
+                coin = float(i['total'])
+                self.callback['onEquity'] = {
+                    self.base:coin
+                }
+
+    def _update_order(self, msg):
+        self.logger.debug(f"ws订单推送 {msg}")
+        for i in msg:
+            if i['event'] == 'put':
+                order_event = dict()
+                order_event['filled'] = 0
+                order_event['filled_price'] = 0
+                order_event['client_id'] = i["text"]
+                order_event['order_id'] = i['id']
+                order_event['status'] = "NEW"
+                self.callback['onOrder'](order_event)
+            elif i['event'] == 'finish':
+                order_event = dict()
+                order_event['filled'] = float(i["amount"]) - float(i["left"])
+                if order_event['filled'] > 0:
+                    order_event['filled_price'] = float(i["filled_total"])/order_event['filled']
+                else:
+                    order_event['filled_price'] = 0
+                order_event['client_id'] = i["text"]
+                order_event['order_id'] = i['id']
+                order_event['fee'] = float(i["fee"])
+                order_event['status'] = "REMOVE"
+                self.callback['onOrder'](order_event)
+                # 根据成交信息更新仓位信息 因为账户信息推送有延迟
+                # 但订单信息和账户信息到达先后时间可能有前有后 可能平仓 账户先置零仓位 然后sell成交达到 导致仓位变成负数
+
+    def _update_usertrade(self, msg):
+        '''暂时不用'''
+        pass
+
+    def _update_position(self, msg):
+        long_pos, short_pos = 0, 0
+        long_avg, short_avg = 0, 0 
+        for i in msg[0]['holding']:
+            if i['side'] == 'long':
+                long_pos += float(i['position'])
+                long_avg = float(i['avg_cost'])
+            if i['side'] == 'short':
+                short_pos += float(i['position'])
+                short_avg = float(i['avg_cost'])
+        pos = model.Position()
+        pos.longPos = long_pos
+        pos.longAvg = long_avg
+        pos.shortPos = short_pos
+        pos.shortAvg = short_avg
+        self.callback['onPosition'](pos)
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    def get_sign(self, message):
+        h = hmac.new(self.params.secret_key.encode("utf8"), message.encode("utf8"), hashlib.sha512)
+        return h.hexdigest()
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        while True:
+            try:
+                ping_time = time.time()
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                ws_url = self.URL
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    self.is_print:print(f'{self.name} ws连接成功')
+                    # 登陆
+                    if is_auth:
+                        # userorders
+                        current_time = int(time.time())
+                        channel = "spot.orders"
+                        sub_str = {
+                            "time": current_time,
+                            "channel": channel,
+                            "event": "subscribe", 
+                            "payload": [self.symbol]
+                        }
+                        message = 'channel=%s&event=%s&time=%d' % (channel, "subscribe", current_time)
+                        sub_str["auth"] = {
+                            "method": "api_key",
+                            "KEY": self.params.access_key,
+                            "SIGN": self.get_sign(message)}
+                        await _ws.send_str(ujson.dumps(sub_str))
+                        # usertrades
+                        current_time = int(time.time())
+                        channel = "spot.usertrades"
+                        sub_str = {
+                            "time": current_time,
+                            "channel": channel,
+                            "event": "subscribe", 
+                            "payload": [self.symbol]
+                        }
+                        message = 'channel=%s&event=%s&time=%d' % (channel, "subscribe", current_time)
+                        sub_str["auth"] = {
+                            "method": "api_key",
+                            "KEY": self.params.access_key,
+                            "SIGN": self.get_sign(message)}
+                        await _ws.send_str(ujson.dumps(sub_str))
+                        # balance
+                        current_time = int(time.time())
+                        channel = "spot.balances"
+                        sub_str = {
+                            "time": current_time,
+                            "channel": channel,
+                            "event": "subscribe", 
+                            "payload": [self.symbol]
+                        }
+                        message = 'channel=%s&event=%s&time=%d' % (channel, "subscribe", current_time)
+                        sub_str["auth"] = {
+                            "method": "api_key",
+                            "KEY": self.params.access_key,
+                            "SIGN": self.get_sign(message)}
+                        await _ws.send_str(ujson.dumps(sub_str))
+                    if sub_trade:
+                        # public trade
+                        current_time = int(time.time())
+                        channel = "spot.trades"
+                        sub_str = {
+                            "time": current_time,
+                            "channel": channel,
+                            "event": "subscribe", 
+                            "payload": [self.symbol]
+                        }
+                        await _ws.send_str(ujson.dumps(sub_str))
+                    # 订阅public
+                    # tickers   太慢了
+                    # current_time = int(time.time())
+                    # channel = "spot.tickers"
+                    # sub_str = {
+                    #     "time": current_time,
+                    #     "channel": channel,
+                    #     "event": "subscribe", 
+                    #     "payload": [self.symbol]
+                    # }
+                    # await _ws.send_str(ujson.dumps(sub_str))
+                    # depth
+                    current_time = int(time.time())
+                    channel = "spot.order_book"
+                    sub_str = {
+                        "time": current_time,
+                        "channel": channel,
+                        "event": "subscribe", 
+                        "payload": [self.symbol,"20","100ms"]
+                    }
+                    await _ws.send_str(ujson.dumps(sub_str))
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:
+                            await _ws.close()
+                            return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=10)
+                        except:
+                            print(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            break
+                        msg = msg[1]
+                        # 处理消息
+                        if 'update' in msg:
+                            msg = ujson.loads(msg)
+                            # if msg['channel'] == 'spot.tickers':self._update_ticker(msg['result'])
+                            if msg['channel'] == 'spot.order_book':self._update_depth(msg['result'])
+                            elif msg['channel'] == 'spot.balances':self._update_account(msg['result'])
+                            elif msg['channel'] == 'spot.orders':self._update_order(msg['result'])
+                            # if msg['channel'] == 'spot.usertrades':self._update_usertrade(msg['result'])
+                            elif msg['channel'] == 'spot.trades':self._update_trade(msg['result'])
+                        else:
+                            # print(msg)
+                            pass
+                        # pong
+                        if time.time() - ping_time > 5:
+                            await _ws.send_str('{"time": %d, "channel" : "spot.ping"}' % int(time.time()))
+                            ping_time = time.time()
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                # await asyncio.sleep(1)
+

+ 525 - 0
exchange/gate_usdt_swap_rest.py

@@ -0,0 +1,525 @@
+import random
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import hmac
+import base64
+import hashlib
+import traceback
+import urllib
+from urllib import parse
+from urllib.parse import urljoin
+import datetime, sys, utils
+from urllib.parse import urlparse
+import logging, logging.handlers
+import model
+from decimal import Decimal
+
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+
+
+class GateUsdtSwapRest:
+
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('使用colo高速线路')
+            self.HOST = 'https://apiv4-private.gateapi.io'
+        else:
+            self.HOST = 'https://api.gateio.ws'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + '_' + self.quote
+        self._SESSIONS = dict()
+        self.callback = {
+            "onMarket":empty_call,
+            "onPosition":empty_call,
+            "onOrder":empty_call,
+            "onEquity":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.exchange_info = dict()
+        self.tickSize = None
+        self.stepSize = None
+        self.delays = []
+        self.max_delay = 0.0
+        self.avg_delay = 0.0
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.stop_flag = 0
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+        self.multiplier = None
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+            
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+
+    def generate_signature(self, method, uri, query_param=None, body=None):
+        t = time.time()
+        m = hashlib.sha512()
+        m.update((body or "").encode('utf-8'))
+        hashed_payload = m.hexdigest()
+        s = '%s\n%s\n%s\n%s\n%s' % (method, uri, query_param or "", hashed_payload, t)
+        sign = hmac.new(self.params.secret_key.encode('utf-8'), s.encode('utf-8'), hashlib.sha512).hexdigest()
+        return {'KEY': self.params.access_key, 'Timestamp': str(t), 'SIGN': sign}
+
+    async def _request(self, method, uri, body=None, params=None, auth=False):
+        url = urljoin(self.HOST, uri)
+        if method == "GET":
+            headers = {
+                "Content-type": "application/x-www-form-urlencoded",
+                "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) "
+                    "Chrome/39.0.2171.71 Safari/537.36"
+            }
+        else:
+            headers = {
+                "Accept": "application/json",
+                "Content-type": "application/json"
+            }
+        if auth:
+            if method == "POST":
+                query_param = ''
+                if params:
+                    for i in params:
+                        query_param += f'{i}={params[i]}&'
+                    query_param = query_param[:-1]
+                    url += "?"+ query_param
+                if body:
+                    body = ujson.dumps(body)
+                sign_headers = self.generate_signature(method, uri, query_param, body)
+                headers.update(sign_headers)
+            if method == "GET" or method == "DELETE":
+                query_param = ''
+                for i in params:
+                    query_param += f'{i}={params[i]}&'
+                query_param = query_param[:-1]
+                sign_headers = self.generate_signature(method, uri, query_param)
+                headers.update(sign_headers)
+        # 发起请求
+        session = self._get_session(url)
+        timeout = aiohttp.ClientTimeout(10)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, params=None, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json() 
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            self.get_delay_info()
+            res_msg = msg + f' 回报 {res}'
+            self.logger.debug(res_msg)
+            if code not in (200, 201, 202, 203, 204, 205, 206):
+                print(f'URL:{url} PARAMS:{params} body:{body} ERROR:{res}')
+                return None, res
+            return res, None
+        except Exception as e:
+            print(f'{self.name} 请求出错', e)
+            self.logger.error('请求错误'+str(e))
+            self.logger.error(traceback.format_exc())
+            return None, e
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='limit'):
+        '''
+            传入单位是张 内部转换为币
+        '''
+        if origin_side =='kd':
+            side = 'buy'
+            reduce_only = False
+        elif origin_side =='pd':
+            side = 'sell'
+            reduce_only = True
+        elif origin_side =='kk':
+            side = 'sell'
+            reduce_only = False
+        elif origin_side =='pk':
+            side = 'buy'
+            reduce_only = True
+        else:
+            return None
+        amount = int(amount/self.exchange_info[symbol].multiplier) # 币转换为张
+        if amount <= 0.0 or price <= 0.0:
+            self.logger.error(f"下单参数错误 amount:{amount} price:{price}")
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['client_id'] = cid
+            order_event['filled_price'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['fee'] = 0.0
+            self.callback["onOrder"](order_event)
+        if side == 'sell':
+            amount = -amount
+        params = {
+            'text':cid,
+            'contract': symbol, 
+            'size': amount,
+            'reduce_only':reduce_only,
+            'price': utils.num_to_str(price, self.exchange_info[symbol].tickSize), 
+            'type':order_type,
+        }
+        # logger.info(f'下单指令 {params}')
+        if self.params.debug == 'True':
+            return await asyncio.sleep(0.1)
+        else:
+            # 发单
+            response, error = await self._request('POST', '/api/v4/futures/usdt/orders', body=params, auth=1)
+            if response:
+                # 增加新的
+                order_event = dict()
+                order_event['status'] = "NEW"
+                order_event['client_id'] = response["text"]
+                order_event['order_id'] = response["id"]
+                self.callback["onOrder"](order_event)
+            if error:
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['client_id'] = params["text"]
+                order_event['filled_price'] = 0.0
+                order_event['filled'] = 0.0
+                order_event['fee'] = 0.0
+                self.callback["onOrder"](order_event)
+                return error
+        return response
+
+    async def check_order(self, order_id=None, client_id=None):
+        params = {}
+        if order_id:
+            response, error = await self._request('GET', f'/api/v4/futures/usdt/orders/{order_id}', params=params, auth=1)
+        elif client_id:
+            response, error = await self._request('GET', f'/api/v4/futures/usdt/orders/{client_id}', params=params, auth=1)
+
+        if response:
+            if response['status'] in ['cancelled','closed','finished']:  # 已撤销 或 全部成交
+                order_event = dict()
+                order_event['client_id'] = response["text"]
+                order_event['order_id'] = response['id']
+                order_event['filled'] = (abs(float(response["size"])) - abs(float(response["left"])))*self.multiplier
+                order_event['filled_price'] = float(response["price"])
+                order_event['fee'] = 0.0
+                order_event['status'] = "REMOVE"
+                self.callback['onOrder'](order_event)
+            elif response['status'] in ['open']: # 还在挂单中
+                order_event = dict()
+                order_event['client_id'] = response["text"]
+                order_event['order_id'] = response['id']
+                order_event['status'] = "NEW"
+                self.callback['onOrder'](order_event)
+        if error:
+            pass
+        return response
+
+    async def cancel_order(self, order_id=None, client_id=None):
+        params = {
+            "currency_pair": self.symbol
+        }
+        if order_id:
+            response, error = await self._request('DELETE', f'/api/v4/futures/usdt/orders/{order_id}', params=params, auth=1)
+        elif client_id:
+            response, error = await self._request('DELETE', f'/api/v4/futures/usdt/orders/{client_id}', params=params, auth=1)
+        if response:
+            pass
+            # self.logger.info(f'撤单回报 {response}')
+            # if response['status'] == 'cancelled':  # 已撤销
+            #     order_event = dict()
+            #     order_event['price'] = float(response["price"])
+            #     order_event['amount'] = float(response["amount"])
+            #     order_event['client_id'] = response["text"]
+            #     order_event['order_id'] = response['id']
+            #     order_event['filled'] = float(response["amount"]) - float(response["left"])
+            #     order_event['filled_price'] = float(response["price"])
+            #     order_event['fee'] = float(response["fee"])
+            #     order_event['status'] = "REMOVE"
+            #     self.callback['onOrder'](order_event)
+        if error:
+            return error
+        return response
+    
+    async def get_order_list(self):
+        params = {
+            'currency_pair':self.symbol,
+            'status':"open",
+        }
+        response, error = await self._request('GET', '/api/v4/futures/usdt/orders', params=params, auth=1)
+        orders = [] # 重置本地订单列表
+        if response is not None:
+            for i in response:
+                if i['side'] == 'buy':
+                    side = 'kd'
+                elif i['side'] == 'sell':
+                    side = 'pd'
+                else:
+                    raise Exception(f"{self.name} wrong side")
+                order_event = dict()
+                order_event['price'] = float(i["price"])
+                order_event['amount'] = float(i["size"])*self.multiplier
+                order_event['client_id'] = i["text"]
+                order_event['order_id'] = i['id']
+                order_event['status'] = "NEW"
+                self.callback['onOrder'](order_event)
+        return response
+    
+    async def get_server_time(self):
+        params = {}
+        response = await self._request('GET', '/api/v1/timestamp', params=params)
+        return response
+
+    async def get_account(self):
+        return await self._request('GET','/api/v4/futures/usdt/accounts', params={}, auth=1)
+
+    async def get_position(self):
+        '''获取持仓 symbol: BTC-USDT'''
+        return await self._request('GET', f'/api/v4/futures/usdt/dual_comp/positions/{self.symbol}', params={}, auth=1)
+
+    async def get_market_details(self):
+        return await self._request('GET',f'/api/v4/futures/usdt/contracts', params={}, auth=1)
+
+    async def get_ticker(self):
+        res, err = await self._request('GET',f'/api/v4/futures/usdt/tickers', params={"contract":self.symbol}, auth=1)
+        if res:
+            ap = float(res[0]["last"])
+            bp = float(res[0]["last"])
+            mp = (ap+bp)*0.5
+            d = {"name":self.name,'mp': mp, 'bp':bp, 'ap':ap}
+            self.callback['onTicker'](d)
+            return d
+        if err:
+            self.logger.error(err)
+            return None 
+
+    async def before_trade(self):
+        ### 获取市场信息
+        res, err = await self.get_market_details()
+        if err:
+            pass
+        if res:
+            for i in res:
+                if self.symbol == i['name']:
+                    self.tickSize = float(i['order_price_round'])
+                    self.multiplier = float(i['quanto_multiplier'])
+                    self.stepSize = float(i['order_size_min'])*float(i['quanto_multiplier']) # 张 转换为 币
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i['name']
+                exchange_info.multiplier = float(i['quanto_multiplier'])
+                exchange_info.tickSize = float(i['order_price_round'])
+                exchange_info.stepSize = float(i['order_size_min'])*float(i['quanto_multiplier'])
+                self.exchange_info[exchange_info.symbol] = exchange_info
+        ### 获取持仓模式
+        res, err = await self._request('POST', f'/api/v4/futures/usdt/dual_mode', params={'dual_mode':"true"}, auth=1)
+        if err:
+            print(err)
+        if res:
+            print(res)
+        ### 杠杆
+        res, err = await self._request('POST', f'/api/v4/futures/usdt/dual_comp/positions/{self.symbol}/leverage', params={'leverage':"20"}, auth=1)
+        if err:
+            print(err)
+        if res:
+            print(res)
+
+    async def get_equity(self):
+        # 更新账户
+        res, err = await self.get_account()
+        if err:print(err)
+        if res:
+            if res['currency'] == self.quote:
+                cash = float(res['total'])
+                self.callback['onEquity']({
+                    self.quote:cash
+                    })
+                self.cash_value = cash
+
+    async def buy_token(self):
+        '''买入平台币'''
+        pass
+
+    async def check_position(self, hold_coin=0.0):
+        '''
+            清空挂单清空仓位  已支持全品种
+        '''
+        try:
+            #############################
+            self.logger.info("清空挂单")
+            params = {
+                'contract':self.symbol
+            }
+            response, error = await self._request('DELETE', '/api/v4/futures/usdt/orders', params=params, auth=1)
+            if response:
+                self.logger.info(response)
+            if error:
+                self.logger.info(error)
+            #############################
+            #############################
+            self.logger.info("检查遗漏仓位")
+            res, err = await self._request('GET', f'/api/v4/futures/usdt/positions', params={}, auth=1)
+            if err:
+                self.logger.info(err)
+            if res:
+                for i in res:
+                    symbol = i['contract']
+                    if symbol not in self.exchange_info:
+                        await self.before_trade()
+                    size = abs(float(i['size']))  # 单位张
+                    side = i['mode']
+                    if size == 0:
+                        pass
+                    else:
+                        #######
+                        res, err = await self._request('GET',f'/api/v4/futures/usdt/tickers', params={"contract":symbol}, auth=1)
+                        if res:
+                            ap = float(res[0]["last"])
+                            bp = float(res[0]["last"])
+                            mp = (ap+bp)*0.5
+                        if err:
+                            pass
+                        #######
+                        amount = abs(size)*self.exchange_info[symbol].multiplier
+                        #######
+                        if side == 'dual_short':
+                            # pk
+                            price = float(Decimal(str(mp*1.001//self.exchange_info[symbol].tickSize))*Decimal(str(self.exchange_info[symbol].tickSize)))
+                            res = await self.take_order(symbol, amount, "pk", price, "t-123", "limit")
+                            self.logger.info(res)
+                        if side == 'dual_long':
+                            # pd
+                            price = float(Decimal(str(mp*0.999//self.exchange_info[symbol].tickSize))*Decimal(str(self.exchange_info[symbol].tickSize)))
+                            res = await self.take_order(symbol, amount, "pd", price, "t-123", "limit")
+                            self.logger.info(res)            
+        except:
+            self.logger.error("清仓程序执行出错")
+            self.logger.error(traceback.format_exc())
+        return
+
+    async def go(self):
+        interval = 60
+        await self.before_trade()
+        await asyncio.sleep(1)
+        while 1:
+            try:
+                # 停机信号
+                if self.stop_flag:return
+                # 更新账户
+                res, err = await self.get_account()
+                if err:print(err)
+                if res:
+                    if res['currency'] == self.quote:
+                        cash = float(res['total'])
+                        self.callback['onEquity']({
+                            self.quote:cash
+                            })
+                        self.cash_value = cash
+                        self.logger.debug(f"rest cash {cash}")
+                # 更新仓位
+                res, err = await self.get_position()
+                if err:
+                    self.logger.info(err)
+                if res:
+                    p = model.Position()
+                    for i in res:
+                        symbol = i['contract']
+                        size = abs(float(i['size']))*self.multiplier
+                        price = float(i['entry_price'])
+                        side = i['mode']
+                        if self.symbol == symbol:
+                            if size == 0:
+                                pass
+                            else:
+                                #######
+                                if side == 'dual_short':
+                                    p.shortAvg =  price
+                                    p.shortPos = size
+                                if side == 'dual_long':
+                                    p.longAvg = price
+                                    p.longPos = size
+                    self.callback['onPosition'](p)
+                await asyncio.sleep(interval)
+                # 打印延迟
+                self.logger.debug(f'{self.name} rest报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(10)
+    
+    async def transfor(self):
+        params = {
+            'currency':'USDT',
+            'from':'spot',
+            'to':'futures',
+            'amount':'400',
+            'settle':'USDT',
+        }
+        response, error = await self._request('POST', '/api/v4/wallet/transfers', body=params, auth=1)
+        if response:
+            print(response)
+        if error:
+            print(error)
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+    
+    async def handle_signals(self, orders):
+        '''执行策略指令'''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    if cid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(client_id=cid))
+                    elif oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3],
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    cid = orders[order_name][0]
+                    # oid = orders[order_name][1]
+                    asyncio.get_event_loop().create_task(self.check_order(client_id=cid))
+        except:
+            # traceback.print_exc()
+            await asyncio.sleep(0.1)
+

+ 464 - 0
exchange/gate_usdt_swap_ws.py

@@ -0,0 +1,464 @@
+from re import sub
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import zlib
+import hmac, sys
+import base64, csv, random
+import traceback, hashlib
+import logging, logging.handlers
+import utils
+import model
+from decimal import Decimal
+
+
+def inflate(data):
+    '''
+        解压缩数据
+    '''
+    decompress = zlib.decompressobj(-zlib.MAX_WBITS)
+    inflated = decompress.decompress(data)
+    inflated += decompress.flush()
+    return inflated
+
+def empty_call(msg):
+    pass
+
+class GateUsdtSwapWs:
+
+    def __init__(self, params:model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('使用colo高速线路')
+            self.URL = 'wss://fxws-private.gateapi.io/v4/ws/usdt'
+        else:
+            self.URL = 'wss://fx-ws.gateio.ws/v4/ws/usdt'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + '_' + self.quote
+        self.callback = {
+            "onMarket":self.save_market,
+            "onPosition":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTrade":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.stop_flag = 0
+        self.update_t = 0
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        # 过期检查
+        self.public_update_time = time.time()
+        self.private_update_time = time.time()
+        self.expired_time = 300
+        self.multiplier = None
+        self.depth = []
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def gen_signed(self, channel, event, timestamp):
+        # 为消息签名
+        api_key = self.params.access_key
+        api_secret = self.params.secret_key
+
+        s = 'channel=%s&event=%s&time=%d' % (channel, event, timestamp)
+        sign = hmac.new(api_secret.encode('utf-8'), s.encode('utf-8'), hashlib.sha512).hexdigest()
+        return {'method': 'api_key', 'KEY': api_key, 'SIGN': sign}
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = self.params.interval
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+    def _update_depth(self, msg):
+        self.public_update_time = time.time()
+        if msg['t'] > self.update_t:
+            self.update_t = msg['t']
+            self.ticker_info["bp"] = float(msg['bids'][0]['p'])
+            self.ticker_info["ap"] = float(msg['asks'][0]['p'])
+            self.callback['onTicker'](self.ticker_info)
+            ##### 标准化深度
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for bid in msg['bids']:
+                price = float(bid['p'])
+                amount = float(bid['s'])
+                if price > price_thre:
+                    bv[index] += amount
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += amount
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for ask in msg['asks']:
+                price = float(ask['p'])
+                amount = float(ask['s'])
+                if price < price_thre:
+                    av[index] += amount
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += amount
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+        else:   
+            self.logger.error(f"收到过时的depth推送 {self.update_t}")
+    
+    def _update_trade(self, msg):
+        self.public_update_time = time.time()
+        for i in msg:
+            amount = float(i['size'])*self.multiplier
+            price = float(i['price'])
+            side = "buy" if amount > 0.0 else "sell"
+            if price > self.max_buy or self.max_buy == 0.0:
+                self.max_buy = price
+            if price < self.min_sell or self.min_sell == 0.0:
+                self.min_sell = price
+            if side == 'buy':
+                self.buy_q += amount
+                self.buy_v += amount*price
+            elif side == 'sell':
+                self.sell_q += amount
+                self.sell_v += amount*price
+
+    def _update_account(self, msg):
+        self.private_update_time = time.time()
+        for i in msg:
+            if self.symbol in i['text']:
+                cash = float(i['balance'])
+                self.callback['onEquity'] = {
+                    self.quote:cash
+                }
+                self.logger.debug(f"ws cash {cash}")
+
+    def _update_order(self, msg):
+        self.private_update_time = time.time()
+        self.logger.debug(f"ws订单推送 {msg}")
+        for i in msg:
+            if i['status'] in ['open']:
+                order_event = dict()
+                order_event['filled'] = 0
+                order_event['filled_price'] = 0
+                order_event['client_id'] = i["text"]
+                order_event['order_id'] = i['id']
+                order_event['status'] = "NEW"
+                self.callback['onOrder'](order_event)
+            elif i['status'] in ['finished']:
+                order_event = dict()
+                filled_paper = Decimal(abs(float(i["size"]))) - Decimal(abs(float(i["left"])))
+                filled_amount = filled_paper*Decimal(str(self.multiplier))
+                order_event['filled'] = float(filled_amount)
+                order_event['filled_price'] = float(i["fill_price"])
+                order_event['client_id'] = i["text"]
+                order_event['order_id'] = i['id']
+                order_event['fee'] = 0.0
+                order_event['status'] = "REMOVE"
+                self.callback['onOrder'](order_event)
+                # 根据成交信息更新仓位信息 因为账户信息推送有延迟
+                # 但订单信息和账户信息到达先后时间可能有前有后 可能平仓 账户先置零仓位 然后sell成交达到 导致仓位变成负数
+
+    def _update_usertrade(self, msg):
+        '''暂时不用'''
+        return
+
+    def _update_position(self, msg):
+        self.private_update_time = time.time()
+        long_pos, short_pos = 0, 0
+        long_avg, short_avg = 0, 0 
+        for i in msg:
+            if i['contract'] == self.symbol:
+                size = float(i['size'])*self.multiplier
+                if size > 0:
+                    long_pos = abs(size)
+                    long_avg = float(i['entry_price'])
+                if size < 0:
+                    short_pos = abs(size)
+                    short_avg = float(i['entry_price'])
+        pos = model.Position()
+        pos.longPos = long_pos
+        pos.longAvg = long_avg
+        pos.shortPos = short_pos
+        pos.shortAvg = short_avg
+        self.callback['onPosition'](pos)
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    def get_sign(self, message):
+        h = hmac.new(self.params.secret_key.encode("utf8"), message.encode("utf8"), hashlib.sha512)
+        return h.hexdigest()
+
+    def _get_uid(self):
+        pass
+
+    def generate_signature(self, method, uri, query_param=None, body=None):
+        t = time.time()
+        m = hashlib.sha512()
+        m.update((body or "").encode('utf-8'))
+        hashed_payload = m.hexdigest()
+        s = '%s\n%s\n%s\n%s\n%s' % (method, uri, query_param or "", hashed_payload, t)
+        sign = hmac.new(self.params.secret_key.encode('utf-8'), s.encode('utf-8'), hashlib.sha512).hexdigest()
+        return {'KEY': self.params.access_key, 'Timestamp': str(t), 'SIGN': sign}
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        while True:
+            try:
+                # 重置更新时间
+                self.public_update_time = time.time()
+                self.private_update_time = time.time()
+                ping_time = time.time()
+                # 获取uid
+                headers = {
+                    "Accept": "application/json",
+                    "Content-type": "application/json"
+                }
+                if is_auth:
+                    user_id = ""
+                    uri = "/api/v4/wallet/fee"
+                    query_param = ''
+                    sign_headers = self.generate_signature('GET', uri, query_param)
+                    headers.update(sign_headers)
+                    async with aiohttp.ClientSession(connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )) as session:
+                        response = await session.get(
+                            "https://api.gateio.ws" + uri,  
+                            headers=headers,
+                            proxy=self.proxy
+                        )
+                        res = await response.json()
+                        user_id = str(res['user_id'])
+                        print(f"uid {user_id}")
+                # 获取合约乘数
+                async with aiohttp.ClientSession(connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )) as session:
+                    uri = "/api/v4/futures/usdt/contracts"
+                    response = await session.get(
+                        "https://api.gateio.ws" + uri,  
+                        headers=headers,
+                        proxy=self.proxy
+                    )
+                    res = await response.json()
+                    if res:
+                        for i in res:
+                            if self.symbol == i['name']:
+                                self.multiplier = float(i['quanto_multiplier'])
+                                print(f"contract multiplier {self.multiplier}")
+                    # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                ws_url = self.URL
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws连接成功')
+                    # 登陆
+                    if is_auth:
+                        # userorders
+                        current_time = int(time.time())
+                        channel = "futures.orders"
+                        sub_str = {
+                            "time": current_time,
+                            "channel": channel,
+                            "event": "subscribe", 
+                            "payload": [user_id,self.symbol]
+                            }
+                        sub_str["auth"] = self.gen_signed(sub_str['channel'], sub_str['event'], sub_str['time'])
+                        await _ws.send_str(ujson.dumps(sub_str))
+                        # positions
+                        current_time = int(time.time())
+                        channel = "futures.positions"
+                        sub_str = {
+                            "time": current_time,
+                            "channel": channel,
+                            "event": "subscribe", 
+                            "payload": [user_id,self.symbol]
+                            }
+                        sub_str["auth"] = self.gen_signed(sub_str['channel'], sub_str['event'], sub_str['time'])
+                        await _ws.send_str(ujson.dumps(sub_str))
+                        # usertrades
+                        # current_time = int(time.time())
+                        # channel = "futures.usertrades"
+                        # sub_str = {
+                        #     "time": current_time,
+                        #     "channel": channel,
+                        #     "event": "subscribe", 
+                        #     "payload": [self.symbol]
+                        # }
+                        # message = 'channel=%s&event=%s&time=%d' % (channel, "subscribe", current_time)
+                        # sub_str["auth"] = {
+                        #     "method": "api_key",
+                        #     "KEY": self.params.access_key,
+                        #     "SIGN": self.get_sign(message)}
+                        # await _ws.send_str(ujson.dumps(sub_str))
+                        # balance
+                        current_time = int(time.time())
+                        channel = "futures.balances"
+                        sub_str = {
+                            "time": current_time,
+                            "channel": channel,
+                            "event": "subscribe", 
+                            "payload": [user_id]
+                        }
+                        sub_str["auth"] = self.gen_signed(sub_str['channel'], sub_str['event'], sub_str['time'])
+                        await _ws.send_str(ujson.dumps(sub_str))
+                    if sub_trade:
+                        # public trade
+                        current_time = int(time.time())
+                        channel = "futures.trades"
+                        sub_str = {
+                            "time": current_time,
+                            "channel": channel,
+                            "event": "subscribe", 
+                            "payload": [self.symbol]
+                        }
+                        await _ws.send_str(ujson.dumps(sub_str))
+                    # 订阅
+                    # tickers 速度慢
+                    # current_time = int(time.time())
+                    # channel = "futures.tickers"
+                    # sub_str = {
+                    #     "time": current_time,
+                    #     "channel": channel,
+                    #     "event": "subscribe", 
+                    #     "payload": [self.symbol]
+                    # }
+                    # await _ws.send_str(ujson.dumps(sub_str))
+                    # depth
+                    current_time = int(time.time())
+                    channel = "futures.order_book"
+                    sub_str = {
+                        "time": current_time,
+                        "channel": channel,
+                        "event": "subscribe", 
+                        "payload": [self.symbol,"20","0"]
+                    }
+                    await _ws.send_str(ujson.dumps(sub_str))
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:
+                            await _ws.close()
+                            return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=10)
+                        except:
+                            print(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            break
+                        msg = ujson.loads(msg.data)
+                        # 处理消息
+                        if msg['event'] in ['update', 'all']:
+                            if msg['channel'] == 'futures.order_book':self._update_depth(msg['result'])
+                            elif msg['channel'] == 'futures.balances':self._update_account(msg['result'])
+                            elif msg['channel'] == 'futures.orders':self._update_order(msg['result'])
+                            # elif msg['channel'] == 'futures.usertrades':self._update_usertrade(msg['result'])
+                            elif msg['channel'] == 'futures.positions':self._update_position(msg['result'])
+                            elif msg['channel'] == 'futures.trades':self._update_trade(msg['result'])
+                        else:
+                            pass
+                        # pong
+                        if time.time() - ping_time > 5:
+                            await _ws.send_str('{"time": %d, "channel" : "futures.ping"}' % int(time.time()))
+                            ping_time = time.time()
+                        if is_auth:                            
+                            if time.time() - self.private_update_time > self.expired_time*5:
+                                raise Exception('长期未更新私有信息重连')
+                        if time.time() - self.public_update_time > self.expired_time:
+                            raise Exception('长期未更新公有信息重连')
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(traceback.format_exc())
+                await asyncio.sleep(1)
+

+ 309 - 0
exchange/huobi_spot_ws.py

@@ -0,0 +1,309 @@
+import aiohttp
+import time
+import asyncio
+import zlib
+import json
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random
+import gzip, csv, sys
+import logging, logging.handlers
+import utils
+import model
+
+def empty_call(msg):
+    pass
+
+
+
+
+class HuobiSpotWs:
+
+    def __init__(self, params:model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.URL = 'wss://api.huobi.pro/ws'
+        else:
+            self.URL = 'wss://api.huobi.pro/ws'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        self.data = dict()
+        self.data['trade'] = []
+        self.data['force'] = []
+        self.callback = {
+            "onMarket":self.save_market,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.stop_flag = 0
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.public_update_time = time.time()
+        self.private_update_time = time.time()
+        self.expired_time = 300
+        self.update_t = 0.0
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.depth = []
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = self.params.interval
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+    async def get_sign(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        params = {
+            'timestamp':int(time.time())*1000,
+            'recvWindow':5000,
+        }
+        query_string = "&".join(["{}={}".format(k, params[k]) for k in sorted(params.keys())])
+        signature = hmac.new(self.params.secret_key.encode(), msg=query_string.encode(), digestmod=hashlib.sha256).hexdigest()
+        params['signature']=signature
+        url = 'https://fapi.binance.com/fapi/v1/listenKey'
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            url, 
+            params=params,
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        login_str = await response.text()
+        await session.close()
+        return eval(login_str)['listenKey']
+
+    def _update_depth(self, msg):
+        if msg['ts'] > self.update_t:
+            self.update_t = msg['ts']
+            ####
+            self.ticker_info["bp"] = float(msg['tick']['bids'][0][0])
+            self.ticker_info["ap"] = float(msg['tick']['asks'][0][0])
+            self.callback['onTicker'](self.ticker_info)
+            ##### 标准化深度
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for bid in msg['tick']['bids']:
+                price = float(bid[0])
+                amount = float(bid[1])
+                if price > price_thre:
+                    bv[index] += amount
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += amount
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for ask in msg['tick']['asks']:
+                price = float(ask[0])
+                amount = float(ask[1])
+                if price < price_thre:
+                    av[index] += amount
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += amount
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+        else:
+            self.logger.debug(f"depth时间戳错误 {self.update_t}")
+    
+    def _update_trade(self, msg):
+        for i in msg['tick']['data']:
+            side = i['direction']
+            price = float(i['price'])
+            amount = float(i['amount'])
+            if price > self.max_buy or self.max_buy == 0.0:
+                self.max_buy = price
+            if price < self.min_sell or self.min_sell == 0.0:
+                self.min_sell = price
+            if side == 'buy':
+                self.buy_q += amount
+                self.buy_v += amount*price
+            elif side == 'sell':
+                self.sell_q += amount
+                self.sell_v += amount*price
+            #### 修正ticker ####
+            # if side == 'buy' and price > self.ticker_info['ap']:
+            #     self.ticker_info['ap'] = price
+            #     self.callback['onTicker'](self.ticker_info)
+            # if side == 'sell' and price < self.ticker_info['bp']:
+            #     self.ticker_info['bp'] = price
+            #     self.callback['onTicker'](self.ticker_info)
+
+    def _update_account(self, msg):
+        msg = eval(msg)
+        for i in msg['a']['B']:
+            if i['a'] == 'USDT':
+                self.data['equity'] = float(i['wb'])
+        self.callback['onEquity'](self.data['equity'])
+    
+    def _update_order(self, msg):
+        msg = json.loads(msg)
+        i = msg['o']
+        if i['s'] == self.symbol:
+            if i['X'] == 'NEW':  # 新增订单
+                pass
+                # self.callback['onOrder']({"newOrder":newOrder})
+            if i['X'] == 'FILLED':  # 删除订单
+                self.callback['onOrder']({"deleteOrder":i['i']})
+            if i['X'] == 'CANCELED':  # 删除订单
+                self.callback['onOrder']({"deleteOrder":i['i']})
+
+    def _update_position(self, msg):
+        long_pos, short_pos = 0, 0
+        long_avg, short_avg = 0, 0 
+        msg = eval(msg)
+        for i in msg['a']['P']:
+            if i['s'] == self.symbol:
+                if i['ps'] == 'LONG':
+                    long_pos += float(i['pa'])
+                    long_avg = float(i['ep'])
+                if i['ps'] == 'SHORT':
+                    short_pos += float(i['pa'])
+                    short_avg = float(i['ep'])
+        pos = model.Position()
+        self.callback['onPosition'](pos)
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        while True:
+            try:
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                # 登陆
+                ws_url = self.URL
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws连接成功')
+                    # 订阅
+                    symbol = self.symbol.lower()
+                    channels=[
+                        f"market.{symbol}.depth.step0",                        
+                        ]
+                    if sub_trade:
+                        channels.append(f"market.{symbol}.trade.detail")
+                    for i in channels:
+                        sub_str = json.dumps({"sub": i})
+                        await _ws.send_str(sub_str)
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=10)
+                        except:
+                            print(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            break
+                        msg = json.loads(gzip.decompress(msg.data).decode())
+                        # print(msg)
+                        # 处理消息
+                        if 'ch' in msg:
+                            if 'depth' in msg['ch']:self._update_depth(msg)
+                            if 'trade' in msg['ch']:self._update_trade(msg)
+                            # if 'ACCOUNT_UPDATE' in msg:self._update_position(msg)
+                            # if 'ACCOUNT_UPDATE' in msg:self._update_account(msg)
+                            # if 'ORDER_TRADE_UPDATE' in msg:self._update_order(msg)
+                        if 'ping' in msg:
+                            await _ws.send_str(json.dumps({"pong":int(time.time())*1000}))
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                # await asyncio.sleep(1)
+
+
+

+ 393 - 0
exchange/huobi_usdt_swap_rest.py

@@ -0,0 +1,393 @@
+import aiohttp
+import time
+import asyncio
+import zlib
+import json
+import hmac
+import base64
+import hashlib
+import traceback
+import logging, logging.handlers
+import urllib, sys
+from urllib import parse
+from urllib.parse import urljoin
+import datetime
+import logging, logging.handlers
+import model
+import utils
+
+def empty_call(msg):
+    print('空的回调函数')
+
+def decimal_amount(amount, d):
+    if int(d) == 0:
+        return str(int(amount))
+    elif int(d) > 0:
+        return str(round(float(amount), int(d)))
+
+def decimal_price(price, d):
+    if int(d) == 0:
+        return str(int(price))
+    elif int(d) > 0:
+        return str(round(float(price), int(d)))
+
+
+
+class HuobiUsdtSwapRest:
+
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.HOST = 'https://api.hbdm.com'
+        else:
+            self.HOST = 'https://api.hbdm.com'
+        self.params = params
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + '-' + self.quote
+        self.data = {}
+        self._SESSIONS = dict()
+        self.data['account'] = {}
+        self.callback = {
+            "onMarket":empty_call,
+            "onPosition":empty_call,
+            "onOrder":empty_call,
+            }
+        self.exchange_info = dict()
+        self.delays = []
+        self.max_delay = 0.0
+        self.avg_delay = 0.0
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.stop_flag = 0
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+        self.multiplier = None
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def _get_session(self, url):
+        key = url
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+
+    def generate_signature(self, method, params, host_url, request_path):
+        if request_path.startswith("http://") or request_path.startswith("https://"):
+            host_url = urllib.parse.urlparse(request_path).hostname.lower()
+            request_path = '/' + '/'.join(request_path.split('/')[3:])
+        else:
+            host_url = urllib.parse.urlparse(self.HOST).hostname.lower()
+        sorted_params = sorted(params.items(), key=lambda d: d[0], reverse=False)
+        encode_params = urllib.parse.urlencode(sorted_params)
+        payload = [method, host_url, request_path, encode_params]
+        payload = "\n".join(payload)
+        payload = payload.encode(encoding="UTF8")
+        secret_key = self.params.secret_key.encode(encoding="utf8")
+        digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest()
+        signature = base64.b64encode(digest)
+        signature = signature.decode()
+        return signature
+
+    async def _request(self, method, uri, body=None, params=None, auth=False):
+        url = urljoin(self.HOST, uri)
+        if auth:
+            timestamp = datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S")
+            params = params if params else {}
+            params.update({"AccessKeyId": self.params.access_key,
+                "SignatureMethod": "HmacSHA256",
+                "SignatureVersion": "2",
+                "Timestamp": timestamp})
+
+            host_name = urllib.parse.urlparse(self.HOST).hostname.lower()
+            params["Signature"] = self.generate_signature(method, params, host_name, uri)
+        if method == "GET":
+            headers = {
+                "Content-type": "application/x-www-form-urlencoded",
+                "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) "
+                    "Chrome/39.0.2171.71 Safari/537.36"
+            }
+        else:
+            headers = {
+                "Accept": "application/json",
+                "Content-type": "application/json"
+            }
+        # 发起请求
+        session = self._get_session(url)
+        timeout = aiohttp.ClientTimeout(10)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json() 
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            res_msg = msg + f' 回报 {res}'
+            self.logger.debug(res_msg)
+            if code not in (200, 201, 202, 203, 204, 205, 206):
+                self.logger.error(f'请求错误 {res}')
+                self.logger.error(code)
+                return None ,res
+            return json.loads(res), None
+        except Exception as e:
+            print('网络请求出错', e)
+            self.logger.error('请求错误')
+            self.logger.error(e)
+            return None, e
+
+    async def take_order(self, symbol, amount, origin_side, price, cid, order_type='LIMIT'):
+        if origin_side =='kd':
+            side = 'buy'
+            positionSide = 'open'
+        elif origin_side =='pd':
+            side = 'sell'
+            positionSide = 'close'
+        elif origin_side =='kk':
+            side = 'sell'
+            positionSide = 'open'
+        elif origin_side =='pk':
+            side = 'buy'
+            positionSide = 'close'
+        else:
+            raise Exception('下单参数错误')
+        params = {
+            'symbol': symbol, 
+            'quantity': decimal_amount(amount, self.params['decimal_amount']), 
+            'side': side, 
+            'positionSide': positionSide, 
+            'price': decimal_price(price, self.params['decimal_price'] ), 
+            'type':order_type,
+            'timeInForce':'GTC',
+        }
+        # logger.info(f'下单指令 {params}')
+        if self.params['debug'] == 'True':
+            return await asyncio.sleep(0.1)
+        else:
+            response, error = await self._request('POST', '/linear-swap-api/v1/swap_cross_order', params=params, auth=1)
+            # logger.info(f'下单回报 {response}')
+            if response:
+                order_event = dict()
+                order_event['status'] = "NEW"
+                order_event['client_id'] = cid
+                order_event['order_id'] = response["order_id"]
+                self.callback["onOrder"](order_event)
+            if error:
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['client_id'] = cid
+                order_event['filled_price'] = 0.0
+                order_event['filled'] = 0.0
+                order_event['fee'] = 0.0
+                self.callback["onOrder"](order_event)
+        return response
+
+    # async def take_orders(self, orders):
+    #     def change(side):
+    #         if side =='kd':
+    #             side = 'BUY'
+    #             positionSide = 'LONG'
+    #         elif side =='pd':
+    #             side = 'SELL'
+    #             positionSide = 'LONG'
+    #         elif side =='kk':
+    #             side = 'SELL'
+    #             positionSide = 'SHORT'
+    #         elif side =='pk':
+    #             side = 'BUY'
+    #             positionSide = 'SHORT'
+    #         else:
+    #             raise Exception('下单参数错误')
+    #         return side, positionSide
+    #     params = {}
+    #     data = []
+    #     for i in orders:
+    #         data.append({
+    #             'symbol': i[0], 
+    #             'quantity': i[1], 
+    #             'side': change(i[2])[0], 
+    #             'positionSide': change(i[2])[1], 
+    #             'price': i[3], 
+    #             'type':'LIMIT',
+    #             'timeInForce':'GTC',
+    #         })
+    #     params['batchOrders'] = json.dumps(data)
+    #     # logger.info(f'下单指令 {params}')
+    #     if self.params['debug'] == 'True':
+    #         return await asyncio.sleep(0.1)
+    #     else:
+    #         response = await self._request('POST', '/fapi/v1/batchOrders', params=params)
+    #         # logger.info(f'下单回报 {response}')
+    #     return response
+    
+    async def cancel_order(self, order_id=None, client_id=None):
+        if order_id:
+            params = {
+                "contract_code": self.symbol,
+                "order_id": order_id,
+            }
+            response, error = await self._request('POST', f'/linear-swap-api/v1/swap_cross_cancel', params=params, auth=1)
+        if response:
+            pass
+        if error:
+            pass
+        return None
+
+    async def cancel_all_orders(self):
+        params = {
+            "contract_code": self.symbol
+        }
+        return await self._request('POST', f'/linear-swap-api/v1/swap_cross_cancelall', params=params, auth=1)
+    
+    async def get_order_list(self):
+        params = {'contract_code':self.symbol}
+        response, error = await self._request('POST', '/linear-swap-api/v1/swap_cross_openorders', params=params, auth=1)
+        if response:
+            for i in response:
+                pass
+            #     if i['direction'] == 'buy' and i['offset'] == 'open':
+            #         side = 'kd'
+            #     elif i['direction'] == 'sell' and i['offset'] == 'close':
+            #         side = 'pd'
+            #     elif i['direction'] == 'sell' and i['offset'] == 'open':
+            #         side = 'kk'
+            #     elif i['direction'] == 'buy' and i['offset'] == 'close':
+            #         side = 'pk'
+            #     orders.append({
+            #         'order_id':i['order_id'],
+            #         'symbol':i['contract_code'], 
+            #         'amount':float(i['volume']), 
+            #         'side':side, 
+            #         'price':float(i['price']), 
+            #     })
+            # self.callback['onOrder']({"refresh":orders})
+        if error:
+            pass
+        return None
+    
+    async def get_server_time(self):
+        params = {}
+        response, error = await self._request('GET', '/api/v1/timestamp', params=params)
+        return response
+
+    async def get_account(self):
+        return await self._request('POST','/linear-swap-api/v1/swap_cross_account_info', params={}, auth=1)
+
+    async def get_position(self):
+        '''获取持仓 symbol: BTC-USDT'''
+        return await self._request('POST','/linear-swap-api/v1/swap_position_info', params={'contract_code':self.symbol}, auth=1)
+    
+    async def before_trade(self):
+        '''获取市场信息'''
+        res, err = await self._request('GET',f'/linear-swap-api/v1/swap_contract_info', params={}, auth=1)
+        if err:
+            print(err)
+        if res:
+            for i in res['data']:
+                if self.symbol == i['name']:
+                    self.multiplier = float(i['contract_size'])
+                    self.tickSize = float(i['price_tick'])
+                    self.stepSize = 1*float(i['contract_size']) # 张 转换为 币
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i['name']
+                exchange_info.multiplier = float(i['contract_size'])
+                exchange_info.tickSize = float(i['price_tick'])
+                exchange_info.stepSize = 1*float(i['contract_size'])
+                self.exchange_info[exchange_info.symbol] = exchange_info
+        pass
+
+    async def go(self):
+        while 1:
+            try:
+                # 更新账户
+                res, err = await self.get_account()
+                for i in res:
+                    self.data['account'][i['asset']] = i['balance']
+                    if self.quote == i['asset'].upper():
+                        cash = float(i['balance'])
+                self.callback['onEquity']({
+                    self.quote:cash
+                })
+                # 更新仓位
+                res, err = await self.get_position()
+                if res:
+                    p = model.Position()
+                    for i in res:
+                        if i['symbol'] == self.symbol:
+                            if i['positionSide'] == 'LONG':
+                                p.longPos = float(i['positionAmt'])
+                                p.longAvg = float(i['entryPrice'])
+                            if i['positionSide'] == 'SHORT':
+                                p.shortPos = float(i['positionAmt'])
+                                p.shortAvg = float(i['entryPrice'])
+                    self.callback['onPosition'](p)
+                await asyncio.sleep(1)
+                # 打印延迟
+                self.get_delay_info()
+            except:
+                # traceback.print_exc()
+                await asyncio.sleep(10)
+
+    def get_data(self):
+        return self.data
+    
+    async def run(self, orders):
+        '''执行策略指令'''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    if cid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(client_id=cid))
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3],
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    cid = orders[order_name][0]
+                    # oid = orders[order_name][1]
+                    asyncio.get_event_loop().create_task(self.check_order(client_id=cid))
+        except:
+            # traceback.print_exc()
+            await asyncio.sleep(0.1)
+
+

+ 407 - 0
exchange/huobi_usdt_swap_ws.py

@@ -0,0 +1,407 @@
+from os import times
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random
+import gzip, sys
+import csv
+import logging, logging.handlers
+import utils
+import model
+import datetime
+import urllib
+
+def empty_call(msg):
+    pass
+
+
+class HuobiUsdtSwapWs:
+
+    def __init__(self, params:model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.URL_public = 'wss://api.hbdm.com/linear-swap-ws'
+            self.URL_private = 'wss://api.hbdm.com/linear-swap-notification'
+        else:
+            self.URL_public = 'wss://api.hbdm.com/linear-swap-ws'
+            self.URL_private = 'wss://api.hbdm.com/linear-swap-notification'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + '-'+ self.quote
+        self.callback = {
+            "onMarket":self.save_market,
+            "onDepth":empty_call,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.stop_flag = 0
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.public_update_time = time.time()
+        self.private_update_time = time.time()
+        self.expired_time = 300
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.depth = []
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = self.params.interval
+        if msg:
+            name = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{name}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+    async def get_sign(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        params = {
+            'timestamp':int(time.time())*1000,
+            'recvWindow':5000,
+        }
+        query_string = "&".join(["{}={}".format(k, params[k]) for k in sorted(params.keys())])
+        signature = hmac.new(self.params.secret_key.encode(), msg=query_string.encode(), digestmod=hashlib.sha256).hexdigest()
+        params['signature']=signature
+        url = 'https://fapi.binance.com/fapi/v1/listenKey'
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            url, 
+            params=params,
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        login_str = await response.text()
+        await session.close()
+        return eval(login_str)['listenKey']
+
+    def _update_depth(self, msg):
+        self.public_update_time = time.time()
+        self.ticker_info["bp"] = float(msg['tick']['bids'][0][0])
+        self.ticker_info["ap"] = float(msg['tick']['asks'][0][0])
+        self.callback['onTicker'](self.ticker_info)
+        ##### 标准化深度
+        mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+        step = mp * utils.EFF_RANGE / utils.LEVEL
+        bp = []
+        ap = []
+        bv = [0 for _ in range(utils.LEVEL)]
+        av = [0 for _ in range(utils.LEVEL)]
+        for i in range(utils.LEVEL):
+            bp.append(self.ticker_info["bp"]-step*i)
+        for i in range(utils.LEVEL):
+            ap.append(self.ticker_info["ap"]+step*i)
+        # 
+        price_thre = self.ticker_info["bp"] - step
+        index = 0
+        for bid in msg['tick']['bids']:
+            price = float(bid[0])
+            amount = float(bid[1])
+            if price > price_thre:
+                bv[index] += amount
+            else:
+                price_thre -= step
+                index += 1
+                if index == utils.LEVEL:
+                    break
+                bv[index] += amount
+        price_thre = self.ticker_info["ap"] + step
+        index = 0
+        for ask in msg['tick']['asks']:
+            price = float(ask[0])
+            amount = float(ask[1])
+            if price < price_thre:
+                av[index] += amount
+            else:
+                price_thre += step
+                index += 1
+                if index == utils.LEVEL:
+                    break
+                av[index] += amount
+        self.depth = bp + bv + ap + av
+        self.callback['onDepth']({'name':self.name,'data':self.depth})
+    
+    def _update_trade(self, msg):
+        self.public_update_time = time.time()
+        for i in msg['tick']['data']:
+            price = float(i['price'])
+            side = i['direction']
+            amount = float(i['amount'])
+            if price > self.max_buy or self.max_buy == 0.0:
+                self.max_buy = price
+            if price < self.min_sell or self.min_sell == 0.0:
+                self.min_sell = price
+            if side == 'buy':
+                self.buy_q += amount
+                self.buy_v += amount*price
+            elif side == 'sell':
+                self.sell_q += amount
+                self.sell_v += amount*price
+            #### 修正ticker ####
+            # if side == 'buy' and price > self.ticker_info['ap']:
+            #     self.ticker_info['ap'] = price
+            #     self.callback['onTicker'](self.ticker_info)
+            # if side == 'sell' and price < self.ticker_info['bp']:
+            #     self.ticker_info['bp'] = price
+            #     self.callback['onTicker'](self.ticker_info)
+
+    def _update_account(self, msg):
+        for i in msg['data']:
+            if i['margin_asset'] == self.quote:
+                cash = i['margin_balance']
+                self.callback['onEquity']({self.quote:cash})
+    
+    def _update_order(self, msg):
+        if msg['contract_code'] == self.symbol:
+            if msg['status'] in [3] :  # 新增订单
+                order_event = dict()
+                order_event['status'] = "NEW"
+                order_event['filled'] = 0
+                order_event['filled_price'] = 0
+                order_event['client_id'] = msg["client_order_id"] if "client_order_id" in msg else ""
+                order_event['order_id'] = msg['order_id']
+                order_event['fee'] = 0.0
+                self.callback["onOrder"](order_event)
+            elif msg['status'] in [5,6,7]:  # 删除订单
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['filled'] = float(msg['trade_volume'])
+                order_event['filled_price'] = float(msg['trade_price'])
+                order_event['client_id'] = msg["client_order_id"] if "client_order_id" in msg else ""
+                order_event['order_id'] = msg['order_id']
+                if msg['fee_asset'] == self.quote:
+                    order_event['fee'] = float(msg['trade_fee'])
+                self.callback["onOrder"](order_event)
+
+    def _update_position(self, msg):
+        p = model.Position()
+        for i in msg['data']:
+            if i['pair'] == self.symbol:
+                if i['direction'] == 'buy':
+                    p.longPos = float(i['volume'])
+                    p.longAvg = float(i['cost_hold'])
+                if i['direction'] == 'sell':
+                    p.shortPos = float(i['volume'])
+                    p.shortAvg = float(i['cost_hold'])
+        self.callback['onPosition'](p)
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        # run
+        asyncio.create_task(self.run_public(sub_trade=0, sub_fast=0))
+        if is_auth:
+            asyncio.create_task(self.run_private())
+        while True:
+            await asyncio.sleep(5)
+
+    async def run_public(self, sub_trade=0, sub_fast=0):
+        while True:
+            try:
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws public')
+                # 登陆
+                ws_url = self.URL_public
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws public 连接成功')
+                    # 订阅
+                    symbol = self.symbol
+                    channels=[
+                        f"market.{symbol}.depth.step6",
+                        ]
+                    if sub_trade:
+                        channels.append(f"market.{symbol}.trade.detail")
+                    for i in channels:
+                        sub_str = json.dumps({"sub": i})
+                        await _ws.send_str(sub_str)
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except:
+                            print(f'{self.name} ws public 长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws public 长时间没有收到消息 准备重连...')
+                            break
+                        msg = ujson.loads(gzip.decompress(msg.data).decode())
+                        # print(msg)
+                        # 处理消息
+                        if 'ch' in msg:
+                            if 'depth' in msg['ch']:self._update_depth(msg)
+                            if 'trade' in msg['ch']:self._update_trade(msg)
+                        if 'ping' in msg:
+                            await _ws.send_str(json.dumps({"pong":int(time.time())*1000}))
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws public 连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws public 连接失败 开始重连...')
+                # await asyncio.sleep(1)
+
+    async def run_private(self):
+        while True:
+            try:
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws private')
+                # 登陆
+                ws_url = self.URL_private
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws private 连接成功')
+                    # 订阅
+
+                    def generate_signature(method, params, host, request_path, secret_key):
+                        host_url = urllib.parse.urlparse(host).hostname.lower()
+                        sorted_params = sorted(params.items(), key=lambda d: d[0], reverse=False)
+                        encode_params = urllib.parse.urlencode(sorted_params)
+                        payload = [method, host_url, request_path, encode_params]
+                        payload = '{}\n{}\n{}\n{}'.format(payload[0],payload[1],payload[2],payload[3])
+                        payload = payload.encode(encoding="utf8")
+                        secret_key = secret_key.encode(encoding="utf8")
+                        digest = hmac.new(secret_key, payload, digestmod=hashlib.sha256).digest()
+                        signature = base64.b64encode(digest)
+                        signature = signature.decode()
+                        # print(payload)
+                        # digest = hmac.new(secret_key.encode('utf8'), payload.encode(
+                        #     'utf8'), digestmod=hashlib.sha256).digest()
+                        # signature = base64.b64encode(digest).decode()
+                        # get Signature
+                        return signature
+                    
+                    timestamp = datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S")
+                    suffix = 'AccessKeyId={}&SignatureMethod=HmacSHA256&SignatureVersion=2&Timestamp={}'.format(
+                        self.params.access_key, timestamp)
+                    payload = '{}\n{}\n{}\n{}'.format("GET", self.URL_private, "/linear-swap-notification", suffix)
+
+                    digest = hmac.new(self.params.secret_key.encode('utf8'), payload.encode(
+                        'utf8'), digestmod=hashlib.sha256).digest()
+                    signature = base64.b64encode(digest).decode()
+
+                    data = {
+                        "AccessKeyId": self.params.access_key,
+                        "SignatureMethod": "HmacSHA256",
+                        "SignatureVersion": "2",
+                        "Timestamp": timestamp
+                    }
+                    # signature = generate_signature("GET", data, self.URL_private, "/swap-notification", self.params.secret_key)
+                    data["op"] = "auth"
+                    data["type"] = "api"
+                    data["Signature"] = signature
+                    await _ws.send_str(ujson.dumps(data))
+                    # position positions_cross.$contract_code
+                    await _ws.send_str(ujson.dumps({"op":"sub","topic": f"positions_cross.{self.symbol.lower()}"}))
+                    # account accounts_cross.$contract_code
+                    await _ws.send_str(ujson.dumps({"op":"sub","topic": f"accounts_cross.{self.symbol.lower()}"}))
+                    # trade orders_cross.$contract_code
+                    await _ws.send_json(ujson.dumps({"op":"sub","topic": f"orders_cross.{self.symbol.lower()}"}))
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except:
+                            print(f'{self.name} ws private 长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws private 长时间没有收到消息 准备重连...')
+                            break
+                        msg = ujson.loads(gzip.decompress(msg.data).decode())
+                        print(msg)
+                        # 处理消息
+                        if 'ch' in msg:
+                            if 'positions_cross' in msg['topic']:self._update_position(msg)
+                            if 'accounts_cross' in msg['topic']:self._update_account(msg)
+                            if 'orders_cross' in msg['topic']:self._update_order(msg)
+                        if 'ping' in msg:
+                            await _ws.send_str(json.dumps({"pong":int(time.time())*1000}))
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws private 连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws private 连接失败 开始重连...')
+                # await asyncio.sleep(1)
+

+ 531 - 0
exchange/kucoin_spot_rest.py

@@ -0,0 +1,531 @@
+import random
+import aiohttp
+import time
+import asyncio
+import zlib
+import json
+import hmac
+import base64
+import hashlib
+import traceback
+import urllib
+from urllib import parse
+from urllib.parse import urljoin
+import datetime, sys
+from urllib.parse import urlparse
+import logging, logging.handlers
+import utils
+import logging, logging.handlers
+import model
+from decimal import Decimal
+
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+
+
+class KucoinSpotRest:
+
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.HOST = 'https://api.kucoin.com'
+        else:
+            self.HOST = 'https://api.kucoin.com'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + '-' + self.quote
+        self.data = {}
+        self._SESSIONS = dict()
+        self.logger = self.get_logger()
+        self.data['account'] = {}
+        self.callback = {
+            "onMarket":empty_call,
+            "onPosition":empty_call,
+            "onOrder":empty_call,
+            "onEquity":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.exchange_info = dict()
+        self.tickSize = None
+        self.stepSize = None
+        self.delays = []
+        self.max_delay = 0
+        self.avg_delay = 0
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.mp_from_rest = None
+        self.stop_flag = 0
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def get_logger(self):
+
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler("log.log",maxBytes=1024*1024,encoding='utf-8')
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        # log to console
+        console = logging.StreamHandler()
+        console.setLevel(logging.WARNING)
+        logger.addHandler(handler)
+        logger.addHandler(console)
+        return logger
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+
+    async def _request(self, method, uri, body=None, params=None, auth=False):
+        url = urljoin(self.HOST, uri)
+        headers = {}
+        if auth:
+            now_time = int(time.time()) * 1000
+            str_to_sign = str(now_time) + method + uri
+            if method in ['GET', 'DELETE']:
+                data_json = ''
+                if params:
+                    strl = []
+                    for key in params:
+                        strl.append("{}={}".format(key, params[key]))
+                    data_json += '&'.join(strl)
+                    str_to_sign += '?' + data_json
+            else:
+                if body:str_to_sign += body
+            sign = base64.b64encode(hmac.new(self.params.secret_key.encode('utf-8'), str_to_sign.encode('utf-8'), hashlib.sha256).digest())
+            passphrase = base64.b64encode(hmac.new(self.params.secret_key.encode('utf-8'), self.params.pass_key.encode('utf-8'), hashlib.sha256).digest())
+            headers = {
+                "KC-API-SIGN": sign.decode(),
+                "KC-API-TIMESTAMP": str(now_time),
+                "KC-API-KEY": self.params.access_key,
+                "KC-API-PASSPHRASE": passphrase.decode(),
+                "Content-Type": "application/json",
+                "KC-API-KEY-VERSION": "2"
+            }
+        headers["User-Agent"] = "kucoin-python-sdk/v1.0"
+        # 发起请求
+        session = self._get_session(url)
+        timeout = aiohttp.ClientTimeout(10)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, params=None, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json() 
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            self.get_delay_info()
+            res_msg = msg + f' 回报 {res}'
+            self.logger.debug(res_msg)
+            if code not in (200, 201, 202, 203, 204, 205, 206) or \
+                int(res['code']) not in (200, 201, 202, 203, 204, 205, 206, 200000):
+                print(f'URL:{url} PARAMS:{params} body:{body} ERROR:{res}')
+                return None, res
+            return res, None
+        except Exception as e:
+            print(f'{self.name} rest 请求出错', str(e))
+            self.logger.error('请求错误'+str(e))
+            self.logger.error(traceback.format_exc())
+            return None, e
+
+    async def buy_token(self):
+        '''买入平台币'''
+        pass
+
+    async def check_position(self, hold_coin=0.0):
+        '''
+            现货交易 已支持全品种
+        '''
+        try:
+            #######################
+            self.logger.info("清空挂单")
+            params = {
+                'status':"active",
+                'tradeType':'TRADE',
+                'type':'limit'
+            }
+            response, error = await self._request('GET', '/api/v1/orders', params=params, auth=1)
+            if response is not None:
+                for i in response['data']['items']:
+                    res = await self.cancel_order(order_id=i["id"])
+                    self.logger.info(res)
+            #######################
+            self.logger.info("现货全平仓位")
+            # 更新账户
+            res, err = await self.get_account()
+            if err:self.logger.info(err)
+            if res:
+                for i in res["data"]:
+                    if i['type'] != 'trade':
+                        continue
+                    coin_name = i['currency']
+                    symbol = coin_name + '-USDT'
+                    if coin_name in ['USDT','KCS']:
+                        continue
+                    if coin_name == self.base:
+                        _hold_coin = hold_coin
+                    else:
+                        _hold_coin = 0
+                    coin = float(i['balance'])
+                    #######################
+                    ticker ,_ = await self._request('GET',f'/api/v1/market/orderbook/level1', params={"symbol":symbol}, auth=1)
+                    if ticker:
+                        ap = float(ticker["data"]["bestAsk"])
+                        bp = float(ticker["data"]["bestBid"])
+                        mp = (ap+bp)*0.5
+                    else:
+                        continue
+                    coin_value = coin * mp
+                    diff = _hold_coin - coin_value
+                    diff *= 0.99 # 避免无法下单
+                    self.logger.info(f'需要调整现货仓位{diff}usd')
+                    if diff > 20.0:
+                        self.logger.info( await self.take_order(
+                                    symbol,
+                                    diff/mp,
+                                    "kd",
+                                    1,
+                                    utils.get_cid(),
+                                    "market"
+                                ))
+                    elif diff < -20.0:
+                        self.logger.info( await self.take_order(
+                                    symbol,
+                                    -diff/mp,
+                                    "kk",
+                                    1,
+                                    utils.get_cid(),
+                                    "market"
+                                ))
+                # #######################
+        except:
+            self.logger.error("清仓程序执行出错")
+            self.logger.error(traceback.format_exc())
+        return
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='limit'):
+        if origin_side =='kd':
+            side = 'buy'
+        elif origin_side =='pd':
+            side = 'sell'
+        elif origin_side =='kk':
+            side = 'sell'
+        elif origin_side =='pk':
+            side = 'buy'
+        else:
+            print("现货不允许此交易方向")
+            return None
+        if symbol not in self.exchange_info:
+            await self.before_trade()
+            # amount = float(Decimal(str(amount//self.exchange_info[symbol].stepSize))*Decimal(str(self.exchange_info[symbol].stepSize)))
+            # price = float(Decimal(str(price//self.exchange_info[symbol].tickSize))*Decimal(str(self.exchange_info[symbol].tickSize)))
+            amount = utils.fix_amount(amount, self.exchange_info[symbol].stepSize)
+            price = utils.fix_price(price, self.exchange_info[symbol].tickSize)
+        if amount <= 0: 
+            self.logger.error(f'下单参数错误 amount:{amount}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0                
+            order_event['fee'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        if price <= 0:
+            self.logger.error(f'下单参数错误 price:{price}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0                
+            order_event['fee'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        params = {
+            'clientOid':cid,
+            'symbol': symbol, 
+            'size':utils.num_to_str(amount, self.exchange_info[symbol].stepSize),
+            'side': side, 
+            'price':utils.num_to_str(price, self.exchange_info[symbol].tickSize), 
+            'type':order_type,
+        }
+        # logger.info(f'下单指令 {params}')
+        if self.params.debug == 'True':
+            return await asyncio.sleep(0.1)
+        else:
+            # 发单
+            response, error = await self._request('POST', '/api/v1/orders', body=json.dumps(params), auth=1)
+            # 再更新
+            if response:
+                # logger.info(f'下单回报 {response}')
+                # 增加新的
+                if 'data' in response:
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['client_id'] = params["clientOid"]
+                    order_event['order_id'] = response['data']["orderId"]
+                    self.callback["onOrder"](order_event)
+            if error:
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['filled_price'] = 0.0                
+                order_event['fee'] = 0.0
+                order_event['filled'] = 0.0
+                order_event['client_id'] = params["clientOid"]
+                self.callback["onOrder"](order_event)
+        return response
+    
+    async def cancel_order(self, order_id=None, client_id=None):
+        if order_id:
+            response, error = await self._request('DELETE', f'/api/v1/orders/{order_id}', auth=1)
+        elif client_id:
+            response, error = await self._request('DELETE', f'/api/v1/order/client-order/{client_id}', auth=1)
+        else:
+            raise Exception("撤单出错 没指定订单号")
+        if response:
+            self.logger.debug(f'撤单回报 {response}')
+            # 撤单成功不会返回成交信息 所以不触发回调
+        if error:
+            return error
+            # print("撤单失败",error)
+            # self.logger.error(error)
+            # if client_id:await self.check_order(client_id=client_id)
+            # if order_id:await self.check_order(order_id=order_id)
+        return response
+    
+    async def check_order(self, order_id=None, client_id=None):
+        if order_id:
+            response, error = await self._request('GET', f'/api/v1/orders/{order_id}', auth=1)
+        elif client_id:
+            response, error = await self._request('GET', f'/api/v1/order/client-order/{client_id}', auth=1)
+        else:
+            return
+        if response:
+            self.logger.debug(f'查单回报 {response}')
+            order_event = dict()
+            if response["data"]['isActive'] == True:
+                order_event['status'] = "NEW"
+            elif response["data"]['isActive'] == False:
+                order_event['status'] = "REMOVE"
+            else:
+                self.logger.error("错误的订单状态")
+            order_event['price'] = float(response["data"]["price"])
+            order_event['amount'] = float(response["data"]["size"])
+            order_event['filled'] = float(response["data"]["dealSize"])
+            order_event['filled_price'] = float(response["data"]["dealFunds"])/float(response["data"]["dealSize"]) if float(response["data"]["dealSize"]) > 0 else 0
+            order_event['client_id'] = response["data"]["clientOid"]
+            order_event['order_id'] = response["data"]['id']
+            order_event['fee'] = float(response["fee"]) if "fee" in response else 0.0
+            self.callback["onOrder"](order_event)
+        if error:
+            print("查单失败",error)
+            self.logger.error(error)
+        return response
+
+    async def get_order_list(self):
+        params = {
+            'symbol':self.symbol,
+            'status':"active",
+            'tradeType':'TRADE',
+            'type':'limit'
+        }
+        response, error = await self._request('GET', '/api/v1/orders', params=params, auth=1)
+        orders = [] # 重置本地订单列表
+        if response is not None:
+            for i in response['data']['items']:
+                order_event = dict()
+                order_event['symbol'] = self.symbol
+                order_event['price'] = float(i["price"])
+                order_event['amount'] = float(i["size"])
+                order_event['filled'] = float(i["dealSize"])
+                order_event['filled_price'] = float(i["dealFunds"])/float(i["dealSize"]) if float(i["dealSize"]) > 0 else 0
+                order_event['client_id'] = i["clientOid"]
+                order_event['order_id'] = i['id']
+                order_event['fee'] = float(i["fee"]) if "fee" in i else 0.0
+                if i['isActive'] == True:
+                    order_event['status'] = "NEW"
+                elif i['isActive'] == False:
+                    order_event['status'] = "REMOVE"
+                else:
+                    self.logger.error("错误的订单状态")
+                self.callback["onOrder"](order_event)
+        if error:
+            print(error)
+        return response
+    
+    async def get_server_time(self):
+        params = {}
+        response = await self._request('GET', '/api/v1/timestamp', params=params)
+        return response
+
+    async def get_account(self):
+        return await self._request('GET','/api/v1/accounts', body={"type":"trade","currency":self.base}, auth=1)
+
+    async def get_market_details(self):
+        return await self._request('GET',f'/api/v1/symbols', params={}, auth=1)
+
+    async def get_ticker(self):
+        res ,err = await self._request('GET',f'/api/v1/market/orderbook/level1', params={"symbol":self.symbol}, auth=1)
+        if res:
+            ap = float(res["data"]["bestAsk"])
+            bp = float(res["data"]["bestBid"])
+            mp = (ap+bp)*0.5
+            d = {"name":self.name,'mp': mp, 'bp':bp, 'ap':ap}
+            self.callback['onTicker'](d)
+            return d
+        if err:
+            self.logger.error(err)
+            return None 
+
+    async def before_trade(self):
+        # 获取市场最新价格
+        res = await self.get_ticker()
+        ticker_price = res["mp"]
+        if isinstance(ticker_price, float):
+            self.mp_from_rest = ticker_price
+        # 获取市场基本情况
+        res, error = await self.get_market_details()
+        if error:
+            pass
+        else:
+            for i in res['data']:
+                if i['symbol'] == self.symbol:
+                    self.stepSize = float(i["baseIncrement"])
+                    self.tickSize = float(i["priceIncrement"])
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i['symbol']
+                exchange_info.multiplier = 1
+                exchange_info.tickSize = float(i["priceIncrement"])
+                exchange_info.stepSize = float(i["baseIncrement"])
+                self.exchange_info[exchange_info.symbol] = exchange_info
+
+    async def get_equity(self):
+        # 更新账户
+        res, err = await self.get_account()
+        if err:print(err)
+        if res:
+            for i in res["data"]:
+                if 'USDT' == i['currency'] and i['type'] == 'trade':
+                    self.data['equity'] = float(i['balance'])
+                    self.callback['onEquity']({
+                        self.quote:self.data['equity']
+                    })
+                    self.cash_value = self.data['equity']
+                if i['currency'] == self.base and i['type'] == 'trade':
+                    coin = float(i['balance'])
+                    self.callback['onEquity']({
+                        self.base:coin
+                    })
+                    self.coin_value = coin
+
+    async def go(self):
+        await self.before_trade()
+        await asyncio.sleep(1)
+        while 1:
+            try:
+                # 停机信号
+                if self.stop_flag:return
+                # 更新账户
+                res, err = await self.get_account()
+                if err:print(err)
+                if res:
+                    for i in res["data"]:
+                        if self.quote == i['currency'] and i['type'] == 'trade':
+                            self.data['equity'] = float(i['balance'])
+                            self.callback['onEquity']({
+                                self.quote:self.data['equity']
+                            })
+                        if i['currency'] == self.base and i['type'] == 'trade':
+                            coin = float(i['balance'])
+                            self.callback['onEquity']({
+                                self.base:coin
+                            })
+                # 更新订单
+                # res = await self.get_order_list()
+                await asyncio.sleep(60)
+                # 打印延迟
+                self.logger.debug(f'报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+                # 更新rest最新价格 用于风控 开启可能会干扰ws推送的价格
+                # res = await self.get_ticker()
+                # ticker_price = res["mp"]
+                # if isinstance(ticker_price, float):
+                #     self.mp_from_rest = ticker_price
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(10)
+
+    def get_data(self):
+        return self.data
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+    
+    async def handle_signals(self, orders):
+        '''执行策略指令'''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    if cid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(client_id=cid))
+                    elif oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3]
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    asyncio.get_event_loop().create_task(self.check_order(client_id=cid))
+        except Exception as e:
+            traceback.print_exc()
+            self.logger.error("执行信号出错"+str(e))
+            await asyncio.sleep(0.1)
+
+

+ 390 - 0
exchange/kucoin_spot_ws.py

@@ -0,0 +1,390 @@
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random
+import gzip, csv, sys
+from uuid import uuid4
+import logging, logging.handlers
+import utils
+import model
+
+def empty_call(msg):
+    pass
+
+
+class KucoinSpotWs:
+
+    def __init__(self, params: model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.BaseURL = "https://api.kucoin.com"
+        else:
+            self.BaseURL = "https://api.kucoin.com"
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + '-' + self.quote
+        self.callback = {
+            "onMarket":self.save_market,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.local_orders = dict()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.stop_flag = 0
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.update_t = 0.0
+        self.depth = []
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = float(self.params.interval)
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+    async def get_sign(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        params = {
+            'timestamp':int(time.time())*1000,
+            'recvWindow':5000,
+        }
+        query_string = "&".join(["{}={}".format(k, params[k]) for k in sorted(params.keys())])
+        signature = hmac.new(self.params.secret_key.encode(), msg=query_string.encode(), digestmod=hashlib.sha256).hexdigest()
+        params['signature']=signature
+        url = 'https://fapi.binance.com/fapi/v1/listenKey'
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            url, 
+            params=params,
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        login_str = await response.text()
+        await session.close()
+        return ujson.loads(login_str)['listenKey']
+
+    def _update_depth(self, msg):
+        if msg['data']['timestamp'] > self.update_t:
+            self.update_t = msg['data']['timestamp']
+            self.ticker_info["bp"] = float(msg['data']['bids'][0][0])
+            self.ticker_info["ap"] = float(msg['data']['asks'][0][0])
+            self.callback['onTicker'](self.ticker_info)
+            ##### 标准化深度
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for bid in msg['data']['bids']:
+                price = float(bid[0])
+                amount = float(bid[1])
+                if price > price_thre:
+                    bv[index] += amount
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += amount
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for ask in msg['data']['asks']:
+                price = float(ask[0])
+                amount = float(ask[1])
+                if price < price_thre:
+                    av[index] += amount
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += amount
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+    
+    def _update_trade(self, msg):
+        price = float(msg["data"]['price'])
+        side = msg["data"]['side']
+        amount = float(msg["data"]['size'])
+        if price > self.max_buy or self.max_buy == 0.0:
+            self.max_buy = price
+        if price < self.min_sell or self.min_sell == 0.0:
+            self.min_sell = price
+        if side == 'buy':
+            self.buy_q += amount
+            self.buy_v += amount*price
+        elif side == 'sell':
+            self.sell_q += amount
+            self.sell_v += amount*price
+
+    def _update_account(self, msg):
+        if 'trade' in msg['data']["relationEvent"]:
+            if msg['data']["currency"].upper() == self.base:
+                coin = abs(float(msg['data']["total"]))
+                self.callback['onEquity']({
+                    self.base:coin
+                })
+            if msg['data']["currency"].upper() == self.quote:
+                self.callback['onEquity']({
+                    self.quote:float(msg['data']["total"])
+                })
+    
+    def _update_order(self, msg):
+        self.logger.debug(f"ws订单推送 {msg}")
+        if msg['topic'] == '/spotMarket/tradeOrders-batch':
+            for i in msg['data']:
+                if i["symbol"] == self.symbol:
+                    if i["status"] == 'open':  # 新增订单
+                        order_event = dict()
+                        order_event['filled'] = 0
+                        order_event['filled_price'] = 0
+                        order_event['client_id'] = i["clientOid"]
+                        order_event['order_id'] = i['orderId']
+                        order_event['status'] = "NEW"
+                        self.callback["onOrder"](order_event)
+                        self.local_orders[i["clientOid"]] = order_event
+                    elif i["status"] == 'done':  # 删除订单
+                        if "price" in i:
+                            price = float(i["price"])
+                        else:
+                            if i["clientOid"] in self.local_orders:
+                                price = self.local_orders[i["clientOid"]]["price"]
+                            else:
+                                # 应该是非本策略订单 忽略price
+                                price = 0
+                        order_event = dict()
+                        order_event['amount'] = float(i["size"])
+                        order_event['filled'] = float(i["filledSize"])
+                        order_event['filled_price'] = price
+                        order_event['client_id'] = i["clientOid"]
+                        order_event['order_id'] = i['orderId']
+                        order_event['status'] = "REMOVE"
+                        order_event['fee'] = float(i["fee"]) if "fee" in i["data"] else 0.0
+                        self.callback["onOrder"](order_event)
+                        if i["clientOid"] in self.local_orders:
+                            del(self.local_orders[i["clientOid"]])
+        else:
+            if msg["data"]["symbol"] == self.symbol:
+                if msg["data"]["status"] == 'open':  # 新增订单
+                    order_event = dict()
+                    order_event['filled'] = 0
+                    order_event['filled_price'] = 0
+                    order_event['client_id'] = msg["data"]["clientOid"]
+                    order_event['order_id'] = msg["data"]['orderId']
+                    order_event['status'] = "NEW"
+                    self.callback["onOrder"](order_event)
+                    self.local_orders[msg["data"]["clientOid"]] = order_event
+                elif msg["data"]["status"] == 'done':  # 删除订单
+                    if "price" in msg['data']:
+                        price = float(msg["data"]["price"])
+                    else:
+                        if msg["data"]["clientOid"] in self.local_orders:
+                            price = self.local_orders[msg["data"]["clientOid"]]["price"]
+                        else:
+                            # 应该是非本策略订单 忽略price
+                            price = 0
+                    order_event = dict()
+                    order_event['filled'] = float(msg["data"]["filledSize"])
+                    order_event['filled_price'] = price
+                    order_event['client_id'] = msg["data"]["clientOid"]
+                    order_event['order_id'] = msg["data"]['orderId']
+                    order_event['fee'] = float(msg["data"]["fee"]) if "fee" in msg["data"] else 0.0
+                    order_event['status'] = "REMOVE"
+                    self.callback["onOrder"](order_event)
+                    if msg["data"]["clientOid"] in self.local_orders:
+                        del(self.local_orders[msg["data"]["clientOid"]])
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    async def get_token(self, is_auth):
+        if is_auth:
+            uri = "/api/v1/bullet-private"
+        else:
+            uri = "/api/v1/bullet-public"
+        headers = {}
+        if is_auth:
+            now_time = int(time.time()) * 1000
+            str_to_sign = str(now_time) + "POST" + uri
+            sign = base64.b64encode(hmac.new(self.params.secret_key.encode('utf-8'), str_to_sign.encode('utf-8'), hashlib.sha256).digest())
+            passphrase = base64.b64encode(hmac.new(self.params.secret_key.encode('utf-8'), self.params.pass_key.encode('utf-8'), hashlib.sha256).digest())
+            headers = {
+                "KC-API-SIGN": sign.decode(),
+                "KC-API-TIMESTAMP": str(now_time),
+                "KC-API-KEY": self.params.access_key,
+                "KC-API-PASSPHRASE": passphrase.decode(),
+                "Content-Type": "application/json",
+                "KC-API-KEY-VERSION": "2"
+            }
+        headers["User-Agent"] = "kucoin-python-sdk/v1.0"
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            self.BaseURL+uri, 
+            timeout=5, 
+            headers=headers,
+            proxy=self.proxy
+            )
+        res = await response.text()
+        res = ujson.loads(res)
+        await session.close()
+        if res["code"] == "200000":
+            token = res["data"]["token"]
+            ws_connect_id = str(uuid4()).replace('-', '')
+            endpoint = res["data"]['instanceServers'][0]['endpoint']
+            ws_endpoint = f"{endpoint}?token={token}&connectId={ws_connect_id}"
+            encrypt = res["data"]['instanceServers'][0]['encrypt']
+            if is_auth:
+                ws_endpoint += '&acceptUserMessage=true'
+            return ws_endpoint, encrypt
+        else:
+            raise Exception("kucoin spot 获取token错误")
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        while True:
+            try:
+                ping_time = time.time()
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                # 获取token
+                ws_endpoint, encrypt = await self.get_token(is_auth)
+                # 登陆
+                ws_url = ws_endpoint
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws连接成功')
+                    self.logger.info(f'{self.name} ws连接成功')
+                    # 订阅 ticker来的很慢
+                    channels=[
+                        # f"/market/ticker:{self.symbol}",
+                        f"/spotMarket/level2Depth50:{self.symbol}",
+                        ]
+                    if sub_trade:
+                        channels.append(f"/market/match:{self.symbol}")
+                    if is_auth:
+                        channels.append(f"/spotMarket/tradeOrders")
+                        channels.append(f"/account/balance")
+                    for i in channels:
+                        sub_str = ujson.dumps({"topic": i, "type":"subscribe"})
+                        await _ws.send_str(sub_str)
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:
+                            await _ws.close()
+                            return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except:
+                            print(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            break
+                        msg = ujson.loads(msg.data)
+                        # self.logger.debug(msg)
+                        # print(msg)
+                        # 处理消息
+                        if 'data' in msg:
+                            if 'level2' in msg['subject']:self._update_depth(msg)
+                            elif 'trade.l3match' in msg['subject']:self._update_trade(msg)
+                            elif 'account.balance' in msg['subject']:self._update_account(msg)
+                            elif 'orderChange' in msg['subject']:self._update_order(msg)
+                        # heartbeat
+                        if time.time() - ping_time > 30:
+                            msg = {
+                                'id': str(int(time.time() * 1000)),
+                                'type': 'ping'
+                            }
+                            await _ws.send_str(ujson.dumps(msg))
+                            ping_time = time.time()
+            except Exception as e:
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(traceback.format_exc())
+                # await asyncio.sleep(1)
+

+ 537 - 0
exchange/kucoin_usdt_swap_rest.py

@@ -0,0 +1,537 @@
+import random
+import aiohttp
+import time
+import asyncio
+import zlib
+import json
+import hmac
+import base64
+import hashlib
+import traceback
+import urllib
+from urllib import parse
+from urllib.parse import urljoin
+import datetime, sys
+from urllib.parse import urlparse
+import logging, logging.handlers
+import utils
+import logging, logging.handlers
+import model
+
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+
+def decimal_amount(amount, d):
+    if int(d) == 0:
+        return str(int(amount))
+    elif int(d) > 0:
+        return str(round(int(amount*10**int(d))*0.1**int(d), int(d)))
+
+def decimal_price(price, d):
+    if int(d) == 0:
+        return str(int(price))
+    elif int(d) > 0:
+        return str(round(float(price), int(d)))
+
+
+class KucoinUsdtSwapRest:
+
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.HOST = "https://api-futures.kucoin.com"
+        else:
+            self.HOST = "https://api-futures.kucoin.com"
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        # 处理特殊情况
+        if self.base == "BTC":self.base = "XBT"
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote + "M"
+        self._SESSIONS = dict()
+        self.logger = self.get_logger()
+        self.callback = {
+            "onMarket":empty_call,
+            "onPosition":empty_call,
+            "onOrder":empty_call,
+            "onEquity":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.exchange_info = dict()
+        self.decimal_amount = 10
+        self.decimal_price = 10
+        self.delays = []
+        self.max_delay = 0
+        self.avg_delay = 0
+        self.multiplier = None
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+        self.last_cash = 0.0
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def get_logger(self):
+
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler("log.log",maxBytes=1024*1024,encoding='utf-8')
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        # log to console
+        console = logging.StreamHandler()
+        console.setLevel(logging.WARNING)
+        logger.addHandler(handler)
+        logger.addHandler(console)
+        return logger
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+
+    async def _request(self, method, uri, body=None, params=None, auth=False):
+        url = urljoin(self.HOST, uri)
+        headers = {}
+        if auth:
+            now_time = int(time.time()) * 1000
+            str_to_sign = str(now_time) + method + uri
+            if method in ['GET', 'DELETE']:
+                data_json = ''
+                if params:
+                    strl = []
+                    for key in params:
+                        strl.append("{}={}".format(key, params[key]))
+                    data_json += '&'.join(strl)
+                    str_to_sign += '?' + data_json
+            else:
+                if body:str_to_sign += body
+            sign = base64.b64encode(hmac.new(self.params.secret_key.encode('utf-8'), str_to_sign.encode('utf-8'), hashlib.sha256).digest())
+            passphrase = base64.b64encode(hmac.new(self.params.secret_key.encode('utf-8'), self.params.pass_key.encode('utf-8'), hashlib.sha256).digest())
+            headers = {
+                "KC-API-SIGN": sign.decode(),
+                "KC-API-TIMESTAMP": str(now_time),
+                "KC-API-KEY": self.params.access_key,
+                "KC-API-PASSPHRASE": passphrase.decode(),
+                "Content-Type": "application/json",
+                "KC-API-KEY-VERSION": "2"
+            }
+        headers["User-Agent"] = "kucoin-python-sdk/v1.0"
+        # 发起请求
+        session = self._get_session(url)
+        timeout = aiohttp.ClientTimeout(10)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, params=None, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json() 
+            delay = int(1000*(time.time() - start_time))
+            res_msg = msg + f' 回报 {res}'
+            self.logger.debug(res_msg)
+            self.delays.append(delay)
+            if code not in (200, 201, 202, 203, 204, 205, 206) or \
+                int(res['code']) not in (200, 201, 202, 203, 204, 205, 206, 200000):
+                print(f'URL:{url} METHOD:{method} PARAMS:{params} body:{body} ERROR:{res}')
+                self.logger.error('请求错误'+str(res))
+                self.logger.error(res)
+                return None, res
+            return res, None
+        except Exception as e:
+            print(f'{self.name} rest 请求出错', str(e))
+            self.logger.error('请求错误'+str(e))
+            self.logger.error(traceback.format_exc())
+            return None, e
+    
+    async def buy_token(self):
+        '''买入平台币'''
+        pass
+
+    async def check_position(self, hold_coin=0.0):
+        '''
+            两次执行check_position之间必须停留足够时间 避免position没及时更新 导致重复下单 已支持全品种
+        '''
+        try:
+            ##############################
+            self.logger.info("清空挂单")
+            params = {
+                'status':"active",
+                'tradeType':'TRADE',
+            }
+            response, error = await self._request('GET', '/api/v1/orders', params=params, auth=1)
+            if response is not None:
+                for i in response['data']['items']:
+                    res = await self.cancel_order(order_id=i["id"])
+                    self.logger.info(res)
+            self.logger.info(f"{self.name} 全平仓位")
+            ##############################
+            # 更新仓位
+            print("获取仓位")
+            res, err = await self._request('GET','/api/v1/positions', params={}, auth=1)
+            if err:self.logger.info(err)
+            if res:
+                for i in res['data']:
+                    symbol = i['symbol']
+                    amt = float(i['currentQty'])
+                    # 转换为单位:张
+                    if self.exchange_info == dict():
+                        await self.before_trade()
+                    if amt > 0:
+                        self.logger.info( await self.take_order(
+                                    symbol,
+                                    amt*self.exchange_info[symbol].multiplier,
+                                    "pd",
+                                    1,
+                                    utils.get_cid(),
+                                    "market"
+                                ))
+                    elif amt < 0:
+                        self.logger.info( await self.take_order(
+                                    symbol,
+                                    -amt*self.exchange_info[symbol].multiplier,
+                                    "pk",
+                                    1,
+                                    utils.get_cid(),
+                                    "market"
+                                ))
+                ##############################        
+        except:
+            self.logger.error("清仓程序执行出错")
+            self.logger.error(traceback.format_exc())
+        return
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='limit'):
+        '''
+            kumex会合并多空方向的仓位 不支持双向交易 所以reduce_only要小心使用
+        '''
+        reduceOnly = False
+        if origin_side =='kd':
+            side = 'buy'
+        elif origin_side =='pd':
+            side = 'sell'
+        elif origin_side =='kk':
+            side = 'sell'
+        elif origin_side =='pk':
+            side = 'buy'
+        else:
+            return None
+        # 转换为单位:张
+        if symbol not in self.exchange_info:
+            await self.before_trade()
+        amount_paper = int(amount/self.exchange_info[symbol].multiplier)
+        if amount_paper == 0:
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['client_id'] = cid
+            order_event['filled'] = 0
+            order_event['filled_price'] = 0
+            self.callback["onOrder"](order_event)
+            return
+        params = {
+            'clientOid':cid,
+            'symbol': symbol, 
+            'size': amount_paper,
+            'side': side, 
+            'leverage': 10, #10, 没kyc老报错,##############################################
+            'reduceOnly':reduceOnly,
+            'price':utils.num_to_str(price, self.exchange_info[symbol].tickSize), 
+            'type':order_type,
+        }
+        # logger.info(f'下单指令 {params}')
+        if self.params.debug == 'True':
+            return await asyncio.sleep(0.1)
+        else:
+            # 发单
+            response, error = await self._request('POST', '/api/v1/orders', body=json.dumps(params), auth=1)
+            # 更新
+            if response:
+                # 增加新的
+                if 'data' in response:
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['client_id'] = cid
+                    order_event['order_id'] = response['data']["orderId"]
+                    self.callback["onOrder"](order_event)
+            if error:
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['client_id'] = cid
+                order_event['filled_price'] = 0
+                order_event['filled'] = 0
+                self.callback["onOrder"](order_event)
+        return response
+    
+    async def cancel_order(self, order_id=None, client_id=None):
+        if order_id:
+            response, error = await self._request('DELETE', f'/api/v1/orders/{order_id}', auth=1)
+        elif client_id:
+            response, error = await self._request('DELETE', f'/api/v1/orders/byClientOid?clientOid={client_id}', auth=1)
+        else:
+            raise Exception("撤单出错 没指定订单号")
+        if response:
+            self.logger.debug(f'撤单回报 {response}')
+            # 撤单成功不会返回成交信息 所以不触发回调
+        if error:
+            print("撤单失败",error)
+            return error
+        return response
+    
+    async def check_order(self, order_id=None, client_id=None):
+        if order_id:
+            response, error = await self._request('GET', f'/api/v1/orders/{order_id}', auth=1)
+        elif client_id:
+            response, error = await self._request('GET', f'/api/v1/orders/byClientOid?clientOid={client_id}', auth=1)
+        else:
+            return
+        if response:
+            self.logger.debug(f'查单回报 {response}')
+            order_event = dict()
+            if response["data"]['isActive'] == True:
+                order_event['status'] = "NEW"
+            elif response["data"]['isActive'] == False:
+                order_event['status'] = "REMOVE"
+            else:
+                self.logger.error("错误的订单状态")
+            if self.multiplier == None:
+                await self.before_trade()
+            order_event['filled'] = float(response["data"]["filledSize"])*self.multiplier
+            order_event['filled_price'] = float(response["data"]["filledValue"])/float(response["data"]["filledSize"])/self.multiplier if float(response["data"]["filledSize"]) > 0 else 0
+            order_event['client_id'] = response["data"]["clientOid"]
+            order_event['order_id'] = response["data"]['id']
+            self.callback["onOrder"](order_event)
+        if error:
+            print("查单失败",error)
+            self.logger.error(error)
+        return response
+
+    async def get_order_list(self):
+        params = {
+            'symbol':self.symbol,
+            'status':"active",
+            'tradeType':'TRADE',
+        }
+        response, error = await self._request('GET', '/api/v1/orders', params=params, auth=1)
+        orders = [] # 重置本地订单列表
+        if response is not None:
+            for i in response['data']['items']:
+                order_event = dict()
+                if i['isActive'] == True:
+                    order_event['status'] = "NEW"
+                elif i['isActive'] == False:
+                    order_event['status'] = "REMOVE"
+                else:
+                    self.logger.error("错误的订单状态")
+                order_event['filled'] = float(i["dealSize"])
+                order_event['filled_price'] = float(i["dealFunds"])/float(i["dealSize"]) if float(i["dealSize"]) > 0 else 0
+                order_event['client_id'] = i["clientOid"]
+                order_event['order_id'] = i['id']
+                self.callback["onOrder"](order_event)
+        if error:
+            print(error)
+        return response
+    
+    async def get_server_time(self):
+        params = {}
+        response = await self._request('GET', '/api/v1/timestamp', params=params)
+        return response
+
+    async def get_account(self):
+        return await self._request('GET','/api/v1/account-overview', params={"currency":self.quote}, auth=1)
+
+    async def get_position(self):
+        return await self._request('GET','/api/v1/position', params={"symbol":self.symbol}, auth=1)
+
+    async def get_market_details(self):
+        return await self._request('GET',f'api/v1/contracts/active', params={}, auth=1)
+
+    async def get_ticker(self):
+        res, err = await self._request('GET',f'api/v1/ticker', params={"symbol":self.symbol}, auth=0)
+        if res:
+            ap = float(res['data']['bestAskPrice']) 
+            bp = float(res['data']['bestBidPrice'])
+            mp = (ap+bp)/2
+            ticker = {"name":self.name,'mp': mp, 'bp':bp, 'ap':ap}
+            return ticker
+        if err:
+            self.logger.error(err)
+        return
+
+    async def before_trade(self):
+        # 获取市场
+        res, error = await self.get_market_details()
+        if error:
+            pass
+        if res:
+            for i in res['data']:
+                if i['symbol'] == self.symbol:
+                    # 1张多少个币
+                    self.multiplier = float(i["multiplier"])
+                    # 1张 币的数量精度   张 转换成 币 需要乘以乘数
+                    self.stepSize =  i["lotSize"]*self.multiplier
+                    self.tickSize = i['tickSize']
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i['symbol']
+                exchange_info.multiplier = float(i["multiplier"])
+                exchange_info.tickSize = i['tickSize']
+                exchange_info.stepSize = i["lotSize"]*float(i["multiplier"])
+                self.exchange_info[exchange_info.symbol] = exchange_info
+        # 设置杠杆
+        await self._request('GET',f'api/v1/contracts/active', params={}, auth=1)
+        # 更新账户
+        res, err = await self.get_position()
+        if err:print(err)
+        if res:
+            amt = float(res['data']['currentQty']) * self.multiplier
+            ep = float(res['data']['avgEntryPrice'])
+            pos = model.Position()
+            if amt == 0.0:
+                pos.longPos = 0
+                pos.longAvg = 0
+                pos.shortPos = 0
+                pos.shortAvg = 0
+            elif amt > 0.0:
+                pos.longPos = amt
+                pos.longAvg = ep
+                pos.shortPos = 0
+                pos.shortAvg = 0
+            elif amt < 0.0:
+                pos.longPos = 0
+                pos.longAvg = 0
+                pos.shortPos = -amt
+                pos.shortAvg = ep
+            self.callback['onPosition'](pos)
+
+    async def get_equity(self):
+        # 更新账户
+        res, err = await self.get_account()
+        if err:print(err)
+        if res:
+            cash = float(res['data']['accountEquity'])
+            if self.last_cash == 0:
+                self.last_cash = cash
+            self.callback['onEquity']({self.quote:cash})
+            self.cash_value = cash
+
+    async def go(self):
+        interval = 60
+        await self.before_trade()
+        await asyncio.sleep(1)
+        while 1:
+            try:
+                # 更新账户
+                res, err = await self.get_account()
+                if err:print(err)
+                if res:
+                    if res['data']['currency'] == "USDT":
+                        cash = float(res['data']['accountEquity'])
+                        # rest有可能获取到中间态 为了避免得到错误的账户信息 需进行判断
+                        if self.last_cash == 0:
+                            self.last_cash = cash
+                            self.callback['onEquity']({self.quote:cash})
+                        else:
+                            # 判断净值是否出现大幅度偏离 两次更新之间的差值不应超过5%
+                            if abs(self.last_cash - cash)/cash < 0.05:
+                                self.last_cash = cash
+                                self.callback['onEquity']({self.quote:cash})
+                            else:
+                                # 否则舍弃本次更新
+                                pass
+                # 更新账户
+                res, err = await self.get_position()
+                if err:print(err)
+                if res:
+                    amt = float(res['data']['currentQty']) * self.multiplier
+                    ep = float(res['data']['avgEntryPrice'])
+                    pos = model.Position()
+                    if amt == 0.0:
+                        pos.longPos = 0
+                        pos.longAvg = 0
+                        pos.shortPos = 0
+                        pos.shortAvg = 0
+                    elif amt > 0.0:
+                        pos.longPos = amt
+                        pos.longAvg = ep
+                        pos.shortPos = 0
+                        pos.shortAvg = 0
+                    elif amt < 0.0:
+                        pos.longPos = 0
+                        pos.longAvg = 0
+                        pos.shortPos = -amt
+                        pos.shortAvg = ep
+                    self.callback['onPosition'](pos)
+                await asyncio.sleep(interval)
+                # 打印延迟
+                self.get_delay_info()
+                self.logger.debug(f'报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(10)
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+    
+    async def handle_signals(self, orders):
+        '''执行策略指令'''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    # cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    # kumex 只能按oid撤单
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3]
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    # cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.check_order(order_id=oid))
+        except:
+            # traceback.print_exc()
+            await asyncio.sleep(0.1)
+
+

+ 386 - 0
exchange/kucoin_usdt_swap_ws.py

@@ -0,0 +1,386 @@
+import aiohttp
+import time
+import asyncio
+import zlib
+import json, ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random
+import gzip, csv, sys
+from uuid import uuid4
+import logging, logging.handlers
+import utils
+import model
+from decimal import Decimal
+
+def empty_call(msg):
+    # print(msg)
+    pass
+
+
+class KucoinUsdtSwapWs:
+
+    def __init__(self, params: model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.BaseURL = "https://api-futures.kucoin.com"
+        else:
+            self.BaseURL = "https://api-futures.kucoin.com"
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote + "M"
+        # 处理特殊情况
+        if self.symbol == 'BTCUSDTM':
+            self.symbol = 'XBTUSDTM'
+        self.callback = {
+            "onMarket":self.save_market,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.multiplier = None
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.update_t = 0.0
+        self.depth = []
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = float(self.params.interval)
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+    async def get_sign(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        params = {
+            'timestamp':int(time.time())*1000,
+            'recvWindow':5000,
+        }
+        query_string = "&".join(["{}={}".format(k, params[k]) for k in sorted(params.keys())])
+        signature = hmac.new(self.params.secret_key.encode(), msg=query_string.encode(), digestmod=hashlib.sha256).hexdigest()
+        params['signature']=signature
+        url = 'https://fapi.binance.com/fapi/v1/listenKey'
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            url, 
+            params=params,
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        login_str = await response.text()
+        await session.close()
+        return ujson.loads(login_str)['listenKey']
+
+    def _update_depth(self, msg):
+        if msg['data']['sequence'] > self.update_t:
+            self.update_t = msg['data']['sequence']
+            self.ticker_info['bp'] = float(msg['data']['bids'][0][0])
+            self.ticker_info['ap'] = float(msg['data']['asks'][0][0])
+            self.callback['onTicker'](self.ticker_info)
+            ##### 标准化深度
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for bid in msg['data']['bids']:
+                price = float(bid[0])
+                amount = float(bid[1])
+                if price > price_thre:
+                    bv[index] += amount
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += amount
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for ask in msg['data']['asks']:
+                price = float(ask[0])
+                amount = float(ask[1])
+                if price < price_thre:
+                    av[index] += amount
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += amount
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+
+    def _update_ticker(self, msg):
+        if msg['data']['sequence'] > self.update_t:
+            self.update_t = msg['data']['sequence']
+            self.ticker_info['bp'] = float(msg['data']['bestBidPrice'])
+            self.ticker_info['ap'] = float(msg['data']['bestAskPrice'])
+            self.callback['onTicker'](self.ticker_info)
+    
+    def _update_trade(self, msg):
+        price = float(msg["data"]['price'])
+        side = msg["data"]['side']
+        amount = float(msg["data"]['size'])*self.multiplier
+        if price > self.max_buy or self.max_buy == 0.0:
+            self.max_buy = price
+        if price < self.min_sell or self.min_sell == 0.0:
+            self.min_sell = price
+        if side == 'buy':
+            self.buy_q += amount
+            self.buy_v += amount*price
+        elif side == 'sell':
+            self.sell_q += amount
+            self.sell_v += amount*price
+
+    def _update_position(self, msg):
+        pos = model.Position()
+        if "currentQty" in msg['data']:
+            amt = float(msg["data"]["currentQty"]) * self.multiplier
+            ep = float(msg["data"]["avgEntryPrice"])
+            if amt == 0:
+                self.callback["onPosition"](pos)
+            elif amt > 0:
+                pos.longPos = amt
+                pos.longAvg = ep
+                self.callback["onPosition"](pos)
+            elif amt < 0:
+                pos.shortPos = -amt
+                pos.shortAvg = ep
+                self.callback["onPosition"](pos)
+
+    def _update_account(self, msg):
+        pass
+        # if msg['data']['currency'] == 'USDT' and msg['subject'] == "availableBalance.change":
+        #     cash = float(msg['data']['availableBalance']) + float(msg['data']['holdBalance'])
+        #     self.callback['onEquity'] = {self.quote:cash}
+    
+    def _update_order(self, msg):
+        self.logger.debug(f"ws订单推送 {msg}")
+        if '/contractMarket/tradeOrders' in msg['topic']:
+            if msg["data"]["symbol"] == self.symbol:
+                if msg["data"]["status"] == 'open':  # 新增订单
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['filled'] = 0
+                    order_event['filled_price'] = 0
+                    order_event['client_id'] = msg["data"]["clientOid"] if "clientOid" in msg["data"] else ""
+                    order_event['order_id'] = msg["data"]['orderId']
+                    self.callback["onOrder"](order_event)
+                elif msg["data"]["type"] in ['filled','canceled']:  # 删除订单
+                    order_event = dict()
+                    order_event['status'] = "REMOVE"
+                    order_event['client_id'] = msg["data"]["clientOid"] if "clientOid" in msg["data"] else ""
+                    order_event['order_id'] = msg["data"]['orderId']
+                    order_event['filled'] = float(Decimal(msg["data"]["filledSize"])*Decimal(str(self.multiplier)))
+                    if 'price' in msg["data"]:
+                        if msg['data']['price'] != '':
+                            order_event['filled_price'] = float(msg["data"]["price"])
+                        else:
+                            order_event['filled_price'] = 0 
+                    else:
+                        order_event['filled_price'] = 0 
+                    self.callback["onOrder"](order_event)
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket']({'name': self.name,'data':market_data})
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    async def get_token(self, is_auth):
+        # 获取 合约系数
+        session = aiohttp.ClientSession()
+        response = await session.get(
+            "https://api-futures.kucoin.com/api/v1/contracts/active", 
+            proxy=self.proxy
+            )
+        res = await response.json()
+        for i in res['data']:
+            if i['symbol'] == self.symbol:
+                self.multiplier = float(i["multiplier"])
+        print(f"合约乘数为 {self.multiplier}")
+        self.logger.debug(f"合约乘数为 {self.multiplier}")
+        await session.close()
+        # 获取 token
+        if is_auth:
+            uri = "/api/v1/bullet-private"
+        else:
+            uri = "/api/v1/bullet-public"
+        headers = {}
+        if is_auth:
+            now_time = int(time.time()) * 1000
+            str_to_sign = str(now_time) + "POST" + uri
+            sign = base64.b64encode(hmac.new(self.params.secret_key.encode('utf-8'), str_to_sign.encode('utf-8'), hashlib.sha256).digest())
+            passphrase = base64.b64encode(hmac.new(self.params.secret_key.encode('utf-8'), self.params.pass_key.encode('utf-8'), hashlib.sha256).digest())
+            headers = {
+                "KC-API-SIGN": sign.decode(),
+                "KC-API-TIMESTAMP": str(now_time),
+                "KC-API-KEY": self.params.access_key,
+                "KC-API-PASSPHRASE": passphrase.decode(),
+                "Content-Type": "application/json",
+                "KC-API-KEY-VERSION": "2"
+            }
+        headers["User-Agent"] = "kucoin-python-sdk/v1.0"
+        session = aiohttp.ClientSession()
+        response = await session.post(
+            self.BaseURL+uri, 
+            timeout=5, 
+            headers=headers,
+            proxy=self.proxy
+            )
+        res = await response.text()
+        res = ujson.loads(res)
+        await session.close()
+        if res["code"] == "200000":
+            token = res["data"]["token"]
+            ws_connect_id = str(uuid4()).replace('-', '')
+            endpoint = res["data"]['instanceServers'][0]['endpoint']
+            ws_endpoint = f"{endpoint}?token={token}&connectId={ws_connect_id}"
+            encrypt = res["data"]['instanceServers'][0]['encrypt']
+            if is_auth:
+                ws_endpoint += '&acceptUserMessage=true'
+            return ws_endpoint, encrypt
+        else:
+            raise Exception("kucoin usdt swap 获取token错误")
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        while True:
+            try:
+                ping_time = time.time()
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                # 获取token
+                ws_endpoint, encrypt = await self.get_token(is_auth)
+                # 登陆
+                ws_url = ws_endpoint
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws连接成功')
+                    self.logger.info(f'{self.name} ws连接成功')
+                    # 订阅
+                    channels=[
+                        # f"/contractMarket/tickerV2:{self.symbol}",
+                        f"/contractMarket/level2Depth50:{self.symbol}",
+                        ]
+                    if sub_trade:
+                        channels += [f"/contractMarket/execution:{self.symbol}"]
+                    if is_auth:
+                        channels += [
+                            f"/contractAccount/wallet",
+                            f"/contract/position:{self.symbol}",
+                            f"/contractMarket/tradeOrders:{self.symbol}",
+                        ]
+                    for i in channels:
+                        sub_str = ujson.dumps({"topic": i, "type":"subscribe"})
+                        await _ws.send_str(sub_str)
+                    while True:
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except:
+                            print(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            break
+                        msg = ujson.loads(msg.data)
+                        # print(msg)
+                        # 处理消息
+                        if 'data' in msg:
+                            if 'level2' in msg['subject']:self._update_depth(msg)
+                            elif "tickerV2" in msg["subject"]:self._update_ticker(msg)
+                            elif 'match' in msg['subject']:self._update_trade(msg)
+                            elif 'orderMargin.change' in msg['subject']:self._update_account(msg)
+                            elif 'symbolOrderChange' in msg['subject']:self._update_order(msg)
+                            elif 'position.change' in msg['subject']:self._update_position(msg)
+                        # heartbeat
+                        if time.time() - ping_time > 30:
+                            msg = {
+                                'id': str(int(time.time() * 1000)),
+                                'type': 'ping'
+                            }
+                            await _ws.send_str(ujson.dumps(msg))
+                            ping_time = time.time()
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                # await asyncio.sleep(1)
+

+ 471 - 0
exchange/mexc_spot_rest.py

@@ -0,0 +1,471 @@
+import model
+import sys
+import time
+import utils
+import logging
+import aiohttp
+import asyncio
+import traceback
+from urllib.parse import urlparse
+from urllib.parse import urljoin
+from decimal import Decimal
+from hashlib import sha256
+import hmac, base64
+import json
+
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+class MexcSpotRest:
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.HOST = 'https://api.mexc.com'
+        else:
+            self.HOST = 'https://api.mexc.com'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        self.data = {}
+        self._SESSIONS = dict()
+        self.logger = self.get_logger()
+        self.data['account'] = {}
+        self.callback = {
+            "onMarket": empty_call,
+            "onPosition": empty_call,
+            "onOrder": empty_call,
+            "onEquity": empty_call,
+            "onTicker": empty_call,
+            "onExit": empty_call,
+        }
+        self.exchange_info = dict()
+        self.tickSize = None
+        self.stepSize = None
+        self.delays = []
+        self.max_delay = 0
+        self.avg_delay = 0
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.stop_flag = 0
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter(
+            '[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(
+            "log.log", maxBytes=1024*1024, encoding='utf-8')
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        # log to console
+        console = logging.StreamHandler()
+        console.setLevel(logging.WARNING)
+        logger.addHandler(handler)
+        logger.addHandler(console)
+        return logger
+    
+    def get_sign(self, params, secret_key):
+        key = secret_key.encode('utf-8')
+        params = params.encode('utf-8')
+        sign = base64.b64encode(hmac.new(key, params, digestmod=sha256).digest()).decode()
+        return sign
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+    
+    async def _request(self, method, uri, body=None, params=None, HOST=None, auth=False):
+        url = urljoin(HOST, uri)
+        headers = {'X-MEXC-APIKEY': self.access_key, 'Content-Type': 'application/json'}
+        if params != None:
+            params['timestamp'] = int(time.time())*1000
+            query_string = "&".join(["{}={}".format(k, params[k]) for k in params.keys()])
+            params['signature'] = self.get_sign(query_string, self.secret_key)
+        
+        if auth:
+            if method == 'GET':
+                headers = {
+                    'X-MEXC-APIKEY': self.access_key,
+                    'Content-Type': 'application/json'
+                    }
+            else:
+                headers = {
+                    'X-MEXC-APIKEY': self.access_key,
+                    'Content-Type': 'application/x-www-form-urlencoded'
+                }
+        # 发起请求
+        session = self._get_session(url)
+        timeout = aiohttp.ClientTimeout(10)
+        msg = "rest请求记录" + str(method) + str(url) + str(params) + str(body)
+        self.logger.debug(msg)
+        try:
+            start_time = time.time()
+            if method == "GET":
+                response = await session.get(url, params=params, headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "POST":
+                response = await session.post(url, params=None, data=json.dumps(body), headers=headers, timeout=timeout, proxy=self.proxy)
+            elif method == "DELETE":
+                response = await session.delete(url, params=params, data=body, headers=headers, timeout=timeout, proxy=self.proxy)
+            code = response.status
+            res = await response.json()
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            res_msg = msg + f' 回报 {res}'
+            self.logger.debug(res_msg)
+            if code not in (200, 201, 202, 203, 204, 205, 206) or res['code'] not in (0, 200):
+                print(f'URL:{url} PARAMS:{params} body:{body} ERROR:{res}')
+                return None, res
+            return res, None
+        except Exception as e:
+            print(f'{self.name} rest 请求出错', str(e))
+            self.logger.error('请求错误'+str(e))
+            self.logger.error(traceback.format_exc())
+            return None, e
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='LIMIT'):
+        if origin_side == 'kd':
+            side = 'buy'
+        elif origin_side == 'pd':
+            side = 'sell'
+        elif origin_side == 'kk':
+            side = 'sell'
+        elif origin_side == 'pk':
+            side = 'buy'
+        else:
+            print("现货不允许此交易方向")
+            return None
+        if symbol not in self.exchange_info:
+            await self.before_trade()
+            amount = float(Decimal(str(amount//self.exchange_info[symbol].stepSize))
+                           * Decimal(str(self.exchange_info[symbol].stepSize)))
+            price = float(Decimal(str(price//self.exchange_info[symbol].tickSize))
+                          * Decimal(str(self.exchange_info[symbol].tickSize)))
+        if amount <= 0:
+            self.logger.error(f'下单参数错误 amount:{amount}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0
+            order_event['fee'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        if price <= 0:
+            self.logger.error(f'下单参数错误 price:{price}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0
+            order_event['fee'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None
+        # Mexc 参数可能有变 Galois
+        params = {
+            'access_id': self.params.access_key,
+            'client_id': cid,
+            'symbol': symbol,
+            'amount': utils.num_to_str(amount, self.exchange_info[symbol].stepSize),
+            'type': side,
+            'price': utils.num_to_str(price, self.exchange_info[symbol].tickSize),
+        }
+        # logger.info(f'下单指令 {params}')
+        if self.params.debug == 'True':
+            return await asyncio.sleep(0.1)
+        else:
+            # 发单
+            response, error = await self._request('POST', '/api/v3/order', body=params, auth=1)
+            # 再更新
+            if response:
+                # logger.info(f'下单回报 {response}')
+                # 增加新的
+                if 'data' in response:
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['client_id'] = params["client_id"]
+                    order_event['order_id'] = response['data']["id"]
+                    self.callback["onOrder"](order_event)
+            if error:
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['filled_price'] = 0.0
+                order_event['fee'] = 0.0
+                order_event['filled'] = 0.0
+                order_event['client_id'] = params["client_id"]
+                self.callback["onOrder"](order_event)
+                return error
+        return response
+    
+    async def cancel_order(self, order_id=None, client_id=None, symbol=None):
+        if order_id:
+            response, error = await self._request('DELETE', f'/api/v3/margin/order', params={'symbol': self.symbol, 'orderId': order_id}, auth=1)
+        elif client_id:
+            response, error = await self._request('DELETE', f'/api/v3/margin/order', params={'symbol': self.symbol, 'orderId': client_id}, auth=1)
+        else:
+            raise Exception("撤单出错 没指定订单号")
+        if response:
+            self.logger.debug(f'撤单回报 {response}')
+        if error:
+            print("撤单失败", error)
+            self.logger.error(error)
+        return response
+    
+    async def check_order(self, order_id=None, client_id=None):
+        if order_id:
+            response, error = await self._request('GET', f'/api/v3/order', params={'symbol': self.symbol, 'orderId': order_id}, auth=1)
+        elif client_id:
+            response, error = await self._request('GET', f'/api/v3/order', params={'symbol': self.symbol, 'orderId': client_id}, auth=1)
+        else:
+            return
+        if response:
+            self.logger.debug(f'查单回报 {response}')
+            order_event = dict()
+            if response["data"]['status'] in ['not_deal', 'part_deal']:
+                order_event['status'] = "NEW"
+            elif response["data"]['status'] in ['cancel', 'done']:
+                order_event['status'] = "REMOVE"
+            else:
+                self.logger.error("错误的订单状态")
+            order_event['price'] = float(response["data"]["price"])
+            order_event['amount'] = float(response["data"]["amount"])
+            order_event['filled'] = float(
+                response["data"]["amount"])-float(response["data"]["left"])
+            order_event['filled_price'] = float(response["data"]["avg_price"])
+            order_event['client_id'] = response["data"]["client_id"]
+            order_event['order_id'] = response["data"]['id']
+            asset_fee = float(response['data']["asset_fee"])
+            money_fee = float(response['data']["money_fee"])
+            stock_fee = float(response['data']["stock_fee"])
+            # 非amm品种 优先扣cet 其次u 再次b
+            # amm品种 买入收b 卖出收u
+            if response['data']['type'] == "sell":
+                # 卖出
+                order_event['fee'] = money_fee
+            elif response['data']['type'] == "buy":
+                # 买入
+                order_event['fee'] = stock_fee
+            self.callback["onOrder"](order_event)
+        if error:
+            print("查单失败", error)
+            self.logger.error(error)
+        return response
+    
+    async def get_order_list(self):
+        params = {
+            'symbol': self.symbol,
+            'limit': 100,
+        }
+        response, error = await self._request('GET', '/api/v3/allOrders', params=params, auth=1)
+        orders = []  # 重置本地订单列表
+        if response is not None:
+            for i in response['data']['data']:
+                order_event = dict()
+                order_event['symbol'] = self.symbol
+                order_event['price'] = float(i["price"])
+                order_event['amount'] = float(i["amount"])
+                order_event['filled'] = float(i["amount"])-float(i["left"])
+                order_event['filled_price'] = float(i["avg_price"])
+                order_event['client_id'] = i["client_id"] if 'client_id' in i else ""
+                order_event['order_id'] = i['id']
+                asset_fee = float(i["asset_fee"])
+                money_fee = float(i["money_fee"])
+                stock_fee = float(i["stock_fee"])
+                # 非amm品种 优先扣cet 其次u 再次b
+                # amm品种 买入收b 卖出收u
+                if i['type'] == "sell":
+                    # 卖出
+                    order_event['fee'] = money_fee
+                elif i['type'] == "buy":
+                    # 买入
+                    order_event['fee'] = stock_fee
+                if i['status'] in ['not_deal', 'part_deal']:
+                    order_event['status'] = "NEW"
+                elif i['status'] in ['cancel', 'done']:
+                    order_event['status'] = "REMOVE"
+                else:
+                    self.logger.error("错误的订单状态")
+                self.callback["onOrder"](order_event)
+        if error:
+            print(error)
+        return response
+
+    async def get_history_order(self):
+        pass
+
+    async def get_server_time(self):
+        params = {}
+        response = await self._request('GET', '/api/v3/time', params=params)
+        return response
+    
+    async def before_trade(self):
+        # 获取市场基本情况
+        res, error = await self.get_market_details()
+        if error:
+            pass
+        else:
+            for i in res['data']:
+                if res['data'][i]['name'] == self.symbol:
+                    self.stepSize = float(Decimal("0.1")**Decimal(res['data'][i]["trading_decimal"]))
+                    self.tickSize = float(Decimal("0.1")**Decimal(res['data'][i]["pricing_decimal"]))
+                #### 保存交易规则信息
+                exchange_info = model.ExchangeInfo()
+                exchange_info.symbol = i
+                exchange_info.multiplier = 1
+                exchange_info.stepSize = float(Decimal("0.1")**Decimal(res['data'][i]["trading_decimal"]))
+                exchange_info.tickSize = float(Decimal("0.1")**Decimal(res['data'][i]["pricing_decimal"]))
+                self.exchange_info[exchange_info.symbol] = exchange_info
+
+    async def get_equity(self):
+        # 更新账户
+        res, err = await self.get_account()
+        if err:
+            print(err)
+        if res:
+            for i in res["data"]:
+                if self.quote == i:
+                    self.data['equity'] = float(
+                        res['data'][i]['available'])+float(res['data'][i]['frozen'])
+                    self.callback['onEquity']({
+                        self.quote: self.data['equity']
+                    })
+                    self.cash_value = self.data['equity']
+                elif self.base == i:
+                    coin = float(res['data'][i]['available']) + \
+                        float(res['data'][i]['frozen'])
+                    self.callback['onEquity']({
+                        self.base: coin
+                    })
+                    self.coin_value = coin
+    async def universalTransfer(self, _type='UMFUTURE_MAIN', asset='USDT', amount=0):
+        pass
+    async def futuresTransfer(self, _type='2', asset='USDT', amount=0):
+        pass
+    async def get_account(self):
+        return await self._request('GET', '/api/v3/account', params={"access_id": self.params.access_key}, auth=1)
+    async def get_market_details(self):
+        return await self._request('GET', f'/api/v3/exchangeInfo', params={}, auth=0)
+    async def get_ticker(self):
+        ## 'merge' 参数可能需要去掉 Galois
+        res, err = await self._request('GET', f'/api/v3/depth', params={"symbol": self.symbol, 'merge': '0.00000001'}, auth=0)
+        if res:
+            ap = float(res["data"]['asks'][0][0])
+            bp = float(res["data"]['bids'][0][0])
+            mp = (ap+bp)*0.5
+            d = {"name": self.name, 'mp': mp, 'bp': bp, 'ap': ap}
+            self.callback['onTicker'](d)
+            return d
+        if err:
+            self.logger.error(err)
+            return None
+    async def buy_token(self):
+        pass
+    async def go(self):
+        await self.before_trade()
+        await asyncio.sleep(1)
+        ### Mexc无法检查是否为AMMM品种
+        # try:
+        #     async with aiohttp.ClientSession(connector = aiohttp.TCPConnector(
+        #                         limit=50,
+        #                         keepalive_timeout=120,
+        #                         verify_ssl=False,
+        #                         local_addr=(self.ip,0)
+        #                     )) as session:
+        #         response = await session.get(
+        #             "https://api.coinex.com/v1/amm/market",
+        #             proxy=self.proxy
+        #         )
+        #         res = await response.json()
+        #         amm_list = res['data']
+        #         print(f'AMM列表{amm_list}')
+        #         if self.symbol in amm_list:
+        #             self.callback['onExit'](f"{self.name} coinex spot 禁止跑AMM品种")
+        #         else:
+        #             print(f'不是AMM品种 正常运行')
+        # except:
+        #     self.logger.error(traceback.format_exc())
+        #     self.callback['onExit'](f"{self.name} coinex spot AMM列表获取失败")
+        while 1:
+            try:
+                # 停机信号
+                if self.stop_flag:
+                    return
+                # 更新账户
+                res, err = await self.get_account()
+                if err:
+                    print(err)
+                if res:
+                    for i in res["data"]:
+                        if self.quote == i:
+                            self.data['equity'] = float(
+                                res['data'][i]['available']) + float(res['data'][i]['frozen'])
+                            self.callback['onEquity']({
+                                self.quote: self.data['equity']
+                            })
+                        elif self.base == i:
+                            coin = float(res['data'][i]['available']) + \
+                                float(res['data'][i]['frozen'])
+                            self.callback['onEquity']({
+                                self.base: coin
+                            })
+                # 更新订单
+                # res = await self.get_order_list()
+                await asyncio.sleep(60)
+                # 打印延迟
+                self.get_delay_info()
+                self.logger.debug(f'报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(10)
+
+    def get_data(self):
+        return self.data
+    
+    async def handle_signals(self, orders):
+        '''执行策略指令'''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    # 只能用oid撤单
+                    if oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3]
+                        ))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    # cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    asyncio.get_event_loop().create_task(self.check_order(order_id=oid))
+        except Exception as e:
+            traceback.print_exc()
+            self.logger.error("执行信号出错"+str(e))
+            await asyncio.sleep(0.1)

+ 340 - 0
exchange/mexc_spot_ws.py

@@ -0,0 +1,340 @@
+import random, csv, sys, utils
+import logging, logging.handlers
+import model
+import time
+import json, ujson
+import asyncio
+import aiohttp
+import traceback
+import hashlib
+def empty_call(msg):
+    print(f'空的回调函数 {msg}')
+
+class MexcSpotWs:
+    def __init__(self, params:model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.URL = 'wss://wbs.mexc.com/ws'
+        else:
+            self.URL = 'wss://wbs.mexc.com/ws'
+        self.params = params
+        self.name = self.params.name
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        self.symbol = self.base + self.quote
+        self.callback = {
+            "onMarket":self.save_market,
+            "onDepth":empty_call,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.stop_flag = 0
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        self.update_t = 0
+        self.depth = []
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+    
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = self.params.interval
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+    
+    async def get_depth_flash(self):
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['X-MBX-APIKEY'] = self.params.access_key
+        url = f'https://api.mexc.com/api/v3/depth?symbol={self.symbol}&limit=1000'
+        session = aiohttp.ClientSession()
+        response = await session.get(
+            url, 
+            headers=headers, 
+            timeout=5, 
+            proxy=self.proxy
+            )
+        depth_flash = await response.text()
+        await session.close()
+        return ujson.loads(depth_flash)
+    
+    def _update_depth(self, msg):
+        msg = ujson.loads(msg)
+        t = float(msg['params'][1]['time'])
+        if t > self.update_t:
+            self.update_t = t
+            self.ticker_info["bp"] = float(msg['params'][1]['bids'][0][0])
+            self.ticker_info["ap"] = float(msg['params'][1]['asks'][0][0])
+            self.callback['onTicker'](self.ticker_info)
+            ##### normalize depth
+            mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+            step = mp * utils.EFF_RANGE / utils.LEVEL
+            bp = []
+            ap = []
+            bv = [0 for _ in range(utils.LEVEL)]
+            av = [0 for _ in range(utils.LEVEL)]
+            for i in range(utils.LEVEL):
+                bp.append(self.ticker_info["bp"]-step*i)
+            for i in range(utils.LEVEL):
+                ap.append(self.ticker_info["ap"]+step*i)
+            # 
+            price_thre = self.ticker_info["bp"] - step
+            index = 0
+            for bid in msg['params'][1]['bids']:
+                price = float(bid[0])
+                amount = float(bid[1])
+                if price > price_thre:
+                    bv[index] += amount
+                else:
+                    price_thre -= step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    bv[index] += amount
+            price_thre = self.ticker_info["ap"] + step
+            index = 0
+            for ask in msg['params'][1]['asks']:
+                price = float(ask[0])
+                amount = float(ask[1])
+                if price < price_thre:
+                    av[index] += amount
+                else:
+                    price_thre += step
+                    index += 1
+                    if index == utils.LEVEL:
+                        break
+                    av[index] += amount
+            self.depth = bp + bv + ap + av
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+        else:
+            self.logger.error("mexc ws推送过期信息")
+    
+    def _update_trade(self, msg):
+        msg = json.loads(msg)
+        for i in msg['params'][1]:
+            side = i["type"]
+            price = float(i["price"])
+            amount = float(i['amount'])
+            if price > self.max_buy or self.max_buy == 0.0:
+                self.max_buy = price
+            if price < self.min_sell or self.min_sell == 0.0:
+                self.min_sell = price
+            if side == 'buy':
+                self.buy_q += amount
+                self.buy_v += amount*price
+            elif side == 'sell':
+                self.sell_q += amount
+                self.sell_v += amount*price
+    
+    def _update_account(self, msg):
+        msg = json.loads(msg)
+        for i in msg['params'][0]:
+            if self.quote == i:
+                cash = float(msg['params'][0][self.quote]['available'])+float(msg['params'][0][self.quote]['frozen'])
+                self.callback['onEquity'] = {
+                    self.quote:cash
+                }
+            elif self.base == i:
+                coin = float(msg['params'][0][self.base]['available'])+float(msg['params'][0][self.base]['frozen'])
+                self.callback['onEquity'] = {
+                    self.base:coin
+                }
+    
+    def _update_order(self, msg):
+        self.logger.debug("ws推送订单"+msg)
+        msg = json.loads(msg)
+        event_type = msg['params'][0]
+        event = msg['params'][1]
+        if event_type == 1:  # 新增订单
+            order_event = dict()
+            order_event['filled'] = 0
+            order_event['filled_price'] = 0
+            order_event['client_id'] = event["client_id"]
+            order_event['order_id'] = event['id']
+            order_event['status'] = "NEW"
+            self.callback["onOrder"](order_event)
+        elif event_type == 3:  # 删除订单
+            order_event = dict()
+            order_event['filled'] = float(event["amount"]) - float(event["left"])
+            order_event['filled_price'] = float(event["price"]) 
+            # asset_fee = float(event["asset_fee"])
+            money_fee = float(event["money_fee"])
+            stock_fee = float(event["stock_fee"])
+            # 非amm品种 优先扣cet 其次u 再次b
+            # amm品种 买入收b 卖出收u
+            if event['side'] == 1:
+                # 卖出
+                order_event['fee'] = money_fee
+            elif event['side'] == 2:
+                # 买入
+                order_event['fee'] = stock_fee
+            order_event['client_id'] = event["client_id"]
+            order_event['order_id'] = event['id']
+            order_event['status'] = "REMOVE"
+            self.callback["onOrder"](order_event)
+
+    def _update_position(self, msg):
+        long_pos, short_pos = 0, 0
+        long_avg, short_avg = 0, 0 
+        msg = ujson.loads(msg)
+        for i in msg['a']['P']:
+            if i['s'] == self.symbol:
+                if i['ps'] == 'LONG':
+                    long_pos += abs(float(i['pa']))
+                    long_avg = abs(float(i['ep']))
+                if i['ps'] == 'SHORT':
+                    short_pos += abs(float(i['pa']))
+                    short_avg = abs(float(i['ep']))
+        pos = model.Position()
+        pos.longPos = long_pos
+        pos.longAvg = long_avg
+        pos.shortPos = short_pos
+        pos.shortAvg = short_avg
+        self.callback['onPosition'](pos)
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+    
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+    
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        ping_time = time.time()
+        while True:
+            try:
+                # 尝试连接
+                print(f'{self.name} 尝试连接ws')
+                # 登陆
+                ws_url = self.URL
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f'{self.name} ws连接成功')
+                    self.logger.info(f'{self.name} ws连接成功')
+                    # 订阅 mexc 现货
+                    symbol = self.symbol.upper()
+                    # 鉴权
+                    if is_auth:
+                        current_time = int(time.time()*1000)
+                        sign_str = f"access_id={self.params.access_key}&tonce={current_time}&secret_key={self.params.secret_key}"
+                        md5 = hashlib.md5(sign_str.encode())
+                        param = {
+                            "id": 1,
+                            "method": "server.sign",
+                            "params": [self.params.access_key, md5.hexdigest().upper(), current_time]
+                        }
+                        await _ws.send_str(ujson.dumps(param))   
+                        res = await _ws.receive(timeout=30)
+                        # 订阅资产
+                        sub_str = ujson.dumps({"id": 1, "method": "asset.subscribe","params": [self.base,self.quote]})
+                        await _ws.send_str(sub_str)
+                        # 订阅私有订单
+                        sub_str = ujson.dumps({"id": 1, "method": "order.subscribe","params": [symbol]})
+                        await _ws.send_str(sub_str)
+                    if sub_trade:
+                        # 订阅公开成交
+                        sub_str = ujson.dumps({"id": 1, "method": "deals.subscribe","params": [symbol]})
+                        await _ws.send_str(sub_str)
+                    # 订阅深度
+                    sub_str = ujson.dumps({"id": 1, "method": "depth.subscribe","params": [symbol, 50, "0.000000001", False]})
+                    await _ws.send_str(sub_str)
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:
+                            await _ws.close()
+                            return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except asyncio.CancelledError:
+                            print('ws取消')
+                            return
+                        except asyncio.TimeoutError:
+                            print(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} ws长时间没有收到消息 准备重连...')
+                            break
+                        except:
+                            print(f'{self.name} ws出现错误 准备重连...')
+                            self.logger.error(f'{self.name} ws出现错误 准备重连...')
+                            self.logger.error(traceback.format_exc())
+                            break
+                        msg = msg.data
+                        # 处理消息
+                        if 'depth.update' in msg:self._update_depth(msg)
+                        elif 'deals.update' in msg:self._update_trade(msg)
+                        elif 'asset.update' in msg:self._update_account(msg)
+                        elif 'order.update' in msg:self._update_order(msg)
+                        else:
+                            print(msg)
+                            pass
+                        if ping_time - time.time() > 60:
+                            ping_time = time.time()
+                            sub_str = ujson.dumps({"id": 1, "method": "server.ping","params": []})
+                            await _ws.send_str(sub_str)
+            except:
+                _ws = None
+                traceback.print_exc()
+                print(f'{self.name} ws连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws连接失败 开始重连...')
+                await asyncio.sleep(1)

+ 98 - 0
exchange/model.py

@@ -0,0 +1,98 @@
+import utils
+
+class BacktestFee:
+    
+    def __init__(self, msg=None):
+        if msg == "v9":
+            self.maker = -0.00001
+            self.taker =  0.0002
+        elif msg == "v0":
+            self.maker = 0.0001
+            self.taker = 0.0005
+        else:
+            self.maker = 0
+            self.taker = 0
+
+class ExchangeInfo:
+    def __init__(self) -> None:
+        self.symbol = None
+        self.tickSize = None
+        self.stepSize = None
+        self.multiplier = None
+
+class Order:
+    def __init__(self):
+        self.symbol = None
+        self.order_id = None
+        self.amount = None
+        self.side = None
+        self.price = None
+
+class Position():
+
+    def __init__(self):
+        self.longPos = 0
+        self.shortPos = 0
+        self.longAvg = 0
+        self.shortAvg = 0
+
+class TraderMsg:
+
+    def __init__(self):
+        self.position = Position()
+        self.cash = 0.0
+        self.coin = 0.0
+        self.orders = dict()
+        self.ref_price = None
+        self.market = []
+        self.predict = 0.0
+
+class ClientParams:
+
+    def __init__(self):
+        self.name = None
+        self.pair = None
+        self.proxy = None
+        self.access_key = None
+        self.secret_key = None
+        self.pass_key = None
+        self.interval = None
+        self.broker_id = None
+        self.debug = None
+
+class Config:
+
+    def __init__(self):
+        self.broker_id = None
+        self.account_name = None
+        self.access_key = None
+        self.secret_key = None
+        self.pass_key = None
+        self.exchange = None
+        self.pair = None
+        self.debug = None
+        self.open = None
+        self.close = None
+        self.server_port = None
+        self.leverrate = None
+        self.interval = 0.1
+        self.close = None
+        self.open = None
+        self.refexchange = None
+        self.refpair = None
+        self.webhook = None
+        self.used_pct = None
+        self.place_order_limit = 0
+        # self.proxy = "http://127.0.0.1:4780" # 仅在win下有效
+        self.proxy = None # 仅在win下有效
+        self.index = 0
+        self.save = 0
+        self.hold_coin = 0.0
+        self.log = 0
+        self.stoploss = 0.05
+        self.gamma = 0.999
+        self.grid = 1
+        self.backtest = 0
+        self.colo = 0
+        self.fast = 1
+        self.ip = 0

+ 712 - 0
exchange/okex_usdt_swap_rest.py

@@ -0,0 +1,712 @@
+
+
+import aiohttp
+import time
+import asyncio
+import json, ujson
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random, csv, sys
+import logging, logging.handlers
+from datetime import datetime
+from urllib.parse import urlparse
+import urllib
+from decimal import Decimal
+import utils
+import model
+
+
+
+
+def empty_call(msg):
+    pass
+
+def sort_num(n):
+    if n.isdigit():
+        return int(n)
+    else:
+        return float(n)
+
+class OkexUsdtSwapRest:
+    """"""
+    def __init__(self, params:model.ClientParams, colo=0):
+        if colo:
+            print('不支持colo高速线路')
+            self.REST = 'https://www.okx.com' # hk
+            # REST = 'https://aws.okx.com' # aws
+        else:
+            self.REST = 'https://www.okx.com' # hk
+            # REST = 'https://aws.okx.com' # aws
+        self.params = params
+        self.name = self.params.name
+        self.base = params.pair.split('_')[0].upper()
+        self.quote = params.pair.split('_')[1].upper()
+        self.symbol = f"{self.base}-{self.quote}-SWAP"
+        if len(self.params.pair.split('_')) > 2:
+            self.delivery = self.params.pair.split('_')[2] # 210924
+            self.symbol += f"-{self.delivery}"
+        self.data = {}
+        self._SESSIONS = dict()
+        self.callback = {
+            "onMarket":empty_call,
+            "onPosition":empty_call,
+            "onOrder":empty_call,
+            "onEquity":empty_call,
+            "onExit":empty_call,
+            }
+        self.exchange_info = dict()
+        self.stepSize = None
+        self.tickSize = None
+        self.ctVal    = None
+        self.ctMult   = None
+        self.delays = []
+        self.avg_delay = 0
+        self.max_delay = 0
+        self.proxy = None
+        self.broker_id = self.params.broker_id
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.stop_flag = 0
+        self.coin_value = 0.0
+        self.cash_value = 0.0
+
+        self.getheader  = self.make_header()
+        self.postheader = self.make_header()
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+
+    async def take_order(self, symbol, amount, origin_side, price, cid="", order_type='LIMIT'):
+        '''
+            下单接口
+        '''
+        if symbol not in self.exchange_info:
+            await self.before_trade()
+            # amount = float(Decimal(str(amount//self.exchange_info[symbol].stepSize))*Decimal(str(self.exchange_info[symbol].stepSize)))
+            # price = float(Decimal(str(price//self.exchange_info[symbol].tickSize))*Decimal(str(self.exchange_info[symbol].tickSize)))
+            amount = utils.fix_amount(amount, self.exchange_info[symbol].stepSize)
+            price = utils.fix_price(price, self.exchange_info[symbol].tickSize)
+        amount = int(amount/self.ctVal) # 这里把币转成张 后续用张来下单
+        # 似乎有了num_to_str就不再需要下面两行
+        if origin_side =='kd':
+            side = 'buy'
+            positionSide = 'long'
+        elif origin_side =='pd':
+            side = 'sell'
+            positionSide = 'long'
+        elif origin_side =='kk':
+            side = 'sell'
+            positionSide = 'short'
+        elif origin_side =='pk':
+            side = 'buy'
+            positionSide = 'short'
+        else:
+            raise Exception(f'下单参数错误 side:{origin_side}')
+        if amount <= 0.0: 
+            # okex 因为 数量单位为张 很容易出现这个问题 避免频繁写入日志
+            # self.logger.error(f'下单参数错误 amount:{amount} side:{origin_side}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None, None
+        if price <= 0.0:
+            self.logger.error(f'下单参数错误 price:{price} side:{origin_side}')
+            order_event = dict()
+            order_event['status'] = "REMOVE"
+            order_event['filled_price'] = 0.0
+            order_event['filled'] = 0.0
+            order_event['client_id'] = cid
+            self.callback["onOrder"](order_event)
+            return None, None
+        params = {
+            'instId': symbol,
+            'tdMode': "cross",
+            'sz': amount,
+            'side': side,
+            'posSide': positionSide,
+            'ordType': "limit" if order_type!="MARKET" else 'market',
+            'clOrdId': cid,
+        }
+        if params['ordType'] == 'limit':
+            params['px'] = utils.num_to_str(price, self.exchange_info[symbol].tickSize)
+        if self.params.debug == 'True':
+            await asyncio.sleep(0.1)
+            return None, None
+        else:
+            # 再报单
+            response, error = await self.http_post_request('/api/v5/trade/order', params)
+            # 再更新
+            if response is not None:
+                if response:
+                    data = response['data'][0]
+                    order_event = dict()
+                    order_event['status'] = "NEW"
+                    order_event['client_id'] = params['clOrdId']
+                    order_event['order_id'] = data['ordId']
+                    self.callback["onOrder"](order_event)
+            if error:
+                order_event = dict()
+                order_event['status'] = "REMOVE"
+                order_event['filled_price'] = 0
+                order_event['filled'] = 0
+                order_event['client_id'] = params['clOrdId']
+                self.callback["onOrder"](order_event)
+        return response, error
+
+    async def cancel_order(self, order_id=None, client_id=None, symbol=None):
+        """注意,ok撤单不能更新订单状态,撤单成功也仅仅代表交易所收到了撤单请求"""
+        params = {
+            "instId": self.symbol if symbol==None else symbol,
+        }
+        if order_id:
+            params["ordId"] = order_id
+        if client_id:
+            params["clOrdId"] = client_id
+        if self.params.debug == 'True':
+            await asyncio.sleep(0.1)
+            return None
+        else:
+            response, error = await self.http_post_request('/api/v5/trade/cancel-order', params)
+            if error:
+                # print("撤单失败",error)
+                # 撤单失败 可能已经撤单 是否发生成交需要rest查
+                # if client_id:await self.check_order(client_id=client_id)
+                # if order_id:await self.check_order(order_id=order_id)
+                return error
+            return response
+
+    async def check_order(self, order_id=None, client_id=None, symbol=None):
+        params = {
+            "instId": self.symbol if symbol==None else symbol,
+        }
+        if order_id:
+            params["ordId"] = order_id
+        if client_id:
+            params["clOrdId"] = client_id
+        if self.params.debug == 'True':
+            await asyncio.sleep(0.1)
+            return None
+        else:
+            response, error = await self.http_get_request('/api/v5/trade/order', params)
+            if error:
+                print(f"{self.name} 查单失败 {error}")
+                return error
+            if response['code']:
+                for order in response['data']:
+                    if order['state'] in ['canceled', 'filled']:
+                        order_event = dict()
+                        order_event['status'] = 'REMOVE'
+                        order_event['client_id'] = order['clOrdId']
+                        order_event['order_id'] = order['ordId']
+                        order_event['filled'] = float(order['accFillSz'])*self.ctVal if order['accFillSz'] != '' else 0.0   # usdt永续需要考虑每张的单位
+                        order_event['filled_price'] = float(order['avgPx']) if order['avgPx'] != '' else 0.0
+                        self.callback['onOrder'](order_event)
+                    else:
+                        order_event = dict()
+                        order_event['status'] = "NEW"
+                        order_event['client_id'] = order['clOrdId']
+                        order_event['order_id'] = order['ordId']
+                        self.callback['onOrder'](order_event)
+            return response
+
+    async def get_order_list(self):
+        '''
+            获取挂单表
+        '''
+        response, error = await self.http_get_request('/api/v5/trade/orders-pending', {'instId':self.symbol})
+        # print(response)
+        orders = [] # 查询当前挂单 只可能出现 new 和 partfill 默认成交为0 只有 done状态的订单才考虑是否有成交
+        if response and response['code']:
+            for i in response['data']:
+                order_event = dict()
+                order_event['status'] = "NEW"
+                order_event['filled'] = 0
+                order_event['filled_price'] = 0
+                order_event['client_id'] = i["clOrdId"]
+                order_event['order_id'] = i['ordId']
+                self.callback["onOrder"](order_event)
+                orders.append(order_event)
+        if error:
+            print('查询列表出错',error)
+        return orders
+
+    async def get_server_time(self):
+        response = await self.http_get_request('/api/v5/public/time')
+        return response
+
+    async def get_equity(self):
+        ##########
+        res, err = await self.get_account()
+        if res:
+            for data in res['data']:
+                for i in data['details']:
+                    if self.quote == i['ccy']:
+                        self.data['equity'] = float(i['eq'])
+                        self.callback['onEquity']({self.quote:self.data['equity']})
+        if err:
+            print('获取账户信息错误', err)
+        ##########
+
+    async def universalTransfer(self, _type='UMFUTURE_MAIN', asset='USDT', amount=0):
+        """okex现在统一账户,没有钱包这个概念了,不实现"""
+        pass
+
+    async def futuresTransfer(self, _type='2', asset='USDT', amount=0):
+        """okex现在统一账户,没有钱包这个概念了,不实现"""
+        pass
+
+    async def buy_token(self):
+        '''买入平台币'''
+        pass
+
+
+    async def check_position(self, hold_coin=0.0):
+        '''
+            检查是否存在非运行币种的仓位并take平仓
+            已支持全品种
+        '''
+        try:
+            ###
+            self.logger.info(f'{self.name} 检查遗漏订单')
+            response, error = await self.http_get_request('/api/v5/trade/orders-pending', {})
+            if response:
+                for order in response['data']:
+                    params = {
+                        "instId": order['instId'],
+                        "ordId": order['ordId']
+                    }
+                    res, err = await self.http_post_request('/api/v5/trade/cancel-order', params)
+                    await asyncio.sleep(0.1)
+                    self.logger.info(f"{self.name} 清理遗漏订单 {res} {err}")
+            ###
+            if self.exchange_info == dict():
+                await self.before_trade()
+            ###
+            self.logger.info(f'{self.name} 检查遗漏仓位')
+            # 清空全部仓位
+            response, error = await self.http_get_request('/api/v5/account/positions', {"instType":"SWAP"})
+            if response:
+                for i in response['data']:
+                    symbol = i['instId']
+                    pos = float(i['pos'])
+                    posSide = i['posSide']
+                    ###
+                    ticker, err = await self.http_get_request('/api/v5/market/ticker', {'instId':symbol})
+                    if err:
+                        print(err)
+                    if ticker:
+                        ap = float(ticker['data'][0]['askPx']) 
+                        bp = float(ticker['data'][0]['bidPx'])
+                    ### 每个品种都要获取各自的精度
+                    trade_side = 'sell' if posSide == 'long' else "buy"
+                    trade_pos = abs(pos)
+                    trade_pos_side = posSide
+                    params = {
+                        'instId': symbol,
+                        'tdMode': "cross",
+                        'sz': int(trade_pos),
+                        'px': ap*1.001 if trade_side == 'buy' else bp*0.999,
+                        'side' : trade_side,
+                        'posSide': trade_pos_side,
+                        'ordType': "limit",
+                    }
+                    response, error = await self.http_post_request('/api/v5/trade/order', params)
+                    print("下单结果",response,error)
+            self.logger.info('遗留仓位检测完毕')        
+        except:
+            self.logger.error("清仓程序执行出错")
+            self.logger.error(traceback.format_exc())
+        return
+
+    async def before_trade(self):
+        """"""
+        response, error = await self.get_instruments()
+        if error:
+            print('获取市场信息错误',error)
+        else:
+            self.update_instruments(response)
+            print(f'before_trade get_instruments successed. {self.symbol} {self.stepSize} {self.tickSize} {self.ctVal} {self.ctMult}')
+        # 更新账户
+        res, err = await self.get_account()
+        if err:
+            print(err)
+        else:
+            for data in res['data']:
+                for i in data['details']:
+                    if self.quote == i['ccy']:
+                        self.data['equity'] = float(i['eq'])
+                        self.callback['onEquity']({self.quote:self.data['equity']})
+                        self.cash_value = float(i['eq'])
+                        print(f"{self.name} on_go {self.symbol} equity {self.quote} {self.data['equity']}")
+        await self.change_pos_side()
+        await asyncio.sleep(1)
+        await self.set_leverage(10)
+        await asyncio.sleep(1)
+        await self.get_position()
+
+    async def get_all_position(self):
+        '''
+            获取仓位信息
+        '''
+        response, error = await self.http_get_request('/api/v5/account/positions', {})
+        print(f'查看此账号全部仓位 {response} {error}')
+
+    async def get_position(self):
+        '''
+            获取仓位信息
+        '''
+        response, error = await self.http_get_request('/api/v5/account/positions', {'instId':self.symbol})
+        if error:
+            print(f"{self.name} get_position error {error}")
+            return None
+        longPos, shortPos = 0, 0
+        longAvg, shortAvg = 0, 0
+        for i in response['data']:
+            if i['instId'] == self.symbol and i['pos'] and i['avgPx']:
+                if i['posSide'] == 'long':
+                    longPos = float(i['pos'])*self.ctVal
+                    longAvg = float(i['avgPx'])
+                elif i['posSide'] == 'short':
+                    shortPos = float(i['pos'])*self.ctVal
+                    shortAvg = float(i['avgPx'])
+        position = model.Position()
+        position.longPos = abs(longPos)
+        position.longAvg = abs(longAvg)
+        position.shortPos = abs(shortPos)
+        position.shortAvg = abs(shortAvg)
+        self.callback['onPosition'](position)
+        return position
+
+    async def get_ticker(self):
+        res, err = await self.http_get_request('/api/v5/market/ticker', {'instId':self.symbol})
+        if res:
+            ap = float(res['data'][0]['askPx']) 
+            bp = float(res['data'][0]['bidPx'])
+            mp = (ap+bp)/2
+            ticker = {"name":self.name,'mp': mp, 'bp':bp, 'ap':ap}
+            return ticker
+        if err:
+            self.logger.debug(err)
+        return None
+
+    async def get_account(self):
+        response, error = await self.http_get_request('/api/v5/account/balance', {'ccy':self.quote})
+        return response, error
+
+    async def go(self):
+        '''
+            盘前
+            获取市场信息
+            获取账户信息
+            更改仓位模式(期货)
+            清空仓位和挂单
+            盘中
+            更新账户信息
+            更新挂单列表
+            更新仓位信息
+            更新延迟信息
+        '''
+        print('Rest循环器启动')
+        interval = 60  # 不能太快防止占用限频
+        ### beforeTrade
+        await self.before_trade()
+        await asyncio.sleep(1)
+        ### onTrade
+        loop = 0
+        while 1:
+            loop += 1
+            try:
+                # 停机信号
+                if self.stop_flag:
+                    return
+                # 更新账户
+                res, err = await self.get_account()
+                if err:
+                    print(err)
+                else:
+                    for data in res['data']:
+                        for i in data['details']:
+                            if self.quote == i['ccy']:
+                                self.data['equity'] = float(i['eq'])
+                                self.callback['onEquity']({self.quote:self.data['equity']})
+                                # print(f"{self.name} on_go {self.symbol} equity {self.quote} {self.data['equity']}")
+                # 更新仓位
+                await self.get_position()
+                # print(f"{self.name} on_go {self.symbol} position {position}")
+                await asyncio.sleep(interval)
+                # 打印延迟
+                self.get_delay_info()
+                self.logger.debug(f'报单延迟 平均{self.avg_delay}ms 最大{self.max_delay}ms')
+            except asyncio.CancelledError:
+                return
+            except:
+                traceback.print_exc()
+                await asyncio.sleep(30)
+
+    async def handle_signals(self, orders):
+        '''
+            执行策略指令
+            撤销订单
+            检查订单
+            下达订单
+        '''
+        try:
+            for order_name in orders:
+                if 'Cancel' in order_name:
+                    cid = orders[order_name][0]
+                    oid = orders[order_name][1]
+                    if cid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(client_id=cid))
+                    elif oid:
+                        asyncio.get_event_loop().create_task(self.cancel_order(order_id=oid))
+            for order_name in orders:
+                if 'Check' in order_name:
+                    cid = orders[order_name][0]
+                    # oid = orders[order_name][1]
+                    asyncio.get_event_loop().create_task(self.check_order(client_id=cid))
+            for order_name in orders:
+                if 'Limits' in order_name:
+                    for i in orders[order_name]:
+                        asyncio.get_event_loop().create_task(self.take_order(
+                            self.symbol,
+                            i[0],
+                            i[1],
+                            i[2],
+                            i[3]
+                        ))
+        except Exception as e:
+            traceback.print_exc()
+            self.logger.error("执行信号出错"+str(e))
+            await asyncio.sleep(0.1)
+
+    async def set_leverage(self, lever=10):
+        params = {'instId':self.symbol, 'lever':utils.num_to_str(lever, 1), 'mgnMode':'cross'}
+        res, error = await self.http_post_request('/api/v5/account/set-leverage', params)
+        if error:
+            print(f"{self.name} 设置杠杆倍数 {params} failed. -->{error}")
+            return None, error
+        if res['code'] == '0':
+            print(f"{self.name} 设置杠杆倍数 {params} success. -->{res}")
+        else:
+            print(f"{self.name} 设置杠杆倍数 {params} 暂时不能设置. -->{res}")
+        return res, error
+
+    async def change_pos_side(self, dual='true'):
+        ''''''
+        params = {'posMode': "long_short_mode"}
+        res, error = await self.http_post_request('/api/v5/account/set-position-mode', params)
+        if error:
+            print(f"{self.name} 设置双向持仓 failed. -->{error}")
+            return res, error
+        if res['code'] == '0':
+            print(f"{self.name} 设置双向持仓 success. -->{res}")
+        else:
+            print(f"{self.name} 设置双向持仓 暂时不能设置. -->{res}")
+        return res, error
+
+    async def get_instruments(self):
+        """从rest获取合约信息"""
+        params = {'instType': 'SWAP'}
+        res, error = await self.http_get_request('/api/v5/public/instruments', params)
+        return res, error
+
+    def update_instruments(self, data):
+        """根据信息调整合约信息"""
+        for info in data['data']:
+            if info['instId'] == self.symbol:
+                self.ctVal    = sort_num(info['ctVal'])
+                self.ctMult   = sort_num(info['ctMult'])
+                self.tickSize = sort_num(info['tickSz'])
+                self.stepSize = sort_num(info['minSz'])*self.ctVal
+            #### 保存交易规则信息
+            exchange_info = model.ExchangeInfo()
+            exchange_info.symbol = info['instId']
+            exchange_info.multiplier = sort_num(info['ctMult'])
+            exchange_info.tickSize = sort_num(info['tickSz'])
+            exchange_info.stepSize = sort_num(info['minSz'])*sort_num(info['ctMult'])
+            self.exchange_info[exchange_info.symbol] = exchange_info
+
+    def __sign(self, timestamp, method, path, data):
+        if data:
+            message = timestamp + method + path + data
+        else:
+            message = timestamp + method + path
+        digest = hmac.new(bytes(self.params.secret_key.encode('utf8')), bytes(message.encode('utf8')), digestmod=hashlib.sha256).digest()
+        return base64.b64encode(digest).decode('utf-8')
+
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    async def http_post_request(self, method, query=None, **args):
+        """"""
+        params = dict()
+        params.update(**args)
+        if query is not None: params.update(query)
+        data = json.dumps(params)
+        timestamp = self.timestamp()
+        sign = self.__sign(timestamp, 'POST', method, data)
+        self.postheader['OK-ACCESS-SIGN'] = sign
+        self.postheader['OK-ACCESS-TIMESTAMP'] = timestamp
+        url = f"{self.REST}{method}"
+        res, error = await self._request('POST', url, data)
+        return res, error
+
+    async def http_get_request(self, method, query=None, **args):
+        """"""
+        params = dict()
+        params.update(**args)
+        if query is not None: params.update(query)
+        timestamp = self.timestamp()
+        if params:
+            path = '{method}?{params}'.format(method=method, params=urllib.parse.urlencode(params))
+        else:
+            path = method
+        sign = self.__sign(timestamp, 'GET', path, None)
+        self.getheader['OK-ACCESS-SIGN'] = sign
+        self.getheader['OK-ACCESS-TIMESTAMP'] = timestamp
+        url = f"{self.REST}{path}"
+        res, error = await self._request('GET', url)
+        return res, error
+
+    async def _request(self, method, url, params=None):
+        """"""
+        try:
+            ######
+            msg = f"rest请求记录 {method} {url} {params}"
+            self.logger.debug(msg)
+            ######
+            session = self._get_session(url)
+            start_time = time.time()
+            if method == 'GET':
+                headers = self.getheader
+                response = await session.get(
+                    url,
+                    headers=headers,
+                    timeout=10,
+                    proxy=self.proxy
+                    )
+            elif method == 'POST':
+                headers = self.postheader
+                response = await session.post(
+                    url,
+                    data=params,
+                    headers=headers,
+                    timeout=10,
+                    proxy=self.proxy
+                    )
+            code = response.status
+            res = await response.json()
+            msg = f"rest请求记录 {method} {url} {headers} {params}"
+            res_msg = msg + f' {res}'
+            self.logger.debug(res_msg)
+            if code not in (200, 201, 202, 203, 204, 205, 206):
+                self.logger.error(f'METHOD:{method} URL:{url} PARAMS:{params} ERROR:{res}')
+                return None, str(res)
+            if 'code' in res:
+                if int(res['code']) not in (0,):
+                    if '51401' in str(res):
+                        pass
+                    else:
+                        self.logger.error(f'METHOD:{method} URL:{url} PARAMS:{params} ERROR:{res}')
+                    return None, str(res)
+            delay = int(1000*(time.time() - start_time))
+            self.delays.append(delay)
+            return res, None
+        except Exception as e:
+            print('网络请求错误')
+            print(f'URL:{url} PARAMS:{params} ERROR:{e}')
+            self.logger.error(e)
+            self.logger.error(traceback.format_exc())
+            return None, str(e)
+
+    async def get_history(self):
+        params = {
+            'instType': "SWAP",
+            'limit':"100"
+            }
+        res, error = await self.http_get_request('/api/v5/trade/fills', params)
+        b_id=res['data'][0]['billId']
+        ###
+        data = []
+        while 1:
+            params = {
+                'instType': "SWAP",
+                'limit':"100",
+                'after':b_id
+                }
+            await asyncio.sleep(0.3)
+            res, _ = await self.http_get_request('/api/v5/trade/fills', params)
+            if res:
+                if len(res['data']) == 0:
+                    break
+                for i in res['data']:
+                    data.append(i)
+            b_id = res['data'][-1]['billId']
+ 
+        # b_id_s = []
+        # for i in data:
+        #     if i['billId'] in b_id_s:
+        #         print(i['billId'])
+        #     b_id_s.append(i['billId'])
+
+        with open("data5.json", 'w+') as f:
+            f.write(json.dumps(data))
+
+    def get_delay_info(self):
+        if len(self.delays) > 100:
+            self.delays = self.delays[-100:]
+        if max(self.delays) > self.max_delay:self.max_delay = max(self.delays)
+        self.avg_delay = round(sum(self.delays)/len(self.delays),1)
+
+    def timestamp(self):
+        return datetime.utcnow().isoformat("T")[:-3] + 'Z'
+
+    def login_params(self):
+        """生成login字符"""
+        timestamp = str(time.time())
+        message = timestamp + 'GET' + '/users/self/verify'
+        mac = hmac.new(bytes(self.params.secret_key.encode('utf8')), bytes(message.encode('utf8')), digestmod=hashlib.sha256).digest()
+        sign = base64.b64encode(mac)
+        login_dict = {}
+        login_dict['apiKey'] = self.params.access_key
+        login_dict['passphrase'] = self.params.pass_key
+        login_dict['timestamp'] = timestamp
+        login_dict['sign'] =  sign.decode('utf-8')
+        login_param = {'op': 'login', 'args': [login_dict]}
+        login_str = ujson.dumps(login_param)
+        return login_str
+
+    def make_header(self):
+        """"""
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['OK-ACCESS-KEY'] = self.params.access_key
+        headers['OK-ACCESS-SIGN'] = None
+        headers['OK-ACCESS-TIMESTAMP'] = None
+        headers['OK-ACCESS-PASSPHRASE'] = self.params.pass_key
+        return headers
+
+    def _get_session(self, url):
+        parsed_url = urlparse(url)
+        key = parsed_url.netloc or parsed_url.hostname
+        if key not in self._SESSIONS:
+            tcp = aiohttp.TCPConnector(limit=50,keepalive_timeout=120,verify_ssl=False,local_addr=(self.ip,0))
+            session = aiohttp.ClientSession(connector=tcp)
+            self._SESSIONS[key] = session
+        return self._SESSIONS[key]
+

+ 811 - 0
exchange/okex_usdt_swap_ws.py

@@ -0,0 +1,811 @@
+
+import aiohttp
+import time
+import asyncio
+import json, ujson
+from numpy import subtract
+import zlib
+import hashlib
+import hmac
+import base64
+import traceback
+import random, csv, sys
+import logging, logging.handlers
+from datetime import datetime
+import urllib
+import utils
+import model
+
+def timeit(func):
+    def wrapper(*args, **kwargs):
+        nowTime = time.time()
+        res = func(*args, **kwargs)
+        spend_time = time.time() - nowTime
+        spend_time = round(spend_time * 1e6, 3)
+        print(f'{func.__name__} 耗时 {spend_time} us')
+        return res
+    return wrapper
+
+# okex 必须要订阅限价频道 防止无法下单
+
+def empty_call(msg):
+    pass
+
+# -- 增量维护本地深度工具 ----------------------
+def update_bids(depth, bids_p):
+    bids_u = depth['bids']
+    for i in bids_u:
+        bid_price = i[0]
+        for j in bids_p:
+            if bid_price == j[0]:
+                if i[1] == '0':
+                    bids_p.remove(j)
+                    break
+                else:
+                    del j[1]
+                    j.insert(1, i[1])
+                    break
+        else:
+            if i[1] != "0":
+                bids_p.append(i)
+    else:
+        bids_p.sort(key=lambda price: sort_num(price[0]), reverse=True)
+    return bids_p
+
+def update_asks(depth, asks_p):
+    asks_u = depth['asks']
+    for i in asks_u:
+        ask_price = i[0]
+        for j in asks_p:
+            if ask_price == j[0]:
+                if i[1] == '0':
+                    asks_p.remove(j)
+                    break
+                else:
+                    del j[1]
+                    j.insert(1, i[1])
+                    break
+        else:
+            if i[1] != "0":
+                asks_p.append(i)
+    else:
+        asks_p.sort(key=lambda price: sort_num(price[0]))
+    return asks_p
+
+def sort_num(n):
+    if n.isdigit():
+        return int(n)
+    else:
+        return float(n)
+
+# @timeit
+def check(bids, asks):
+    # 获取bid档str
+    bids_l = []
+    bid_l = []
+    count_bid = 1
+    while count_bid <= 25:
+        if count_bid > len(bids):
+            break
+        bids_l.append(bids[count_bid-1])
+        count_bid += 1
+    for j in bids_l:
+        str_bid = ':'.join(j[0 : 2])
+        bid_l.append(str_bid)
+    # 获取ask档str
+    asks_l = []
+    ask_l = []
+    count_ask = 1
+    while count_ask <= 25:
+        if count_ask > len(asks):
+            break
+        asks_l.append(asks[count_ask-1])
+        count_ask += 1
+    for k in asks_l:
+        str_ask = ':'.join(k[0 : 2])
+        ask_l.append(str_ask)
+    # 拼接str
+    num = ''
+    if len(bid_l) == len(ask_l):
+        for m in range(len(bid_l)):
+            num += bid_l[m] + ':' + ask_l[m] + ':'
+    elif len(bid_l) > len(ask_l):
+        # bid档比ask档多
+        for n in range(len(ask_l)):
+            num += bid_l[n] + ':' + ask_l[n] + ':'
+        for l in range(len(ask_l), len(bid_l)):
+            num += bid_l[l] + ':'
+    elif len(bid_l) < len(ask_l):
+        # ask档比bid档多
+        for n in range(len(bid_l)):
+            num += bid_l[n] + ':' + ask_l[n] + ':'
+        for l in range(len(bid_l), len(ask_l)):
+            num += ask_l[l] + ':'
+    new_num = num[:-1]
+    int_checksum = zlib.crc32(new_num.encode())
+    fina = change(int_checksum)
+    return fina
+
+def change(num_old):
+    num = pow(2, 31) - 1
+    if num_old > num:
+        out = num_old - num * 2 - 2
+    else:
+        out = num_old
+    return out
+#--------------------------------------------------
+
+
+class OkexUsdtSwapWs:
+
+    def __init__(self, params:model.ClientParams, colo=0, is_print=0):
+        if colo:
+            print('不支持colo高速线路 请修改hosts')
+            #### hk
+            self.URL_PUBLIC  = 'wss://ws.okx.com:8443/ws/v5/public'
+            self.URL_PRIVATE = 'wss://ws.okx.com:8443/ws/v5/private'
+            self.REST = 'https://www.okx.com'
+            #### aws
+            # self.URL_PUBLIC  = 'wss://wsaws.okx.com:8443/ws/v5/public'
+            # self.URL_PRIVATE = 'wss://wsaws.okx.com:8443/ws/v5/private'
+            # self.REST = 'https://aws.okx.com'
+        else:
+            #### hk
+            self.URL_PUBLIC  = 'wss://ws.okx.com:8443/ws/v5/public'
+            self.URL_PRIVATE = 'wss://ws.okx.com:8443/ws/v5/private'
+            self.REST = 'https://www.okx.com'
+            #### aws
+            # self.URL_PUBLIC  = 'wss://wsaws.okx.com:8443/ws/v5/public'
+            # self.URL_PRIVATE = 'wss://wsaws.okx.com:8443/ws/v5/private'
+            # self.REST = 'https://aws.okx.com'
+        self.params = params
+        self.name = self.params.name
+        self.base = params.pair.split('_')[0].upper()
+        self.quote = params.pair.split('_')[1].upper()
+        self.symbol = f"{self.base}-{self.quote}-SWAP"
+        if len(self.params.pair.split('_')) > 2:
+            self.delivery = self.params.pair.split('_')[2] # 210924
+            self.symbol += f"-{self.delivery}"
+        self.data = dict()
+        self.data['trade'] = []
+        self.data['force'] = []
+        self.callback = {
+            "onMarket":self.save_market,
+            "onPosition":empty_call,
+            "onEquity":empty_call,
+            "onOrder":empty_call,
+            "onTicker":empty_call,
+            "onDepth":empty_call,
+            "onExit":empty_call,
+            }
+        self.depth_update = []
+        self.need_flash = 1
+        self.updata_u = None
+        self.last_update_id = None
+        self.depth = []
+        self.is_print = is_print
+        self.proxy = None
+        if 'win' in sys.platform:
+            self.proxy = self.params.proxy
+        self.logger = self.get_logger()
+        self.ticker_info = {"name":self.name,'bp':0.0,'ap':0.0}
+        self.stop_flag = 0
+        self.public_update_time = time.time()
+        self.private_update_time = time.time()
+        self.expired_time = 300
+
+        self.update_t = 0
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+
+        self.getheader  = self.make_header()
+        self.postheader = self.make_header()
+
+        self._detail_ob = {}  # 用于增量更新depth
+        self.stepSize = None
+        self.tickSize = None
+        self.ctVal    = None  # 合约乘数
+        self.ctMult   = None  # 合约面值
+
+        self.depth = []
+
+        self.sub_trade = 0
+        self.sub_fast = 0
+
+        #### 指定发包ip
+        iplist = utils.get_local_ip_list()
+        self.ip = iplist[int(self.params.ip)]
+
+    def save_market(self, msg):
+        date = time.strftime('%Y-%m-%d',time.localtime())
+        interval = self.params.interval
+        if msg:
+            exchange = msg['name']
+            if len(msg['data']) > 1:
+                with open(f'./history/{exchange}_{self.symbol}_{interval}_{date}.csv',
+                            'a',
+                            newline='',
+                            encoding='utf-8') as f:
+                    writer = csv.writer(f, delimiter=',')
+                    writer.writerow(msg['data'])
+                if self.is_print:print(f'写入行情 {self.symbol}')
+
+    # ------------------------------------------------------
+    # -- core ----------------------------------------------
+
+    async def get_instruments(self):
+        """从rest获取合约信息"""
+        params = {'instId': self.symbol, 'instType': 'SWAP'}
+        session = aiohttp.ClientSession()
+        response = await self.http_get_request_pub(session, '/api/v5/public/instruments', params)
+        data = await response.text()
+        await session.close()
+        return ujson.loads(data)
+
+    def update_instruments(self, data):
+        """根据信息调整合约信息"""
+        for info in data['data']:
+            if info['instId'] == self.symbol:
+                self.stepSize = sort_num(info['minSz'])
+                self.tickSize = sort_num(info['tickSz'])
+                self.ctVal    = sort_num(info['ctVal'])
+                self.ctMult   = sort_num(info['ctMult'])
+                return
+
+    async def get_depth_flash(self):
+        """rest获取深度信息"""
+        params = {'instId': self.symbol, 'sz':20}
+        session = aiohttp.ClientSession()
+        response = await self.http_get_request(session, '/api/v5/market/books', params)
+        depth_flash = await response.text()
+        await session.close()
+        return ujson.loads(depth_flash)
+
+    def _get_data(self):
+        market_data = self.depth + [self.max_buy, self.min_sell]
+        self.max_buy = 0.0
+        self.min_sell = 0.0
+        self.buy_v = 0.0
+        self.buy_q = 0.0
+        self.sell_v = 0.0
+        self.sell_q = 0.0
+        return {'name': self.name,'data':market_data}
+
+    async def go(self):
+        interval = float(self.params.interval)
+        if self.is_print:print(f'Ws循环器启动 interval {interval}')
+        ### onTrade
+        while 1:
+            try:
+                # 更新市场信息
+                market_data = self._get_data()
+                self.callback['onMarket'](market_data)
+            except:
+                traceback.print_exc()
+            await asyncio.sleep(interval)
+
+    def subscribe_private(self):
+        subs = [
+            {'channel':'balance_and_position'},
+            {'channel':'account'},
+            {'channel':'orders', 'instType':"SWAP", 'instId':self.symbol}
+        ]
+        return ujson.dumps({'op':'subscribe', 'args':subs})
+
+    def subscribe_public(self):
+        channels = []
+        if self.sub_fast:
+            channels.append("books50-l2-tbt")
+        else:
+            channels.append("books5")
+        # "tickers",  # 100ms 比book50慢
+        # "books",
+        # "books5",
+        # "books-l2-tbt",
+        # "books50-l2-tbt",
+        # "price-limit"
+        if self.sub_trade:
+            channels.append("trades")
+        subs = [{'instId':self.symbol, 'channel':channel} for channel in channels]
+        return ujson.dumps({'op':'subscribe', 'args':subs})
+
+    async def run_private(self):
+        """"""
+        while 1:
+            try:
+                self.private_update_time = time.time()
+                print(f"{self.name} private 尝试连接ws")
+                ws_url = self.URL_PRIVATE
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f"{self.name} private ws连接成功")
+                    self.logger.debug(f"{self.name} private ws连接成功")
+                    await _ws.send_str(self.login_params())
+                    msg = await _ws.receive(timeout=30)
+                    loggined = False
+                    if msg:
+                        msg = ujson.loads(msg.data)
+                        if msg['event'] == 'login' and msg['code'] == "0":
+                            loggined = True
+                            print(f"{self.name} private login success.")
+                            self.logger.debug(f"{self.name} private login success.")
+                            await _ws.send_str(self.subscribe_private()) # login成功就需要去订阅
+                    if not loggined:
+                        print(f"{self.name} private login failed. --> {msg}")
+                        self.logger.debug(f"{self.name} private login failed. --> {msg}")
+                        await asyncio.sleep(3)
+                    while loggined:
+                        # 停机信号
+                        if self.stop_flag:
+                            await _ws.close()
+                            return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except:
+                            print(f'{self.name} private ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} private ws长时间没有收到消息 准备重连...')
+                            break
+                        msg = msg.data
+                        await self.on_message_private(_ws, msg)
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws public 连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws public 连接失败 开始重连...')
+                self.logger.error(traceback.format_exc())
+                await asyncio.sleep(1)
+                
+
+    async def run_public(self):
+        """"""
+        while 1:
+            try:
+                self.public_update_time = time.time()
+                print(f"{self.name} public 尝试连接ws")
+                ws_url = self.URL_PUBLIC
+                async with aiohttp.ClientSession(
+                        connector = aiohttp.TCPConnector(
+                                limit=50,
+                                keepalive_timeout=120,
+                                verify_ssl=False,
+                                local_addr=(self.ip,0)
+                            )
+                        ).ws_connect(
+                            ws_url,
+                            proxy=self.proxy,
+                            timeout=30,
+                            receive_timeout=30,
+                            ) as _ws:
+                    print(f"{self.name} public ws连接成功")
+                    self.logger.debug(f"{self.name} public ws连接成功")
+                    await _ws.send_str(self.subscribe_public())
+                    while True:
+                        # 停机信号
+                        if self.stop_flag:
+                            await _ws.close()
+                            return
+                        # 接受消息
+                        try:
+                            msg = await _ws.receive(timeout=30)
+                        except:
+                            print(f'{self.name} public ws长时间没有收到消息 准备重连...')
+                            self.logger.error(f'{self.name} public ws长时间没有收到消息 准备重连...')
+                            break
+                        msg = msg.data
+                        await self.on_message_public(_ws, msg)
+            except:
+                traceback.print_exc()
+                print(f'{self.name} ws public 连接失败 开始重连...')
+                self.logger.error(f'{self.name} ws public 连接失败 开始重连...')
+                self.logger.error(traceback.format_exc())
+                await asyncio.sleep(1)
+
+    def _update_ticker(self, msg):
+        """"""
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        ticker = msg['data'][0]
+        t = int(ticker['ts'])
+        if t > self.update_t:
+            self.update_t = t
+            bp = float(ticker['bidPx'])
+            ap = float(ticker['askPx'])
+            bq = float(ticker['bidSz'])
+            aq = float(ticker['askSz'])
+            self.ticker_info["bp"] = bp 
+            self.ticker_info["ap"] = ap 
+            self.callback['onTicker'](self.ticker_info)
+            ####
+            self.depth = [bp, bq, ap, aq]
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+
+    def _update_trade(self, msg):
+        """"""
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        for i in msg['data']:
+            side = i['side']
+            amount = float(i['sz'])*self.ctVal # paper trans to coin
+            price = float(i['px'])
+            if price > self.max_buy or self.max_buy == 0.0:
+                self.max_buy = price
+            if price < self.min_sell or self.min_sell == 0.0:
+                self.min_sell = price
+            if side == 'buy':
+                self.buy_q += amount
+                self.buy_v += amount*price
+            elif side == 'sell':
+                self.sell_q += amount
+                self.sell_v += amount*price
+            #### 修正ticker ####
+            # if side == 'buy' and price > self.ticker_info['ap']:
+            #     self.ticker_info['ap'] = price
+            #     self.callback['onTicker'](self.ticker_info)
+            # if side == 'sell' and price < self.ticker_info['bp']:
+            #     self.ticker_info['bp'] = price
+            #     self.callback['onTicker'](self.ticker_info)
+
+    async def _update_depth(self, _ws, msg):
+        """"""
+        self.public_update_time = time.time()
+        msg = ujson.loads(msg)
+        if "action" not in msg:
+            # books5 就没有action,但5档实在不够用,而且间隔200ms
+            depth = msg['data'][0]
+            bp = float(depth['bids'][0][0])
+            bv = float(depth['bids'][0][1])
+            ap = float(depth['asks'][0][0])
+            av = float(depth['asks'][0][1])
+            self.ticker_info["bp"] = bp
+            self.ticker_info["ap"] = ap
+            self.callback['onTicker'](self.ticker_info)
+            self.depth = [bp,bv,ap,av]
+            self.callback['onDepth']({'name':self.name,'data':self.depth})
+            return
+        depth = msg['data'][0]
+        action = msg['action']
+        if action == 'update':
+            self._update_depth_update(depth)
+        elif action == 'snapshot':
+            self._update_depth_snapshot(depth)
+        ob = self._detail_ob
+        if self.compare_checksum(ob, depth):     
+            t = int(depth['ts'])
+            if t > self.update_t:
+                self.update_t = t
+                self.ticker_info["bp"] = float(ob['bids'][0][0])
+                self.ticker_info["ap"] = float(ob['asks'][0][0])
+                self.callback['onTicker'](self.ticker_info)
+                ##### 标准化深度
+                mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
+                step = mp * utils.EFF_RANGE / utils.LEVEL
+                bp = []
+                ap = []
+                bv = [0 for _ in range(utils.LEVEL)]
+                av = [0 for _ in range(utils.LEVEL)]
+                for i in range(utils.LEVEL):
+                    bp.append(self.ticker_info["bp"]-step*i)
+                for i in range(utils.LEVEL):
+                    ap.append(self.ticker_info["ap"]+step*i)
+                # 
+                price_thre = self.ticker_info["bp"] - step
+                index = 0
+                for bid in ob['bids']:
+                    price = float(bid[0])
+                    amount = float(bid[1])
+                    if price > price_thre:
+                        bv[index] += amount
+                    else:
+                        price_thre -= step
+                        index += 1
+                        if index == utils.LEVEL:
+                            break
+                        bv[index] += amount
+                price_thre = self.ticker_info["ap"] + step
+                index = 0
+                for ask in ob['asks']:
+                    price = float(ask[0])
+                    amount = float(ask[1])
+                    if price < price_thre:
+                        av[index] += amount
+                    else:
+                        price_thre += step
+                        index += 1
+                        if index == utils.LEVEL:
+                            break
+                        av[index] += amount
+                self.depth = bp + bv + ap + av
+                self.callback['onDepth']({'name':self.name,'data':self.depth})
+        else:
+            await self.resubscribe_depth(_ws)
+
+    async def resubscribe_depth(self, _ws):
+        info = f"{self.name} checksum not correct!"
+        print(info)
+        self.logger.info(info)
+        args = []
+        if self.sub_fast:
+            args.append({'channel':'books50-l2-tbt','instId':self.symbol})
+        else:
+            args.append({'channel':'books5','instId':self.symbol},)
+        # {'channel':'books','instId':self.symbol},
+        # {'channel':'books5','instId':self.symbol},
+        # {'channel':'books50-l2-tbt','instId':self.symbol},
+        # {'channel':'books-l2-tbt','instId':self.symbol},
+        sub_str = {'op':"unsubscribe",'args':args}
+        await _ws.send_str(ujson.dumps(sub_str))
+        await asyncio.sleep(1)
+        sub_str['op'] = 'subscribe'
+        await _ws.send_str(ujson.dumps(sub_str))
+
+    def _update_depth_update(self, depth):
+        ob = self._detail_ob
+        ob['timestamp'] = depth['ts']
+        bids_p = ob['bids']
+        asks_p = ob['asks']
+        bids_p = update_bids(depth, bids_p)
+        asks_p = update_asks(depth, asks_p)
+
+    def _update_depth_snapshot(self, depth):
+        self._detail_ob = depth
+
+    def _update_order(self, msg):
+        '''将ws收到的订单信息触发quant'''
+        msg = ujson.loads(msg)
+        self.logger.debug(f"ws订单推送 {msg}")
+        order = msg['data'][0]
+        if order['instId'] == self.symbol:
+            order_event = dict()
+            status = order['state']
+            if status in ["live", 'partially_filled']:
+                local_status = 'NEW'
+            elif status in ['canceled', 'filled']:
+                local_status = 'REMOVE'
+            else:
+                print(f'未知订单状态 {order}')
+                return
+            order_event['status'] = local_status
+            order_event['filled_price'] = float(order['avgPx'])
+            order_event['filled'] = float(order['accFillSz'])*self.ctVal  # usdt永续需要考虑每张的单位
+            order_event['client_id'] = order['clOrdId']
+            order_event['order_id'] = order['ordId']
+            if order['feeCcy'] == 'USDT':
+                order_event['fee'] = -float(order['fee'])
+            self.callback['onOrder'](order_event)
+            # print(order_event)
+        self.private_update_time = time.time()
+
+    def _update_balance_position(self, msg):
+        """"""
+        msg = ujson.loads(msg)
+        msg = msg['data'][0]
+        # accounts = msg['balData']
+        # self._update_account(accounts)
+        positions = msg['posData']
+        self._update_position(positions)
+
+    def _update_position(self, positions):
+        long_pos, short_pos = 0, 0
+        long_avg, short_avg = 0, 0
+        is_update = 0
+        for i in positions:
+            if i['instId'] == self.symbol:
+                is_update = 1
+                if i['posSide'] == 'long':
+                    long_pos += abs(float(i['pos'])*self.ctVal)
+                    long_avg = abs(float(i['avgPx']))
+                elif i['posSide'] == 'short':
+                    short_pos += abs(float(i['pos'])*self.ctVal)
+                    short_avg = abs(float(i['avgPx']))
+        if is_update:
+            pos = model.Position()
+            pos.longPos = long_pos
+            pos.longAvg = long_avg
+            pos.shortPos = short_pos
+            pos.shortAvg = short_avg
+            self.callback['onPosition'](pos)
+            #print(f'{self.symbol} {long_pos} {long_avg} {short_pos} {short_avg}')
+        self.private_update_time = time.time()
+
+    def _update_account(self, accounts):
+        accounts = ujson.loads(accounts)
+        for data in accounts['data']:
+            for i in data['details']:
+                if i['ccy'] == self.quote:
+                    self.data['equity'] = float(i['eq'])
+                    self.callback['onEquity']({self.quote:self.data['equity']})
+        self.private_update_time = time.time()
+
+    def _update_price_limit(self, accounts):
+        accounts = ujson.loads(accounts)
+        buy_limit = float(accounts['data'][0]['buyLmt'])
+        sell_limit = float(accounts['data'][0]['sellLmt'])
+        if self.ticker_info['bp'] > 0 and self.ticker_info['ap'] > 0:
+            mp = (self.ticker_info['bp']+self.ticker_info['ap'])*0.5
+            upper = buy_limit * 0.99
+            lower = sell_limit * 1.01
+            if mp > upper or mp < lower:
+                self.callback['onExit'](f"{self.name} 触发限价警告 准备停机")
+        self.private_update_time = time.time()
+
+    async def on_message_private(self, _ws, msg):
+        """"""
+        if "data" in msg:
+            # 推送数据时,有data字段,优先级也最高
+            if "orders" in msg:
+                self._update_order(msg)
+            elif "balance_and_position" in msg:
+                self._update_balance_position(msg)
+            elif "account" in msg:
+                self._update_account(msg)
+        elif "event" in msg:
+            # event常见于事件回报,一般都可以忽略,只需要看看是否有error
+            if "error" in msg:
+                info = f'{self.name} on_message error! --> {msg}'
+                print(info)
+                self.logger.error(info)
+        elif 'ping' in msg:
+            await _ws.send_str('pong')
+        else:
+            print(msg)
+
+    async def on_message_public(self, _ws, msg):
+        """"""
+        #print(msg)
+        if "data" in msg:
+            # 推送数据时,有data字段,优先级也最高
+            if "tickers" in msg:
+                self._update_ticker(msg)
+            elif "trades" in msg:
+                self._update_trade(msg)
+            elif "books" in msg:
+                await self._update_depth(_ws, msg)
+            elif "price-limit" in msg:
+                self._update_price_limit(msg)
+        elif "event" in msg:
+            # event常见于事件回报,一般都可以忽略,只需要看看是否有error
+            if "error" in msg:
+                info = f'{self.name} on_message error! --> {msg}'
+                print(info)
+                self.logger.error(info)
+        elif 'ping' in msg:
+            await _ws.send_str('pong')
+        else:
+            print(msg)
+
+    async def run(self, is_auth=0, sub_trade=0, sub_fast=0):
+        # update sub info
+        self.sub_fast = sub_fast
+        self.sub_trade = sub_trade
+        # exchange info
+        info = await self.get_instruments()
+        self.update_instruments(info)
+        print(f"{self.name} public 更新产品信息 ws {self.symbol} {self.ctVal} {self.ctMult}")
+        # run
+        asyncio.create_task(self.run_public())
+        if is_auth:
+            asyncio.create_task(self.run_private())
+        while True:
+            await asyncio.sleep(5)
+
+    # ------------------------------------------------------
+    # -- utils ---------------------------------------------
+
+    @staticmethod
+    def compare_checksum(ob, depth):
+        """计算深度的校验和"""
+        #t1 = time.time()
+        # 降低校验频率
+        if random.randint(0,10) == 0:
+            cm = check(ob["bids"], ob['asks'])
+            #t2 = time.time()
+            #print(cm, depth['checksum'], (t2-t1)*1000)
+            return cm==depth['checksum']
+        else:
+            return 1
+
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    async def http_get_request(self, session, method, query=None, **args):
+        """"""
+        params = dict()
+        params.update(**args)
+        if query is not None: params.update(query)
+        timestamp = self.timestamp()
+        if params:
+            path = '{method}?{params}'.format(method=method, params=urllib.parse.urlencode(params))
+        else:
+            path = method
+        sign = self.__sign(timestamp, 'GET', path, None)
+        self.getheader['OK-ACCESS-SIGN'] = sign
+        self.getheader['OK-ACCESS-TIMESTAMP'] = timestamp
+        url = f"{self.REST}{path}"
+        rst = await session.get(
+            url,
+            headers=self.getheader,
+            timeout=5,
+            proxy=self.proxy
+            )
+        return rst
+
+    async def http_get_request_pub(self, session, method, query=None, **args):
+        """"""
+        params = dict()
+        params.update(**args)
+        if query is not None: params.update(query)
+        timestamp = self.timestamp()
+        if params:
+            path = '{method}?{params}'.format(method=method, params=urllib.parse.urlencode(params))
+        else:
+            path = method
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        url = f"{self.REST}{path}"
+        rst = await session.get(
+            url,
+            headers=headers,
+            timeout=5,
+            proxy=self.proxy
+            )
+        return rst
+
+    def timestamp(self):
+        return datetime.utcnow().isoformat("T")[:-3] + 'Z'
+
+    def __sign(self, timestamp, method, path, data):
+        if data:
+            message = timestamp + method + path + data
+        else:
+            message = timestamp + method + path
+        digest = hmac.new(bytes(self.params.secret_key.encode('utf8')), bytes(message.encode('utf8')), digestmod=hashlib.sha256).digest()
+        return base64.b64encode(digest).decode('utf-8')
+
+    def login_params(self):
+        """生成login字符"""
+        timestamp = str(time.time())
+        message = timestamp + 'GET' + '/users/self/verify'
+        mac = hmac.new(bytes(self.params.secret_key.encode('utf8')), bytes(message.encode('utf8')), digestmod=hashlib.sha256).digest()
+        sign = base64.b64encode(mac)
+        login_dict = {}
+        login_dict['apiKey'] = self.params.access_key
+        login_dict['passphrase'] = self.params.pass_key
+        login_dict['timestamp'] = timestamp
+        login_dict['sign'] =  sign.decode('utf-8')
+        login_param = {'op': 'login', 'args': [login_dict]}
+        login_str = ujson.dumps(login_param)
+        return login_str
+
+    def make_header(self):
+        """"""
+        headers = {}
+        headers['Content-Type'] = 'application/json'
+        headers['OK-ACCESS-KEY'] = self.params.access_key
+        headers['OK-ACCESS-SIGN'] = None
+        headers['OK-ACCESS-TIMESTAMP'] = None
+        headers['OK-ACCESS-PASSPHRASE'] = self.params.pass_key
+        return headers
+

+ 18 - 0
exchange/readme.md

@@ -0,0 +1,18 @@
+# 接口说明
+
+# rest接口
+    1.account
+    2.position
+    3.order list
+    4.cancel order
+    5.create order
+    6.position side
+    7.exchange info
+
+# ws接口
+    1.bbo/ticker
+    2.depth
+    3.trade
+    4.account
+    5.position
+    6.order

+ 520 - 0
exchange/utils.py

@@ -0,0 +1,520 @@
+import json
+import traceback
+import utils
+import model
+import toml, time, random
+import os, sys, asyncio, aiohttp
+import socket
+import asyncio
+import requests
+import ujson
+from decimal import Decimal
+from decimal import ROUND_HALF_UP, ROUND_FLOOR
+import gzip
+import csv
+import os
+import base64
+from Crypto.Cipher import AES
+from Crypto import Random
+import os
+import base64
+import json
+
+parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 
+sys.path.insert(0,parentdir)  
+
+############### 全局配置
+VERSION = "2022-04-18"
+CHILD_RUN_SECOND = 60 * 60 * 24 # child process max run time per loop 
+EARLY_STOP_SECOND = 60 * 60 * 2 # child early stop min check time
+BACKTEST_PREHOT_SECOND = 60 * 30 # backtest pre hot time
+DUMMY_RUN_SECOND = 60 * 60 * 12 # dummy process max run time per loop 
+DUMMY_EARLY_STOP_SECOND = 60 * 60 # dummy process max run time per loop 
+POST_SIDE_LIMIT = [0] # post side limit
+MARKET_DELAY_LIMIT = 30000 # market update delay limit threhold unit:ms
+GRID = 1
+STOPLOSS = 0.02
+GAMMA = 0.999
+###### market行情数据长度 标准化n档深度+6档成交信息 ######
+LEVEL = 1
+TRADE_LEN = 2 # 最高 最低 成交价
+LEN = LEVEL * 4 + TRADE_LEN # 总长度
+BP_INDEX = LEVEL * 0
+BQ_INDEX = LEVEL * 0 + 1
+AP_INDEX = LEVEL * 2
+AQ_INDEX = LEVEL * 2 + 1
+MAX_FILL_INDEX = LEVEL * 4 + 0
+MIN_FILL_INDEX = LEVEL * 4 + 1
+# BUY_Q_INDEX = LEVEL * 4 + 2
+# BUY_V_INDEX = LEVEL * 4 + 3
+# SELL_Q_INDEX = LEVEL * 4 + 4
+# SELL_V_INDEX = LEVEL * 4 + 5
+#### depth/trade effient range #####
+EFF_RANGE = 0.001
+### init backtest delay ###
+BACKTEST_DELAY = 0.15
+
+global base_cid
+base_cid = 0
+def get_cid(broker=None):
+    global base_cid
+    base_cid += 1
+    if base_cid > 999:
+        base_cid=0
+    cid = str(time.time())[4:10]+str(random.randint(1,999))+str(base_cid)
+    if broker:
+        cid = broker + cid
+    return cid      
+
+def csv_to_gz_and_remove():
+    def List_files(filepath, substr):
+        X = []
+        Y = []
+        for path, subdirs, files in sorted(os.walk(filepath), reverse=True):
+            for name in files:
+                X.append(os.path.join(path, name))
+        Y = [line for line in X if substr in line]
+        return Y
+
+    for file in List_files('./', '.csv'):
+        if '.gz' not in file:
+            data = open(file, 'rb' ).read()
+            with gzip.open(file + '.gz', 'a') as zip:
+                zip.write(data)
+                zip.close()
+            os.remove(file)
+
+def get_params(fname):
+    # 读取配置
+    try:
+        params = toml.load(fname)
+    except:
+        f = open(fname)
+        data = f.read()
+        text = base64.b64decode(data)
+        cryptor = AES.new(key =bytes("qFHFPv6MugrSTkEsWFs8wCDg3iC6!er%".encode()), mode=AES.MODE_ECB)
+        plain_text = cryptor.decrypt(text)
+        paddingLen = plain_text[len(plain_text)-1]
+        msg = plain_text[0:-paddingLen]
+        msg = msg.decode()
+        params = toml.loads(msg)
+    p = model.Config()
+    # 账号昵称
+    p.account_name = params['account_name'] if 'account_name' in params else 'Unknown Account'
+    # api
+    p.access_key = params['access_key'].replace(" ", "") if 'access_key' in params else '***'
+    p.secret_key = params['secret_key'].replace(" ", "") if 'secret_key' in params else '***'
+    p.pass_key = params['pass_key'].replace(" ", "") if 'pass_key' in params else 'qwer1234'
+    # 经纪商id
+    broker_id_from_config = params['broker_id'] if 'broker_id' in params else ""
+    p.broker_id = get_broker_id( broker_id_from_config, params['exchange'])
+    # 交易盘口
+    p.exchange = params['exchange'] if 'exchange' in params else ""
+    # 交易品种
+    p.pair = params['pair'] if 'pair' in params else ""
+    # 调试模式开关
+    p.debug = params['debug'] if 'debug' in params else "False"
+    # 开仓
+    p.open = params['open'] if 'open' in params else "0.002"
+    # 平仓
+    p.close = params['close'] if 'close' in params else "0.0002"
+    # 监听端口
+    p.server_port = params['server_port'] if 'server_port' in params else 6000
+    # 杠杆大小
+    p.leverrate = float(params['leverrate']) if 'leverrate' in params else 1.0
+    # 参考盘口
+    p.refexchange = params['refexchange'].replace('[','').replace(']','').replace("'",'').replace(" ", "").split(',') if "refexchange" in params else ""
+    # 参考品种
+    p.refpair = params['refpair'].replace('[','').replace(']','').replace("'",'').replace(" ", "").split(',') if "refpair" in params else ""
+    # 网络代理
+    p.proxy = params['proxy'] if 'proxy' in params else None # 仅在win下有效
+    # 账户资金使用比例
+    p.used_pct = params['used_pct'] if 'used_pct' in params else "0.9"
+    # discord播报地址
+    p.webhook = params['webhook'] if 'webhook' in params else "https://discord.com/api/webhooks/907870708481265675/IfN4GqH4fj8HWS_FecH3Lrc2qtRyqsCHsSJVLFHlxY8ioHprfdxIMUNAfqkZZ6opzVEP"  
+    # 默认第n参考盘口
+    p.index = params['index'] if 'index' in params else 0
+    # 止损比例 0.02 = 2%
+    p.stoploss = params['stoploss'] if 'stoploss' in params else STOPLOSS
+    # 平滑系数 默认0.999
+    p.gamma = params['gamma'] if 'gamma' in params else GAMMA
+    # 分批建仓功能 小资金建议1 大资金建议3
+    p.grid = params['grid'] if 'grid' in params else GRID
+    # 实时调参开关 会有巨大性能损耗
+    p.backtest = params['backtest'] if 'backtest' in params else 1
+    # 保存实时行情 会有巨大性能损耗
+    p.save = params['save'] if 'save' in params else 0
+    p.place_order_limit = params['place_order_limit'] if 'place_order_limit' in params else 0 # 允许的每秒下单次数
+    # 是否启用colocation技术
+    p.colo = params['colo'] if 'colo' in params else 0 
+    # 是否启用fast行情 会增加性能开销
+    p.fast = params['fast'] if 'fast' in params else 1 
+    # 选择指定的私有ip进行网络通信 默认0 用于多网卡多ip的实例
+    p.ip = params['ip'] if 'ip' in params else 0
+    # 合约不允许holdcoin持有底仓币
+    if "spot" in p.exchange:
+        p.hold_coin = params['hold_coin'] if 'hold_coin' in params else 0.0
+    else:
+        p.hold_coin = 0.0
+    # 是否开启日志记录 会有一定性能损耗
+    p.log = params['log'] if 'log' in params else 1
+    #### 特殊情况处理
+    if p.exchange == 'binance_usdt_swap':
+        if p.pair in ['shib_usdt', 'xec_usdt', 'bttc_usdt']:
+            p.pair = "1000" + p.pair
+    ref_num = len(p.refexchange)
+    for i in range(ref_num):
+        if p.refexchange[i] == 'binance_usdt_swap':
+            if p.refpair[i] in ['shib_usdt', 'xec_usdt', 'bttc_usdt']:
+                p.refpair[i] = "1000" + p.refpair[i]
+    ####
+    return p
+
+def get_broker_id(broker_id , exchange_name):
+    '''处理brokerid特殊情况'''
+    if 'binance' in exchange_name:
+        return broker_id
+    elif 'gate' in exchange_name:
+        return "t-"
+    else:
+        return ""
+
+# 报单频率限制等级
+BASIC_LIMIT = 100
+GATE_SPOT_LIMIT = 10.0
+GATE_USDT_SWAP_LIMIT = 100.0
+KUCOIN_SPOT_LIMIT = 15.0
+KUCOIN_USDT_SWAP_LIMIT = 10.0
+BINANCE_USDT_SWAP_LIMIT = 5.0
+BINANCE_SPOT_LIMIT = 2.0
+COINEX_SPOT_LIMIT = 40.0
+COINEX_USDT_SWAP_LIMIT = 100.0
+OKEX_USDT_SWAP_LIMIT= 30.0
+BITGET_USDT_SWAP_LIMIT = 10.0
+BYBIT_USDT_SWAP_LIMIT = 1.0
+RATIO = 4.0
+
+def get_limit_requests_num_per_second(exchange, limit=0):
+    '''每秒请求频率'''
+    if limit != 0:
+        return limit*RATIO
+    elif exchange == "gate_spot":
+        return GATE_SPOT_LIMIT*RATIO
+    elif exchange == "gate_usdt_swap": # 100/s
+        return GATE_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "kucoin_spot": # 15/s
+        return KUCOIN_SPOT_LIMIT*RATIO
+    elif exchange == "kucoin_usdt_swap":
+        return KUCOIN_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "binance_usdt_swap":
+        return BINANCE_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "binance_spot":
+        return BINANCE_SPOT_LIMIT*RATIO
+    elif exchange == "coinex_spot":
+        return COINEX_SPOT_LIMIT*RATIO
+    elif exchange == "coinex_usdt_swap":
+        return COINEX_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "okex_usdt_swap":
+        return OKEX_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "bitget_usdt_swap":
+        return BITGET_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "bybit_usdt_swap":
+        return BYBIT_USDT_SWAP_LIMIT*RATIO
+    else:
+        print("限频规则未找到")
+    return BASIC_LIMIT*RATIO
+
+
+def get_limit_order_requests_num_per_second(exchange, limit=0):
+    '''每秒下单请求频率'''
+    if limit != 0:
+        return limit
+    elif exchange == "gate_spot": # 10/s
+        return GATE_SPOT_LIMIT
+    elif exchange == "gate_usdt_swap": # 100/s
+        return GATE_USDT_SWAP_LIMIT
+    elif exchange == "kucoin_spot": # 15/s
+        return KUCOIN_SPOT_LIMIT
+    elif exchange == "kucoin_usdt_swap": # 10/s
+        return KUCOIN_USDT_SWAP_LIMIT
+    elif exchange == "binance_usdt_swap": # 5/s
+        return BINANCE_USDT_SWAP_LIMIT
+    elif exchange == "binance_spot": # 2/s
+        return BINANCE_SPOT_LIMIT
+    elif exchange == "coinex_spot": # 40/s
+        return COINEX_SPOT_LIMIT
+    elif exchange == "coinex_usdt_swap": # 100/s
+        return COINEX_USDT_SWAP_LIMIT
+    elif exchange == "okex_usdt_swap": # 30/s
+        return OKEX_USDT_SWAP_LIMIT
+    elif exchange == "bitget_usdt_swap": # 10/s
+        return BITGET_USDT_SWAP_LIMIT
+    elif exchange == "bybit_usdt_swap": # 2/s
+        return BYBIT_USDT_SWAP_LIMIT
+    else:
+        print("限频规则未找到")
+    return BASIC_LIMIT
+
+def dist_to_weight(price, mp, eff_range=EFF_RANGE):
+    '''
+        距离转换为权重
+    '''
+    dist = abs(price-mp)/mp
+    weight = 1 - clip(dist/eff_range, 0.0, 0.95)
+    weight = weight if weight > 0 else 0
+    return weight
+
+def change_params(fname, params, changes):
+    # 更改配置
+    for i in changes:
+        params[i[0]] = i[1]
+    with open(f"{fname}","w") as f:
+        toml.dump(params,f)
+
+def show_memory(unit='B', threshold=1024):
+    '''查看变量占用内存情况
+
+    :param unit: 显示的单位,可为`B`,`KB`,`MB`,`GB`
+    :param threshold: 仅显示内存数值大于等于threshold的变量
+    '''
+    from sys import getsizeof
+    scale = {'B': 1, 'KB': 1024, 'MB': 1048576, 'GB': 1073741824}[unit]
+    msg = '内存占用情况: \n'
+    for i in list(globals().keys()):
+        memory = eval("getsizeof({})".format(i)) // scale
+        if memory >= threshold:
+            msg += f'{i} {memory} {unit}\n'
+    print(msg)
+    return msg
+
+def clip(num, _min, _max):
+    if num > _max: num = _max
+    if num < _min: num = _min
+    return num
+
+async def ding(msg, at_all, webhook, proxy=None):
+    '''
+        发送钉钉消息
+    '''
+    header = {
+        "Content-Type": "application/json",
+        "Charset": "UTF-8"
+    }
+    embed = {
+        "title": "策略通知",
+        "description": msg
+    }
+    message = {
+    "content": "大吉大利 今晚吃鸡",
+    "username": "千千喵",
+    "embeds": [
+        embed
+            ],
+    }
+    message_json = json.dumps(message)
+    if 'win' in sys.platform:
+        proxy = proxy
+    else:
+        proxy = None
+    async with aiohttp.ClientSession() as session:
+        await session.post(url=webhook, data=message_json, headers=header, proxy=proxy, timeout = 10)
+
+def _get_params(url, proxy, params):
+    '''更新参数'''
+    import requests
+    try:
+        res = requests.post(url=url, json=params, timeout = 10)
+        return json.loads(res.text)
+    except:
+        traceback.print_exc()
+        return []
+
+async def _post_params(url, proxy, params):
+    '''更新参数'''
+    try:
+        if 'win' in sys.platform:
+            proxy = proxy
+        else:
+            proxy = None
+        async with aiohttp.ClientSession() as session:
+            res = await session.post(url=url, proxy=proxy, data=params, timeout = 10)
+            data = await res.text()
+            print(data)
+            return data
+    except:
+        print(traceback.format_exc())
+        return "post_params error"
+    return None
+
+def get_ip():
+    try:
+        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        s.connect(('8.8.8.8', 80))
+        ip = s.getsockname()[0]
+    finally:
+        s.close()
+    return ip
+
+def check_auth():
+    print("*** 检查使用权限1 ***")
+    ip = get_ip()
+    print(f"当前IP {ip}")
+    white_list = requests.get(f"http://158.247.204.56:7777/ip_list")
+    if ip in white_list:
+        print("当前IP位于白名单中")
+    else:
+        print("@@@ 本版本仅限指定IP白名单运行 @@@")
+        os._exit(0)
+    print("*** 符合要求 ***")
+
+def check_time():
+    print("*** 检查使用权限2 ***")
+    if time.time() > int(time.mktime(time.strptime('2021-11-17 00:00:00', "%Y-%m-%d %H:%M:%S"))):
+        print("@@@ 此版本目前已过试用期 @@@")
+        os._exit(0)
+    print("*** 符合要求 ***")
+
+def num_to_str(num, d):
+    if d >= 1.0:return "%d"%num
+    elif d in [0.1, 0.5]:return "%.1f"%num
+    elif d in [0.01, 0.05]:return "%.2f"%num
+    elif d in [0.001, 0.005]:return "%.3f"%num
+    elif d in [0.0001, 0.0005]:return "%.4f"%num
+    elif d in [0.00001, 0.00005]:return "%.5f"%num
+    elif d in [0.000001, 0.000005]:return "%.6f"%num
+    elif d in [0.0000001, 0.0000005]:return "%.7f"%num
+    elif d in [0.00000001, 0.00000005]:return "%.8f"%num
+    elif d in [0.000000001, 0.000000005]:return "%.9f"%num
+    elif d in [0.0000000001, 0.0000000005]:return "%.10f"%num
+    else: return str(num)
+
+def num_to_decimal(num):
+    '''根据小数点位数获取精度'''
+    num = str(num)
+    if '.' not in num:return 0
+    elif '.' == num[-2]:return 1
+    elif '.' == num[-3]:return 2
+    elif '.' == num[-4]:return 3
+    elif '.' == num[-5]:return 4
+    elif '.' == num[-6]:return 5
+    elif '.' == num[-7]:return 6
+    elif '.' == num[-8]:return 7
+    elif '.' == num[-9]:return 8
+    elif '.' == num[-10]:return 9
+    elif '.' == num[-11]:return 10
+    else:return 11
+
+def fix_amount(amount, stepSize):
+    '''修补数量向下取整'''
+    return float(
+            Decimal(str(amount))//Decimal(str(stepSize)
+        ) \
+        * Decimal(str(stepSize)))
+    # return float(Decimal(str(amount)).quantize(Decimal(str(stepSize)), ROUND_FLOOR))
+
+
+def fix_price(price, tickSize):
+    '''修补价格四舍五入'''
+    return float(
+            round(Decimal(str(price))/Decimal(str(tickSize))
+        ) \
+        * Decimal(str(tickSize)))
+    # return float(Decimal(str(price)).quantize(Decimal(str(tickSize)), ROUND_HALF_UP))
+
+def timeit(func):
+    def wrapper(*args, **kwargs):
+        nowTime = time.time()
+        res = func(*args, **kwargs)
+        spend_time = time.time() - nowTime
+        spend_time = round(spend_time * 1e6, 3)
+        print(f'{func.__name__} 耗时 {spend_time} us')
+        return res
+    return wrapper
+
+def get_backtest_set(base=""):
+    '''生成预设参数'''
+    # 开仓距离不能太近必须超过大部分价格tick运动的距离
+    open_list = [
+        0.0055,
+        0.0045,
+        0.0035,
+        0.0030,
+        0.0025,
+        0.0020,
+        0.0015,
+    ]
+    close_dict = dict()
+    for open in open_list:
+        close_dict[open] = [
+            open*0.1,
+            open*0.2,
+            ]
+    alpha_list = [0.0]
+    return open_list, close_dict, alpha_list
+
+def get_local_ip_list():
+    '''获取本地ip'''
+    import netifaces as ni
+    ipList = []
+    # print('检测服务器网络配置')
+    for dev in ni.interfaces():
+        print('dev:',dev)
+        if 'ens' in dev or 'eth' in dev or 'enp' in dev:
+            # print(ni.ifaddresses(dev))
+            for i in ni.ifaddresses(dev)[2]:
+                ip=i['addr']
+                print(f"检测到私有ip:{ip}")
+                if ip not in ipList:
+                    ipList.append(ip)
+    print(f"当前服务器私有ip为{ipList}")
+    return ipList
+    
+if __name__ == "__main__":
+
+    #########
+    if 0:
+        print(fix_amount(1.0, 0.1))
+        print(fix_amount(0.9, 0.05))
+        print(fix_amount(1.1, 0.1))
+        print(fix_amount(1.2, 0.5))
+        print(fix_amount(0.01, 0.05))
+    if 1:
+        print(fix_price(1.0, 0.1))
+        print(fix_price(0.9, 2.0))
+        print(fix_price(1.1, 0.1))
+        print(fix_price(1.2, 0.5))
+        print(fix_price(4999.99, 0.5))
+    #########
+    if 0:
+        # print(num_to_str(123.123))
+        print(get_backtest_set())
+
+    ####################
+    if 0:
+
+        p = get_params("config.toml")
+
+        loop = asyncio.get_event_loop()
+        
+        # loop.create_task(ding("123", 1, "https://discord.com/api/webhooks/907870708481265675/IfN4GqH4fj8HWS_FecH3Lrc2qtRyqsCHsSJVLFHlxY8ioHprfdxIMUNAfqkZZ6opzVEP"))
+        
+        loop.create_task(
+            _post_params(
+                "http://wwww.khods.com:8888/post_params", 
+                None,
+                ujson.dumps({
+                    "exchange":"binance_usdt_swap",
+                    "pair":"eth_usdt",
+                    "open":"0.001",
+                    "close":"0.0001",
+                    "refexchange":"binance_spot",
+                    "refpair":"eth_usdt",
+                    "profit":0.1,
+                })
+            )
+        )
+
+        loop.run_forever()
+    
+    ####################
+

+ 99 - 0
model.py

@@ -0,0 +1,99 @@
+import utils
+
+class BacktestFee:
+    
+    def __init__(self, msg=None):
+        if msg == "v9":
+            self.maker = -0.00001
+            self.taker =  0.0002
+        elif msg == "v0":
+            self.maker = 0.0001
+            self.taker = 0.0005
+        else:
+            self.maker = 0
+            self.taker = 0
+
+class ExchangeInfo:
+    def __init__(self) -> None:
+        self.symbol = None
+        self.tickSize = None
+        self.stepSize = None
+        self.multiplier = None
+
+class Order:
+    def __init__(self):
+        self.symbol = None
+        self.order_id = None
+        self.amount = None
+        self.side = None
+        self.price = None
+
+class Position():
+
+    def __init__(self):
+        self.longPos = 0
+        self.shortPos = 0
+        self.longAvg = 0
+        self.shortAvg = 0
+
+class TraderMsg:
+
+    def __init__(self):
+        self.position = Position()
+        self.cash = 0.0
+        self.coin = 0.0
+        self.orders = dict()
+        self.ref_price = None
+        self.market = []
+        self.predict = 0.0
+
+class ClientParams:
+
+    def __init__(self):
+        self.name = None
+        self.pair = None
+        self.proxy = None
+        self.access_key = None
+        self.secret_key = None
+        self.pass_key = None
+        self.interval = None
+        self.broker_id = None
+        self.debug = None
+        self.ip = 0
+
+class Config:
+
+    def __init__(self):
+        self.broker_id = None
+        self.account_name = None
+        self.access_key = None
+        self.secret_key = None
+        self.pass_key = None
+        self.exchange = None
+        self.pair = None
+        self.debug = None
+        self.open = None
+        self.close = None
+        self.server_port = None
+        self.leverrate = None
+        self.interval = 0.1
+        self.close = None
+        self.open = None
+        self.refexchange = None
+        self.refpair = None
+        self.webhook = None
+        self.used_pct = None
+        self.place_order_limit = 0
+        # self.proxy = "http://127.0.0.1:4780" # 仅在win下有效
+        self.proxy = None # 仅在win下有效
+        self.index = 0
+        self.save = 0
+        self.hold_coin = 0.0
+        self.log = 0
+        self.stoploss = 0.05
+        self.gamma = 0.999
+        self.grid = 1
+        self.backtest = 0
+        self.colo = 0
+        self.fast = 1
+        self.ip = 0

+ 131 - 0
predictor.py

@@ -0,0 +1,131 @@
+import time
+import utils
+import numpy as np
+
+class Predictor:
+    '''
+        reference
+    '''
+    def __init__(self, ref_name = ["Unknown Market"], alpha = [1.0 for _ in range(99)], gamma = 0.999):
+        self.loop = 0
+        self.arr = []
+        self.trade_mp_series = []
+        self.ref_mp_series = []
+        self.ref_num = len(ref_name)
+        ### 定价
+        self.window = 10
+        ### 价格系数
+        self.alpha = alpha
+        # 参考
+        print('定价系数:', gamma)
+        self.gamma = gamma
+        self.avg_spread = [None for _ in range(self.ref_num)]
+    
+    def processer(self):
+        '''
+            计算任务
+        '''
+        # update trade mp
+        bp=self.arr[-1][utils.BP_INDEX]
+        ap=self.arr[-1][utils.AP_INDEX]
+        mp=(bp+ap)*0.5
+        self.trade_mp_series.append(mp)
+        # 更新参考盘口mp
+        ref_mp = []
+        for i in range(self.ref_num):
+            bp = self.arr[-1][utils.LEN*(1+i)+utils.BP_INDEX]
+            ap = self.arr[-1][utils.LEN*(1+i)+utils.AP_INDEX]
+            mp=(bp+ap)*0.5
+            ref_mp.append(mp)
+        self.ref_mp_series.append(ref_mp)
+        # 偏差计算
+        self._update_avg_spread()
+
+    def _update_avg_spread(self):
+        '''
+            更新平均偏差
+        '''
+        # 计算偏差1
+        for i in range(self.ref_num):
+            bias = self.ref_mp_series[-1][i]*self.alpha[i] - self.trade_mp_series[-1]
+            # 如果是刚启动 gamma不能太大
+            if self.loop < 100:
+                gamma = 0.9
+            else:
+                gamma = self.gamma
+            if self.avg_spread[i] == None:
+                self.avg_spread[i] = bias
+            else:
+                self.avg_spread[i] = self.avg_spread[i]*gamma + bias*(1-gamma)
+
+    def check_length(self):
+        # 行情缓存
+        if len(self.arr) > self.window:del(self.arr[0])
+        if len(self.trade_mp_series) > self.window:del(self.trade_mp_series[0])
+        if len(self.ref_mp_series) > self.window:del(self.ref_mp_series[0])
+
+    # @utils.timeit
+    def onTime(self, data):
+        if isinstance(data, list):
+            if len(data) > 0:
+                self.loop += 1
+                self.arr.append(data)
+                self.processer()
+                self.check_length()
+            else:
+                print("行情数据为空")
+        else:
+            print("行情数据为None")
+
+    # @utils.timeit
+    def Get_ref(self, ref_ticker):
+        '''
+            get ref price
+        '''
+        ref_mid = []
+        for i in range(self.ref_num):
+            ref_mid.append(
+                [ref_ticker[i][0]*self.alpha[i] - self.avg_spread[i], ref_ticker[i][1]*self.alpha[i] - self.avg_spread[i]]
+            )
+        return ref_mid
+
+if __name__ == "__main__":
+
+    import pandas as pd
+    import numpy as np
+    import matplotlib
+    matplotlib.use('TkAgg')
+    import matplotlib.pyplot as plt
+
+    arr = pd.read_csv('history/ftm_usdt_binance_usdt_swap.csv').values.tolist()
+
+    def line_data_to_tickers(data, ref_num):
+        ref_tickers = []
+        for i in range(ref_num):
+            ref_tickers.append([data[utils.LEN*(i+1)+utils.BP_INDEX], data[utils.LEN*(i+1)+utils.AP_INDEX]])
+        return ref_tickers
+
+    ref_num = len(arr[0])//utils.LEN - 1
+
+    p = Predictor(ref_name=["unkwon" for _ in range(ref_num)])
+    t = []
+    ref_index = 1
+
+    for data in arr:
+        p.onTime(data)
+        trade_mp = (data[utils.BP_INDEX] + data[utils.AP_INDEX])*0.5
+        ref_price = p.Get_ref(line_data_to_tickers(data, ref_num))
+        t.append([
+            trade_mp,
+            (ref_price[ref_index][0]+ref_price[ref_index][1])*0.5,
+        ])
+
+    t = pd.DataFrame(t,columns=['mp','ref'])
+
+    if 1:
+        plt.figure()
+        plt.plot(t['mp'],'k')
+        plt.plot(t['ref'],'g')
+        plt.grid()
+        plt.show()
+

+ 1473 - 0
quant.py

@@ -0,0 +1,1473 @@
+import asyncio
+from aiohttp import web
+import traceback
+import time, csv
+import strategy as strategy
+import utils
+import model
+import logging, logging.handlers
+import signal
+import os, json, sys
+import predictor
+import backtest
+import multiprocessing
+import random
+import psutil
+import ujson
+import broker
+from decimal import Decimal
+
+VERSION = utils.VERSION
+
+def timeit(func):
+    def wrapper(*args, **kwargs):
+        nowTime = time.time()
+        res = func(*args, **kwargs)
+        spend_time = time.time() - nowTime
+        spend_time = round(spend_time * 1000, 5)
+        print(f'{func.__name__} 耗时 {spend_time} ms')
+        return res
+    return wrapper
+
+class Quant:
+
+    def __init__(self, params:model.Config, logname="test_logname", father=1):
+        print('###############   超级无敌韭菜收割机   ################')
+        print(f'>>> 版本 {VERSION} <<<')
+        print('*** 当前配置')
+        self.params = params
+        for p in self.params.__dict__:
+            print('***', p, ' => ', getattr(self.params, p))
+        print('##################################################')
+        self.logger = self.get_logger(logname)
+        self.csvname = logname + ' ' + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 
+        pid = os.getpid()
+        self.pid_start_time = time.time()
+        self.logger.info(f"进程号{pid} 启动时间{self.pid_start_time}")
+        ##### 绑定cpu
+        cpu_count = psutil.cpu_count()
+        print("检测cpu核心负载")
+        cpu_used_pct = [0.0 for _ in range(cpu_count)]
+        for _ in range(random.randint(5,15)):
+            r = psutil.cpu_times_percent(percpu=True)
+            for i in range(cpu_count):
+                cpu_used_pct[i] += int(r[i].user)
+            time.sleep(1)
+            print(cpu_used_pct)
+        cpu_id = cpu_used_pct.index(min(cpu_used_pct))
+        print(f"当前负载最低的cpu为:{cpu_id}")
+        self.process = psutil.Process(pid)
+        print(f"核心数{cpu_count} 目标绑定cpu:{cpu_id}")
+        os.system(f"taskset -cp {cpu_id} {pid}")
+        print("调整系统调度优先级为最高等级")
+        if 'win' not in sys.platform:
+            print(os.nice(-20))
+        #### cpu 内存 平均占用
+        self.cpu_ema = 0.0
+        self.mm_ema = 0.0
+        #####
+        self.acct_name = self.params.account_name
+        self.symbol = self.params.pair
+        self.base = self.params.pair.split('_')[0].upper()
+        self.quote = self.params.pair.split('_')[1].upper()
+        if 1:
+            ### 使用uvloop
+            if 'win' not in sys.platform:
+                print('采用高速事件循环库')
+                import uvloop
+                self.loop = uvloop.new_event_loop()
+            else:
+                print('采用普通事件循环库')
+                self.loop = asyncio.get_event_loop()
+        else:
+            ### 使用原生loop
+            self.loop = asyncio.get_event_loop()
+        self.strategy = strategy.Strategy(self.params, is_print=1)
+        ######  判断启动方式
+        self.father = father
+        print(f"父进程标识 {self.father}")
+        ##### 现货底仓
+        hold_coin = float(self.params.hold_coin)
+        self.hold_coin = utils.clip(hold_coin, 0.0, 10000.0)
+        ##### 本地状态量
+        self.data = dict()
+        self.total_equity = 0.0
+        self.local_orders = dict() # 本地挂单表
+        self.local_orders_backup = dict() # 本地订单缓存队列
+        self.local_orders_backup_cid = [] # 本地订单缓存cid队列
+        self.handled_orders_cid = [] # 本地已处理cid缓存队列
+        self.local_profit = 0.0
+        self.local_cash = 0.0 # 本地U保证金
+        self.local_coin = 0.0 # 本地币保证金
+        self.local_position = model.Position()
+        self.local_position_by_orders = model.Position()
+        self.local_buy_amount = 0.0
+        self.local_sell_amount = 0.0
+        self.local_buy_value = 0.0
+        self.local_sell_value = 0.0
+        self.local_cancel_log = dict()
+        self.interval = float(self.params.interval)
+        self.exchange = self.params.exchange
+        self.tradeMsg = model.TraderMsg()
+        self.exit_msg = "正常退出"
+        self.save = int(self.params.save) # 保存行情数据
+        self.logger.info(f"实时行情数据记录开关:{self.save}")
+        # 仓位检查结果序列
+        self.position_check_series = []
+        # 止损大小
+        self.stoploss = float(self.params.stoploss)
+        # 资金使用率
+        self.used_pct = float(self.params.used_pct) #使用资金比例
+        # 启停信号 0 表示运行 大于1开始倒计时 1时停机
+        self.mode_signal = 0 
+        # 交易盘口订单流更新时间
+        self.trade_order_update_time = time.time()
+        # onTick触发时间记录
+        self.on_tick_event_time = time.time()
+        # 盘口ticker depth信息
+        self.tickers = dict()
+        self.depths = dict()
+        # 行情更新延迟监控
+        self.market_update_time = dict()
+        self.market_update_interval = dict()
+        # 参考盘口名称
+        refex = self.params.refexchange
+        refpair = self.params.refpair
+        if len(refex) != len(refpair):
+            self.logger.error("参考盘口数不等于参考品种数 退出")
+            raise Exception("参考盘口数不等于参考品种数 退出")
+        self.ref_num = len(refex)
+        self.ref_name = []
+        for i in range(self.ref_num):
+            if refex[i] not in broker.exchange_lists:
+                self.logger.error("出现不支持的参考盘口")
+                raise Exception("出现不支持的参考盘口")
+            name = refex[i] + '@' + refpair[i] + '@ref'
+            self.ref_name.append(name)
+            self.tickers[name] = dict()
+            self.depths[name] = list()
+            self.market_update_time[name] = 0.0
+            self.market_update_interval[name] = 0.0
+        # 参考盘口tick更新时间
+        # 服务器私有ip地址检查
+        ipList = utils.get_local_ip_list()
+        ipListNum = len(ipList)
+        if int(self.params.ip) >= ipListNum:
+            raise Exception("指定私有ip地址序号不存在")
+        # 创建ws实例
+        name = self.exchange+'@'+self.params.pair
+        self.trade_name = name
+        self.market_update_time[name] = 0.0
+        self.market_update_interval[name] = 0.0
+        self.tickers[name] = dict()
+        self.depths[name] = list()
+        cp = model.ClientParams()
+        cp.name = self.trade_name
+        cp.pair = self.params.pair
+        cp.access_key = self.params.access_key
+        cp.secret_key = self.params.secret_key
+        cp.pass_key = self.params.pass_key
+        cp.interval = self.params.interval
+        cp.broker_id = self.params.broker_id
+        cp.debug = self.params.debug
+        cp.proxy = self.params.proxy
+        cp.interval = self.params.interval
+        cp.ip = int(self.params.ip)
+        self.ws = broker.newWs(self.exchange)(
+            params=cp, 
+            colo=int(self.params.colo),
+            is_print=0,
+        )
+        self.ws.logger = self.logger
+        self.ready = 0
+        # rest实例
+        self.rest = broker.newRest(self.exchange)(cp, colo=int(self.params.colo))
+        self.ws_ref = dict()
+        # 参考盘口 ws 实例
+        for i in range(self.ref_num):
+            cp = model.ClientParams()
+            cp.name = self.ref_name[i]
+            cp.pair = self.params.refpair[i]
+            cp.proxy = self.params.proxy
+            cp.interval = self.params.interval
+            cp.ip = int(self.params.ip)
+            exchange = self.params.refexchange[i]
+            if exchange not in broker.exchange_lists:
+                self.logger.error("参考盘口名称错误 退出")
+                return
+            _colo = 0
+            if self.params.refexchange[i] == self.params.exchange and \
+                self.params.refpair[i] == self.params.pair and int(self.params.colo):
+                _colo = 1
+            self.ws_ref[self.ref_name[i]] = broker.newWs(exchange)(cp, colo=_colo)
+            self.ws_ref[self.ref_name[i]].callback['onTicker']=self.update_ticker
+            self.ws_ref[self.ref_name[i]].callback['onDepth']=self.update_depth
+            self.ws_ref[self.ref_name[i]].logger = self.logger
+        # 添加回调
+        self.ws.callback = {
+            'onTicker':self.update_ticker,
+            'onDepth':self.update_depth,
+            'onPosition':self.update_position,
+            'onEquity':self.update_equity,
+            'onOrder':self.update_order,
+            'onExit':self.update_exit,
+            }
+        self.rest.callback = {
+            'onTicker':self.update_ticker,
+            'onDepth':self.update_depth,
+            'onPosition':self.update_position,
+            'onEquity':self.update_equity,
+            'onOrder':self.update_order,
+            'onExit':self.update_exit,
+            }
+        self.rest.logger = self.logger
+        # 配置策略
+        self.strategy.logger = self.logger
+        # 配置定价模型
+        price_alpha = []
+        for i in self.params.refpair:
+            # 交易1000shib 参考 shib
+            if '1000' in self.params.pair and '1000' not in i:
+                price_alpha.append(1000.0)
+            # 交易shib 参考 1000shib
+            elif '1000' not in self.params.pair and '1000' in i:
+                price_alpha.append(0.001)
+            else:
+            # 交易shib 参考 shib
+                price_alpha.append(1.0)
+        self.logger.info(f'价格系数{price_alpha}')
+        self.Predictor = predictor.Predictor(ref_name=self.ref_name, alpha=price_alpha, gamma=float(self.params.gamma))
+        # 初始化参数
+        self.strategy.trade_open_dist = float(self.params.open)
+        self.strategy.trade_close_dist = float(self.params.close)
+        # 在线训练
+        self.backtest = int(self.params.backtest)
+        self.logger.info(f'在线训练开关 {self.backtest}')
+        ####
+        time.sleep(3)
+
+    def get_logger(self, logname):
+        '''日志模块'''
+        logger = logging.getLogger(__name__)
+        # log flag
+        if int(self.params.log):
+            log_level = logging.DEBUG
+            logger.setLevel(log_level)
+            # log to txt
+            formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+            if logname == None: logname = "log"
+            handler = logging.handlers.RotatingFileHandler(f"{logname}.log",maxBytes=1024*1024*10,encoding='utf-8')
+            handler.setLevel(log_level)
+            handler.setFormatter(formatter)
+            # log to console
+            console = logging.StreamHandler()
+            console.setLevel(logging.INFO)
+            # add
+            logger.addHandler(handler)
+            logger.addHandler(console)
+        else:
+            log_level = logging.INFO
+            logger.setLevel(log_level)
+            # log to txt
+            formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+            if logname == None: logname = "log"
+            handler = logging.handlers.RotatingFileHandler(f"{logname}.log",maxBytes=1024*1024*10,encoding='utf-8')
+            handler.setLevel(log_level)
+            handler.setFormatter(formatter)
+            # add
+            logger.addHandler(handler)
+        logger.info('开启日志记录')
+        return logger
+
+    def update_order(self, data):
+        self.loop.create_task(self._update_order(data))
+
+    async def _update_order(self, data):
+        '''
+            更新订单
+            首先直接复写本地订单
+            1、如果是开仓单
+                如果新增: 增加本地订单
+                如果取消: 删除本地订单 查看是否完全成交 如果是部分成交 则按已成交量发送平仓订单 修改本地仓位
+                如果成交: 删除本地订单 发送平仓订单 修改本地仓位
+            2、如果是平仓单
+                如果新增: 增加本地订单
+                如果取消: 删除本地订单 查看是否完全成交 如果是部分成交 则按未成交量发送平仓订单 修改本地仓位
+                如果成交: 删除本地订单 修改本地仓位
+            NEW 可以从 ws / rest 来
+            REMOVE 主要从 ws 来 必须包含 filled 和 filled_price 用于本地仓位推算 定期rest查过旧订单
+            为了防止下单失败依然有订单成交 本地需要做一个缓存
+        '''
+        try:
+            # 触发订单更新
+            self.trade_order_update_time = time.time()
+            # 新增订单推送 仅需要cid oid信息
+            if data['status'] == 'NEW':
+                # 更新oid信息 更新订单 loceltime信息(尤其是查单返回new的情况 必须更新 否则会误触发风控)
+                if data['client_id'] in self.local_orders:
+                    self.local_orders[data['client_id']]["order_id"] = data['order_id']
+                    self.local_orders[data['client_id']]["localtime"] = time.time()
+            # 完成订单推送 仅需要cid filled filled_size信息
+            elif data['status'] == 'REMOVE':
+                # 如果在撤单记录中 说明此订单结束生命周期 可以移除记录
+                if data["client_id"] in self.local_cancel_log:
+                    del(self.local_cancel_log[data["client_id"]])
+                # 在cid缓存队列中 说明是本策略的订单 
+                if data["client_id"] in self.local_orders_backup:
+                    # 不在已处理cid缓存队列中 说明还没参与过仓位计算 则执行订单计算
+                    if data['client_id'] not in self.handled_orders_cid:
+                        # 添加进已处理队列
+                        self.handled_orders_cid.append(data["client_id"])
+                        # 提取成交信息 方向 价格 量
+                        filled = data["filled"]
+                        side = self.local_orders_backup[data['client_id']]["side"]
+                        if "filled_price" in data:
+                            if data["filled_price"] > 0.0:
+                                filled_price = data["filled_price"]
+                            else: 
+                                filled_price = self.local_orders_backup[data['client_id']]["price"]
+                        else:
+                            filled_price = self.local_orders_backup[data['client_id']]["price"]
+                        # 只有开仓成交才触发onPosition
+                        # 如果漏推送 rest补充的订单查询信息过来 可能会导致 kd kk 推送出现计算分母为0的情况
+                        if filled > 0:
+                            if "spot" in self.exchange:# 如果是现货交易 还需要修改equity
+                                ### 现货必须考虑fee 买入fee单位为币 卖出fee单位为u
+                                fee = data["fee"]
+                                ### 现货订单流仓位计算
+                                if side == "kd": # buy
+                                    self.local_buy_amount += filled - fee
+                                    self.local_buy_value += (filled - fee) * filled_price
+                                    new_long_pos = float(Decimal(str(self.local_position_by_orders.longPos)) + Decimal(str(filled)) - Decimal(str(fee)))
+                                    if new_long_pos == 0.0:
+                                        self.local_position_by_orders.longAvg = 0.0
+                                        self.local_position_by_orders.longPos = 0.0
+                                    else:
+                                        self.local_position_by_orders.longAvg = \
+                                            (self.local_position_by_orders.longPos * self.local_position_by_orders.longAvg + filled * filled_price) / new_long_pos
+                                        self.local_position_by_orders.longPos = new_long_pos
+                                    self.local_cash -= filled * filled_price
+                                    self.local_coin += filled - fee
+                                elif side == "pd": # sell
+                                    self.local_sell_amount += filled
+                                    self.local_sell_value += filled * filled_price
+                                    self.local_profit += filled * (filled_price - self.local_position_by_orders.longAvg)
+                                    new_long_pos = float(Decimal(str(self.local_position_by_orders.longPos)) - Decimal(str(filled)))
+                                    if new_long_pos == 0.0:
+                                        self.local_position_by_orders.longAvg = 0.0
+                                        self.local_position_by_orders.longPos = 0.0
+                                    else:
+                                        self.local_position_by_orders.longPos = new_long_pos
+                                    self.local_cash += filled * filled_price - fee
+                                    self.local_coin -= filled
+                                elif side == "pk": # buy
+                                    self.local_buy_amount += filled - fee
+                                    self.local_buy_value += (filled - fee) * filled_price
+                                    self.local_profit += filled * (self.local_position_by_orders.shortAvg - filled_price)
+                                    new_short_pos = float(Decimal(str(self.local_position_by_orders.shortPos)) - Decimal(str(filled)) - Decimal(str(fee)))
+                                    if new_short_pos == 0.0:
+                                        self.local_position_by_orders.shortAvg = 0.0
+                                        self.local_position_by_orders.shortPos = 0.0
+                                    else:
+                                        self.local_position_by_orders.shortPos = new_short_pos
+                                    self.local_cash -= filled * filled_price
+                                    self.local_coin += filled - fee
+                                elif side == "kk": # sell
+                                    self.local_sell_amount += filled
+                                    self.local_sell_value += filled * filled_price
+                                    new_short_pos = float(Decimal(str(self.local_position_by_orders.shortPos)) + Decimal(str(filled)))
+                                    if new_short_pos == 0.0:
+                                        self.local_position_by_orders.shortAvg = 0.0
+                                        self.local_position_by_orders.shortPos = 0.0
+                                    else:
+                                        self.local_position_by_orders.shortAvg = \
+                                            (self.local_position_by_orders.shortPos * self.local_position_by_orders.shortAvg + filled * filled_price) / new_short_pos
+                                        self.local_position_by_orders.shortPos = new_short_pos
+                                    self.local_cash += filled * filled_price - fee
+                                    self.local_coin -= filled
+                                else:
+                                    self.logger.error(f"错误的仓位方向{side}")
+                            else:
+                                ### 合约订单流仓位计算
+                                if side == "kd":
+                                    self.local_buy_amount += filled
+                                    self.local_buy_value += filled * filled_price
+                                    new_long_pos = (self.local_position_by_orders.longPos + filled)
+                                    if new_long_pos == 0.0:
+                                        self.local_position_by_orders.longAvg = 0
+                                        self.local_position_by_orders.longPos = 0
+                                    else:
+                                        self.local_position_by_orders.longAvg = \
+                                            (self.local_position_by_orders.longPos * self.local_position_by_orders.longAvg + filled * filled_price) / new_long_pos
+                                        self.local_position_by_orders.longPos = float(Decimal(str(self.local_position_by_orders.longPos)) + Decimal(str(filled)))
+                                elif side == "kk":
+                                    self.local_sell_amount += filled
+                                    self.local_sell_value += filled * filled_price
+                                    new_short_pos = (self.local_position_by_orders.shortPos + filled)
+                                    if new_short_pos == 0.0:
+                                        self.local_position_by_orders.shortAvg = 0
+                                        self.local_position_by_orders.shortPos = 0
+                                    else:
+                                        self.local_position_by_orders.shortAvg = \
+                                            (self.local_position_by_orders.shortPos * self.local_position_by_orders.shortAvg + filled * filled_price) / new_short_pos
+                                        self.local_position_by_orders.shortPos = float(Decimal(str(self.local_position_by_orders.shortPos)) + Decimal(str(filled)))
+                                elif side == "pd":
+                                    self.local_sell_amount += filled
+                                    self.local_sell_value += filled * filled_price
+                                    self.local_profit += filled * (filled_price - self.local_position_by_orders.longAvg)
+                                    self.local_position_by_orders.longPos = float(Decimal(str(self.local_position_by_orders.longPos)) - Decimal(str(filled)))
+                                    if self.local_position_by_orders.longPos == 0:self.local_position_by_orders.longAvg = 0
+                                elif side == "pk":
+                                    self.local_buy_amount += filled
+                                    self.local_buy_value += filled * filled_price
+                                    self.local_profit += filled * (self.local_position_by_orders.shortAvg-filled_price)
+                                    self.local_position_by_orders.shortPos = float(Decimal(str(self.local_position_by_orders.shortPos)) - Decimal(str(filled)))
+                                    if self.local_position_by_orders.shortPos == 0:self.local_position_by_orders.shortAvg = 0
+                                else:
+                                    self.logger.error(f"错误的仓位方向{side}")
+                                # 统计合约交易手续费 正fee为扣手续费 负fee为返佣
+                                if 'fee' in data:
+                                    self.local_profit -= data['fee']
+                            self.logger.debug('更新推算仓位'+str(self.local_position_by_orders.__dict__))
+                            ### 
+                            self._print_local_trades_summary()
+                        # 每次有订单变动就触发一次策略
+                        if self.mode_signal == 0 and self.ready:
+                            ### 更新交易数据
+                            self.update_trade_msg()
+                            ### 触发策略挂单逻辑
+                            # 更新策略时间
+                            self.strategy.local_time = time.time()
+                            orders = self.strategy.onTime(self.tradeMsg)
+                            ### 记录指令触发信息
+                            if self._not_empty(orders):
+                                self.logger.debug("触发onOrder")
+                                self._update_local_orders(orders)
+                                self.loop.create_task(self.rest.handle_signals(orders))
+                                self.logger.debug(orders)
+                    else:
+                        self.logger.debug(f"订单已经参与过仓位计算 拒绝重复进行计算{data['client_id']}")
+                else:
+                    self.logger.debug(f"订单不属于本策略 拒绝进行仓位计算{data['client_id']}")
+                # 移除本地订单
+                if data["client_id"] in self.local_orders:
+                    self.logger.debug(['删除本地订单', data["client_id"]])
+                    del(self.local_orders[data["client_id"]])
+                else:
+                    self.logger.debug(['该订单不在本地挂单表中', data["client_id"]])
+            else:
+                print(data)
+                self.logger.debug(f"未知的订单事件类型 {data}")
+        except Exception as e:
+            print("处理订单推送出错:"+str(e))
+            self.logger.error("处理订单推送出错:"+str(e))
+            self.logger.error(traceback.format_exc())
+            self.exit_msg="处理订单推送出错"
+            self.stop()
+
+    def _update_local_orders(self, orders):
+        """
+            本地记录所有报单信息
+        """
+        try:
+            for i in orders:
+                if "Limits" in i:
+                    for j in orders[i]:
+                        order_info = dict()
+                        order_info['symbol'] = self.symbol
+                        order_info['amount'] = float(j[0])
+                        order_info['side'] = j[1]
+                        order_info['price'] = float(j[2])
+                        order_info['client_id'] = j[3]
+                        order_info['filled_price'] = 0
+                        order_info['filled'] = 0
+                        order_info['order_id'] = ""
+                        order_info['localtime'] = self.strategy.local_time
+                        order_info['createtime'] = self.strategy.local_time
+                        self.local_orders[j[3]] = order_info # 本地挂单表
+                        self.logger.debug(['新增本地订单', order_info])
+                        self.local_orders_backup[j[3]] = order_info # 本地缓存表
+                        self.local_orders_backup_cid.append(j[3]) # 本地缓存cid表
+                if 'Cancel' in i:
+                    # 记录撤单次数
+                    cid = orders[i][0]
+                    if cid in self.local_cancel_log:
+                        self.local_cancel_log[cid] += 1
+                    else:
+                        self.local_cancel_log[cid] = 0
+            # 清除过于久远的历史记录
+            if len(self.local_orders_backup_cid) > 9999:
+                cid = self.local_orders_backup_cid[0]
+                # 判断是否超过1个小时 如果超过则移除历史记录
+                if cid in self.local_orders_backup:
+                    if time.time() - self.local_orders_backup[cid]["localtime"] > 3600:
+                        del(self.local_orders_backup[cid])
+                        del(self.local_orders_backup_cid[0])
+            if len(self.handled_orders_cid) > 9999:
+                del(self.handled_orders_cid[0])
+        except:
+            self.logger.error("本地记录订单信息出错")
+            self.logger.error(traceback.format_exc())
+            self.exit_msg="本地记录订单信息出错"
+            self.stop()
+
+    def _not_empty(self, orders):
+        '''检查指令是否不为空'''
+        if isinstance(orders, dict):
+            for order_name in orders:
+                if "Cancel" in order_name or "Check" in order_name:
+                    return 1
+                elif "Limits_open" in order_name:
+                    if len(orders["Limits_open"]) > 0:
+                        return 1
+                elif "Limits_close" in order_name:
+                    if len(orders["Limits_close"]) > 0:
+                        return 1
+        return 0
+
+    def _print_local_trades_summary(self):
+        '''计算本地累计利润'''
+        ###
+        local_buy_amount = round(self.local_buy_amount,5)
+        local_buy_value = round(self.local_buy_value,5)
+        local_sell_amount = round(self.local_sell_amount,5)
+        local_sell_value = round(self.local_sell_value,5)
+        local_profit = 0.0 
+        if isinstance(self.strategy.mp, float):
+            unrealized = (local_buy_amount - local_sell_amount) * self.strategy.mp
+            realized = local_sell_value - local_buy_value
+            local_profit = round(unrealized+realized,5)
+            self.strategy.local_profit = local_profit
+            ###
+            msg = f"买量{local_buy_amount} 卖量{local_sell_amount} 买额{local_buy_value} 卖额{local_sell_value} 利润 {local_profit}"
+            self.logger.info(msg)
+
+    def update_position(self, data):
+        '''
+            更新仓位信息
+        '''
+        if data != self.local_position:
+            self.local_position = data
+            self.logger.debug('更新本地仓位'+str(self.local_position.__dict__))
+
+    """
+    2023-2-22
+    用create_task去执行,会延迟,占用越大,延迟越大,可能会延迟100ms计算
+    """
+    def update_ticker(self, data):
+        '''
+            增加onticker撤单 可能会导致平仓难度加大
+        '''
+        self.loop.create_task(self._update_ticker(data))
+
+    def update_depth(self, data):
+        self.loop.create_task(self._update_depth(data))
+
+    async def _update_ticker(self, data):
+        '''
+            update ticker infomation
+        '''
+        name = data['name']
+        # 记录tick更新时间
+        # self.market_update_time[name] = time.time()
+        self.tickers[name] = data
+        ### 判断是否需要触发ontick
+        if name == self.ref_name[self.strategy.ref_index]:
+            pass
+        elif name == self.trade_name:
+            pass
+        else:
+            pass
+
+    # @utils.timeit
+    async def _update_depth(self, data):
+        '''
+            update orderbook infomation
+        '''
+        name = data['name']
+        now_time = time.time()
+
+        if self.market_update_time[name] == 0.0:
+            pass
+        else:
+            interval = now_time - self.market_update_time[name]
+            if self.market_update_interval[name] == 0.0:
+                self.market_update_interval[name] = interval
+            else:
+                self.market_update_interval[name] = self.market_update_interval[name]*0.999 + interval*0.001
+        self.market_update_time[name] = now_time
+        ### 初始化depths
+        if self.depths[name] == list():
+            self.depths[name] = data['data']
+        ### 判断是否需要触发ondepth
+        # 如果是交易盘口
+        if name == self.trade_name:
+            ### 更新depths
+            self.depths[name] = data['data']
+            # 允许交易
+            if self.mode_signal == 0 and self.ready:
+                ### 聚合行情处理
+                self.on_agg_market()
+        ### 判断是否为当前跟踪的盘口
+        elif name == self.ref_name[self.strategy.ref_index]:
+            ### 判断是否需要触发ontick 对行情进行过滤
+            ### 过滤条件 价格变化很大 时间间隔很长
+            flag = 0
+            if abs(data['data'][utils.BP_INDEX] - self.depths[name][utils.BP_INDEX])/data['data'][utils.BP_INDEX] > 0.0002 or \
+                abs(data['data'][utils.AP_INDEX] - self.depths[name][utils.AP_INDEX])/data['data'][utils.AP_INDEX] > 0.0002 or \
+                time.time() - self.on_tick_event_time > 0.05:
+                ### 允许交易
+                flag = 1
+                ### 更新ontick触发时间记录
+                self.on_tick_event_time = time.time()
+            ### 更新depths
+            self.depths[name] = data['data']
+            # 允许交易
+            if self.mode_signal == 0 and self.ready and flag:
+                ### 更新交易数据
+                self.update_trade_msg()
+                ### 触发事件撤单逻辑
+                # 更新策略时间
+                self.strategy.local_time = time.time()
+                # 产生交易信号
+                orders = self.strategy.onTime(self.tradeMsg)
+                ### 记录指令触发信息
+                if self._not_empty(orders):
+                    self.logger.debug("触发onTick")
+                    self._update_local_orders(orders)
+                    self.loop.create_task(self.rest.handle_signals(orders))
+                    self.logger.debug(orders)
+        else:
+            pass
+
+    # @timeit
+    async def real_time_back_test(self, data):
+        ''' 
+            按照长短期回测利润选择参数
+            优先按长期回测利润选参数 如果找不到就
+            再按短期回测利润选参数 如果还找不到就
+            使用默认参数 如果默认参数亏损就触发冷静期
+        '''
+        now_time = time.time()
+        await asyncio.sleep(0.005)
+        for i in self.backtest_tasks:
+            i["backtest_engine"].backtest_time = now_time
+            i["backtest_engine"].run_by_tick(data)
+
+    def choose_params(self):
+        ''' 
+        按照长短期回测利润选择参数
+        优先按长期回测利润选参数 如果找不到就
+        再按短期回测利润选参数 如果还找不到就
+        使用默认参数 如果默认参数亏损就触发冷静期
+        '''
+        profits = []
+        for i in self.backtest_tasks:
+            # 获取绩效信息
+            e = i["backtest_engine"].equity # 最终净值
+            # 计算标准化利润
+            p = (e-self.backtest_start_cash) / self.backtest_start_cash \
+                / self.backtest_look_length * self.tick_profit_to_daily
+            # 有一定成交次数的回测结果才有代表性 持仓太久的参数禁止使用
+            _trade_num = i['backtest_engine'].trade_num
+            _avg_hold_time = i['backtest_engine'].avg_hold_time
+            _equity_high = i['backtest_engine'].equity_high
+            # 排除交易次数太少的参数
+            if i['open'] <= 0.002:
+                if _trade_num < 10:
+                    p = 0.0
+            # 排除长期持仓的参数
+            if _avg_hold_time > 600:
+                p = 0.0 
+            # 排除近期回撤较大的参数
+            if _equity_high > e*1.01:
+                p = 0.0
+            profits.append(p) #利润
+            ############## 重置回测
+            # if _trade_num > 200:
+            #     i["backtest_engine"].trade_num = 0
+            #     i["backtest_engine"].equity = self.backtest_start_cash
+        # 盈利参数个数不能太少 防止孤岛参数
+        win_num = 0
+        for i in profits:
+            if i > 0.0:
+                win_num += 1
+        cond1 = win_num > self.backtest_num*0.1
+        cond2 = win_num > 2
+        cond_win = cond1 and cond2
+        if cond_win:
+            # 按最优回测结果调整参数
+            max_profit = max(profits)
+            max_index = profits.index(max_profit)
+            self.strategy.trade_open_dist = self.backtest_tasks[max_index]["open"]
+            self.strategy.trade_close_dist = self.backtest_tasks[max_index]["close"]
+            self.strategy.ref_index = self.backtest_tasks[max_index]["index"]
+            self.strategy.post_side = self.backtest_tasks[max_index]["side"]
+            self.strategy.predict_alpha = self.backtest_tasks[max_index]["alpha"]
+            # 检查是否需要关闭回测
+            # if self.strategy.ready == 1:
+            #     self.backtest = 0
+        else:
+            # 如果没有符合条件的盈利参数
+            self.strategy.trade_open_dist = 0.01
+            self.strategy.trade_close_dist = 0.00001
+            self.strategy.ref_index = 0
+            self.strategy.post_side = 0
+            self.strategy.predict_alpha = 0
+            # 检查是否需要关闭回测
+            # if self.strategy.ready == 1:
+            #     self.backtest = 0
+            #     self.exit_msg = "未找到合适参数 停机"
+            #     self.stop()
+        return
+
+
+    def update_equity(self, data):
+        '''
+            更新保证金信息
+            合约一直更新
+            现货只有当出现异常时更新
+        '''
+        if "spot" in self.exchange:
+            pass
+        else:
+            self.local_cash = data[self.quote] * self.used_pct
+
+    def update_exit(self, data):
+        '''
+            底层触发停机
+        '''
+        self.exit_msg = data
+        self.stop()
+
+    def get_all_market_data(self):
+        '''
+            只能定时触发
+            组合市场信息=交易盘口+参考盘口
+        '''
+        market = []
+        data = self.ws._get_data()["data"]
+        market += data
+        for i in self.ref_name:
+            data = self.ws_ref[i]._get_data()["data"]
+            market += data
+        # handle save real market data
+        if self.save:
+            with open(f'./{self.csvname}.csv',
+                        'a',
+                        newline='',
+                        encoding='utf-8') as f:
+                writer = csv.writer(f, delimiter=',')
+                writer.writerow(market)
+        return market
+
+    async def before_trade(self):
+        ####### 启动ws #######
+        # 启动交易ws
+        # 当开启回测时才订阅交易盘口的成交流
+        _sub_trade = int(self.params.backtest)
+        _sub_fast = int(self.params.fast)
+        self.loop.create_task(self.ws.run(is_auth=1, sub_trade=_sub_trade, sub_fast=0))
+        for i in self.ref_name:
+            # 启动参考ws 参考盘口使用fast行情性能消耗更大 使用普通行情可以节省性能
+            self.loop.create_task(self.ws_ref[i].run(is_auth=0, sub_trade=0, sub_fast=_sub_fast))
+        await asyncio.sleep(1)
+        ###### 做交易前准备工作 ######
+        # 买入平台币
+        await self.rest.buy_token()
+        await asyncio.sleep(1)
+        # 清空挂单和仓位
+        await self.rest.check_position(hold_coin=self.hold_coin)
+        await asyncio.sleep(1)
+        # 获取市场信息
+        await self.rest.before_trade()
+        await asyncio.sleep(1)
+        # 获取价格信息
+        ticker = await self.rest.get_ticker()
+        mp = ticker['mp']
+        # 获取账户信息
+        await asyncio.sleep(1)
+        await self.rest.get_equity()
+        # 初始资金
+        start_cash = self.rest.cash_value * self.used_pct
+        start_coin = self.rest.coin_value * self.used_pct
+        if start_cash == 0.0 and start_coin == 0.0:
+            self.exit_msg = f"初始为零 cash: {start_cash} coin: {start_coin}"
+            self.stop()
+        self.logger.info(f"初始cash: {start_cash} 初始coin: {start_coin}")
+        # 初始化策略基础信息
+        if isinstance(mp, float):
+            if mp <= 0.0:
+                self.exit_msg = f"初始价格获取错误 {mp}"
+                self.stop()
+            else:
+                print(f"初始价格为 {mp}")
+        else:
+            self.exit_msg = f"初始价格获取错误 {mp}"
+            self.stop()
+        self.strategy.mp = mp
+        self.strategy.start_cash = start_cash 
+        self.strategy.start_coin = start_coin
+        self.strategy.start_equity = start_cash + start_coin * mp
+        self.strategy.max_equity = self.strategy.start_equity
+        self.strategy.equity = self.strategy.start_equity
+        self.strategy.total_amount = self.strategy.equity * self.strategy.leverrate / self.strategy.mp
+        self.strategy.stepSize = self.rest.stepSize if self.rest.stepSize < 1.0 else int(self.rest.stepSize)
+        self.strategy.tickSize = self.rest.tickSize if self.rest.tickSize < 1.0 else int(self.rest.tickSize)
+        if self.strategy.stepSize == None or self.strategy.tickSize == None:
+            self.exit_msg = f"交易精度未正常获取 stepsize: {self.strategy.stepSize} ticksize: {self.strategy.tickSize}"
+            self.stop()
+        else:
+            self.logger.info(f"数量精度{self.strategy.stepSize}")
+            self.logger.info(f"价格精度{self.strategy.tickSize}")
+        grid = float(self.params.grid)
+        if "spot" in self.exchange:
+            long_one_hand_value = start_cash * float(self.params.leverrate) / grid
+            short_one_hand_value = start_coin * mp * float(self.params.leverrate) / grid
+            long_one_hand_amount = float(Decimal(str(long_one_hand_value / mp//self.strategy.stepSize))*Decimal(str(self.strategy.stepSize)))
+            short_one_hand_amount = float(Decimal(str(short_one_hand_value / mp//self.strategy.stepSize))*Decimal(str(self.strategy.stepSize)))
+        else:
+            long_one_hand_value = start_cash * float(self.params.leverrate) / grid
+            short_one_hand_value = start_cash * float(self.params.leverrate) / grid
+            long_one_hand_amount = float(Decimal(str(long_one_hand_value / mp//self.strategy.stepSize))*Decimal(str(self.strategy.stepSize)))
+            short_one_hand_amount = float(Decimal(str(short_one_hand_value / mp//self.strategy.stepSize))*Decimal(str(self.strategy.stepSize)))
+        # 检查是否满足最低交易要求
+        print(f"最低单手交易下单量为 buy: {long_one_hand_amount} sell: {short_one_hand_amount}")
+        if (long_one_hand_amount == 0 and short_one_hand_amount == 0) or (long_one_hand_value < 20 and short_one_hand_value < 20):
+            self.exit_msg = f"初始下单量太少 buy: {long_one_hand_amount} sell: {short_one_hand_amount}"
+            self.stop()
+        # 初始化调度器
+        self.local_cash = start_cash
+        self.local_coin = start_coin
+        # 配置在线训练
+        if self.backtest:
+            # 设置策略默认参数
+            self.strategy.trade_close_dist = 0.00001
+            self.strategy.trade_open_dist = 0.01
+            self.backtest_look_length = 86400 / self.interval  # 回测区间足够长
+            self.backtest_tasks = list()
+            self.tick_profit_to_daily = (86400/self.interval)
+            self.backtest_start_cash = 1000000.0
+            # 备选参数
+            open_list, close_list, alpha_list = utils.get_backtest_set(self.base)
+            if 'spot' in self.exchange:
+                side_list = []
+                if long_one_hand_amount > 0:
+                    side_list.append(1)
+                if short_one_hand_amount > 0:
+                    side_list.append(-1)
+                if 1 in side_list and -1 in side_list:
+                    side_list.append(0)
+            else:
+                side_list = [-1,0,1]
+            side_list_allow = []
+            for s in side_list:
+                if s in utils.POST_SIDE_LIMIT:
+                    side_list_allow.append(s)
+            side_list = side_list_allow
+            for _open in open_list:
+                for _side in side_list:
+                    for _close in close_list[_open]:
+                        for _index in range(self.ref_num):
+                            for _alpha in alpha_list:
+                                task = dict()
+                                st = strategy.Strategy(self.params, is_print=0)
+                                st.leverrate = 1.0
+                                st.trade_open_dist = _open
+                                st.trade_close_dist = _close
+                                st.predict_alpha = _alpha
+                                st.ref_index = _index
+                                st.post_side = _side
+                                st.exchange = "dummy_usdt_swap"
+                                st.local_start_time = 0.0
+                                bt = backtest.Backtest(st, is_plot=0)
+                                bt.start_cash = self.backtest_start_cash
+                                task["backtest_engine"] = bt
+                                task["open"] = _open
+                                task["close"] = _close
+                                task["index"] = _index
+                                task["side"] = _side
+                                task["alpha"] = _alpha
+                                self.backtest_tasks.append(task)
+            backtest_num = len(self.backtest_tasks)
+            self.backtest_num = backtest_num
+            self.logger.info(f'在线模拟撮合数量{backtest_num}')
+            self.logger.info(f'当前为在线训练模式 需预热{utils.BACKTEST_PREHOT_SECOND}秒 请耐心等候...')
+        else:
+            self.logger.info('当前为指定参数模式...')
+        ###### 交易前准备就绪 可以开始交易 ######
+        self.loop.create_task(self.rest.go())
+        self.loop.create_task(self.on_timer())
+        self.loop.create_task(self._run_server())
+        self.loop.create_task(self.run_stratey())
+        #self.loop.create_task(self.post_loop())    #改
+        self.loop.create_task(self.early_stop_loop())
+
+    def update_trade_msg(self):
+        # 更新保证金
+        self.tradeMsg.cash = round(self.local_cash,10)
+        self.tradeMsg.coin = round(self.local_coin,10)
+        # 使用本地推算仓位
+        self.tradeMsg.position = self.local_position_by_orders
+        # 更新订单
+        self.tradeMsg.orders = self.local_orders
+        ### 更新 ref
+        ref_tickers = []
+        for i in self.ref_name:
+            ref_tickers.append([self.tickers[i]['bp'], self.tickers[i]['ap']])
+        self.tradeMsg.ref_price = self.Predictor.Get_ref(ref_tickers)
+
+    async def server_handle(self, request):
+        '''中控数据接口'''
+        if 'spot' in self.exchange:
+            pos = self.local_position_by_orders.longPos - self.local_position_by_orders.shortPos
+        else:
+            pos = self.local_position.longPos - self.local_position.shortPos
+        if pos > 0.0:
+            entryPrice = self.local_position_by_orders.longAvg
+        elif pos < 0.0:
+            entryPrice = self.local_position_by_orders.shortAvg
+        else:
+            entryPrice = 0
+        return web.Response(body=json.dumps({
+            "now_balance": round(self.strategy.equity/self.used_pct, 4),    #钱包余额
+            "unrealized_pn_l": round(self.local_profit, 4),    #未实现盈利
+            "pos": round(pos, 8),        #持仓数量
+            "entry_price": round(entryPrice, 8),      #开仓价格
+            "now_price": round(self.strategy.mp, 8),           #当前价格
+        }))
+
+    async def change(self, request):
+        '''中控台修改参数'''
+        try:
+            data = await request.json()
+            if "stop" in data:
+                self.logger.warning('中控停机')
+                self.exit_msg = '中控停机'
+                self.stop()
+                return web.Response(text=f"停机成功")
+                
+            ip = request.remote
+            print(f'从{ip}收到更新参数请求',data)
+            if isinstance(data, str):
+                data = json.loads(data)
+
+            if self.backtest == 1:
+                return web.Response(text="自动调参模式不允许手动修改参数")
+            else:
+                open = float(data['open'])
+                close = float(data['close'])
+                self.strategy.trade_open_dist = open
+                self.strategy.trade_close_dist = close
+                return web.Response(text=f"参数修改成功 {open} {close}")
+        except Exception as e:
+            return web.Response(text=f"参数修改失败 {e}")
+
+    # @utils.timeit
+    def check_risk(self):
+        '''检查风控'''
+        if self.strategy.start_cash == 0.0:
+            print("请检查交易账户余额")
+            return 0
+        if isinstance(self.strategy.mp, float):
+            pass
+        else:
+            print("请检查最新价格")
+            return 0
+        ############
+        # print("当前线程数",self.process.num_threads())
+        ###### 资源风控0 ######
+        cpu_pct = psutil.cpu_times_percent().user
+        self.cpu_ema = self.cpu_ema * 0.8 + cpu_pct * 0.2
+        # print(f"cpu占用 {cpu_pct}")
+        if self.cpu_ema > 95:
+            msg = f"cpu占用过高 {self.cpu_ema} 准备停机"
+            print(msg)
+            self.logger.warning(msg)
+            self.exit_msg = msg
+            self.stop()
+        mm_pct = psutil.virtual_memory().percent
+        self.mm_ema = self.mm_ema * 0.8 + mm_pct * 0.2
+        # print(f"内存占用 {mm_pct}")
+        if self.mm_ema > 95:
+            msg = f"内存占用过高 {self.mm_ema} 准备停机"
+            print(msg)
+            self.logger.warning(msg)
+            self.exit_msg = msg
+            self.stop()
+        ###### 回撤风控1 ######
+        if "spot" not in self.exchange:
+            draw_back = 1-self.strategy.equity/self.strategy.max_equity
+            if draw_back > self.stoploss:
+                msg = f"{self.acct_name} 总资金吊灯回撤{draw_back} 当前{self.strategy.equity} 最高{self.strategy.max_equity} 触发止损 准备停机"
+                print(msg)
+                self.logger.warning(msg)
+                self.exit_msg = msg
+                self.stop()
+        ###### 回撤风控2 ######
+        draw_back = self.local_profit/self.strategy.start_equity
+        if draw_back < -self.stoploss:
+            msg = f"{self.acct_name} 交易亏损 触发止损 准备停机"
+            print(msg)
+            self.logger.warning(msg)
+            self.exit_msg = msg
+            self.stop()
+        ###### 报单延迟风控 ######
+        if self.rest.avg_delay > 5000: # 平均延迟允许上限 5000ms 
+            msg = f"{self.acct_name} 延迟爆表 触发风控 准备停机"
+            print(msg)
+            self.logger.warning(msg)
+            self.exit_msg = msg
+            self.stop()
+        ###### 仓位异常风控 ######
+        ### 合约60秒更新一次绝对仓位 ###
+        # 连续5分钟仓位不正确就停机
+        # 5 * 60 = 300   300/10 = 30
+        diff_pos = max(abs(self.local_position.longPos - self.local_position_by_orders.longPos),abs(self.local_position.shortPos - self.local_position_by_orders.shortPos))
+        if "spot" not in self.exchange:
+            diff_pos_value = diff_pos * self.strategy.mp
+            if diff_pos_value > self.strategy._min_amount_value:
+                msg = f"{self.acct_name} ***发现仓位异常*** 推算{self.local_position_by_orders.__dict__} 本地{self.local_position.__dict__}"
+                print(msg)
+                self.logger.warning(msg)
+                self.position_check_series.append(1)
+            else:
+                self.position_check_series.append(0)
+            if len(self.position_check_series) > 30:
+                del(self.position_check_series[0])
+            if sum(self.position_check_series) >= 30:
+                msg = f"{self.acct_name} 合约连续检查本地仓位和推算仓位不相符 退出"
+                print(msg)
+                self.logger.warning(msg)
+                self.exit_msg = msg
+                self.stop()    
+        ###### 下单异常风控 ######
+        if self.strategy.total_amount == 0.0:
+            msg = f"{self.acct_name} 开仓量为零 退出"
+            print(msg)
+            self.logger.warning(msg)
+            self.exit_msg = msg
+            self.stop()
+        ###### 行情更新异常风控 ######
+        for name in self.ref_name:
+            delay = round((time.time() - self.market_update_time[name]) * 1000, 3)
+            if delay > utils.MARKET_DELAY_LIMIT: # thre
+                msg = f"{self.acct_name} ticker_name:{name} delay:{delay}ms 行情更新延迟过高 退出"
+                self.logger.error(msg)
+                self.exit_msg = msg
+                self.stop()
+        for name in [self.trade_name]:
+            delay = round((time.time() - self.market_update_time[name]) * 1000, 3)
+            if delay > utils.MARKET_DELAY_LIMIT: # thre
+                msg = f"{self.acct_name} ticker_name:{name} delay:{delay}ms 行情更新延迟过高 退出"
+                self.logger.error(msg)
+                self.exit_msg = msg
+                self.stop()
+        ###### 订单异常风控 ######
+        for cid in self.local_orders:
+            if time.time() - self.local_orders[cid]["localtime"] > 300: # 订单长时间停留 怀疑漏单 但未必一定漏 5min
+                msg = f"{self.acct_name} cid:{cid} 订单停留过久 怀疑异常 退出"
+                self.logger.error(msg)
+                self.exit_msg = msg
+                self.stop()
+        ###### 持仓均价异常风控 ######
+        if isinstance(self.strategy.long_pos_bias, float):
+            # 偏离mp较大 且持仓较大 说明出现异常
+            if self.strategy.long_hold_value > 2*self.strategy._min_amount_value:
+                if self.strategy.long_pos_bias > 4.0 or self.strategy.long_pos_bias < -2.0:
+                    msg = f"{self.acct_name} long_pos_bias:{self.strategy.long_pos_bias} 持仓均价异常 退出"
+                    self.logger.error(msg)
+                    self.exit_msg = msg
+                    self.stop()
+        if isinstance(self.strategy.short_pos_bias, float):
+            # 偏离mp较大 且持仓较大 说明出现出现异常
+            if self.strategy.short_hold_value > 2*self.strategy._min_amount_value:
+                if self.strategy.short_pos_bias > 4.0 or self.strategy.short_pos_bias < -2.0: 
+                    msg = f"{self.acct_name} short_pos_bias:{self.strategy.short_pos_bias} 持仓均价异常 退出"
+                    self.logger.error(msg)
+                    self.exit_msg = msg
+                    self.stop()
+        ###### 订单撤单异常风控 ######
+        for cid in self.local_cancel_log:
+            if self.local_cancel_log[cid] > 300:
+                msg = f"{self.acct_name} 订单长时间无法撤销 退出"
+                self.logger.error(msg)
+                self.exit_msg = msg
+                self.stop()
+        ###### 定价异常风控 ######
+        if abs(self.strategy.ref_price-self.strategy.mp)/self.strategy.mp > 0.03:
+            msg = f"{self.acct_name} 定价偏离过大 怀疑异常 退出"
+            self.logger.error(msg)
+            self.exit_msg = msg
+            self.stop()
+
+    async def exit(self, delay=0):
+        '''退出操作'''
+        try:
+            self.logger.info(f"预约退出操作 delay:{delay}")
+            if delay > 0:
+                await asyncio.sleep(delay)
+            self.logger.info(f"开始退出操作")
+            self.logger.info("为避免api失效导致遗漏仓位 建议人工复查")
+            await self.rest.check_position(hold_coin=self.hold_coin)
+            # stop flag
+            self.rest.stop_flag = 1
+            self.ws.stop_flag = 1
+            for i in self.ref_name:
+                self.ws_ref[i].stop_flag = 1
+            # double check 需要延迟几秒以便等待更新数据
+            await asyncio.sleep(3)
+            self.logger.info("双重检查遗漏仓位")
+            await self.rest.check_position(hold_coin=self.hold_coin)
+            self.logger.info(f'停机退出 停机原因 {self.exit_msg}')
+            await asyncio.sleep(1)
+            # 发送交易状态
+            await self._post_params()
+            # 压缩行情文件
+            utils.csv_to_gz_and_remove()
+            # close pid
+            self.logger.info("退出进程")
+        except:
+            self.logger.error(traceback.format_exc())
+        finally:
+            os._exit(0)
+
+    async def on_timer(self):
+        '''定期触发系统逻辑'''
+        await asyncio.sleep(20)
+        while 1:
+            try:
+                # 10秒检查一次风控
+                await asyncio.sleep(10)
+                # 检查风控
+                self.check_risk()
+                # stop
+                if self.mode_signal == 1:return
+                # 计算预估成交额
+                total_trade_value = self.local_buy_value + self.local_sell_value
+                self.strategy.trade_vol_24h = round(total_trade_value / (time.time()-self.pid_start_time) * 86400 / 10000, 2)
+                # 打印
+                if int(self.params.log):
+                    self.strategy._print_summary()
+                    # 打印行情延迟监控
+                    self.logger.info('Rest 报单平均延迟 ' + str(self.rest.avg_delay) + 'ms ')
+                    self.logger.info('Rest 报单最高延迟 ' + str(self.rest.max_delay) + 'ms ')
+                    for name in self.market_update_interval:
+                        avg_interval = round(self.market_update_interval[name]*1e3, 2)
+                        self.logger.info(f'WS 盘口{name}行情 平均更新间隔 {avg_interval}ms')
+                # 选择参数
+                if self.backtest:
+                    self.choose_params()
+            except asyncio.CancelledError:
+                print('定期循环任务取消')
+            except:
+                print("定时循环系统出错")
+                self.logger.error(traceback.print_exc())
+                await asyncio.sleep(10)
+
+    async def _post_params(self):
+        '''推送交易信息'''
+        profit = round(self.strategy.daily_return/self.strategy.leverrate,4)
+        if time.time() - self.pid_start_time > utils.EARLY_STOP_SECOND * 0.5 or profit < 0.0:
+            await utils._post_params(
+                "http://wwww.khods.com:8888/post_params", 
+                self.params.proxy,
+                ujson.dumps({
+                    "pwd":"123456",
+                    "exchange":self.params.exchange,
+                    "pair":self.params.pair,
+                    "open":self.params.open,
+                    "close":self.params.close,
+                    "refexchange":self.params.refexchange[self.strategy.ref_index],
+                    "profit":profit,
+                })
+            )
+        else:
+            self.logger.info("不满足推送过滤条件 放弃推送参数")
+
+    async def post_loop(self):
+        '''定期触发交易信息推送'''
+        await asyncio.sleep(30)
+        _interval = 60 # 定期推送一次盈利情况
+        while 1:
+            try:
+                # 定期推送一次
+                await asyncio.sleep(_interval)
+                # 发送交易状态
+                await self._post_params()
+            except asyncio.CancelledError:
+                print('post loop 循环任务取消')
+            except:
+                print("post loop 循环系统出错")
+                self.logger.error(traceback.print_exc())
+                await asyncio.sleep(10)
+    
+    async def early_stop_loop(self):
+        '''定期触发交易信息推送'''
+        if self.father:
+            self.logger.info(f'以父进程方式启动 关闭早停检测')
+            return
+        else:
+            self.logger.info(f'以子进程方式启动 开启早停检测')
+        await asyncio.sleep(30)
+        _interval = utils.EARLY_STOP_SECOND
+        _last_equity = self.strategy.start_equity
+        _last_local_profit = 0.0
+        while 1:
+            try:
+                # 休眠
+                await asyncio.sleep(_interval)
+                ###### 子进场早停风控 ######
+                self.logger.info(f'当前净值{self.strategy.equity} 上次检测时净值{_last_equity} 当前累积利润{self.local_profit} 上次检测时利润{_last_local_profit}')
+                # 检查是否需要早停 没有成交 或者 亏损
+                if self.strategy.equity <= _last_equity or self.local_profit <= _last_local_profit:
+                    self.logger.info('触发早停条件 当零持仓时退出')
+                    # 没有持仓
+                    for _ in range(30):
+                        await asyncio.sleep(5)
+                        if self.strategy.long_hold_value < self.strategy._min_amount_value and \
+                            self.strategy.short_hold_value < self.strategy._min_amount_value:
+                            msg = f"{self.acct_name} 子进程盈利状况不理想 提前停机 退出"
+                            self.logger.error(msg)
+                            self.exit_msg = msg
+                            self.stop()
+                # 更新上一次检测的净值
+                _last_equity = self.strategy.equity
+                _last_local_profit = self.local_profit
+            except asyncio.CancelledError:
+                print('early stop 循环任务取消')
+            except:
+                print("early stop 循环系统出错")
+                self.logger.error(traceback.print_exc())
+                await asyncio.sleep(10)
+
+    def on_agg_market(self):
+        '''
+            处理聚合行情
+            1. 获取聚合行情
+            2. 更新预测器
+            3. 触发tick回测
+        '''
+        ### 更新聚合市场数据
+        agg_market = self.get_all_market_data()
+        ### 更新聚合市场信息
+        self.tradeMsg.market = agg_market
+        ### 更新预测器
+        self.Predictor.onTime(agg_market)
+        ### 触发回测
+        if self.backtest:
+            self.loop.create_task(self.real_time_back_test(self.tradeMsg))
+
+    async def run_stratey(self):
+        '''
+            定期触发策略
+        '''
+        print('定时触发器启动')
+        # 准备交易
+        try:
+            print('前期准备完成')
+            await asyncio.sleep(10)
+            while 1:
+                try:
+                    # 时间预设
+                    start_time = time.time()
+                    ### 是否准备充分
+                    if self.ready:
+                        ### 更新交易信息集合
+                        self.update_trade_msg()
+                        ### 触发策略
+                        if self.mode_signal == 0:
+                            pass
+                            # # 更新策略时间
+                            # self.strategy.local_time = time.time()
+                            # # 产生信号
+                            # orders = self.strategy.onTime(self.tradeMsg)
+                            # ### 记录指令触发信息
+                            # if self._not_empty(orders):
+                            #     self.logger.debug("触发onTime")
+                            #     self._update_local_orders(orders)
+                            #     self.loop.create_task(self.rest.handle_signals(orders))
+                            #     self.logger.debug(orders)
+                        else:
+                            if self.mode_signal > 1:self.mode_signal -= 1
+                            if self.mode_signal == 1:return
+                            # 触发策略
+                            # 更新策略时间
+                            self.strategy.local_time = time.time()
+                            # 获取信号
+                            if self.mode_signal > 20:
+                                # 先执行onExit
+                                orders = self.strategy.onExit(self.tradeMsg)
+                                ### 记录指令触发信息
+                                if self._not_empty(orders):
+                                    self.logger.debug("触发onExit")
+                                    self._update_local_orders(orders)
+                                    self.loop.create_task(self.rest.handle_signals(orders))
+                                    self.logger.debug(orders)
+                            else:
+                                # 再执行onSleep
+                                orders = self.strategy.onSleep(self.tradeMsg)
+                                ### 记录指令触发信息
+                                if self._not_empty(orders):
+                                    self.logger.debug("触发onSleep")
+                                    self._update_local_orders(orders)
+                                    self.loop.create_task(self.rest.handle_signals(orders))
+                                    self.logger.debug(orders)
+                        ############################################################
+                    else:
+                        self.check_ready()
+                    ### 计算耗时并进行休眠
+                    pass_time = time.time()-start_time
+                    await asyncio.sleep(utils.clip(self.interval-pass_time, 0.0, 1.0))
+                except asyncio.CancelledError:
+                    print('策略触发任务取消')
+                except:
+                    self.logger.error(traceback.format_exc())
+                    traceback.print_exc()
+                    await asyncio.sleep(10)
+        except asyncio.CancelledError:
+            print('策略触发任务取消')
+        except:
+            self.logger.error(traceback.format_exc())
+            traceback.print_exc()
+            await asyncio.sleep(10)
+    
+    def check_ready(self):
+        '''
+            判断初始数据是否齐全
+        '''
+        ### 检查 ticker 行情
+        for i in self.ref_name:
+            if i not in self.tickers or self.tickers[i] == {}:
+                print("参考盘口ticker未准备好")
+                return
+            else:
+                if self.tickers[i]['bp'] == 0 or self.tickers[i]['ap'] == 0:
+                    print("参考盘口ticker未准备好")
+                    return
+        if self.trade_name not in self.tickers or self.tickers[self.trade_name] == {}:
+            print("交易盘口ticker未准备好")
+            return
+        else:
+            if self.tickers[self.trade_name]['bp'] == 0 or self.tickers[self.trade_name]['ap'] == 0:
+                print("交易盘口ticker未准备好")
+                return
+        ### 检查 market 行情
+        all_market = self.get_all_market_data()
+        if len(all_market) != utils.LEN*(1+self.ref_num):
+            print("聚合行情未准备好")
+            return
+        else:
+            # 如果行情已经就绪 预热trademsg和predictor
+            print("聚合行情准备就绪")
+            self.tradeMsg.market = all_market
+            self.Predictor.onTime(all_market)
+        self.ready = 1
+
+    def stop(self):
+        '''
+            停机函数
+            mode_signal 不能小于80
+            前6秒用于maker平仓
+            后2秒用于撤maker平仓单
+            休眠2秒再执行check_position 避免卡单导致漏仓位
+        '''
+        self.logger.info(f'进入停机流程...')
+        self.mode_signal = 80
+        # 等strategy onExit 彻底执行完毕 进入沉默状态之后 再进入exit 否则可能导致多处同时操作订单
+        # 尽量减少大仓位直接take平
+        self.loop.create_task(self.exit(delay=10))
+
+    async def _run_server(self):
+        print('server正在启动...')
+        for _ in range(30):
+            await asyncio.sleep(5)
+            if self.strategy.equity > 0.0:break
+        app = web.Application()
+        app.router.add_route('GET', '/account', self.server_handle)
+        app.router.add_route('POST', '/change', self.change)
+        try:
+            self.loop.create_task(web._run_app(app, host='0.0.0.0', port=self.params.server_port, handle_signals=False))
+        except:
+            self.logger.error(f"Server启动失败")
+            self.logger.error(traceback.format_exc())
+            self.exit_msg = "服务启动失败 停机退出"
+            self.stop()
+
+    def run(self):
+        '''启动ws行情获取'''
+
+        def keyboard_interrupt(s, f):
+            self.logger.info("收到退出信号 准备关机")
+            self.stop()
+
+        try:
+            signal.signal(signal.SIGINT, keyboard_interrupt)
+            signal.signal(signal.SIGTERM, keyboard_interrupt)
+            if 'win' not in sys.platform:
+                signal.signal(signal.SIGKILL, keyboard_interrupt)
+                signal.signal(signal.SIGQUIT, keyboard_interrupt)
+        except:
+            pass
+
+        self.loop.create_task(self.before_trade())
+
+        print(f'判断启动方式...')
+        if self.father:
+            print('以父进程方式启动 最大允许运行时间为30天')
+            self.loop.create_task(self.exit(delay=60*60*24*30))
+        else:
+            print('以子进程方式启动 最大允许运行时间为60分钟')
+            self.loop.create_task(self.exit(delay=utils.CHILD_RUN_SECOND)) 
+        
+        self.loop.run_forever()
+
+if __name__ == "__main__":
+
+    if 0:
+        utils.check_auth()
+
+    if 0:
+        utils.check_time()
+
+    pnum = len(sys.argv)
+
+    if pnum > 0:
+        fname = None
+        log_file = None
+        pidnum = None
+        father = 1
+        for i in range(pnum):
+            print(f"第{i}个参数为:{sys.argv[i]}")
+            if sys.argv[i] == '-c' or sys.argv[i] == '--c': 
+                fname = sys.argv[i+1]
+            elif sys.argv[i] == '-h': 
+                print("帮助文档")
+            elif sys.argv[i] == '-log_file' or sys.argv[i] == '--log_file':
+                log_file = sys.argv[i+1]
+            elif sys.argv[i] == '-num' or sys.argv[i] == '--num':
+                pidnum = sys.argv[i+1]
+            elif sys.argv[i] == '-v' or sys.argv[i] == '--v':
+                print(f"当前版本为 V{VERSION}")
+            elif sys.argv[i] == '-child' or sys.argv[i] == '--child':
+                father = 0
+                print(f"当前以子进程方式启动")
+        if fname and log_file and pidnum:
+            print(f"指定的配置为 fname:{fname} log_file:{log_file} pidnum:{pidnum} father:{father}")
+            date = time.strftime("%Y%m%d", time.localtime()) 
+            logname = f"{log_file}-{date}"
+            quant = Quant(utils.get_params(fname), logname, father)
+            quant.run()
+        elif fname:
+            print(f"运行指定配置文件{fname}")
+            quant = Quant(utils.get_params(fname),father=father)
+            quant.run()
+        else:
+            print("缺少指定参数 运行默认配置文件")
+            fname = 'config.toml'
+            quant = Quant(utils.get_params(fname),father=father)
+            quant.run()
+    else:
+        fname = 'config.toml'
+        quant = Quant(utils.get_params(fname))
+        quant.run()

+ 38 - 0
readme.txt

@@ -0,0 +1,38 @@
+运行环境:
+	Aws20 t3.m 2h 1gb内存  10u一个月
+	东京 A区
+	Ubuntu20系统
+
+环境安装:
+	sudo sed -i 's/^#\?PermitRootLogin.*/PermitRootLogin yes/g' /etc/ssh/sshd_config;
+	sudo sed -i 's/^#\?PasswordAuthentication.*/PasswordAuthentication yes/g' /etc/ssh/sshd_config;
+	sudo service sshd restart
+
+	sudo apt update  -y
+	sudo apt list --upgradable
+	sudo apt upgrade -y
+	sudo apt install gcc -y
+	gcc -v
+
+	sudo apt install build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libreadline-dev libffi-dev libbz2-dev liblzma-dev sqlite3 libsqlite3-dev tk-dev uuid-dev libgdbm-compat-dev -y
+
+	wget https://www.python.org/ftp/python/3.10.0/Python-3.10.0.tgz
+	tar xzf Python-3.10.0.tgz
+	cd Python-3.10.0
+
+	sudo ./configure --enable-optimizations
+	sudo make altinstall
+
+	python3.10 --version
+	pip3.10 --version
+
+	pip3.10 install numpy pandas toml ujson psutil aiohttp flask joblib sklearn uvloop requests_toolbelt pycrypto netifaces
+
+修改配置文件:
+	config.toml
+	utils.py   get_params函数定义了每个配置文件具体是什么作用
+
+	
+运行指令:
+	python3 quant.py
+	

+ 1080 - 0
strategy.py

@@ -0,0 +1,1080 @@
+import time
+import traceback, utils
+import model
+import logging, logging.handlers
+from decimal import Decimal
+from decimal import ROUND_HALF_UP, ROUND_FLOOR
+
+class Strategy:
+    '''
+        策略逻辑
+    '''
+    def __init__(self, params: model.Config, is_print=0):
+        self.params = params
+        self.exchange = self.params.exchange
+        self.broker_id = params.broker_id
+        self.logger = self.get_logger()
+        #### 实盘
+        ex = self.params.exchange
+        pair = self.params.pair
+        self.trade_name = ex + '@' + pair
+        #### 参考
+        refex = self.params.refexchange
+        refpair = self.params.refpair
+        if len(refex) != len(refpair):
+            print("参考盘口数不等于参考品种数 退出")
+            return
+        self.ref_num = len(refex)
+        self.ref_name = []
+        for i in range(self.ref_num):
+            name = refex[i] + '@' + refpair[i]
+            self.ref_name.append(name)
+        #### maker mode
+        self.maker_mode = 'free'
+        ####
+        self._print_time = 0
+        self.local_orders = dict()
+        self.pos = model.Position()
+        self.long_hold_value = 0.0
+        self.short_hold_value = 0.0
+        self.equity = 0.0
+        self.coin = 0.0
+        self.cash = 0.0
+        self.start_equity = 0.0
+        self.start_coin = 0.0
+        self.start_cash = 0.0
+        self.max_equity = 0.0
+        self.local_profit = 0.0
+        self.total_amount = 0.0
+        self.ready = 0
+        self._is_print = is_print
+        self._min_amount_value = 30.0  # 最小下单额 防止下单失败
+        self._max_amount_value = 10000.0  # 最大下单额 防止下单过重 平不掉就很悲剧
+        self.local_time = time.time()
+        self.local_start_time = time.time()
+        self.interval = float(self.params.interval)
+        self.mp = None
+        self.bp = None
+        self.ap = None
+        self.ref_price = 0.0
+        self.ref_bp = 0.0
+        self.ref_ap = 0.0
+        self.stepSize = 1e-10
+        self.tickSize = 1e-10
+        self.maxPos = 0.0
+        self.profit = 0.0
+        self.daily_return = 0.0
+        #### 
+        self.mp_ewma = None
+        self.adjust_leverrate = 1.0
+        #### 持仓偏差
+        self.long_pos_bias = None
+        self.short_pos_bias = None
+        self.long_hold_rate = 0.0
+        self.short_hold_rate = 0.0
+        #### 时间相关参数
+        self.leverrate = float(self.params.leverrate)  # 最大仓位
+        if "spot" in self.exchange:self.leverrate = min(self.leverrate, 1.0)
+        self._print_time = time.time()
+        self._start_time = time.time()
+        self.request_num = 0  # 记录请求次数
+        self.request_order_num = 0  # 记录下单次数
+        self._print_interval = 5  # 打印信息时间间隔
+        #### 距离范围
+        self.open_dist = None
+        self.close_dist = None
+        #### 查单频率    
+        self._check_local_orders_time = time.time()
+        self._check_local_orders_interval = 10.0
+        #### 内部限頻
+        try:
+            self.place_order_limit = float(self.params.place_order_limit)
+        except:
+            self.place_order_limit = 0
+        self.request_limit_check_time = time.time()
+        self.request_limit_check_interval = 10.0
+        self.limit_requests_num = utils.get_limit_requests_num_per_second(
+            self.params.exchange, self.place_order_limit) * self.request_limit_check_interval
+        self.limit_order_requests_num = utils.get_limit_order_requests_num_per_second(
+            self.params.exchange, self.place_order_limit) * self.request_limit_check_interval
+        #### 网络请求频率
+        self._req_num_per_window = 0  
+        # 开仓下单间隔 均匀下单机会
+        self.post_open_time = time.time()
+        self.post_open_interval = 1/utils.get_limit_order_requests_num_per_second(self.params.exchange)
+        #### 策略参数
+        # 距离类参数
+        self.trade_close_dist = 0.00001  # 基础挂单距离  
+        self.trade_open_dist = 0.01  # 基础挂单距离  
+        #### 时间类参数
+        # 撤单限頻队列 强制等待 防止频繁发起撤单
+        self.in_cancel = dict()
+        self.cancel_wait_interval = 0.2
+        # 查单限頻队列 强制等待
+        self.in_check = dict()
+        self.check_wait_interval = 10.0
+        # ref index
+        self.ref_index = 0
+        # predict
+        self.predict = 0.0
+        self.predict_alpha = 0.0
+        # post side 
+        self.post_side = 0
+        # trade vol
+        self.trade_vol_24h = 0.0
+        # grid num
+        self.grid = float(self.params.grid)
+    
+    def get_logger(self):
+        logger = logging.getLogger(__name__)
+        logger.setLevel(logging.DEBUG)
+        # log to txt
+        formatter = logging.Formatter('[%(asctime)s] - %(levelname)s - %(message)s')
+        handler = logging.handlers.RotatingFileHandler(f"log.log",maxBytes=1024*1024)
+        handler.setLevel(logging.DEBUG)
+        handler.setFormatter(formatter)
+        logger.addHandler(handler)
+        return logger
+
+    # @utils.timeit
+    def _update_data(self, data:model.TraderMsg):
+        '''更新本地数据'''
+        try:
+            # 更新信息
+            # orders
+            self.local_orders.clear()
+            self.local_orders.update(data.orders)
+            # position
+            if self.pos.longPos != data.position.longPos:
+                self.pos.longPos = data.position.longPos
+                self.pos.longAvg = data.position.longAvg
+            if self.pos.shortPos != data.position.shortPos:
+                self.pos.shortPos = data.position.shortPos
+                self.pos.shortAvg = data.position.shortAvg
+            # bp ap 
+            self.bp = data.market[utils.BP_INDEX]
+            self.ap = data.market[utils.AP_INDEX]
+            # trade mp
+            self.mp = (self.bp+self.ap)*0.5
+            ########### 动态杠杆调节 ###########
+            if self.mp_ewma == None:
+                self.mp_ewma = self.mp
+            else:
+                self.mp_ewma = self.mp_ewma*0.999 + self.mp*0.001
+            if self.mp > self.mp_ewma:
+                # 增加杠杆
+                self.adjust_leverrate = 1.0
+            else:
+                # 降低杠杆
+                self.adjust_leverrate = 0.8
+            ########### 当前持仓价值 ###########
+            self.long_hold_value = self.pos.longPos * self.mp
+            self.short_hold_value = self.pos.shortPos * self.mp
+            ########### 现货 ###########
+            if 'spot' in self.exchange:
+                ### 计算总保证金情况
+                self.max_long_value = self.start_cash * self.leverrate * self.adjust_leverrate
+                self.max_short_value = self.start_coin * self.leverrate * self.adjust_leverrate * self.mp
+            ########### 合约 ###########
+            else:
+                ### 计算总保证金情况
+                self.max_long_value = self.equity * self.leverrate * self.adjust_leverrate
+                self.max_short_value = self.max_long_value
+            ###### maker mode ######
+            if self.ref_name[self.ref_index] == self.trade_name:
+                self.maker_mode = 'free'
+            else:
+                self.maker_mode = 'follow'
+            ###### ref price ######
+            if data.ref_price == None:
+                self.ref_bp = self.bp
+                self.ref_ap = self.ap
+                self.ref_price = self.mp
+            else:
+                self.ref_bp = data.ref_price[self.ref_index][0]
+                self.ref_ap = data.ref_price[self.ref_index][1]
+                self.ref_price = (self.ref_bp+self.ref_ap)*0.5
+            # spread
+            self.predict = utils.clip(data.predict*self.predict_alpha, -self.trade_open_dist, self.trade_open_dist)
+            # is base spread normal ? can take but can't move too far
+            # if abs(self.ref_price - self.mp)/self.mp > self.trade_open_dist*3:
+                # back to pure market making strategy
+                # self.ref_price = self.mp
+            # equity 当前账户可用cash和coin
+            self.coin = data.coin
+            self.cash = data.cash
+            if self.mp:
+                self.equity = data.cash + data.coin * self.mp
+            # max equity
+            if self.equity > self.max_equity:
+                self.max_equity = self.equity
+            self.total_amount = float(utils.fix_amount(self.equity * self.leverrate * self.adjust_leverrate / self.mp, self.stepSize))
+            if self.total_amount == 0.0:
+                if self._is_print:
+                    self.logger.error("总可开数量太少")
+            # max pos
+            maxPos = max([
+                    self.pos.longPos, self.pos.shortPos
+                ]) * self.mp // (self.equity if self.equity > 0 else 99999999)
+            if maxPos > self.maxPos:
+                self.maxPos = maxPos
+            return 1
+        except:
+            if self._is_print:
+                self.logger.error(traceback.format_exc())
+            return 0
+
+    # @utils.timeit
+    def _print_summary(self):
+        '''
+            打印状态信息
+            耗时700us
+        '''
+        msg = '>>> '
+        msg += '盘口 ' + self.exchange + ' '
+        msg += '品种 ' + self.params.pair + ' '
+        msg += '现价 ' + str(round(self.mp, 6)) + ' '
+        msg += '定价 ' + str(round(self.ref_price, 6)) + ' '
+        msg += '偏差 ' + str(round((self.ref_price-self.mp)/self.mp*100, 2)) + '% '
+        msg += '净值 ' + str(round(self.equity, 3)) + ' '
+        msg += 'Cash ' + str(round(self.cash, 3)) + ' '
+        msg += 'Coin ' + str(round(self.coin*self.mp, 3)) + ' '
+        msg += '推算利润 ' + str(self.local_profit) + ' '
+        self.profit = round(
+            (self.equity - self.start_equity) /
+            self.start_equity * 100, 3) if self.start_equity > 0 else 0
+        msg += '盈亏 ' + str(self.profit) + '% '
+        msg += '多杠杆' + str(
+            round(
+                self.pos.longPos * self.mp /
+                (self.equity if self.equity > 0 else 99999999), 3)) + ' '
+        self.long_pos_bias = None
+        if self.pos.longPos > 0.0:
+            self.long_pos_bias = round(100 - 100 * self.pos.longAvg / self.mp, 2)
+            msg += '浮盈' + str(self.long_pos_bias) + '% '
+        else:
+            msg += '浮盈 None '
+        msg += '空杠杆' + str(
+            round(
+                self.pos.shortPos * self.mp /
+                (self.equity if self.equity > 0 else 99999999), 3)) + ' '
+        self.short_pos_bias = None
+        if self.pos.shortPos > 0.0:
+            self.short_pos_bias = round(100 * self.pos.shortAvg / self.mp - 100, 2)
+            msg += '浮盈' + str(self.short_pos_bias) + '% '
+        else:
+            msg += '浮盈 None '
+        msg += '杠杆' + str(self.leverrate) + ' 动态 ' + str(self.adjust_leverrate) + ' '
+        msg += '最大' + str(self.maxPos) + ' '
+        msg += '请求' + str(self._req_num_per_window) + ' 上限 ' + str(self.limit_order_requests_num) + '次/10s '
+        run_time = time.time() - self._start_time
+        self.daily_return = round(self.profit / 100 / run_time * 86400, 5)
+        msg += '日化' + str(self.daily_return) + ' '
+        msg += '当前参数 ' + \
+            ' 开仓 ' + str(round(self.trade_open_dist,6)) + \
+            ' 平仓 ' + str(round(self.trade_close_dist,6)) + \
+            ' 方向 ' + str(self.post_side) + \
+            ' 参考 ' + self.ref_name[self.ref_index]  + \
+            ' 模式 ' + self.maker_mode + \
+            ' 预测 ' + str(round(self.predict,5)) + \
+            ' 预估24H成交额 ' + str(self.trade_vol_24h) + 'W' + \
+            ' 实时优化 ' + str(self.params.backtest)
+        # 写入日志
+        if self._is_print:
+            self.logger.info(msg)
+        #### 本地订单状态列表 ####
+        o_num = len(self.local_orders)
+        self.logger.info(f"挂单列表 共{o_num}单")
+        for cid in self.local_orders:
+            i = self.local_orders[cid]
+            msg = i['symbol'] + ' ' + str(i['client_id']) + ' ' + i['side'] + ' 杠杆: ' + \
+                str(round(i['amount'] * self.mp / (self.equity if self.equity > 0 else 99999999), 3)) + 'x 价值:' + \
+                str(round(i['amount'] * self.mp,2)) + 'u'    + ' 价格: ' + \
+                str(i['price']) + ' 偏离:' + \
+                str(round((i['price'] - self.mp) / self.mp * 100, 3)) + '%'
+            if self._is_print:
+                self.logger.info(msg)
+        self.logger.info("撤单列表")
+        if len(self.in_cancel) > 0:
+            if self._is_print:
+                self.logger.info(self.in_cancel)
+        self.logger.info("查单列表")
+        if len(self.in_check) > 0:
+            if self._is_print:
+                self.logger.info(self.in_check)
+
+    # def fix_amount(self, amount):
+    #     '''修补数量向下取整'''
+    #     return float(Decimal(str(amount)).quantize(Decimal(str(self.stepSize)), ROUND_FLOOR))
+
+    # def fix_price(self, price):
+    #     '''修补价格四舍五入'''
+    #     return float(Decimal(str(price)).quantize(Decimal(str(self.tickSize)), ROUND_HALF_UP))
+
+    def _cancel_targit_side_orders(self, order_side=["kd","kk","pd","pk"]):
+        '''清理指定类型挂单'''
+        signals = dict()
+        # 撤销指定类型挂单
+        for cid in self.local_orders:
+            i = self.local_orders[cid]
+            if i["side"] in order_side:
+                cid = i['client_id']
+                oid = i['order_id']
+                signals[f'Cancel{cid}'] = [cid, oid]
+        return signals
+
+    def _close_all(self):
+        '''
+            清空所有挂单和仓位保持休眠状态
+        '''
+        signals = dict()
+        # 撤销全部挂单
+        pd_amount = 0.0
+        pk_amount = 0.0
+        for cid in self.local_orders:
+            i = self.local_orders[cid]
+            cid = i['client_id']
+            oid = i['order_id']
+            if i['side'] == 'pk':
+                pk_amount += i['amount']
+            elif i['side'] == 'pd':
+                pd_amount += i['amount']
+            signals[f'Cancel{cid}'] = [cid, oid]
+        # 批量挂单
+        signals['Limits_close'] = []
+        need_close_long = self.pos.longPos - pd_amount
+        need_close_short = self.pos.shortPos - pk_amount
+        if "spot" in self.exchange:
+            if need_close_long * self.mp > self._min_amount_value:
+                amount = need_close_long
+                price = utils.fix_price(self.mp, self.tickSize)
+                amount = utils.fix_amount(amount, self.stepSize)
+                signals['Limits_close'].append([amount, 'pd', price, utils.get_cid(self.broker_id)])
+            if need_close_short * self.mp > self._min_amount_value:
+                amount = need_close_short
+                price = utils.fix_price(self.mp, self.tickSize)
+                amount = utils.fix_amount(amount, self.stepSize)
+                signals['Limits_close'].append([amount, 'pk', price, utils.get_cid(self.broker_id)])
+        else:
+            if need_close_long > 0:
+                # sell
+                price = utils.fix_price(self.mp, self.tickSize)
+                amount = need_close_long
+                if amount * self.mp > self._min_amount_value:
+                    signals['Limits_close'].append([amount, 'pd', price,utils.get_cid(self.broker_id)])
+            if need_close_short > 0:
+                # buy
+                price = utils.fix_price(self.mp, self.tickSize)
+                amount = need_close_short
+                if amount * self.mp > self._min_amount_value:
+                    signals['Limits_close'].append([amount, 'pk', price,utils.get_cid(self.broker_id)])
+        return signals
+
+    # def _cancel_close(self):
+    #     '''
+    #         取消平仓订单
+    #     '''
+    #     # 准备命令
+    #     signals = dict()
+    #     # 撤掉危险挂单
+    #     cond1 = self.close_dist[0]
+    #     cond2 = self.close_dist[1]
+    #     cond3 = self.close_dist[2]
+    #     cond4 = self.close_dist[3]
+    #     # # 获取当前挂单
+    #     for cid in self.local_orders:
+    #         i = self.local_orders[cid]
+    #         if i['side'] == 'pk':
+    #             if (i['price'] >= cond1 or i['price'] <= cond2):
+    #                 cid = i['client_id']
+    #                 oid = i['order_id']
+    #                 signals[f'Cancel{cid}'] = [cid, oid]
+    #         if i['side'] == 'pd':
+    #             if (i['price'] <= cond3 or i['price'] >= cond4):
+    #                 cid = i['client_id']
+    #                 oid = i['order_id']
+    #                 signals[f'Cancel{cid}'] = [cid, oid]
+    #     return signals
+
+    def _post_close(self):
+        '''
+            处理平仓
+        '''
+        # 准备命令
+        signals = dict()
+        signals['Limits_close'] = []
+        # 撤掉危险挂单
+        pdAmount = 0.0
+        pdOrderNum = 0
+        pkAmount = 0.0
+        pkOrderNum = 0
+        # 计算
+        cond1 = self.close_dist[0]
+        cond2 = self.close_dist[1]
+        cond3 = self.close_dist[2]
+        cond4 = self.close_dist[3]
+        # # 获取当前挂单
+        for cid in self.local_orders:
+            i = self.local_orders[cid]
+            if i['side'] == 'pk':
+                c1 = i['price'] > cond1
+                c2 = i['price'] < cond2
+                if c1 or c2:
+                    signals[f'Cancel{cid}'] = [i['client_id'], i['order_id']]
+                # else:
+                pkAmount += i['amount']
+                pkOrderNum += 1
+            elif i['side'] == 'pd':
+                c1 = i['price'] < cond3
+                c2 = i['price'] > cond4
+                if c1 or c2:
+                    signals[f'Cancel{cid}'] = [i['client_id'], i['order_id']]
+                # else:
+                pdAmount += i['amount']
+                pdOrderNum += 1
+        need_cancel_all_close = 0
+        if abs(pdAmount - self.pos.longPos) * self.mp > self._min_amount_value or \
+            abs(pkAmount - self.pos.shortPos) * self.mp > self._min_amount_value:
+            need_cancel_all_close = 1
+        if need_cancel_all_close:
+            for cid in self.local_orders:
+                i = self.local_orders[cid]
+                if i['side'] in ['pk','pd']:
+                    signals[f'Cancel{cid}'] = [i['client_id'], i['order_id']]
+        ####################### 检查是否需要挂平仓单
+        if 'spot' in self.exchange:
+            ### 需要平多的价值大于最小交易价值 执行平多逻辑 
+            if self.pos.longPos * self.mp > self._min_amount_value:  
+                if pdOrderNum == 0:  # 需要更新平仓挂单
+                    price = (cond3 + cond4)*0.5
+                    price = utils.clip(price, self.bp*0.9995, self.ap*1.03)
+                    price = utils.fix_price(price, self.tickSize)
+                    amount = self.pos.longPos
+                    amount = utils.fix_amount(amount, self.stepSize)
+                    if float(amount) * float(price) > self._min_amount_value:
+                        signals['Limits_close'].append([
+                            amount,
+                            'pd',
+                            price,
+                            utils.get_cid(self.broker_id)
+                        ])
+            if self.pos.shortPos > self._min_amount_value:  
+                if pkOrderNum == 0:  # 需要更新平仓挂单
+                    price = (cond1 + cond2)*0.5
+                    price = utils.clip(price, self.bp*0.97, self.ap*1.0005)
+                    price = utils.fix_price(price, self.tickSize)
+                    amount = self.pos.shortPos
+                    amount = utils.fix_amount(amount, self.stepSize)
+                    if float(amount) * float(price) > self._min_amount_value:
+                        signals['Limits_close'].append([
+                            amount,
+                            'pk',
+                            price,
+                            utils.get_cid(self.broker_id)
+                        ])
+        else:
+            if self.pos.longPos > 0.0:  # 正常平多
+                if pdOrderNum == 0:  # 需要更新平仓挂单
+                    price = cond3*0.5 + cond4*0.5
+                    price = utils.clip(price, self.bp*0.9995, self.ap*1.03)
+                    price = utils.fix_price(price, self.tickSize)
+                    signals['Limits_close'].append([
+                        self.pos.longPos,
+                        'pd',
+                        price,
+                        utils.get_cid(self.broker_id)
+                    ])
+            if self.pos.shortPos > 0.0:  # 正常平空
+                if pkOrderNum == 0:  # 需要更新平仓挂单
+                    price = cond1*0.5 + cond2*0.5
+                    price = utils.clip(price, self.bp*0.97, self.ap*1.0005)
+                    price = utils.fix_price(price, self.tickSize)
+                    signals['Limits_close'].append([
+                        self.pos.shortPos,
+                        'pk',
+                        price,
+                        utils.get_cid(self.broker_id)
+                    ])
+        return signals
+
+    def _cancel_open(self):
+        '''
+            撤销开仓
+        '''
+        # 准备命令
+        signals = dict()
+        # 计算挂单范围
+        cond1 = self.open_dist[0]
+        cond2 = self.open_dist[1]
+        cond3 = self.open_dist[2]
+        cond4 = self.open_dist[3]
+        # # 获取当前挂单
+        for cid in self.local_orders:
+            i = self.local_orders[cid]
+            if i['side'] == 'kd':
+                # 撤开多单
+                c1 = i['price'] > cond1
+                c2 = i['price'] < cond2
+                # c3 = i['price'] > cond1*(1-self.long_hold_rate) + cond2*self.long_hold_rate
+                # if c1 or c2 or c3:
+                if c1 or c2:
+                    signals[f'Cancel{cid}'] = [i['client_id'], i['order_id']]
+            elif i['side'] == 'kk':
+                # 撤开空单
+                c1 = i['price'] < cond3
+                c2 = i['price'] > cond4
+                # c3 = i['price'] < cond3*(1-self.short_hold_rate) + cond4*self.short_hold_rate
+                # if c1 or c2 or c3:
+                if c1 or c2:
+                    signals[f'Cancel{cid}'] = [i['client_id'], i['order_id']]
+        return signals
+
+    # @utils.timeit
+    def _post_open(self):
+        '''
+            处理开仓 开仓要一直挂
+        '''
+        # 准备命令
+        signals = dict()
+        signals['Limits_open'] = []
+        # 计算挂单范围
+        cond1 = self.open_dist[0]
+        cond2 = self.open_dist[1]
+        cond3 = self.open_dist[2]
+        cond4 = self.open_dist[3]
+        # # 获取当前挂单
+        buyP = []
+        sellP = []
+        buy_value = 0
+        sell_value = 0
+        for cid in self.local_orders:
+            i = self.local_orders[cid]
+            if i['side'] in ['kd']:
+                buyP.append(i['price'])
+                buy_value += i['amount'] * i['price']
+            if i['side'] in ['kk']:
+                sellP.append(i['price'])
+                sell_value += i['amount'] * i['price']
+        ########### 现货 ###########
+        if 'spot' in self.exchange:
+            ### 计算当前持币和持u ###
+            coin_value = self.coin * self.mp * self.leverrate * self.adjust_leverrate # 可卖价值
+            cash_value = self.cash * self.leverrate * self.adjust_leverrate # 可买价值
+            long_free_value = min(cash_value, self.max_long_value) - buy_value
+            short_free_value = min(coin_value, self.max_short_value) - sell_value
+        ########### 合约 ###########
+        else:
+            ### 合约只有已有仓位和开仓订单会占用保证金
+            long_free_value = self.max_long_value - self.long_hold_value - buy_value
+            short_free_value = self.max_short_value - self.short_hold_value - sell_value
+        #######################################
+        one_hand_long_value = self.max_long_value / self.grid * 0.99
+        one_hand_short_value = self.max_short_value / self.grid * 0.99
+        ############## 单层挂单 #################
+        ########### 挂多单 ###########
+        if self.post_side >= 0:
+            if len(buyP) == 0:
+                # 1
+                targit_buy_price = cond1 * 0.5 + cond2 * 0.5
+                targit_buy_price = utils.clip(targit_buy_price, self.bp*0.97, self.ap*1.0005)
+                targit_buy_price = utils.fix_price(targit_buy_price, self.tickSize)
+                value = min(one_hand_long_value, long_free_value)
+                amount = utils.fix_amount(value/self.mp, self.stepSize)
+                amount_value = float(amount) * self.mp
+                if amount_value >= self._min_amount_value and amount_value <= long_free_value:
+                    signals['Limits_open'].append([amount, 'kd', targit_buy_price, utils.get_cid(self.broker_id)])
+        ########### 挂空单 ###########
+        if self.post_side <= 0:
+            if len(sellP) == 0:
+                # 1
+                targit_sell_price = cond3 * 0.5 + cond4 * 0.5
+                targit_sell_price = utils.clip(targit_sell_price, self.bp*0.9995, self.ap*1.03)
+                targit_sell_price = utils.fix_price(targit_sell_price, self.tickSize)
+                value = min(one_hand_short_value, short_free_value)
+                amount = utils.fix_amount(value/self.mp, self.stepSize)
+                amount_value = float(amount) * self.mp
+                if amount_value >= self._min_amount_value and amount_value <= short_free_value:
+                    signals['Limits_open'].append([amount, 'kk', targit_sell_price, utils.get_cid(self.broker_id)])
+        ############## 多层挂单 #################
+        # step = (cond1-cond2) / self.grid
+        # ########### 挂多单 ###########
+        # if self.post_side >= 0:
+        #     if len(buyP) == 0:
+        #         # 1
+        #         targit_buy_price = cond1 * 0.1 + cond2 * 0.9
+        #         targit_buy_price = utils.clip(targit_buy_price, self.bp*0.97, self.ap*1.0005)
+        #         targit_buy_price = self.fix_price(targit_buy_price)
+        #         value = min(one_hand_long_value, long_free_value)
+        #         amount = self.fix_amount(value/self.mp)
+        #         amount_value = float(amount) * self.mp
+        #         if targit_buy_price < cond1*(1-self.long_hold_rate) + cond2*self.long_hold_rate:
+        #             if amount_value >= self._min_amount_value and amount_value <= long_free_value:
+        #                 signals['Limits_open'].append([amount, 'kd', targit_buy_price, utils.get_cid(self.broker_id)])
+        #     else:
+        #         # 2
+        #         targit_buy_price = max(buyP) + step
+        #         targit_buy_price = utils.clip(targit_buy_price, self.bp*0.97, self.ap*1.0005)
+        #         targit_buy_price = self.fix_price(targit_buy_price)
+        #         value = min(one_hand_long_value, long_free_value)
+        #         amount = self.fix_amount(value/self.mp)
+        #         amount_value = float(amount) * self.mp
+        #         if targit_buy_price < cond1*(1-self.long_hold_rate) + cond2*self.long_hold_rate:
+        #             if targit_buy_price < cond1 - 0.1*step and targit_buy_price > cond2 + 0.1*step:
+        #                 if amount_value >= self._min_amount_value and amount_value <= long_free_value:
+        #                     signals['Limits_open'].append([amount, 'kd', targit_buy_price, utils.get_cid(self.broker_id)])
+        #         # 3
+        #         targit_buy_price = min(buyP) - step
+        #         targit_buy_price = utils.clip(targit_buy_price, self.bp*0.97, self.ap*1.0005)
+        #         targit_buy_price = self.fix_price(targit_buy_price)
+        #         value = min(one_hand_long_value, long_free_value)
+        #         amount = self.fix_amount(value/self.mp)
+        #         amount_value = float(amount) * self.mp
+        #         if targit_buy_price < cond1*(1-self.long_hold_rate) + cond2*self.long_hold_rate:
+        #             if targit_buy_price < cond1 - 0.1*step and targit_buy_price > cond2 + 0.1*step:
+        #                 if amount_value >= self._min_amount_value and amount_value <= long_free_value:
+        #                     signals['Limits_open'].append([amount, 'kd', targit_buy_price, utils.get_cid(self.broker_id)])
+        # ########### 挂空单 ###########
+        # if self.post_side <= 0:
+        #     if len(sellP) == 0:
+        #         # 1
+        #         targit_sell_price = cond3 * 0.1 + cond4 * 0.9
+        #         targit_sell_price = utils.clip(targit_sell_price, self.bp*0.9995, self.ap*1.03)
+        #         targit_sell_price = self.fix_price(targit_sell_price)
+        #         value = min(one_hand_short_value, short_free_value)
+        #         amount = self.fix_amount(value/self.mp)
+        #         amount_value = float(amount) * self.mp
+        #         if targit_sell_price > cond3*(1-self.short_hold_rate) + cond4*self.short_hold_rate:
+        #             if amount_value >= self._min_amount_value and amount_value <= short_free_value:
+        #                 signals['Limits_open'].append([amount, 'kk', targit_sell_price, utils.get_cid(self.broker_id)])
+        #     else:
+        #         # 2
+        #         targit_sell_price = min(sellP) - step
+        #         targit_sell_price = utils.clip(targit_sell_price, self.bp*0.9995, self.ap*1.03)
+        #         targit_sell_price = self.fix_price(targit_sell_price)
+        #         value = min(one_hand_short_value, short_free_value)
+        #         amount = self.fix_amount(value/self.mp)
+        #         amount_value = float(amount) * self.mp
+        #         if targit_sell_price > cond3*(1-self.short_hold_rate) + cond4*self.short_hold_rate:
+        #             if targit_sell_price > cond3 + 0.1*step and targit_sell_price < cond4 - 0.1*step:
+        #                 if amount_value >= self._min_amount_value and amount_value <= short_free_value:
+        #                     signals['Limits_open'].append([amount, 'kk', targit_sell_price, utils.get_cid(self.broker_id)])
+        #         # 3
+        #         targit_sell_price = max(sellP) + step
+        #         targit_sell_price = utils.clip(targit_sell_price, self.bp*0.9995, self.ap*1.03)
+        #         targit_sell_price = self.fix_price(targit_sell_price)
+        #         value = min(one_hand_short_value, short_free_value)
+        #         amount = self.fix_amount(value/self.mp)
+        #         amount_value = float(amount) * self.mp
+        #         if targit_sell_price > cond3*(1-self.short_hold_rate) + cond4*self.short_hold_rate:
+        #             if targit_sell_price > cond3 + 0.1*step and targit_sell_price < cond4 - 0.1*step:
+        #                 if amount_value >= self._min_amount_value and amount_value <= short_free_value:
+        #                     signals['Limits_open'].append([amount, 'kk', targit_sell_price, utils.get_cid(self.broker_id)])
+        ############################################
+        return signals
+
+    # @utils.timeit
+    def gen_dist(self, open, close, pos_rate, ref_bp, ref_ap, predict, grid=1, mode='free'):
+        '''
+            Input: 
+                开仓距离
+                平仓距离
+                参考价格
+            产生挂单位置
+            4
+            3
+            1
+            2
+            用预测调近挂单距离很危险
+        '''
+        ###########################
+        mp = (ref_bp+ref_ap)*0.5
+        buy_start = mp
+        sell_start = mp
+        ###########################
+        # 持有多仓时 平仓sell更近  持有空仓时 平仓buy 更近
+        avoid = min(0.0005, close * 0.5) # 平仓位置偏移可以适当大一点
+        # 平仓位置
+        close_dist = [
+            buy_start * ( 1 + predict - close + avoid), # buy upper
+            buy_start * ( 1 + predict - close - avoid), # buy lower
+            sell_start * ( 1 + predict + close - avoid), # sell lower
+            sell_start * ( 1 + predict + close + avoid), # sell upper
+        ]
+        ######################################################
+        if mode == 'free':
+            # 自由做市
+            buy_start = ref_bp
+            sell_start = ref_ap
+        elif mode == 'follow':
+            # 跟随做市
+            mp = (ref_bp+ref_ap)*0.5
+            buy_start = mp
+            sell_start = mp
+        else:
+            # 跟随做市
+            mp = (ref_bp+ref_ap)*0.5
+            buy_start = mp
+            sell_start = mp
+        ###########################
+        ###########################
+        # 持有多仓时 开仓buy更远  持有空仓时 开仓sell 更远
+        avoid = min(0.001, open * 0.05) # 开仓位置偏移可以适当小一点
+        # 持仓偏移
+        buy_shift = 1 + pos_rate[0] * grid
+        sell_shift = 1 + pos_rate[1] * grid
+        # 保护窗口
+        open_dist = [
+            buy_start * ( 1 + predict - open * buy_shift  + avoid), # buy upper
+            buy_start * ( 1 + predict - open * buy_shift  - avoid), # buy lower
+            sell_start * ( 1 + predict + open * sell_shift - avoid), # sell lower
+            sell_start * ( 1 + predict + open * sell_shift + avoid), # sell upper
+        ]
+        ###########################
+        return open_dist, close_dist
+
+    def _update_request_num(self, signals):
+        '''统计请求次数'''
+        if 'Limits_open' in signals:
+            self.request_num += len(signals['Limits_open'])
+            self.request_order_num += len(signals['Limits_open'])
+        if 'Limits_close' in signals:
+            self.request_num += len(signals['Limits_close'])
+            self.request_order_num += len(signals['Limits_close'])
+        for i in signals:
+            if 'Cancel' in i:
+                self.request_num += 1
+            elif 'Check' in i:
+                self.request_num += 1
+
+    def _check_request_limit(self, signals):
+        '''根据平均请求次数限制开仓下单'''
+        if self.request_num > self.limit_requests_num:
+            return dict()
+        elif self.request_num >= self.limit_requests_num * 0.5 or self.request_order_num >= self.limit_order_requests_num*0.8:
+            new_signals = dict()
+            for order_name in signals:
+                if 'Limits_open' in order_name:
+                    pass
+                elif 'Limits_close' in order_name and self.request_order_num >= self.limit_order_requests_num:
+                    pass
+                else:
+                    new_signals[order_name] = signals[order_name]
+            return new_signals
+        else:
+            return signals
+
+    def _check_local_orders(self):
+        '''超过时间限制触发查单信号'''
+        signals = dict()
+        if self.local_time - self._check_local_orders_time >= self._check_local_orders_interval:
+            for cid in self.local_orders:
+                # 如果没在查单队列中
+                if cid not in self.in_check:
+                    # 超过10s没动的订单 进行检查
+                    if self.local_time - self.local_orders[cid]["localtime"] > self._check_local_orders_interval:
+                        signals[f"Check{cid}"] = [cid, self.local_orders[cid]['order_id']]
+                        self.in_check[cid] = self.local_time
+                        if self._is_print:
+                            self.logger.debug(f"查询订单 {cid}")
+            # 维护查单队列
+            self._release_in_check()
+            # 更新查单时间
+            self._check_local_orders_time = self.local_time
+        return signals
+
+    def _release_in_check(self):
+        '''检查是否正在撤单'''
+        new_dict = dict()
+        for cid in self.in_check:
+            # 等待超过后移除正在撤单队列
+            if self.local_time - self.in_check[cid] <= self.check_wait_interval:
+                new_dict[cid] = self.in_check[cid]
+        self.in_check = new_dict
+
+    def _release_in_cancel(self):
+        '''检查是否正在撤单'''
+        new_dict = dict()
+        for cid in self.in_cancel:
+            # 等待超过后移除正在撤单队列
+            if self.local_time - self.in_cancel[cid] <= self.cancel_wait_interval:
+                new_dict[cid] = self.in_cancel[cid]
+        self.in_cancel = new_dict
+
+    def _update_in_cancel(self, signals):
+        '''
+            新增正在撤单
+            检查撤单队列
+            释放过时限制
+        '''
+        new_signals = dict()
+        for i in signals:
+            if 'Cancel' in i:
+                cid = signals[i][0]
+                need_limit_cancel = 1
+                # 判断是否在挂单表中
+                if cid in self.local_orders:
+                    # 判断是否在订单创建100ms内
+                    if self.local_time - self.local_orders[cid]['createtime'] < 0.1:
+                        # 解除撤单限制
+                        need_limit_cancel = 0
+                if need_limit_cancel:
+                    # 增加撤单限制
+                    if cid not in self.in_cancel:
+                        self.in_cancel[cid] = self.local_time
+                        new_signals[i] = signals[i]
+            else:
+                new_signals[i] = signals[i]
+        ### 释放撤单限制
+        self._release_in_cancel()
+        return new_signals
+
+    def _refresh_request_limit(self):
+        if self.local_time - self.request_limit_check_time >= self.request_limit_check_interval:
+            self._req_num_per_window = self.request_num
+            self.request_num = 0
+            self.request_order_num = 0
+            self.request_limit_check_time = self.local_time
+
+    def _pos_rate(self):
+        '''获取持仓比例 0~1'''
+        long_hold_rate = 0.0
+        short_hold_rate = 0.0
+        if self.max_long_value > 0.0:
+            long_hold_rate = self.long_hold_value/self.max_long_value
+        if self.max_short_value > 0.0:
+            short_hold_rate = self.short_hold_value/self.max_short_value
+        # print(long_hold_rate, short_hold_rate)
+        self.long_hold_rate = long_hold_rate
+        self.short_hold_rate = short_hold_rate
+
+    def check_ready(self):
+        '''检查准备'''
+        pre_hot = 10
+        if int(self.params.backtest):
+            pre_hot = utils.BACKTEST_PREHOT_SECOND
+        if self.ready != 1:
+            if isinstance(self.mp, float) and self.local_time - self.local_start_time > pre_hot:
+                self.ready = 1
+                if self._is_print: print('预热完毕')
+            return 1
+        else:
+            return 0
+
+    def check_allow_post_open(self):
+        '''
+            检查是否允许报单
+        '''
+        ### 接近整点时刻 不允许报单 防止下单bug ###
+        diff_time = self.local_time % 3600
+        if diff_time < 30 or diff_time > 3570:
+            return 0
+        ########################################
+        return 1
+
+    def onExit(self, data):
+        '''
+            全撤全平 准备退出
+        '''
+        try:
+            # 更新状态
+            if self._update_data(data):
+                # 检查是否准备充分
+                if self.check_ready():
+                    return dict()
+                # 交易模式
+                signals = self._close_all()
+                # 更新撤单队列
+                signals = self._update_in_cancel(signals)
+                # 交易模式
+                signals = self._check_request_limit(signals)
+                # 统计请求频率
+                self._update_request_num(signals)
+                return signals
+        except:
+            traceback.print_exc()
+
+    def onSleep(self, data):
+        '''
+            全撤 不再下新订单了 防止影响check_position执行
+        '''
+        try:
+            # 更新状态
+            if self._update_data(data):
+                # 检查是否准备充分
+                if self.check_ready():
+                    return dict()
+                # 交易模式
+                signals = self._cancel_targit_side_orders()
+                # 更新撤单队列
+                signals = self._update_in_cancel(signals)
+                # 交易模式
+                signals = self._check_request_limit(signals)
+                # 统计请求频率
+                self._update_request_num(signals)
+                return signals
+        except:
+            traceback.print_exc()
+
+    # @timeit
+    # def onOrder(self, data):
+    #     '''
+    #         call on order update event
+    #     '''
+    #     try:
+    #         # 更新状态
+    #         if self._update_data(data):
+    #             # 检查是否准备充分
+    #             if self.check_ready():
+    #                 return dict()
+    #             ###### 关键操作 ######
+    #             # 更新挂单距离
+    #             self._pos_rate()
+    #             open_dist, close_dist = self.gen_dist(
+    #                 open=self.trade_open_dist,
+    #                 close=self.trade_close_dist, 
+    #                 pos_rate=[self.long_hold_rate, self.short_hold_rate],
+    #                 ref_bp=self.ref_bp,
+    #                 ref_ap=self.ref_ap,
+    #                 predict=self.predict,
+    #                 grid=self.grid,
+    #                 mode=self.maker_mode,
+    #             )
+    #             self.open_dist = \
+    #                 [
+    #                     self.fix_price(open_dist[0]),
+    #                     self.fix_price(open_dist[1]),
+    #                     self.fix_price(open_dist[2]),
+    #                     self.fix_price(open_dist[3]),
+    #                 ]
+    #             self.close_dist = \
+    #                 [
+    #                     self.fix_price(close_dist[0]),
+    #                     self.fix_price(close_dist[1]),
+    #                     self.fix_price(close_dist[2]),
+    #                     self.fix_price(close_dist[3]),
+    #                 ]
+    #             # 获取开平仓指令
+    #             signals = dict()
+    #             # 获取平仓信号
+    #             signals.update(self._post_close())
+    #             # 更新撤单队列
+    #             signals = self._update_in_cancel(signals)
+    #             # 限制频率
+    #             signals = self._check_request_limit(signals)
+    #             # 统计请求频率
+    #             self._update_request_num(signals)
+    #             ##########################
+    #             return signals
+    #     except:
+    #         traceback.print_exc()
+    #         if self._is_print:self.logger.error(traceback.format_exc())
+
+    # @timeit
+    # def onTick(self, data):
+    #     '''
+    #         call on ticker update event
+    #     '''
+    #     try:
+    #         # 更新状态
+    #         if self._update_data(data):
+    #             # 检查是否准备充分
+    #             if self.check_ready():
+    #                 return dict()
+    #             ###### 关键操作 ######
+    #             # 更新挂单距离
+    #             self.open_dist, self.close_dist = utils.gen_dist(
+    #                 self.trade_open_dist,
+    #                 self.trade_close_dist, 
+    #                 self._pos_rate(),
+                    # self.ref_bp,
+                    # self.ref_ap,
+    #                 self.predict
+    #             )
+    #             # 获取开平仓指令
+    #             signals = dict()
+    #             # 获取平仓信号
+    #             signals.update(self._cancel_open())
+    #             # 更新撤单队列
+    #             signals = self._update_in_cancel(signals)
+    #             # 限制频率
+    #             signals = self._check_request_limit(signals)
+    #             # 统计请求频率
+    #             self._update_request_num(signals)
+    #             ##########################
+    #             return signals
+    #     except:
+    #         traceback.print_exc()
+    #         if self._is_print:self.logger.error(traceback.format_exc())
+
+    # @utils.timeit
+    def onTime(self, data):
+        '''
+            call on time 
+        '''
+        try:
+            # 定时打印
+            if self.local_time - self._print_time > self._print_interval:
+                self._print_time = self.local_time
+                if self._is_print: 
+                    if self.ready:
+                        pass
+                    else:
+                        self.logger.info("预热中")
+            # 更新状态
+            if self._update_data(data):
+                # 检查是否准备充分
+                if self.check_ready():
+                    return dict()
+                ###### 关键操作 ######
+                # 更新挂单距离
+                self._pos_rate()
+                open_dist, close_dist = self.gen_dist(
+                    open=self.trade_open_dist,
+                    close=self.trade_close_dist, 
+                    pos_rate=[self.long_hold_rate, self.short_hold_rate],
+                    ref_bp=self.ref_bp,
+                    ref_ap=self.ref_ap,
+                    predict=self.predict,
+                    grid=self.grid,
+                    mode=self.maker_mode,
+                )
+                self.open_dist = \
+                    [
+                        utils.fix_price(open_dist[0], self.tickSize),
+                        utils.fix_price(open_dist[1], self.tickSize),
+                        utils.fix_price(open_dist[2], self.tickSize),
+                        utils.fix_price(open_dist[3], self.tickSize),
+                    ]
+                self.close_dist = \
+                    [
+                        utils.fix_price(close_dist[0], self.tickSize),
+                        utils.fix_price(close_dist[1], self.tickSize),
+                        utils.fix_price(close_dist[2], self.tickSize),
+                        utils.fix_price(close_dist[3], self.tickSize),
+                    ]
+                # 获取开平仓指令
+                signals = dict()
+                # 获取撤单信号
+                signals.update(self._cancel_open())
+                # 获取开仓信号 整点时刻前后不报单
+                if self.check_allow_post_open():
+                    if self.local_time - self.post_open_time > self.post_open_interval:
+                        self.post_open_time = self.local_time
+                        signals.update(self._post_open())
+                # 获取平仓信号
+                signals.update(self._post_close())
+                # 每隔固定时间检查超时订单
+                signals.update(self._check_local_orders())
+                # 更新撤单队列
+                signals = self._update_in_cancel(signals)
+                # 限制频率
+                signals = self._check_request_limit(signals)
+                # 刷新频率限制
+                self._refresh_request_limit()
+                # 统计请求频率
+                self._update_request_num(signals)
+                ##########################
+                return signals
+        except:
+            traceback.print_exc()
+            if self._is_print:self.logger.error(traceback.format_exc())

+ 528 - 0
utils.py

@@ -0,0 +1,528 @@
+import json
+import traceback
+import utils
+import model
+import toml, time, random
+import os, sys, asyncio, aiohttp
+import socket
+import asyncio
+import requests
+import ujson
+from decimal import Decimal
+from decimal import ROUND_HALF_UP, ROUND_FLOOR
+import gzip
+import csv
+import os
+import base64
+from Crypto.Cipher import AES
+from Crypto import Random
+import os
+import base64
+import json
+
+parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 
+sys.path.insert(0,parentdir)  
+
+############### 全局配置
+VERSION = "2022-04-18"
+CHILD_RUN_SECOND = 60 * 60 * 24 # child process max run time per loop 
+EARLY_STOP_SECOND = 60 * 60 * 2 # child early stop min check time
+BACKTEST_PREHOT_SECOND = 60 * 30 # backtest pre hot time
+DUMMY_RUN_SECOND = 60 * 60 * 12 # dummy process max run time per loop 
+DUMMY_EARLY_STOP_SECOND = 60 * 60 # dummy process max run time per loop 
+POST_SIDE_LIMIT = [0] # post side limit
+MARKET_DELAY_LIMIT = 30000 # market update delay limit threhold unit:ms
+GRID = 1
+STOPLOSS = 0.02
+GAMMA = 0.999
+###### market行情数据长度 标准化n档深度+6档成交信息 ######
+LEVEL = 1
+TRADE_LEN = 2 # 最高 最低 成交价
+LEN = LEVEL * 4 + TRADE_LEN # 总长度
+BP_INDEX = LEVEL * 0
+BQ_INDEX = LEVEL * 0 + 1
+AP_INDEX = LEVEL * 2
+AQ_INDEX = LEVEL * 2 + 1
+MAX_FILL_INDEX = LEVEL * 4 + 0
+MIN_FILL_INDEX = LEVEL * 4 + 1
+# BUY_Q_INDEX = LEVEL * 4 + 2
+# BUY_V_INDEX = LEVEL * 4 + 3
+# SELL_Q_INDEX = LEVEL * 4 + 4
+# SELL_V_INDEX = LEVEL * 4 + 5
+#### depth/trade effient range #####
+EFF_RANGE = 0.001
+### init backtest delay ###
+BACKTEST_DELAY = 0.15
+
+global base_cid
+base_cid = 0
+def get_cid(broker=None):
+    global base_cid
+    base_cid += 1
+    if base_cid > 999:
+        base_cid=0
+    cid = str(time.time())[4:10]+str(random.randint(1,999))+str(base_cid)
+    if broker:
+        cid = broker + cid
+    return cid      
+
+def csv_to_gz_and_remove():
+    def List_files(filepath, substr):
+        X = []
+        Y = []
+        for path, subdirs, files in sorted(os.walk(filepath), reverse=True):
+            for name in files:
+                X.append(os.path.join(path, name))
+        Y = [line for line in X if substr in line]
+        return Y
+
+    for file in List_files('./', '.csv'):
+        if '.gz' not in file:
+            data = open(file, 'rb' ).read()
+            with gzip.open(file + '.gz', 'a') as zip:
+                zip.write(data)
+                zip.close()
+            os.remove(file)
+
+def get_params(fname):
+    # 读取配置
+    try:
+        params = toml.load(fname)
+    except:
+        f = open(fname)
+        data = f.read()
+        text = base64.b64decode(data)
+        cryptor = AES.new(key =bytes("qFHFPv6MugrSTkEsWFs8wCDg3iC6!er%".encode()), mode=AES.MODE_ECB)
+        plain_text = cryptor.decrypt(text)
+        paddingLen = plain_text[len(plain_text)-1]
+        msg = plain_text[0:-paddingLen]
+        msg = msg.decode()
+        params = toml.loads(msg)
+    p = model.Config()
+    # 账号昵称
+    p.account_name = params['account_name'] if 'account_name' in params else 'Unknown Account'
+    # api
+    p.access_key = params['access_key'].replace(" ", "") if 'access_key' in params else '***'
+    p.secret_key = params['secret_key'].replace(" ", "") if 'secret_key' in params else '***'
+    p.pass_key = params['pass_key'].replace(" ", "") if 'pass_key' in params else 'qwer1234'
+    # 经纪商id
+    broker_id_from_config = params['broker_id'] if 'broker_id' in params else ""
+    p.broker_id = get_broker_id( broker_id_from_config, params['exchange'])
+    # 交易盘口
+    p.exchange = params['exchange'] if 'exchange' in params else ""
+    # 交易品种
+    p.pair = params['pair'] if 'pair' in params else ""
+    # 调试模式开关
+    p.debug = params['debug'] if 'debug' in params else "False"
+    # 开仓
+    p.open = params['open'] if 'open' in params else "0.002"
+    # 平仓
+    p.close = params['close'] if 'close' in params else "0.0002"
+    # 监听端口
+    p.server_port = params['server_port'] if 'server_port' in params else 6000
+    # 杠杆大小
+    p.leverrate = float(params['leverrate']) if 'leverrate' in params else 1.0
+    # 参考盘口
+    p.refexchange = params['refexchange'].replace('[','').replace(']','').replace("'",'').replace(" ", "").split(',') if "refexchange" in params else ""
+    # 参考品种
+    p.refpair = params['refpair'].replace('[','').replace(']','').replace("'",'').replace(" ", "").split(',') if "refpair" in params else ""
+    # 网络代理
+    p.proxy = params['proxy'] if 'proxy' in params else None # 仅在win下有效
+    # 账户资金使用比例
+    p.used_pct = params['used_pct'] if 'used_pct' in params else "0.9"
+    # discord播报地址
+    p.webhook = params['webhook'] if 'webhook' in params else "https://discord.com/api/webhooks/907870708481265675/IfN4GqH4fj8HWS_FecH3Lrc2qtRyqsCHsSJVLFHlxY8ioHprfdxIMUNAfqkZZ6opzVEP"  
+    # 默认第n参考盘口
+    p.index = params['index'] if 'index' in params else 0
+    # 止损比例 0.02 = 2%
+    p.stoploss = params['stoploss'] if 'stoploss' in params else STOPLOSS
+    # 平滑系数 默认0.999
+    p.gamma = params['gamma'] if 'gamma' in params else GAMMA
+    # 分批建仓功能 小资金建议1 大资金建议3
+    p.grid = params['grid'] if 'grid' in params else GRID
+    # 实时调参开关 会有巨大性能损耗
+    p.backtest = params['backtest'] if 'backtest' in params else 1
+    # 保存实时行情 会有巨大性能损耗
+    p.save = params['save'] if 'save' in params else 0
+    p.place_order_limit = params['place_order_limit'] if 'place_order_limit' in params else 0 # 允许的每秒下单次数
+    # 是否启用colocation技术
+    p.colo = params['colo'] if 'colo' in params else 0 
+    # 是否启用fast行情 会增加性能开销
+    p.fast = params['fast'] if 'fast' in params else 1 
+    # 选择指定的私有ip进行网络通信 默认0 用于多网卡多ip的实例
+    p.ip = params['ip'] if 'ip' in params else 0
+    # 合约不允许holdcoin持有底仓币
+    if "spot" in p.exchange:
+        p.hold_coin = params['hold_coin'] if 'hold_coin' in params else 0.0
+    else:
+        p.hold_coin = 0.0
+    # 是否开启日志记录 会有一定性能损耗
+    p.log = params['log'] if 'log' in params else 1
+    #### 特殊情况处理
+    if p.exchange == 'binance_usdt_swap':
+        if p.pair in ['shib_usdt', 'xec_usdt', 'bttc_usdt']:
+            p.pair = "1000" + p.pair
+    ref_num = len(p.refexchange)
+    for i in range(ref_num):
+        if p.refexchange[i] == 'binance_usdt_swap':
+            if p.refpair[i] in ['shib_usdt', 'xec_usdt', 'bttc_usdt']:
+                p.refpair[i] = "1000" + p.refpair[i]
+    ####
+
+    print('debu11g')
+    print(p)
+    return p
+
+def get_broker_id(broker_id , exchange_name):
+    '''处理brokerid特殊情况'''
+    if 'binance' in exchange_name:
+        return broker_id
+    elif 'gate' in exchange_name:
+        return "t-"
+    else:
+        return ""
+
+# 报单频率限制等级
+BASIC_LIMIT = 100
+GATE_SPOT_LIMIT = 10.0
+GATE_USDT_SWAP_LIMIT = 100.0
+KUCOIN_SPOT_LIMIT = 15.0
+KUCOIN_USDT_SWAP_LIMIT = 10.0
+BINANCE_USDT_SWAP_LIMIT = 5.0
+BINANCE_SPOT_LIMIT = 2.0
+COINEX_SPOT_LIMIT = 20.0
+COINEX_USDT_SWAP_LIMIT = 20.0
+OKEX_USDT_SWAP_LIMIT= 30.0
+BITGET_USDT_SWAP_LIMIT = 10.0
+BYBIT_USDT_SWAP_LIMIT = 1.0
+MEXC_SPOT_LIMIT = 333
+RATIO = 4.0
+
+def get_limit_requests_num_per_second(exchange, limit=0):
+    '''每秒请求频率'''
+    if limit != 0:
+        return limit*RATIO
+    elif exchange == "gate_spot":
+        return GATE_SPOT_LIMIT*RATIO
+    elif exchange == "gate_usdt_swap": # 100/s
+        return GATE_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "kucoin_spot": # 15/s
+        return KUCOIN_SPOT_LIMIT*RATIO
+    elif exchange == "kucoin_usdt_swap":
+        return KUCOIN_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "binance_usdt_swap":
+        return BINANCE_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "binance_spot":
+        return BINANCE_SPOT_LIMIT*RATIO
+    elif exchange == "coinex_spot":
+        return COINEX_SPOT_LIMIT*RATIO
+    elif exchange == "coinex_usdt_swap":
+        return COINEX_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "okex_usdt_swap":
+        return OKEX_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "bitget_usdt_swap":
+        return BITGET_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "bybit_usdt_swap":
+        return BYBIT_USDT_SWAP_LIMIT*RATIO
+    elif exchange == "mexc_spot":
+        return MEXC_SPOT_LIMIT*RATIO
+    else:
+        print("限频规则未找到")
+    return BASIC_LIMIT*RATIO
+
+
+def get_limit_order_requests_num_per_second(exchange, limit=0):
+    '''每秒下单请求频率'''
+    if limit != 0:
+        return limit
+    elif exchange == "gate_spot": # 10/s
+        return GATE_SPOT_LIMIT
+    elif exchange == "gate_usdt_swap": # 100/s
+        return GATE_USDT_SWAP_LIMIT
+    elif exchange == "kucoin_spot": # 15/s
+        return KUCOIN_SPOT_LIMIT
+    elif exchange == "kucoin_usdt_swap": # 10/s
+        return KUCOIN_USDT_SWAP_LIMIT
+    elif exchange == "binance_usdt_swap": # 5/s
+        return BINANCE_USDT_SWAP_LIMIT
+    elif exchange == "binance_spot": # 2/s
+        return BINANCE_SPOT_LIMIT
+    elif exchange == "coinex_spot": # 20/s
+        return COINEX_SPOT_LIMIT
+    elif exchange == "coinex_usdt_swap": # 20/s
+        return COINEX_USDT_SWAP_LIMIT
+    elif exchange == "okex_usdt_swap": # 30/s
+        return OKEX_USDT_SWAP_LIMIT
+    elif exchange == "bitget_usdt_swap": # 10/s
+        return BITGET_USDT_SWAP_LIMIT
+    elif exchange == "bybit_usdt_swap": # 2/s
+        return BYBIT_USDT_SWAP_LIMIT
+    elif exchange == "mexc_spot": # 2/s
+        return MEXC_SPOT_LIMIT
+    else:
+        print("限频规则未找到")
+    return BASIC_LIMIT
+
+def dist_to_weight(price, mp, eff_range=EFF_RANGE):
+    '''
+        距离转换为权重
+    '''
+    dist = abs(price-mp)/mp
+    weight = 1 - clip(dist/eff_range, 0.0, 0.95)
+    weight = weight if weight > 0 else 0
+    return weight
+
+def change_params(fname, params, changes):
+    # 更改配置
+    for i in changes:
+        params[i[0]] = i[1]
+    with open(f"{fname}","w") as f:
+        toml.dump(params,f)
+
+def show_memory(unit='B', threshold=1024):
+    '''查看变量占用内存情况
+
+    :param unit: 显示的单位,可为`B`,`KB`,`MB`,`GB`
+    :param threshold: 仅显示内存数值大于等于threshold的变量
+    '''
+    from sys import getsizeof
+    scale = {'B': 1, 'KB': 1024, 'MB': 1048576, 'GB': 1073741824}[unit]
+    msg = '内存占用情况: \n'
+    for i in list(globals().keys()):
+        memory = eval("getsizeof({})".format(i)) // scale
+        if memory >= threshold:
+            msg += f'{i} {memory} {unit}\n'
+    print(msg)
+    return msg
+
+def clip(num, _min, _max):
+    if num > _max: num = _max
+    if num < _min: num = _min
+    return num
+
+async def ding(msg, at_all, webhook, proxy=None):
+    '''
+        发送钉钉消息
+    '''
+    header = {
+        "Content-Type": "application/json",
+        "Charset": "UTF-8"
+    }
+    embed = {
+        "title": "策略通知",
+        "description": msg
+    }
+    message = {
+    "content": "大吉大利 今晚吃鸡",
+    "username": "千千喵",
+    "embeds": [
+        embed
+            ],
+    }
+    message_json = json.dumps(message)
+    if 'win' in sys.platform:
+        proxy = proxy
+    else:
+        proxy = None
+    async with aiohttp.ClientSession() as session:
+        await session.post(url=webhook, data=message_json, headers=header, proxy=proxy, timeout = 10)
+
+def _get_params(url, proxy, params):
+    '''更新参数'''
+    import requests
+    try:
+        res = requests.post(url=url, json=params, timeout = 10)
+        return json.loads(res.text)
+    except:
+        traceback.print_exc()
+        return []
+
+async def _post_params(url, proxy, params):
+    '''更新参数'''
+    try:
+        if 'win' in sys.platform:
+            proxy = proxy
+        else:
+            proxy = None
+        async with aiohttp.ClientSession() as session:
+            res = await session.post(url=url, proxy=proxy, data=params, timeout = 10)
+            data = await res.text()
+            print(data)
+            return data
+    except:
+        print(traceback.format_exc())
+        return "post_params error"
+    return None
+
+def get_ip():
+    try:
+        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        s.connect(('8.8.8.8', 80))
+        ip = s.getsockname()[0]
+    finally:
+        s.close()
+    return ip
+
+def check_auth():
+    print("*** 检查使用权限1 ***")
+    ip = get_ip()
+    print(f"当前IP {ip}")
+    white_list = requests.get(f"http://158.247.204.56:7777/ip_list")
+    if ip in white_list:
+        print("当前IP位于白名单中")
+    else:
+        print("@@@ 本版本仅限指定IP白名单运行 @@@")
+        os._exit(0)
+    print("*** 符合要求 ***")
+
+def check_time():
+    print("*** 检查使用权限2 ***")
+    if time.time() > int(time.mktime(time.strptime('2021-11-17 00:00:00', "%Y-%m-%d %H:%M:%S"))):
+        print("@@@ 此版本目前已过试用期 @@@")
+        os._exit(0)
+    print("*** 符合要求 ***")
+
+def num_to_str(num, d):
+    if d >= 1.0:return "%d"%num
+    elif d in [0.1, 0.5]:return "%.1f"%num
+    elif d in [0.01, 0.05]:return "%.2f"%num
+    elif d in [0.001, 0.005]:return "%.3f"%num
+    elif d in [0.0001, 0.0005]:return "%.4f"%num
+    elif d in [0.00001, 0.00005]:return "%.5f"%num
+    elif d in [0.000001, 0.000005]:return "%.6f"%num
+    elif d in [0.0000001, 0.0000005]:return "%.7f"%num
+    elif d in [0.00000001, 0.00000005]:return "%.8f"%num
+    elif d in [0.000000001, 0.000000005]:return "%.9f"%num
+    elif d in [0.0000000001, 0.0000000005]:return "%.10f"%num
+    else: return str(num)
+
+def num_to_decimal(num):
+    '''根据小数点位数获取精度'''
+    num = str(num)
+    if '.' not in num:return 0
+    elif '.' == num[-2]:return 1
+    elif '.' == num[-3]:return 2
+    elif '.' == num[-4]:return 3
+    elif '.' == num[-5]:return 4
+    elif '.' == num[-6]:return 5
+    elif '.' == num[-7]:return 6
+    elif '.' == num[-8]:return 7
+    elif '.' == num[-9]:return 8
+    elif '.' == num[-10]:return 9
+    elif '.' == num[-11]:return 10
+    else:return 11
+
+def fix_amount(amount, stepSize):
+    '''修补数量向下取整'''
+    return float(
+            Decimal(str(amount))//Decimal(str(stepSize)
+        ) \
+        * Decimal(str(stepSize)))
+    # return float(Decimal(str(amount)).quantize(Decimal(str(stepSize)), ROUND_FLOOR))
+
+
+def fix_price(price, tickSize):
+    '''修补价格四舍五入'''
+    return float(
+            round(Decimal(str(price))/Decimal(str(tickSize))
+        ) \
+        * Decimal(str(tickSize)))
+    # return float(Decimal(str(price)).quantize(Decimal(str(tickSize)), ROUND_HALF_UP))
+
+def timeit(func):
+    def wrapper(*args, **kwargs):
+        nowTime = time.time()
+        res = func(*args, **kwargs)
+        spend_time = time.time() - nowTime
+        spend_time = round(spend_time * 1e6, 3)
+        print(f'{func.__name__} 耗时 {spend_time} us')
+        return res
+    return wrapper
+
+def get_backtest_set(base=""):
+    '''生成预设参数'''
+    # 开仓距离不能太近必须超过大部分价格tick运动的距离
+    open_list = [
+        0.0055,
+        0.0045,
+        0.0035,
+        0.0030,
+        0.0025,
+        0.0020,
+        0.0015,
+    ]
+    close_dict = dict()
+    for open in open_list:
+        close_dict[open] = [
+            open*0.1,
+            open*0.2,
+            ]
+    alpha_list = [0.0]
+    return open_list, close_dict, alpha_list
+
+def get_local_ip_list():
+    '''获取本地ip'''
+    import netifaces as ni
+    ipList = []
+    # print('检测服务器网络配置')
+    for dev in ni.interfaces():
+        print('dev:',dev)
+        if 'ens' in dev or 'eth' in dev or 'enp' in dev:
+            # print(ni.ifaddresses(dev))
+            for i in ni.ifaddresses(dev)[2]:
+                ip=i['addr']
+                print(f"检测到私有ip:{ip}")
+                if ip not in ipList:
+                    ipList.append(ip)
+    print(f"当前服务器私有ip为{ipList}")
+    return ipList
+    
+if __name__ == "__main__":
+
+    #########
+    if 0:
+        print(fix_amount(1.0, 0.1))
+        print(fix_amount(0.9, 0.05))
+        print(fix_amount(1.1, 0.1))
+        print(fix_amount(1.2, 0.5))
+        print(fix_amount(0.01, 0.05))
+    if 1:
+        print(fix_price(1.0, 0.1))
+        print(fix_price(0.9, 2.0))
+        print(fix_price(1.1, 0.1))
+        print(fix_price(1.2, 0.5))
+        print(fix_price(4999.99, 0.5))
+    #########
+    if 0:
+        # print(num_to_str(123.123))
+        print(get_backtest_set())
+
+    ####################
+    if 0:
+
+        p = get_params("config.toml")
+
+        loop = asyncio.get_event_loop()
+        
+        # loop.create_task(ding("123", 1, "https://discord.com/api/webhooks/907870708481265675/IfN4GqH4fj8HWS_FecH3Lrc2qtRyqsCHsSJVLFHlxY8ioHprfdxIMUNAfqkZZ6opzVEP"))
+        
+        loop.create_task(
+            _post_params(
+                "http://wwww.khods.com:8888/post_params", 
+                None,
+                ujson.dumps({
+                    "exchange":"binance_usdt_swap",
+                    "pair":"eth_usdt",
+                    "open":"0.001",
+                    "close":"0.0001",
+                    "refexchange":"binance_spot",
+                    "refpair":"eth_usdt",
+                    "profit":0.1,
+                })
+            )
+        )
+
+        loop.run_forever()
+    
+    ####################
+