Skip to content

Commit

Permalink
Merge pull request #36 from NREL/cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
rajeee authored Dec 14, 2023
2 parents ff8e507 + d6ec53d commit ccf04da
Show file tree
Hide file tree
Showing 22 changed files with 574 additions and 3,227 deletions.
10 changes: 5 additions & 5 deletions buildstock_query/aggregate_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def aggregate_annual(self, *,
weights = list(params.weights) if params.weights else []
restrict = list(params.restrict) if params.restrict else []

[self._bsq.get_table(jl[0]) for jl in join_list] # ingress all tables in join list
[self._bsq._get_table(jl[0]) for jl in join_list] # ingress all tables in join list
if params.upgrade_id in {None, 0, '0'}:
enduse_cols = self._bsq._get_enduse_cols(params.enduses, table='baseline')
upgrade_id = None
Expand Down Expand Up @@ -65,10 +65,10 @@ def aggregate_annual(self, *,
self._bsq.up_table, sa.and_(self._bsq.bs_table.c[self._bsq.building_id_column_name] ==
self._bsq.up_table.c[self._bsq.building_id_column_name],
self._bsq.up_table.c["upgrade"] == str(upgrade_id),
self._bsq.up_successful_condition))
self._bsq._up_successful_condition))
query = query.select_from(tbljoin)

restrict = [(self._bsq.bs_completed_status_col, [self._bsq.db_schema.completion_values.success])] + restrict
restrict = [(self._bsq._bs_completed_status_col, [self._bsq.db_schema.completion_values.success])] + restrict
query = self._bsq._add_join(query, join_list)
query = self._bsq._add_restrict(query, restrict)
query = self._bsq._add_group_by(query, group_by_selection)
Expand Down Expand Up @@ -132,7 +132,7 @@ def aggregate_timeseries(self, params: TSQuery):

if params.split_enduses:
return self._aggregate_timeseries_light(params)
[self._bsq.get_table(jl[0]) for jl in params.join_list] # ingress all tables in join list
[self._bsq._get_table(jl[0]) for jl in params.join_list] # ingress all tables in join list
enduses_cols = self._bsq._get_enduse_cols(params.enduses, table='timeseries')
total_weight = self._bsq._get_weight(params.weights)

Expand Down Expand Up @@ -187,7 +187,7 @@ def aggregate_timeseries(self, params: TSQuery):
upgrade_in_restrict = any(entry[0] == 'upgrade' for entry in params.restrict)
if self._bsq.up_table is not None and not upgrade_in_restrict and 'upgrade' not in group_by_names:
logger.info(f"Restricting query to Upgrade {upgrade_id}.")
params.restrict = list(params.restrict) + [(self._bsq.ts_upgrade_col, [upgrade_id])]
params.restrict = list(params.restrict) + [(self._bsq._ts_upgrade_col, [upgrade_id])]

