Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make library type-checkable #16

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion restcountries/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# simpler import as described in the readme
from restcountries.base import RestCountryApiV2
from restcountries.base import Country, RestCountryApiV2
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually this class is not needed for imports - why add it here?

Copy link
Author

@iron3oxide iron3oxide Oct 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows for something like

countries: list[Country] = rapi.get_all()
print(countries[0].alpha_code)

in order to pass pyright or a similar type checker, as said checker can now verify that alpha_code is a field that the first object of that list is guaranteed/supposed to have.

CORRECTION: Since I added the type hint for result_list, this is not strictly needed to pass pyright as it can simply infer the type of countries in the example above. Country is however needed to "properly" type hint it, but I guess countries: list will work for most people so if there are good reasons not to make Country importable, this change can be removed.


__version__ = "2.0.0"
14 changes: 6 additions & 8 deletions restcountries/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import requests
import json

import requests


class RestCountryApiV2:
BASE_URI = "https://restcountries.com/v2"
Expand Down Expand Up @@ -36,7 +37,7 @@ def _get_country_list(cls, resource, term="", filters=None):

response = requests.get(uri)
if response.status_code == 200:
result_list = []
result_list: list[Country] = []
data = json.loads(response.text) # parse json to dict
if type(data) == list:
for (
Expand All @@ -47,7 +48,7 @@ def _get_country_list(cls, resource, term="", filters=None):
country = Country(country_data)
result_list.append(country)
else:
return Country(data)
result_list.append(Country(data))
return result_list
elif response.status_code == 404:
raise requests.exceptions.InvalidURL
Expand All @@ -58,7 +59,7 @@ def _get_country_list(cls, resource, term="", filters=None):
def get_all(cls, filters=None):
"""Returns all countries provided by restcountries.eu.

:param filters - a list of fields to filter the output of the request to include only the specified fields.
:param filters - a list of fields to filter the output of the request to include only the specified fields.
"""
resource = "/all"
return cls._get_country_list(resource, filters=filters)
Expand Down Expand Up @@ -106,7 +107,7 @@ def get_country_by_country_code(cls, alpha, filters=None):
You can look those up at wikipedia: https://en.wikipedia.org/wiki/ISO_3166-1
"""
resource = "/alpha"
return cls._get_country_list(resource, alpha, filters=filters)
return cls._get_country_list(resource, alpha, filters=filters)[0]

@classmethod
def get_countries_by_country_codes(cls, codes, filters=None):
Expand Down Expand Up @@ -167,9 +168,6 @@ def get_countries_by_capital(cls, capital, filters=None):


class Country:
def __str__(self):
SteinRobert marked this conversation as resolved.
Show resolved Hide resolved
return "{}".format(self.name)

def __init__(self, country_data):
self.top_level_domain = country_data.get("topLevelDomain")
self.alpha2_code = country_data.get("alpha2Code")
Expand Down