From f2f3a5d2547003f0394d102ad1efd7e8586fb5d8 Mon Sep 17 00:00:00 2001 From: edtechre Date: Sat, 9 Dec 2023 18:04:47 -0800 Subject: [PATCH] Clean up AKShare data source - Improve timeframe parsing. - Remove unneeded constructor. - Add unit tests for timeframes. --- src/pybroker/ext/data.py | 65 ++++++++-------------------------------- tests/test_data.py | 37 +++++++++++++++++++++-- 2 files changed, 48 insertions(+), 54 deletions(-) diff --git a/src/pybroker/ext/data.py b/src/pybroker/ext/data.py index dbd02c5..6585cc1 100644 --- a/src/pybroker/ext/data.py +++ b/src/pybroker/ext/data.py @@ -7,7 +7,7 @@ """ from datetime import datetime -from typing import Iterable, Optional, Union +from typing import Optional import akshare import pandas as pd @@ -17,44 +17,13 @@ class AKShare(DataSource): - r"""Retrieves data from `AKShare `_. + r"""Retrieves data from `AKShare `_.""" - Args: - adjust: The type of adjustment to make. - timeframe: Timeframe of the data to query. - """ - - def __init__( - self, adjust: Optional[str] = "", timeframe: Optional[str] = "1d" - ): - super().__init__() - self.adjust = adjust - self.timeframe = timeframe - - def query( - self, - symbols: Union[str, Iterable[str]], - start_date: Union[str, datetime], - end_date: Union[str, datetime], - timeframe: Optional[str] = "1d", - adjust: Optional[str] = "", - ) -> pd.DataFrame: - r"""Queries data from `AKShare `_\ . - The timeframe of the data is limited to per daily, weekly and monthly. - - Args: - symbols: Ticker symbols of the data to query. - start_date: Start date of the data to query (inclusive). - end_date: End date of the data to query (inclusive). - timeframe: Timeframe of the data to query. - adjust: The type of adjustment to make. - - Returns: - :class:`pandas.DataFrame` containing the queried data. - """ - timeframe = timeframe if timeframe != "1d" else self.timeframe - adjust = adjust if adjust != "" else self.adjust - return super().query(symbols, start_date, end_date, timeframe, adjust) + _tf_to_period = { + "": "daily", + "1day": "daily", + "1week": "weekly", + } def _fetch_data( self, @@ -70,28 +39,20 @@ def _fetch_data( symbols_list = list(symbols) symbols_simple = [item.split(".")[0] for item in symbols_list] result = pd.DataFrame() - period_timeframe_map = { - "": "daily", - "1day": "daily", - "1week": "weekly", - "1month": "monthly", - } - for i in range(len(symbols_list)): - try: + formatted_tf = self._format_timeframe(timeframe) + if formatted_tf in AKShare._tf_to_period: + period = AKShare._tf_to_period[formatted_tf] + for i in range(len(symbols_list)): temp_df = akshare.stock_zh_a_hist( symbol=symbols_simple[i], start_date=start_date_str, end_date=end_date_str, - period=period_timeframe_map[timeframe] - if timeframe - else "daily", + period=period, adjust=adjust if adjust is not None else "", ) if not temp_df.columns.empty: temp_df["symbol"] = symbols_list[i] - except KeyError: - temp_df = pd.DataFrame() - result = pd.concat([result, temp_df], ignore_index=True) + result = pd.concat([result, temp_df], ignore_index=True) if result.columns.empty: return pd.DataFrame( columns=[ diff --git a/tests/test_data.py b/tests/test_data.py index bae368d..be668bc 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -571,7 +571,8 @@ def test_query_when_empty_result(self, symbols, columns): class TestAKShare: @pytest.mark.usefixtures("setup_ds_cache") - def test_query(self): + @pytest.mark.parametrize("timeframe", [None, "", "1d", "1w"]) + def test_query(self, timeframe): symbols = ["A"] ak = AKShare() expected_df = pd.DataFrame( @@ -588,7 +589,7 @@ def test_query(self): with mock.patch.object( akshare, "stock_zh_a_hist", return_value=expected_df ): - df = ak.query(symbols, START_DATE, END_DATE) + df = ak.query(symbols, START_DATE, END_DATE, timeframe) assert set(df.columns) == { "date", "open", @@ -638,3 +639,35 @@ def test_query_when_empty_result(self, columns): "symbol", ) ) + + @pytest.mark.usefixtures("setup_ds_cache") + def test_query_when_unsupported_timeframe_then_empty(self): + symbols = ["A"] + ak = AKShare() + expected_df = pd.DataFrame( + { + "日期": [END_DATE], + "开盘": [1], + "收盘": [2], + "最高": [3], + "最低": [4], + "成交量": [5], + "symbol": symbols, + } + ) + with mock.patch.object( + akshare, "stock_zh_a_hist", return_value=expected_df + ): + df = ak.query(symbols, START_DATE, END_DATE, "2d") + assert df.empty + assert set(df.columns) == set( + ( + "date", + "open", + "high", + "low", + "close", + "volume", + "symbol", + ) + )