Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
dwreeves committed Jan 7, 2025
1 parent d36580f commit 197b51b
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 89 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
sudo apt-get install
chmod +x ./run
uv venv
uv sync --group python-dev
uv sync --extra python-dev
uv pip install -U "dbt-core==$DBT_CORE_VERSION" "dbt-${DBT_TARGET}==$DBT_CORE_VERSION"
env:
UV_NO_SYNC: true
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Official support for Clickhouse!
- Rename `format=` and `format_options=` to `output=` and `output_options=` to make the API consistent with **dbt_pca**.
- Allow for setting method and output options globally with `vars:`

### `0.2.6`

Expand Down
64 changes: 48 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ group by
- Snowflake
- DuckDB
- Clickhouse
- Redshift
- Postgres\*

If **dbt_linreg** does not work in your database tool, please let me know in a bug report.
Expand Down Expand Up @@ -226,24 +227,38 @@ This has been deprecated to make **dbt_linreg**'s API more consistent with **dbt

### Options for `output='long'`

- **round** (default = `None`): If not None, round all coefficients to `round` number of digits.
- **constant_name** (default = `'const'`): String name that refers to constant term.
- **variable_column_name** (default = `'variable_name'`): Column name storing strings of variable names.
- **coefficient_column_name** (default = `'coefficient'`): Column name storing model coefficients.
- **strip_quotes** (default = `True`): If true, strip outer quotes from column names if provided; if false, always use string literals.
- **round** (`int`; default = `None`): If not None, round all coefficients to `round` number of digits.
- **constant_name** (`string`; default = `'const'`): String name that refers to constant term.
- **variable_column_name** (`string`; default = `'variable_name'`): Column name storing strings of variable names.
- **coefficient_column_name** (`string`; default = `'coefficient'`): Column name storing model coefficients.
- **strip_quotes** (`bool`; default = `True`): If true, strip outer quotes from column names if provided; if false, always use string literals.

These options are available for `output='long'` only when `method='chol'`:

- **calculate_standard_error** (default = `True if not alpha else False`): If true, provide the standard error in the output.
- **standard_error_column_name** (default = `'standard_error'`): Column name storing the standard error for the parameter.
- **t_statistic_column_name** (default = `'t_statistic'`): Column name storing the t-statistic for the parameter.
- **calculate_standard_error** (`bool`; default = `True if not alpha else False`): If true, provide the standard error in the output.
- **standard_error_column_name** (`string`; default = `'standard_error'`): Column name storing the standard error for the parameter.
- **t_statistic_column_name** (`string`; default = `'t_statistic'`): Column name storing the t-statistic for the parameter.

### Options for `output='wide'`

- **round** (default = `None`): If not None, round all coefficients to `round` number of digits.
- **constant_name** (default = `'const'`): String name that refers to constant term.
- **variable_column_prefix** (default = `None`): If not None, prefix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)
- **variable_column_suffix** (default = `None`): If not None, suffix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)
- **round** (`int`; default = `None`): If not None, round all coefficients to `round` number of digits.
- **constant_name** (`string`; default = `'const'`): String name that refers to constant term.
- **variable_column_prefix** (`string`; default = `None`): If not None, prefix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)
- **variable_column_suffix** (`string`; default = `None`): If not None, suffix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.)

## Setting output options globally

Output options can be set globally via `vars`, e.g. in your `dbt_project.yml`:

```yaml
# dbt_project.yml
vars:
dbt_linreg:
output_options:
round: 5
```

Output options passed via `ols()` always take precedence over globally set output options.

# Methods and method options

Expand All @@ -262,8 +277,9 @@ This method calculates regression coefficients using the Moore-Penrose pseudo-in

Specify these in a dict using the `method_options=` kwarg:

