浏览代码

最佳价差版本

JiahengHe 1 年之前
父节点
当前提交
09968497cf

+ 0 - 17
backtest.py

@@ -2,27 +2,10 @@
     基于深度信息的异步事件触发回测框架
     基于深度信息的异步事件触发回测框架
     作者 千千量化
     作者 千千量化
 '''
 '''
-import os
-import sys
-import asyncio
-import json, ujson
-import requests
-import hmac
-import base64
-import zlib
-import datetime
-import time
 import strategy as strategy
 import strategy as strategy
-import traceback
-import joblib
-import traceback
-import configparser
-import signal
-import random
 import utils
 import utils
 import model
 import model
 import time
 import time
-import copy
 import random
 import random
 
 
 def timeit(func):
 def timeit(func):

+ 9 - 9
config.toml

@@ -1,23 +1,23 @@
 broker_id = "kucoin"
 broker_id = "kucoin"
 account_name = "ku1"
 account_name = "ku1"
-access_key = ""
-secret_key = ""
-pass_key = ""
+access_key = "6393f3565f0d4500011f846b"
+secret_key = "9c0df8b7-daaa-493e-a53a-82703067f7dd"
+pass_key = "b87d055f"
 exchange = "kucoin_usdt_swap"
 exchange = "kucoin_usdt_swap"
-pair = "loom_usdt"
+pair = "meme_usdt"
 debug = "False"
 debug = "False"
-open = 0.02
+open = 0.01
 close = 0.0002
 close = 0.0002
 leverrate = 0.5
 leverrate = 0.5
 interval = 0.1
 interval = 0.1
-refexchange = "['binance_usdt_swap']"
-refpair = "['loom_usdt']"
-used_pct = 1
+refexchange = "['kucoin_usdt_swap']"
+refpair = "['meme_usdt']"
+used_pct = 0.9
 index = 0
 index = 0
 save = 0
 save = 0
 hold_coin = 0.0
 hold_coin = 0.0
 log = 1
 log = 1
-stoploss = "0.02"
+stoploss = "0.01"
 gamma = 0.999
 gamma = 0.999
 grid = 1
 grid = 1
 place_order_limit = 0
 place_order_limit = 0

+ 13 - 9
exchange/binance_usdt_swap_ws.py

@@ -1,17 +1,21 @@
-import aiohttp
-import time
 import asyncio
 import asyncio
-import zlib
-import json, ujson
-import zlib
+import csv
 import hashlib
 import hashlib
 import hmac
 import hmac
-import base64
+import logging
+import logging.handlers
+import random
+import sys
+import time
 import traceback
 import traceback
-import random, csv, sys, utils
-import logging, logging.handlers
+import utils
+import zlib
+
+import aiohttp
+import ujson
+
 import model
 import model
-from loguru import logger
+
 
 
 def empty_call(msg):
 def empty_call(msg):
     pass
     pass

+ 3 - 11
exchange/bybit_usdt_swap_rest.py

@@ -1,25 +1,17 @@
-import imp
-import random
-import re
 import aiohttp
 import aiohttp
 import time
 import time
 import asyncio
 import asyncio
-import zlib
-import json, ujson
+import ujson
 import hmac
 import hmac
-import base64
 import hashlib
 import hashlib
 import traceback
 import traceback
-import urllib
-from urllib import parse
 from urllib.parse import urljoin
 from urllib.parse import urljoin
-import datetime, sys
+import sys
 from urllib.parse import urlparse
 from urllib.parse import urlparse
-import logging, logging.handlers
+import logging.handlers
 import utils
 import utils
 import logging, logging.handlers
 import logging, logging.handlers
 import model
 import model
-from decimal import Decimal
 
 
 def empty_call(msg):
 def empty_call(msg):
     print(f'空的回调函数 {msg}')
     print(f'空的回调函数 {msg}')

+ 2 - 6
exchange/kucoin_usdt_swap_rest.py

@@ -1,19 +1,15 @@
-import random
 import aiohttp
 import aiohttp
 import time
 import time
 import asyncio
 import asyncio
-import zlib
 import json
 import json
 import hmac
 import hmac
 import base64
 import base64
 import hashlib
 import hashlib
 import traceback
 import traceback
-import urllib
-from urllib import parse
 from urllib.parse import urljoin
 from urllib.parse import urljoin
-import datetime, sys
+import sys
 from urllib.parse import urlparse
 from urllib.parse import urlparse
-import logging, logging.handlers
+import logging.handlers
 import utils
 import utils
 import logging, logging.handlers
 import logging, logging.handlers
 import model
 import model

+ 8 - 6
exchange/kucoin_usdt_swap_ws.py

@@ -1,21 +1,17 @@
 import aiohttp
 import aiohttp
 import time
 import time
 import asyncio
 import asyncio
-import zlib
-import json, ujson
-import zlib
+import ujson
 import hashlib
 import hashlib
 import hmac
 import hmac
 import base64
 import base64
 import traceback
 import traceback
-import random
-import gzip, csv, sys
+import csv, sys
 from uuid import uuid4
 from uuid import uuid4
 import logging, logging.handlers
 import logging, logging.handlers
 import utils
 import utils
 import model
 import model
 from decimal import Decimal
 from decimal import Decimal
-from loguru import logger
 
 
 def empty_call(msg):
 def empty_call(msg):
     # print(msg)
     # print(msg)
@@ -45,6 +41,7 @@ class KucoinUsdtSwapWs:
             "onOrder":empty_call,
             "onOrder":empty_call,
             "onTicker":empty_call,
             "onTicker":empty_call,
             "onDepth":empty_call,
             "onDepth":empty_call,
+            "onTrade":empty_call,
             "onExit":empty_call,
             "onExit":empty_call,
             }
             }
         self.is_print = is_print
         self.is_print = is_print
@@ -120,6 +117,7 @@ class KucoinUsdtSwapWs:
             self.update_t = msg['data']['sequence']
             self.update_t = msg['data']['sequence']
             self.ticker_info['bp'] = float(msg['data']['bids'][0][0])
             self.ticker_info['bp'] = float(msg['data']['bids'][0][0])
             self.ticker_info['ap'] = float(msg['data']['asks'][0][0])
             self.ticker_info['ap'] = float(msg['data']['asks'][0][0])
+            self.ticker_info['time'] = msg['data']['timestamp']
             self.callback['onTicker'](self.ticker_info)
             self.callback['onTicker'](self.ticker_info)
             ##### 标准化深度
             ##### 标准化深度
             mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
             mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
@@ -167,6 +165,7 @@ class KucoinUsdtSwapWs:
             self.update_t = msg['data']['sequence']
             self.update_t = msg['data']['sequence']
             self.ticker_info['bp'] = float(msg['data']['bestBidPrice'])
             self.ticker_info['bp'] = float(msg['data']['bestBidPrice'])
             self.ticker_info['ap'] = float(msg['data']['bestAskPrice'])
             self.ticker_info['ap'] = float(msg['data']['bestAskPrice'])
+            self.ticker_info['time'] = msg['data']['ts']
             self.callback['onTicker'](self.ticker_info)
             self.callback['onTicker'](self.ticker_info)
 
 
             bp = float(msg['data']['bestBidPrice'])
             bp = float(msg['data']['bestBidPrice'])
@@ -190,6 +189,8 @@ class KucoinUsdtSwapWs:
         elif side == 'sell':
         elif side == 'sell':
             self.sell_q += amount
             self.sell_q += amount
             self.sell_v += amount*price
             self.sell_v += amount*price
+        self.callback['onTrade']({'timestamp': msg["data"]['ts'], 'price': price, 'amount': amount, 'side': side})
+
 
 
     def _update_position(self, msg):
     def _update_position(self, msg):
         pos = model.Position()
         pos = model.Position()
@@ -348,6 +349,7 @@ class KucoinUsdtSwapWs:
                     channels=[
                     channels=[
                         # f"/contractMarket/tickerV2:{self.symbol}",
                         # f"/contractMarket/tickerV2:{self.symbol}",
                         f"/contractMarket/level2Depth50:{self.symbol}",
                         f"/contractMarket/level2Depth50:{self.symbol}",
+                        f"/contractMarket/execution:{self.symbol}"
                         ]
                         ]
                     if sub_trade:
                     if sub_trade:
                         channels += [f"/contractMarket/execution:{self.symbol}"]
                         channels += [f"/contractMarket/execution:{self.symbol}"]

