Skip to content

Commit

Permalink
Clean up AKShare data source
Browse files Browse the repository at this point in the history
- Improve timeframe parsing.
- Remove unneeded constructor.
- Add unit tests for timeframes.
  • Loading branch information
edtechre committed Dec 10, 2023
1 parent 2e8815b commit f2f3a5d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 54 deletions.
65 changes: 13 additions & 52 deletions src/pybroker/ext/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from datetime import datetime
from typing import Iterable, Optional, Union
from typing import Optional

import akshare
import pandas as pd
Expand All @@ -17,44 +17,13 @@


class AKShare(DataSource):
r"""Retrieves data from `AKShare <https://akshare.akfamily.xyz/>`_.
r"""Retrieves data from `AKShare <https://akshare.akfamily.xyz/>`_."""

Args:
adjust: The type of adjustment to make.
timeframe: Timeframe of the data to query.
"""

def __init__(
self, adjust: Optional[str] = "", timeframe: Optional[str] = "1d"
):
super().__init__()
self.adjust = adjust
self.timeframe = timeframe

def query(
self,
symbols: Union[str, Iterable[str]],
start_date: Union[str, datetime],
end_date: Union[str, datetime],
timeframe: Optional[str] = "1d",
adjust: Optional[str] = "",
) -> pd.DataFrame:
r"""Queries data from `AKShare <https://akshare.akfamily.xyz/>`_\ .
The timeframe of the data is limited to per daily, weekly and monthly.
Args:
symbols: Ticker symbols of the data to query.
start_date: Start date of the data to query (inclusive).
end_date: End date of the data to query (inclusive).
timeframe: Timeframe of the data to query.
adjust: The type of adjustment to make.
Returns:
:class:`pandas.DataFrame` containing the queried data.
"""
timeframe = timeframe if timeframe != "1d" else self.timeframe
adjust = adjust if adjust != "" else self.adjust
return super().query(symbols, start_date, end_date, timeframe, adjust)
_tf_to_period = {
"": "daily",
"1day": "daily",
"1week": "weekly",
}

def _fetch_data(
self,
Expand All @@ -70,28 +39,20 @@ def _fetch_data(
symbols_list = list(symbols)
symbols_simple = [item.split(".")[0] for item in symbols_list]
result = pd.DataFrame()
period_timeframe_map = {
"": "daily",
"1day": "daily",
"1week": "weekly",
"1month": "monthly",
}
for i in range(len(symbols_list)):
try:
formatted_tf = self._format_timeframe(timeframe)
if formatted_tf in AKShare._tf_to_period:
period = AKShare._tf_to_period[formatted_tf]
for i in range(len(symbols_list)):
temp_df = akshare.stock_zh_a_hist(
symbol=symbols_simple[i],
start_date=start_date_str,
end_date=end_date_str,
period=period_timeframe_map[timeframe]
if timeframe
else "daily",
period=period,
adjust=adjust if adjust is not None else "",
)
if not temp_df.columns.empty:
temp_df["symbol"] = symbols_list[i]
except KeyError:
temp_df = pd.DataFrame()
result = pd.concat([result, temp_df], ignore_index=True)
result = pd.concat([result, temp_df], ignore_index=True)
if result.columns.empty:
return pd.DataFrame(
columns=[
Expand Down
37 changes: 35 additions & 2 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,8 @@ def test_query_when_empty_result(self, symbols, columns):

class TestAKShare:
@pytest.mark.usefixtures("setup_ds_cache")
def test_query(self):
@pytest.mark.parametrize("timeframe", [None, "", "1d", "1w"])
def test_query(self, timeframe):
symbols = ["A"]
ak = AKShare()
expected_df = pd.DataFrame(
Expand All @@ -588,7 +589,7 @@ def test_query(self):
with mock.patch.object(
akshare, "stock_zh_a_hist", return_value=expected_df
):
df = ak.query(symbols, START_DATE, END_DATE)
df = ak.query(symbols, START_DATE, END_DATE, timeframe)
assert set(df.columns) == {
"date",
"open",
Expand Down Expand Up @@ -638,3 +639,35 @@ def test_query_when_empty_result(self, columns):
"symbol",
)
)

@pytest.mark.usefixtures("setup_ds_cache")
def test_query_when_unsupported_timeframe_then_empty(self):
symbols = ["A"]
ak = AKShare()
expected_df = pd.DataFrame(
{
"日期": [END_DATE],
"开盘": [1],
"收盘": [2],
"最高": [3],
"最低": [4],
"成交量": [5],
"symbol": symbols,
}
)
with mock.patch.object(
akshare, "stock_zh_a_hist", return_value=expected_df
):
df = ak.query(symbols, START_DATE, END_DATE, "2d")
assert df.empty
assert set(df.columns) == set(
(
"date",
"open",
"high",
"low",
"close",
"volume",
"symbol",
)
)

0 comments on commit f2f3a5d

Please sign in to comment.