Skip to content

Commit

Permalink
fixed all TODO LG
Browse files Browse the repository at this point in the history
  • Loading branch information
lcgraham committed May 17, 2016
1 parent dd2450c commit dd9d6c7
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ You will need to run sphinx-apidoc AND reinstall BET anytime a new module or met

Useful scripts are contained in ``examples/``

Tests
-----

To run tests in serial call::

nosetests tests

To run tests in parallel call::

mpirun -np NPROC nosetets tests

Dependencies
------------

Expand Down
10 changes: 6 additions & 4 deletions bet/sampling/adaptiveSampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def generalized_chains(self, input_obj, t_set, kern,
hot_start=0):
"""
Basic adaptive sampling algorithm using generalized chains.
.. todo::
Test HOTSTART from parallel files using different and same num proc
:param string initial_sample_type: type of initial sample random (or r),
latin hypercube(lhs), or space-filling curve(TBD)
Expand Down Expand Up @@ -331,7 +335,7 @@ def generalized_chains(self, input_obj, t_set, kern,
# be the one with the matching processor number (doesn't
# really matter)
mdat = sio.loadmat(mdat_files[comm.rank])
disc = sample.load_discretization(savefile)
disc = sample.load_discretization(mdat_files[comm.rank])
kern_old = np.squeeze(mdat['kern_old'])
all_step_ratios = np.squeeze(mdat['step_ratios'])
elif hot_start == 1 and len(mdat_files) != comm.size:
Expand Down Expand Up @@ -388,7 +392,6 @@ def generalized_chains(self, input_obj, t_set, kern,
kern_old = np.squeeze(mdat['kern_old'])
all_step_ratios = np.squeeze(mdat['step_ratios'])
chain_length = disc.check_nums()/self.num_chains
#mdat_files = []
# reshape if parallel
if comm.size > 1:
temp_input = np.reshape(disc._input_sample_set.\
Expand All @@ -397,7 +400,6 @@ def generalized_chains(self, input_obj, t_set, kern,
temp_output = np.reshape(disc._output_sample_set.\
get_values(), (self.num_chains, chain_length,
-1), 'F')

all_step_ratios = np.reshape(all_step_ratios,
(self.num_chains, chain_length), 'F')
# SPLIT DATA IF NECESSARY
Expand Down Expand Up @@ -427,7 +429,7 @@ def generalized_chains(self, input_obj, t_set, kern,
get_values_local()[-self.num_chains_pproc:, :])

# Determine how many batches have been run
start_ind = disc.check_nums()/self.num_chains_pproc
start_ind = disc._input_sample_set.get_values_local().shape[0]/self.num_chains_pproc

mdat = dict()
self.update_mdict(mdat)
Expand Down
3 changes: 2 additions & 1 deletion test/test_sampling/test_adaptiveSampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def ifun(outputs):
assert np.all(all_step_ratios <= t_set.max_ratio)

# did the savefiles get created? (proper number, contain proper keys)
comm.barrier()
mdat = dict()
if comm.rank == 0:
mdat = sio.loadmat(savefile)
Expand Down Expand Up @@ -229,6 +230,7 @@ def map_10t4(x):


def tearDown(self):
comm.barrier()
for f in self.savefiles:
if comm.rank == 0 and os.path.exists(f+".mat"):
os.remove(f+".mat")
Expand Down Expand Up @@ -424,7 +426,6 @@ def ifun(outputs):
assert asr > t_set.min_ratio
assert asr < t_set.max_ratio

#TODO: LG Fix
def test_generalized_chains(self):
"""
Test :met:`bet.sampling.adaptiveSampling.sampler.generalized_chains`
Expand Down

0 comments on commit dd9d6c7

Please sign in to comment.