+ 169 - 0
predictor_new.py

@@ -0,0 +1,169 @@
+import time
+from decimal import Decimal
+
+import utils
+from util.instant_volatility import InstantVolatilityIndicator
+from util.trading_intensity import TradingIntensityIndicator
+
+
+class PredictorNew:
+    '''
+        reference
+    '''
+
+    def __init__(self, ref_name=["Unknown Market"], alpha=[1.0 for _ in range(99)], gamma=0.999):
+        self.loop = 0
+        self.arr = []
+        self.trade_mp_series = []
+        self.ref_mp_series = []
+        self.ref_num = len(ref_name)
+        ### 定价
+        self.window = 10
+        ### 价格系数
+        self.alpha = alpha
+        # 参考
+        print('定价系数:', gamma)
+        self.gamma = gamma
+        self.avg_spread = [None for _ in range(self.ref_num)]
+        self.vol = InstantVolatilityIndicator(sampling_length=100, processing_length=1)
+        self.trading_intensity = TradingIntensityIndicator(sampling_length=20)
+
+
+    def processer(self):
+        '''
+            计算任务
+        '''
+        # update trade mp
+        bp = self.arr[-1][utils.BP_INDEX]
+        ap = self.arr[-1][utils.AP_INDEX]
+        mp = (bp + ap) * 0.5
+        self.trade_mp_series.append(mp)
+        # 更新参考盘口mp
+        ref_mp = []
+        for i in range(self.ref_num):
+            bp = self.arr[-1][utils.LEN * (1 + i) + utils.BP_INDEX]
+            ap = self.arr[-1][utils.LEN * (1 + i) + utils.AP_INDEX]
+            mp = (bp + ap) * 0.5
+            ref_mp.append(mp)
+        self.ref_mp_series.append(ref_mp)
+        # 偏差计算
+        self._update_avg_spread()
+
+    def _update_avg_spread(self):
+        '''
+            更新平均偏差
+        '''
+        # 计算偏差1
+        for i in range(self.ref_num):
+            bias = self.ref_mp_series[-1][i] * self.alpha[i] - self.trade_mp_series[-1]
+            # 如果是刚启动 gamma不能太大
+            if self.loop < 100:
+                gamma = 0.9
+            else:
+                gamma = self.gamma
+            if self.avg_spread[i] == None:
+                self.avg_spread[i] = bias
+            else:
+                self.avg_spread[i] = self.avg_spread[i] * gamma + bias * (1 - gamma)
+
+    def check_length(self):
+        # 行情缓存
+        if len(self.arr) > self.window: del (self.arr[0])
+        if len(self.trade_mp_series) > self.window: del (self.trade_mp_series[0])
+        if len(self.ref_mp_series) > self.window: del (self.ref_mp_series[0])
+
+    # @utils.timeit
+    def onTime(self, data):
+        if isinstance(data, list):
+            if len(data) > 0:
+                self.loop += 1
+                self.arr.append(data)
+                self.processer()
+                self.check_length()
+            else:
+                print("行情数据为空")
+        else:
+            print("行情数据为None")
+
+    def market_update(self, best_mid_price, time_num):
+        self.vol.add_sample(best_mid_price)
+        self.trading_intensity.calculate(best_mid_price, time_num)
+
+    # @utils.timeit
+    # def Get_ref(self, ref_ticker):
+    #     '''
+    #         get ref price
+    #     '''
+    #     ref_mid = []
+    #     for i in range(self.ref_num):
+    #         ref_mid.append(
+    #             [ref_ticker[i][0] * self.alpha[i] - self.avg_spread[i],
+    #              ref_ticker[i][1] * self.alpha[i] - self.avg_spread[i]]
+    #         )
+    #     return ref_mid
+
+    def Get_ref(self, ref_ticker):
+        '''
+            get ref price
+        '''
+        bp = ref_ticker[0][0]
+        ap = ref_ticker[0][1]
+        mp = Decimal((bp + ap) * 0.5)
+        std = self.vol.current_value
+        gamma = Decimal(0.5)
+        alpha, kappa = self.trading_intensity.current_value
+        ref_mid = []
+        if alpha != 0 and kappa > 0 and std != 0:
+            factor = Decimal(1) + gamma / kappa
+            _optimal_spread = gamma * std
+            _optimal_spread += 2 * factor.ln()
+            _optimal_ask = mp + _optimal_spread / 2
+            _optimal_bid = mp - _optimal_spread / 2
+            ref_mid.append(
+                [_optimal_ask,
+                 _optimal_bid]
+            )
+        return ref_mid
+
+if __name__ == "__main__":
+
+    import pandas as pd
+    import matplotlib
+
+    matplotlib.use('TkAgg')
+    import matplotlib.pyplot as plt
+
+    arr = pd.read_csv('history/ftm_usdt_binance_usdt_swap.csv').values.tolist()
+
+
+    def line_data_to_tickers(data, ref_num):
+        ref_tickers = []
+        for i in range(ref_num):
+            ref_tickers.append([data[utils.LEN * (i + 1) + utils.BP_INDEX], data[utils.LEN * (i + 1) + utils.AP_INDEX]])
+        return ref_tickers
+
+
+    ref_num = len(arr[0]) // utils.LEN - 1
+
+    p = PredictorNew(ref_name=["unkwon" for _ in range(ref_num)])
+    t = []
+    ref_index = 1
+
+    for data in arr:
+        p.onTime(data)
+        trade_mp = (data[utils.BP_INDEX] + data[utils.AP_INDEX]) * 0.5
+        ref_price = p.Get_ref(line_data_to_tickers(data, ref_num))
+        t.append([
+            trade_mp,
+            (ref_price[ref_index][0] + ref_price[ref_index][1]) * 0.5,
+        ])
+
+    t = pd.DataFrame(t, columns=['mp', 'ref'])
+
+    if 1:
+        plt.figure()
+        plt.plot(t['mp'], 'k')
+        plt.plot(t['ref'], 'g')
+        plt.grid()
+        plt.show()
+

+ 26 - 11
quant.py

@@ -8,15 +8,13 @@ import model
 import logging, logging.handlers
 import logging, logging.handlers
 import signal
 import signal
 import os, json, sys
 import os, json, sys
-import predictor
+import predictor_new
 import backtest
 import backtest
-import multiprocessing
 import random
 import random
 import psutil
 import psutil
 import ujson
 import ujson
 import broker
 import broker
 from decimal import Decimal
 from decimal import Decimal
-from loguru import logger
 
 
 VERSION = utils.VERSION
 VERSION = utils.VERSION
 
 
@@ -202,6 +200,7 @@ class Quant:
             self.ws_ref[self.ref_name[i]] = broker.newWs(exchange)(cp, colo=_colo)
             self.ws_ref[self.ref_name[i]] = broker.newWs(exchange)(cp, colo=_colo)
             self.ws_ref[self.ref_name[i]].callback['onTicker']=self.update_ticker
             self.ws_ref[self.ref_name[i]].callback['onTicker']=self.update_ticker
             self.ws_ref[self.ref_name[i]].callback['onDepth']=self.update_depth
             self.ws_ref[self.ref_name[i]].callback['onDepth']=self.update_depth
+            self.ws_ref[self.ref_name[i]].callback['onTrade'] = self.update_trade
             self.ws_ref[self.ref_name[i]].logger = self.logger
             self.ws_ref[self.ref_name[i]].logger = self.logger
         # 添加回调
         # 添加回调
         self.ws.callback = {
         self.ws.callback = {
@@ -210,6 +209,7 @@ class Quant:
             'onPosition':self.update_position,
             'onPosition':self.update_position,
             'onEquity':self.update_equity,
             'onEquity':self.update_equity,
             'onOrder':self.update_order,
             'onOrder':self.update_order,
+            'onTrade':self.update_trade,
             'onExit':self.update_exit,
             'onExit':self.update_exit,
             }
             }
         self.rest.callback = {
         self.rest.callback = {
@@ -218,6 +218,7 @@ class Quant:
             'onPosition':self.update_position,
             'onPosition':self.update_position,
             'onEquity':self.update_equity,
             'onEquity':self.update_equity,
             'onOrder':self.update_order,
             'onOrder':self.update_order,
+            'onTrade': self.update_trade,
             'onExit':self.update_exit,
             'onExit':self.update_exit,
             }
             }
         self.rest.logger = self.logger
         self.rest.logger = self.logger
