浏览代码

最佳价差版本

JiahengHe 1 年之前
父节点
当前提交
09968497cf

+ 0 - 17
backtest.py

@@ -2,27 +2,10 @@
     基于深度信息的异步事件触发回测框架
     作者 千千量化
 '''
-import os
-import sys
-import asyncio
-import json, ujson
-import requests
-import hmac
-import base64
-import zlib
-import datetime
-import time
 import strategy as strategy
-import traceback
-import joblib
-import traceback
-import configparser
-import signal
-import random
 import utils
 import model
 import time
-import copy
 import random
 
 def timeit(func):

+ 9 - 9
config.toml

@@ -1,23 +1,23 @@
 broker_id = "kucoin"
 account_name = "ku1"
-access_key = ""
-secret_key = ""
-pass_key = ""
+access_key = "6393f3565f0d4500011f846b"
+secret_key = "9c0df8b7-daaa-493e-a53a-82703067f7dd"
+pass_key = "b87d055f"
 exchange = "kucoin_usdt_swap"
-pair = "loom_usdt"
+pair = "meme_usdt"
 debug = "False"
-open = 0.02
+open = 0.01
 close = 0.0002
 leverrate = 0.5
 interval = 0.1
-refexchange = "['binance_usdt_swap']"
-refpair = "['loom_usdt']"
-used_pct = 1
+refexchange = "['kucoin_usdt_swap']"
+refpair = "['meme_usdt']"
+used_pct = 0.9
 index = 0
 save = 0
 hold_coin = 0.0
 log = 1
-stoploss = "0.02"
+stoploss = "0.01"
 gamma = 0.999
 grid = 1
 place_order_limit = 0

+ 13 - 9
exchange/binance_usdt_swap_ws.py

@@ -1,17 +1,21 @@
-import aiohttp
-import time
 import asyncio
-import zlib
-import json, ujson
-import zlib
+import csv
 import hashlib
 import hmac
-import base64
+import logging
+import logging.handlers
+import random
+import sys
+import time
 import traceback
-import random, csv, sys, utils
-import logging, logging.handlers
+import utils
+import zlib
+
+import aiohttp
+import ujson
+
 import model
-from loguru import logger
+
 
 def empty_call(msg):
     pass

+ 3 - 11
exchange/bybit_usdt_swap_rest.py

@@ -1,25 +1,17 @@
-import imp
-import random
-import re
 import aiohttp
 import time
 import asyncio
-import zlib
-import json, ujson
+import ujson
 import hmac
-import base64
 import hashlib
 import traceback
-import urllib
-from urllib import parse
 from urllib.parse import urljoin
-import datetime, sys
+import sys
 from urllib.parse import urlparse
-import logging, logging.handlers
+import logging.handlers
 import utils
 import logging, logging.handlers
 import model
-from decimal import Decimal
 
 def empty_call(msg):
     print(f'空的回调函数 {msg}')

+ 2 - 6
exchange/kucoin_usdt_swap_rest.py

@@ -1,19 +1,15 @@
-import random
 import aiohttp
 import time
 import asyncio
-import zlib
 import json
 import hmac
 import base64
 import hashlib
 import traceback
-import urllib
-from urllib import parse
 from urllib.parse import urljoin
-import datetime, sys
+import sys
 from urllib.parse import urlparse
-import logging, logging.handlers
+import logging.handlers
 import utils
 import logging, logging.handlers
 import model

+ 8 - 6
exchange/kucoin_usdt_swap_ws.py

@@ -1,21 +1,17 @@
 import aiohttp
 import time
 import asyncio
-import zlib
-import json, ujson
-import zlib
+import ujson
 import hashlib
 import hmac
 import base64
 import traceback
-import random
-import gzip, csv, sys
+import csv, sys
 from uuid import uuid4
 import logging, logging.handlers
 import utils
 import model
 from decimal import Decimal
-from loguru import logger
 
 def empty_call(msg):
     # print(msg)
@@ -45,6 +41,7 @@ class KucoinUsdtSwapWs:
             "onOrder":empty_call,
             "onTicker":empty_call,
             "onDepth":empty_call,
+            "onTrade":empty_call,
             "onExit":empty_call,
             }
         self.is_print = is_print
@@ -120,6 +117,7 @@ class KucoinUsdtSwapWs:
             self.update_t = msg['data']['sequence']
             self.ticker_info['bp'] = float(msg['data']['bids'][0][0])
             self.ticker_info['ap'] = float(msg['data']['asks'][0][0])
+            self.ticker_info['time'] = msg['data']['timestamp']
             self.callback['onTicker'](self.ticker_info)
             ##### 标准化深度
             mp = (self.ticker_info["bp"] + self.ticker_info["ap"])*0.5
@@ -167,6 +165,7 @@ class KucoinUsdtSwapWs:
             self.update_t = msg['data']['sequence']
             self.ticker_info['bp'] = float(msg['data']['bestBidPrice'])
             self.ticker_info['ap'] = float(msg['data']['bestAskPrice'])
+            self.ticker_info['time'] = msg['data']['ts']
             self.callback['onTicker'](self.ticker_info)
 
             bp = float(msg['data']['bestBidPrice'])
@@ -190,6 +189,8 @@ class KucoinUsdtSwapWs:
         elif side == 'sell':
             self.sell_q += amount
             self.sell_v += amount*price
+        self.callback['onTrade']({'timestamp': msg["data"]['ts'], 'price': price, 'amount': amount, 'side': side})
+
 
     def _update_position(self, msg):
         pos = model.Position()
@@ -348,6 +349,7 @@ class KucoinUsdtSwapWs:
                     channels=[
                         # f"/contractMarket/tickerV2:{self.symbol}",
                         f"/contractMarket/level2Depth50:{self.symbol}",
+                        f"/contractMarket/execution:{self.symbol}"
                         ]
                     if sub_trade:
                         channels += [f"/contractMarket/execution:{self.symbol}"]

+ 169 - 0
predictor_new.py

@@ -0,0 +1,169 @@
+import time
+from decimal import Decimal
+
+import utils
+from util.instant_volatility import InstantVolatilityIndicator
+from util.trading_intensity import TradingIntensityIndicator
+
+
+class PredictorNew:
+    '''
+        reference
+    '''
+
+    def __init__(self, ref_name=["Unknown Market"], alpha=[1.0 for _ in range(99)], gamma=0.999):
+        self.loop = 0
+        self.arr = []
+        self.trade_mp_series = []
+        self.ref_mp_series = []
+        self.ref_num = len(ref_name)
+        ### 定价
+        self.window = 10
+        ### 价格系数
+        self.alpha = alpha
+        # 参考
+        print('定价系数:', gamma)
+        self.gamma = gamma
+        self.avg_spread = [None for _ in range(self.ref_num)]
+        self.vol = InstantVolatilityIndicator(sampling_length=100, processing_length=1)
+        self.trading_intensity = TradingIntensityIndicator(sampling_length=20)
+
+
+    def processer(self):
+        '''
+            计算任务
+        '''
+        # update trade mp
+        bp = self.arr[-1][utils.BP_INDEX]
+        ap = self.arr[-1][utils.AP_INDEX]
+        mp = (bp + ap) * 0.5
+        self.trade_mp_series.append(mp)
+        # 更新参考盘口mp
+        ref_mp = []
+        for i in range(self.ref_num):
+            bp = self.arr[-1][utils.LEN * (1 + i) + utils.BP_INDEX]
+            ap = self.arr[-1][utils.LEN * (1 + i) + utils.AP_INDEX]
+            mp = (bp + ap) * 0.5
+            ref_mp.append(mp)
+        self.ref_mp_series.append(ref_mp)
+        # 偏差计算
+        self._update_avg_spread()
+
+    def _update_avg_spread(self):
+        '''
+            更新平均偏差
+        '''
+        # 计算偏差1
+        for i in range(self.ref_num):
+            bias = self.ref_mp_series[-1][i] * self.alpha[i] - self.trade_mp_series[-1]
+            # 如果是刚启动 gamma不能太大
+            if self.loop < 100:
+                gamma = 0.9
+            else:
+                gamma = self.gamma
+            if self.avg_spread[i] == None:
+                self.avg_spread[i] = bias
+            else:
+                self.avg_spread[i] = self.avg_spread[i] * gamma + bias * (1 - gamma)
+
+    def check_length(self):
+        # 行情缓存
+        if len(self.arr) > self.window: del (self.arr[0])
+        if len(self.trade_mp_series) > self.window: del (self.trade_mp_series[0])
+        if len(self.ref_mp_series) > self.window: del (self.ref_mp_series[0])
+
+    # @utils.timeit
+    def onTime(self, data):
+        if isinstance(data, list):
+            if len(data) > 0:
+                self.loop += 1
+                self.arr.append(data)
+                self.processer()
+                self.check_length()
+            else:
+                print("行情数据为空")
+        else:
+            print("行情数据为None")
+
+    def market_update(self, best_mid_price, time_num):
+        self.vol.add_sample(best_mid_price)
+        self.trading_intensity.calculate(best_mid_price, time_num)
+
+    # @utils.timeit
+    # def Get_ref(self, ref_ticker):
+    #     '''
+    #         get ref price
+    #     '''
+    #     ref_mid = []
+    #     for i in range(self.ref_num):
+    #         ref_mid.append(
+    #             [ref_ticker[i][0] * self.alpha[i] - self.avg_spread[i],
+    #              ref_ticker[i][1] * self.alpha[i] - self.avg_spread[i]]
+    #         )
+    #     return ref_mid
+
+    def Get_ref(self, ref_ticker):
+        '''
+            get ref price
+        '''
+        bp = ref_ticker[0][0]
+        ap = ref_ticker[0][1]
+        mp = Decimal((bp + ap) * 0.5)
+        std = self.vol.current_value
+        gamma = Decimal(0.5)
+        alpha, kappa = self.trading_intensity.current_value
+        ref_mid = []
+        if alpha != 0 and kappa > 0 and std != 0:
+            factor = Decimal(1) + gamma / kappa
+            _optimal_spread = gamma * std
+            _optimal_spread += 2 * factor.ln()
+            _optimal_ask = mp + _optimal_spread / 2
+            _optimal_bid = mp - _optimal_spread / 2
+            ref_mid.append(
+                [_optimal_ask,
+                 _optimal_bid]
+            )
+        return ref_mid
+
+if __name__ == "__main__":
+
+    import pandas as pd
+    import matplotlib
+
+    matplotlib.use('TkAgg')
+    import matplotlib.pyplot as plt
+
+    arr = pd.read_csv('history/ftm_usdt_binance_usdt_swap.csv').values.tolist()
+
+
+    def line_data_to_tickers(data, ref_num):
+        ref_tickers = []
+        for i in range(ref_num):
+            ref_tickers.append([data[utils.LEN * (i + 1) + utils.BP_INDEX], data[utils.LEN * (i + 1) + utils.AP_INDEX]])
+        return ref_tickers
+
+
+    ref_num = len(arr[0]) // utils.LEN - 1
+
+    p = PredictorNew(ref_name=["unkwon" for _ in range(ref_num)])
+    t = []
+    ref_index = 1
+
+    for data in arr:
+        p.onTime(data)
+        trade_mp = (data[utils.BP_INDEX] + data[utils.AP_INDEX]) * 0.5
+        ref_price = p.Get_ref(line_data_to_tickers(data, ref_num))
+        t.append([
+            trade_mp,
+            (ref_price[ref_index][0] + ref_price[ref_index][1]) * 0.5,
+        ])
+
+    t = pd.DataFrame(t, columns=['mp', 'ref'])
+
+    if 1:
+        plt.figure()
+        plt.plot(t['mp'], 'k')
+        plt.plot(t['ref'], 'g')
+        plt.grid()
+        plt.show()
+