- **safe** (default = `True`): If True, returns null coefficients instead of an error when X is perfectly multicollinear. If False, a negative value will be passed into a SQRT(), and most SQL engines will raise an error when this happens.
- **subquery_optimization** (default: `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened.
- **safe** (`bool`; default: `True`): If True, returns null coefficients instead of an error when X is perfectly multicollinear. If False, a negative value will be passed into a SQRT(), and most SQL engines will raise an error when this happens.
- **subquery_optimization** (`bool`; default = `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened.
- **intra_select_aliasing** (`bool`; default = `[depends on db]`): If True, within a single select statement, column aliases are used to refer to other columns created during that select. This can significantly reduce the text of a SQL query, but not all SQL engines support this. By default, for all databases officially supported by **dbt_linreg**, the best option is already selected. For unsupported databases, the default is `False` for broad compatibility, so if you are running **dbt_linreg** in an officially unsupported database engine which supports this feature, you may want to modify this option globally in your `vars` to be `true`.

## `fwl` method

Expand Down Expand Up @@ -299,11 +315,27 @@ So when should you use `fwl`? The main use case is in OLTP systems (e.g. Postgre

- Regression coefficients in Postgres are always `numeric` types.

### Possible future features
## Setting method options globally

Method options can be set globally via `vars`, e.g. in your `dbt_project.yml`. Each `method` gets its own config, e.g. `dbt_linreg: chol: ...`. Here is an example:

```yaml
# dbt_project.yml
vars:
dbt_linreg:
method_options:
chol:
intra_select_aliasing: true
```

Method options passed via `ols()` always take precedence over globally set method options.

# Possible future features

Some things that could happen in the future:

- Weighted least squares (WLS)
- Efficient multivariate regression (i.e. multiple endogenous vectors sharing a single design matrix)
- P-values
- Heteroskedasticity robust standard errors
- Recursive CTE implementations + long formatted inputs
Expand Down Expand Up @@ -332,7 +364,7 @@ There is no closed-form solution to L1 regularization, which makes it very very

### Is the `group_by=[...]` argument like categorical variables / one-hot encodings?

No. You should think of the group by more as a [seemingly unrelated regressions](https://en.wikipedia.org/wiki/Seemingly_unrelated_regressions) implementation than as a categorical variable implementation. It's running multiple regressions and each individual partition is its own `y` vector and `X` matrix. This is _not_ a replacement for dummy variables.
No. The `group_by` runs a linear regressions within each group, and each individual partition is its own `y` vector and `X` matrix. This is _not_ a replacement for dummy variables.

### Why aren't categorical variables / one-hot encodings supported?

Expand Down
2 changes: 1 addition & 1 deletion integration_tests/tests/test_long_format_options.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ with
base as (

select strip_quotes, vname, co
from {{ ref("long_output_options") }}
from {{ ref("long_format_options") }}

),

Expand Down
2 changes: 1 addition & 1 deletion integration_tests/tests/test_wide_format_options.sql
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ with base as (
"fooxa_bar",
fooxb_bar
from
{{ ref("wide_output_options") }}
{{ ref("wide_format_options") }}

)

Expand Down
46 changes: 27 additions & 19 deletions macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,33 @@
in the query you wrote.
If that's not available, the previous calc will be in the dict. #}
{% macro _cell_or_alias(i, j, d, prefix=none) %}
{% macro _cell_or_alias(i, j, d, prefix=none, isa=none) %}
{% if isa is not none %}
{% if isa %}
{{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
{% else %}
{{ return(d[(i, j)]) }}
{% endif %}
{% endif %}
{{ return(
adapter.dispatch('_cell_or_alias', 'dbt_linreg')
(i, j, d, prefix)
(i, j, d, prefix, isa)
) }}
{% endmacro %}
{% macro default___cell_or_alias(i, j, d, prefix=none) %}
{% macro default___cell_or_alias(i, j, d, prefix=none, isa=none) %}
{{ return(d[(i, j)]) }}
{% endmacro %}
{% macro snowflake___cell_or_alias(i, j, d, prefix=none) %}
{% macro snowflake___cell_or_alias(i, j, d, prefix=none, isa=none) %}
{{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
{% endmacro %}
{% macro duckdb___cell_or_alias(i, j, d, prefix=none) %}
{% macro duckdb___cell_or_alias(i, j, d, prefix=none, isa=none) %}
{{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
{% endmacro %}
{% macro clickhouse___cell_or_alias(i, j, d, prefix=none) %}
{% macro clickhouse___cell_or_alias(i, j, d, prefix=none, isa=none) %}
{{ return((prefix if prefix is not none else '') ~ 'i' ~ i ~ 'j' ~ j) }}
{% endmacro %}
Expand All @@ -46,7 +53,7 @@
{{ return('sqrt('~x~')') }}
{% endmacro %}
{% macro _cholesky_decomposition(li, subquery_optimization=True, safe=True) %}
{% macro _cholesky_decomposition(li, subquery_optimization=true, safe=true, isa=none) %}
{% set d = {} %}
{% for i in li %}
{% for j in range(li[0], i + 1) %}
Expand All @@ -57,18 +64,18 @@
{% set ns.s = 'x'~j~'x'~i %}
{% for k in range(li[0], j) %}
{% if subquery_optimization and i != j %}
{% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d)~'*i'~j~'j'~k %}
{% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, isa=isa)~'*i'~j~'j'~k %}
{% else %}
{% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d)~'*'~dbt_linreg._cell_or_alias(i=j, j=k, d=d) %}
{% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, isa=isa)~'*'~dbt_linreg._cell_or_alias(i=j, j=k, d=d, isa=isa) %}
{% endif %}
{% endfor %}
{% if i == j %}
{% do d.update({(i, j): dbt_linreg._safe_sqrt(x=ns.s, safe=safe)}) %}
{% else %}
{% if adapter.type() == "postgres" %}
{% do d.update({(i, j): '('~ns.s~')/nullif('~dbt_linreg._cell_or_alias(i=j, j=j, d=d) ~ ', 0)'}) %}
{% if safe %}
{% do d.update({(i, j): '('~ns.s~')/nullif('~dbt_linreg._cell_or_alias(i=j, j=j, d=d, isa=isa) ~ ', 0)'}) %}
{% else %}
{% do d.update({(i, j): '('~ns.s~')/'~dbt_linreg._cell_or_alias(i=j, j=j, d=d)}) %}
{% do d.update({(i, j): '('~ns.s~')/'~dbt_linreg._cell_or_alias(i=j, j=j, d=d, isa=isa)}) %}
{% endif %}
{% endif %}
{% endif %}
Expand All @@ -77,7 +84,7 @@
{{ return(d) }}
{% endmacro %}
{% macro _forward_substitution(li, safe=true) %}
{% macro _forward_substitution(li, safe=true, isa=none) %}
{% set d = {} %}
{% for i, j in modules.itertools.combinations_with_replacement(li, 2) %}
{% set ns = namespace() %}
Expand All @@ -86,7 +93,7 @@
{% else %}
{% set ns.numerator = '(' %}
{% for k in range(i, j) %}
{% set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, prefix="inv_") %}
{% set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*'~dbt_linreg._cell_or_alias(i=i, j=k, d=d, prefix="inv_", isa=isa) %}
{% endfor %}
{% set ns.numerator = ns.numerator~')' %}
{% endif %}
Expand Down Expand Up @@ -121,9 +128,10 @@
alpha=alpha
)) }}
{%- endif %}
{%- set subquery_optimization = method_options.get('subquery_optimization', True) %}
{%- set safe_mode = method_options.get('safe', True) %}
{%- set calculate_standard_error = output_options.get('calculate_standard_error', (not alpha)) and output == 'long' %}
{%- set subquery_optimization = dbt_linreg._get_method_option('chol', 'subquery_optimization', method_options, true) %}
{%- set safe_mode = dbt_linreg._get_method_option('chol', 'safe', method_options, true) %}
{% set isa = dbt_linreg._get_method_option('chol', 'intra_select_aliasing', method_options) %}
{%- set calculate_standard_error = dbt_linreg._get_output_option('calculate_standard_error', output_options, (not alpha) and output == 'long') %}
{%- if alpha and calculate_standard_error %}
{% do log(
'Warning: Standard errors are NOT designed to take into account ridge regression regularization.'
Expand Down Expand Up @@ -175,7 +183,7 @@ _dbt_linreg_xtx as (
),
_dbt_linreg_chol as (
{%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_mode) %}
{%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_mode, isa=isa) %}
{%- if subquery_optimization %}
{%- for i in (xcols | reverse) %}
select
Expand Down Expand Up @@ -206,7 +214,7 @@ _dbt_linreg_chol as (
),
_dbt_linreg_inverse_chol as (
{#- The optimal way to calculate is to do each diagonal at a time. #}
{%- set d = dbt_linreg._forward_substitution(li=xcols, safe=safe_mode) %}
{%- set d = dbt_linreg._forward_substitution(li=xcols, safe=safe_mode, isa=isa) %}
{%- if subquery_optimization %}
{%- for gap in (range(0, upto) | reverse) %}
select *,
Expand Down
Loading

0 comments on commit 197b51b

Please sign in to comment.