@@ -236,7 +237,7 @@ class Quant:
             # 交易shib 参考 shib
             # 交易shib 参考 shib
                 price_alpha.append(1.0)
                 price_alpha.append(1.0)
         self.logger.info(f'价格系数{price_alpha}')
         self.logger.info(f'价格系数{price_alpha}')
-        self.Predictor = predictor.Predictor(ref_name=self.ref_name, alpha=price_alpha, gamma=float(self.params.gamma))
+        self.Predictor = predictor_new.PredictorNew(ref_name=self.ref_name, alpha=price_alpha, gamma=float(self.params.gamma))
         # 初始化参数
         # 初始化参数
         self.strategy.trade_open_dist = float(self.params.open)
         self.strategy.trade_open_dist = float(self.params.open)
         self.strategy.trade_close_dist = float(self.params.close)
         self.strategy.trade_close_dist = float(self.params.close)
@@ -574,6 +575,10 @@ class Quant:
             update ticker infomation
             update ticker infomation
         '''
         '''
         name = data['name']
         name = data['name']
+        if name == self.ref_name[0]:
+            min_price = (data['bp'] + data['ap']) * 0.5
+            self.Predictor.market_update(min_price, data['time'])
+
         # 记录tick更新时间
         # 记录tick更新时间
         # self.market_update_time[name] = time.time()
         # self.market_update_time[name] = time.time()
         self.tickers[name] = data
         self.tickers[name] = data
@@ -585,6 +590,13 @@ class Quant:
         else:
         else:
             pass
             pass
 
 
+    def update_trade(self, data):
+        self.loop.create_task(self._update_trade(data))
+
+    async def _update_trade(self, data):
+        print(f"trade数据:{data}")
+        self.Predictor.trading_intensity.c_register_trade(data)
+
     # @utils.timeit
     # @utils.timeit
     async def _update_depth(self, data):
     async def _update_depth(self, data):
         '''
         '''
@@ -930,10 +942,12 @@ class Quant:
         ref_tickers = []
         ref_tickers = []
         for i in self.ref_name:
         for i in self.ref_name:
             ref_tickers.append([self.tickers[i]['bp'], self.tickers[i]['ap']])
             ref_tickers.append([self.tickers[i]['bp'], self.tickers[i]['ap']])
-        self.tradeMsg.ref_price = self.Predictor.Get_ref(ref_tickers)
+        ref_price = self.Predictor.Get_ref(ref_tickers)
+        if len(ref_price) == 0:
+            return
         # logger.info('ref_price={}, market={}, predict={}'.format(
         # logger.info('ref_price={}, market={}, predict={}'.format(
         #     self.tradeMsg.ref_price, self.tradeMsg.market, self.tradeMsg.predict))
         #     self.tradeMsg.ref_price, self.tradeMsg.market, self.tradeMsg.predict))
-
+        self.tradeMsg.ref_price = ref_price
     async def server_handle(self, request):
     async def server_handle(self, request):
         '''中控数据接口'''
         '''中控数据接口'''
         if 'spot' in self.exchange:
         if 'spot' in self.exchange:
@@ -1112,11 +1126,12 @@ class Quant:
                 self.exit_msg = msg
                 self.exit_msg = msg
                 self.stop()
                 self.stop()
         ###### 定价异常风控 ######
         ###### 定价异常风控 ######
-        if abs(self.strategy.ref_price-self.strategy.mp)/self.strategy.mp > 0.03:
-            msg = f"{self.acct_name} 定价偏离过大 怀疑异常 退出"
-            self.logger.error(msg)
-            self.exit_msg = msg
-            self.stop()
+        if self.strategy.ref_price is not None and self.strategy.ref_price != 0:
+            if abs(self.strategy.ref_price-self.strategy.mp)/self.strategy.mp > 0.03:
+                msg = f"{self.acct_name} 定价偏离过大 怀疑异常 退出"
+                self.logger.error(msg)
+                self.exit_msg = msg
+                self.stop()
 
 
     async def exit(self, delay=0):
     async def exit(self, delay=0):
         '''退出操作'''
         '''退出操作'''

+ 3 - 5
strategy.py

@@ -4,7 +4,6 @@ import model
 import logging, logging.handlers
 import logging, logging.handlers
 from decimal import Decimal
 from decimal import Decimal
 from decimal import ROUND_HALF_UP, ROUND_FLOOR
 from decimal import ROUND_HALF_UP, ROUND_FLOOR
-from loguru import logger
 
 
 class Strategy:
 class Strategy:
     '''
     '''
@@ -194,10 +193,9 @@ class Strategy:
                 self.maker_mode = 'follow'
                 self.maker_mode = 'follow'
 
 
             ###### ref price ######
             ###### ref price ######
-            if data.ref_price == None:
-                self.ref_bp = self.bp
-                self.ref_ap = self.ap
-                self.ref_price = self.mp
+            if data.ref_price is None or len(data.ref_price) == 0:
+                print('参考价格还未预热完成,等待预热...')
+                return 0
             else:
             else:
                 self.ref_bp = data.ref_price[self.ref_index][0]
                 self.ref_bp = data.ref_price[self.ref_index][0]
                 self.ref_ap = data.ref_price[self.ref_index][1]
                 self.ref_ap = data.ref_price[self.ref_index][1]

+ 0 - 0
util/__init__.py


+ 66 - 0
util/base_trailing_indicator.py

@@ -0,0 +1,66 @@
+import logging
+from abc import ABC, abstractmethod
+from decimal import Decimal
+
+import numpy as np
+
+from util.ring_buffer import RingBuffer
+
+
+class BaseTrailingIndicator(ABC):
+
+    def __init__(self, sampling_length: int = 30, processing_length: int = 15):
+        self._sampling_buffer = RingBuffer(sampling_length)
+        self._processing_buffer = RingBuffer(processing_length)
+        self._samples_length = 0
+
+    def add_sample(self, value: float):
+        self._sampling_buffer.add_value(value)
+        indicator_value = self._indicator_calculation()
+        self._processing_buffer.add_value(indicator_value)
+
+    @abstractmethod
+    def _indicator_calculation(self) -> float:
+        raise NotImplementedError
+
+    def _processing_calculation(self) -> float:
+        """
+        Processing of the processing buffer to return final value.
+        Default behavior is buffer average
+        """
+        return np.mean(self._processing_buffer.get_as_numpy_array())
+
+    @property
+    def current_value(self) -> Decimal:
+        return Decimal(self._processing_calculation())
+
+    @property
+    def is_sampling_buffer_full(self) -> bool:
+        return self._sampling_buffer.is_full
+
+    @property
+    def is_processing_buffer_full(self) -> bool:
+        return self._processing_buffer.is_full
+
+    @property
+    def is_sampling_buffer_changed(self) -> bool:
+        buffer_len = len(self._sampling_buffer.get_as_numpy_array())
+        is_changed = self._samples_length != buffer_len
+        self._samples_length = buffer_len
+        return is_changed
+
+    @property
+    def sampling_length(self) -> int:
+        return self._sampling_buffer.length
+
+    @sampling_length.setter
+    def sampling_length(self, value):
+        self._sampling_buffer.length = value
+
+    @property
+    def processing_length(self) -> int:
+        return self._processing_buffer.length
+
+    @processing_length.setter
+    def processing_length(self, value):
+        self._processing_buffer.length = value

+ 19 - 0
util/instant_volatility.py

