-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
130 lines (113 loc) · 4.87 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import re
import os
import logging
import numpy as np
import arviz as az
import scipy.stats as stats
from cmdstanpy import CmdStanModel
from pathlib import Path
from shutil import copy, move
from datetime import datetime
class stan_model:
"""A thin wrapper around CmdStanModel to compile, sample, diagnose, and save."""
def __init__(self, stan_file):
self.stan_file = stan_file
self.model_name = stan_file.stem
self.stan_dir = stan_file.parent
self.log_file = self.stan_dir / f'{self.model_name}.log'
self.logger, self.handler = self._setup_logger()
def _setup_logger(self):
if os.path.exists(self.log_file):
os.remove(self.log_file)
logger = logging.getLogger("cmdstanpy")
logger.handlers = []
logger.setLevel(logging.DEBUG)
handler = logging.FileHandler(self.log_file)
handler.setLevel(logging.DEBUG)
handler.setFormatter(
logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
"%H:%M:%S",
)
)
logger.addHandler(handler)
return logger, handler
def change_namespace(self, user_header, name):
filepath = user_header
with open(filepath, 'r', encoding='utf-8') as file:
lines = file.readlines()
with open(filepath, "w", encoding='utf-8') as file:
for line in lines:
# Matches all characters except a single space character
match = re.match(r'^namespace (\S+)_model_namespace {$', line)
if match:
new_line = f'namespace {name}_model_namespace {{\n'
file.write(new_line)
else:
file.write(line)
def compile(self, user_header=None, stanc_options=None, cpp_options=None, **kwargs):
bnb_hpp_dir = user_header.parent / 'bnb'
for file in Path(bnb_hpp_dir).rglob('*.hpp'):
self.change_namespace(file, self.model_name)
# Compile the Stan model
print(f"Start compiling the model {self.stan_file}")
self.model = CmdStanModel(stan_file=self.stan_file,
force_compile=True,
user_header=user_header,
stanc_options=stanc_options,
cpp_options=cpp_options,
**kwargs)
def sample(self, data, **kwargs):
self.data = data
self.fit = self.model.sample(data=data, **kwargs)
self.df = self.fit.draws_pd()
separator = "=" * 100
self.logger.debug(separator)
self.logger.debug("="*40+" Sampling completed "+"="*40)
self.logger.debug(separator)
for index, file_path in enumerate(self.fit.runset.stdout_files, start=1):
with open(file_path, 'r', encoding='utf-8') as out_file:
content = out_file.read()
self.logger.debug('Runset Stdout Files (self.fit.runset.stdout_files[%d]): %s', index, content)
def diagnose(self):
self.diagnosis = self.fit.diagnose()
print(self.diagnosis)
self.az_data = az.from_cmdstanpy(
posterior=self.fit,
posterior_predictive=None,
log_likelihood="log_lik",
observed_data={"y": self.data["y"]},
save_warmup=False,
)
self.loo = az.loo(self.az_data, pointwise=True, scale="log")
self.waic = az.waic(self.az_data, pointwise=True, scale="log")
print(self.loo)
print(self.waic)
self.logger.debug("DataFrame (self.df):\n%s", self.df)
self.logger.debug("LOO (self.loo):\n%s", self.loo)
self.logger.debug("WAIC (self.waic):\n%s", self.waic)
self.logger.debug("Diagnosis (self.diagnosis):\n%s", self.diagnosis)
def save(self):
self.logger.removeHandler(self.handler)
self.handler.close()
logging.shutdown()
# Setup save directory and save Stan csv
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
stan_save_dir = os.path.join(self.stan_dir, f"{self.model_name}_{timestamp}")
if not os.path.exists(stan_save_dir):
os.makedirs(stan_save_dir)
self.fit.save_csvfiles(dir=stan_save_dir)
copy(self.stan_file, stan_save_dir)
move(self.log_file, stan_save_dir)
# Clean temporary files
# exe_file = os.path.join(self.stan_dir, self.model_name)
# if os.path.isfile(exe_file):
# os.remove(exe_file)
def beta_neg_binomial_rng(r, alpha, beta, y_max, size=1, seed=None):
if seed is not None:
np.random.seed(seed)
p = stats.beta.rvs(alpha, beta, size=size)
# N Number of successes, p probability of success
y = stats.nbinom.rvs(n=r,p=p)
y = np.minimum(y, y_max)
return y