Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: strict typing #70

Merged
merged 3 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
name: Formatting and Tests
name: Linting and Tests

on:
- pull_request

jobs:
test:
runs-on: ubuntu-latest
name: Pytest and Black formatting
name: Linting and Tests

strategy:
max-parallel: 4
Expand All @@ -28,13 +28,16 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install poetry
poetry install
poetry install --with dev

- name: Check Formatting
run: |
poetry run black --check .
poetry run flake8 .
poetry run isort --check .

- name: Check Typing
run: poetry run mypy --strict .

- name: Run Tests
run: poetry run pytest
20 changes: 14 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
repos:
- repo: https://github.com/asottile/seed-isort-config
rev: v1.9.3
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.5.1
hooks:
- id: seed-isort-config
- repo: https://github.com/pre-commit/mirrors-isort
rev: v4.3.21
- id: mypy
args: [--strict]
additional_dependencies:
[pydantic, pytest, pytest_mock, types-requests, flagsmith-flag-engine, responses, types-pytz, sseclient-py]
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 23.7.0
hooks:
- id: black
language_version: python3
- repo: https://github.com/pycqa/flake8
rev: 6.1.0
hooks:
- id: flake8
name: flake8
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
Expand Down
4 changes: 3 additions & 1 deletion flagsmith/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .flagsmith import Flagsmith # noqa
from .flagsmith import Flagsmith

__all__ = ("Flagsmith",)
19 changes: 11 additions & 8 deletions flagsmith/analytics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import json
import typing
from datetime import datetime

from requests_futures.sessions import FuturesSession
from requests_futures.sessions import FuturesSession # type: ignore

ANALYTICS_ENDPOINT = "analytics/flags/"
ANALYTICS_ENDPOINT: typing.Final[str] = "analytics/flags/"

# Used to control how often we send data(in seconds)
ANALYTICS_TIMER = 10
ANALYTICS_TIMER: typing.Final[int] = 10

session = FuturesSession(max_workers=4)

Expand All @@ -17,7 +18,9 @@ class AnalyticsProcessor:
the Flagsmith SDK. Docs: https://docs.flagsmith.com/advanced-use/flag-analytics.
"""

def __init__(self, environment_key: str, base_api_url: str, timeout: int = 3):
def __init__(
self, environment_key: str, base_api_url: str, timeout: typing.Optional[int] = 3
):
"""
Initialise the AnalyticsProcessor to handle sending analytics on flag usage to
the Flagsmith API.
Expand All @@ -30,10 +33,10 @@ def __init__(self, environment_key: str, base_api_url: str, timeout: int = 3):
self.analytics_endpoint = base_api_url + ANALYTICS_ENDPOINT
self.environment_key = environment_key
self._last_flushed = datetime.now()
self.analytics_data = {}
self.timeout = timeout
self.analytics_data: typing.MutableMapping[str, typing.Any] = {}
self.timeout = timeout or 3

def flush(self):
def flush(self) -> None:
"""
Sends all the collected data to the api asynchronously and resets the timer
"""
Expand All @@ -53,7 +56,7 @@ def flush(self):
self.analytics_data.clear()
self._last_flushed = datetime.now()

def track_feature(self, feature_name: str):
def track_feature(self, feature_name: str) -> None:
self.analytics_data[feature_name] = self.analytics_data.get(feature_name, 0) + 1
if (datetime.now() - self._last_flushed).seconds > ANALYTICS_TIMER:
self.flush()
82 changes: 59 additions & 23 deletions flagsmith/flagsmith.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,23 @@
from flagsmith.offline_handlers import BaseOfflineHandler
from flagsmith.polling_manager import EnvironmentDataPollingManager
from flagsmith.streaming_manager import EventStreamManager, StreamEvent
from flagsmith.utils.identities import generate_identities_data
from flagsmith.utils.identities import Identity, generate_identities_data

logger = logging.getLogger(__name__)

DEFAULT_API_URL = "https://edge.api.flagsmith.com/api/v1/"
DEFAULT_REALTIME_API_URL = "https://realtime.flagsmith.com/"

JsonType = typing.Union[
None,
int,
str,
bool,
typing.List["JsonType"],
typing.List[typing.Mapping[str, "JsonType"]],
typing.Dict[str, "JsonType"],
]