@@ -0,0 +1,19 @@
+from .base_trailing_indicator import BaseTrailingIndicator
+import numpy as np
+
+
+class InstantVolatilityIndicator(BaseTrailingIndicator):
+    def __init__(self, sampling_length: int = 30, processing_length: int = 15):
+        super().__init__(sampling_length, processing_length)
+
+    def _indicator_calculation(self) -> float:
+        # The standard deviation should be calculated between ticks and not with a mean of the whole buffer
+        # Otherwise if the asset is trending, changing the length of the buffer would result in a greater volatility as more ticks would be further away from the mean
+        # which is a nonsense result. If volatility of the underlying doesn't change in fact, changing the length of the buffer shouldn't change the result.
+        np_sampling_buffer = self._sampling_buffer.get_as_numpy_array()
+        vol = np.sqrt(np.sum(np.square(np.diff(np_sampling_buffer))) / np_sampling_buffer.size)
+        return vol
+
+    def _processing_calculation(self) -> float:
+        # Only the last calculated volatlity, not an average of multiple past volatilities
+        return self._processing_buffer.get_last_value()

+ 99 - 0
util/ring_buffer.py

@@ -0,0 +1,99 @@
+import numpy as np
+
+
+class RingBuffer:
+    def __init__(self, length=20):
+        self._length = length
+        self._buffer = np.zeros(length, dtype=np.float64)
+        self._delimiter = 0
+        self._is_full = False
+
+    def __dealloc__(self):
+        self._buffer = None
+
+    def c_add_value(self, value):
+        self._buffer[self._delimiter] = value
+        self.c_increment_delimiter()
+
+    def c_increment_delimiter(self):
+        self._delimiter = (self._delimiter + 1) % self._length
+        if not self._is_full and self._delimiter == 0:
+            self._is_full = True
+
+    def c_is_empty(self):
+        return (not self._is_full) and (0 == self._delimiter)
+
+    def c_get_last_value(self):
+        if self.c_is_empty():
+            return np.nan
+        return self._buffer[self._delimiter - 1]
+
+    def c_is_full(self):
+        return self._is_full
+
+    def c_mean_value(self):
+        result = np.nan
+        if self._is_full:
+            result = np.mean(self.c_get_as_numpy_array())
+        return result
+
+    def c_variance(self):
+        result = np.nan
+        if self._is_full:
+            result = np.var(self.c_get_as_numpy_array())
+        return result
+
+    def c_std_dev(self):
+        result = np.nan
+        if self._is_full:
+            result = np.std(self.c_get_as_numpy_array())
+        return result
+
+    def c_get_as_numpy_array(self):
+        if not self._is_full:
+            indexes = np.arange(0, stop=self._delimiter, dtype=np.int16)
+        else:
+            indexes = np.arange(self._delimiter, stop=self._delimiter + self._length,
+                                dtype=np.int16) % self._length
+        return np.asarray(self._buffer)[indexes]
+
+    def add_value(self, val):
+        self.c_add_value(val)
+
+    def get_as_numpy_array(self):
+        return self.c_get_as_numpy_array()
+
+    def get_last_value(self):
+        return self.c_get_last_value()
+
+    @property
+    def is_full(self):
+        return self.c_is_full()
+
+    @property
+    def mean_value(self):
+        return self.c_mean_value()
+
+    @property
+    def std_dev(self):
+        return self.c_std_dev()
+
+    @property
+    def variance(self):
+        return self.c_variance()
+
+    @property
+    def length(self) -> int:
+        return self._length
+
+    @length.setter
+    def length(self, value):
+        data = self.get_as_numpy_array()
+
+        self._length = value
+        self._buffer = np.zeros(value, dtype=np.float64)
+        self._delimiter = 0
+        self._is_full = False
+
+        for val in data[-value:]:
+            self.add_value(val)

+ 135 - 0
util/trading_intensity.py

@@ -0,0 +1,135 @@
+from decimal import Decimal
+from typing import Tuple
+
+import numpy as np
+from scipy.optimize import curve_fit
+
+
+class TradingIntensityIndicator:
+
+    def __init__(self, sampling_length: int = 30):
+        self._alpha = Decimal(0)
+        self._kappa = Decimal(0)
+        self._trade_samples = {}
+        self._current_trade_sample = []
+        self._sampling_length = sampling_length
+        self._samples_length = 0
+        self._last_quotes = []
+
+    @property
+    def current_value(self) -> Tuple[Decimal, Decimal]:
+        return self._alpha, self._kappa
+
+    @property
+    def is_sampling_buffer_full(self) -> bool:
+        return len(self._trade_samples.keys()) == self._sampling_length
+
+    @property
+    def is_sampling_buffer_changed(self) -> bool:
+        is_changed = self._samples_length != len(self._trade_samples.keys())
+        self._samples_length = len(self._trade_samples.keys())
+        return is_changed
+
+    @property
+    def sampling_length(self) -> int:
+        return self._sampling_length
+
+    @sampling_length.setter
+    def sampling_length(self, new_len: int):
+        self._sampling_length = new_len
+
+    @property
+    def last_quotes(self) -> list:
+        """A helper method to be used in unit tests"""
+        return self._last_quotes
+
+    @last_quotes.setter
+    def last_quotes(self, value):
+        """A helper method to be used in unit tests"""
+        self._last_quotes = value
+
+    def calculate(self, price, timestamp):
+        """A helper method to be used in unit tests"""
+        self.c_calculate(price, timestamp)
+
+    def c_calculate(self, price, timestamp):
+        # Descending order of price-timestamp quotes
+        self._last_quotes = [{'timestamp': timestamp, 'price': price}] + self._last_quotes
+
+        latest_processed_quote_idx = None
+        for trade in self._current_trade_sample:
+            for i, quote in enumerate(self._last_quotes):
+                if quote["timestamp"] < trade['timestamp']:
+                    if latest_processed_quote_idx is None or i < latest_processed_quote_idx:
+                        latest_processed_quote_idx = i
+                    trade = {"price_level": abs(trade["price"] - float(quote["price"])), "amount": trade["amount"]}
+
+                    if quote["timestamp"] + 1 not in self._trade_samples.keys():
+                        self._trade_samples[quote["timestamp"] + 1] = []
+
+                    self._trade_samples[quote["timestamp"] + 1] += [trade]
+                    break
+
+        # THere are no trades left to process
+        self._current_trade_sample = []
+        # Store quotes that happened after the latest trade + one before
+        if latest_processed_quote_idx is not None:
+            self._last_quotes = self._last_quotes[0:latest_processed_quote_idx + 1]
+
+        if len(self._trade_samples.keys()) > self._sampling_length:
+            timestamps = list(self._trade_samples.keys())
+            timestamps.sort()
+            timestamps = timestamps[-self._sampling_length:]
+
+            trade_samples = {}
+            for timestamp in timestamps:
+                trade_samples[timestamp] = self._trade_samples[timestamp]
+            self._trade_samples = trade_samples
+
+        if self.is_sampling_buffer_full:
+            self.c_estimate_intensity()
+
+    def register_trade(self, trade):
+        """A helper method to be used in unit tests"""
+        self.c_register_trade(trade)
+
+    def c_register_trade(self, trade):
+        self._current_trade_sample.append(trade)
+
+    def c_estimate_intensity(self):
+
+        # Calculate lambdas / trading intensities
+        lambdas = []
+
+        trades_consolidated = {}
+        price_levels = []
+        for timestamp in self._trade_samples.keys():
+            tick = self._trade_samples[timestamp]
+            for trade in tick:
+                if trade['price_level'] not in trades_consolidated.keys():
+                    trades_consolidated[trade['price_level']] = 0
+                    price_levels += [trade['price_level']]
+
+                trades_consolidated[trade['price_level']] += trade['amount']
+
+        price_levels = sorted(price_levels, reverse=True)
+
+        for price_level in price_levels:
+            lambdas += [trades_consolidated[price_level]]
+
+        # Adjust to be able to calculate log
+        lambdas_adj = [10 ** -10 if x == 0 else x for x in lambdas]
+
+        # Fit the probability density function; reuse previously calculated parameters as initial values
+        try:
+            params = curve_fit(lambda t, a, b: a * np.exp(-b * t),
+                               price_levels,
+                               lambdas_adj,
+                               p0=(self._alpha, self._kappa),
+                               method='dogbox',
+                               bounds=([0, 0], [np.inf, np.inf]))
+
+            self._kappa = Decimal(str(params[0][1]))
+            self._alpha = Decimal(str(params[0][0]))
+        except (RuntimeError, ValueError) as e:
+            pass

