Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 600586503
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Jan 22, 2024
1 parent 9c958e7 commit bf70b5a
Show file tree
Hide file tree
Showing 29 changed files with 58 additions and 58 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ generated-members=
# Maximum number of characters on a single line.
max-line-length=80

# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt
# TODO: Direct pylint to exempt
# lines made too long by directives to pytype.

# Regexp for a line that is allowed to be longer than the limit.
Expand Down
8 changes: 4 additions & 4 deletions swirl_dynamics/lib/diffusion/vivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def __call__(self, inputs: Array, *, train: bool) -> Array:
name='conv_transpose_temporal_decoder',
)(x)

# TODO(lzepedanunez): Use unets.depth_to_space here instead.
# TODO: Use unets.depth_to_space here instead.
x = jnp.reshape(
x, (batch_size, *self.encoded_shapes, t, h, w, self.features_out)
)
Expand Down Expand Up @@ -665,7 +665,7 @@ class TransformerBlock(nn.Module):
mlp_dim: int
num_layers: int
num_heads: int
# TODO(lzepedanunez): encapsulate the configurations in its own container.
# TODO: encapsulate the configurations in its own container.
attention_config: ml_collections.ConfigDict | None = None
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
Expand All @@ -682,10 +682,10 @@ def __call__(self, inputs: Array, *, train: bool) -> Array:
dtype = jax.dtypes.canonicalize_dtype(self.dtype)

# Computing positional embeddings.
# TODO(lzepedanunez): Introduce more types of positional encoding.
# TODO: Introduce more types of positional encoding.
if self.positional_embedding == 'sinusoidal_3d':
batch, num_tokens, hidden_dim = inputs.shape
# TODO(lzepedanunez): change this one to handle non-square domains.
# TODO: change this one to handle non-square domains.
height = width = int(np.sqrt(num_tokens // self.temporal_dims))
if height * width * self.temporal_dims != num_tokens:
raise ValueError('Input is assumed to be square for sinusoidal init.')
Expand Down
4 changes: 2 additions & 2 deletions swirl_dynamics/lib/diffusion/vivit_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __call__(self, inputs: Array, emb: Array, *, train: bool) -> Array:
dtype = jax.dtypes.canonicalize_dtype(self.dtype)

# Choosing the type of embedding.
# TODO(lzepedanunez): add more embeddings in here.
# TODO: add more embeddings in here.
if self.positional_embedding == 'sinusoidal_3d':
batch, num_tokens, hidden_dim = inputs.shape
height = width = int(np.sqrt(num_tokens // self.temporal_dims))
Expand Down Expand Up @@ -324,7 +324,7 @@ def __call__(self, inputs: Array, emb: Array, *, train: bool) -> Array:
self.attention_config.get('attention_kernel_init_method',
'xavier')], # pytype: disable=attribute-error
temporal_dims=self.temporal_dims)
# TODO(lzepedanunez): implement factorized_dot_product_attention.
# TODO: implement factorized_dot_product_attention.
else:
raise ValueError(f'Unknown attention type {self.attention_config.type}') # pytype: disable=attribute-error

Expand Down
6 changes: 3 additions & 3 deletions swirl_dynamics/lib/networks/cycle_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class Generator(nn.Module):
use_skips: bool = True
use_global_skip: bool = True
dtype: jnp.dtype = jnp.float32
padding: str = "CIRCULAR" # TODO(lzepedanunez): Add one adapted for ERA5.
padding: str = "CIRCULAR" # TODO: Add one adapted for ERA5.
padding_transpose: str = "CIRCULAR"
use_weight_global_skip: bool = False
weight_skip: bool = False
Expand Down Expand Up @@ -288,7 +288,7 @@ def __call__(self, x: Array, is_training: bool) -> Array:
)(x)

# Use a transformer core.
# TODO(lzepedanunez) add a conformer model.
# TODO add a conformer model.
if self.use_attention:
b, *hw, c = x.shape
# Adding positional encoding.
Expand Down Expand Up @@ -339,7 +339,7 @@ def __call__(self, x: Array, is_training: bool) -> Array:
)(x)

