Skip to content

Commit

Permalink
compute ms ssim in test steps
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidtronix committed Jan 21, 2025
1 parent 122daa6 commit 7e58456
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ml4h/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,7 @@ def __init__(self, name="multi_scale_ssim", **kwargs):

def update_state(self, y_true, y_pred, max_val, sample_weight=None):
# Calculate MS-SSIM for the batch
ssim = tf.image.ssim_multiscale(y_true, y_pred, max_val=max_val, power_factors=[0.1, 0.2, 0.4, 0.3])
ssim = tf.image.ssim_multiscale(y_true, y_pred, max_val=max_val, power_factors=[0.25, 0.25, 0.25, 0.25])
if sample_weight is not None:
ssim = tf.multiply(ssim, sample_weight)

Expand Down
6 changes: 3 additions & 3 deletions ml4h/models/diffusion_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,14 @@ def compile(self, **kwargs):
self.mae_metric = tf.keras.metrics.MeanAbsoluteError(name="mae")
if self.tensor_map.axes() == 3 and self.inspect_model:
self.kid = KernelInceptionDistance(name = "kid", input_shape = self.tensor_map.shape, kernel_image_size=299)
self.inception_score = InceptionScore(name = "is", input_shape = self.tensor_map.shape, kernel_image_size=299)
self.ms_ssim = MultiScaleSSIM()

@property
def metrics(self):
m = [self.noise_loss_tracker, self.image_loss_tracker, self.mse_metric, self.mae_metric]
if self.tensor_map.axes() == 3 and self.inspect_model:
m.append(self.kid)
m.append(self.inception_score)
m.append(self.ms_ssim)
return m

def denormalize(self, images):
Expand Down Expand Up @@ -488,7 +488,7 @@ def test_step(self, images_original):
num_images=self.batch_size, diffusion_steps=20
)
self.kid.update_state(images, generated_images)
self.inception_score.update_state(images, generated_images)
self.ms_ssim.update_state(images, generated_images, 255)

return {m.name: m.result() for m in self.metrics}

Expand Down

0 comments on commit 7e58456

Please sign in to comment.