Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support fetching of multiple accessories at once #105

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 102 additions & 13 deletions findmy/reports/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
TYPE_CHECKING,
Any,
Callable,
Sequence,
TypedDict,
TypeVar,
cast,
Expand Down Expand Up @@ -49,6 +48,8 @@
)

if TYPE_CHECKING:
from collections.abc import Sequence

from findmy.accessory import RollingKeyPairSource
from findmy.keys import HasHashedPublicKey
from findmy.util.types import MaybeCoro
Expand Down Expand Up @@ -248,13 +249,28 @@ def fetch_reports(
date_to: datetime | None,
) -> MaybeCoro[list[LocationReport]]: ...

@overload
def fetch_reports(
self,
keys: Sequence[RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> MaybeCoro[dict[RollingKeyPairSource, list[LocationReport]]]: ...

@abstractmethod
def fetch_reports(
self,
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> MaybeCoro[list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]]:
) -> MaybeCoro[
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
]:
"""
Fetch location reports for `HasHashedPublicKey`s between `date_from` and `date_end`.

Expand Down Expand Up @@ -286,12 +302,27 @@ def fetch_last_reports(
hours: int = 7 * 24,
) -> MaybeCoro[list[LocationReport]]: ...

@overload
@abstractmethod
def fetch_last_reports(
self,
keys: Sequence[RollingKeyPairSource],
hours: int = 7 * 24,
) -> MaybeCoro[dict[RollingKeyPairSource, list[LocationReport]]]: ...

@abstractmethod
def fetch_last_reports(
self,
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
hours: int = 7 * 24,
) -> MaybeCoro[list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]]:
) -> MaybeCoro[
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
]:
"""
Fetch location reports for a sequence of `HasHashedPublicKey`s for the last `hours` hours.

Expand Down Expand Up @@ -641,14 +672,29 @@ async def fetch_reports(
date_to: datetime | None,
) -> list[LocationReport]: ...

@overload
async def fetch_reports(
self,
keys: Sequence[RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...

@require_login_state(LoginState.LOGGED_IN)
@override
async def fetch_reports(
self,
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
):
"""See `BaseAppleAccount.fetch_reports`."""
date_to = date_to or datetime.now().astimezone()

Expand Down Expand Up @@ -679,13 +725,27 @@ async def fetch_last_reports(
hours: int = 7 * 24,
) -> list[LocationReport]: ...

@overload
async def fetch_last_reports(
self,
keys: Sequence[RollingKeyPairSource],
hours: int = 7 * 24,
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...

@require_login_state(LoginState.LOGGED_IN)
@override
async def fetch_last_reports(
self,
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
hours: int = 7 * 24,
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
):
"""See `BaseAppleAccount.fetch_last_reports`."""
end = datetime.now(tz=timezone.utc)
start = end - timedelta(hours=hours)
Expand Down Expand Up @@ -1041,13 +1101,28 @@ def fetch_reports(
date_to: datetime | None,
) -> list[LocationReport]: ...

@overload
def fetch_reports(
self,
keys: Sequence[RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...

@override
def fetch_reports(
self,
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
date_from: datetime,
date_to: datetime | None,
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
):
"""See `AsyncAppleAccount.fetch_reports`."""
coro = self._asyncacc.fetch_reports(keys, date_from, date_to)
return self._evt_loop.run_until_complete(coro)
Expand All @@ -1073,12 +1148,26 @@ def fetch_last_reports(
hours: int = 7 * 24,
) -> list[LocationReport]: ...

@overload
def fetch_last_reports(
self,
keys: Sequence[RollingKeyPairSource],
hours: int = 7 * 24,
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...

@override
def fetch_last_reports(
self,
keys: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
keys: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
hours: int = 7 * 24,
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
):
"""See `AsyncAppleAccount.fetch_last_reports`."""
coro = self._asyncacc.fetch_last_reports(keys, hours)
return self._evt_loop.run_until_complete(coro)
Expand Down
104 changes: 73 additions & 31 deletions findmy/reports/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import hashlib
import logging
import struct
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, overload
from typing import TYPE_CHECKING, cast, overload

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
Expand Down Expand Up @@ -260,12 +261,27 @@ async def fetch_reports(
device: RollingKeyPairSource,
) -> list[LocationReport]: ...

@overload
async def fetch_reports(
self,
date_from: datetime,
date_to: datetime,
device: Sequence[RollingKeyPairSource],
) -> dict[RollingKeyPairSource, list[LocationReport]]: ...

async def fetch_reports(
self,
date_from: datetime,
date_to: datetime,
device: HasHashedPublicKey | Sequence[HasHashedPublicKey] | RollingKeyPairSource,
) -> list[LocationReport] | dict[HasHashedPublicKey, list[LocationReport]]:
device: HasHashedPublicKey
| Sequence[HasHashedPublicKey]
| RollingKeyPairSource
| Sequence[RollingKeyPairSource],
) -> (
list[LocationReport]
| dict[HasHashedPublicKey, list[LocationReport]]
| dict[RollingKeyPairSource, list[LocationReport]]
):
"""
Fetch location reports for a certain device.

Expand All @@ -276,45 +292,71 @@ async def fetch_reports(
When ``device`` is a :class:`.RollingKeyPairSource`, it will return a list of
location reports corresponding to that source.
"""
# single key
key_devs: (
dict[HasHashedPublicKey, HasHashedPublicKey]
| dict[HasHashedPublicKey, RollingKeyPairSource]
) = {}
if isinstance(device, HasHashedPublicKey):
return await self._fetch_reports(date_from, date_to, [device])

# key generator
# add 12h margin to the generator
if isinstance(device, RollingKeyPairSource):
keys = list(
device.keys_between(
# single key
key_devs = {device: device}
elif isinstance(device, list) and all(isinstance(x, HasHashedPublicKey) for x in device):
# multiple static keys
device = cast(list[HasHashedPublicKey], device)
key_devs = {key: key for key in device}
elif isinstance(device, RollingKeyPairSource):
# key generator
# add 12h margin to the generator
key_devs = {
key: device
for key in device.keys_between(
date_from - timedelta(hours=12),
date_to + timedelta(hours=12),
)
}
elif isinstance(device, list) and all(isinstance(x, RollingKeyPairSource) for x in device):
# multiple key generators
# add 12h margin to each generator
device = cast(list[RollingKeyPairSource], device)
key_devs = {
key: dev
for dev in device
for key in dev.keys_between(
date_from - timedelta(hours=12),
date_to + timedelta(hours=12),
),
)
)
}
else:
keys = device
msg = "Unknown device type: %s"
raise ValueError(msg, type(device))

# sequence of keys (fetch 256 max at a time)
reports: list[LocationReport] = []
key_reports: dict[HasHashedPublicKey, list[LocationReport]] = {}
keys = list(key_devs.keys())
for key_offset in range(0, len(keys), 256):
chunk = keys[key_offset : key_offset + 256]
reports.extend(await self._fetch_reports(date_from, date_to, chunk))

if isinstance(device, RollingKeyPairSource):
return reports

res: dict[HasHashedPublicKey, list[LocationReport]] = {key: [] for key in keys}
for report in reports:
for key in res:
if key.hashed_adv_key_bytes == report.hashed_adv_key_bytes:
res[key].append(report)
break
return res
chunk_keys = keys[key_offset : key_offset + 256]
chunk_reports = await self._fetch_reports(date_from, date_to, chunk_keys)
key_reports |= chunk_reports

# combine (key -> list[report]) and (key -> device) into (device -> list[report])
device_reports = defaultdict(list)
for key, reports in key_reports.items():
device_reports[key_devs[key]].extend(reports)
for dev in device_reports:
device_reports[dev] = sorted(device_reports[dev])

# result
if isinstance(device, (HasHashedPublicKey, RollingKeyPairSource)):
# single key or generator
return device_reports[device]
# multiple static keys or key generators
return device_reports

async def _fetch_reports(
self,
date_from: datetime,
date_to: datetime,
keys: Sequence[HasHashedPublicKey],
) -> list[LocationReport]:
) -> dict[HasHashedPublicKey, list[LocationReport]]:
logging.debug("Fetching reports for %s keys", len(keys))

# lock requested time range to the past 7 days, +- 12 hours, then filter the response.
Expand All @@ -327,7 +369,7 @@ async def _fetch_reports(
data = await self._account.fetch_raw_reports(start_date, end_date, ids)

id_to_key: dict[bytes, HasHashedPublicKey] = {key.hashed_adv_key_bytes: key for key in keys}
reports: list[LocationReport] = []
reports: dict[HasHashedPublicKey, list[LocationReport]] = defaultdict(list)
for report in data.get("results", []):
payload = base64.b64decode(report["payload"])
hashed_adv_key = base64.b64decode(report["id"])
Expand All @@ -347,6 +389,6 @@ async def _fetch_reports(
if loc_report.timestamp < date_from or loc_report.timestamp > date_to:
continue

reports.append(loc_report)
reports[key].append(loc_report)

return reports
Loading