Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

easy conversion to draws from rstantools format #251

Open
wds15 opened this issue Aug 4, 2022 · 2 comments
Open

easy conversion to draws from rstantools format #251

wds15 opened this issue Aug 4, 2022 · 2 comments
Labels
documentation Improvements or additions to documentation feature New feature or request interface Interface improvements or changes

Comments

@wds15
Copy link

wds15 commented Aug 4, 2022

I am struggling with converting posteriors I get from rstantools things like posterior_linpred to a draws object correctly. The problem is that the chain information gets dropped. Here is an example illustrating what I'd like to have:

library(posterior)
#> Warning: package 'posterior' was built under R version 4.1.2
#> This is posterior version 1.2.2
#> 
#> Attaching package: 'posterior'
#> The following objects are masked from 'package:stats':
#> 
#>     mad, sd, var


samp <- as_draws_matrix(example_draws())

## posterior_* functions from stan tools return matrices like

rstantools_samp <- matrix(as.matrix(samp), niterations(samp)*nchains(samp), nvariables(samp))
colnames(rstantools_samp) <- variables(samp)
head(rstantools_samp)
#>            mu      tau   theta[1]    theta[2]   theta[3] theta[4]    theta[5]
#> [1,] 2.005831 2.767367  3.9617520  0.27123540 -0.7431706 2.104805  0.92348879
#> [2,] 1.458316 6.979976  0.1237101 -0.06901539  0.9518270 7.281225 -0.06195211
#> [3,] 5.814947 9.677075 21.2510465 14.93055775  1.8290945 1.381443  0.53106337
#> [4,] 6.849586 4.788366 14.6996540  8.58618604  2.6749150 4.393232  4.75807198
#> [5,] 1.805168 2.848165  5.9600546  1.15573721  3.1088628 1.994890  0.76885094
#> [6,] 3.841243 4.083357  5.7601096  9.90920447 -0.9956266 5.328625  5.88894271
#>       theta[6]  theta[7]  theta[8]
#> [1,]  1.650237  3.320019  4.848542
#> [2,] 11.257502  9.621128 -8.640446
#> [3,]  7.155371 14.802013 -1.736363
#> [4,]  8.101547  9.491277  5.281551
#> [5,]  4.656270  1.208251 -4.540236
#> [6,] -1.701463  2.780403  7.075855

dim(rstantools_samp)
#> [1] 400  10

## things are order by chain, so we have

all(rstantools_samp[1:100,1] == subset_draws(samp, variable="mu", chain=1))
#> [1] TRUE
all(rstantools_samp[101:200,1] == subset_draws(samp, variable="mu", chain=2))
#> [1] TRUE

## now we should have a posterior function which lets me create from
## rstantools_samp a posterior draws thing which knows the number of
## chains. This does not work:

as_draws_matrix(rstantools_samp, .nchains=4)
#> # A draws_matrix: 400 iterations, 1 chains, and 10 variables
#>     variable
#> draw   mu tau theta[1] theta[2] theta[3] theta[4] theta[5] theta[6]
#>   1  2.01 2.8     3.96    0.271    -0.74      2.1    0.923      1.7
#>   2  1.46 7.0     0.12   -0.069     0.95      7.3   -0.062     11.3
#>   3  5.81 9.7    21.25   14.931     1.83      1.4    0.531      7.2
#>   4  6.85 4.8    14.70    8.586     2.67      4.4    4.758      8.1
#>   5  1.81 2.8     5.96    1.156     3.11      2.0    0.769      4.7
#>   6  3.84 4.1     5.76    9.909    -1.00      5.3    5.889     -1.7
#>   7  5.47 4.0     4.03    4.151    10.15      6.6    3.741     -2.2
#>   8  1.20 1.5    -0.28    1.846     0.47      4.3    1.467      3.3
#>   9  0.15 3.9     1.81    0.661     0.86      4.5   -1.025      1.1
#>   10 7.17 1.8     6.08    8.102     7.68      5.6    7.106      8.5
#> # ... with 390 more draws, and 2 more variables

Created on 2022-08-04 by the reprex package (v2.0.1)

What I can do is to crudely set the nchains attribute to the number of chains. So I think the above should just work and give me a draws thing with 4 chains... this obviously requires documented formatting of the input samples to be column major sorted...

@mjskay
Copy link
Collaborator

mjskay commented Aug 4, 2022

One option is to ingest posterior_...() function output via rvar(), as the internal format of rvar() is (by design) the same as the output of those functions. rvar() lets you set the number of chains:

library(posterior)
library(rstanarm)

mtcars_subset = mtcars[, c("hp", "cyl", "mpg")]

m = stan_glm(mpg ~ hp*cyl, data = mtcars_subset, chains = 4)

epred = rvar(posterior_epred(m), nchains = 4)
epred
#> rvar<1000,4>[32] mean ± sd:
#> 
#>           Mazda RX4       Mazda RX4 Wag          Datsun 710      Hornet 4 Drive 
#>          20 ± 0.79           20 ± 0.79           26 ± 0.93           20 ± 0.79  
#>   Hornet Sportabout             Valiant          Duster 360           Merc 240D 
#>          16 ± 0.88           21 ± 0.80           15 ± 1.01           28 ± 1.18  
#>            Merc 230            Merc 280           Merc 280C          Merc 450SE 
#>          26 ± 0.94           20 ± 0.81           20 ± 0.81           15 ± 0.85  
#>          Merc 450SL         Merc 450SLC  Cadillac Fleetwood Lincoln Continental 
#>          15 ± 0.85           15 ± 0.85           15 ± 0.79           15 ± 0.81  
#>   Chrysler Imperial            Fiat 128         Honda Civic      Toyota Corolla 
#>          15 ± 0.89           28 ± 1.10           29 ± 1.41           28 ± 1.12  
#>       Toyota Corona    Dodge Challenger         AMC Javelin          Camaro Z28 
#>          26 ± 0.97           16 ± 1.09           16 ± 1.09           15 ± 1.01  
#>    Pontiac Firebird           Fiat X1-9       Porsche 914-2        Lotus Europa 
#>          16 ± 0.88           28 ± 1.10           26 ± 0.91           24 ± 1.25  
#>      Ford Pantera L        Ferrari Dino       Maserati Bora          Volvo 142E 
#>          14 ± 1.21           18 ± 1.44           13 ± 2.11           25 ± 1.17

