diff --git a/custom/fixeddataengine.py b/custom/fixeddataengine.py index 36ed927..aebfa85 100644 --- a/custom/fixeddataengine.py +++ b/custom/fixeddataengine.py @@ -25,7 +25,7 @@ def __init__(self, event_engine, clock_engine, watch_stocks=None, s='sina'): self.source = None self.__queue = mp.Queue(1000) self.is_pause = not clock_engine.is_tradetime_now() - self._control_thread = Thread(target=self._process_control) + self._control_thread = Thread(target=self._process_control, name="FixedDataEngine._control_thread") self._control_thread.start() super(FixedDataEngine, self).__init__(event_engine, clock_engine) diff --git a/custom/fixedmainengine.py b/custom/fixedmainengine.py index ca0560a..db43f01 100644 --- a/custom/fixedmainengine.py +++ b/custom/fixedmainengine.py @@ -32,7 +32,7 @@ def __init__(self, broker, need_data='ht.json', quotation_engines=[FixedDataEngi # 加载锁 self.lock = Lock() # 加载线程 - self._watch_thread = Thread(target=self._load_strategy) + self._watch_thread = Thread(target=self._load_strategy, name="FixedMainEngine.watch_reload_strategy") positions = [p['stock_code'] for p in self.user.position] positions.extend(ext_stocks) for quotation_engine in quotation_engines: diff --git a/easyquant/event_engine.py b/easyquant/event_engine.py index 051f7bc..42965b3 100644 --- a/easyquant/event_engine.py +++ b/easyquant/event_engine.py @@ -23,7 +23,7 @@ def __init__(self): self.__active = False # 事件引擎处理线程 - self.__thread = Thread(target=self.__run) + self.__thread = Thread(target=self.__run, name="EventEngine.__thread") # 事件字典,key 为时间, value 为对应监听事件函数的列表 self.__handlers = defaultdict(list) @@ -33,7 +33,7 @@ def __run(self): while self.__active: try: event = self.__queue.get(block=True, timeout=1) - handle_thread = Thread(target=self.__process, args=(event,)) + handle_thread = Thread(target=self.__process, name="EventEngine.__process", args=(event,)) handle_thread.start() except Empty: pass diff --git a/easyquant/main_engine.py b/easyquant/main_engine.py index 4a04386..839120d 100644 --- a/easyquant/main_engine.py +++ b/easyquant/main_engine.py @@ -5,6 +5,7 @@ import time from collections import OrderedDict import dill +from threading import Thread, Lock import easytrader from logbook import Logger, StreamHandler @@ -23,6 +24,7 @@ ACCOUNT_OBJECT_FILE = 'account.session' + class MainEngine: """主引擎,负责行情 / 事件驱动引擎 / 交易""" @@ -62,6 +64,20 @@ def __init__(self, broker=None, need_data=None, quotation_engines=None, self.strategies = OrderedDict() self.strategy_list = list() + # 是否要动态重载策略 + self.is_watch_strategy = False + # 修改时间缓存 + self._cache = {} + # # 文件进程映射 + # self._process_map = {} + # 文件模块映射 + self._modules = {} + self._names = None + # 加载锁 + self.lock = Lock() + # 加载线程 + self._watch_thread = Thread(target=self._load_strategy, name="MainEngine.watch_reload_strategy") + self.log.info('启动主引擎') def start(self): @@ -74,24 +90,98 @@ def start(self): quotation_engine.start() self.clock_engine.start() + def load(self, names, strategy_file): + with self.lock: + mtime = os.path.getmtime(os.path.join('strategies', strategy_file)) + + # 是否需要重新加载 + reload = False + + strategy_module_name = os.path.basename(strategy_file)[:-3] + new_module = lambda strategy_module_name: importlib.import_module('.' + strategy_module_name, 'strategies') + strategy_module = self._modules.get( + strategy_file, # 从缓存中获取 module 实例 + new_module(strategy_module_name) # 创建新的 module 实例 + ) + + if self._cache.get(strategy_file, None) == mtime: + # 检查最后改动时间 + return + elif self._cache.get(strategy_file, None) is not None: + # 注销策略的监听 + old_strategy = self.get_strategy(strategy_module.Strategy.name) + if old_strategy is None: + print(18181818, strategy_module_name) + for s in self.strategy_list: + print(s.name) + self.log.warn(u'卸载策略: %s' % old_strategy.name) + self.strategy_listen_event(old_strategy, "unlisten") + time.sleep(2) + reload = True + # 重新加载 + if reload: + strategy_module = importlib.reload(strategy_module) + + self._modules[strategy_file] = strategy_module + + strategy_class = getattr(strategy_module, 'Strategy') + if names is None or strategy_class.name in names: + self.strategies[strategy_module_name] = strategy_class + # 缓存加载信息 + new_strategy = strategy_class(log_handler=self.log, main_engine=self) + self.strategy_list.append(new_strategy) + self._cache[strategy_file] = mtime + self.strategy_listen_event(new_strategy, "listen") + self.log.info(u'加载策略: %s' % strategy_module_name) + + def strategy_listen_event(self, strategy, _type="listen"): + """ + 所有策略要监听的事件都绑定到这里 + :param strategy: Strategy() + :param _type: "listen" OR "unlisten" + :return: + """ + func = { + "listen": self.event_engine.register, + "unlisten": self.event_engine.unregister, + }.get(_type) + + # 行情引擎的事件 + for quotation_engine in self.quotation_engines: + func(quotation_engine.EventType, strategy.run) + + # 时钟事件 + func(ClockEngine.EventType, strategy.clock) + def load_strategy(self, names=None): """动态加载策略 :param names: 策略名列表,元素为策略的 name 属性""" s_folder = 'strategies' + self._names = names strategies = os.listdir(s_folder) strategies = filter(lambda file: file.endswith('.py') and file != '__init__.py', strategies) importlib.import_module(s_folder) for strategy_file in strategies: - strategy_module_name = os.path.basename(strategy_file)[:-3] - strategy_module = importlib.import_module('.' + strategy_module_name, 'strategies') - strategy_class = getattr(strategy_module, 'Strategy') - - if names is None or strategy_class.name in names: - self.strategies[strategy_module_name] = strategy_class - self.strategy_list.append(strategy_class(log_handler=self.log, main_engine=self)) - self.log.info('加载策略: %s' % strategy_module_name) + self.load(self._names, strategy_file) + # 如果线程没有启动,就启动策略监视线程 + if self.is_watch_strategy and not self._watch_thread.is_alive(): + self.log.warn("启用了动态加载策略功能") + self._watch_thread.start() + + def _load_strategy(self): + while True: + try: + self.load_strategy(self._names) + time.sleep(2) + except Exception as e: + print(e) + + def get_strategy(self, name): + """ + :param name: + :return: + """ for strategy in self.strategy_list: - self.event_engine.register(ClockEngine.EventType, strategy.clock) - for quotation_engine in self.quotation_engines: - self.event_engine.register(quotation_engine.EventType, strategy.run) - self.log.info('加载策略完毕') + if strategy.name == name: + return strategy + return None diff --git a/easyquant/multiprocess/strategy_wrapper.py b/easyquant/multiprocess/strategy_wrapper.py index 2418a7e..baffc96 100644 --- a/easyquant/multiprocess/strategy_wrapper.py +++ b/easyquant/multiprocess/strategy_wrapper.py @@ -73,9 +73,9 @@ def _process(self): """ 启动进程 """ - event_thread = Thread(target=self._process_event) + event_thread = Thread(target=self._process_event, name="ProcessWrapper._process_event") event_thread.start() - clock_thread = Thread(target=self._process_clock) + clock_thread = Thread(target=self._process_clock, name="ProcessWrapper._process_clock") clock_thread.start() event_thread.join() diff --git a/easyquant/push_engine/base_engine.py b/easyquant/push_engine/base_engine.py index 3bbc7f5..47c7ac1 100644 --- a/easyquant/push_engine/base_engine.py +++ b/easyquant/push_engine/base_engine.py @@ -21,7 +21,7 @@ def __init__(self, event_engine, clock_engine): self.event_engine = event_engine self.clock_engine = clock_engine self.is_active = True - self.quotation_thread = Thread(target=self.push_quotation) + self.quotation_thread = Thread(target=self.push_quotation, name="QuotationEngine.%s" % self.EventType) self.quotation_thread.setDaemon(False) self.init() diff --git a/easyquant/push_engine/clock_engine.py b/easyquant/push_engine/clock_engine.py index a1155d2..ea6d358 100644 --- a/easyquant/push_engine/clock_engine.py +++ b/easyquant/push_engine/clock_engine.py @@ -112,7 +112,7 @@ def __init__(self, event_engine, tzinfo=None): self.event_engine = event_engine self.is_active = True - self.clock_engine_thread = Thread(target=self.clocktick) + self.clock_engine_thread = Thread(target=self.clocktick, name="ClockEngine.clocktick") self.sleep_time = 1 self.trading_state = True if (etime.is_tradetime(datetime.datetime.now()) and etime.is_trade_date(datetime.datetime.now())) else False self.clock_moment_handlers = deque() diff --git "a/strategies/\347\255\226\347\225\2451_Demo.py" "b/strategies/\347\255\226\347\225\2451_Demo.py" index 7daac74..db560cd 100644 --- "a/strategies/\347\255\226\347\225\2451_Demo.py" +++ "b/strategies/\347\255\226\347\225\2451_Demo.py" @@ -1,14 +1,17 @@ +import time import datetime as dt from dateutil import tz from easyquant import DefaultLogHandler from easyquant import StrategyTemplate - class Strategy(StrategyTemplate): name = '测试策略1' def init(self): - now = self.clock_engine.now_dt + # 通过下面的方式来获取时间戳 + now_dt = self.clock_engine.now_dt + now = self.clock_engine.now + now = time.time() # 注册时钟事件 clock_type = "盘尾" diff --git a/test.py b/test.py index 9738c99..868e314 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,4 @@ import easyquotation -from easyquant.push_engine.clock_engine import ClockEngine import easyquant from easyquant import DefaultQuotationEngine, DefaultLogHandler, PushBaseEngine @@ -38,6 +37,7 @@ def init(self): def fetch_quotation(self): return self.source.stocks(['162411', '000002']) + quotation_choose = input('请输入使用行情引擎 1: sina 2: leverfun 十档 行情(目前只选择了 162411, 000002)\n:') quotation_engine = DefaultQuotationEngine if quotation_choose == '1' else LFEngine @@ -52,7 +52,7 @@ def fetch_quotation(self): log_handler = DefaultLogHandler(name='测试', log_type=log_type, filepath=log_filepath) - m = easyquant.MainEngine(broker, need_data, quotation_engines=[quotation_engine], log_handler=log_handler) +m.is_watch_strategy = True # 策略文件出现改动时,自动重载,不建议在生产环境下使用 m.load_strategy() m.start() diff --git a/unitest_demo.py b/unitest_demo.py index c3b73c2..d0459b2 100644 --- a/unitest_demo.py +++ b/unitest_demo.py @@ -104,16 +104,16 @@ def test_set_now(self): clock_engien = ClockEngine(EventEngine(), tzinfo) # 去掉微秒误差后验证其数值 - self.assertEqual(clock_engien.now, now.timestamp()) # time.time 时间戳 - self.assertEqual(clock_engien.now_dt, now) # datetime 时间戳 + self.assertEqual(clock_engien.now, now.timestamp()) # time.time 时间戳 + self.assertEqual(clock_engien.now_dt, now) # datetime 时间戳 # 据此可以模拟一段时间内各个闹钟事件的触发,比如模拟开市9:00一直到休市15:00 for _ in range(60): clock_engien.tock() - now += datetime.timedelta(seconds=1) # 每秒触发一次 tick_tock + now += datetime.timedelta(seconds=1) # 每秒触发一次 tick_tock time.time = mock.Mock(return_value=now.timestamp()) - self.assertEqual(clock_engien.now, now.timestamp()) # time.time 时间戳 - self.assertEqual(clock_engien.now_dt, now) # datetime 时间戳 + self.assertEqual(clock_engien.now, now.timestamp()) # time.time 时间戳 + self.assertEqual(clock_engien.now_dt, now) # datetime 时间戳 def test_clock_moment_is_active(self): # 设置时间 @@ -325,7 +325,6 @@ def count(event): self.assertEqual(len(counts[15]), (15 - 9) * 4 + 1 - len(["9:00"])) - def test_tick_moment_event(self): """ 测试 tick 中的时刻时钟事件 @@ -357,7 +356,7 @@ def count(event): # 预估时间事件触发次数, 每个交易日触发一次 actived_times = 0 - for date in pd.date_range(begin.date(), periods=days+1): + for date in pd.date_range(begin.date(), periods=days + 1): if is_trade_date(date): actived_times += 1 @@ -384,4 +383,3 @@ def count(event): self.assertEqual(len(counts['pause']), actived_times) self.assertEqual(len(counts['continue']), actived_times) self.assertEqual(len(counts['close']), actived_times) -