+ 159 - 159
utils.py

@@ -1,44 +1,35 @@
-import json
 import traceback
 import traceback
-import utils
 import model
 import model
 import toml, time, random
 import toml, time, random
-import os, sys, asyncio, aiohttp
+import sys, aiohttp
 import socket
 import socket
-import asyncio
 import requests
 import requests
-import ujson
 from decimal import Decimal
 from decimal import Decimal
-from decimal import ROUND_HALF_UP, ROUND_FLOOR
 import gzip
 import gzip
-import csv
-import os
-import base64
 from Crypto.Cipher import AES
 from Crypto.Cipher import AES
-from Crypto import Random
 import os
 import os
 import base64
 import base64
 import json
 import json
 
 
-parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 
-sys.path.insert(0,parentdir)  
+parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.insert(0, parentdir)
 
 
 ############### 全局配置
 ############### 全局配置
 VERSION = "2022-04-18"
 VERSION = "2022-04-18"
-CHILD_RUN_SECOND = 60 * 60 * 24 # child process max run time per loop 
-EARLY_STOP_SECOND = 60 * 60 * 2 # child early stop min check time
-BACKTEST_PREHOT_SECOND = 60 * 30 # backtest pre hot time
-DUMMY_RUN_SECOND = 60 * 60 * 12 # dummy process max run time per loop 
-DUMMY_EARLY_STOP_SECOND = 60 * 60 # dummy process max run time per loop 
-POST_SIDE_LIMIT = [0] # post side limit
-MARKET_DELAY_LIMIT = 30000 # market update delay limit threhold unit:ms
+CHILD_RUN_SECOND = 60 * 60 * 24  # child process max run time per loop
+EARLY_STOP_SECOND = 60 * 60 * 2  # child early stop min check time
+BACKTEST_PREHOT_SECOND = 60 * 30  # backtest pre hot time
+DUMMY_RUN_SECOND = 60 * 60 * 12  # dummy process max run time per loop
+DUMMY_EARLY_STOP_SECOND = 60 * 60  # dummy process max run time per loop
+POST_SIDE_LIMIT = [0]  # post side limit
+MARKET_DELAY_LIMIT = 30000  # market update delay limit threhold unit:ms
 GRID = 1
 GRID = 1
 STOPLOSS = 0.02
 STOPLOSS = 0.02
 GAMMA = 0.999
 GAMMA = 0.999
 ###### market行情数据长度 标准化n档深度+6档成交信息 ######
 ###### market行情数据长度 标准化n档深度+6档成交信息 ######
 LEVEL = 1
 LEVEL = 1
-TRADE_LEN = 2 # 最高 最低 成交价
-LEN = LEVEL * 4 + TRADE_LEN # 总长度
+TRADE_LEN = 2  # 最高 最低 成交价
+LEN = LEVEL * 4 + TRADE_LEN  # 总长度
 BP_INDEX = LEVEL * 0
 BP_INDEX = LEVEL * 0
 BQ_INDEX = LEVEL * 0 + 1
 BQ_INDEX = LEVEL * 0 + 1
 AP_INDEX = LEVEL * 2
 AP_INDEX = LEVEL * 2
@@ -56,15 +47,18 @@ BACKTEST_DELAY = 0.15
 
 
 global base_cid
 global base_cid
 base_cid = 0
 base_cid = 0
+
+
 def get_cid(broker=None):
 def get_cid(broker=None):
     global base_cid
     global base_cid
     base_cid += 1
     base_cid += 1
     if base_cid > 999:
     if base_cid > 999:
-        base_cid=0
-    cid = str(time.time())[4:10]+str(random.randint(1,999))+str(base_cid)
+        base_cid = 0
+    cid = str(time.time())[4:10] + str(random.randint(1, 999)) + str(base_cid)
     if broker:
     if broker:
         cid = broker + cid
         cid = broker + cid
-    return cid      
+    return cid
+
 
 
 def csv_to_gz_and_remove():
 def csv_to_gz_and_remove():
     def List_files(filepath, substr):
     def List_files(filepath, substr):
@@ -78,12 +72,13 @@ def csv_to_gz_and_remove():
 
 
     for file in List_files('./', '.csv'):
     for file in List_files('./', '.csv'):
         if '.gz' not in file:
         if '.gz' not in file:
-            data = open(file, 'rb' ).read()
+            data = open(file, 'rb').read()
             with gzip.open(file + '.gz', 'a') as zip:
             with gzip.open(file + '.gz', 'a') as zip:
                 zip.write(data)
                 zip.write(data)
                 zip.close()
                 zip.close()
             os.remove(file)
             os.remove(file)
 
 
+
 def get_params(fname):
 def get_params(fname):
     # 读取配置
     # 读取配置
     try:
     try:
@@ -92,9 +87,9 @@ def get_params(fname):
         f = open(fname)
         f = open(fname)
         data = f.read()
         data = f.read()
         text = base64.b64decode(data)
         text = base64.b64decode(data)
-        cryptor = AES.new(key =bytes("qFHFPv6MugrSTkEsWFs8wCDg3iC6!er%".encode()), mode=AES.MODE_ECB)
+        cryptor = AES.new(key=bytes("qFHFPv6MugrSTkEsWFs8wCDg3iC6!er%".encode()), mode=AES.MODE_ECB)
         plain_text = cryptor.decrypt(text)
         plain_text = cryptor.decrypt(text)
-        paddingLen = plain_text[len(plain_text)-1]
+        paddingLen = plain_text[len(plain_text) - 1]
         msg = plain_text[0:-paddingLen]
         msg = plain_text[0:-paddingLen]
         msg = msg.decode()
         msg = msg.decode()
         params = toml.loads(msg)
         params = toml.loads(msg)
@@ -107,7 +102,7 @@ def get_params(fname):
     p.pass_key = params['pass_key'].replace(" ", "") if 'pass_key' in params else 'qwer1234'
     p.pass_key = params['pass_key'].replace(" ", "") if 'pass_key' in params else 'qwer1234'
     # 经纪商id
     # 经纪商id
     broker_id_from_config = params['broker_id'] if 'broker_id' in params else ""
     broker_id_from_config = params['broker_id'] if 'broker_id' in params else ""
-    p.broker_id = get_broker_id( broker_id_from_config, params['exchange'])
+    p.broker_id = get_broker_id(broker_id_from_config, params['exchange'])
     # 交易盘口
     # 交易盘口
     p.exchange = params['exchange'] if 'exchange' in params else ""
     p.exchange = params['exchange'] if 'exchange' in params else ""
     # 交易品种
     # 交易品种
@@ -123,15 +118,18 @@ def get_params(fname):
     # 杠杆大小
     # 杠杆大小
     p.leverrate = float(params['leverrate']) if 'leverrate' in params else 1.0
     p.leverrate = float(params['leverrate']) if 'leverrate' in params else 1.0
     # 参考盘口
     # 参考盘口
-    p.refexchange = params['refexchange'].replace('[','').replace(']','').replace("'",'').replace(" ", "").split(',') if "refexchange" in params else ""
+    p.refexchange = params['refexchange'].replace('[', '').replace(']', '').replace("'", '').replace(" ", "").split(
+        ',') if "refexchange" in params else ""
     # 参考品种
     # 参考品种
-    p.refpair = params['refpair'].replace('[','').replace(']','').replace("'",'').replace(" ", "").split(',') if "refpair" in params else ""
+    p.refpair = params['refpair'].replace('[', '').replace(']', '').replace("'", '').replace(" ", "").split(
+        ',') if "refpair" in params else ""
     # 网络代理
     # 网络代理
-    p.proxy = params['proxy'] if 'proxy' in params else None # 仅在win下有效
+    p.proxy = params['proxy'] if 'proxy' in params else None  # 仅在win下有效
     # 账户资金使用比例
     # 账户资金使用比例
     p.used_pct = params['used_pct'] if 'used_pct' in params else "0.9"
     p.used_pct = params['used_pct'] if 'used_pct' in params else "0.9"
     # discord播报地址
     # discord播报地址
-    p.webhook = params['webhook'] if 'webhook' in params else "https://discord.com/api/webhooks/907870708481265675/IfN4GqH4fj8HWS_FecH3Lrc2qtRyqsCHsSJVLFHlxY8ioHprfdxIMUNAfqkZZ6opzVEP"  
+    p.webhook = params[
+        'webhook'] if 'webhook' in params else "https://discord.com/api/webhooks/907870708481265675/IfN4GqH4fj8HWS_FecH3Lrc2qtRyqsCHsSJVLFHlxY8ioHprfdxIMUNAfqkZZ6opzVEP"
     # 默认第n参考盘口
     # 默认第n参考盘口
     p.index = params['index'] if 'index' in params else 0
     p.index = params['index'] if 'index' in params else 0
     # 止损比例 0.02 = 2%
     # 止损比例 0.02 = 2%
@@ -144,11 +142,11 @@ def get_params(fname):
     p.backtest = params['backtest'] if 'backtest' in params else 1
     p.backtest = params['backtest'] if 'backtest' in params else 1
     # 保存实时行情 会有巨大性能损耗
     # 保存实时行情 会有巨大性能损耗
     p.save = params['save'] if 'save' in params else 0
     p.save = params['save'] if 'save' in params else 0
-    p.place_order_limit = params['place_order_limit'] if 'place_order_limit' in params else 0 # 允许的每秒下单次数
+    p.place_order_limit = params['place_order_limit'] if 'place_order_limit' in params else 0  # 允许的每秒下单次数
     # 是否启用colocation技术
     # 是否启用colocation技术
-    p.colo = params['colo'] if 'colo' in params else 0 
+    p.colo = params['colo'] if 'colo' in params else 0
     # 是否启用fast行情 会增加性能开销
     # 是否启用fast行情 会增加性能开销
-    p.fast = params['fast'] if 'fast' in params else 1 
+    p.fast = params['fast'] if 'fast' in params else 1
     # 选择指定的私有ip进行网络通信 默认0 用于多网卡多ip的实例
     # 选择指定的私有ip进行网络通信 默认0 用于多网卡多ip的实例
     p.ip = params['ip'] if 'ip' in params else 0
     p.ip = params['ip'] if 'ip' in params else 0
     # 合约不允许holdcoin持有底仓币
     # 合约不允许holdcoin持有底仓币
@@ -173,7 +171,8 @@ def get_params(fname):
     print(p)
     print(p)
     return p
     return p
 
 
-def get_broker_id(broker_id , exchange_name):
+
+def get_broker_id(broker_id, exchange_name):
     '''处理brokerid特殊情况'''
     '''处理brokerid特殊情况'''
     if 'binance' in exchange_name:
     if 'binance' in exchange_name:
         return broker_id
         return broker_id
@@ -182,6 +181,7 @@ def get_broker_id(broker_id , exchange_name):
     else:
     else:
         return ""
         return ""
 
 
+
 # 报单频率限制等级
 # 报单频率限制等级
 BASIC_LIMIT = 100
 BASIC_LIMIT = 100
 GATE_SPOT_LIMIT = 10.0
 GATE_SPOT_LIMIT = 10.0
@@ -192,92 +192,96 @@ BINANCE_USDT_SWAP_LIMIT = 5.0
 BINANCE_SPOT_LIMIT = 2.0
 BINANCE_SPOT_LIMIT = 2.0
 COINEX_SPOT_LIMIT = 20.0
 COINEX_SPOT_LIMIT = 20.0
 COINEX_USDT_SWAP_LIMIT = 20.0
 COINEX_USDT_SWAP_LIMIT = 20.0
-OKEX_USDT_SWAP_LIMIT= 30.0
+OKEX_USDT_SWAP_LIMIT = 30.0
 BITGET_USDT_SWAP_LIMIT = 10.0
 BITGET_USDT_SWAP_LIMIT = 10.0
 BYBIT_USDT_SWAP_LIMIT = 1.0
 BYBIT_USDT_SWAP_LIMIT = 1.0
 MEXC_SPOT_LIMIT = 333
 MEXC_SPOT_LIMIT = 333
 RATIO = 4.0
 RATIO = 4.0
 
 
+
 def get_limit_requests_num_per_second(exchange, limit=0):
 def get_limit_requests_num_per_second(exchange, limit=0):
     '''每秒请求频率'''
     '''每秒请求频率'''
     if limit != 0:
     if limit != 0:
-        return limit*RATIO
+        return limit * RATIO
     elif exchange == "gate_spot":
     elif exchange == "gate_spot":
-        return GATE_SPOT_LIMIT*RATIO
-    elif exchange == "gate_usdt_swap": # 100/s
-        return GATE_USDT_SWAP_LIMIT*RATIO
-    elif exchange == "kucoin_spot": # 15/s
-        return KUCOIN_SPOT_LIMIT*RATIO
+        return GATE_SPOT_LIMIT * RATIO
+    elif exchange == "gate_usdt_swap":  # 100/s
+        return GATE_USDT_SWAP_LIMIT * RATIO
+    elif exchange == "kucoin_spot":  # 15/s
+        return KUCOIN_SPOT_LIMIT * RATIO
     elif exchange == "kucoin_usdt_swap":
     elif exchange == "kucoin_usdt_swap":
-        return KUCOIN_USDT_SWAP_LIMIT*RATIO
+        return KUCOIN_USDT_SWAP_LIMIT * RATIO
     elif exchange == "binance_usdt_swap":
     elif exchange == "binance_usdt_swap":
-        return BINANCE_USDT_SWAP_LIMIT*RATIO
+        return BINANCE_USDT_SWAP_LIMIT * RATIO
     elif exchange == "binance_spot":
     elif exchange == "binance_spot":
-        return BINANCE_SPOT_LIMIT*RATIO
+        return BINANCE_SPOT_LIMIT * RATIO
     elif exchange == "coinex_spot":
     elif exchange == "coinex_spot":
-        return COINEX_SPOT_LIMIT*RATIO
+        return COINEX_SPOT_LIMIT * RATIO
     elif exchange == "coinex_usdt_swap":
     elif exchange == "coinex_usdt_swap":
-        return COINEX_USDT_SWAP_LIMIT*RATIO
+        return COINEX_USDT_SWAP_LIMIT * RATIO
     elif exchange == "okex_usdt_swap":
     elif exchange == "okex_usdt_swap":
-        return OKEX_USDT_SWAP_LIMIT*RATIO
+        return OKEX_USDT_SWAP_LIMIT * RATIO
     elif exchange == "bitget_usdt_swap":
     elif exchange == "bitget_usdt_swap":
-        return BITGET_USDT_SWAP_LIMIT*RATIO
+        return BITGET_USDT_SWAP_LIMIT * RATIO
     elif exchange == "bybit_usdt_swap":
     elif exchange == "bybit_usdt_swap":
-        return BYBIT_USDT_SWAP_LIMIT*RATIO
+        return BYBIT_USDT_SWAP_LIMIT * RATIO
     elif exchange == "mexc_spot":
     elif exchange == "mexc_spot":
-        return MEXC_SPOT_LIMIT*RATIO
+        return MEXC_SPOT_LIMIT * RATIO
     else:
     else:
         print("限频规则未找到")
         print("限频规则未找到")
-    return BASIC_LIMIT*RATIO
+    return BASIC_LIMIT * RATIO
 
 
 
 
 def get_limit_order_requests_num_per_second(exchange, limit=0):
 def get_limit_order_requests_num_per_second(exchange, limit=0):
     '''每秒下单请求频率'''
     '''每秒下单请求频率'''
     if limit != 0:
     if limit != 0:
         return limit
         return limit
-    elif exchange == "gate_spot": # 10/s
+    elif exchange == "gate_spot":  # 10/s
         return GATE_SPOT_LIMIT
         return GATE_SPOT_LIMIT
-    elif exchange == "gate_usdt_swap": # 100/s
+    elif exchange == "gate_usdt_swap":  # 100/s
         return GATE_USDT_SWAP_LIMIT
         return GATE_USDT_SWAP_LIMIT
-    elif exchange == "kucoin_spot": # 15/s
+    elif exchange == "kucoin_spot":  # 15/s
         return KUCOIN_SPOT_LIMIT
         return KUCOIN_SPOT_LIMIT
-    elif exchange == "kucoin_usdt_swap": # 10/s
+    elif exchange == "kucoin_usdt_swap":  # 10/s
         return KUCOIN_USDT_SWAP_LIMIT
         return KUCOIN_USDT_SWAP_LIMIT
-    elif exchange == "binance_usdt_swap": # 5/s
+    elif exchange == "binance_usdt_swap":  # 5/s
         return BINANCE_USDT_SWAP_LIMIT
         return BINANCE_USDT_SWAP_LIMIT
-    elif exchange == "binance_spot": # 2/s
+    elif exchange == "binance_spot":  # 2/s
         return BINANCE_SPOT_LIMIT
         return BINANCE_SPOT_LIMIT
-    elif exchange == "coinex_spot": # 20/s
+    elif exchange == "coinex_spot":  # 20/s
         return COINEX_SPOT_LIMIT
         return COINEX_SPOT_LIMIT
-    elif exchange == "coinex_usdt_swap": # 20/s
+    elif exchange == "coinex_usdt_swap":  # 20/s
         return COINEX_USDT_SWAP_LIMIT
         return COINEX_USDT_SWAP_LIMIT
-    elif exchange == "okex_usdt_swap": # 30/s
+    elif exchange == "okex_usdt_swap":  # 30/s
         return OKEX_USDT_SWAP_LIMIT
         return OKEX_USDT_SWAP_LIMIT
-    elif exchange == "bitget_usdt_swap": # 10/s
+    elif exchange == "bitget_usdt_swap":  # 10/s
         return BITGET_USDT_SWAP_LIMIT
         return BITGET_USDT_SWAP_LIMIT
-    elif exchange == "bybit_usdt_swap": # 2/s
+    elif exchange == "bybit_usdt_swap":  # 2/s
         return BYBIT_USDT_SWAP_LIMIT
         return BYBIT_USDT_SWAP_LIMIT
-    elif exchange == "mexc_spot": # 2/s
+    elif exchange == "mexc_spot":  # 2/s
         return MEXC_SPOT_LIMIT
         return MEXC_SPOT_LIMIT
     else:
     else:
         print("限频规则未找到")
         print("限频规则未找到")
     return BASIC_LIMIT
     return BASIC_LIMIT
 
 
+
 def dist_to_weight(price, mp, eff_range=EFF_RANGE):
 def dist_to_weight(price, mp, eff_range=EFF_RANGE):
     '''
     '''
         距离转换为权重
         距离转换为权重
     '''
     '''
-    dist = abs(price-mp)/mp
-    weight = 1 - clip(dist/eff_range, 0.0, 0.95)
+    dist = abs(price - mp) / mp
+    weight = 1 - clip(dist / eff_range, 0.0, 0.95)
     weight = weight if weight > 0 else 0
     weight = weight if weight > 0 else 0
     return weight
     return weight
 
 
+
 def change_params(fname, params, changes):
 def change_params(fname, params, changes):
     # 更改配置
     # 更改配置
     for i in changes:
     for i in changes:
         params[i[0]] = i[1]
         params[i[0]] = i[1]
-    with open(f"{fname}","w") as f:
-        toml.dump(params,f)
+    with open(f"{fname}", "w") as f:
+        toml.dump(params, f)
+
 
 
 def show_memory(unit='B', threshold=1024):
 def show_memory(unit='B', threshold=1024):
     '''查看变量占用内存情况
     '''查看变量占用内存情况
@@ -295,11 +299,13 @@ def show_memory(unit='B', threshold=1024):
     print(msg)
     print(msg)
     return msg
     return msg
 
 
+
 def clip(num, _min, _max):
 def clip(num, _min, _max):
     if num > _max: num = _max
     if num > _max: num = _max
     if num < _min: num = _min
     if num < _min: num = _min
     return num
     return num
 
 
+
 async def ding(msg, at_all, webhook, proxy=None):
 async def ding(msg, at_all, webhook, proxy=None):
     '''
     '''
         发送钉钉消息
         发送钉钉消息
@@ -313,11 +319,11 @@ async def ding(msg, at_all, webhook, proxy=None):
         "description": msg
         "description": msg
     }
     }
     message = {
     message = {
-    "content": "大吉大利 今晚吃鸡",
-    "username": "千千喵",
-    "embeds": [
-        embed
-            ],
+        "content": "大吉大利 今晚吃鸡",
+        "username": "千千喵",
+        "embeds": [
+            embed
+        ],
     }
     }
     message_json = json.dumps(message)
     message_json = json.dumps(message)
     if 'win' in sys.platform:
     if 'win' in sys.platform:
@@ -325,18 +331,20 @@ async def ding(msg, at_all, webhook, proxy=None):
     else:
     else:
         proxy = None
         proxy = None
     async with aiohttp.ClientSession() as session:
     async with aiohttp.ClientSession() as session:
-        await session.post(url=webhook, data=message_json, headers=header, proxy=proxy, timeout = 10)
+        await session.post(url=webhook, data=message_json, headers=header, proxy=proxy, timeout=10)
+
 
 
 def _get_params(url, proxy, params):
 def _get_params(url, proxy, params):
     '''更新参数'''
     '''更新参数'''
     import requests
     import requests
     try:
     try:
-        res = requests.post(url=url, json=params, timeout = 10)
+        res = requests.post(url=url, json=params, timeout=10)
         return json.loads(res.text)
         return json.loads(res.text)
     except:
     except:
         traceback.print_exc()
         traceback.print_exc()
         return []
         return []
 
 
+
 async def _post_params(url, proxy, params):
 async def _post_params(url, proxy, params):
     '''更新参数'''
     '''更新参数'''
     try:
     try:
@@ -345,7 +353,7 @@ async def _post_params(url, proxy, params):
         else:
         else:
             proxy = None
             proxy = None
         async with aiohttp.ClientSession() as session:
         async with aiohttp.ClientSession() as session:
-            res = await session.post(url=url, proxy=proxy, data=params, timeout = 10)
+            res = await session.post(url=url, proxy=proxy, data=params, timeout=10)
             data = await res.text()
             data = await res.text()
             print(data)
             print(data)
             return data
             return data
@@ -354,6 +362,7 @@ async def _post_params(url, proxy, params):
         return "post_params error"
         return "post_params error"
     return None
     return None
 
 
+
 def get_ip():
 def get_ip():
     try:
     try:
         s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
         s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@@ -363,6 +372,7 @@ def get_ip():
         s.close()
         s.close()
     return ip
     return ip
 
 
+
 def check_auth():
 def check_auth():
     print("*** 检查使用权限1 ***")
     print("*** 检查使用权限1 ***")
     ip = get_ip()
     ip = get_ip()
@@ -375,6 +385,7 @@ def check_auth():
         os._exit(0)
         os._exit(0)
     print("*** 符合要求 ***")
     print("*** 符合要求 ***")
 
 
+
 def check_time():
 def check_time():
     print("*** 检查使用权限2 ***")
     print("*** 检查使用权限2 ***")
     if time.time() > int(time.mktime(time.strptime('2021-11-17 00:00:00', "%Y-%m-%d %H:%M:%S"))):
     if time.time() > int(time.mktime(time.strptime('2021-11-17 00:00:00', "%Y-%m-%d %H:%M:%S"))):
@@ -382,41 +393,68 @@ def check_time():
         os._exit(0)
         os._exit(0)
     print("*** 符合要求 ***")
     print("*** 符合要求 ***")
 
 
+
 def num_to_str(num, d):
 def num_to_str(num, d):
-    if d >= 1.0:return "%d"%num
-    elif d in [0.1, 0.5]:return "%.1f"%num
-    elif d in [0.01, 0.05]:return "%.2f"%num
-    elif d in [0.001, 0.005]:return "%.3f"%num
-    elif d in [0.0001, 0.0005]:return "%.4f"%num
-    elif d in [0.00001, 0.00005]:return "%.5f"%num
-    elif d in [0.000001, 0.000005]:return "%.6f"%num
-    elif d in [0.0000001, 0.0000005]:return "%.7f"%num
-    elif d in [0.00000001, 0.00000005]:return "%.8f"%num
-    elif d in [0.000000001, 0.000000005]:return "%.9f"%num
-    elif d in [0.0000000001, 0.0000000005]:return "%.10f"%num
-    else: return str(num)
+    if d >= 1.0:
+        return "%d" % num
+    elif d in [0.1, 0.5]:
+        return "%.1f" % num
+    elif d in [0.01, 0.05]:
+        return "%.2f" % num
+    elif d in [0.001, 0.005]:
+        return "%.3f" % num
+    elif d in [0.0001, 0.0005]:
+        return "%.4f" % num
+    elif d in [0.00001, 0.00005]:
+        return "%.5f" % num
+    elif d in [0.000001, 0.000005]:
+        return "%.6f" % num
+    elif d in [0.0000001, 0.0000005]:
+        return "%.7f" % num
+    elif d in [0.00000001, 0.00000005]:
+        return "%.8f" % num
+    elif d in [0.000000001, 0.000000005]:
+        return "%.9f" % num
+    elif d in [0.0000000001, 0.0000000005]:
+        return "%.10f" % num
+    else:
+        return str(num)
+
 
 
 def num_to_decimal(num):
 def num_to_decimal(num):
     '''根据小数点位数获取精度'''
     '''根据小数点位数获取精度'''
     num = str(num)
     num = str(num)
-    if '.' not in num:return 0
-    elif '.' == num[-2]:return 1
-    elif '.' == num[-3]:return 2
-    elif '.' == num[-4]:return 3
-    elif '.' == num[-5]:return 4
-    elif '.' == num[-6]:return 5
-    elif '.' == num[-7]:return 6
-    elif '.' == num[-8]:return 7
-    elif '.' == num[-9]:return 8
-    elif '.' == num[-10]:return 9
-    elif '.' == num[-11]:return 10
-    else:return 11
+    if '.' not in num:
+        return 0
+    elif '.' == num[-2]:
+        return 1
+    elif '.' == num[-3]:
+        return 2
+    elif '.' == num[-4]:
+        return 3
+    elif '.' == num[-5]:
+        return 4
+    elif '.' == num[-6]:
+        return 5
+    elif '.' == num[-7]:
+        return 6
+    elif '.' == num[-8]:
+        return 7
+    elif '.' == num[-9]:
+        return 8
+    elif '.' == num[-10]:
+        return 9
+    elif '.' == num[-11]:
+        return 10
+    else:
+        return 11
+
 
 
 def fix_amount(amount, stepSize):
 def fix_amount(amount, stepSize):
     '''修补数量向下取整'''
     '''修补数量向下取整'''
     return float(
     return float(
-            Decimal(str(amount))//Decimal(str(stepSize)
-        ) \
+        Decimal(str(amount)) // Decimal(str(stepSize)
+                                        ) \
         * Decimal(str(stepSize)))
         * Decimal(str(stepSize)))
     # return float(Decimal(str(amount)).quantize(Decimal(str(stepSize)), ROUND_FLOOR))
     # return float(Decimal(str(amount)).quantize(Decimal(str(stepSize)), ROUND_FLOOR))
 
 
@@ -424,11 +462,12 @@ def fix_amount(amount, stepSize):
 def fix_price(price, tickSize):
 def fix_price(price, tickSize):
     '''修补价格四舍五入'''
     '''修补价格四舍五入'''
     return float(
     return float(
-            round(Decimal(str(price))/Decimal(str(tickSize))
-        ) \
+        round(Decimal(str(price)) / Decimal(str(tickSize))
+              ) \
         * Decimal(str(tickSize)))
         * Decimal(str(tickSize)))
     # return float(Decimal(str(price)).quantize(Decimal(str(tickSize)), ROUND_HALF_UP))
     # return float(Decimal(str(price)).quantize(Decimal(str(tickSize)), ROUND_HALF_UP))
 
 
+
 def timeit(func):
 def timeit(func):
     def wrapper(*args, **kwargs):
     def wrapper(*args, **kwargs):
         nowTime = time.time()
         nowTime = time.time()
@@ -437,8 +476,10 @@ def timeit(func):
         spend_time = round(spend_time * 1e6, 3)
         spend_time = round(spend_time * 1e6, 3)
         print(f'{func.__name__} 耗时 {spend_time} us')
         print(f'{func.__name__} 耗时 {spend_time} us')
         return res
         return res
+
     return wrapper
     return wrapper
 
 
+
 def get_backtest_set(base=""):
 def get_backtest_set(base=""):
     '''生成预设参数'''
     '''生成预设参数'''
     # 开仓距离不能太近必须超过大部分价格tick运动的距离
     # 开仓距离不能太近必须超过大部分价格tick运动的距离
@@ -454,75 +495,34 @@ def get_backtest_set(base=""):
     close_dict = dict()
     close_dict = dict()
     for open in open_list:
     for open in open_list:
         close_dict[open] = [
         close_dict[open] = [
-            open*0.1,
-            open*0.2,
-            ]
+            open * 0.1,
+            open * 0.2,
+        ]
     alpha_list = [0.0]
     alpha_list = [0.0]
     return open_list, close_dict, alpha_list
     return open_list, close_dict, alpha_list
 
 
+
 def get_local_ip_list():
 def get_local_ip_list():
     '''获取本地ip'''
     '''获取本地ip'''
     import netifaces as ni
     import netifaces as ni
     ipList = []
     ipList = []
     # print('检测服务器网络配置')
     # print('检测服务器网络配置')
     for dev in ni.interfaces():
     for dev in ni.interfaces():
-        print('dev:',dev)
+        print('dev:', dev)
         if 'ens' in dev or 'eth' in dev or 'enp' in dev:
         if 'ens' in dev or 'eth' in dev or 'enp' in dev:
             # print(ni.ifaddresses(dev))
             # print(ni.ifaddresses(dev))
             for i in ni.ifaddresses(dev)[2]:
             for i in ni.ifaddresses(dev)[2]:
-                ip=i['addr']
+                ip = i['addr']
                 print(f"检测到私有ip:{ip}")
                 print(f"检测到私有ip:{ip}")
                 if ip not in ipList:
                 if ip not in ipList:
                     ipList.append(ip)
                     ipList.append(ip)
     print(f"当前服务器私有ip为{ipList}")
     print(f"当前服务器私有ip为{ipList}")
     return ['127.0.0.1']
     return ['127.0.0.1']
-    
-if __name__ == "__main__":
-
-    #########
-    if 0:
-        print(fix_amount(1.0, 0.1))
-        print(fix_amount(0.9, 0.05))
-        print(fix_amount(1.1, 0.1))
-        print(fix_amount(1.2, 0.5))
-        print(fix_amount(0.01, 0.05))
-    if 1:
-        print(fix_price(1.0, 0.1))
-        print(fix_price(0.9, 2.0))
-        print(fix_price(1.1, 0.1))
-        print(fix_price(1.2, 0.5))
-        print(fix_price(4999.99, 0.5))
-    #########
-    if 0:
-        # print(num_to_str(123.123))
-        print(get_backtest_set())
-
-    ####################
-    if 0:
-
-        p = get_params("config.toml")
-
-        loop = asyncio.get_event_loop()
-        
-        # loop.create_task(ding("123", 1, "https://discord.com/api/webhooks/907870708481265675/IfN4GqH4fj8HWS_FecH3Lrc2qtRyqsCHsSJVLFHlxY8ioHprfdxIMUNAfqkZZ6opzVEP"))
-        
-        loop.create_task(
-            _post_params(
-                "http://wwww.khods.com:8888/post_params", 
-                None,
-                ujson.dumps({
-                    "exchange":"binance_usdt_swap",
-                    "pair":"eth_usdt",
-                    "open":"0.001",
-                    "close":"0.0001",
-                    "refexchange":"binance_spot",
-                    "refpair":"eth_usdt",
-                    "profit":0.1,
-                })
-            )
-        )
-
-        loop.run_forever()
-    
-    ####################
 
 
+
+if __name__ == '__main__':
+    gamma = float(0.5)
+    kappa = float(2)
+    factor = 1 + gamma / kappa
+    _optimal_spread = 2 * Decimal(factor).ln()
+    print(_optimal_spread)