query = self._bsq._add_restrict(query, params.restrict)
query = self._bsq._add_group_by(query, group_by_selection)
Expand Down
91 changes: 47 additions & 44 deletions buildstock_query/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,16 @@ def __init__(self,
say, 'mfm_run', then it must correspond to tables in athena named mfm_run_baseline and optionally
mfm_run_timeseries and mf_run_upgrades. Or, tuple of three elements can be privided for the table names
for baseline, timeseries and upgrade. Timeseries and upgrade can be None if no such table exist.
timestamp_column_name (str, optional): The column name for the time column. Defaults to 'time'
building_id_column_name (str, optional): The column name for building_id. Defaults to 'building_id'
sample_weight (str, optional): The column name to be used to get the sample weight. Pass floats/integer to
use constant sample weight.. Defaults to "build_existing_model.sample_weight".
db_schema (str, optional): The database structure in Athena is different between ResStock and ComStock run.
It is also different between the version in OEDI and default version from BuildStockBatch. This argument
controls the assumed schema. Allowed values are 'resstock_default', 'resstock_oedi', 'comstock_default'
and 'comstock_oedi'. Defaults to 'resstock_default' for resstock and 'comstock_default' for comstock.
sample_weight (str, optional): Specify a custom sample_weight. Otherwise, the default is 1 for ComStock and
uses sample_weight in the run for ResStock.
region_name (str, optional): the AWS region where the database exists. Defaults to 'us-west-2'.
execution_history (str, optional): A temporary files to record which execution is run by the user,
to help stop them. Will use .execution_history if not supplied.
execution_history (str, optional): A temporary file to record which execution is run by the user,
to help stop them. Will use .execution_history if not supplied. Generally, not required to supply a
custom filename.
skip_reports (bool, optional): If true, skips report printing during initialization. If False (default),
prints report from `buildstock_query.report_query.BuildStockReport.get_success_report`.
athena_query_reuse (bool, optional): When true, Athena will make use of its built-in 7 day query cache.
Expand All @@ -81,13 +84,13 @@ def __init__(self,
execution_history=execution_history,
athena_query_reuse=athena_query_reuse
)
self.run_params = self.params.get_run_params()
self._run_params = self.params.get_run_params()
from buildstock_query.report_query import BuildStockReport
from buildstock_query.aggregate_query import BuildStockAggregate
from buildstock_query.savings_query import BuildStockSavings
from buildstock_query.utility_query import BuildStockUtility

super().__init__(params=self.run_params)
super().__init__(params=self._run_params)
#: `buildstock_query.report_query.BuildStockReport` object to perform report queries
self.report: BuildStockReport = BuildStockReport(self)
#: `buildstock_query.aggregate_query.BuildStockAggregate` object to perform aggregate queries
Expand All @@ -97,8 +100,8 @@ def __init__(self,
#: `buildstock_query.utility_query.BuildStockUtility` object to perform utility queries
self.utility = BuildStockUtility(self)

self.char_prefix = self.db_schema.column_prefix.characteristics
self.out_prefix = self.db_schema.column_prefix.output
self._char_prefix = self.db_schema.column_prefix.characteristics
self._out_prefix = self.db_schema.column_prefix.output

if not skip_reports:
logger.info("Getting Success counts...")
Expand All @@ -116,10 +119,10 @@ def get_buildstock_df(self) -> pd.DataFrame:
results_df = self.get_results_csv_full()
results_df = results_df[results_df[self.db_schema.column_names.completed_status].astype(str).str.lower() ==
self.db_schema.completion_values.success.lower()]
buildstock_cols = [c for c in results_df.columns if c.startswith(self.char_prefix)]
buildstock_cols = [c for c in results_df.columns if c.startswith(self._char_prefix)]
buildstock_df = results_df[buildstock_cols]
buildstock_cols = [''.join(c.split(".")[1:]).replace("_", " ") for c in buildstock_df.columns
if c.startswith(self.char_prefix)]
if c.startswith(self._char_prefix)]
buildstock_df.columns = buildstock_cols
return buildstock_df

Expand Down Expand Up @@ -182,7 +185,7 @@ def get_distinct_vals(self, column: str, table_name: Optional[str],
pd.Series: The distinct vals.
"""
table_name = self.bs_table.name if table_name is None else table_name
tbl = self.get_table(table_name)
tbl = self._get_table(table_name)
query = sa.select(tbl.c[column]).distinct()
if get_query_only:
return self._compile(query)
Expand All @@ -203,7 +206,7 @@ def get_distinct_count(self, column: str, table_name: Optional[str] = None,
Returns:
pd.Series: The distinct counts.
"""
tbl = self.bs_table if table_name is None else self.get_table(table_name)
tbl = self.bs_table if table_name is None else self._get_table(table_name)
query = sa.select([tbl.c[column], safunc.sum(1).label("sample_count"),
safunc.sum(self.sample_wt).label("weighted_count")])
query = query.group_by(tbl.c[column]).order_by(tbl.c[column])
Expand Down Expand Up @@ -478,14 +481,14 @@ def _get_simulation_info(self, get_query_only: Literal[True]) -> str:
@validate_arguments(config=dict(smart_union=True))
def _get_simulation_info(self, get_query_only: bool = False) -> Union[str, SimInfo]:
# find the simulation time interval
query0 = sa.select([self.ts_bldgid_column, self.ts_upgrade_col]).limit(1) # get a building id and upgrade
query0 = sa.select([self.ts_bldgid_column, self._ts_upgrade_col]).limit(1) # get a building id and upgrade
bldg_df = self.execute(query0)
bldg_id = bldg_df.values[0][0]
upgrade_id = bldg_df.values[0][1]
query1 = sa.select([self.timestamp_column.distinct().label(
self.timestamp_column_name)]).where(self.ts_bldgid_column == bldg_id)
if self.up_table is not None:
query1 = query1.where(self.ts_upgrade_col == upgrade_id)
query1 = query1.where(self._ts_upgrade_col == upgrade_id)
query1 = query1.order_by(self.timestamp_column).limit(2)
if get_query_only:
return self._compile(query1)
Expand All @@ -509,8 +512,8 @@ def _get_simulation_info(self, get_query_only: bool = False) -> Union[str, SimIn
assert offset in [0, interval]
return SimInfo(sim_year, interval, offset, unit)

def get_special_column(self,
column_type: Literal['month', 'day', 'hour', 'is_weekend', 'day_of_week']) -> DBColType:
def _get_special_column(self,
column_type: Literal['month', 'day', 'hour', 'is_weekend', 'day_of_week']) -> DBColType:
sim_info = self._get_simulation_info()
if sim_info.offset > 0:
# If timestamps are not period begining we should make them so we get proper values of special columns.
Expand Down Expand Up @@ -547,17 +550,17 @@ def _get_gcol(self, column) -> DBColType: # gcol => group by col

if isinstance(column, tuple):
try:
return self.get_column(column[0]).label(column[1])
return self._get_column(column[0]).label(column[1])
except ValueError:
new_name = f"{self.char_prefix}{column[0]}"
return self.get_column(new_name).label(column[1])
new_name = f"{self._char_prefix}{column[0]}"
return self._get_column(new_name).label(column[1])
elif isinstance(column, str):
try:
return self.get_column(column).label(self._simple_label(column))
return self._get_column(column).label(self._simple_label(column))
except ValueError as e:
if not column.startswith(self.char_prefix):
new_name = f"{self.char_prefix}{column}"
return self.get_column(new_name).label(column)
if not column.startswith(self._char_prefix):
new_name = f"{self._char_prefix}{column}"
return self._get_column(new_name).label(column)
raise ValueError(f"Invalid column name {column}") from e
else:
raise ValueError(f"Invalid column name type {column}: {type(column)}")
Expand All @@ -577,7 +580,7 @@ def _get_enduse_cols(self, enduses: Sequence[AnyColType],
enduse_cols.append(tbl.c[enduse])
except KeyError as err:
if table in ['baseline', 'upgrade']:
enduse_cols.append(tbl.c[f"{self.out_prefix}{enduse}"])
enduse_cols.append(tbl.c[f"{self._out_prefix}{enduse}"])
else:
raise ValueError(f"Invalid enduse column names for {table} table") from err
elif isinstance(enduse, MappedColumn):
Expand All @@ -592,8 +595,8 @@ def get_groupby_cols(self) -> List[str]:
Returns:
List[str]: List of building characteristics.
"""
cols = {y.removeprefix(self.char_prefix) for y in self.bs_table.c.keys()
if y.startswith(self.char_prefix)}
cols = {y.removeprefix(self._char_prefix) for y in self.bs_table.c.keys()
if y.startswith(self._char_prefix)}
return list(cols)

def _validate_group_by(self, group_by: Sequence[Union[str, tuple[str, str]]]):
Expand Down Expand Up @@ -676,10 +679,10 @@ def _process_groupby_cols(self, group_by, annual_only=False):
if annual_only:
new_group_by = []
for entry in group_by:
if isinstance(entry, str) and not entry.startswith(self.char_prefix):
new_group_by.append(f"{self.char_prefix}{entry}")
elif isinstance(entry, tuple) and not entry[0].startswith(self.char_prefix):
new_group_by.append((f"{self.char_prefix}{entry[0]}", entry[1]))
if isinstance(entry, str) and not entry.startswith(self._char_prefix):
new_group_by.append(f"{self._char_prefix}{entry}")
elif isinstance(entry, tuple) and not entry[0].startswith(self._char_prefix):
new_group_by.append((f"{self._char_prefix}{entry[0]}", entry[1]))
else:
new_group_by.append(entry)
group_by = new_group_by
Expand Down Expand Up @@ -728,23 +731,23 @@ def get_buildings_by_locations(self, location_col: str, locations: List[str],
"""
query = sa.select([self.bs_bldgid_column])
query = query.where(self.get_column(location_col).in_(locations))
query = query.where(self._get_column(location_col).in_(locations))
query = self._add_order_by(query, [self.bs_bldgid_column])
if get_query_only:
return self._compile(query)
res = self.execute(query)
return res

@property
def bs_completed_status_col(self):
def _bs_completed_status_col(self):
if not isinstance(self.bs_table.c[self.db_schema.column_names.completed_status].type, sqltypes.String):
return sa.cast(self.bs_table.c[self.db_schema.column_names.completed_status],
sa.String).label('completed_status')
else:
return self.bs_table.c[self.db_schema.column_names.completed_status]

@property
def up_completed_status_col(self):
def _up_completed_status_col(self):
if self.up_table is None:
raise ValueError("No upgrades table")
if not isinstance(self.up_table.c[self.db_schema.column_names.completed_status].type, sqltypes.String):
Expand All @@ -754,35 +757,35 @@ def up_completed_status_col(self):
return self.up_table.c[self.db_schema.column_names.completed_status]

@property
def bs_successful_condition(self):
return self.bs_completed_status_col == self.db_schema.completion_values.success
def _bs_successful_condition(self):
return self._bs_completed_status_col == self.db_schema.completion_values.success

@property
def up_successful_condition(self):
return self.up_completed_status_col == self.db_schema.completion_values.success
def _up_successful_condition(self):
return self._up_completed_status_col == self.db_schema.completion_values.success

@property
def ts_upgrade_col(self):
def _ts_upgrade_col(self):
if not isinstance(self.ts_table.c['upgrade'].type, sqltypes.String):
return sa.cast(self.ts_table.c['upgrade'], sa.String).label('upgrade')
else:
return self.ts_table.c['upgrade']

@property
def up_upgrade_col(self):
def _up_upgrade_col(self):
if self.up_table is None:
raise ValueError("No upgrades table")
if not isinstance(self.up_table.c['upgrade'].type, sqltypes.String):
return sa.cast(self.up_table.c['upgrade'], sa.String).label('upgrade')
else:
return self.up_table.c['upgrade']

def get_completed_status_col(self, table: AnyTableType):
def _get_completed_status_col(self, table: AnyTableType):
if not isinstance(table.c[self.db_schema.column_names.completed_status].type, sqltypes.String):
return sa.cast(table.c[self.db_schema.column_names.completed_status],
sa.String).label('completed_status')
else:
return table.c[self.db_schema.column_names.completed_status]

def get_success_condition(self, table: AnyTableType):
return self.get_completed_status_col(table) == self.db_schema.completion_values.success
def _get_success_condition(self, table: AnyTableType):
return self._get_completed_status_col(table) == self.db_schema.completion_values.success
Loading

0 comments on commit ccf04da

Please sign in to comment.