+ 26 - 11
quant.py

@@ -8,15 +8,13 @@ import model
 import logging, logging.handlers
 import signal
 import os, json, sys
-import predictor
+import predictor_new
 import backtest
-import multiprocessing
 import random
 import psutil
 import ujson
 import broker
 from decimal import Decimal
-from loguru import logger
 
 VERSION = utils.VERSION
 
@@ -202,6 +200,7 @@ class Quant:
             self.ws_ref[self.ref_name[i]] = broker.newWs(exchange)(cp, colo=_colo)
             self.ws_ref[self.ref_name[i]].callback['onTicker']=self.update_ticker
             self.ws_ref[self.ref_name[i]].callback['onDepth']=self.update_depth
+            self.ws_ref[self.ref_name[i]].callback['onTrade'] = self.update_trade
             self.ws_ref[self.ref_name[i]].logger = self.logger
         # 添加回调
         self.ws.callback = {
@@ -210,6 +209,7 @@ class Quant:
             'onPosition':self.update_position,
             'onEquity':self.update_equity,
             'onOrder':self.update_order,
+            'onTrade':self.update_trade,
             'onExit':self.update_exit,
             }
         self.rest.callback = {
@@ -218,6 +218,7 @@ class Quant:
             'onPosition':self.update_position,
             'onEquity':self.update_equity,
             'onOrder':self.update_order,
+            'onTrade': self.update_trade,
             'onExit':self.update_exit,
             }
         self.rest.logger = self.logger
