Skip to content

Commit

Permalink
Add types to a lot of functions
Browse files Browse the repository at this point in the history
  • Loading branch information
dragon-dxw committed Dec 12, 2024
1 parent f647e4d commit 8b982c1
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions ds-caselaw-ingester/lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, metadata):
self.parameters = metadata.get("parameters", {})

@property
def is_tdr(self):
def is_tdr(self) -> bool:
return "TDR" in self.parameters.keys()

@property
Expand Down Expand Up @@ -86,7 +86,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@property
def originator(self):
def originator(self) -> str:
return self.message.get("parameters", {}).get("originator")

def get_consignment_reference(self):
Expand All @@ -100,7 +100,7 @@ def get_consignment_reference(self):

raise InvalidMessageException("Malformed v2 message, please supply a reference")

def save_s3_response(self, sqs_client, s3_client):
def save_s3_response(self, sqs_client, s3_client) -> str:
s3_bucket = self.message.get("parameters", {}).get("s3Bucket")
s3_key = self.message.get("parameters", {}).get("s3Key")
reference = self.get_consignment_reference()
Expand Down Expand Up @@ -195,7 +195,7 @@ def modify_filename(original: str, addition: str) -> str:
return os.path.join(path, new_basename)


def all_messages(event) -> List[Message]:
def all_messages(event) -> list[Message]:
"""All the messages in the SNS event, as Message subclasses"""
decoder = json.decoder.JSONDecoder()
messages_as_decoded_json = [decoder.decode(record["Sns"]["Message"]) for record in event["Records"]]
Expand Down Expand Up @@ -249,7 +249,7 @@ def extract_docx_filename(metadata: dict, consignment_reference: str) -> str:
)


def extract_lambda_versions(versions: List[Dict[str, str]]) -> List[Tuple[str, str]]:
def extract_lambda_versions(versions: list[dict[str, str]]) -> list[tuple[str, str]]:
version_tuples = []
for d in versions:
version_tuples += list(d.items())
Expand Down Expand Up @@ -501,7 +501,7 @@ def store_metadata(self) -> None:
value=tdr_metadata["Consignment-Completed-Datetime"],
)

def save_files_to_s3(self):
def save_files_to_s3(self) -> None:
sqs_client, s3_client = aws_clients()
# Determine if there's a word document -- we need to know before we save the tar.gz file
docx_filename = extract_docx_filename(self.metadata, self.consignment_reference)
Expand Down Expand Up @@ -555,7 +555,7 @@ def save_files_to_s3(self):
)

@property
def metadata_object(self):
def metadata_object(self) -> Metadata:
return Metadata(self.metadata)

def will_publish(self) -> bool:
Expand All @@ -574,7 +574,7 @@ def will_publish(self) -> bool:

raise RuntimeError(f"Didn't recognise originator {originator!r}")

def send_email(self):
def send_email(self) -> None:
originator = self.message.originator
if originator == "FCL":
return None
Expand All @@ -587,10 +587,10 @@ def send_email(self):

raise RuntimeError(f"Didn't recognise originator {originator!r}")

def close_tar(self):
def close_tar(self) -> None:
self.tar.close()

def upload_xml(self):
def upload_xml(self) -> None:
self.updated = self.update_document_xml()
self.inserted = False if self.updated else self.insert_document_xml()
if not self.updated and not self.inserted:
Expand All @@ -599,7 +599,7 @@ def upload_xml(self):
)

@property
def upload_state(self):
def upload_state(self) -> str:
return "updated" if self.updated else "inserted"


Expand Down

0 comments on commit 8b982c1

Please sign in to comment.