diff --git a/.idea/dbt_linreg.iml b/.idea/dbt_linreg.iml index 1154642..5c0ed1b 100644 --- a/.idea/dbt_linreg.iml +++ b/.idea/dbt_linreg.iml @@ -1,7 +1,14 @@ - + + + + + + + + @@ -20,7 +27,7 @@ diff --git a/CHANGELOG.md b/CHANGELOG.md index daf067a..e0ecb78 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ ### `0.1.2` +- Add `chol` method to `dbt_linreg.ols()`, and also set as the default method. (This method is significantly faster than `fwl`, and has a few other benefits.) +- Add standard error column in `long` format for `chol` method. + +### `0.1.2` + - Added the ability to turn off/on the constant term with `add_constant: bool = True` kwarg. - Fixed error that occurred when rendering a 1-variable ridge regression. diff --git a/README.md b/README.md index 613acc8..d156e3b 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ Add this the `packages:` list your dbt project's `packages.yml`: ```yaml - package: "dwreeves/dbt_linreg" - version: "0.1.2" + version: "0.2.0" ``` The full file will look something like this: @@ -69,14 +69,14 @@ select * from {{ Output: -|variable_name|coefficient| -|---|---| -|const|10.0| -|xa|3.0| -|xb|5.0| -|xc|7.0| +|variable_name|coefficient|standard_error|t_statistic| +|---|---|---| +|const|10.0|0.00462|2163.27883| +|xa|5.0|0.46226|10.81639| +|xb|7.0|0.46226|15.14295| +|xc|9.0|0.46226|19.46951| -Note: `simple_matrix` is one of the test cases. +Note: `simple_matrix` is one of the test cases, so you can try this yourself! Standard errors are constant across `xa`, `xb`, `xc`, because `simple_matrix` is orthonormal. ### Complex example @@ -188,7 +188,7 @@ def ols( format_options: Optional[dict[str, Any]] = None, group_by: Optional[Union[str, list[str]]] = None, alpha: Optional[Union[float, list[float]]] = None, - method: Literal['fwl'] = 'fwl' + method: Literal['chol', 'fwl'] = 'chol' ): ... ``` @@ -199,13 +199,14 @@ Where: - **endog**: The endogenous variable / y variable / target variable of the regression. (You can also specify `y=...` instead of `endog=...` if you prefer.) - **exog**: The endogenous variable / y variable / target variable of the regression. (You can also specify `x=...` instead of `exog=...` if you prefer.) - **add_constant**: If true, a constant term is added automatically to the regression. -- **format**: Either "wide" or "long" format for coefficients. +- **format**: Either "wide" or "long" format for coefficients. See **Formats and format options** for more. - If `wide`, the variables span the columns with their original variable names, and the coefficients fill a single row. - If `long`, the coefficients are in a single column called `coefficient`, and the variable names are in a single column called `variable_name`. - **format_options**: See **Formats and format options** section for more. - **group_by**: If specified, the regression will be grouped by these variables, and individual regressions will run on each group. - **alpha**: If not null, the regression will be run as a ridge regression with a penalty of `alpha`. See **Notes** section for more information. -- **method**: The way the regression is calculated. Right now, only `'fwl'` is a valid option. See **FAQ** section for implementation details. +- **method**: The method used to calculate the regression. See **Methods and method options** for more. +- **method_options**: Options specific to the estimation method. See **Methods and method options** for more. # Formats and format options @@ -219,10 +220,16 @@ All formats have their own format options, which can be passed into the `format_ - **round** (default = `None`): If not None, round all coefficients to `round` number of digits. - **constant_name** (default = `'const'`): String name that refers to constant term. -- **variable_column_name** (default = `'variable_name'`): Column name storing strings of . +- **variable_column_name** (default = `'variable_name'`): Column name storing strings of variable names. - **coefficient_column_name** (default = `'coefficient'`): Column name storing model coefficients. - **strip_quotes** (default = `True`): If true, strip outer quotes from column names if provided; if false, always use string literals. +These options are only available when `method='chol'`: + +- **calculate_standard_error** (default = `'calculate_standard_error'`): If true, provide the standard error in the output. +- **standard_error_column_name** (default = `'standard_error'`): Column name storing the standard error for the parameter. +-- **t_statistic_column_name** (default = `'t_statistic'`): Column name storing the t-statistic for the parameter. + ### Options for `format='wide'` - **round** (default = `None`): If not None, round all coefficients to `round` number of digits. @@ -230,9 +237,43 @@ All formats have their own format options, which can be passed into the `format_ - **variable_column_prefix** (default = `None`): If not None, prefix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.) - **variable_column_suffix** (default = `None`): If not None, suffix all variable columns with this. (Does NOT delimit, so make sure to include your own underscore if you'd want that.) -# Notes +# Methods and method options + +There are currently two valid methods for calculating regression coefficients: + +- `chol`: Uses Cholesky decomposition to calculate the pseudo-inverse. +- `fwl`: Uses a "Frisch univariate regressions + +## `chol` method + +**👍 This is the suggested method (and the default) for calculating regressions!** + +This method calculates regression coefficients using the Moore-Penrose pseudo-inverse, and the inverse of **X'X** is calculated using Cholesky decomposition, hence it is referred to as `chol`. + +### Options for `method='chol'` + +Specify these in a dict using the `method_options=` kwarg: + +- **safe** (default = `True`): If True, returns null coefficients instead of an error when X is perfectly multicollinear. If False, a negative value will be passed into a SQRT(), and most SQL engines will raise an error when this happens. +- **subquery_optimization** (default = `True`): If True, nested subqueries are used during some of the steps to optimize the query speed. If false, the query is flattened. Note that turning this off can significantly degrade performance. + +## `fwl` method + +**This method is generally not recommended.** + +Simple univariate regression coefficients are simply `covar_pop(y, x) / var_pop(x)`. + +The multiple regression implementation uses a technique described in section `3.2.3 Multiple Regression from Simple Univariate Regression` of TEoSL ([source](https://hastie.su.domains/Papers/ESLII.pdf#page=71)). Econometricians know this as the Frisch-Waugh-Lowell theorem, hence the method is referred to as `fwl` internally in the code base. + +Ridge regression is implemented using the augmentation technique described in Exercise 12 of Chapter 3 of TEoSL ([source](https://hastie.su.domains/Papers/ESLII.pdf#page=115)). -- ⚠️ **Please be aware that this implementation is very inefficient for large numbers of columns!** I believe the time complexity of the Jinja templating is O(2^K). I would suggest not going over 5 or 6 features for a single regression. +There are a few reasons why this method is discouraged over the `chol` method: + +- 🐌 It tends to be much slower, and struggles to efficiently calculate large number of columns. +- 📊 It does not calculate standard errors. +- 😕 For ridge regression, coefficients are not accurate; they tend to be off by a magnitude of ~0.01%. + +# Notes - ⚠️ **If your coefficients are null, it does not mean dbt_linreg is broken, it most likely means your feature columns are perfectly multicollinear.** If you are 100% sure that is not the issue, please file a bug report with a minimally reproducible example. @@ -241,8 +282,6 @@ All formats have their own format options, which can be passed into the `format_ - An array input (e.g. `alpha=[0.01, 0.02, 0.03, 0.04, 0.05]`) will apply an alpha of `0.01` to the first column, `0.02` to the second column, etc. - `alpha` is equivalent to what TEoSL refers to as "lambda," times the sample size N. That is to say: `α ≡ λ * N`. -- Ridge regression coefficients tend to be slightly off Statsmodels's ridge regression coefficients, but by no more than a 0.01% deviation in my experience (this 0.01% threshold is enforced in the integration tests). - ### Possible future features Some things I am thinking about working on down the line: @@ -255,11 +294,7 @@ Some things I am thinking about working on down the line: ### How does this work? -Simple univariate regression coefficients are simply `covar_pop(y, x) / var_pop(x)`. - -The multiple regression implementation uses a technique described in section `3.2.3 Multiple Regression from Simple Univariate Regression` of TEoSL ([source](https://hastie.su.domains/Papers/ESLII.pdf#page=71)). Econometricians know this as the Frisch-Waugh-Lowell theorem, hence the method is referred to as `'fwl'` internally in the code base. - -Ridge regression is implemented using the augmentation technique described in Exercise 12 of Chapter 3 of TEoSL ([source](https://hastie.su.domains/Papers/ESLII.pdf#page=115)). +See **Methods and method options** section for a full breakdown of each linear regression implementation. All approaches were validated using Statsmodels `sm.OLS()`. Note that the ridge regression coefficients differ very slightly from Statsmodels's outputs for currently unknown reasons, but the coefficients are very close (I enforce a `<0.01%` deviation from Statsmodels's ridge regression coefficients in my integration tests). @@ -287,6 +322,12 @@ Note that you couldn't simply add categorical variables in the same list as nume If you'd like to regress on a categorical variable, for now you'll need to do your own feature engineering, e.g. `(foo = 'bar')::int as foo_bar` +### Why are there no p-values? + +This is planned for the future, so stay tuned! P-values would require a lookup on a dimension table, which is a significant amount of work to manage nicely, but I hope to get to it soon. + +In the meanwhile, you can implement this yourself-- just create a dimension table that left joins a t-statistic on a half-open interval to lookup a p-value. + # Trademark & Copyright dbt is a trademark of dbt Labs. diff --git a/integration_tests/models/chol_decomp_v1.sql b/integration_tests/models/chol_decomp_v1.sql deleted file mode 100644 index 0e5bc0d..0000000 --- a/integration_tests/models/chol_decomp_v1.sql +++ /dev/null @@ -1,144 +0,0 @@ -{{ - config( - materialized="table", - tags=["perftest"] - ) -}} -{#{%- set exog_aliased = ['x1', 'x2', 'x3', 'x4'] %}#} -{%- set exog_aliased = ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9'] %} -with base as ( - - select - y, - 1 as x0, - xa as x1, - xb as x2, - xc as x3, - xd as x4, - xe as x5, - xf as x6, - xg as x7, - xh as x8, - xi as x9, - xj as x10 - from {{ ref('simple_matrix') }} - -), - -xtx as ( - - select - {%- for i, j in modules.itertools.combinations_with_replacement(range(exog_aliased|length), 2) %} - sum(x{{ i }} * x{{ j }}) as x{{ i }}x{{ j }} - {%- if not loop.last -%} - , - {%- endif %} - {%- endfor %} - from base - -), - -chol as ( - - select - {%- set d = {} %} - {%- for i in range((exog_aliased | length)) %} - {%- for j in range(i + 1) %} - {%- if i == 0 and j == 0 %} - {%- do d.update({(0, 0): 'sqrt(x0x0)'}) %} - {%- else %} - {%- set ns = namespace() %} - {%- set ns.s = 'x'~j~'x'~i %} - {%- for k in range(j) %} - {%- set ns.s = ns.s~'-i'~i~'j'~k~'*i'~j~'j'~k %} -{#- {%- set ns.s = ns.s~'-'~d[(i,k)]~'*'~d[(j,k)] %}#} - {%- endfor %} - {%- if i == j %} - {%- do d.update({(i, j): 'sqrt('~ns.s~')'}) %} - {%- else %} - {%- do d.update({(i, j): '('~ns.s~')/i'~j~'j'~j}) %} -{#- {%- do d.update({(i, j): '('~ns.s~')/'~d[(j, j)]}) %}#} - {%- endif %} - {%- endif %} - {%- endfor %} - {%- endfor %} - {%- for k, v in d.items() %} - {{ v }} as {{ 'i'~k[0]~'j'~k[1] }} - {%- if not loop.last -%} - , - {%- endif %} - {%- endfor %} - from xtx - -), - -inverse_chol as ( - - select - {%- set d = {} %} - {%- for i, j in modules.itertools.combinations_with_replacement(range((exog_aliased | length)), 2) %} - {%- set ns = namespace() %} - {%- if i == j %} - {%- set ns.numerator = '1' %} - {%- else %} - {%- set ns.numerator = '(' %} - {%- for k in range(i, j) %} -{#- {%- set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*'~d[(i, k)] %}#} - {%- set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*inv_i'~i~'j'~k %} - {%- endfor %} - {%- set ns.numerator = ns.numerator~')' %} - {%- endif %} - {%- do d.update({(i, j): '('~ns.numerator~'/i'~j~'j'~j~')'}) %} - {%- endfor %} - {%- for k, v in d.items() %} - {{ v }} as inv_{{ 'i'~k[0]~'j'~k[1] }} - {%- if not loop.last -%} - , - {%- endif %} - {%- endfor %} - from chol - -), - -inverse_xtx as ( - - select - {%- for i, j in modules.itertools.combinations_with_replacement(range((exog_aliased | length)), 2) %} - {%- for k in range(j, (exog_aliased | length)) %} - inv_i{{ i }}j{{ k }} * inv_i{{ j }}j{{ k }} - {%- if not loop.last %} + {% endif -%} - {%- endfor %} - as inv_x{{ i }}x{{ j }} - {%- if not loop.last -%} - , - {%- endif %} - {%- endfor %} - from inverse_chol - -), - -linreg as ( - - select - {%- for x1 in range(exog_aliased|length) %} - sum(( - {%- for x2 in range(exog_aliased|length) %} - {%- if x2 > x1 %} - x{{ x2 }} * inv_x{{ x1 }}x{{ x2 }} - {%- else %} - x{{ x2 }} * inv_x{{ x2 }}x{{ x1 }} - {%- endif %} - {%- if not loop.last %} + {% endif -%} - {%- endfor %} - ) * y) as x{{ x1 }}_coef - {%- if not loop.last -%} - , - {%- endif %} - {%- endfor %} - from - base, - inverse_xtx - -) - -select * from linreg diff --git a/integration_tests/models/chol_decomp_v2.sql b/integration_tests/models/chol_decomp_v2.sql deleted file mode 100644 index 9921155..0000000 --- a/integration_tests/models/chol_decomp_v2.sql +++ /dev/null @@ -1,217 +0,0 @@ -{{ - config( - materialized="table", - tags=["perftest"] - ) -}} -{#{%- set exog_aliased = ['x1', 'x2', 'x3', 'x4'] %}#} -{%- set exog_aliased = ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9'] %} -with base as ( - - select - y, - 1 as x0, - xa as x1, - xb as x2, - xc as x3, - xd as x4, - xe as x5, - xf as x6, - xg as x7, - xh as x8, - xi as x9, - xj as x10 - from {{ ref('simple_matrix') }} - -), - -xtx as ( - - select - {%- for i, j in modules.itertools.combinations_with_replacement(range(exog_aliased|length), 2) %} - sum(x{{ i }} * x{{ j }}) as x{{ i }}x{{ j }} - {%- if not loop.last -%} - , - {%- endif %} - {%- endfor %} - from base - -), - -chol as ( - - select - *, - (x0x8)/i0j0 as i8j0, - (x1x8-i8j0*i1j0)/i1j1 as i8j1, - (x2x8-i8j0*i2j0-i8j1*i2j1)/i2j2 as i8j2, - (x3x8-i8j0*i3j0-i8j1*i3j1-i8j2*i3j2)/i3j3 as i8j3, - (x4x8-i8j0*i4j0-i8j1*i4j1-i8j2*i4j2-i8j3*i4j3)/i4j4 as i8j4, - (x5x8-i8j0*i5j0-i8j1*i5j1-i8j2*i5j2-i8j3*i5j3-i8j4*i5j4)/i5j5 as i8j5, - (x6x8-i8j0*i6j0-i8j1*i6j1-i8j2*i6j2-i8j3*i6j3-i8j4*i6j4-i8j5*i6j5)/i6j6 as i8j6, - (x7x8-i8j0*i7j0-i8j1*i7j1-i8j2*i7j2-i8j3*i7j3-i8j4*i7j4-i8j5*i7j5-i8j6*i7j6)/i7j7 as i8j7, - sqrt(x8x8-i8j0*i8j0-i8j1*i8j1-i8j2*i8j2-i8j3*i8j3-i8j4*i8j4-i8j5*i8j5-i8j6*i8j6-i8j7*i8j7) as i8j8 - from - (select *, - (x0x7)/i0j0 as i7j0, - (x1x7-i7j0*i1j0)/i1j1 as i7j1, - (x2x7-i7j0*i2j0-i7j1*i2j1)/i2j2 as i7j2, - (x3x7-i7j0*i3j0-i7j1*i3j1-i7j2*i3j2)/i3j3 as i7j3, - (x4x7-i7j0*i4j0-i7j1*i4j1-i7j2*i4j2-i7j3*i4j3)/i4j4 as i7j4, - (x5x7-i7j0*i5j0-i7j1*i5j1-i7j2*i5j2-i7j3*i5j3-i7j4*i5j4)/i5j5 as i7j5, - (x6x7-i7j0*i6j0-i7j1*i6j1-i7j2*i6j2-i7j3*i6j3-i7j4*i6j4-i7j5*i6j5)/i6j6 as i7j6, - sqrt(x7x7-i7j0*i7j0-i7j1*i7j1-i7j2*i7j2-i7j3*i7j3-i7j4*i7j4-i7j5*i7j5-i7j6*i7j6) as i7j7 - from - (select *, - (x0x6)/i0j0 as i6j0, - (x1x6-i6j0*i1j0)/i1j1 as i6j1, - (x2x6-i6j0*i2j0-i6j1*i2j1)/i2j2 as i6j2, - (x3x6-i6j0*i3j0-i6j1*i3j1-i6j2*i3j2)/i3j3 as i6j3, - (x4x6-i6j0*i4j0-i6j1*i4j1-i6j2*i4j2-i6j3*i4j3)/i4j4 as i6j4, - (x5x6-i6j0*i5j0-i6j1*i5j1-i6j2*i5j2-i6j3*i5j3-i6j4*i5j4)/i5j5 as i6j5, - sqrt(x6x6-i6j0*i6j0-i6j1*i6j1-i6j2*i6j2-i6j3*i6j3-i6j4*i6j4-i6j5*i6j5) as i6j6 - from - (select *, - (x0x5)/i0j0 as i5j0, - (x1x5-i5j0*i1j0)/i1j1 as i5j1, - (x2x5-i5j0*i2j0-i5j1*i2j1)/i2j2 as i5j2, - (x3x5-i5j0*i3j0-i5j1*i3j1-i5j2*i3j2)/i3j3 as i5j3, - (x4x5-i5j0*i4j0-i5j1*i4j1-i5j2*i4j2-i5j3*i4j3)/i4j4 as i5j4, - sqrt(x5x5-i5j0*i5j0-i5j1*i5j1-i5j2*i5j2-i5j3*i5j3-i5j4*i5j4) as i5j5, - from - (select *, - (x0x4)/i0j0 as i4j0, - (x1x4-i4j0*i1j0)/i1j1 as i4j1, - (x2x4-i4j0*i2j0-i4j1*i2j1)/i2j2 as i4j2, - (x3x4-i4j0*i3j0-i4j1*i3j1-i4j2*i3j2)/i3j3 as i4j3, - sqrt(x4x4-i4j0*i4j0-i4j1*i4j1-i4j2*i4j2-i4j3*i4j3) as i4j4 - from - (select *, - (x0x3)/i0j0 as i3j0, - (x1x3-i3j0*i1j0)/i1j1 as i3j1, - (x2x3-i3j0*i2j0-i3j1*i2j1)/i2j2 as i3j2, - sqrt(x3x3-i3j0*i3j0-i3j1*i3j1-i3j2*i3j2) as i3j3 - from - (select *, - (x0x2)/i0j0 as i2j0, - (x1x2-i2j0*i1j0)/i1j1 as i2j1, - sqrt(x2x2-i2j0*i2j0-i2j1*i2j1) as i2j2, - from - (select - *, - sqrt(x0x0) as i0j0, - (x0x1)/i0j0 as i1j0, - sqrt(x1x1-i1j0*i1j0) as i1j1 - from xtx))))))) - -), - -inverse_chol as ( - - select - *, - (1/i7j7) as inv_i7j7, - ((-i8j7*inv_i7j7)/i8j8) as inv_i7j8, - (1/i8j8) as inv_i8j8 - from ( - select *, - (1/i6j6) as inv_i6j6, - ((-i7j6*inv_i6j6)/i7j7) as inv_i6j7, - ((-i8j6*inv_i6j6-i8j7*inv_i6j7)/i8j8) as inv_i6j8, - from ( - select *, - (1/i5j5) as inv_i5j5, - ((-i6j5*inv_i5j5)/i6j6) as inv_i5j6, - ((-i7j5*inv_i5j5-i7j6*inv_i5j6)/i7j7) as inv_i5j7, - ((-i8j5*inv_i5j5-i8j6*inv_i5j6-i8j7*inv_i5j7)/i8j8) as inv_i5j8, - from ( - select *, - (1/i4j4) as inv_i4j4, - ((-i5j4*inv_i4j4)/i5j5) as inv_i4j5, - ((-i6j4*inv_i4j4-i6j5*inv_i4j5)/i6j6) as inv_i4j6, - ((-i7j4*inv_i4j4-i7j5*inv_i4j5-i7j6*inv_i4j6)/i7j7) as inv_i4j7, - ((-i8j4*inv_i4j4-i8j5*inv_i4j5-i8j6*inv_i4j6-i8j7*inv_i4j7)/i8j8) as inv_i4j8, - from ( - select *, - (1/i3j3) as inv_i3j3, - ((-i4j3*inv_i3j3)/i4j4) as inv_i3j4, - ((-i5j3*inv_i3j3-i5j4*inv_i3j4)/i5j5) as inv_i3j5, - ((-i6j3*inv_i3j3-i6j4*inv_i3j4-i6j5*inv_i3j5)/i6j6) as inv_i3j6, - ((-i7j3*inv_i3j3-i7j4*inv_i3j4-i7j5*inv_i3j5-i7j6*inv_i3j6)/i7j7) as inv_i3j7, - ((-i8j3*inv_i3j3-i8j4*inv_i3j4-i8j5*inv_i3j5-i8j6*inv_i3j6-i8j7*inv_i3j7)/i8j8) as inv_i3j8, - from ( - select *, - (1/i2j2) as inv_i2j2, - ((-i3j2*inv_i2j2)/i3j3) as inv_i2j3, - ((-i4j2*inv_i2j2-i4j3*inv_i2j3)/i4j4) as inv_i2j4, - ((-i5j2*inv_i2j2-i5j3*inv_i2j3-i5j4*inv_i2j4)/i5j5) as inv_i2j5, - ((-i6j2*inv_i2j2-i6j3*inv_i2j3-i6j4*inv_i2j4-i6j5*inv_i2j5)/i6j6) as inv_i2j6, - ((-i7j2*inv_i2j2-i7j3*inv_i2j3-i7j4*inv_i2j4-i7j5*inv_i2j5-i7j6*inv_i2j6)/i7j7) as inv_i2j7, - ((-i8j2*inv_i2j2-i8j3*inv_i2j3-i8j4*inv_i2j4-i8j5*inv_i2j5-i8j6*inv_i2j6-i8j7*inv_i2j7)/i8j8) as inv_i2j8, - from ( - select *, - (1/i1j1) as inv_i1j1, - ((-i2j1*inv_i1j1)/i2j2) as inv_i1j2, - ((-i3j1*inv_i1j1-i3j2*inv_i1j2)/i3j3) as inv_i1j3, - ((-i4j1*inv_i1j1-i4j2*inv_i1j2-i4j3*inv_i1j3)/i4j4) as inv_i1j4, - ((-i5j1*inv_i1j1-i5j2*inv_i1j2-i5j3*inv_i1j3-i5j4*inv_i1j4)/i5j5) as inv_i1j5, - ((-i6j1*inv_i1j1-i6j2*inv_i1j2-i6j3*inv_i1j3-i6j4*inv_i1j4-i6j5*inv_i1j5)/i6j6) as inv_i1j6, - ((-i7j1*inv_i1j1-i7j2*inv_i1j2-i7j3*inv_i1j3-i7j4*inv_i1j4-i7j5*inv_i1j5-i7j6*inv_i1j6)/i7j7) as inv_i1j7, - ((-i8j1*inv_i1j1-i8j2*inv_i1j2-i8j3*inv_i1j3-i8j4*inv_i1j4-i8j5*inv_i1j5-i8j6*inv_i1j6-i8j7*inv_i1j7)/i8j8) as inv_i1j8, - from ( - select *, - (1/i0j0) as inv_i0j0, - ((-i1j0*inv_i0j0)/i1j1) as inv_i0j1, - ((-i2j0*inv_i0j0-i2j1*inv_i0j1)/i2j2) as inv_i0j2, - ((-i3j0*inv_i0j0-i3j1*inv_i0j1-i3j2*inv_i0j2)/i3j3) as inv_i0j3, - ((-i4j0*inv_i0j0-i4j1*inv_i0j1-i4j2*inv_i0j2-i4j3*inv_i0j3)/i4j4) as inv_i0j4, - ((-i5j0*inv_i0j0-i5j1*inv_i0j1-i5j2*inv_i0j2-i5j3*inv_i0j3-i5j4*inv_i0j4)/i5j5) as inv_i0j5, - ((-i6j0*inv_i0j0-i6j1*inv_i0j1-i6j2*inv_i0j2-i6j3*inv_i0j3-i6j4*inv_i0j4-i6j5*inv_i0j5)/i6j6) as inv_i0j6, - ((-i7j0*inv_i0j0-i7j1*inv_i0j1-i7j2*inv_i0j2-i7j3*inv_i0j3-i7j4*inv_i0j4-i7j5*inv_i0j5-i7j6*inv_i0j6)/i7j7) as inv_i0j7, - ((-i8j0*inv_i0j0-i8j1*inv_i0j1-i8j2*inv_i0j2-i8j3*inv_i0j3-i8j4*inv_i0j4-i8j5*inv_i0j5-i8j6*inv_i0j6-i8j7*inv_i0j7)/i8j8) as inv_i0j8, - from chol - ))))))) -), - -inverse_xtx as ( - - select - {%- for i, j in modules.itertools.combinations_with_replacement(range((exog_aliased | length)), 2) %} - {%- for k in range(j, (exog_aliased | length)) %} - inv_i{{ i }}j{{ k }} * inv_i{{ j }}j{{ k }} - {%- if not loop.last %} + {% endif -%} - {%- endfor %} - as inv_x{{ i }}x{{ j }} - {%- if not loop.last -%} - , - {%- endif %} - {%- endfor %} - from inverse_chol - -), - -linreg as ( - - select - {%- for x1 in range(exog_aliased|length) %} - sum(( - {%- for x2 in range(exog_aliased|length) %} - {%- if x2 > x1 %} - x{{ x2 }} * inv_x{{ x1 }}x{{ x2 }} - {%- else %} - x{{ x2 }} * inv_x{{ x2 }}x{{ x1 }} - {%- endif %} - {%- if not loop.last %} + {% endif -%} - {%- endfor %} - ) * y) as x{{ x1 }}_coef - {%- if not loop.last -%} - , - {%- endif %} - {%- endfor %} - from - base, - inverse_xtx - -) - -select * from linreg diff --git a/integration_tests/models/collinear_matrix_regression_chol_unoptimized.sql b/integration_tests/models/collinear_matrix_regression_chol_unoptimized.sql new file mode 100644 index 0000000..c85c9aa --- /dev/null +++ b/integration_tests/models/collinear_matrix_regression_chol_unoptimized.sql @@ -0,0 +1,15 @@ +{{ + config( + materialized="table" + ) +}} +select * from {{ + dbt_linreg.ols( + table=ref('collinear_matrix'), + endog='y', + exog=['x1', 'x2', 'x3', 'x4', 'x5'], + format='long', + method='chol', + method_options={'subquery_optimization': False} + ) +}} diff --git a/integration_tests/models/collinear_matrix_ridge_regression_chol_unoptimized.sql b/integration_tests/models/collinear_matrix_ridge_regression_chol_unoptimized.sql new file mode 100644 index 0000000..63888e7 --- /dev/null +++ b/integration_tests/models/collinear_matrix_ridge_regression_chol_unoptimized.sql @@ -0,0 +1,16 @@ +{{ + config( + materialized="table" + ) +}} +select * from {{ + dbt_linreg.ols( + table=ref('collinear_matrix'), + endog='y', + exog=['x1', 'x2', 'x3', 'x4', 'x5'], + format='long', + alpha=0.01, + method='chol', + method_options={'subquery_optimization': False} + ) +}} diff --git a/integration_tests/models/groups_matrix_regression_chol_unoptimized.sql b/integration_tests/models/groups_matrix_regression_chol_unoptimized.sql new file mode 100644 index 0000000..e1026d8 --- /dev/null +++ b/integration_tests/models/groups_matrix_regression_chol_unoptimized.sql @@ -0,0 +1,17 @@ +{{ + config( + materialized="table" + ) +}} +select * from {{ + dbt_linreg.ols( + table=ref('groups_matrix'), + endog='y', + exog=['x1', 'x2', 'x3'], + group_by=['gb_var'], + format='long', + method='chol', + method_options={'subquery_optimization': False} + ) +}} +order by gb_var, variable_name diff --git a/integration_tests/tests/test_collinear_matrix_regression_chol_unoptimized.sql b/integration_tests/tests/test_collinear_matrix_regression_chol_unoptimized.sql new file mode 100644 index 0000000..aeb37db --- /dev/null +++ b/integration_tests/tests/test_collinear_matrix_regression_chol_unoptimized.sql @@ -0,0 +1,25 @@ +with + +expected as ( + + select 'const' as variable_name, 19.757104885315176 as coefficient + union all + select 'x1' as variable_name, 9.90708767581426 as coefficient + union all + select 'x2' as variable_name, 6.187473206056227 as coefficient + union all + select 'x3' as variable_name, 19.66874583168642 as coefficient + union all + select 'x4' as variable_name, 3.7192417102253468 as coefficient + union all + select 'x5' as variable_name, 13.444273483323244 as coefficient + +) + +select base.variable_name +from {{ ref('collinear_matrix_regression_chol_unoptimized') }} as base +full outer join expected +on base.variable_name = expected.variable_name +where + round(base.coefficient, 7) - round(expected.coefficient, 7) + or base.coefficient is null diff --git a/integration_tests/tests/test_collinear_matrix_ridge_regression_chol_unoptimized.sql b/integration_tests/tests/test_collinear_matrix_ridge_regression_chol_unoptimized.sql new file mode 100644 index 0000000..926771b --- /dev/null +++ b/integration_tests/tests/test_collinear_matrix_ridge_regression_chol_unoptimized.sql @@ -0,0 +1,29 @@ +/* Ridge regression coefficients do not match exactly. + Instead, a threshold of no more than 0.01% deviation is enforced. */ +{% set THRESHOLD = 0.0001 %} +with + +expected as ( + + select 'const' as variable_name, 20.7548151107157 as coefficient + union all + select 'x1' as variable_name, 9.784064449021356 as coefficient + union all + select 'x2' as variable_name, 6.315640539781496 as coefficient + union all + select 'x3' as variable_name, 19.578696589513562 as coefficient + union all + select 'x4' as variable_name, 3.736823845978248 as coefficient + union all + select 'x5' as variable_name, 13.323547772767592 as coefficient + +) + +select base.variable_name +from {{ ref('collinear_matrix_ridge_regression_chol_unoptimized') }} as base +full outer join expected +on base.variable_name = expected.variable_name +where + abs(log(abs(base.coefficient)) - log(abs(expected.coefficient))) > {{ THRESHOLD }} + or sign(base.coefficient) != sign(expected.coefficient) + or base.coefficient is null diff --git a/integration_tests/tests/test_groups_matrix_regression_chol.sql b/integration_tests/tests/test_groups_matrix_regression_chol.sql index cc979e1..4d726b2 100644 --- a/integration_tests/tests/test_groups_matrix_regression_chol.sql +++ b/integration_tests/tests/test_groups_matrix_regression_chol.sql @@ -1,6 +1,3 @@ -/* Ridge regression coefficients do not match exactly. - Instead, a threshold of no more than 0.01% deviation is enforced. */ -{% set THRESHOLD = 0.0001 %} with expected as ( diff --git a/integration_tests/tests/test_groups_matrix_regression_chol_unoptimized.sql b/integration_tests/tests/test_groups_matrix_regression_chol_unoptimized.sql new file mode 100644 index 0000000..1b6ea14 --- /dev/null +++ b/integration_tests/tests/test_groups_matrix_regression_chol_unoptimized.sql @@ -0,0 +1,31 @@ +with + +expected as ( + + select 'a' as gb_var, 'const' as variable_name, -0.06563066041472207 as coefficient + union all + select 'a' as gb_var, 'x1' as variable_name, 0.9905419281557593 as coefficient + union all + select 'a' as gb_var, 'x2' as variable_name, 4.948221700496285 as coefficient + union all + select 'a' as gb_var, 'x3' as variable_name, 0.031234030051974747 as coefficient + union all + select 'b' as gb_var, 'const' as variable_name, 2.0117130483709955 as coefficient + union all + select 'b' as gb_var, 'x1' as variable_name, 2.996331112245573 as coefficient + union all + select 'b' as gb_var, 'x2' as variable_name, 9.019683491736044 as coefficient + union all + select 'b' as gb_var, 'x3' as variable_name, 0.016151316166848173 as coefficient + +) + +select base.variable_name +from {{ ref('groups_matrix_regression_chol_unoptimized') }} as base +full outer join expected +on + base.gb_var = expected.gb_var + and base.variable_name = expected.variable_name +where + round(base.coefficient, 7) - round(expected.coefficient, 7) + or base.coefficient is null diff --git a/integration_tests/tests/test_simple_0var_regression_long_chol.sql b/integration_tests/tests/test_simple_0var_regression_long_chol.sql index 8f1b945..a5aa18c 100644 --- a/integration_tests/tests/test_simple_0var_regression_long_chol.sql +++ b/integration_tests/tests/test_simple_0var_regression_long_chol.sql @@ -7,7 +7,7 @@ expected as ( ) select base.variable_name -from {{ ref('simple_2var_regression_long_chol') }} as base +from {{ ref('simple_0var_regression_long_chol') }} as base full outer join expected on base.variable_name = expected.variable_name where base.coefficient != expected.coefficient or base.coefficient is null diff --git a/integration_tests/tests/test_simple_0var_regression_long_fwl.sql b/integration_tests/tests/test_simple_0var_regression_long_fwl.sql index 4e91586..366d526 100644 --- a/integration_tests/tests/test_simple_0var_regression_long_fwl.sql +++ b/integration_tests/tests/test_simple_0var_regression_long_fwl.sql @@ -7,7 +7,7 @@ expected as ( ) select base.variable_name -from {{ ref('simple_2var_regression_long_fwl') }} as base +from {{ ref('simple_0var_regression_long_fwl') }} as base full outer join expected on base.variable_name = expected.variable_name where base.coefficient != expected.coefficient or base.coefficient is null diff --git a/integration_tests/tests/test_simple_0var_regression_wide.sql b/integration_tests/tests/test_simple_0var_regression_wide.sql index 69af7b1..bf7a58e 100644 --- a/integration_tests/tests/test_simple_0var_regression_wide.sql +++ b/integration_tests/tests/test_simple_0var_regression_wide.sql @@ -7,7 +7,7 @@ expected as ( ) select base.* -from {{ ref('simple_2var_regression_wide') }} as base, expected +from {{ ref('simple_0var_regression_wide') }} as base, expected where not ( base.const = expected.const ) diff --git a/macros/linear_regression/ols.sql b/macros/linear_regression/ols.sql index 58706ff..9f82acb 100644 --- a/macros/linear_regression/ols.sql +++ b/macros/linear_regression/ols.sql @@ -8,7 +8,8 @@ format_options=None, group_by=None, alpha=None, - method='fwl') -%} + method=None, + method_options=None) -%} {############################################################################# @@ -40,6 +41,10 @@ {% set format_options = {} %} {% endif %} + {% if method_options is none %} + {% set method_options = {} %} + {% endif %} + {% if y is not none and endog is none %} {% set endog = y %} {% elif y is not none and endog is not none %} @@ -69,6 +74,10 @@ {% set alpha = [alpha] * (exog | length) %} {% endif %} + {% if method is none %} + {% set method = 'chol' %} + {% endif %} + {# Check for user input errors #} {# --------------------------- #} @@ -124,7 +133,21 @@ ) }} {% endif %} - {% if method == 'fwl' %} + {% if method == 'chol' %} + {{ return( + dbt_linreg._ols_chol( + table=table, + endog=endog, + exog=exog, + add_constant=add_constant, + format=format, + format_options=format_options, + group_by=group_by, + alpha=alpha, + method_options=method_options + ) + ) }} + {% elif method == 'fwl' %} {{ return( dbt_linreg._ols_fwl( table=table, @@ -134,12 +157,13 @@ format=format, format_options=format_options, group_by=group_by, - alpha=alpha + alpha=alpha, + method_options=method_options ) ) }} {% else %} {{ exceptions.raise_compiler_error( - "Invalid method specified. The only currently valid method is 'fwl'" + "Invalid method specified. The only valid methods are 'chol' and 'fwl'" ) }} {% endif %} diff --git a/macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql b/macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql index 0e5bc0d..0c810f2 100644 --- a/macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql +++ b/macros/linear_regression/ols_impl_chol/_ols_impl_chol.sql @@ -1,144 +1,306 @@ -{{ - config( - materialized="table", - tags=["perftest"] - ) -}} -{#{%- set exog_aliased = ['x1', 'x2', 'x3', 'x4'] %}#} -{%- set exog_aliased = ['x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9'] %} -with base as ( +{# In some warehouses, you can reference newly created column aliases + in the query you wrote. + If that's not available, the previous calc will be in the dict. #} +{% macro _cell_or_alias(i, j, d) %} + {{ return( + adapter.dispatch('_cell_or_alias', 'dbt_linreg') + (i, j, d) + ) }} +{% endmacro %} - select - y, - 1 as x0, - xa as x1, - xb as x2, - xc as x3, - xd as x4, - xe as x5, - xf as x6, - xg as x7, - xh as x8, - xi as x9, - xj as x10 - from {{ ref('simple_matrix') }} +{% macro default___cell_or_alias(i, j, d) %} + {{ return(d[(i, j)]) }} +{% endmacro %} -), +{% macro snowflake___cell_or_alias(i, j, d) %} + {{ return('i' ~ i ~ 'j' ~ j) }} +{% endmacro %} + +{% macro duckdb___cell_or_alias(i, j, d) %} + {{ return('i' ~ i ~ 'j' ~ j) }} +{% endmacro %} + +{% macro _safe_sqrt(x, safe=True) %} + {{ return( + adapter.dispatch('_safe_sqrt', 'dbt_linreg') + (x, safe) + ) }} +{% endmacro %} + +{% macro default___safe_sqrt(x, safe=True) %} + {% if safe %} + {{ return('case when ('~x~') >= 0 then sqrt('~x~') end') }} + {% endif %} + {{ return('sqrt('~x~')') }} +{% endmacro %} -xtx as ( +{% macro bigquery___safe_sqrt(x, safe=True) %} + {% if safe %} + {{ return('safe.sqrt('~x~')') }} + {% endif %} + {{ return('sqrt('~x~')') }} +{% endmacro %} +{% macro _cholesky_decomposition(li, subquery_optimization=True, safe=True) %} + {% set d = {} %} + {% for i in li %} + {% for j in range(li[0], i + 1) %} + {% if i == li[0] and j == li[0] %} + {% do d.update({(i, j): dbt_linreg._safe_sqrt(x='x'~i~'x'~j, safe=safe)}) %} + {% else %} + {% set ns = namespace() %} + {% set ns.s = 'x'~j~'x'~i %} + {% for k in range(li[0], j) %} + {% if subquery_optimization and i != j %} + {% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d)~'*i'~j~'j'~k %} + {% else %} + {% set ns.s = ns.s~'-'~dbt_linreg._cell_or_alias(i=i, j=k, d=d)~'*'~dbt_linreg._cell_or_alias(i=j, j=k, d=d) %} + {% endif %} + {% endfor %} + {% if i == j %} + {% do d.update({(i, j): dbt_linreg._safe_sqrt(x=ns.s, safe=safe)}) %} + {% else %} + {% do d.update({(i, j): '('~ns.s~')/'~dbt_linreg._cell_or_alias(i=j, j=j, d=d)}) %} + {% endif %} + {% endif %} + {% endfor %} + {% endfor %} + {{ return(d) }} +{% endmacro %} + +{% macro _forward_substitution(li) %} + {% set d = {} %} + {% for i, j in modules.itertools.combinations_with_replacement(li, 2) %} + {% set ns = namespace() %} + {% if i == j %} + {% set ns.numerator = '1' %} + {% else %} + {% set ns.numerator = '(' %} + {% for k in range(i, j) %} + {% set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*inv_'~dbt_linreg._cell_or_alias(i=i, j=k, d=d) %} + {% endfor %} + {% set ns.numerator = ns.numerator~')' %} + {% endif %} + {% do d.update({(i, j): '('~ns.numerator~'/i'~j~'j'~j~')'}) %} + {% endfor %} + {{ return(d) }} +{% endmacro %} + +{% macro _ols_chol(table, + endog, + exog, + add_constant=True, + format=None, + format_options=None, + group_by=None, + alpha=None, + method_options=None) -%} +{%- if (exog | length) == 0 %} + {% do log('Note: exog was empty; running regression on constant term only.') %} + {{ return(dbt_linreg._ols_0var( + table=table, + endog=endog, + exog=exog, + add_constant=add_constant, + format=format, + format_options=format_options, + group_by=group_by, + alpha=alpha + )) }} +{%- endif %} +{%- set subquery_optimization = method_options.get('subquery_optimization', True) %} +{%- set safe_sqrt = method_options.get('safe', True) %} +{%- set calculate_standard_error = format_options.get('calculate_standard_error', True) and format == 'long' %} +{%- if add_constant %} + {% set xmin = 0 %} +{%- else %} + {% set xmin = 1 %} +{%- endif %} +{%- set xcols = (range(xmin, (exog | length) + 1) | list) %} +{%- set exog_aliased = dbt_linreg._alias_exog(exog) %} +(with +_dbt_linreg_base as ( select - {%- for i, j in modules.itertools.combinations_with_replacement(range(exog_aliased|length), 2) %} - sum(x{{ i }} * x{{ j }}) as x{{ i }}x{{ j }} + {{ dbt_linreg._alias_gb_cols(group_by) | indent(4) }} + {{ endog }} as y, + {%- if add_constant %} + 1 as x0, + {%- endif %} + {%- for i in range(1, (exog | length) + 1) %} + b.{{ exog[loop.index0] }} as x{{ i }} {%- if not loop.last -%} , {%- endif %} {%- endfor %} - from base - + from + {{ table }} as b ), - -chol as ( - +_dbt_linreg_xtx as ( select - {%- set d = {} %} - {%- for i in range((exog_aliased | length)) %} - {%- for j in range(i + 1) %} - {%- if i == 0 and j == 0 %} - {%- do d.update({(0, 0): 'sqrt(x0x0)'}) %} - {%- else %} - {%- set ns = namespace() %} - {%- set ns.s = 'x'~j~'x'~i %} - {%- for k in range(j) %} - {%- set ns.s = ns.s~'-i'~i~'j'~k~'*i'~j~'j'~k %} -{#- {%- set ns.s = ns.s~'-'~d[(i,k)]~'*'~d[(j,k)] %}#} - {%- endfor %} - {%- if i == j %} - {%- do d.update({(i, j): 'sqrt('~ns.s~')'}) %} + {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }} + {%- for i, j in modules.itertools.combinations_with_replacement(xcols, 2) %} + {%- if alpha and i == j and i > 0 %} + sum(b.x{{ i }} * b.x{{ j }} + {{ alpha[i-1] }}) as x{{ i }}x{{ j }} {%- else %} - {%- do d.update({(i, j): '('~ns.s~')/i'~j~'j'~j}) %} -{#- {%- do d.update({(i, j): '('~ns.s~')/'~d[(j, j)]}) %}#} + sum(b.x{{ i }} * b.x{{ j }}) as x{{ i }}x{{ j }} {%- endif %} + {%- if not loop.last -%} + , {%- endif %} {%- endfor %} + from _dbt_linreg_base as b + {%- if group_by %} + group by + {{ dbt_linreg._gb_cols(group_by) | indent(4) }} + {%- endif %} +), +_dbt_linreg_chol as ( + + {%- set d = dbt_linreg._cholesky_decomposition(li=xcols, subquery_optimization=subquery_optimization, safe=safe_sqrt) %} + {%- if subquery_optimization %} + {%- for i in (xcols | reverse) %} + select + *, + {%- for j in range(xmin, i + 1) %} + {{ d[(i, j)] }} as i{{ i }}j{{ j }} + {%- if not loop.last -%} + , + {%- endif %} {%- endfor %} + {%- if not loop.last %} + from ( + {%- else %} + from _dbt_linreg_xtx{{ ')' * ((xcols | length) - 1) }} + {%- endif %} + {%- endfor %} + {%- else %} + select + {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }} {%- for k, v in d.items() %} {{ v }} as {{ 'i'~k[0]~'j'~k[1] }} {%- if not loop.last -%} , {%- endif %} {%- endfor %} - from xtx - + from _dbt_linreg_xtx + {%- endif %} ), - -inverse_chol as ( - - select - {%- set d = {} %} - {%- for i, j in modules.itertools.combinations_with_replacement(range((exog_aliased | length)), 2) %} - {%- set ns = namespace() %} - {%- if i == j %} - {%- set ns.numerator = '1' %} - {%- else %} - {%- set ns.numerator = '(' %} - {%- for k in range(i, j) %} -{#- {%- set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*'~d[(i, k)] %}#} - {%- set ns.numerator = ns.numerator~'-i'~j~'j'~k~'*inv_i'~i~'j'~k %} - {%- endfor %} - {%- set ns.numerator = ns.numerator~')' %} +_dbt_linreg_inverse_chol as ( + {#- The optimal way to calculate is to do each diagonal at a time. #} + {%- set d = dbt_linreg._forward_substitution(li=xcols) %} + {%- if subquery_optimization %} + {%- set upto = (xcols | length) %} + {%- for gap in (range(0, upto) | reverse) %} + select *, + {%- for j in range(gap + xmin, upto + xmin) %} + {%- set i = j - gap %} + {{ d[(i, j)] }} as inv_i{{ i }}j{{ j }} + {%- if not loop.last -%} + , {%- endif %} - {%- do d.update({(i, j): '('~ns.numerator~'/i'~j~'j'~j~')'}) %} {%- endfor %} + {%- if not loop.last %} + from ( + {%- else %} + from _dbt_linreg_chol{{ ')' * (upto - 1) }} + {%- endif %} + {%- endfor %} + {%- else %} + select + {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }} {%- for k, v in d.items() %} {{ v }} as inv_{{ 'i'~k[0]~'j'~k[1] }} {%- if not loop.last -%} , {%- endif %} {%- endfor %} - from chol - + from _dbt_linreg_chol + {%- endif %} ), - -inverse_xtx as ( - +_dbt_linreg_inverse_xtx as ( select - {%- for i, j in modules.itertools.combinations_with_replacement(range((exog_aliased | length)), 2) %} - {%- for k in range(j, (exog_aliased | length)) %} - inv_i{{ i }}j{{ k }} * inv_i{{ j }}j{{ k }} - {%- if not loop.last %} + {% endif -%} + {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(4) }} + {%- for i, j in modules.itertools.combinations_with_replacement(xcols, 2) %} + {%- set upto = (xcols | length) %} + {%- if not add_constant %} + {%- set upto = upto + 1 %} + {%- endif %} + {%- for k in range(j, upto) %} + inv_i{{ i }}j{{ k }} * inv_i{{ j }}j{{ k }}{%- if not loop.last %} + {% endif -%} {%- endfor %} as inv_x{{ i }}x{{ j }} {%- if not loop.last -%} , {%- endif %} {%- endfor %} - from inverse_chol - + from _dbt_linreg_inverse_chol ), - -linreg as ( - +_dbt_linreg_final_coefs as ( select - {%- for x1 in range(exog_aliased|length) %} + {{ dbt_linreg._gb_cols(group_by, trailing_comma=True, prefix='b') | indent(4) }} + {%- for x1 in xcols %} sum(( - {%- for x2 in range(exog_aliased|length) %} + {%- for x2 in xcols %} {%- if x2 > x1 %} - x{{ x2 }} * inv_x{{ x1 }}x{{ x2 }} + b.x{{ x2 }} * inv_x{{ x1 }}x{{ x2 }} {%- else %} - x{{ x2 }} * inv_x{{ x2 }}x{{ x1 }} + b.x{{ x2 }} * inv_x{{ x2 }}x{{ x1 }} {%- endif %} {%- if not loop.last %} + {% endif -%} {%- endfor %} - ) * y) as x{{ x1 }}_coef + ) * b.y) as x{{ x1 }}_coef {%- if not loop.last -%} , {%- endif %} {%- endfor %} from - base, - inverse_xtx - + _dbt_linreg_base as b + {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_inverse_xtx') | indent(2) }} + {%- if group_by %} + group by + {{ dbt_linreg._gb_cols(group_by, prefix='b') | indent(4) }} + {%- endif %} +){%- if calculate_standard_error %}, +_dbt_linreg_resid as ( + select + {{ dbt_linreg._gb_cols(group_by, trailing_comma=True, prefix='b') | indent(4) }} + var_pop(y + {%- for x in xcols %} + - x{{ x }} * x{{ x }}_coef + {%- endfor %} + ) as resid_var, + count(*) as n + from + _dbt_linreg_base as b + {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_final_coefs') | indent(2) }} + {%- if group_by %} + group by + {{ dbt_linreg._gb_cols(group_by, prefix='b') | indent(2) }} + {%- endif %} +), +_dbt_linreg_stderrs as ( + select + {{ dbt_linreg._gb_cols(group_by, trailing_comma=True, prefix='b') | indent(4) }} + {%- for x in xcols %} + sqrt(inv_x{{ x }}x{{ x }} * resid_var * n / (n - {{ xcols | length }})) as x{{ x }}_stderr + {%- if not loop.last -%} + , + {%- endif %} + {%- endfor %} + from + _dbt_linreg_resid as b + {{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_inverse_xtx') | indent(2) }} ) - -select * from linreg +{%- endif %} +{{ + dbt_linreg.final_select( + exog=exog, + exog_aliased=exog_aliased, + add_constant=add_constant, + group_by=group_by, + format=format, + format_options=format_options, + calculate_standard_error=calculate_standard_error + ) +}}) +{% endmacro %} diff --git a/macros/linear_regression/ols_impl_fwl/_ols_impl_fwl.sql b/macros/linear_regression/ols_impl_fwl/_ols_impl_fwl.sql index fd4a858..755e693 100644 --- a/macros/linear_regression/ols_impl_fwl/_ols_impl_fwl.sql +++ b/macros/linear_regression/ols_impl_fwl/_ols_impl_fwl.sql @@ -99,7 +99,8 @@ format=None, format_options=None, group_by=None, - alpha=None) -%} + alpha=None, + method_options=None) -%} {%- if (exog | length) == 0 %} {% do log('Note: exog was empty; running regression on constant term only.') %} {{ return(dbt_linreg._ols_0var( @@ -185,16 +186,14 @@ _dbt_linreg_step0 as ( {% for step in range(1, (exog | length)) %} _dbt_linreg_step{{ step }} as ( with - _coefs as ( + __dbt_linreg_coefs as ( select {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) | indent(6) }} {#- Slope terms #} - {%- for _y, _x, _o in dbt_linreg._traverse_slopes(step, exog_aliased) %} {%- set _c = dbt_linreg._orth_x_slope(_x, _o) %} {{ dbt_linreg.regress(_y, _c, add_constant=add_constant) }} as {{ _y }}_{{ _c }}_coef, {%- endfor %} - {#- Constant terms #} {%- if add_constant %} {%- for _y, _o in dbt_linreg._traverse_intercepts(step, exog_aliased) %} @@ -238,7 +237,7 @@ _dbt_linreg_step{{ step }} as ( {%- endif %} {%- endfor %} from _dbt_linreg_step0 as b - {{ dbt_linreg._join_on_groups(group_by, 'b', '_coefs') | indent(2) }} + {{ dbt_linreg._join_on_groups(group_by, 'b', '__dbt_linreg_coefs') | indent(2) }} ), {%- if loop.last %} _dbt_linreg_final_coefs as ( @@ -249,7 +248,7 @@ _dbt_linreg_final_coefs as ( {%- for _x, _o in dbt_linreg._traverse_intercepts(step, exog_aliased) %} - avg({{ dbt_linreg._filter_and_center_if_alpha(_x, alpha, base_prefix='b.') }}) * {{ dbt_linreg.regress('b.y', dbt_linreg._orth_x_intercept('b.' ~ _x, _o)) }} {%- endfor %} - as const_coef, + as x0_coef, {%- endif %} {%- for _x, _o in dbt_linreg._traverse_intercepts(step, exog_aliased) %} {{ dbt_linreg.regress('b.y', dbt_linreg._orth_x_intercept(_x, _o), add_constant=add_constant) }} as {{ _x }}_coef @@ -275,7 +274,8 @@ _dbt_linreg_final_coefs as ( add_constant=add_constant, group_by=group_by, format=format, - format_options=format_options + format_options=format_options, + calculate_standard_error=False ) }} ) diff --git a/macros/linear_regression/ols_impl_special/_ols_0var.sql b/macros/linear_regression/ols_impl_special/_ols_0var.sql index cf3dd38..9e1cdf0 100644 --- a/macros/linear_regression/ols_impl_special/_ols_0var.sql +++ b/macros/linear_regression/ols_impl_special/_ols_0var.sql @@ -9,7 +9,7 @@ (with _dbt_linreg_final_coefs as ( select {{ dbt_linreg._gb_cols(group_by, trailing_comma=True) }} - avg({{ endog }}) as const_coef + avg({{ endog }}) as x0_coef from {{ table }} {%- if group_by %} group by @@ -23,7 +23,8 @@ add_constant=add_constant, group_by=group_by, format=format, - format_options=format_options + format_options=format_options, + calculate_standard_error=False ) }} ) diff --git a/macros/linear_regression/ols_impl_special/_ols_1var.sql b/macros/linear_regression/ols_impl_special/_ols_1var.sql index ea79aef..a8c301d 100644 --- a/macros/linear_regression/ols_impl_special/_ols_1var.sql +++ b/macros/linear_regression/ols_impl_special/_ols_1var.sql @@ -55,7 +55,7 @@ _dbt_linreg_final_coefs as ( {%- if add_constant %} avg({{ dbt_linreg._filter_and_center_if_alpha('b.y', alpha) }}) - avg({{ dbt_linreg._filter_and_center_if_alpha('b.x1', alpha) }}) * {{ dbt_linreg.regress('b.y', 'b.x1') }} - as const_coef, + as x0_coef, {%- endif %} {{ dbt_linreg.regress('b.y', 'b.x1', add_constant=add_constant) }} as x1_coef from _dbt_linreg_base as b @@ -74,7 +74,8 @@ _dbt_linreg_final_coefs as ( add_constant=add_constant, group_by=group_by, format=format, - format_options=format_options + format_options=format_options, + calculate_standard_error=False ) }} ) diff --git a/macros/linear_regression/utils.sql b/macros/linear_regression/utils.sql index 82f85af..e3f66eb 100644 --- a/macros/linear_regression/utils.sql +++ b/macros/linear_regression/utils.sql @@ -38,24 +38,36 @@ add_constant=True, format=None, format_options=None, - round_=None) -%} + calculate_standard_error=False) -%} {%- if format == 'long' %} {%- if add_constant %} select - {{ dbt_linreg._unalias_gb_cols(group_by) }} + {{ dbt_linreg._unalias_gb_cols(group_by, prefix='b') | indent(2) }} '{{ format_options.get('constant_name', 'const') }}' as {{ format_options.get('variable_column_name', 'variable_name') }}, - {{ dbt_linreg._fmt_final_coef('const', format_options.get('round')) }} as {{ format_options.get('coefficient_column_name', 'coefficient') }} -from _dbt_linreg_final_coefs + {{ dbt_linreg._maybe_round('x0_coef', format_options.get('round')) }} as {{ format_options.get('coefficient_column_name', 'coefficient') }}{% if calculate_standard_error %}, + {{ dbt_linreg._maybe_round('x0_stderr', format_options.get('round')) }} as {{ format_options.get('standard_error_column_name', 'standard_error') }}, + {{ dbt_linreg._maybe_round('x0_coef/x0_stderr', format_options.get('round')) }} as {{ format_options.get('t_statistic_column_name', 't_statistic') }} + {%- endif %} +from _dbt_linreg_final_coefs as b +{%- if calculate_standard_error %} +{{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_stderrs') }} +{%- endif %} {%- if exog_aliased %} union all {%- endif %} {%- endif %} {%- for i in exog_aliased %} select - {{ dbt_linreg._unalias_gb_cols(group_by) }} - '{{ dbt_linreg._strip_quotes(exog[loop.index0], format_options) }}' as variable_name, - {{ dbt_linreg._fmt_final_coef(i, format_options.get('round')) }} as coefficient -from _dbt_linreg_final_coefs + {{ dbt_linreg._unalias_gb_cols(group_by, prefix='b') | indent(2) }} + '{{ dbt_linreg._strip_quotes(exog[loop.index0], format_options) }}' as {{ format_options.get('variable_column_name', 'variable_name') }}, + {{ dbt_linreg._maybe_round(i~'_coef', format_options.get('round')) }} as {{ format_options.get('coefficient_column_name', 'coefficient') }}{% if calculate_standard_error %}, + {{ dbt_linreg._maybe_round(i~'_stderr', format_options.get('round')) }} as {{ format_options.get('standard_error_column_name', 'standard_error') }}, + {{ dbt_linreg._maybe_round(i~'_coef/'~i~'_stderr', format_options.get('round')) }} as {{ format_options.get('t_statistic_column_name', 't_statistic') }} + {%- endif %} +from _dbt_linreg_final_coefs as b +{%- if calculate_standard_error %} +{{ dbt_linreg._join_on_groups(group_by, 'b', '_dbt_linreg_stderrs') }} +{%- endif %} {%- if not loop.last %} union all {%- endif %} @@ -63,14 +75,14 @@ union all {%- elif format == 'wide' %} select {%- if add_constant -%} - {{ dbt_linreg._unalias_gb_cols(group_by) }} - {{ dbt_linreg._fmt_final_coef('const', format_options.get('round')) }} as {{ dbt_linreg._format_wide_variable_column(format_options.get('constant_name', 'const'), format_options) }} + {{ dbt_linreg._unalias_gb_cols(group_by) | indent(2) }} + {{ dbt_linreg._maybe_round('x0_coef', format_options.get('round')) }} as {{ dbt_linreg._format_wide_variable_column(format_options.get('constant_name', 'const'), format_options) }} {%- if exog_aliased -%} , {%- endif -%} {%- endif -%} {%- for i in exog_aliased %} - {{ dbt_linreg._fmt_final_coef(i, format_options.get('round')) }} as {{ dbt_linreg._format_wide_variable_column(exog[loop.index0], format_options) }} + {{ dbt_linreg._maybe_round(i~'_coef', format_options.get('round')) }} as {{ dbt_linreg._format_wide_variable_column(exog[loop.index0], format_options) }} {%- if not loop.last -%} , {%- endif %} @@ -128,21 +140,25 @@ select * from _dbt_linreg_final_coefs {%- endmacro %} {# This macros reverses gb column aliases at the end of an OLS query. #} -{% macro _unalias_gb_cols(group_by) -%} +{% macro _unalias_gb_cols(group_by, prefix=None) -%} {%- if group_by %} {%- for gb in group_by %} +{%- if prefix %} +{{ prefix }}.gb{{ loop.index }} as {{ gb }}, +{%- else %} gb{{ loop.index }} as {{ gb }}, +{%- endif %} {%- endfor %} {%- endif %} {%- endmacro %} {# Round the final coefficient if the user specifies the `round` format option. Otherwise, keep as is. #} -{% macro _fmt_final_coef(x, round_) %} +{% macro _maybe_round(x, round_) %} {% if round_ is not none %} - {{ return('round(' ~ x ~ '_coef, ' ~ round_ ~ ')') }} + {{ return('round(' ~ x ~ ', ' ~ round_ ~ ')') }} {% else %} - {{ return(x ~ '_coef') }} + {{ return(x) }} {% endif %} {% endmacro %} diff --git a/run b/run index bcbd709..a04d663 100755 --- a/run +++ b/run @@ -2,18 +2,30 @@ set -eo pipefail + + function setup { poetry install poetry run pre-commit install } +function testloc { + # rm -f integration_tests/dbt.duckdb + export DBT_PROFILES_DIR=./integration_tests/profiles + poetry run dbt deps --project-dir ./integration_tests + # poetry run dbt compile --project-dir ./integration_tests --select tag:perftest + poetry run dbt run --project-dir ./integration_tests --select tag:perftest +} + + function test { # rm -f integration_tests/dbt.duckdb - # poetry run python scripts.py gen-test-cases - poetry run dbt deps --project-dir ./integration_tests/ --profiles-dir ./integration_tests/profiles - # poetry run dbt seed --project-dir ./integration_tests/ --profiles-dir ./integration_tests/profiles - poetry run dbt run --project-dir ./integration_tests/ --profiles-dir ./integration_tests/profiles - poetry run dbt test --project-dir ./integration_tests/ --profiles-dir ./integration_tests/profiles + export DBT_PROFILES_DIR=./integration_tests/profiles + poetry run python scripts.py gen-test-cases + poetry run dbt deps --project-dir ./integration_tests + poetry run dbt seed --project-dir ./integration_tests + poetry run dbt run --project-dir ./integration_tests + poetry run dbt test --project-dir ./integration_tests } function lint { diff --git a/scripts.py b/scripts.py index d7670a4..66e6675 100644 --- a/scripts.py +++ b/scripts.py @@ -1,4 +1,5 @@ import os.path as op +import warnings from typing import NamedTuple from typing import Optional from typing import Protocol @@ -10,6 +11,13 @@ from tabulate import tabulate +# Suppress iteritems warning +warnings.simplefilter("ignore", category=FutureWarning) + +# No scientific notation +np.set_printoptions(suppress=True) + + DIR = op.dirname(__file__) DEFAULT_SIZE = 10_000 @@ -205,13 +213,13 @@ def cli(): def regress(table: str, const: bool, columns: int, alpha: float, size: int, seed: int): callback = ALL_TEST_CASES[table] - click.echo(click.style("=" * 80, fg="red")) + click.echo(click.style("=" * 80, fg="blue")) click.echo( - click.style("Test case: ", fg="red", bold=True) + click.style("Test case: ", fg="blue", bold=True) + - click.style(table, fg="red") + click.style(table, fg="blue") ) - click.echo(click.style("=" * 80, fg="red")) + click.echo(click.style("=" * 80, fg="blue")) test_case = callback(size, seed) @@ -219,10 +227,10 @@ def regress(table: str, const: bool, columns: int, alpha: float, size: int, seed x_cols = test_case.x_cols else: # K plus Constant (1) - x_cols = test_case.x_cols[:columns] + x_cols = test_case.x_cols[:columns+1] - if const: - x_cols = ["const"] + x_cols + if not const: + x_cols = [i for i in x_cols if i != "const"] def _run_model(cond=None): if cond is None: @@ -230,7 +238,10 @@ def _run_model(cond=None): y = test_case.df.loc[cond, test_case.y_col] x_mat = test_case.df.loc[cond, x_cols] if alpha: - alpha_arr = [0, *([alpha] * (len(x_mat.columns) - 1))] + if const: + alpha_arr = [0, *([alpha] * (len(x_mat.columns) - 1))] + else: + alpha_arr = [alpha] * len(x_mat.columns) model = sm.OLS( y, x_mat