Skip to content

Commit

Permalink
Merge pull request #394 from smattis/v3-steve
Browse files Browse the repository at this point in the history
Updates from review
  • Loading branch information
smattis authored Jul 8, 2020
2 parents fb9aa62 + 613e250 commit 03e52c0
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 9 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ install:
- pip install .
- pip install codecov pytest-cov Sphinx sphinx_rtd_theme
- pip install git+https://github.com/CU-Denver-UQ/LUQ
- pip install mpi4py

script:
- pytest --cov=./bet/ ./test/
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ Another option is to clone the repository and install BET using


## Dependencies
BET is tested on Python 3.6 and 3.7 (but should work on most recent Python 3 versions) and depends on [NumPy](http://www.numpy.org/), [SciPy](http://www.scipy.org/), [matplotlib](http://matplotlib.org/), [pyDOE](https://pythonhosted.org/pyDOE/), [pytest](https://docs.pytest.org/), and [mpi4py](https://mpi4py.readthedocs.io/en/stable/) (optional) (see [requirements.txt](requirements.txt) for version information). For some optional features [LUQ](https://github.com/CU-Denver-UQ/LUQ) is also required.
BET is tested on Python 3.6 and 3.7 (but should work on most recent Python 3 versions) and depends on [NumPy](http://www.numpy.org/), [SciPy](http://www.scipy.org/), [matplotlib](http://matplotlib.org/), [pyDOE](https://pythonhosted.org/pyDOE/), [pytest](https://docs.pytest.org/), and [mpi4py](https://mpi4py.readthedocs.io/en/stable/) (optional) (see [requirements.txt](requirements.txt) for version information). For some optional features [LUQ](https://github.com/CU-Denver-UQ/LUQ) is also required. mpi4py is required to take advantage of parallel features and requires an mpi implementation. It can be installed by:

pip install mpi4py


## License
[GNU Lesser General Public License (LGPL)](LICENSE.txt)
Expand Down
33 changes: 26 additions & 7 deletions bet/sampling/useLUQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class useLUQ:
from LUQ output.
"""

def __init__(self, predict_set, obs_set, lb_model, times):
def __init__(self, predict_set, lb_model, times, obs_set=None):
"""
Initialize the object.
:param predict_set: Sample set defining input prediction samples.
Expand Down Expand Up @@ -80,7 +80,16 @@ def get_obs(self):
"""
self.obs_time_series = self.lb_model(self.obs_set.get_values(), self.times)

def initialize(self, predicted_time_series, obs_time_series, times):
def set_observed_time_series(self, obs_time_series):
"""
Set observed time series data manually.
:param obs_time_series: time series data
:type obs_time_series:
:return: :class:`numpy.ndarray` with shape (num_obs, num_times)
"""
self.obs_time_series = obs_time_series

def initialize(self, predicted_time_series=None, obs_time_series=None, times=None):
"""
Initialize the LUQ object. This can be used manually if time series are pre-computed.
Expand All @@ -96,6 +105,13 @@ def initialize(self, predicted_time_series, obs_time_series, times):
except ImportError:
raise missing_module("luq cannot be imported")

if predicted_time_series is None:
predicted_time_series = self.predicted_time_series
if obs_time_series is None:
obs_time_series = self.obs_time_series
if times is None:
times = self.times

self.learn = LUQ(predicted_time_series, obs_time_series, times)

def setup(self):
Expand Down Expand Up @@ -128,7 +144,7 @@ def make_disc(self):
"""
Construct `bet.sample.discretization` objects for predict and obs sets.
:return: predict_disc, obs_disc
:rtype: `bet.sample.discretization`, `bet.sample.discretization`
:rtype: `bet.sample.discretization`, `bet.sample.discretization` or None if no observation set.
"""
out_dim = self.learn.num_pcs[0]

Expand All @@ -144,12 +160,15 @@ def make_disc(self):
disc1 = sample.discretization(input_sample_set=self.predict_set,
output_sample_set=predict_output,
output_observed_set=obs_output)
disc1.local_to_global()

# Observation discretization
disc2 = sample.discretization(input_sample_set=self.obs_set,
output_sample_set=obs_output)
disc1.local_to_global()
disc2.local_to_global()
if self.obs_set is None:
disc2 = None
else:
disc2 = sample.discretization(input_sample_set=self.obs_set,
output_sample_set=obs_output)
disc2.local_to_global()

return disc1, disc2

Expand Down
1 change: 1 addition & 0 deletions doc/examples_overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ Examples
Examples
------------
Documented examples can be found `here <https://github.com/UT-CHG/BET/tree/master/examples>`_.
Jupyter notebooks of examples can be found `here <https://github.com/CU-Denver-UQ/BET-notebooks>`_.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@ scipy>=1.3.1
matplotlib>=3.0
pyDOE
pytest
mpi4py
6 changes: 6 additions & 0 deletions test/test_postProcess/test_plotP.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ def test_plot_marginals_1D(self):
plotP.plot_1D_marginal_probs(marginals, bins, self.samples,
filename="file", interactive=False)
go = True
if os.path.exists("file_1D_0.png") and comm.rank == 0:
os.remove("file_1D_0.png")
if os.path.exists("file_1D_1.png") and comm.rank == 0:
os.remove("file_1D_1.png")
except (RuntimeError, TypeError, NameError):
go = False
nptest.assert_equal(go, True)
Expand All @@ -222,6 +226,8 @@ def test_plot_marginals_2D(self):
go = True
if os.path.exists("file_2D_0_1.png") and comm.rank == 0:
os.remove("file_2D_0_1.png")
if os.path.exists("file_surf_0_1.png") and comm.rank == 0:
os.remove("file_surf_0_1.png")
except (RuntimeError, TypeError, NameError):
go = False
nptest.assert_equal(go, True)
Expand Down

0 comments on commit 03e52c0

Please sign in to comment.