elif self.upsample_mode == "deconv":
# TODO(lzepedanunez): use channel unrolling for the upsampling.
# TODO: use channel unrolling for the upsampling.
x = nn.ConvTranspose(
features=(self.ngf * mult) // 2,
kernel_size=self.kernel_size_upsampling,
Expand Down
2 changes: 1 addition & 1 deletion swirl_dynamics/lib/networks/nonlinear_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def __call__(self, inputs: Array) -> Array:
# shape : (2, num_freqs, 2) for sin-cos, \omega, and x-y.
y = omega * (x_i.reshape((1, 1, 2)) + a)

# TODO(lzepedanunez): create a funcion that creates the periodic features.
# TODO: create a funcion that creates the periodic features.
# Applying the trigonometric functions, which can be written as:
# [[1, 1],
# [sin(ω₁ x), sin(ω₁ y)],
Expand Down
2 changes: 1 addition & 1 deletion swirl_dynamics/lib/networks/rational_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class RationalMLP(nn.Module):
dtype: Any = jnp.float32
multi_rational: bool = False
use_bias: bool = True
# TODO(lzepedanunez): add precision flag to have more granular control
# TODO: add precision flag to have more granular control

@nn.compact
def __call__(self, inputs: Array) -> Array:
Expand Down
4 changes: 2 additions & 2 deletions swirl_dynamics/projects/debiasing/cycle_gan/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def run_generator_forward(
A tuple containing the generated samples.
"""

# TODO(lzepedanunez): perhaps use dictionaries instead of positional tuples.
# TODO: perhaps use dictionaries instead of positional tuples.
params_gen_a2b = params_gen[0]
params_gen_b2a = params_gen[1]

Expand Down Expand Up @@ -480,7 +480,7 @@ def loss_fn(
to be real data.
"""
# Split the States.
# TODO(lzepedanunez): specify how to split the parameters.
# TODO: specify how to split the parameters.

params_gen_a2b = params[0]
params_gen_b2a = params[1]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def create_loader_from_hdf5(
mean and std stats (if normalize=True, else dict
contains NoneType values).
"""
# TODO(lzepedanunez): create the data arrays following a similar convention.
# TODO: create the data arrays following a similar convention.
snapshots = hdf5_utils.read_single_array(
dataset_path,
f"{split}/u",
Expand Down
4 changes: 2 additions & 2 deletions swirl_dynamics/projects/debiasing/rectified_flow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def main(argv):
)

model = models.ReFlowModel(
# TODO(lzepedanunez): clean this part.
# TODO: clean this part.
input_shape=(
config.input_shapes[0][1] // config.spatial_downsample_factor[0],
config.input_shapes[0][2] // config.spatial_downsample_factor[0],
Expand Down Expand Up @@ -157,7 +157,7 @@ def main(argv):
base_dir=workdir,
options=ckpt_options,
),
# TODO(lzepedanunez) add a plot callback.
# TODO add a plot callback.
),
)

Expand Down
2 changes: 1 addition & 1 deletion swirl_dynamics/projects/debiasing/rectified_flow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ReFlowModel(models.BaseModel):
num_eval_time_levels: ClassVar[int] = 10

def initialize(self, rng: Array):
# TODO(lzepedanunez): Add a dtype object to ensure consistency of types.
# TODO: Add a dtype object to ensure consistency of types.
x = jnp.ones((1,) + self.input_shape)

return self.flow_model.init(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class DistributedReFlowTrainer(
):
"""Multi-device trainer for rectified flow models."""

# TODO(lzepedanunez): Write a test for this trainer.
# TODO: Write a test for this trainer.

# MRO: ReFlowTrainer > BasicDistributedTrainer > BasicTrainer
...
2 changes: 1 addition & 1 deletion swirl_dynamics/projects/ergodic/choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def dispatch(
Returns:
ScanOdeSolver | MultiStepScanOdeSolver
"""
# TODO(yairschiff): Profile if the moveaxis call required here introduces a
# TODO: Profile if the moveaxis call required here introduces a
# bottleneck
return {
"ExplicitEuler": ode.ExplicitEuler(time_axis_pos=1),
Expand Down
6 changes: 3 additions & 3 deletions swirl_dynamics/projects/ergodic/configs/ks_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_config():

# Model params
config.model = 'PeriodicConvNetModel' # 'Fno'
# TODO(yairschiff): Split CNN and FNO into separate configs
# TODO: Split CNN and FNO into separate configs
########### PeriodicConvNetModel ################
config.latent_dim = 48
config.num_levels = 4
Expand Down Expand Up @@ -120,7 +120,7 @@ def skip(
return False


# TODO(yairschiff): Refactor sweeps and experiment definition to use gin.
# TODO: Refactor sweeps and experiment definition to use gin.
# use option --sweep=False in the command line to avoid sweeping
def sweep(add):
"""Define param sweep."""
Expand Down Expand Up @@ -165,7 +165,7 @@ def sweep(add):
)


# TODO(yairschiff): Ablation!
# TODO: Ablation!
# def sweep(add):
# """Define param sweep."""
# # pylint: disable=line-too-long
Expand Down
2 changes: 1 addition & 1 deletion swirl_dynamics/projects/ergodic/configs/ks_1d_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def skip(
return False


# TODO(yairschiff): Refactor sweeps and experiment definition to use gin.
# TODO: Refactor sweeps and experiment definition to use gin.
# use option --sweep=False in the command line to avoid sweeping
def sweep(add):
"""Define param sweep."""
Expand Down
2 changes: 1 addition & 1 deletion swirl_dynamics/projects/ergodic/configs/lorenz63.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def skip(
return False


# TODO(yairschiff): Refactor sweeps and experiment definition to use gin.
# TODO: Refactor sweeps and experiment definition to use gin.
def sweep(add):
"""Define param sweep."""
# pylint: disable=line-too-long
Expand Down
4 changes: 2 additions & 2 deletions swirl_dynamics/projects/ergodic/configs/ns_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_config():
config.noise_level = 0.0

# Model params
# TODO(yairschiff): Split CNN and FNO into separate configs
# TODO: Split CNN and FNO into separate configs
config.model = 'PeriodicConvNetModel' # 'Fno' 'Fno2d'
########### PeriodicConvNetModel ################
config.latent_dim = 16
Expand Down Expand Up @@ -131,7 +131,7 @@ def skip(


# pylint: disable=line-too-long
# TODO(yairschiff): Refactor sweeps and experiment definition to use gin.
# TODO: Refactor sweeps and experiment definition to use gin.
# use option --sweep=False in the command line to avoid sweeping
def sweep(add):
"""Define param sweep."""
Expand Down
2 changes: 1 addition & 1 deletion swirl_dynamics/projects/ergodic/configs/ns_2d_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_config():
return config


# TODO(yairschiff): Refactor sweeps and experiment definition to use gin.
# TODO: Refactor sweeps and experiment definition to use gin.
def sweep(add):
"""Define param sweep."""
for seed in [42]:
Expand Down
4 changes: 2 additions & 2 deletions swirl_dynamics/projects/ergodic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

r"""The main entry point for running training loops."""
# TODO(yairschiff): Consider enabling float64 for Lorenz63 experiment
# TODO: Consider enabling float64 for Lorenz63 experiment

import json
from os import path as osp
Expand Down Expand Up @@ -82,7 +82,7 @@ def main(argv):

elif experiment == choices.Experiment.NS_2D:
fig_callback_cls = ns_2d.NS2dPlotFigures
# TODO(yairschiff): This state dim is temporary for FNO data, should be 256
# TODO: This state dim is temporary for FNO data, should be 256
state_dims = (
64 // config.spatial_downsample_factor,
64 // config.spatial_downsample_factor,
Expand Down
4 changes: 2 additions & 2 deletions swirl_dynamics/projects/ergodic/measure_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ def mmd(x: Array, y: Array) -> Array:
xx, yy, xy = (jnp.zeros_like(xx), jnp.zeros_like(xx), jnp.zeros_like(xx))

# Multiscale
# TODO(yairschiff): We may need to experiment with these bandwidths to have
# TODO: We may need to experiment with these bandwidths to have
# MMD loss better distinguish distributions, especially for high dim data
bandwidth_range = [0.2, 0.5, 0.9, 1.3]
for a in bandwidth_range:
xx += a**2 * (a**2 + dxx) ** -1
yy += a**2 * (a**2 + dyy) ** -1
xy += a**2 * (a**2 + dxy) ** -1

# TODO(yairschiff): We may want to use jnp.sqrt(...) here; see:
# TODO: We may want to use jnp.sqrt(...) here; see:
# https://arxiv.org/abs/1502.02761
return jnp.mean(xx + yy - 2.0 * xy)

Expand Down
16 changes: 8 additions & 8 deletions swirl_dynamics/projects/ergodic/stable_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def __post_init__(self):
self.pred_integrator = functools.partial(
pred_integrator, ode.nn_module_to_dynamics(self.conf.dynamics_model)
)
# TODO(lzepedanunez): check if this is compatible with distributed training.
# TODO: check if this is compatible with distributed training.
self.vmapped_measure_dist = jax.vmap(self.conf.measure_dist, in_axes=(1, 1))

def initialize(self, rng):
Expand All @@ -97,7 +97,7 @@ def loss_fn(
tspan = batch["tspan"].reshape((-1,))
rollout_weight = batch["rollout_weight"].reshape((-1,))

# TODO(lzepedanunez): implement the logic in the Neural Markov paper.
# TODO: implement the logic in the Neural Markov paper.
if self.conf.add_noise:
noise = self.conf.noise_level + jax.random.normal(rng, x0.shape)
x0 += noise
Expand Down Expand Up @@ -131,7 +131,7 @@ def loss_fn(

# Compare to true trajectory last step.
if self.conf.use_sobolev_norm:
# TODO(yairschiff): Rollout weighting not implemented for this case!
# TODO: Rollout weighting not implemented for this case!
# The spatial dimension is the length of the shape minus 2,
# which accounts for the batch, frame, and channel dimensions.
dim = len(pred.shape) - 2
Expand Down Expand Up @@ -163,9 +163,9 @@ def loss_fn(
)

# Compare to full reference trajectory.
# TODO(lzepedanunez): this is code is repeated.
# TODO: this is code is repeated.
if self.conf.use_sobolev_norm:
# TODO(yairschiff): Rollout weighting not implemented for this case!
# TODO: Rollout weighting not implemented for this case!
dim = len(pred.shape) - 3
l2 = ergodic_utils.sobolev_norm(
pred - true[:, 1:, ...],
Expand Down Expand Up @@ -221,7 +221,7 @@ def eval_fn(
pred_trajs *= self.conf.normalize_stats["std"]
pred_trajs += self.conf.normalize_stats["mean"]

# TODO(lzepedanunez): this only computes the local sinkhorn distance.
# TODO: this only computes the local sinkhorn distance.
sd = measure_distances.sinkhorn_div(
pred_trajs[:, -1, ...], trajs[:, -1, ...]
)
Expand Down Expand Up @@ -356,7 +356,7 @@ def preprocess_train_batch(
num_time_steps += self.conf.num_rollout_steps + 1
else:
num_time_steps = self.conf.num_rollout_steps + 1
# TODO(yairschiff): Should we remove this random sampling?
# TODO: Should we remove this random sampling?
if self.conf.use_pushfwd and num_time_steps > 2:
num_time_steps = jax.random.randint(
rng, (1,), minval=2, maxval=num_time_steps + 1
Expand Down Expand Up @@ -478,7 +478,7 @@ def preprocess_train_batch(
num_time_steps += self.conf.num_rollout_steps + 1
else:
num_time_steps = self.conf.num_rollout_steps + 1
# TODO(yairschiff): Should we remove this random sampling?
# TODO: Should we remove this random sampling?
if self.conf.use_pushfwd and num_time_steps > 2:
num_time_steps = jax.random.randint(
rng, (1,), minval=2, maxval=num_time_steps + 1
Expand Down
8 changes: 4 additions & 4 deletions swirl_dynamics/projects/ergodic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
DynamicsFn = Callable[[Array, Array, PyTree], Array]


# TODO(yairschiff): Move this method to swirl_dynamics.data.utils
# TODO: Move this method to swirl_dynamics.data.utils
def generate_data_from_known_dynamcics(
integrator: ode.ScanOdeSolver,
dynamics: DynamicsFn,
Expand All @@ -47,7 +47,7 @@ def generate_data_from_known_dynamcics(
return integrator(dynamics, x0, tspan, {})[warmup:]


# TODO(yairschiff): Move this method to swirl_dynamics.data.utils
# TODO: Move this method to swirl_dynamics.data.utils
def create_loader_from_hdf5(
num_time_steps: int,
time_stride: int,
Expand Down Expand Up @@ -180,7 +180,7 @@ def create_loader_from_tfds(
"""Load pre-computed trajectories dumped to hdf5 file.
This loader has fewer options that the one from hdf5, in particular, it has
no normalization. TODO(lzepedanunez): Add normalization.
no normalization. TODO: Add normalization.
Arguments:
num_time_steps: Number of time steps to include in each trajectory.
Expand Down Expand Up @@ -241,7 +241,7 @@ def create_loader_from_tfds(
return loader, {"mean": None, "std": None}


# TODO(lzepedanunez): find a better place for this function and refactor with
# TODO: find a better place for this function and refactor with
# vmap.
def sobolev_norm(
u: Array, s: int = 1, dim: int = 2, length: float = 1.0
Expand Down
2 changes: 1 addition & 1 deletion swirl_dynamics/projects/probabilistic_diffusion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def loss_fn(
sigma=sigma,
cond=cond,
is_training=True,
rngs={"dropout": rng3}, # TODO(lzepedanunez): refactor this.
rngs={"dropout": rng3}, # TODO: refactor this.
)
loss = jnp.mean(vmapped_mult(weights, jnp.square(denoised - batch["x"])))
metric = dict(loss=loss)
Expand Down
Loading

0 comments on commit bf70b5a

Please sign in to comment.