Skip to content

Commit

Permalink
feat: add type hints to core lib (#71)
Browse files Browse the repository at this point in the history
* feat: add type hints to core lib

---------

Co-authored-by: Kim Gustyr <[email protected]>
Co-authored-by: Zach Aysan <[email protected]>
  • Loading branch information
3 people authored Mar 13, 2024
1 parent fe66a66 commit 21eb873
Show file tree
Hide file tree
Showing 16 changed files with 343 additions and 155 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ repos:
hooks:
- id: mypy
args: [--strict]
additional_dependencies: [pydantic, pytest, pytest_mock]
additional_dependencies:
[pydantic, pytest, pytest_mock, types-requests, flagsmith-flag-engine, responses, types-pytz, sseclient-py]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
Expand Down
4 changes: 3 additions & 1 deletion flagsmith/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .flagsmith import Flagsmith # noqa
from .flagsmith import Flagsmith

__all__ = ("Flagsmith",)
19 changes: 11 additions & 8 deletions flagsmith/analytics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
import typing
from datetime import datetime

from requests_futures.sessions import FuturesSession
from requests_futures.sessions import FuturesSession # type: ignore

ANALYTICS_ENDPOINT = "analytics/flags/"
ANALYTICS_ENDPOINT: typing.Final[str] = "analytics/flags/"

# Used to control how often we send data(in seconds)
ANALYTICS_TIMER = 10
ANALYTICS_TIMER: typing.Final[int] = 10

session = FuturesSession(max_workers=4)

Expand All @@ -17,7 +18,9 @@ class AnalyticsProcessor:
the Flagsmith SDK. Docs: https://docs.flagsmith.com/advanced-use/flag-analytics.
"""

def __init__(self, environment_key: str, base_api_url: str, timeout: int = 3):
def __init__(
self, environment_key: str, base_api_url: str, timeout: typing.Optional[int] = 3
):
"""
Initialise the AnalyticsProcessor to handle sending analytics on flag usage to
the Flagsmith API.
Expand All @@ -30,10 +33,10 @@ def __init__(self, environment_key: str, base_api_url: str, timeout: int = 3):
self.analytics_endpoint = base_api_url + ANALYTICS_ENDPOINT
self.environment_key = environment_key
self._last_flushed = datetime.now()
self.analytics_data = {}
self.timeout = timeout
self.analytics_data: typing.MutableMapping[str, typing.Any] = {}
self.timeout = timeout or 3

def flush(self):
def flush(self) -> None:
"""
Sends all the collected data to the api asynchronously and resets the timer
"""
Expand All @@ -53,7 +56,7 @@ def flush(self):
self.analytics_data.clear()
self._last_flushed = datetime.now()

def track_feature(self, feature_name: str):
def track_feature(self, feature_name: str) -> None:
self.analytics_data[feature_name] = self.analytics_data.get(feature_name, 0) + 1
if (datetime.now() - self._last_flushed).seconds > ANALYTICS_TIMER:
self.flush()
82 changes: 59 additions & 23 deletions flagsmith/flagsmith.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,23 @@
from flagsmith.offline_handlers import BaseOfflineHandler
from flagsmith.polling_manager import EnvironmentDataPollingManager
from flagsmith.streaming_manager import EventStreamManager, StreamEvent
from flagsmith.utils.identities import generate_identities_data
from flagsmith.utils.identities import Identity, generate_identities_data

logger = logging.getLogger(__name__)

DEFAULT_API_URL = "https://edge.api.flagsmith.com/api/v1/"
DEFAULT_REALTIME_API_URL = "https://realtime.flagsmith.com/"

JsonType = typing.Union[
None,
int,
str,
bool,
typing.List["JsonType"],
typing.List[typing.Mapping[str, "JsonType"]],
typing.Dict[str, "JsonType"],
]


class Flagsmith:
"""A Flagsmith client.
Expand All @@ -45,19 +55,21 @@ class Flagsmith:

def __init__(
self,
environment_key: str = None,
api_url: str = None,
environment_key: typing.Optional[str] = None,
api_url: typing.Optional[str] = None,
realtime_api_url: typing.Optional[str] = None,
custom_headers: typing.Dict[str, typing.Any] = None,
request_timeout_seconds: int = None,
custom_headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
request_timeout_seconds: typing.Optional[int] = None,
enable_local_evaluation: bool = False,
environment_refresh_interval_seconds: typing.Union[int, float] = 60,
retries: Retry = None,
retries: typing.Optional[Retry] = None,
enable_analytics: bool = False,
default_flag_handler: typing.Callable[[str], DefaultFlag] = None,
proxies: typing.Dict[str, str] = None,
default_flag_handler: typing.Optional[
typing.Callable[[str], DefaultFlag]
] = None,
proxies: typing.Optional[typing.Dict[str, str]] = None,
offline_mode: bool = False,
offline_handler: BaseOfflineHandler = None,
offline_handler: typing.Optional[BaseOfflineHandler] = None,
enable_realtime_updates: bool = False,
):
"""
Expand Down Expand Up @@ -94,8 +106,8 @@ def __init__(
self.offline_handler = offline_handler
self.default_flag_handler = default_flag_handler
self.enable_realtime_updates = enable_realtime_updates
self._analytics_processor = None
self._environment = None
self._analytics_processor: typing.Optional[AnalyticsProcessor] = None
self._environment: typing.Optional[EnvironmentModel] = None
self._identity_overrides_by_identifier: typing.Dict[str, IdentityModel] = {}

# argument validation
Expand Down Expand Up @@ -159,6 +171,9 @@ def __init__(
def _initialise_local_evaluation(self) -> None:
if self.enable_realtime_updates:
self.update_environment()
if not self._environment:
raise ValueError("Unable to get environment from API key")

stream_url = f"{self.realtime_api_url}sse/environments/{self._environment.api_key}/stream"

self.event_stream_thread = EventStreamManager(
Expand Down Expand Up @@ -196,6 +211,10 @@ def handle_stream_event(self, event: StreamEvent) -> None:
if stream_updated_at.tzinfo is None:
stream_updated_at = pytz.utc.localize(stream_updated_at)

if not self._environment:
raise ValueError(
"Unable to access environment. Environment should not be null"
)
environment_updated_at = self._environment.updated_at
if environment_updated_at.tzinfo is None:
environment_updated_at = pytz.utc.localize(environment_updated_at)
Expand All @@ -214,7 +233,9 @@ def get_environment_flags(self) -> Flags:
return self._get_environment_flags_from_api()

def get_identity_flags(
self, identifier: str, traits: typing.Dict[str, typing.Any] = None
self,
identifier: str,
traits: typing.Optional[typing.Mapping[str, TraitValue]] = None,
) -> Flags:
"""
Get all the flags for the current environment for a given identity. Will also
Expand All @@ -233,7 +254,9 @@ def get_identity_flags(
return self._get_identity_flags_from_api(identifier, traits)

def get_identity_segments(
self, identifier: str, traits: typing.Dict[str, typing.Any] = None
self,
identifier: str,
traits: typing.Optional[typing.Mapping[str, TraitValue]] = None,
) -> typing.List[Segment]:
"""
Get a list of segments that the given identity is in.
Expand All @@ -255,7 +278,7 @@ def get_identity_segments(
segment_models = get_identity_segments(self._environment, identity_model)
return [Segment(id=sm.id, name=sm.name) for sm in segment_models]

def update_environment(self):
def update_environment(self) -> None:
self._environment = self._get_environment_from_api()
self._update_overrides()

Expand All @@ -272,16 +295,20 @@ def _get_environment_from_api(self) -> EnvironmentModel:
return EnvironmentModel.model_validate(environment_data)

def _get_environment_flags_from_document(self) -> Flags:
if self._environment is None:
raise TypeError("No environment present")
return Flags.from_feature_state_models(
feature_states=engine.get_environment_feature_states(self._environment),
analytics_processor=self._analytics_processor,
default_flag_handler=self.default_flag_handler,
)

def _get_identity_flags_from_document(
self, identifier: str, traits: typing.Dict[str, typing.Any]
self, identifier: str, traits: typing.Mapping[str, TraitValue]
) -> Flags:
identity_model = self._get_identity_model(identifier, **traits)
if self._environment is None:
raise TypeError("No environment present")
feature_states = engine.get_identity_feature_states(
self._environment, identity_model
)
Expand All @@ -294,11 +321,11 @@ def _get_identity_flags_from_document(

def _get_environment_flags_from_api(self) -> Flags:
try:
api_flags = self._get_json_response(
url=self.environment_flags_url, method="GET"
)
json_response: typing.List[
typing.Mapping[str, JsonType]
] = self._get_json_response(url=self.environment_flags_url, method="GET")
return Flags.from_api_flags(
api_flags=api_flags,
api_flags=json_response,
analytics_processor=self._analytics_processor,
default_flag_handler=self.default_flag_handler,
)
Expand All @@ -310,11 +337,13 @@ def _get_environment_flags_from_api(self) -> Flags:
raise

def _get_identity_flags_from_api(
self, identifier: str, traits: typing.Dict[str, typing.Any]
self, identifier: str, traits: typing.Mapping[str, typing.Any]
) -> Flags:
try:
data = generate_identities_data(identifier, traits)
json_response = self._get_json_response(
json_response: typing.Dict[
str, typing.List[typing.Dict[str, JsonType]]
] = self._get_json_response(
url=self.identities_url, method="POST", body=data
)
return Flags.from_api_flags(
Expand All @@ -329,7 +358,14 @@ def _get_identity_flags_from_api(
return Flags(default_flag_handler=self.default_flag_handler)
raise

def _get_json_response(self, url: str, method: str, body: dict = None):
def _get_json_response(
self,
url: str,
method: str,
body: typing.Optional[
typing.Union[Identity, typing.Dict[str, JsonType]]
] = None,
) -> typing.Any:
try:
request_method = getattr(self.session, method.lower())
response = request_method(
Expand Down Expand Up @@ -371,7 +407,7 @@ def _get_identity_model(
identity_traits=trait_models,
)

def __del__(self):
def __del__(self) -> None:
if hasattr(self, "environment_data_polling_manager_thread"):
self.environment_data_polling_manager_thread.stop()

Expand Down
36 changes: 19 additions & 17 deletions flagsmith/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import typing
from dataclasses import dataclass, field

Expand All @@ -13,7 +15,7 @@ class BaseFlag:
value: typing.Union[str, int, float, bool, None]


@dataclass
@dataclass()
class DefaultFlag(BaseFlag):
is_default: bool = field(default=True)

Expand All @@ -28,8 +30,8 @@ class Flag(BaseFlag):
def from_feature_state_model(
cls,
feature_state_model: FeatureStateModel,
identity_id: typing.Union[str, int] = None,
) -> "Flag":
identity_id: typing.Optional[typing.Union[str, int]] = None,
) -> Flag:
return Flag(
enabled=feature_state_model.enabled,
value=feature_state_model.get_value(identity_id=identity_id),
Expand All @@ -38,7 +40,7 @@ def from_feature_state_model(
)

@classmethod
def from_api_flag(cls, flag_data: dict) -> "Flag":
def from_api_flag(cls, flag_data: typing.Mapping[str, typing.Any]) -> Flag:
return Flag(
enabled=flag_data["enabled"],
value=flag_data["feature_state_value"],
Expand All @@ -50,17 +52,17 @@ def from_api_flag(cls, flag_data: dict) -> "Flag":
@dataclass
class Flags:
flags: typing.Dict[str, Flag] = field(default_factory=dict)
default_flag_handler: typing.Callable[[str], DefaultFlag] = None
_analytics_processor: AnalyticsProcessor = None
default_flag_handler: typing.Optional[typing.Callable[[str], DefaultFlag]] = None
_analytics_processor: typing.Optional[AnalyticsProcessor] = None

@classmethod
def from_feature_state_models(
cls,
feature_states: typing.List[FeatureStateModel],
analytics_processor: AnalyticsProcessor,
default_flag_handler: typing.Callable,
identity_id: typing.Union[str, int] = None,
) -> "Flags":
feature_states: typing.Sequence[FeatureStateModel],
analytics_processor: typing.Optional[AnalyticsProcessor],
default_flag_handler: typing.Optional[typing.Callable[[str], DefaultFlag]],
identity_id: typing.Optional[typing.Union[str, int]] = None,
) -> Flags:
flags = {
feature_state.feature.name: Flag.from_feature_state_model(
feature_state, identity_id=identity_id
Expand All @@ -77,10 +79,10 @@ def from_feature_state_models(
@classmethod
def from_api_flags(
cls,
api_flags: typing.List[dict],
analytics_processor: AnalyticsProcessor,
default_flag_handler: typing.Callable,
) -> "Flags":
api_flags: typing.Sequence[typing.Mapping[str, typing.Any]],
analytics_processor: typing.Optional[AnalyticsProcessor],
default_flag_handler: typing.Optional[typing.Callable[[str], DefaultFlag]],
) -> Flags:
flags = {
flag_data["feature"]["name"]: Flag.from_api_flag(flag_data)
for flag_data in api_flags
Expand Down Expand Up @@ -120,12 +122,12 @@ def get_feature_value(self, feature_name: str) -> typing.Any:
"""
return self.get_flag(feature_name).value

def get_flag(self, feature_name: str) -> BaseFlag:
def get_flag(self, feature_name: str) -> typing.Union[DefaultFlag, Flag]:
"""
Get a specific flag given the feature name.
:param feature_name: the name of the feature to retrieve the flag for.
:return: BaseFlag object.
:return: DefaultFlag | Flag object.
:raises FlagsmithClientError: if feature doesn't exist
"""
try:
Expand Down
10 changes: 6 additions & 4 deletions flagsmith/polling_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
import threading
import time
Expand All @@ -16,10 +18,10 @@
class EnvironmentDataPollingManager(threading.Thread):
def __init__(
self,
*args,
main: "Flagsmith",
*args: typing.Any,
main: Flagsmith,
refresh_interval_seconds: typing.Union[int, float] = 10,
**kwargs
**kwargs: typing.Any,
):
super(EnvironmentDataPollingManager, self).__init__(*args, **kwargs)
self._stop_event = threading.Event()
Expand All @@ -37,5 +39,5 @@ def run(self) -> None:
def stop(self) -> None:
self._stop_event.set()

def __del__(self):
def __del__(self) -> None:
self._stop_event.set()
Loading

0 comments on commit 21eb873

Please sign in to comment.