This can be especially useful for the posterior_...() functions since you can put the resulting rvars in data frame alongside the data used to make the predictions:

cbind(mtcars_subset, epred = epred)
#>                      hp cyl  mpg                epred
#> Mazda RX4           110   6 21.0 20.48962 ± 0.7908150
#> Mazda RX4 Wag       110   6 21.0 20.48962 ± 0.7908150
#> Datsun 710           93   4 22.8 25.80881 ± 0.9256434
#> Hornet 4 Drive      110   6 21.4 20.48962 ± 0.7908150
#> Hornet Sportabout   175   8 18.7 15.51820 ± 0.8760158
#> Valiant             105   6 18.1 20.70694 ± 0.8036316
#> Duster 360          245   8 14.3 14.55358 ± 1.0112446
#> Merc 240D            62   4 24.4 28.07635 ± 1.1825719
#> Merc 230             95   4 22.8 25.66252 ± 0.9437801
#> Merc 280            123   6 19.2 19.92459 ± 0.8124373
#> Merc 280C           123   6 17.8 19.92459 ± 0.8124373
#> Merc 450SE          180   8 16.4 15.44930 ± 0.8459401
#> Merc 450SL          180   8 17.3 15.44930 ± 0.8459401
#> Merc 450SLC         180   8 15.2 15.44930 ± 0.8459401
#> Cadillac Fleetwood  205   8 10.4 15.10479 ± 0.7862891
#> Lincoln Continental 215   8 10.4 14.96699 ± 0.8091412
#> Chrysler Imperial   230   8 14.7 14.76028 ± 0.8889254
#> Fiat 128             66   4 32.4 27.78376 ± 1.1026873
#> Honda Civic          52   4 30.4 28.80781 ± 1.4145373
#> Toyota Corolla       65   4 33.9 27.85691 ± 1.1217972
#> Toyota Corona        97   4 21.5 25.51622 ± 0.9659045
#> Dodge Challenger    150   8 15.5 15.86271 ± 1.0899219
#> AMC Javelin         150   8 15.2 15.86271 ± 1.0899219
#> Camaro Z28          245   8 13.3 14.55358 ± 1.0112446
#> Pontiac Firebird    175   8 19.2 15.51820 ± 0.8760158
#> Fiat X1-9            66   4 27.3 27.78376 ± 1.1026873
#> Porsche 914-2        91   4 26.0 25.95510 ± 0.9117325
#> Lotus Europa        113   4 30.4 24.34588 ± 1.2535618
#> Ford Pantera L      264   8 15.8 14.29175 ± 1.2067307
#> Ferrari Dino        175   6 19.7 17.66450 ± 1.4376777
#> Maserati Bora       335   8 15.0 13.31335 ± 2.1102372
#> Volvo 142E          109   4 21.4 24.63846 ± 1.1669329

And if you do want it as a draws_matrix, you can use as_draws_matrix():

as_draws_matrix(epred)
#> # A draws_matrix: 1000 iterations, 4 chains, and 32 variables
#>     variable
#> draw x[Mazda RX4] x[Mazda RX4 Wag] x[Datsun 710] x[Hornet 4 Drive]
#>   1            21               21            26                21
#>   2            21               21            25                21
#>   3            18               18            25                18
#>   4            21               21            24                21
#>   5            21               21            25                21
#>   6            21               21            24                21
#>   7            23               23            28                23
#>   8            21               21            26                21
#>   9            21               21            28                21
#>   10           22               22            25                22
#>     variable
#> draw x[Hornet Sportabout] x[Valiant] x[Duster 360] x[Merc 240D]
#>   1                    15         21            13           28
#>   2                    15         21            13           28
#>   3                    14         19            15           27
#>   4                    17         21            17           26
#>   5                    17         21            16           27
#>   6                    17         21            17           26
#>   7                    15         23            12           30
#>   8                    16         21            16           28
#>   9                    15         22            13           30
#>   10                   18         22            17           27
#> # ... with 3990 more draws, and 24 more variables

That said, it does seem like since draws_matrix() has an .nchains argument, perhaps as_draws_matrix() should too?

@wds15
Copy link
Author

wds15 commented Aug 5, 2022

Hi!

Indeed, as_draws_matrix(rvar(rstantools_samp, nchains=4)) gives me what I want for the example I quoted. Maybe all of the as_draws_* should have a .nchains argument? Certainly, the as_draws_matrix needs it... and this needs doc on the format posterior expects things to be (column-major).

Thanks!

@paul-buerkner paul-buerkner added documentation Improvements or additions to documentation feature New feature or request interface Interface improvements or changes labels Nov 29, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation feature New feature or request interface Interface improvements or changes
Projects
None yet
Development

No branches or pull requests

3 participants