@@ -236,7 +237,7 @@ class Quant:
             # 交易shib 参考 shib
                 price_alpha.append(1.0)
         self.logger.info(f'价格系数{price_alpha}')
-        self.Predictor = predictor.Predictor(ref_name=self.ref_name, alpha=price_alpha, gamma=float(self.params.gamma))
+        self.Predictor = predictor_new.PredictorNew(ref_name=self.ref_name, alpha=price_alpha, gamma=float(self.params.gamma))
         # 初始化参数
         self.strategy.trade_open_dist = float(self.params.open)
         self.strategy.trade_close_dist = float(self.params.close)
@@ -574,6 +575,10 @@ class Quant:
             update ticker infomation
         '''
         name = data['name']
+        if name == self.ref_name[0]:
+            min_price = (data['bp'] + data['ap']) * 0.5
+            self.Predictor.market_update(min_price, data['time'])
+
         # 记录tick更新时间
         # self.market_update_time[name] = time.time()
         self.tickers[name] = data
@@ -585,6 +590,13 @@ class Quant:
         else:
             pass
 
+    def update_trade(self, data):
+        self.loop.create_task(self._update_trade(data))
+
+    async def _update_trade(self, data):
+        print(f"trade数据:{data}")
+        self.Predictor.trading_intensity.c_register_trade(data)
+
     # @utils.timeit
     async def _update_depth(self, data):
         '''
@@ -930,10 +942,12 @@ class Quant:
         ref_tickers = []
         for i in self.ref_name:
             ref_tickers.append([self.tickers[i]['bp'], self.tickers[i]['ap']])
-        self.tradeMsg.ref_price = self.Predictor.Get_ref(ref_tickers)
+        ref_price = self.Predictor.Get_ref(ref_tickers)
+        if len(ref_price) == 0:
+            return
         # logger.info('ref_price={}, market={}, predict={}'.format(
         #     self.tradeMsg.ref_price, self.tradeMsg.market, self.tradeMsg.predict))
-
+        self.tradeMsg.ref_price = ref_price
     async def server_handle(self, request):
         '''中控数据接口'''
         if 'spot' in self.exchange:
@@ -1112,11 +1126,12 @@ class Quant:
                 self.exit_msg = msg
                 self.stop()
         ###### 定价异常风控 ######
-        if abs(self.strategy.ref_price-self.strategy.mp)/self.strategy.mp > 0.03:
-            msg = f"{self.acct_name} 定价偏离过大 怀疑异常 退出"
-            self.logger.error(msg)
-            self.exit_msg = msg
-            self.stop()
+        if self.strategy.ref_price is not None and self.strategy.ref_price != 0:
+            if abs(self.strategy.ref_price-self.strategy.mp)/self.strategy.mp > 0.03:
+                msg = f"{self.acct_name} 定价偏离过大 怀疑异常 退出"
+                self.logger.error(msg)
+                self.exit_msg = msg
+                self.stop()
 
     async def exit(self, delay=0):
         '''退出操作'''

+ 3 - 5
strategy.py

@@ -4,7 +4,6 @@ import model
 import logging, logging.handlers
 from decimal import Decimal
 from decimal import ROUND_HALF_UP, ROUND_FLOOR
