Skip to content

Commit

Permalink
V0.9.30 更新一批代码 (#170)
Browse files Browse the repository at this point in the history
* 0.9.30 first commit

* update

* 0.9.30 优化 factor、event

* 0.9.30 update

* 0.9.30 优化 RedisWeightsClient

* fix docs

* 0.9.30 优化部分功能代码

* 0.9.30 update RedisWeightsClient

* 0.9.30 update

* 0.9.30 update

* 0.9.30 新增几个 streamlit 组件

* 0.9.30 RedisWeightsClient 支持权限管控
  • Loading branch information
zengbin93 authored Oct 7, 2023
1 parent 7bc88ac commit 941369e
Show file tree
Hide file tree
Showing 19 changed files with 523 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pythonpackage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ name: Python package

on:
push:
branches: [ master, V0.9.29 ]
branches: [ master, V0.9.30 ]
pull_request:
branches: [ master ]

Expand Down
16 changes: 14 additions & 2 deletions czsc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
BarGenerator,
freq_end_time,
resample_bars,
is_trading_time,
get_intraday_times,
check_freq_and_market,

dill_dump,
dill_load,
read_json,
Expand Down Expand Up @@ -81,18 +85,26 @@
# streamlit 量化分析组件
from czsc.utils.st_components import (
show_daily_return,
show_correlation,
show_sectional_ic,
show_factor_returns,
show_factor_layering,
)

from czsc.utils.bi_info import (
calculate_bi_info,
symbols_bi_infos,
)

from czsc.utils.features import (
normalize_feature,
)


__version__ = "0.9.29"
__version__ = "0.9.30"
__author__ = "zengbin93"
__email__ = "[email protected]"
__date__ = "20230904"
__date__ = "20230925"


def welcome():
Expand Down
19 changes: 14 additions & 5 deletions czsc/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,8 +510,9 @@ class Factor:
def __post_init__(self):
if not self.signals_all:
raise ValueError("signals_all 不能为空")
str_signals = str(self.dump())
sha256 = hashlib.sha256(str_signals.encode("utf-8")).hexdigest().upper()[:8]
_fatcor = self.dump()
_fatcor.pop("name")
sha256 = hashlib.sha256(str(_fatcor).encode("utf-8")).hexdigest().upper()[:8]
self.name = f"{self.name}#{sha256}" if self.name else sha256

@property
Expand Down Expand Up @@ -552,6 +553,7 @@ def dump(self) -> dict:
signals_not = [x.signal for x in self.signals_not] if self.signals_not else []

raw = {
"name": self.name,
"signals_all": signals_all,
"signals_any": signals_any,
"signals_not": signals_not,
Expand Down Expand Up @@ -604,9 +606,14 @@ class Event:
def __post_init__(self):
if not self.factors:
raise ValueError("factors 不能为空")
str_factors = str(self.dump())
sha256 = hashlib.sha256(str_factors.encode("utf-8")).hexdigest().upper()[:8]
self.name = f"{self.operate.value}#{sha256}"
_event = self.dump()
_event.pop("name")
sha256 = hashlib.sha256(str(_event).encode("utf-8")).hexdigest().upper()[:8]
if self.name:
self.name = f"{self.name}#{sha256}"
else:
self.name = f"{self.operate.value}#{sha256}"
self.sha256 = sha256

@property
def unique_signals(self) -> List[str]:
Expand Down Expand Up @@ -681,6 +688,7 @@ def dump(self) -> dict:
factors = [x.dump() for x in self.factors]

raw = {
"name": self.name,
"operate": self.operate.value,
"signals_all": signals_all,
"signals_any": signals_any,
Expand Down Expand Up @@ -712,6 +720,7 @@ def load(cls, raw: dict):
assert raw["factors"], "factors can not be empty"

e = Event(
name=raw.get("name", ""),
operate=Operate.__dict__["_value2member_map_"][raw["operate"]],
factors=[Factor.load(x) for x in raw["factors"]],
signals_all=[Signal(x) for x in raw.get("signals_all", [])],
Expand Down
44 changes: 29 additions & 15 deletions czsc/traders/rwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
class RedisWeightsClient:
"""策略持仓权重收发客户端"""

version = "V231006"

def __init__(self, strategy_name, redis_url, **kwargs):
"""
:param strategy_name: str, 策略名
Expand All @@ -39,6 +41,7 @@ def __init__(self, strategy_name, redis_url, **kwargs):
"""
self.strategy_name = strategy_name
self.redis_url = redis_url
self.key_prefix = kwargs.get("key_prefix", "Weights")

self.heartbeat_client = redis.from_url(redis_url, decode_responses=True)
self.heartbeat_prefix = kwargs.get("heartbeat_prefix", "heartbeat")
Expand All @@ -47,22 +50,33 @@ def __init__(self, strategy_name, redis_url, **kwargs):
self.r = redis.Redis(connection_pool=thread_safe_pool)
self.lua_publish = RedisWeightsClient.register_lua_publish(self.r)

self.heartbeat_thread = threading.Thread(target=self.__heartbeat, daemon=True)
self.heartbeat_thread.start()
if kwargs.get('send_heartbeat', True):
self.heartbeat_thread = threading.Thread(target=self.__heartbeat, daemon=True)
self.heartbeat_thread.start()

def set_metadata(self, base_freq, description, author, outsample_sdt, **kwargs):
"""设置策略元数据"""
key = f'{self.key_prefix}:META:{self.strategy_name}'
if self.r.exists(key):
if not kwargs.pop('overwrite', False):
logger.warning(f'已存在 {self.strategy_name} 的元数据,如需覆盖请设置 overwrite=True')
return
else:
self.r.delete(key)
logger.warning(f'删除 {self.strategy_name} 的元数据,重新写入')

outsample_sdt = pd.to_datetime(outsample_sdt).strftime('%Y%m%d')
meta = {'name': self.strategy_name, 'base_freq': base_freq,
meta = {'name': self.strategy_name, 'base_freq': base_freq, 'key_prefix': self.key_prefix,
'description': description, 'author': author, 'outsample_sdt': outsample_sdt,
'update_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'kwargs': json.dumps(kwargs)}
self.r.hset(f'{self.strategy_name}:meta', mapping=meta)
self.r.hset(key, mapping=meta)

@property
def metadata(self):
"""获取策略元数据"""
return self.r.hgetall(f'{self.strategy_name}:meta')
key = f'{self.key_prefix}:META:{self.strategy_name}'
return self.r.hgetall(key)

def publish(self, symbol, dt, weight, price=0, ref=None, overwrite=False):
"""发布单个策略持仓权重
Expand All @@ -79,7 +93,7 @@ def publish(self, symbol, dt, weight, price=0, ref=None, overwrite=False):
dt = pd.to_datetime(dt)

udt = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
key = f'Weights:{self.strategy_name}:{symbol}:{dt.strftime("%Y%m%d%H%M%S")}'
key = f'{self.key_prefix}:{self.strategy_name}:{symbol}:{dt.strftime("%Y%m%d%H%M%S")}'
ref = ref if ref else '{}'
ref_str = json.dumps(ref) if isinstance(ref, dict) else ref
return self.lua_publish(keys=[key], args=[1 if overwrite else 0, udt, weight, price, ref_str])
Expand All @@ -103,7 +117,7 @@ def publish_dataframe(self, df, overwrite=False, batch_size=10000):

keys, args = [], []
for row in df[['symbol', 'dt', 'weight', 'price', 'ref']].to_numpy():
key = f'Weights:{self.strategy_name}:{row[0]}:{row[1].strftime("%Y%m%d%H%M%S")}'
key = f'{self.key_prefix}:{self.strategy_name}:{row[0]}:{row[1].strftime("%Y%m%d%H%M%S")}'
keys.append(key)

args.append(row[2])
Expand All @@ -130,21 +144,21 @@ def publish_dataframe(self, df, overwrite=False, batch_size=10000):

def __heartbeat(self):
while True:
key = f'{self.heartbeat_prefix}:{self.strategy_name}'
key = f'{self.key_prefix}:{self.heartbeat_prefix}:{self.strategy_name}'
try:
self.heartbeat_client.set(key, datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
except Exception:
continue
time.sleep(15)

def get_keys(self, pattern):
"""使用 lua 获取 redis 中指定 pattern 的 keys"""
return self.r.eval('''local pattern = ARGV[1]\nreturn redis.call('KEYS', pattern)''', 0, pattern)
"""获取 redis 中指定 pattern 的 keys"""
return self.r.keys(pattern)

def clear_all(self):
"""删除该策略所有记录"""
self.r.delete(f"{self.strategy_name}:meta")
keys = self.get_keys(f'Weights:{self.strategy_name}*')
self.r.delete(f'{self.key_prefix}:META:{self.strategy_name}')
keys = self.get_keys(f'{self.key_prefix}:{self.strategy_name}*')
if keys is not None and len(keys) > 0:
self.r.delete(*keys)

Expand Down Expand Up @@ -195,7 +209,7 @@ def register_lua_publish(client):

def get_symbols(self):
"""获取策略交易的品种列表"""
keys = self.get_keys(f'Weights:{self.strategy_name}*')
keys = self.get_keys(f'{self.key_prefix}:{self.strategy_name}*')
symbols = {x.split(":")[2] for x in keys}
return list(symbols)

Expand All @@ -204,7 +218,7 @@ def get_last_weights(self, symbols=None):
symbols = symbols if symbols else self.get_symbols()
with self.r.pipeline() as pipe:
for symbol in symbols:
pipe.hgetall(f"Weights:{self.strategy_name}:{symbol}:LAST")
pipe.hgetall(f'{self.key_prefix}:{self.strategy_name}:{symbol}:LAST')
rows = pipe.execute()

dfw = pd.DataFrame(rows)
Expand All @@ -216,7 +230,7 @@ def get_hist_weights(self, symbol, sdt, edt) -> pd.DataFrame:
"""获取单个品种的持仓权重历史数据"""
start_score = pd.to_datetime(sdt).strftime('%Y%m%d%H%M%S')
end_score = pd.to_datetime(edt).strftime('%Y%m%d%H%M%S')
model_key = f'Weights:{self.strategy_name}:{symbol}'
model_key = f'{self.key_prefix}:{self.strategy_name}:{symbol}'
key_list = self.r.zrangebyscore(model_key, start_score, end_score)

if len(key_list) == 0:
Expand Down
67 changes: 41 additions & 26 deletions czsc/traders/weight_backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ def get_ensemble_weight(trader: CzscTrader, method: Union[AnyStr, Callable] = 'm


class WeightBacktest:
"""持仓权重回测"""
"""持仓权重回测
飞书文档:https://s0cqcxuy3p.feishu.cn/wiki/Pf1fw1woQi4iJikbKJmcYToznxb
"""
version = "V231005"

def __init__(self, dfw, digits=2, **kwargs) -> None:
"""持仓权重回测
Expand All @@ -142,7 +146,7 @@ def __init__(self, dfw, digits=2, **kwargs) -> None:
dt 为K线结束时间,必须是连续的交易时间序列,不允许有时间断层
symbol 为合约代码,
weight 为K线结束时间对应的持仓权重,
weight 为K线结束时间对应的持仓权重,品种之间的权重是独立的,不会互相影响
price 为结束时间对应的交易价格,可以是当前K线的收盘价,或者下一根K线的开盘价,或者未来N根K线的TWAP、VWAP等
数据样例如下:
Expand All @@ -169,10 +173,7 @@ def __init__(self, dfw, digits=2, **kwargs) -> None:
self.fee_rate = kwargs.get('fee_rate', 0.0002)
self.dfw['weight'] = self.dfw['weight'].round(digits)
self.symbols = list(self.dfw['symbol'].unique().tolist())
self.res_path = Path(kwargs.get('res_path', "weight_backtest"))
self.res_path.mkdir(exist_ok=True, parents=True)
logger.add(self.res_path.joinpath("weight_backtest.log"), rotation="1 week")
logger.info(f"持仓权重回测参数:digits={digits}, fee_rate={self.fee_rate},res_path={self.res_path},kwargs={kwargs}")
self.results = self.backtest()

def get_symbol_daily(self, symbol):
"""获取某个合约的每日收益率
Expand Down Expand Up @@ -285,39 +286,53 @@ def __add_operate(dt, bar_id, volume, price, operate):

def backtest(self):
"""回测所有合约的收益率"""
symbols = self.symbols
res = {}
for symbol in self.symbols:
for symbol in symbols:
daily = self.get_symbol_daily(symbol)
pairs = self.get_symbol_pairs(symbol)
res[symbol] = {"daily": daily, "pairs": pairs}

pd.to_pickle(res, self.res_path.joinpath("res.pkl"))
logger.info(f"回测结果已保存到 {self.res_path.joinpath('res.pkl')}")

# 品种等权费后日收益率
dret = pd.concat([v['daily'] for v in res.values()], ignore_index=True)
dret = pd.concat([v['daily'] for k, v in res.items() if k in symbols], ignore_index=True)
dret = pd.pivot_table(dret, index='date', columns='symbol', values='return').fillna(0)
dret['total'] = dret[list(res.keys())].mean(axis=1)
res['品种等权日收益'] = dret

stats = {"开始日期": dret.index.min().strftime("%Y%m%d"), "结束日期": dret.index.max().strftime("%Y%m%d")}
stats.update(daily_performance(dret['total']))
logger.info(f"品种等权费后日收益率:{stats}")
dret.to_excel(self.res_path.joinpath("daily_return.xlsx"), index=True)
logger.info(f"品种等权费后日收益率已保存到 {self.res_path.joinpath('daily_return.xlsx')}")
dfp = pd.concat([v['pairs'] for k, v in res.items() if k in symbols], ignore_index=True)
pairs_stats = evaluate_pairs(dfp)
pairs_stats = {k: v for k, v in pairs_stats.items() if k in ['单笔收益', '持仓K线数', '交易胜率', '持仓天数']}
stats.update(pairs_stats)

res['绩效评价'] = stats
return res

def report(self, res_path):
"""回测报告"""
res_path = Path(res_path)
res_path.mkdir(exist_ok=True, parents=True)
logger.add(res_path.joinpath("weight_backtest.log"), rotation="1 week")
logger.info(f"持仓权重回测参数:digits={self.digits}, fee_rate={self.fee_rate},res_path={res_path}")

res = self.results
pd.to_pickle(res, res_path.joinpath("res.pkl"))
logger.info(f"回测结果已保存到 {res_path.joinpath('res.pkl')}")

# 品种等权费后日收益率
dret = res['品种等权日收益'].copy()
dret.to_excel(res_path.joinpath("daily_return.xlsx"), index=True)
logger.info(f"品种等权费后日收益率已保存到 {res_path.joinpath('daily_return.xlsx')}")

# 品种等权费后日收益率资金曲线绘制
dret = dret.cumsum()
fig = px.line(dret, y=dret.columns.to_list(), title="费后日收益率资金曲线")
fig.for_each_trace(lambda trace: trace.update(visible=True if trace.name == 'total' else 'legendonly'))
fig.write_html(self.res_path.joinpath("daily_return.html"))
logger.info(f"费后日收益率资金曲线已保存到 {self.res_path.joinpath('daily_return.html')}")
fig.write_html(res_path.joinpath("daily_return.html"))
logger.info(f"费后日收益率资金曲线已保存到 {res_path.joinpath('daily_return.html')}")

# 所有开平交易记录的表现
dfp = pd.concat([v['pairs'] for v in res.values()], ignore_index=True)
pairs_stats = evaluate_pairs(dfp)
pairs_stats = {k: v for k, v in pairs_stats.items() if k in ['单笔收益', '持仓K线数', '交易胜率', '持仓天数']}
logger.info(f"所有开平交易记录的表现:{pairs_stats}")
stats.update(pairs_stats)
logger.info(f"策略评价:{stats}")
save_json(stats, self.res_path.joinpath("stats.json"))
res['stats'] = stats
return res
stats = res['绩效评价'].copy()
logger.info(f"绩效评价:{stats}")
save_json(stats, res_path.joinpath("stats.json"))
logger.info(f"绩效评价已保存到 {res_path.joinpath('stats.json')}")
1 change: 1 addition & 0 deletions czsc/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .word_writer import WordWriter
from .corr import nmi_matrix, single_linear, cross_sectional_ic
from .bar_generator import BarGenerator, freq_end_time, resample_bars
from .bar_generator import is_trading_time, get_intraday_times, check_freq_and_market
from .io import dill_dump, dill_load, read_json, save_json
from .sig import check_pressure_support, check_gap_info, is_bis_down, is_bis_up, get_sub_elements
from .sig import same_dir_counts, fast_slow_cross, count_last_same, create_single_signal
Expand Down
18 changes: 18 additions & 0 deletions czsc/utils/bar_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@
freq_edt_map[f"{_f}_{_m}"] = {k: v for k, v in dfg[["time", _f]].values}


def is_trading_time(dt: datetime = datetime.now(), market="A股"):
"""判断指定时间是否是交易时间"""
hm = dt.strftime("%H:%M")
times = freq_market_times[f"1分钟_{market}"]
return True if hm in times else False


def get_intraday_times(freq='1分钟', market="A股"):
"""获取指定市场的交易时间段
:param market: 市场名称,可选值:A股、期货、默认
:return: 交易时间段列表
"""
assert market in ['A股', '期货', '默认'], "market 参数必须为 A股 或 期货 或 默认"
assert freq.endswith("分钟"), "freq 参数必须为分钟级别的K线周期"
return freq_market_times[f"{freq}_{market}"]


def check_freq_and_market(time_seq: List[AnyStr]):
"""检查时间序列是否为同一周期,是否为同一市场
Expand Down
Loading

0 comments on commit 941369e

Please sign in to comment.