Skip to content

Commit

Permalink
Fixes after more testing
Browse files Browse the repository at this point in the history
  • Loading branch information
rajeee committed Dec 14, 2023
1 parent 285e439 commit b625782
Show file tree
Hide file tree
Showing 3 changed files with 2,254 additions and 1,967 deletions.
48 changes: 34 additions & 14 deletions buildstock_query/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def get_buildstock_df(self) -> pd.DataFrame:
pd.DataFrame: The buildstock.csv dataframe.
"""
results_df = self.get_results_csv_full()
results_df = results_df[results_df[self.db_schema.column_names.completed_status].astype(str) ==
self.db_schema.completion_values.success]
results_df = results_df[results_df[self.db_schema.column_names.completed_status].astype(str).str.lower() ==
self.db_schema.completion_values.success.lower()]
buildstock_cols = [c for c in results_df.columns if c.startswith(self.char_prefix)]
buildstock_df = results_df[buildstock_cols]
buildstock_cols = [''.join(c.split(".")[1:]).replace("_", " ") for c in buildstock_df.columns
Expand Down Expand Up @@ -274,19 +274,29 @@ def _download_results_csv(self) -> str:
if os.path.exists(local_copy_path):
return local_copy_path

if isinstance(self.table_name, str):
db_table_name = f'{self.table_name}{self.db_schema.table_suffix.baseline}'
else:
db_table_name = self.table_name[0]
baseline_path = self._aws_glue.get_table(DatabaseName=self.db_name,
Name=self.bs_table.name)['Table']['StorageDescriptor']['Location']
Name=db_table_name)['Table']['StorageDescriptor']['Location']
bucket = baseline_path.split('/')[2]
key = '/'.join(baseline_path.split('/')[3:])
s3_data = self._aws_s3.list_objects(Bucket=bucket, Prefix=key)

if 'Contents' not in s3_data:
raise ValueError(f"Results parquet not found in s3 at {baseline_path}")
if len(s3_data['Contents']) > 1:
raise ValueError(f"Multiple results parquet found in s3 at {baseline_path}")
matching_files = [path['Key'] for path in s3_data['Contents']
if "up00.parquet" in path['Key'] or 'baseline' in path['Key']]

baseline_parquet_path = s3_data['Contents'][0]['Key']
self._aws_s3.download_file(bucket, baseline_parquet_path, local_copy_path)
if len(matching_files) > 1:
raise ValueError(f"Multiple results parquet found in s3 at {baseline_path} for baseline."
f"These files matched: {matching_files}")
if len(matching_files) == 0:
raise ValueError(f"No results parquet found in s3 at {baseline_path} for baseline."
f"Here are the files: {[content[0]['Key'] for content in s3_data['Contents']]}")

self._aws_s3.download_file(bucket, matching_files[0], local_copy_path)
return local_copy_path

def get_results_csv_full(self) -> pd.DataFrame:
Expand All @@ -296,7 +306,10 @@ def get_results_csv_full(self) -> pd.DataFrame:
pd.DataFrame: The full results csv.
"""
local_copy_path = self._download_results_csv()
return pd.read_parquet(local_copy_path).set_index(self.bs_bldgid_column.name)
df = pd.read_parquet(local_copy_path)
if df.index.name != self.bs_bldgid_column.name:
df = df.set_index(self.bs_bldgid_column.name)
return df

@typing.overload
def get_upgrades_csv(self, *, get_query_only: Literal[False] = False, upgrade_id: Union[int, str] = '0',
Expand Down Expand Up @@ -366,18 +379,22 @@ def _download_upgrades_csv(self, upgrade_id: int) -> str:
if os.path.exists(local_copy_path):
return local_copy_path

if isinstance(self.table_name, str):
db_table_name = f'{self.table_name}{self.db_schema.table_suffix.upgrades}'
else:
db_table_name = self.table_name[2]
upgrades_path = self._aws_glue.get_table(DatabaseName=self.db_name,
Name=self.up_table.name)['Table']['StorageDescriptor']['Location']
Name=db_table_name)['Table']['StorageDescriptor']['Location']
bucket = upgrades_path.split('/')[2]
key = '/'.join(upgrades_path.split('/')[3:])
s3_data = self._aws_s3.list_objects(Bucket=bucket, Prefix=key)

if 'Contents' not in s3_data:
raise ValueError(f"Results parquet not found in s3 at {upgrades_path}")
if len(s3_data['Contents']) != len(available_upgrades):
raise ValueError(f"Number of parquet found in s3 at {upgrades_path} is not equal to number of upgrades")
# out of the contents find the key with name matching the pattern results_up{upgrade_id}.parquet
matching_files = [path['Key'] for path in s3_data['Contents'] if f"up{upgrade_id:02}.parquet" in path['Key']]
matching_files = [path['Key'] for path in s3_data['Contents']
if f"up{upgrade_id:02}.parquet" in path['Key'] or
f"upgrade{upgrade_id:02}.parquet" in path['Key']]
if len(matching_files) > 1:
raise ValueError(f"Multiple results parquet found in s3 at {upgrades_path} for upgrade {upgrade_id}."
f"These files matched: {matching_files}")
Expand All @@ -394,8 +411,11 @@ def get_upgrades_csv_full(self, upgrade_id: int) -> pd.DataFrame:
athena.
"""
local_copy_path = self._download_upgrades_csv(upgrade_id)
df = pd.read_parquet(local_copy_path).set_index(self.up_bldgid_column.name)
df.insert(0, 'upgrade', upgrade_id)
df = pd.read_parquet(local_copy_path)
if df.index.name != self.up_bldgid_column.name:
df = df.set_index(self.up_bldgid_column.name)
if 'upgrade' not in df.columns:
df.insert(0, 'upgrade', upgrade_id)
return df

@typing.overload
Expand Down
Loading

0 comments on commit b625782

Please sign in to comment.