-from loguru import logger
 
 class Strategy:
     '''
@@ -194,10 +193,9 @@ class Strategy:
                 self.maker_mode = 'follow'
 
             ###### ref price ######
-            if data.ref_price == None:
-                self.ref_bp = self.bp
-                self.ref_ap = self.ap
-                self.ref_price = self.mp
+            if data.ref_price is None or len(data.ref_price) == 0:
+                print('参考价格还未预热完成,等待预热...')
+                return 0
             else:
                 self.ref_bp = data.ref_price[self.ref_index][0]
                 self.ref_ap = data.ref_price[self.ref_index][1]

+ 0 - 0
util/__init__.py


+ 66 - 0
util/base_trailing_indicator.py

@@ -0,0 +1,66 @@
+import logging
+from abc import ABC, abstractmethod
+from decimal import Decimal
+
+import numpy as np
+
+from util.ring_buffer import RingBuffer
+
+
+class BaseTrailingIndicator(ABC):
+
+    def __init__(self, sampling_length: int = 30, processing_length: int = 15):
+        self._sampling_buffer = RingBuffer(sampling_length)
+        self._processing_buffer = RingBuffer(processing_length)
+        self._samples_length = 0
+
+    def add_sample(self, value: float):
+        self._sampling_buffer.add_value(value)
+        indicator_value = self._indicator_calculation()
+        self._processing_buffer.add_value(indicator_value)
+
+    @abstractmethod
+    def _indicator_calculation(self) -> float:
+        raise NotImplementedError
+
+    def _processing_calculation(self) -> float:
+        """
+        Processing of the processing buffer to return final value.
+        Default behavior is buffer average
+        """
+        return np.mean(self._processing_buffer.get_as_numpy_array())
+
+    @property
+    def current_value(self) -> Decimal:
+        return Decimal(self._processing_calculation())
+
+    @property
+    def is_sampling_buffer_full(self) -> bool:
+        return self._sampling_buffer.is_full
+
+    @property
+    def is_processing_buffer_full(self) -> bool:
+        return self._processing_buffer.is_full
+
+    @property
+    def is_sampling_buffer_changed(self) -> bool:
+        buffer_len = len(self._sampling_buffer.get_as_numpy_array())
+        is_changed = self._samples_length != buffer_len
+        self._samples_length = buffer_len
+        return is_changed
+
+    @property
+    def sampling_length(self) -> int:
+        return self._sampling_buffer.length
+
+    @sampling_length.setter
+    def sampling_length(self, value):
+        self._sampling_buffer.length = value
+
+    @property
+    def processing_length(self) -> int:
+        return self._processing_buffer.length
+
+    @processing_length.setter
+    def processing_length(self, value):
+        self._processing_buffer.length = value

+ 19 - 0
util/instant_volatility.py

@@ -0,0 +1,19 @@
+from .base_trailing_indicator import BaseTrailingIndicator
+import numpy as np
+
+
+class InstantVolatilityIndicator(BaseTrailingIndicator):
+    def __init__(self, sampling_length: int = 30, processing_length: int = 15):
+        super().__init__(sampling_length, processing_length)
+
+    def _indicator_calculation(self) -> float:
+        # The standard deviation should be calculated between ticks and not with a mean of the whole buffer
+        # Otherwise if the asset is trending, changing the length of the buffer would result in a greater volatility as more ticks would be further away from the mean
+        # which is a nonsense result. If volatility of the underlying doesn't change in fact, changing the length of the buffer shouldn't change the result.
+        np_sampling_buffer = self._sampling_buffer.get_as_numpy_array()
+        vol = np.sqrt(np.sum(np.square(np.diff(np_sampling_buffer))) / np_sampling_buffer.size)
+        return vol
+
+    def _processing_calculation(self) -> float:
+        # Only the last calculated volatlity, not an average of multiple past volatilities
+        return self._processing_buffer.get_last_value()

+ 99 - 0
util/ring_buffer.py

@@ -0,0 +1,99 @@
+import numpy as np
+
+
+class RingBuffer:
+    def __init__(self, length=20):
+        self._length = length
+        self._buffer = np.zeros(length, dtype=np.float64)
+        self._delimiter = 0
+        self._is_full = False
+
+    def __dealloc__(self):
+        self._buffer = None
+
+    def c_add_value(self, value):
+        self._buffer[self._delimiter] = value
+        self.c_increment_delimiter()
+
+    def c_increment_delimiter(self):
+        self._delimiter = (self._delimiter + 1) % self._length
+        if not self._is_full and self._delimiter == 0:
+            self._is_full = True
+
+    def c_is_empty(self):
+        return (not self._is_full) and (0 == self._delimiter)
+
+    def c_get_last_value(self):
+        if self.c_is_empty():
+            return np.nan
+        return self._buffer[self._delimiter - 1]
+
+    def c_is_full(self):
+        return self._is_full
+
+    def c_mean_value(self):
+        result = np.nan
+        if self._is_full:
+            result = np.mean(self.c_get_as_numpy_array())
+        return result
+
+    def c_variance(self):
+        result = np.nan
+        if self._is_full:
+            result = np.var(self.c_get_as_numpy_array())
+        return result
+
+    def c_std_dev(self):
+        result = np.nan
+        if self._is_full:
+            result = np.std(self.c_get_as_numpy_array())
+        return result
+
+    def c_get_as_numpy_array(self):
+        if not self._is_full:
+            indexes = np.arange(0, stop=self._delimiter, dtype=np.int16)
+        else:
+            indexes = np.arange(self._delimiter, stop=self._delimiter + self._length,
+                                dtype=np.int16) % self._length
+        return np.asarray(self._buffer)[indexes]
+
+    def add_value(self, val):
+        self.c_add_value(val)
+
+    def get_as_numpy_array(self):
+        return self.c_get_as_numpy_array()
+
+    def get_last_value(self):
+        return self.c_get_last_value()
+
+    @property
+    def is_full(self):
+        return self.c_is_full()
+
+    @property
+    def mean_value(self):
+        return self.c_mean_value()
+
+    @property
+    def std_dev(self):
+        return self.c_std_dev()
+
+    @property
+    def variance(self):
+        return self.c_variance()
+
+    @property
+    def length(self) -> int:
+        return self._length
+
+    @length.setter
+    def length(self, value):
+        data = self.get_as_numpy_array()
+
+        self._length = value
+        self._buffer = np.zeros(value, dtype=np.float64)
+        self._delimiter = 0
+        self._is_full = False
+
+        for val in data[-value:]:
+            self.add_value(val)

+ 135 - 0
util/trading_intensity.py

@@ -0,0 +1,135 @@
+from decimal import Decimal
+from typing import Tuple
+
+import numpy as np
+from scipy.optimize import curve_fit
+
+
+class TradingIntensityIndicator:
+
+    def __init__(self, sampling_length: int = 30):
+        self._alpha = Decimal(0)
+        self._kappa = Decimal(0)
+        self._trade_samples = {}
+        self._current_trade_sample = []
+        self._sampling_length = sampling_length
+        self._samples_length = 0
+        self._last_quotes = []
+
+    @property
+    def current_value(self) -> Tuple[Decimal, Decimal]:
+        return self._alpha, self._kappa
+
+    @property
+    def is_sampling_buffer_full(self) -> bool:
+        return len(self._trade_samples.keys()) == self._sampling_length
+
+    @property
+    def is_sampling_buffer_changed(self) -> bool:
+        is_changed = self._samples_length != len(self._trade_samples.keys())
+        self._samples_length = len(self._trade_samples.keys())
+        return is_changed
+
+    @property
+    def sampling_length(self) -> int:
+        return self._sampling_length
+
+    @sampling_length.setter
+    def sampling_length(self, new_len: int):
+        self._sampling_length = new_len
+
+    @property
+    def last_quotes(self) -> list:
+        """A helper method to be used in unit tests"""
+        return self._last_quotes
+
+    @last_quotes.setter
+    def last_quotes(self, value):
+        """A helper method to be used in unit tests"""
+        self._last_quotes = value
+
+    def calculate(self, price, timestamp):
+        """A helper method to be used in unit tests"""
+        self.c_calculate(price, timestamp)
+
+    def c_calculate(self, price, timestamp):
+        # Descending order of price-timestamp quotes
+        self._last_quotes = [{'timestamp': timestamp, 'price': price}] + self._last_quotes
+
+        latest_processed_quote_idx = None
+        for trade in self._current_trade_sample:
+            for i, quote in enumerate(self._last_quotes):
+                if quote["timestamp"] < trade['timestamp']:
+                    if latest_processed_quote_idx is None or i < latest_processed_quote_idx:
+                        latest_processed_quote_idx = i
+                    trade = {"price_level": abs(trade["price"] - float(quote["price"])), "amount": trade["amount"]}
+
+                    if quote["timestamp"] + 1 not in self._trade_samples.keys():
+                        self._trade_samples[quote["timestamp"] + 1] = []
+
+                    self._trade_samples[quote["timestamp"] + 1] += [trade]
+                    break
+
+        # THere are no trades left to process
+        self._current_trade_sample = []
+        # Store quotes that happened after the latest trade + one before
+        if latest_processed_quote_idx is not None:
+            self._last_quotes = self._last_quotes[0:latest_processed_quote_idx + 1]
+
+        if len(self._trade_samples.keys()) > self._sampling_length:
+            timestamps = list(self._trade_samples.keys())
+            timestamps.sort()
+            timestamps = timestamps[-self._sampling_length:]
+
+            trade_samples = {}
+            for timestamp in timestamps:
+                trade_samples[timestamp] = self._trade_samples[timestamp]
+            self._trade_samples = trade_samples
+
+        if self.is_sampling_buffer_full:
+            self.c_estimate_intensity()
+
+    def register_trade(self, trade):
+        """A helper method to be used in unit tests"""
+        self.c_register_trade(trade)
+
+    def c_register_trade(self, trade):
+        self._current_trade_sample.append(trade)
+
+    def c_estimate_intensity(self):
+
+        # Calculate lambdas / trading intensities
+        lambdas = []
+
+        trades_consolidated = {}
+        price_levels = []
+        for timestamp in self._trade_samples.keys():
+            tick = self._trade_samples[timestamp]
+            for trade in tick:
+                if trade['price_level'] not in trades_consolidated.keys():
+                    trades_consolidated[trade['price_level']] = 0
+                    price_levels += [trade['price_level']]
+
+                trades_consolidated[trade['price_level']] += trade['amount']
+
+        price_levels = sorted(price_levels, reverse=True)
+
+        for price_level in price_levels:
+            lambdas += [trades_consolidated[price_level]]
+
+        # Adjust to be able to calculate log
+        lambdas_adj = [10 ** -10 if x == 0 else x for x in lambdas]
+
+        # Fit the probability density function; reuse previously calculated parameters as initial values
+        try:
+            params = curve_fit(lambda t, a, b: a * np.exp(-b * t),
+                               price_levels,
+                               lambdas_adj,
+                               p0=(self._alpha, self._kappa),
+                               method='dogbox',
+                               bounds=([0, 0], [np.inf, np.inf]))
+
+            self._kappa = Decimal(str(params[0][1]))
+            self._alpha = Decimal(str(params[0][0]))
+        except (RuntimeError, ValueError) as e:
+            pass

+ 159 - 159
utils.py

@@ -1,44 +1,35 @@
-import json
 import traceback
-import utils
 import model
 import toml, time, random
-import os, sys, asyncio, aiohttp
+import sys, aiohttp
 import socket
-import asyncio
 import requests
-import ujson
 from decimal import Decimal
-from decimal import ROUND_HALF_UP, ROUND_FLOOR
 import gzip
-import csv
-import os
-import base64
 from Crypto.Cipher import AES
-from Crypto import Random
 import os
 import base64
 import json
 
-parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 
-sys.path.insert(0,parentdir)  
+parentdir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+sys.path.insert(0, parentdir)
 
 ############### 全局配置
 VERSION = "2022-04-18"
-CHILD_RUN_SECOND = 60 * 60 * 24 # child process max run time per loop 
-EARLY_STOP_SECOND = 60 * 60 * 2 # child early stop min check time
-BACKTEST_PREHOT_SECOND = 60 * 30 # backtest pre hot time
-DUMMY_RUN_SECOND = 60 * 60 * 12 # dummy process max run time per loop 
-DUMMY_EARLY_STOP_SECOND = 60 * 60 # dummy process max run time per loop 
-POST_SIDE_LIMIT = [0] # post side limit
-MARKET_DELAY_LIMIT = 30000 # market update delay limit threhold unit:ms
+CHILD_RUN_SECOND = 60 * 60 * 24  # child process max run time per loop
+EARLY_STOP_SECOND = 60 * 60 * 2  # child early stop min check time
+BACKTEST_PREHOT_SECOND = 60 * 30  # backtest pre hot time
+DUMMY_RUN_SECOND = 60 * 60 * 12  # dummy process max run time per loop
+DUMMY_EARLY_STOP_SECOND = 60 * 60  # dummy process max run time per loop
+POST_SIDE_LIMIT = [0]  # post side limit
+MARKET_DELAY_LIMIT = 30000  # market update delay limit threhold unit:ms
 GRID = 1
 STOPLOSS = 0.02
 GAMMA = 0.999
 ###### market行情数据长度 标准化n档深度+6档成交信息 ######
 LEVEL = 1
-TRADE_LEN = 2 # 最高 最低 成交价
-LEN = LEVEL * 4 + TRADE_LEN # 总长度
+TRADE_LEN = 2  # 最高 最低 成交价
+LEN = LEVEL * 4 + TRADE_LEN  # 总长度
 BP_INDEX = LEVEL * 0
 BQ_INDEX = LEVEL * 0 + 1
 AP_INDEX = LEVEL * 2
@@ -56,15 +47,18 @@ BACKTEST_DELAY = 0.15
 
 global base_cid
 base_cid = 0
+
+
 def get_cid(broker=None):
     global base_cid
     base_cid += 1
     if base_cid > 999:
-        base_cid=0
-    cid = str(time.time())[4:10]+str(random.randint(1,999))+str(base_cid)
+        base_cid = 0
+    cid = str(time.time())[4:10] + str(random.randint(1, 999)) + str(base_cid)
     if broker:
         cid = broker + cid
-    return cid      
+    return cid
+
 
 def csv_to_gz_and_remove():
     def List_files(filepath, substr):
@@ -78,12 +72,13 @@ def csv_to_gz_and_remove():
 
     for file in List_files('./', '.csv'):
         if '.gz' not in file:
-            data = open(file, 'rb' ).read()
+            data = open(file, 'rb').read()
             with gzip.open(file + '.gz', 'a') as zip:
                 zip.write(data)
                 zip.close()
             os.remove(file)
 
+
 def get_params(fname):
     # 读取配置
     try:
@@ -92,9 +87,9 @@ def get_params(fname):
         f = open(fname)
         data = f.read()
         text = base64.b64decode(data)
-        cryptor = AES.new(key =bytes("qFHFPv6MugrSTkEsWFs8wCDg3iC6!er%".encode()), mode=AES.MODE_ECB)
+        cryptor = AES.new(key=bytes("qFHFPv6MugrSTkEsWFs8wCDg3iC6!er%".encode()), mode=AES.MODE_ECB)
         plain_text = cryptor.decrypt(text)
-        paddingLen = plain_text[len(plain_text)-1]
+        paddingLen = plain_text[len(plain_text) - 1]
         msg = plain_text[0:-paddingLen]
         msg = msg.decode()
         params = toml.loads(msg)
@@ -107,7 +102,7 @@ def get_params(fname):
     p.pass_key = params['pass_key'].replace(" ", "") if 'pass_key' in params else 'qwer1234'
     # 经纪商id
     broker_id_from_config = params['broker_id'] if 'broker_id' in params else ""
-    p.broker_id = get_broker_id( broker_id_from_config, params['exchange'])
+    p.broker_id = get_broker_id(broker_id_from_config, params['exchange'])
     # 交易盘口
     p.exchange = params['exchange'] if 'exchange' in params else ""
     # 交易品种
@@ -123,15 +118,18 @@ def get_params(fname):
     # 杠杆大小
     p.leverrate = float(params['leverrate']) if 'leverrate' in params else 1.0
     # 参考盘口
-    p.refexchange = params['refexchange'].replace('[','').replace(']','').replace("'",'').replace(" ", "").split(',') if "refexchange" in params else ""
+    p.refexchange = params['refexchange'].replace('[', '').replace(']', '').replace("'", '').replace(" ", "").split(
+        ',') if "refexchange" in params else ""
     # 参考品种
-    p.refpair = params['refpair'].replace('[','').replace(']','').replace("'",'').replace(" ", "").split(',') if "refpair" in params else ""
+    p.refpair = params['refpair'].replace('[', '').replace(']', '').replace("'", '').replace(" ", "").split(
+        ',') if "refpair" in params else ""
     # 网络代理
-    p.proxy = params['proxy'] if 'proxy' in params else None # 仅在win下有效
+    p.proxy = params['proxy'] if 'proxy' in params else None  # 仅在win下有效
     # 账户资金使用比例
     p.used_pct = params['used_pct'] if 'used_pct' in params else "0.9"
     # discord播报地址
-    p.webhook = params['webhook'] if 'webhook' in params else "https://discord.com/api/webhooks/907870708481265675/IfN4GqH4fj8HWS_FecH3Lrc2qtRyqsCHsSJVLFHlxY8ioHprfdxIMUNAfqkZZ6opzVEP"  
+    p.webhook = params[
+        'webhook'] if 'webhook' in params else "https://discord.com/api/webhooks/907870708481265675/IfN4GqH4fj8HWS_FecH3Lrc2qtRyqsCHsSJVLFHlxY8ioHprfdxIMUNAfqkZZ6opzVEP"
     # 默认第n参考盘口
     p.index = params['index'] if 'index' in params else 0
     # 止损比例 0.02 = 2%
@@ -144,11 +142,11 @@ def get_params(fname):
     p.backtest = params['backtest'] if 'backtest' in params else 1
     # 保存实时行情 会有巨大性能损耗
     p.save = params['save'] if 'save' in params else 0
-    p.place_order_limit = params['place_order_limit'] if 'place_order_limit' in params else 0 # 允许的每秒下单次数
+    p.place_order_limit = params['place_order_limit'] if 'place_order_limit' in params else 0  # 允许的每秒下单次数
     # 是否启用colocation技术
-    p.colo = params['colo'] if 'colo' in params else 0 
+    p.colo = params['colo'] if 'colo' in params else 0
     # 是否启用fast行情 会增加性能开销
-    p.fast = params['fast'] if 'fast' in params else 1 
+    p.fast = params['fast'] if 'fast' in params else 1
     # 选择指定的私有ip进行网络通信 默认0 用于多网卡多ip的实例
     p.ip = params['ip'] if 'ip' in params else 0
     # 合约不允许holdcoin持有底仓币
@@ -173,7 +171,8 @@ def get_params(fname):
     print(p)
     return p
 
-def get_broker_id(broker_id , exchange_name):
+
+def get_broker_id(broker_id, exchange_name):
     '''处理brokerid特殊情况'''
     if 'binance' in exchange_name:
         return broker_id
@@ -182,6 +181,7 @@ def get_broker_id(broker_id , exchange_name):
     else:
         return ""
 
+
 # 报单频率限制等级
 BASIC_LIMIT = 100
 GATE_SPOT_LIMIT = 10.0
@@ -192,92 +192,96 @@ BINANCE_USDT_SWAP_LIMIT = 5.0
 BINANCE_SPOT_LIMIT = 2.0
 COINEX_SPOT_LIMIT = 20.0
 COINEX_USDT_SWAP_LIMIT = 20.0
-OKEX_USDT_SWAP_LIMIT= 30.0
+OKEX_USDT_SWAP_LIMIT = 30.0
 BITGET_USDT_SWAP_LIMIT = 10.0
 BYBIT_USDT_SWAP_LIMIT = 1.0
 MEXC_SPOT_LIMIT = 333
 RATIO = 4.0
 
+
 def get_limit_requests_num_per_second(exchange, limit=0):
     '''每秒请求频率'''
     if limit != 0:
-        return limit*RATIO
+        return limit * RATIO
     elif exchange == "gate_spot":
-        return GATE_SPOT_LIMIT*RATIO
-    elif exchange == "gate_usdt_swap": # 100/s
-        return GATE_USDT_SWAP_LIMIT*RATIO
-    elif exchange == "kucoin_spot": # 15/s
-        return KUCOIN_SPOT_LIMIT*RATIO
+        return GATE_SPOT_LIMIT * RATIO
+    elif exchange == "gate_usdt_swap":  # 100/s
+        return GATE_USDT_SWAP_LIMIT * RATIO
+    elif exchange == "kucoin_spot":  # 15/s
+        return KUCOIN_SPOT_LIMIT * RATIO
     elif exchange == "kucoin_usdt_swap":
-        return KUCOIN_USDT_SWAP_LIMIT*RATIO
+        return KUCOIN_USDT_SWAP_LIMIT * RATIO
     elif exchange == "binance_usdt_swap":
-        return BINANCE_USDT_SWAP_LIMIT*RATIO
+        return BINANCE_USDT_SWAP_LIMIT * RATIO
     elif exchange == "binance_spot":
-        return BINANCE_SPOT_LIMIT*RATIO
+        return BINANCE_SPOT_LIMIT * RATIO
     elif exchange == "coinex_spot":
-        return COINEX_SPOT_LIMIT*RATIO
+        return COINEX_SPOT_LIMIT * RATIO
     elif exchange == "coinex_usdt_swap":
-        return COINEX_USDT_SWAP_LIMIT*RATIO
+        return COINEX_USDT_SWAP_LIMIT * RATIO
     elif exchange == "okex_usdt_swap":
-        return OKEX_USDT_SWAP_LIMIT*RATIO
+        return OKEX_USDT_SWAP_LIMIT * RATIO
     elif exchange == "bitget_usdt_swap":
-        return BITGET_USDT_SWAP_LIMIT*RATIO
+        return BITGET_USDT_SWAP_LIMIT * RATIO
     elif exchange == "bybit_usdt_swap":
-        return BYBIT_USDT_SWAP_LIMIT*RATIO
+        return BYBIT_USDT_SWAP_LIMIT * RATIO
     elif exchange == "mexc_spot":
-        return MEXC_SPOT_LIMIT*RATIO
+        return MEXC_SPOT_LIMIT * RATIO
     else:
         print("限频规则未找到")
-    return BASIC_LIMIT*RATIO
+    return BASIC_LIMIT * RATIO
 
 
 def get_limit_order_requests_num_per_second(exchange, limit=0):
     '''每秒下单请求频率'''
     if limit != 0:
         return limit
-    elif exchange == "gate_spot": # 10/s
+    elif exchange == "gate_spot":  # 10/s
         return GATE_SPOT_LIMIT
-    elif exchange == "gate_usdt_swap": # 100/s
+    elif exchange == "gate_usdt_swap":  # 100/s
         return GATE_USDT_SWAP_LIMIT
-    elif exchange == "kucoin_spot": # 15/s
+    elif exchange == "kucoin_spot":  # 15/s
         return KUCOIN_SPOT_LIMIT
-    elif exchange == "kucoin_usdt_swap": # 10/s
+    elif exchange == "kucoin_usdt_swap":  # 10/s
         return KUCOIN_USDT_SWAP_LIMIT
-    elif exchange == "binance_usdt_swap": # 5/s
+    elif exchange == "binance_usdt_swap":  # 5/s
         return BINANCE_USDT_SWAP_LIMIT
-    elif exchange == "binance_spot": # 2/s
+    elif exchange == "binance_spot":  # 2/s
         return BINANCE_SPOT_LIMIT
-    elif exchange == "coinex_spot": # 20/s
+    elif exchange == "coinex_spot":  # 20/s
         return COINEX_SPOT_LIMIT
-    elif exchange == "coinex_usdt_swap": # 20/s
+    elif exchange == "coinex_usdt_swap":  # 20/s
         return COINEX_USDT_SWAP_LIMIT
-    elif exchange == "okex_usdt_swap": # 30/s
+    elif exchange == "okex_usdt_swap":  # 30/s
         return OKEX_USDT_SWAP_LIMIT
-    elif exchange == "bitget_usdt_swap": # 10/s
+    elif exchange == "bitget_usdt_swap":  # 10/s
         return BITGET_USDT_SWAP_LIMIT
-    elif exchange == "bybit_usdt_swap": # 2/s
+    elif exchange == "bybit_usdt_swap":  # 2/s
         return BYBIT_USDT_SWAP_LIMIT
-    elif exchange == "mexc_spot": # 2/s
+    elif exchange == "mexc_spot":  # 2/s
         return MEXC_SPOT_LIMIT
     else:
         print("限频规则未找到")
     return BASIC_LIMIT
 
+
 def dist_to_weight(price, mp, eff_range=EFF_RANGE):
     '''
         距离转换为权重
     '''
-    dist = abs(price-mp)/mp
-    weight = 1 - clip(dist/eff_range, 0.0, 0.95)
+    dist = abs(price - mp) / mp
+    weight = 1 - clip(dist / eff_range, 0.0, 0.95)
     weight = weight if weight > 0 else 0
     return weight
 
+
 def change_params(fname, params, changes):
     # 更改配置
     for i in changes:
         params[i[0]] = i[1]
-    with open(f"{fname}","w") as f:
-        toml.dump(params,f)
+    with open(f"{fname}", "w") as f:
+        toml.dump(params, f)
+
 
 def show_memory(unit='B', threshold=1024):
     '''查看变量占用内存情况
