Skip to content

Commit

Permalink
fix CI error
Browse files Browse the repository at this point in the history
  • Loading branch information
Linlang committed Mar 5, 2024
1 parent 6ea921b commit 8cf7bb3
Show file tree
Hide file tree
Showing 12 changed files with 24 additions and 35 deletions.
3 changes: 0 additions & 3 deletions examples/benchmarks/TRA/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ def predict(self, dataset, segment="test"):


class LSTM(nn.Module):

"""LSTM Model
Args:
Expand Down Expand Up @@ -414,7 +413,6 @@ def forward(self, x):


class Transformer(nn.Module):

"""Transformer Model
Args:
Expand Down Expand Up @@ -475,7 +473,6 @@ def forward(self, x):


class TRA(nn.Module):

"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction errors & latent representation as inputs,
Expand Down
16 changes: 9 additions & 7 deletions qlib/backtest/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,15 @@ def create_account_instance(
init_cash=init_cash,
position_dict=position_dict,
pos_type=pos_type,
benchmark_config={}
if benchmark is None
else {
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
},
benchmark_config=(
{}
if benchmark is None
else {
"benchmark": benchmark,
"start_time": start_time,
"end_time": end_time,
}
),
)


Expand Down
8 changes: 5 additions & 3 deletions qlib/backtest/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,9 +622,11 @@ def cal_trade_indicators(
print(
"[Indicator({}) {}]: FFR: {}, PA: {}, POS: {}".format(
freq,
trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S"),
(
trade_start_time
if isinstance(trade_start_time, str)
else trade_start_time.strftime("%Y-%m-%d %H:%M:%S")
),
fulfill_rate,
price_advantage,
positive_rate,
Expand Down
1 change: 1 addition & 0 deletions qlib/contrib/eva/alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
The interface should be redesigned carefully in the future.
"""

import pandas as pd
from typing import Tuple
from qlib import get_module_logger
Expand Down
3 changes: 0 additions & 3 deletions qlib/contrib/model/pytorch_tra.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,6 @@ def predict(self, dataset, segment="test"):


class RNN(nn.Module):

"""RNN Model
Args:
Expand Down Expand Up @@ -601,7 +600,6 @@ def forward(self, x):


class Transformer(nn.Module):

"""Transformer Model
Args:
Expand Down Expand Up @@ -649,7 +647,6 @@ def forward(self, x):


class TRA(nn.Module):

"""Temporal Routing Adaptor (TRA)
TRA takes historical prediction errors & latent representation as inputs,
Expand Down
1 change: 0 additions & 1 deletion qlib/contrib/strategy/signal_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ def generate_trade_decision(self, execute_result=None):


class EnhancedIndexingStrategy(WeightStrategyBase):

"""Enhanced Indexing Strategy
Enhanced indexing combines the arts of active management and passive management,
Expand Down
8 changes: 2 additions & 6 deletions qlib/data/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,11 @@ def fetch_df_by_index(
if fetch_orig:
for slc in idx_slc:
if slc != slice(None, None):
return df.loc[
pd.IndexSlice[idx_slc],
] # noqa: E231
return df.loc[pd.IndexSlice[idx_slc],] # noqa: E231
else: # pylint: disable=W0120
return df
else:
return df.loc[
pd.IndexSlice[idx_slc],
] # noqa: E231
return df.loc[pd.IndexSlice[idx_slc],] # noqa: E231


def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame:
Expand Down
2 changes: 0 additions & 2 deletions qlib/model/ens/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __call__(self, ensemble_dict: dict, *args, **kwargs):


class SingleKeyEnsemble(Ensemble):

"""
Extract the object if there is only one key and value in the dict. Make the result more readable.
{Only key: Only value} -> Only value
Expand Down Expand Up @@ -64,7 +63,6 @@ def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -


class RollingEnsemble(Ensemble):

"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".
Expand Down
4 changes: 1 addition & 3 deletions qlib/model/riskmodel/shrink.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,7 @@ def _get_shrink_param_lw_single_factor(self, X: np.ndarray, S: np.ndarray, F: np
v1 = y.T.dot(z) / t - cov_mkt[:, None] * S
roff1 = np.sum(v1 * cov_mkt[:, None].T) / var_mkt - np.sum(np.diag(v1) * cov_mkt) / var_mkt
v3 = z.T.dot(z) / t - var_mkt * S
roff3 = (
np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
)
roff3 = np.sum(v3 * np.outer(cov_mkt, cov_mkt)) / var_mkt**2 - np.sum(np.diag(v3) * cov_mkt**2) / var_mkt**2
roff = 2 * roff1 - roff3
rho = rdiag + roff

Expand Down
1 change: 0 additions & 1 deletion qlib/workflow/online/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def get_collector(self) -> Collector:


class RollingStrategy(OnlineStrategy):

"""
This example strategy always uses the latest rolling model sas online models.
"""
Expand Down
4 changes: 1 addition & 3 deletions scripts/dump_bin.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,7 @@ def get_dump_fields(self, df_columns: Iterable[str]) -> Iterable[str]:
return (
self._include_fields
if self._include_fields
else set(df_columns) - set(self._exclude_fields)
if self._exclude_fields
else df_columns
else set(df_columns) - set(self._exclude_fields) if self._exclude_fields else df_columns
)

@staticmethod
Expand Down
8 changes: 5 additions & 3 deletions scripts/dump_pit.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,11 @@ def get_dump_fields(self, df: Iterable[str]) -> Iterable[str]:
return (
set(self._include_fields)
if self._include_fields
else set(df[self.field_column_name]) - set(self._exclude_fields)
if self._exclude_fields
else set(df[self.field_column_name])
else (
set(df[self.field_column_name]) - set(self._exclude_fields)
if self._exclude_fields
else set(df[self.field_column_name])
)
)

def get_filenames(self, symbol, field, interval):
Expand Down

0 comments on commit 8cf7bb3

Please sign in to comment.