Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add categories and subdivisions support to substituted holidays #1558

Merged
merged 3 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion holidays/groups/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,9 @@ class StaticHolidays:

def __init__(self, cls) -> None:
for attribute_name in cls.__dict__.keys():
if attribute_name.startswith("special_") or attribute_name.startswith("substituted_"):
if attribute_name.startswith("special_"):
setattr(self, attribute_name, getattr(cls, attribute_name))
self._has_special = True
elif attribute_name.startswith("substituted_"):
setattr(self, attribute_name, getattr(cls, attribute_name))
self._has_substituted = True
33 changes: 20 additions & 13 deletions holidays/holiday_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def _add_subdiv_holidays(self):

def _add_substituted_holidays(self):
"""Populate substituted holidays."""
if len(self.substituted_holidays) == 0:
if not hasattr(self, "_has_substituted"):
return None
if not hasattr(self, "substituted_label") or not hasattr(self, "substituted_date_format"):
raise ValueError(
Expand All @@ -675,11 +675,15 @@ def _add_substituted_holidays(self):
)
substituted_label = self.tr(self.substituted_label)
substituted_date_format = self.tr(self.substituted_date_format)
for hol in _normalize_tuple(self.substituted_holidays.get(self._year, ())):
from_year = hol[0] if len(hol) == 5 else self._year
from_month, from_day, to_month, to_day = hol[-4:]
from_date = date(from_year, from_month, from_day).strftime(substituted_date_format)
self._add_holiday(substituted_label % from_date, to_month, to_day)

for mapping_name in self._get_static_holiday_mapping_names():
for hol in _normalize_tuple(
getattr(self, f"substituted_{mapping_name}", {}).get(self._year, ())
):
from_year = hol[0] if len(hol) == 5 else self._year
from_month, from_day, to_month, to_day = hol[-4:]
from_date = date(from_year, from_month, from_day).strftime(substituted_date_format)
self._add_holiday(substituted_label % from_date, to_month, to_day)

def _check_weekday(self, weekday: int, *args) -> bool:
"""
Expand Down Expand Up @@ -751,27 +755,30 @@ def _populate(self, year: int) -> None:
# Populate substituted holidays.
self._add_substituted_holidays()

def _get_special_holiday_mapping_names(self):
def _get_static_holiday_mapping_names(self):
# Check for general special holidays.
mapping_names = ["special_holidays"]
mapping_names = ["holidays"]

# Check subdivision specific special holidays.
if self.subdiv is not None:
subdiv = self.subdiv.replace("-", "_").replace(" ", "_").lower()
mapping_names.append(f"special_{subdiv}_holidays")
mapping_names.append(f"{subdiv}_holidays")

# Check category specific special holidays (both general and per subdivision).
for category in sorted(self.categories):
mapping_names.append(f"special_{category}_holidays")
mapping_names.append(f"{category}_holidays")
if self.subdiv is not None:
mapping_names.append(f"special_{subdiv}_{category}_holidays")
mapping_names.append(f"{subdiv}_{category}_holidays")

return mapping_names

def _add_special_holidays(self):
for mapping_name in self._get_special_holiday_mapping_names():
if not hasattr(self, "_has_special"):
return None

for mapping_name in self._get_static_holiday_mapping_names():
for month, day, name in _normalize_tuple(
getattr(self, mapping_name, {}).get(self._year, ())
getattr(self, f"special_{mapping_name}", {}).get(self._year, ())
):
self._add_holiday(name, date(self._year, month, day))

Expand Down
4 changes: 2 additions & 2 deletions holidays/observed_holiday_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def _add_special_holidays(self):
if not self.observed:
return None

for mapping_name in self._get_special_holiday_mapping_names():
for mapping_name in self._get_static_holiday_mapping_names():
for month, day, name in _normalize_tuple(
getattr(self, f"{mapping_name}_observed", {}).get(self._year, ())
getattr(self, f"special_{mapping_name}_observed", {}).get(self._year, ())
):
self._add_holiday(
self.tr(self.observed_label) % self.tr(name), date(self._year, month, day)
Expand Down
3 changes: 3 additions & 0 deletions tests/test_holiday_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@


class EntityStub(HolidayBase):
_has_special = True
_has_substituted = True
special_holidays = {
1111: (JAN, 1, "Test holiday"),
2222: (FEB, 2, "Test holiday"),
Expand Down Expand Up @@ -931,6 +933,7 @@ def test_market(self):

class TestSubstitutedHolidays(unittest.TestCase):
class SubstitutedHolidays(HolidayBase):
_has_substituted = True
country = "HB"
substituted_holidays = {
1991: (
Expand Down