class Flagsmith:
"""A Flagsmith client.
Expand All @@ -45,19 +55,21 @@ class Flagsmith:

def __init__(
self,
environment_key: str = None,
api_url: str = None,
environment_key: typing.Optional[str] = None,
api_url: typing.Optional[str] = None,
realtime_api_url: typing.Optional[str] = None,
custom_headers: typing.Dict[str, typing.Any] = None,
request_timeout_seconds: int = None,
custom_headers: typing.Optional[typing.Dict[str, typing.Any]] = None,
request_timeout_seconds: typing.Optional[int] = None,
enable_local_evaluation: bool = False,
environment_refresh_interval_seconds: typing.Union[int, float] = 60,
retries: Retry = None,
retries: typing.Optional[Retry] = None,
enable_analytics: bool = False,
default_flag_handler: typing.Callable[[str], DefaultFlag] = None,
proxies: typing.Dict[str, str] = None,
default_flag_handler: typing.Optional[
typing.Callable[[str], DefaultFlag]
] = None,
proxies: typing.Optional[typing.Dict[str, str]] = None,
offline_mode: bool = False,
offline_handler: BaseOfflineHandler = None,
offline_handler: typing.Optional[BaseOfflineHandler] = None,
enable_realtime_updates: bool = False,
):
"""
Expand Down Expand Up @@ -94,8 +106,8 @@ def __init__(
self.offline_handler = offline_handler
self.default_flag_handler = default_flag_handler
self.enable_realtime_updates = enable_realtime_updates
self._analytics_processor = None
self._environment = None
self._analytics_processor: typing.Optional[AnalyticsProcessor] = None
self._environment: typing.Optional[EnvironmentModel] = None
self._identity_overrides_by_identifier: typing.Dict[str, IdentityModel] = {}

# argument validation
Expand Down Expand Up @@ -159,6 +171,9 @@ def __init__(
def _initialise_local_evaluation(self) -> None:
if self.enable_realtime_updates:
self.update_environment()
if not self._environment:
raise ValueError("Unable to get environment from API key")

stream_url = f"{self.realtime_api_url}sse/environments/{self._environment.api_key}/stream"

self.event_stream_thread = EventStreamManager(
Expand Down Expand Up @@ -196,6 +211,10 @@ def handle_stream_event(self, event: StreamEvent) -> None:
if stream_updated_at.tzinfo is None:
stream_updated_at = pytz.utc.localize(stream_updated_at)

if not self._environment:
raise ValueError(
"Unable to access environment. Environment should not be null"
)
environment_updated_at = self._environment.updated_at
if environment_updated_at.tzinfo is None:
environment_updated_at = pytz.utc.localize(environment_updated_at)
Expand All @@ -214,7 +233,9 @@ def get_environment_flags(self) -> Flags:
return self._get_environment_flags_from_api()

def get_identity_flags(
self, identifier: str, traits: typing.Dict[str, typing.Any] = None
self,
identifier: str,
traits: typing.Optional[typing.Mapping[str, TraitValue]] = None,
) -> Flags:
"""
Get all the flags for the current environment for a given identity. Will also
Expand All @@ -233,7 +254,9 @@ def get_identity_flags(
return self._get_identity_flags_from_api(identifier, traits)

def get_identity_segments(
self, identifier: str, traits: typing.Dict[str, typing.Any] = None
self,
identifier: str,
traits: typing.Optional[typing.Mapping[str, TraitValue]] = None,
) -> typing.List[Segment]:
"""
Get a list of segments that the given identity is in.
Expand All @@ -255,7 +278,7 @@ def get_identity_segments(
segment_models = get_identity_segments(self._environment, identity_model)
return [Segment(id=sm.id, name=sm.name) for sm in segment_models]

def update_environment(self):
def update_environment(self) -> None:
self._environment = self._get_environment_from_api()
self._update_overrides()

Expand All @@ -272,16 +295,20 @@ def _get_environment_from_api(self) -> EnvironmentModel:
return EnvironmentModel.model_validate(environment_data)

def _get_environment_flags_from_document(self) -> Flags:
if self._environment is None:
raise TypeError("No environment present")
return Flags.from_feature_state_models(
feature_states=engine.get_environment_feature_states(self._environment),
analytics_processor=self._analytics_processor,
default_flag_handler=self.default_flag_handler,
)

def _get_identity_flags_from_document(
self, identifier: str, traits: typing.Dict[str, typing.Any]
self, identifier: str, traits: typing.Mapping[str, TraitValue]
) -> Flags:
identity_model = self._get_identity_model(identifier, **traits)
if self._environment is None:
raise TypeError("No environment present")
feature_states = engine.get_identity_feature_states(
self._environment, identity_model
)
Expand All @@ -294,11 +321,11 @@ def _get_identity_flags_from_document(

def _get_environment_flags_from_api(self) -> Flags:
try:
api_flags = self._get_json_response(
url=self.environment_flags_url, method="GET"
)
json_response: typing.List[
typing.Mapping[str, JsonType]
] = self._get_json_response(url=self.environment_flags_url, method="GET")
return Flags.from_api_flags(
api_flags=api_flags,
api_flags=json_response,
analytics_processor=self._analytics_processor,
default_flag_handler=self.default_flag_handler,
)
Expand All @@ -310,11 +337,13 @@ def _get_environment_flags_from_api(self) -> Flags:
raise

def _get_identity_flags_from_api(
self, identifier: str, traits: typing.Dict[str, typing.Any]
self, identifier: str, traits: typing.Mapping[str, typing.Any]
) -> Flags:
try:
data = generate_identities_data(identifier, traits)
json_response = self._get_json_response(
json_response: typing.Dict[
str, typing.List[typing.Dict[str, JsonType]]
] = self._get_json_response(
url=self.identities_url, method="POST", body=data
)
return Flags.from_api_flags(
Expand All @@ -329,7 +358,14 @@ def _get_identity_flags_from_api(
return Flags(default_flag_handler=self.default_flag_handler)
raise

def _get_json_response(self, url: str, method: str, body: dict = None):
def _get_json_response(
self,
url: str,
method: str,
body: typing.Optional[
typing.Union[Identity, typing.Dict[str, JsonType]]
] = None,
) -> typing.Any:
try:
request_method = getattr(self.session, method.lower())
response = request_method(
Expand Down Expand Up @@ -371,7 +407,7 @@ def _get_identity_model(
identity_traits=trait_models,
)

def __del__(self):
def __del__(self) -> None:
if hasattr(self, "environment_data_polling_manager_thread"):
self.environment_data_polling_manager_thread.stop()

Expand Down
34 changes: 18 additions & 16 deletions flagsmith/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import typing
from dataclasses import dataclass, field

Expand Down Expand Up @@ -28,8 +30,8 @@ class Flag(BaseFlag):
def from_feature_state_model(
cls,
feature_state_model: FeatureStateModel,
identity_id: typing.Union[str, int] = None,
) -> "Flag":
identity_id: typing.Optional[typing.Union[str, int]] = None,
) -> Flag:
return Flag(
enabled=feature_state_model.enabled,
value=feature_state_model.get_value(identity_id=identity_id),
Expand All @@ -38,7 +40,7 @@ def from_feature_state_model(
)

@classmethod
def from_api_flag(cls, flag_data: dict) -> "Flag":
def from_api_flag(cls, flag_data: typing.Mapping[str, typing.Any]) -> Flag:
return Flag(
enabled=flag_data["enabled"],
value=flag_data["feature_state_value"],
Expand All @@ -50,17 +52,17 @@ def from_api_flag(cls, flag_data: dict) -> "Flag":
@dataclass
class Flags:
flags: typing.Dict[str, Flag] = field(default_factory=dict)
default_flag_handler: typing.Callable[[str], DefaultFlag] = None
_analytics_processor: AnalyticsProcessor = None
default_flag_handler: typing.Optional[typing.Callable[[str], DefaultFlag]] = None
_analytics_processor: typing.Optional[AnalyticsProcessor] = None

@classmethod
def from_feature_state_models(
cls,
feature_states: typing.List[FeatureStateModel],
analytics_processor: AnalyticsProcessor,
default_flag_handler: typing.Callable,
identity_id: typing.Union[str, int] = None,
) -> "Flags":
feature_states: typing.Sequence[FeatureStateModel],
analytics_processor: typing.Optional[AnalyticsProcessor],
default_flag_handler: typing.Optional[typing.Callable[[str], DefaultFlag]],
identity_id: typing.Optional[typing.Union[str, int]] = None,
) -> Flags:
flags = {
feature_state.feature.name: Flag.from_feature_state_model(
feature_state, identity_id=identity_id
Expand All @@ -77,10 +79,10 @@ def from_feature_state_models(
@classmethod
def from_api_flags(
cls,
api_flags: typing.List[dict],
analytics_processor: AnalyticsProcessor,
default_flag_handler: typing.Callable,
) -> "Flags":
api_flags: typing.Sequence[typing.Mapping[str, typing.Any]],
analytics_processor: typing.Optional[AnalyticsProcessor],
default_flag_handler: typing.Optional[typing.Callable[[str], DefaultFlag]],
) -> Flags:
flags = {
flag_data["feature"]["name"]: Flag.from_api_flag(flag_data)
for flag_data in api_flags
Expand Down Expand Up @@ -120,12 +122,12 @@ def get_feature_value(self, feature_name: str) -> typing.Any:
"""
return self.get_flag(feature_name).value

def get_flag(self, feature_name: str) -> BaseFlag:
def get_flag(self, feature_name: str) -> typing.Union[DefaultFlag, Flag]:
"""
Get a specific flag given the feature name.

:param feature_name: the name of the feature to retrieve the flag for.
:return: BaseFlag object.
:return: DefaultFlag | Flag object.
:raises FlagsmithClientError: if feature doesn't exist
"""
try:
Expand Down
Loading
Loading