@@ -295,11 +299,13 @@ def show_memory(unit='B', threshold=1024):
     print(msg)
     return msg
 
+
 def clip(num, _min, _max):
     if num > _max: num = _max
     if num < _min: num = _min
     return num
 
+
 async def ding(msg, at_all, webhook, proxy=None):
     '''
         发送钉钉消息
@@ -313,11 +319,11 @@ async def ding(msg, at_all, webhook, proxy=None):
         "description": msg
     }
     message = {
-    "content": "大吉大利 今晚吃鸡",
-    "username": "千千喵",
-    "embeds": [
-        embed
-            ],
+        "content": "大吉大利 今晚吃鸡",
+        "username": "千千喵",
+        "embeds": [
+            embed
+        ],
     }
     message_json = json.dumps(message)
     if 'win' in sys.platform:
@@ -325,18 +331,20 @@ async def ding(msg, at_all, webhook, proxy=None):
     else:
         proxy = None
     async with aiohttp.ClientSession() as session:
-        await session.post(url=webhook, data=message_json, headers=header, proxy=proxy, timeout = 10)
+        await session.post(url=webhook, data=message_json, headers=header, proxy=proxy, timeout=10)
+
 
 def _get_params(url, proxy, params):
     '''更新参数'''
     import requests
     try:
-        res = requests.post(url=url, json=params, timeout = 10)
+        res = requests.post(url=url, json=params, timeout=10)
         return json.loads(res.text)
     except:
         traceback.print_exc()
         return []
 
+
 async def _post_params(url, proxy, params):
     '''更新参数'''
     try:
@@ -345,7 +353,7 @@ async def _post_params(url, proxy, params):
         else:
             proxy = None
         async with aiohttp.ClientSession() as session:
-            res = await session.post(url=url, proxy=proxy, data=params, timeout = 10)
+            res = await session.post(url=url, proxy=proxy, data=params, timeout=10)
             data = await res.text()
             print(data)
             return data
@@ -354,6 +362,7 @@ async def _post_params(url, proxy, params):
         return "post_params error"
     return None
 
+
 def get_ip():
     try:
         s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@@ -363,6 +372,7 @@ def get_ip():
         s.close()
     return ip
 
+
 def check_auth():
     print("*** 检查使用权限1 ***")
     ip = get_ip()
@@ -375,6 +385,7 @@ def check_auth():
         os._exit(0)
     print("*** 符合要求 ***")
 
+
 def check_time():
     print("*** 检查使用权限2 ***")
     if time.time() > int(time.mktime(time.strptime('2021-11-17 00:00:00', "%Y-%m-%d %H:%M:%S"))):
@@ -382,41 +393,68 @@ def check_time():
         os._exit(0)
     print("*** 符合要求 ***")
 
+
 def num_to_str(num, d):
-    if d >= 1.0:return "%d"%num
-    elif d in [0.1, 0.5]:return "%.1f"%num
-    elif d in [0.01, 0.05]:return "%.2f"%num
-    elif d in [0.001, 0.005]:return "%.3f"%num
-    elif d in [0.0001, 0.0005]:return "%.4f"%num
-    elif d in [0.00001, 0.00005]:return "%.5f"%num
-    elif d in [0.000001, 0.000005]:return "%.6f"%num
-    elif d in [0.0000001, 0.0000005]:return "%.7f"%num
-    elif d in [0.00000001, 0.00000005]:return "%.8f"%num
-    elif d in [0.000000001, 0.000000005]:return "%.9f"%num
-    elif d in [0.0000000001, 0.0000000005]:return "%.10f"%num
-    else: return str(num)
+    if d >= 1.0:
+        return "%d" % num
+    elif d in [0.1, 0.5]:
+        return "%.1f" % num
+    elif d in [0.01, 0.05]:
+        return "%.2f" % num
+    elif d in [0.001, 0.005]:
+        return "%.3f" % num
+    elif d in [0.0001, 0.0005]:
+        return "%.4f" % num
+    elif d in [0.00001, 0.00005]:
+        return "%.5f" % num
+    elif d in [0.000001, 0.000005]:
+        return "%.6f" % num
+    elif d in [0.0000001, 0.0000005]:
+        return "%.7f" % num
+    elif d in [0.00000001, 0.00000005]:
+        return "%.8f" % num
+    elif d in [0.000000001, 0.000000005]:
+        return "%.9f" % num
+    elif d in [0.0000000001, 0.0000000005]:
+        return "%.10f" % num
+    else:
+        return str(num)
+
 
 def num_to_decimal(num):
     '''根据小数点位数获取精度'''
     num = str(num)
-    if '.' not in num:return 0
-    elif '.' == num[-2]:return 1
-    elif '.' == num[-3]:return 2
-    elif '.' == num[-4]:return 3
-    elif '.' == num[-5]:return 4
-    elif '.' == num[-6]:return 5
-    elif '.' == num[-7]:return 6
-    elif '.' == num[-8]:return 7
-    elif '.' == num[-9]:return 8
-    elif '.' == num[-10]:return 9
-    elif '.' == num[-11]:return 10
-    else:return 11
+    if '.' not in num:
+        return 0
+    elif '.' == num[-2]:
+        return 1
+    elif '.' == num[-3]:
+        return 2
+    elif '.' == num[-4]:
+        return 3
+    elif '.' == num[-5]:
+        return 4
+    elif '.' == num[-6]:
+        return 5
+    elif '.' == num[-7]:
+        return 6
+    elif '.' == num[-8]:
+        return 7
+    elif '.' == num[-9]:
+        return 8
+    elif '.' == num[-10]:
+        return 9
+    elif '.' == num[-11]:
+        return 10
+    else:
+        return 11
+
 
 def fix_amount(amount, stepSize):
     '''修补数量向下取整'''
     return float(
-            Decimal(str(amount))//Decimal(str(stepSize)
-        ) \
+        Decimal(str(amount)) // Decimal(str(stepSize)
+                                        ) \
         * Decimal(str(stepSize)))
     # return float(Decimal(str(amount)).quantize(Decimal(str(stepSize)), ROUND_FLOOR))
 
@@ -424,11 +462,12 @@ def fix_amount(amount, stepSize):
 def fix_price(price, tickSize):
     '''修补价格四舍五入'''
     return float(
-            round(Decimal(str(price))/Decimal(str(tickSize))
-        ) \
+        round(Decimal(str(price)) / Decimal(str(tickSize))
+              ) \
         * Decimal(str(tickSize)))
     # return float(Decimal(str(price)).quantize(Decimal(str(tickSize)), ROUND_HALF_UP))
 
+
 def timeit(func):
     def wrapper(*args, **kwargs):
         nowTime = time.time()
@@ -437,8 +476,10 @@ def timeit(func):
         spend_time = round(spend_time * 1e6, 3)
         print(f'{func.__name__} 耗时 {spend_time} us')
         return res
+
     return wrapper
 
+
 def get_backtest_set(base=""):
     '''生成预设参数'''
     # 开仓距离不能太近必须超过大部分价格tick运动的距离
@@ -454,75 +495,34 @@ def get_backtest_set(base=""):
     close_dict = dict()
     for open in open_list:
         close_dict[open] = [
-            open*0.1,
-            open*0.2,
-            ]
+            open * 0.1,
+            open * 0.2,
+        ]
     alpha_list = [0.0]
     return open_list, close_dict, alpha_list
 
+
 def get_local_ip_list():
     '''获取本地ip'''
     import netifaces as ni
     ipList = []
     # print('检测服务器网络配置')
     for dev in ni.interfaces():
-        print('dev:',dev)
+        print('dev:', dev)
         if 'ens' in dev or 'eth' in dev or 'enp' in dev:
             # print(ni.ifaddresses(dev))
             for i in ni.ifaddresses(dev)[2]:
-                ip=i['addr']
+                ip = i['addr']
                 print(f"检测到私有ip:{ip}")
                 if ip not in ipList:
                     ipList.append(ip)
     print(f"当前服务器私有ip为{ipList}")
     return ['127.0.0.1']
-    
-if __name__ == "__main__":
-
-    #########
-    if 0:
-        print(fix_amount(1.0, 0.1))
-        print(fix_amount(0.9, 0.05))
-        print(fix_amount(1.1, 0.1))
-        print(fix_amount(1.2, 0.5))
-        print(fix_amount(0.01, 0.05))
-    if 1:
-        print(fix_price(1.0, 0.1))
-        print(fix_price(0.9, 2.0))
-        print(fix_price(1.1, 0.1))
-        print(fix_price(1.2, 0.5))
-        print(fix_price(4999.99, 0.5))
-    #########
-    if 0:
-        # print(num_to_str(123.123))
-        print(get_backtest_set())
-
-    ####################
-    if 0:
-
-        p = get_params("config.toml")
-
-        loop = asyncio.get_event_loop()
-        
-        # loop.create_task(ding("123", 1, "https://discord.com/api/webhooks/907870708481265675/IfN4GqH4fj8HWS_FecH3Lrc2qtRyqsCHsSJVLFHlxY8ioHprfdxIMUNAfqkZZ6opzVEP"))
-        
-        loop.create_task(
-            _post_params(
-                "http://wwww.khods.com:8888/post_params", 
-                None,
-                ujson.dumps({
-                    "exchange":"binance_usdt_swap",
-                    "pair":"eth_usdt",
-                    "open":"0.001",
-                    "close":"0.0001",
-                    "refexchange":"binance_spot",
-                    "refpair":"eth_usdt",
-                    "profit":0.1,
-                })
-            )
-        )
-
-        loop.run_forever()
-    
-    ####################
 
+
+if __name__ == '__main__':
+    gamma = float(0.5)
+    kappa = float(2)
+    factor = 1 + gamma / kappa
+    _optimal_spread = 2 * Decimal(factor).ln()
+    print(_optimal_spread)