Skip to content

Commit

Permalink
Merge pull request #108 from HIIT/wrapper_fix
Browse files Browse the repository at this point in the history
Wrapper fix
  • Loading branch information
vuolleko authored Jan 23, 2017
2 parents ad7190a + e223a3d commit 6db3dfb
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 13 deletions.
1 change: 1 addition & 0 deletions elfi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from elfi.storage import *
from elfi.visualization import *
from elfi.inference_task import InferenceTask
from elfi.wrapper import *
from elfi.env import client, inference_task, new_inference_task
from elfi import tools

Expand Down
1 change: 1 addition & 0 deletions elfi/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from elfi.bo.acquisition import LCBAcquisition, SecondDerivativeNoiseMixin, RbfAtPendingPointsMixin

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

"""Implementations of some ABC algorithms.
Expand Down
25 changes: 14 additions & 11 deletions elfi/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
class Wrapper():
""" Wraps an external command to work as a callable operation for a node.
Currently only supports sequential operations (not vectorized).
Currently only supports sequential operations (not vectorized). This should be
enforced in ELFI by setting batch_size=1 in the ABC method.
Parameters
----------
Expand All @@ -26,28 +27,31 @@ def __init__(self, command_template="", post=None, pre=None):

@staticmethod
def process_elfi_internals(command_template, args, kwargs):
""" Replace 'prng' in kwargs with a seed from the generator if present in template """
""" Replace 'random_state' in kwargs with a seed from the generator if present in template """
proc_args = list()
for a in args:
if isinstance(a, np.ndarray):
if a.shape == (1,):
if a.shape == (1, 1):
# take single values out of array
proc_args.append(a[0])
proc_args.append(a.item())
else:
raise NotImplementedError("Wrapper does not yet support array arguments")
else:
proc_args.append(a)
if "prng" in kwargs.keys():
if "random_state" in kwargs.keys():
if "{seed}" in command_template:
if isinstance(kwargs["prng"], np.random.RandomState):
kwargs["seed"] = str(kwargs["prng"].randint(np.iinfo(np.uint32).max))
del kwargs["prng"]
if isinstance(kwargs["random_state"], np.random.RandomState):
kwargs["seed"] = str(kwargs["random_state"].randint(np.iinfo(np.uint32).max))
del kwargs["random_state"]
return command_template, proc_args, kwargs

@staticmethod
def read_nparray(stdout):
""" Interpret the stdout as a space-separated numpy array """
return np.fromstring(stdout, sep=" ")
""" Interpret the stdout as a space-separated numpy array.
"""
arr = np.fromstring(stdout, sep=" ")
arr = arr[None, :] # for compatibility with core
return arr

def __call__(self, *args, **kwargs):
""" Executes the wrapped command, with additional arguments and keyword arguments.
Expand All @@ -65,4 +69,3 @@ def __call__(self, *args, **kwargs):
argv = command.split(" ")
stdout = check_output(argv, universal_newlines=True)
return self.post(stdout)

4 changes: 2 additions & 2 deletions tests/unit/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def test_echo_non_string_args(self):
ret = wrapper(1)
assert ret == 1

def test_echo_1d_array_args(self):
def test_echo_2d_array_args(self):
command = "echo {0}"
wrapper = Wrapper(command, post=int)
ret = wrapper(np.array([1]))
ret = wrapper(np.array([[1]]))
assert ret == 1

0 comments on commit 6db3dfb

Please sign in to comment.