Skip to content

Commit

Permalink
remove deprecated way of loading location specific paramaters
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Feb 13, 2024
1 parent cda0606 commit 9e87e01
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,6 @@ def __init__(self, config):
combine_method = config.model.progressive_combine.lower()
combiner = functools.partial(Combine, method=combine_method)

# DEPRECATED: the old way to include a learnable feature map of location-specific parameters
if config.model.map_features > 0:
self.map = nn.Parameter(torch.zeros(config.model.map_features, config.data.image_size, config.data.image_size))

modules = []
# timestep/noise_level embedding; only for continuous training
if embedding_type == 'fourier':
Expand Down Expand Up @@ -150,7 +146,7 @@ def __init__(self, config):
else:
cond_time_channels = 0

channels = cond_var_channels + cond_time_channels + output_channels + config.model.map_features + config.model.loc_spec_channels
channels = cond_var_channels + cond_time_channels + output_channels + config.model.loc_spec_channels
if progressive_input != 'none':
input_pyramid_ch = channels

Expand Down Expand Up @@ -257,9 +253,6 @@ def forward(self, x, cond, time_cond):

# combine the modelled data and the conditioning inputs
x = torch.cat([x, cond], dim=1)
# DEPRECATED: old way to add a map of location-specific features to input
if self.config.model.map_features > 0:
x = torch.cat([x, self.map.broadcast_to((x.shape[0], *self.map.shape))], dim=1)
# timestep/noise_level embedding; only for continuous training
modules = self.all_modules
m_idx = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,10 @@ def __init__(self, config):
marginal_prob_std=None
cond_var_channels, output_channels = list(map(len, get_variables(config.data.dataset_name)))
cond_time_channels = 3
input_channels = output_channels + cond_var_channels + cond_time_channels + config.model.map_features + config.model.loc_spec_channels
input_channels = output_channels + cond_var_channels + cond_time_channels + config.model.loc_spec_channels
channels=[32, 64, 128, 256]
embed_dim=256

# DEPRECATED: the old way to include a learnable feature map of location-specific parameters
if config.model.map_features > 0:
self.map = nn.Parameter(torch.zeros(config.model.map_features, USABLE_IMAGE_SIZE, USABLE_IMAGE_SIZE))

# Gaussian random feature embedding layer for time
self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
nn.Linear(embed_dim, embed_dim))
Expand Down Expand Up @@ -96,9 +92,6 @@ def forward(self, x, cond, t):
output_shape = x.shape
# combine the modelled data and the conditioning inputs
x = torch.cat([x, cond], dim=1)[..., :USABLE_IMAGE_SIZE, :USABLE_IMAGE_SIZE]
# DEPRECATED: old way to add a map of location-specific features to input
if self.config.model.map_features > 0:
x = torch.cat([x, self.map.broadcast_to((x.shape[0], *self.map.shape))], dim=1)
# Obtain the Gaussian random feature embedding for t
embed = self.act(self.embed(t))
# Encoding path
Expand Down

0 comments on commit 9e87e01

Please sign in to comment.