You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@junpenglao I'm taking a look at implementing the run_inference_loop from here. I'm running into a potential issue. It seems as though some inference algorithms require more than rng_key and state as inputs to their step function. Take for example sgld which requires a minibatch of data and a step size at each call to its .step.
I suspect this will also be the case too for the variational inference algorithms when they are in a more final state. In these situations, run_inference_loop cannot currently handle such cases.
Should I just leave these particular examples where this is the case alone? And then use the run_inference_loop wherever I can?
One potential solution to allow the incorporation of batches to be passed in during step is to modify run_inference_loop like so:
defrun_inference_algorithm(
rng_key,
initial_state_or_position,
inference_algorithm,
batches,
num_steps,
): ->tuple[State, State, Info]:
try:
initial_state=inference_algorithm.init(initial_state_or_position)
exceptTypeError:
# We assume initial_state is already in the right format.initial_state=initial_state_or_positionkeys=split(rng_key, num_steps)
@jax.jitdefone_step(state, rng_key):
batch=next(batches)
state, info=inference_algorithm.step(rng_key, state, batch)
returnstate, (state, info)
final_state, (state_history, info_history) =lax.scan(one_step, initial_state, keys)
returnfinal_state, state_history, info_history
Where batches is any iterator (possibly a generator) over batches of data examples. However, if batches is a generator that uses any jax operations, then I have run into issues with scan (not exactly sure the reason), but if batches is a generator that uses (say numpy) then it does work.
An example of a numpy data generator:
defdata_stream(seed, data, batch_size, data_size):
"""Return an iterator over batches of data."""rng=np.random.RandomState(seed)
num_batches=int(np.ceil(data_size/batch_size))
whileTrue:
perm=rng.permutation(data_size)
foriinrange(num_batches):
batch_idx=perm[i*batch_size : (i+1) *batch_size]
yielddata[batch_idx]
batches=data_stream(...)
This also works with (say huggingface dataset) data loader. Something like
@junpenglao I'm taking a look at implementing the
run_inference_loop
from here. I'm running into a potential issue. It seems as though some inference algorithms require more thanrng_key
andstate
as inputs to theirstep
function. Take for examplesgld
which requires a minibatch of data and a step size at each call to its.step
.I suspect this will also be the case too for the variational inference algorithms when they are in a more final state. In these situations,
run_inference_loop
cannot currently handle such cases.Should I just leave these particular examples where this is the case alone? And then use the
run_inference_loop
wherever I can?One potential solution to allow the incorporation of
batch
es to be passed in during step is to modifyrun_inference_loop
like so:Where
batches
is any iterator (possibly a generator) over batches of data examples. However, ifbatches
is a generator that uses anyjax
operations, then I have run into issues withscan
(not exactly sure the reason), but ifbatches
is a generator that uses (say numpy) then it does work.An example of a numpy data generator:
This also works with (say huggingface dataset) data loader. Something like
I'm not sure this would be the preferred solution. I am also In any case, I'll think about it some more.
Thanks!
The text was updated successfully, but these errors were encountered: