diff --git a/buildstock_query/aggregate_query.py b/buildstock_query/aggregate_query.py index 222e1a6..e7f25a1 100644 --- a/buildstock_query/aggregate_query.py +++ b/buildstock_query/aggregate_query.py @@ -71,6 +71,7 @@ def aggregate_annual(self, *, restrict = [(self._bsq._bs_completed_status_col, [self._bsq.db_schema.completion_values.success])] + restrict query = self._bsq._add_join(query, join_list) query = self._bsq._add_restrict(query, restrict) + query = self._bsq._add_avoid(query, params.avoid) query = self._bsq._add_group_by(query, group_by_selection) query = self._bsq._add_order_by(query, group_by_selection if params.sort else []) @@ -190,6 +191,7 @@ def aggregate_timeseries(self, params: TSQuery): params.restrict = list(params.restrict) + [(self._bsq._ts_upgrade_col, [upgrade_id])] query = self._bsq._add_restrict(query, params.restrict) + query = self._bsq._add_avoid(query, params.avoid) query = self._bsq._add_group_by(query, group_by_selection) query = self._bsq._add_order_by(query, group_by_selection if params.sort else []) query = query.limit(params.limit) if params.limit else query diff --git a/buildstock_query/aggregate_query.pyi b/buildstock_query/aggregate_query.pyi index 90eec8e..8fdcf7c 100644 --- a/buildstock_query/aggregate_query.pyi +++ b/buildstock_query/aggregate_query.pyi @@ -21,6 +21,7 @@ class BuildStockAggregate: join_list: Sequence[tuple[AnyTableType, AnyColType, AnyColType]] = [], weights: Sequence[Union[str, tuple]] = [], restrict: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = [], + avoid: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = [], get_quartiles: bool = False, get_nonzero_count: bool = False, ) -> str: @@ -36,6 +37,7 @@ class BuildStockAggregate: join_list: Sequence[tuple[AnyTableType, AnyColType, AnyColType]] = [], weights: Sequence[Union[str, tuple]] = [], restrict: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = [], + avoid: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = [], get_quartiles: bool = False, get_nonzero_count: bool = False, ) -> pd.DataFrame: @@ -51,6 +53,7 @@ class BuildStockAggregate: join_list: Sequence[tuple[AnyTableType, AnyColType, AnyColType]] = [], weights: Sequence[Union[str, tuple]] = [], restrict: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = [], + avoid: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = [], get_quartiles: bool = False, get_nonzero_count: bool = False, ) -> Union[pd.DataFrame, str]: @@ -79,6 +82,8 @@ class BuildStockAggregate: restrict: The list of where condition to restrict the results to. It should be specified as a list of tuple. Example: `[('state',['VA','AZ']), ("build_existing_model.lighting",['60% CFL']), ...]` + avoid: Just like restrict, but the opposite. It will only include rows that do not match (any of) the + conditions. get_quartiles: If true, return the following quartiles in addition to the sum for each enduses: [0, 0.02, .25, .5, .75, .98, 1]. The 0% quartile is the minimum and the 100% quartile is the maximum. @@ -109,6 +114,7 @@ class BuildStockAggregate: join_list: Sequence[tuple[AnyTableType, AnyColType, AnyColType]] = [], weights: Sequence[Union[str, tuple]] = [], restrict: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = [], + avoid: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = [], split_enduses: bool = False, collapse_ts: bool = False, timestamp_grouping_func: Optional[str] = None, @@ -125,6 +131,7 @@ class BuildStockAggregate: join_list: Sequence[tuple[AnyTableType, AnyColType, AnyColType]] = [], weights: Sequence[Union[str, tuple]] = [], restrict: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = [], + avoid: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = [], split_enduses: bool = False, collapse_ts: bool = False, timestamp_grouping_func: Optional[str] = None, @@ -143,6 +150,7 @@ class BuildStockAggregate: join_list: Sequence[tuple[AnyTableType, AnyColType, AnyColType]] = [], weights: Sequence[Union[str, tuple]] = [], restrict: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = [], + avoid: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = [], split_enduses: bool = False, collapse_ts: bool = False, timestamp_grouping_func: Optional[str] = None, diff --git a/buildstock_query/main.py b/buildstock_query/main.py index 017061d..b24c49b 100644 --- a/buildstock_query/main.py +++ b/buildstock_query/main.py @@ -40,11 +40,12 @@ def __init__(self, table_name: Union[str, tuple[str, Optional[str], Optional[str]]], db_schema: Optional[str] = None, buildstock_type: Literal['resstock', 'comstock'] = 'resstock', - sample_weight: Optional[Union[int, float]] = None, + sample_weight_override: Optional[Union[int, float]] = None, region_name: str = 'us-west-2', execution_history: Optional[str] = None, skip_reports: bool = False, athena_query_reuse: bool = True, + **kwargs, ) -> None: """A class to run Athena queries for BuildStock runs and download results as pandas DataFrame. @@ -60,8 +61,8 @@ def __init__(self, It is also different between the version in OEDI and default version from BuildStockBatch. This argument controls the assumed schema. Allowed values are 'resstock_default', 'resstock_oedi', 'comstock_default' and 'comstock_oedi'. Defaults to 'resstock_default' for resstock and 'comstock_default' for comstock. - sample_weight (str, optional): Specify a custom sample_weight. Otherwise, the default is 1 for ComStock and - uses sample_weight in the run for ResStock. + sample_weight_override (str, optional): Specify a custom sample_weight. Otherwise, the default is 1 for + ComStock and uses sample_weight in the run for ResStock. region_name (str, optional): the AWS region where the database exists. Defaults to 'us-west-2'. execution_history (str, optional): A temporary file to record which execution is run by the user, to help stop them. Will use .execution_history if not supplied. Generally, not required to supply a @@ -71,6 +72,7 @@ def __init__(self, athena_query_reuse (bool, optional): When true, Athena will make use of its built-in 7 day query cache. When false, it will not. Defaults to True. One use case to set this to False is when you have modified the underlying s3 data or glue schema and want to make sure you are not using the cached results. + kargs: Any other extra keyword argument supported by the QueryCore can be supplied here """ db_schema = db_schema or f"{buildstock_type}_default" self.params = BSQParams( @@ -79,7 +81,7 @@ def __init__(self, buildstock_type=buildstock_type, table_name=table_name, db_schema=db_schema, - sample_weight_override=sample_weight, + sample_weight_override=sample_weight_override, region_name=region_name, execution_history=execution_history, athena_query_reuse=athena_query_reuse diff --git a/buildstock_query/query_core.py b/buildstock_query/query_core.py index 49d7d15..ee285d1 100644 --- a/buildstock_query/query_core.py +++ b/buildstock_query/query_core.py @@ -972,7 +972,7 @@ def _simple_label(self, label: str): label = label.removeprefix(self.db_schema.column_prefix.output) return label - def _add_restrict(self, query, restrict, bs_only=False): + def _add_restrict(self, query, restrict, *, bs_only=False): if not restrict: return query where_clauses = [] @@ -988,6 +988,22 @@ def _add_restrict(self, query, restrict, bs_only=False): query = query.where(*where_clauses) return query + def _add_avoid(self, query, avoid, *, bs_only=False): + if not avoid: + return query + where_clauses = [] + for col_str, criteria in avoid: + col = self._get_column(col_str, table_name=self.bs_table) if bs_only else self._get_column(col_str) + if isinstance(criteria, (list, tuple)): + if len(criteria) > 1: + where_clauses.append(self._get_column(col).not_in(criteria)) + continue + else: + criteria = criteria[0] + where_clauses.append(col != criteria) + query = query.where(*where_clauses) + return query + def _get_name(self, col): if isinstance(col, tuple): return col[1] diff --git a/buildstock_query/schema/query_params.py b/buildstock_query/schema/query_params.py index 3644656..a4f1d4e 100644 --- a/buildstock_query/schema/query_params.py +++ b/buildstock_query/schema/query_params.py @@ -11,6 +11,7 @@ class AnnualQuery(BaseModel): sort: bool = True join_list: Sequence[tuple[AnyTableType, AnyColType, AnyColType]] = Field(default_factory=list) restrict: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = Field(default_factory=list) + avoid: Sequence[tuple[AnyColType, Union[str, int, Sequence[Union[int, str]]]]] = Field(default_factory=list) weights: Sequence[Union[str, tuple, AnyColType]] = Field(default_factory=list) get_quartiles: bool = False get_nonzero_count: bool = False diff --git a/buildstock_query/schema/run_params.py b/buildstock_query/schema/run_params.py index aa49c8f..95fefca 100644 --- a/buildstock_query/schema/run_params.py +++ b/buildstock_query/schema/run_params.py @@ -8,7 +8,7 @@ class RunParams(BaseModel): db_name: str table_name: Union[str, tuple[str, Optional[str], Optional[str]]] buildstock_type: Literal['resstock', 'comstock'] = 'resstock' - db_schema: Optional[str] = 'resstock_raw' + db_schema: Optional[str] = None sample_weight_override: Optional[Union[int, float]] = None region_name: str = 'us-west-2' execution_history: Optional[str] = None diff --git a/buildstock_query/schema/utilities.py b/buildstock_query/schema/utilities.py index 381752d..b1860df 100644 --- a/buildstock_query/schema/utilities.py +++ b/buildstock_query/schema/utilities.py @@ -2,13 +2,13 @@ from typing import Union, Any, Sequence from pydantic import BaseModel import sqlalchemy as sa -from sqlalchemy.sql.elements import Label +from sqlalchemy.sql.elements import Label, ColumnElement from sqlalchemy.sql.selectable import Subquery # from buildstock_query import BuildStockQuery # can't import due to circular import -SACol = sa.Column +SACol = Union[sa.Column, ColumnElement] SALabel = Label DBColType = Union[SALabel, SACol] DBTableType = sa.Table diff --git a/buildstock_query/tools/upgrades_visualizer/upgrades_visualizer.py b/buildstock_query/tools/upgrades_visualizer/upgrades_visualizer.py index f3ff5ca..564b259 100644 --- a/buildstock_query/tools/upgrades_visualizer/upgrades_visualizer.py +++ b/buildstock_query/tools/upgrades_visualizer/upgrades_visualizer.py @@ -35,8 +35,6 @@ # ]) transforms = [MultiplexerTransform()] -# yaml_path = "/Users/radhikar/Documents/eulpda/EULP-data-analysis/notebooks/EUSS-project-file-example.yml" -yaml_path = "/Users/radhikar/Documents/largee/resstock/project_national/fact_sheets_category_1.yml" opt_sat_path = "/Users/radhikar/Downloads/options_saturations.csv" default_end_use = "fuel_use_electricity_total_m_btu" @@ -56,12 +54,40 @@ def filter_cols(all_columns, prefixes=[], suffixes=[]): return cols -def _get_app(yaml_path: str, opt_sat_path: str, db_name: str = 'euss-tests', +def get_int_set(input_str): + """ + Convert "1,2,3-6,8,9" to [1, 2, 3, 4, 5, 6, 8, 9] + """ + if not input_str: + return set() + + pattern = r'^(\d+(-\d+)?,)*(\d+(-\d+)?)$' + if not re.match(pattern, input_str): + raise ValueError(f"{input_str} is not a valid pattern for list") + + result = set() + segments = input_str.split(',') + for segment in segments: + if '-' in segment: + start, end = map(int, segment.split('-')) + result |= set(range(start, end + 1)) + else: + result.add(int(segment)) + + return result + + +def _get_app(opt_sat_path: str, db_name: str = 'euss-tests', table_name: str = 'res_test_03_2018_10k_20220607', workgroup: str = 'largeee', - buildstock_type: str = 'resstock'): - viz_data = VizData(yaml_path=yaml_path, opt_sat_path=opt_sat_path, db_name=db_name, - run=table_name, workgroup=workgroup, buildstock_type=buildstock_type) + buildstock_type: str = 'resstock', + include_monthly: bool = True, + upgrades_selection_str: str = ''): + viz_data = VizData(opt_sat_path=opt_sat_path, db_name=db_name, + run=table_name, workgroup=workgroup, buildstock_type=buildstock_type, + include_monthly=include_monthly, + upgrades_selection=get_int_set(upgrades_selection_str) + ) return get_app(viz_data) @@ -70,7 +96,6 @@ def get_app(viz_data: VizData): upgrade2res = viz_data.upgrade2res # upgrade2res_monthly = viz_data.upgrade2res_monthly upgrade2name = viz_data.upgrade2name - resolution = 'annual' all_cols = viz_data.upgrade2res[0].columns emissions_cols = filter_cols(all_cols, suffixes=['_lb']) # end_use_cols = filter_cols(all_cols, ["end_use_", "energy_use__", "fuel_use_"]) @@ -93,7 +118,8 @@ def get_buildings(upgrade): return upgrade2res[int(upgrade)]['building_id'].to_list() def get_plot(end_use, value_type='mean', savings_type='', change_type='', - sync_upgrade=None, filter_bldg=None, group_cols=None, report_upgrade=None): + sync_upgrade=None, filter_bldg=None, group_cols=None, report_upgrade=None, + resolution='annual'): filter_bldg = filter_bldg or [] group_cols = group_cols or [] sync_upgrade = sync_upgrade or 0 @@ -115,7 +141,7 @@ def get_plot(end_use, value_type='mean', savings_type='', change_type='', dbc.Row([dbc.Col(html.H1("Upgrades Visualizer"), width='auto'), dbc.Col(html.Sup("beta"))]), # Add a row for annual, vs monthly vs seasonal plot radio buttons dbc.Row([dbc.Col(dbc.Label("Resolution: "), width='auto'), - dbc.Col(dcc.RadioItems(["annual", "monthly"], "annual", + dbc.Col(dcc.RadioItems(["annual", "monthly"] if viz_data.include_monthly else ["annual"], "annual", inline=True, id="radio_resolution"))]), dbc.Row([dbc.Col(dbc.Label("Visualization Type: "), width='auto'), @@ -278,7 +304,7 @@ def download_char(n_clicks, bldg_id, bldg_options, bldg_options2, chk_chars): bdf = viz_data.upgrade2res[0].filter(pl.col("building_id").is_in(set(bldg_ids))).select(char_cols) return dcc.send_bytes(bdf.write_csv, f"chars_{n_clicks}.csv") - def get_elligible_output_columns(category, fuel): + def get_elligible_output_columns(category, fuel, resolution): if category == 'energy': elligible_cols = viz_data.get_cleaned_up_end_use_cols(resolution, fuel) elif category == 'water': @@ -304,15 +330,6 @@ def get_elligible_output_columns(category, fuel): raise ValueError(f"Invalid tab {category}") return elligible_cols - @app.callback( - Output('radio_resolution', 'options'), - Input('radio_resolution', 'value'), - ) - def update_resolution(res): - nonlocal resolution - resolution = res - return ['annual', 'monthly'] - @app.callback( Output('dropdown_enduse', "options"), Output('dropdown_enduse', "value"), @@ -322,7 +339,7 @@ def update_resolution(res): Input('radio_resolution', 'value') ) def update_enduse_options(view_tab, fuel_type, current_enduse, resolution): - elligible_cols = get_elligible_output_columns(view_tab, fuel_type) + elligible_cols = get_elligible_output_columns(view_tab, fuel_type, resolution) enduse = current_enduse if current_enduse in elligible_cols else elligible_cols[0] return sorted(elligible_cols), enduse @@ -772,11 +789,12 @@ def show_char_report(bldg_id, bldg_options, bldg_options2, inp_char: list[str], Input('input_building2', 'options'), Input('chk-graph', 'value'), State("uirevision", "data"), - State('report_upgrade', 'value') + State('report_upgrade', 'value'), + State('radio_resolution', 'value') ) def update_figure(view_tab, grp_by, fuel, enduse, graph_type, savings_type, chng_type, sync_upgrade, selected_bldg, bldg_options, bldg_options2, chk_graph, uirevision, - report_upgrade): + report_upgrade, resolution): nonlocal download_csv_df if dash.callback_context.triggered_id == 'input_building2' and "Graph" not in chk_graph: raise PreventUpdate() @@ -798,7 +816,8 @@ def update_figure(view_tab, grp_by, fuel, enduse, graph_type, savings_type, chng filter_bldg = [int(b) for b in bldg_options] new_figure, report_df = get_plot(full_name, graph_type, savings_type, - chng_type, sync_upgrade, filter_bldg, grp_by, report_upgrade) + chng_type, sync_upgrade, filter_bldg, grp_by, report_upgrade, + resolution) uirevision = uirevision or "default" new_figure.update_layout(uirevision=uirevision) @@ -813,30 +832,38 @@ def update_figure(view_tab, grp_by, fuel, enduse, graph_type, savings_type, chng def main(): print("Welcome to Upgrades Visualizer.") defaults = load_script_defaults("project_info") - yaml_file = inquirer.text(message="Please enter path to the buildstock configuration yml file: ", - default=defaults.get("yaml_file", "")).execute() - opt_sat_file = inquirer.text(message="Please enter path to the options saturation csv file: ", + opt_sat_file = inquirer.text(message="Please enter path to the options saturation csv file:", default=defaults.get("opt_sat_file", "")).execute() - workgroup = inquirer.text(message="Please Athena workgroup name: ", + workgroup = inquirer.text(message="Please enter Athena workgroup name:", default=defaults.get("workgroup", "")).execute() - db_name = inquirer.text(message="Please enter database_name " - "(found in postprocessing:aws:athena in the buildstock configuration file): ", + db_name = inquirer.text(message="Please enter database name " + "(found in postprocessing:aws:athena in the buildstock configuration file):", default=defaults.get("db_name", "")).execute() table_name = inquirer.text(message="Please enter table name (same as output folder name; found under " "output_directory in the buildstock configuration file). [Enter two names " - "separated by comma if baseline and upgrades are in different run] :", + "separated by comma if baseline and upgrades are in different run]:", default=defaults.get("table_name", "") ).execute() - defaults.update({"yaml_file": yaml_file, "opt_sat_file": opt_sat_file, "workgroup": workgroup, - "db_name": db_name, "table_name": table_name}) + monthly_default = defaults.get("include_monthly", True) + default_str = "Yes" if monthly_default else "No" + include_monthly = inquirer.confirm(f"Do you want to include monthly plots ({default_str})?", + default=monthly_default, + ).execute() + upgrades_selection = inquirer.text(message="Please enter upgrade ids separated by comma and dashes " + "(example: `1-3,5,7,8-9`) or leave empty to include all upgrades.", + default=defaults.get("upgrades_selection", "")).execute() + defaults.update({"opt_sat_file": opt_sat_file, "workgroup": workgroup, + "db_name": db_name, "table_name": table_name, "include_monthly": include_monthly, + "upgrades_selection": upgrades_selection}) save_script_defaults("project_info", defaults) if ',' in table_name: table_name = table_name.split(',') - app = _get_app(yaml_path=yaml_file, - opt_sat_path=opt_sat_file, + app = _get_app(opt_sat_path=opt_sat_file, workgroup=workgroup, db_name=db_name, - table_name=table_name) + table_name=table_name, + include_monthly=include_monthly, + upgrades_selection_str=upgrades_selection) app.run_server(debug=False, port=8005) diff --git a/buildstock_query/tools/upgrades_visualizer/viz_data.py b/buildstock_query/tools/upgrades_visualizer/viz_data.py index ea73cce..76a1b1a 100644 --- a/buildstock_query/tools/upgrades_visualizer/viz_data.py +++ b/buildstock_query/tools/upgrades_visualizer/viz_data.py @@ -3,6 +3,7 @@ import polars as pl from buildstock_query.tools.upgrades_visualizer.plot_utils import PlotParams from typing import Union +import datetime num2month = {1: "January", 2: "February", 3: "March", 4: "April", 5: "May", 6: "June", 7: "July", 8: "August", @@ -12,12 +13,14 @@ class VizData: @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) - def __init__(self, yaml_path: str, opt_sat_path: str, + def __init__(self, opt_sat_path: str, db_name: str, run: Union[str, tuple[str, str]], workgroup: str = 'largeee', buildstock_type: str = 'resstock', - skip_init: bool = False): + skip_init: bool = False, + include_monthly: bool = True, + upgrades_selection: set = set()): if isinstance(run, tuple): # Allows for separate baseline and upgrade runs # In this case, run[0] is the baseline run and run[1] is the upgrade run @@ -39,24 +42,32 @@ def __init__(self, yaml_path: str, opt_sat_path: str, buildstock_type=buildstock_type, table_name=table, skip_reports=skip_init) - self.yaml_path = yaml_path self.opt_sat_path = opt_sat_path + self.upgrades_selection = upgrades_selection + self.include_monthly = include_monthly if not skip_init: self.initialize() def initialize(self): - self.ua = self.main_run.get_upgrades_analyzer(yaml_file=self.yaml_path, - opt_sat_file=self.opt_sat_path) + available_upgrades = self.main_run.get_available_upgrades() + available_upgrades = [int(u) for u in available_upgrades] + if not self.upgrades_selection: + self.upgrades_selection = set(available_upgrades) + if (unavailable_upgrades := self.upgrades_selection - set(available_upgrades)): + raise ValueError(f"Upgrades {unavailable_upgrades} is not available in the run") + available_upgrades = self.upgrades_selection self.report = pl.from_pandas(self.main_run.report.get_success_report(), include_index=True) - self.available_upgrades = list(sorted(set(self.report["upgrade"].unique()) - {0})) - self.upgrade2name = {indx+1: f"Upgrade {indx+1}: {upgrade['upgrade_name']}" for indx, - upgrade in enumerate(self.ua.cfg.get('upgrades', []))} - self.upgrade2name[0] = "Upgrade 0: Baseline" - self.upgrade2shortname = {indx+1: f"Upgrade {indx+1}" for indx, - upgrade in enumerate(self.ua.cfg.get('upgrades', []))} + self.available_upgrades = list(set([int(u) for u in available_upgrades]) - {0}) + self.upgrade2name = {0: "Upgrade 0: Baseline"} + if self.available_upgrades: + upgrade_names = self.main_run.get_upgrade_names() + self.upgrade2name |= upgrade_names + + self.upgrade2shortname = {indx+1: f"Upgrade {indx+1}" for indx in range(len(self.available_upgrades) + 1)} self.chng2bldg = self.get_change2bldgs() self.init_annual_results() - self.init_monthly_results(self.metadata_df) + if self.include_monthly: + self.init_monthly_results(self.metadata_df) self.all_upgrade_plotting_df = None def run_obj(self, upgrade: int) -> BuildStockQuery: @@ -122,10 +133,23 @@ def init_monthly_results(self, metadata_df): ts_cols = self._get_ts_enduse_cols(upgrade) print(f"Getting monthly results for {upgrade}") run_obj = self.run_obj(upgrade) - monthly_vals = run_obj.agg.aggregate_timeseries(enduses=ts_cols, - group_by=[run_obj.bs_bldgid_column], - upgrade_id=upgrade, - timestamp_grouping_func='month') + monthly_vals_query = run_obj.agg.aggregate_timeseries(get_query_only=True, + enduses=ts_cols, + group_by=[run_obj.bs_bldgid_column], + upgrade_id=upgrade, + timestamp_grouping_func='month', + ) + if monthly_vals_query in run_obj._query_cache: + monthly_vals = run_obj._query_cache[monthly_vals_query].copy() + else: + month_year = f"{datetime.datetime.now().strftime('%b%Y')}" + s3_unload_path = f"s3://resstock-core/athena_unload_results/{month_year}/" + pd_cursor = run_obj._conn.cursor(unload=True, s3_staging_dir=s3_unload_path).execute( + monthly_vals_query, + result_reuse_enable=True, + result_reuse_minutes=60 * 24 * 7) + monthly_vals = pd_cursor.as_pandas() + run_obj._query_cache[monthly_vals_query] = monthly_vals run_obj.save_cache() monthly_df = pl.from_pandas(monthly_vals, include_index=True) monthly_df = monthly_df.with_columns(pl.col('time').dt.month().alias("month")) @@ -193,7 +217,7 @@ def get_plotting_df(self, upgrade: int, .then(0) .otherwise(pl.col("value")) .alias("value") - ) + ) return up_df def get_all_cols(self, resolution: str) -> list[str]: diff --git a/buildstock_query/utility_query.py b/buildstock_query/utility_query.py index db767ad..a2f68d6 100644 --- a/buildstock_query/utility_query.py +++ b/buildstock_query/utility_query.py @@ -165,14 +165,14 @@ def aggregate_ts_by_eiaid(self, params: UtilityTSQuery): params=params) @validate_arguments(config=dict(arbitrary_types_allowed=True, smart_union=True)) - def aggregate_unit_counts_by_eiaid(self, *, eiaid_list: list[str], + def aggregate_unit_counts_by_eiaid(self, *, eiaid_list: Optional[list[str]] = None, group_by: list[Union[AnyColType, tuple[str, str]]] = [], get_query_only: bool = False): """ Returns the counts of the number of dwelling units, grouping by eiaid and other additional group_by columns if provided. Args: - eiaid_list: The list of utility ids (EIAID) to aggregate for + eiaid_list: The list of utility ids (EIAID) to aggregate for. If not provided, all the eiaids will be used. group_by: Additional columns to group by mapping_version: Version of eiaid mapping to use. After the spatial refactor upgrade, version two should be used @@ -185,7 +185,7 @@ def aggregate_unit_counts_by_eiaid(self, *, eiaid_list: list[str], group_by = group_by or [] eiaid_map_table_name, map_baseline_column, map_eiaid_column = self.get_eiaid_map() group_by = [] if group_by is None else group_by - restrict = [('eiaid', eiaid_list)] + restrict = [('eiaid', eiaid_list)] if eiaid_list else [] eiaid_col = self._bsq._get_column("eiaid", eiaid_map_table_name) result = self._agg.aggregate_annual(enduses=[], group_by=[eiaid_col] + group_by, sort=True, diff --git a/tests/generate_reference_viz_data_files.py b/tests/generate_reference_viz_data_files.py index fd74591..52b671c 100644 --- a/tests/generate_reference_viz_data_files.py +++ b/tests/generate_reference_viz_data_files.py @@ -18,10 +18,8 @@ def save_bsq_obj(bsq_obj: BuildStockQuery, cache_name=None): def save_viz_data_reference_data(): folder_path = pathlib.Path(__file__).parent.resolve() - yaml_path = str(folder_path / "reference_files" / "example_category_1.yml") opt_sat_path = str(folder_path / "reference_files" / "options_saturations.csv") viz_data = VizData( - yaml_path=yaml_path, opt_sat_path=opt_sat_path, workgroup='largeee', db_name='largeee_test_runs', diff --git a/tests/reference_files/c2e8c98cee7aca046d23eaea93afcb1393eda217d2f98a99ed74388852ac9b8f_query_cache.pkl b/tests/reference_files/c2e8c98cee7aca046d23eaea93afcb1393eda217d2f98a99ed74388852ac9b8f_query_cache.pkl index 418d92c..e447ad3 100644 Binary files a/tests/reference_files/c2e8c98cee7aca046d23eaea93afcb1393eda217d2f98a99ed74388852ac9b8f_query_cache.pkl and b/tests/reference_files/c2e8c98cee7aca046d23eaea93afcb1393eda217d2f98a99ed74388852ac9b8f_query_cache.pkl differ diff --git a/tests/reference_files/small_run_baseline_20230810_100_baseline.pkl b/tests/reference_files/small_run_baseline_20230810_100_baseline.pkl index 9df5da7..8c430b4 100644 Binary files a/tests/reference_files/small_run_baseline_20230810_100_baseline.pkl and b/tests/reference_files/small_run_baseline_20230810_100_baseline.pkl differ diff --git a/tests/reference_files/small_run_baseline_20230810_100_timeseries.pkl b/tests/reference_files/small_run_baseline_20230810_100_timeseries.pkl index a420564..c02d4a2 100644 Binary files a/tests/reference_files/small_run_baseline_20230810_100_timeseries.pkl and b/tests/reference_files/small_run_baseline_20230810_100_timeseries.pkl differ diff --git a/tests/reference_files/small_run_category_1_20230616_timeseries.pkl b/tests/reference_files/small_run_category_1_20230616_timeseries.pkl index 75e0621..8e3150e 100644 Binary files a/tests/reference_files/small_run_category_1_20230616_timeseries.pkl and b/tests/reference_files/small_run_category_1_20230616_timeseries.pkl differ diff --git a/tests/reference_files/small_run_category_1_20230616_upgrades.pkl b/tests/reference_files/small_run_category_1_20230616_upgrades.pkl index 5739188..8d8aaeb 100644 Binary files a/tests/reference_files/small_run_category_1_20230616_upgrades.pkl and b/tests/reference_files/small_run_category_1_20230616_upgrades.pkl differ diff --git a/tests/test_BuildStockQuery.py b/tests/test_BuildStockQuery.py index b051385..6231c26 100644 --- a/tests/test_BuildStockQuery.py +++ b/tests/test_BuildStockQuery.py @@ -247,7 +247,7 @@ def test_aggregate_annual(temp_history_file): db_name='buildstock_testing', buildstock_type='resstock', table_name='res_n250_hrly_v1', - sample_weight=29.1, + sample_weight_override=29.1, execution_history=temp_history_file, skip_reports=True ) @@ -390,7 +390,7 @@ def test_aggregate_ts(temp_history_file): buildstock_type='resstock', table_name='res_n250_hrly_v1', execution_history=temp_history_file, - sample_weight=29.1, + sample_weight_override=29.1, skip_reports=True ) my_athena2.get_available_upgrades = lambda: ['0'] diff --git a/tests/test_Viz.py b/tests/test_Viz.py index 93d7d24..e040278 100644 --- a/tests/test_Viz.py +++ b/tests/test_Viz.py @@ -20,10 +20,8 @@ class TestViz: @pytest.fixture(scope='class') def viz_data(self): folder_path = pathlib.Path(__file__).parent.resolve() - yaml_path = str(folder_path / "reference_files" / "example_category_1.yml") opt_sat_path = str(folder_path / "reference_files" / "options_saturations.csv") mydata = VizData( - yaml_path=yaml_path, opt_sat_path=opt_sat_path, workgroup='largeee', db_name='largeee_test_runs',