diff --git a/adventofcode/d20.py b/adventofcode/d20.py index 1a2616a..8f3a93e 100644 --- a/adventofcode/d20.py +++ b/adventofcode/d20.py @@ -6,7 +6,7 @@ from abc import ABCMeta, abstractmethod from collections import Counter, deque from dataclasses import dataclass -from typing import Iterable, Iterator, Never, NewType, overload, override +from typing import Iterable, Iterator, Never, NewType, override _logger = logging.getLogger(__name__) @@ -31,99 +31,44 @@ class _PulseNew: @dataclass(frozen=True, kw_only=True, slots=True) class _Pulse: - button_presses: int - pulse_index: int value: _PulseValue from_: _Module to: _Module + button_presses: int @staticmethod - @overload - def new( - button_presses: int, - pulse_index: int, - value: _PulseValue, - from_: _Module, - to: _Module, - /, - ) -> _Pulse: ... - - @staticmethod - @overload - def new( - button_presses: int, pulse_index: int, new_pulse: _PulseNew, / - ) -> _Pulse: ... - - @staticmethod - def new( - button_presses: int, - pulse_index: int, - value_or_new_pulse: _PulseValue | _PulseNew, - from_: _Module | None = None, - to: _Module | None = None, - /, - ) -> _Pulse: - if isinstance(value_or_new_pulse, _PulseNew): - return _Pulse( - button_presses=button_presses, - pulse_index=pulse_index, - value=value_or_new_pulse.value, - from_=value_or_new_pulse.from_, - to=value_or_new_pulse.to, - ) - - if from_ is None: - raise ValueError("from_") - if to is None: - raise ValueError("to") + def create(new_pulse: _PulseNew, button_presses: int, /) -> _Pulse: return _Pulse( + value=new_pulse.value, + from_=new_pulse.from_, + to=new_pulse.to, button_presses=button_presses, - pulse_index=pulse_index, - value=value_or_new_pulse, - from_=from_, - to=to, ) class _Module(metaclass=ABCMeta): def __init__(self, name: _ModuleName) -> None: self._name = name - self._outputs: list[tuple[_Module, _PulseValue]] = [] + self._outputs: list[_Module] = [] @property def name(self) -> _ModuleName: return self._name - @property - @abstractmethod - def possible_output_values(self) -> tuple[_PulseValue, ...]: ... - - @property - @abstractmethod - def interesting_input_signal_values(self) -> tuple[_PulseValue, ...]: ... - def add_receiving_module(self, output: _Module) -> None: - interested_pulse_values = output.interesting_input_signal_values - assert interested_pulse_values - for value in interested_pulse_values: - if value not in self.possible_output_values: - continue - - self._outputs.append((output, value)) + self._outputs.append(output) @abstractmethod def process_pulse(self, pulse: _Pulse) -> Iterator[_PulseNew]: ... def _get_output_pulses(self, value: _PulseValue) -> Iterator[_PulseNew]: - for output, interested_value in self._outputs: - if interested_value == value: - yield _PulseNew(value=value, from_=self, to=output) + for output in self._outputs: + yield _PulseNew(value=value, from_=self, to=output) if __debug__: def _validate_incoming_pulse(self, pulse: _Pulse) -> None: assert pulse.to is self - assert pulse.value in self.interesting_input_signal_values class _Button(_Module): @@ -139,16 +84,6 @@ def process_button_press(self) -> Iterator[_PulseNew]: self._button_presses += 1 yield from self._get_output_pulses(_PulseLow) - @property - @override - def possible_output_values(self) -> tuple[_PulseValue, ...]: - return (_PulseLow,) - - @property - @override - def interesting_input_signal_values(self) -> tuple[_PulseValue, ...]: - return tuple() - @override def process_pulse(self, pulse: _Pulse) -> Never: if __debug__: @@ -157,16 +92,6 @@ def process_pulse(self, pulse: _Pulse) -> Never: class _Broadcast(_Module): - @property - @override - def possible_output_values(self) -> tuple[_PulseValue, ...]: - return (_PulseLow, _PulseHigh) - - @property - @override - def interesting_input_signal_values(self) -> tuple[_PulseValue, ...]: - return (_PulseLow, _PulseHigh) - @override def process_pulse(self, pulse: _Pulse) -> Iterator[_PulseNew]: if __debug__: @@ -177,21 +102,6 @@ def process_pulse(self, pulse: _Pulse) -> Iterator[_PulseNew]: class _Receiver(_Module): def __init__(self, name: _ModuleName) -> None: super().__init__(name=name) - self._received_low = False - - @property - def received_low(self) -> bool: - return self._received_low - - @property - @override - def possible_output_values(self) -> tuple[_PulseValue, ...]: - return tuple() - - @property - @override - def interesting_input_signal_values(self) -> tuple[_PulseValue, ...]: - return (_PulseLow,) @override def add_receiving_module(self, output: _Module) -> Never: @@ -202,9 +112,11 @@ def add_receiving_module(self, output: _Module) -> Never: def process_pulse(self, pulse: _Pulse) -> Iterator[_PulseNew]: if __debug__: self._validate_incoming_pulse(pulse) - _logger.info("Receiver received signal from %s: %s", pulse.from_, pulse.value) - assert pulse.value is _PulseLow - self._received_low = True + if pulse.value is _PulseLow: + _logger.info( + "Receiver received signal from %s: %s", pulse.from_, pulse.value + ) + self._received_low = True yield from [] @@ -213,22 +125,13 @@ def __init__(self, name: _ModuleName) -> None: super().__init__(name=name) self._state: _PulseValue = _PulseLow - @property - @override - def possible_output_values(self) -> tuple[_PulseValue, ...]: - return (_PulseLow, _PulseHigh) - - @property - @override - def interesting_input_signal_values(self) -> tuple[_PulseValue, ...]: - return (_PulseLow,) - @override def process_pulse(self, pulse: _Pulse) -> Iterator[_PulseNew]: if __debug__: self._validate_incoming_pulse(pulse) - assert pulse.value is _PulseLow + if pulse.value is _PulseHigh: + return self._state = _PulseValue(not self._state) yield from self._get_output_pulses(self._state) @@ -238,17 +141,6 @@ class _Conjunction(_Module): def __init__(self, name: _ModuleName) -> None: super().__init__(name=name) self._state: dict[_ModuleName, _PulseValue] = {} - self._input_signals: dict[_ModuleName, list[tuple[_PulseValue, int]]] = {} - - @property - @override - def possible_output_values(self) -> tuple[_PulseValue, ...]: - return (_PulseLow, _PulseHigh) - - @property - @override - def interesting_input_signal_values(self) -> tuple[_PulseValue, ...]: - return (_PulseLow, _PulseHigh) @override def process_pulse(self, pulse: _Pulse) -> Iterator[_PulseNew]: @@ -286,7 +178,10 @@ def button_press_period(self) -> int: return self.last_button_press - self.first_button_press + 1 def __repr__(self) -> str: - return f"({self.value}, {self.count}, {self.first_button_press}, {self.last_button_press}, {self.button_press_period})" + return ( + f"({self.value}, {self.count}, {self.first_button_press}, " + f"{self.last_button_press}, {self.button_press_period})" + ) class _GatewayConjuction(_Conjunction): @@ -427,7 +322,7 @@ def _parse_module(line: str) -> tuple[_AnyModule, list[_ModuleName]]: def _parse_modules( - lines: Iterable[str], + lines: Iterable[str], *, use_gateway: bool = False ) -> tuple[ _Receiver | None, _Button, @@ -455,23 +350,24 @@ def _parse_modules( for module, outputs in modules_with_output_names if receiver_name in outputs ] - assert ( - len(gateways_to_receiver) <= 1 - ), "Safety check: only 0-1 gateways are known" - if gateways_to_receiver: - gateway_to_receiver = gateways_to_receiver[0] - assert isinstance( - gateway_to_receiver, _Conjunction - ), "Safety check: only conjunctions are known" - modules_with_output_names = [ - ( - _GatewayConjuction(module.name) - if module is gateway_to_receiver - else module, - outputs, - ) - for module, outputs in modules_with_output_names - ] + if use_gateway: + assert ( + len(gateways_to_receiver) <= 1 + ), "Safety check: only 0-1 gateways are known" + if gateways_to_receiver: + gateway_to_receiver = gateways_to_receiver[0] + assert isinstance( + gateway_to_receiver, _Conjunction + ), "Safety check: only conjunctions are known" + modules_with_output_names = [ + ( + _GatewayConjuction(module.name) + if module is gateway_to_receiver + else module, + outputs, + ) + for module, outputs in modules_with_output_names + ] modules_by_name = {module.name: module for module, _ in modules_with_output_names} @@ -523,54 +419,50 @@ def _parse_modules( return receiver, button, gateway, modules_by_type +def _extend_pulses( + queue: deque[_Pulse], button: _Button, pulses: Iterable[_PulseNew] +) -> None: + queue.extend(map(lambda new: _Pulse.create(new, button.button_presses), pulses)) + + def p1(input_str: str) -> int: _, button, _, _ = _parse_modules(input_str.splitlines()) counts = Counter[_PulseValue]() + queue = deque[_Pulse]() + for _ in range(1000): - queue = deque[_Pulse](button.process_button_press()) + _extend_pulses(queue, button, button.process_button_press()) + while queue: pulse = queue.popleft() counts.update((pulse.value,)) - if _logger.isEnabledFor(logging.DEBUG): - _logger.debug("Pulse: %s (queue length: %d)", pulse, len(queue)) - queue.extend(pulse.to.process_pulse(pulse)) + _extend_pulses(queue, button, pulse.to.process_pulse(pulse)) _logger.info(f"Counts: {counts}") return math.prod(counts.values()) -def _process_p2(input_str: str, output_module_name: _ModuleName) -> int: - receiver, button, gateway, _ = _parse_modules(input_str.splitlines()) +def p2(input_str: str) -> int: + receiver_name = _ModuleName("rx") + receiver, button, gateway, _ = _parse_modules( + input_str.splitlines(), use_gateway=True + ) assert receiver is not None - assert receiver.name == output_module_name + assert receiver.name == receiver_name assert gateway is not None queue = deque[_Pulse]() - pulse_index = -1 - - def next_pulse_count() -> int: - nonlocal pulse_index - pulse_index += 1 - return pulse_index - - def new_pulse(new: _PulseNew) -> _Pulse: - return _Pulse.new(button.button_presses, next_pulse_count(), new) - - def extend_pulses(pulses: Iterable[_PulseNew]) -> None: - nonlocal queue - queue.extend(map(new_pulse, pulses)) while True: - if not queue: - extend_pulses(button.process_button_press()) - if button.button_presses % 100_000 == 0: - _logger.info(f"Button presses: {button.button_presses:_}") + _extend_pulses(queue, button, button.process_button_press()) + if button.button_presses % 100_000 == 0: + _logger.info(f"Button presses: {button.button_presses:_}") while queue: pulse = queue.popleft() - extend_pulses(pulse.to.process_pulse(pulse)) + _extend_pulses(queue, button, pulse.to.process_pulse(pulse)) if gateway.has_pattern_for_all: _logger.info( @@ -580,10 +472,3 @@ def extend_pulses(pulses: Iterable[_PulseNew]) -> None: break return math.lcm(*(pattern.period for pattern in gateway.patterns.values())) - - -def p2(input_str: str, output_module_name: _ModuleName | None = None) -> int: - if output_module_name is None: - output_module_name = _ModuleName("rx") - - return _process_p2(input_str, output_module_name)