From 972058c6f053f534cc6f6daf3faefede04c44666 Mon Sep 17 00:00:00 2001 From: Leon Ericsson Date: Tue, 6 Feb 2024 21:42:07 +0100 Subject: [PATCH] updated calls to use loc &scale --- python/mlx/nn/init.py | 6 +++--- python/mlx/nn/layers/embedding.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mlx/nn/init.py b/python/mlx/nn/init.py index 5afc6170e..6596ba741 100644 --- a/python/mlx/nn/init.py +++ b/python/mlx/nn/init.py @@ -60,7 +60,7 @@ def normal( """ def initializer(a: mx.array) -> mx.array: - return std * mx.random.normal(shape=a.shape, dtype=dtype) + mean + return mx.random.normal(shape=a.shape, scale=std, loc=mean, dtype=dtype) return initializer @@ -184,7 +184,7 @@ def glorot_normal( def initializer(a: mx.array, gain: float = 1.0) -> mx.array: fan_in, fan_out = _calculate_fan_in_fan_out(a) std = gain * math.sqrt(2.0 / (fan_in + fan_out)) - return mx.random.normal(shape=a.shape, dtype=dtype) * std + return mx.random.normal(shape=a.shape, scale=std, dtype=dtype) return initializer @@ -285,7 +285,7 @@ def initializer( raise ValueError(f"Invalid mode: {mode}. Valid modes are: fan_in, fan_out") std = gain / math.sqrt(fan) - return mx.random.normal(shape=a.shape, dtype=dtype) * std + return mx.random.normal(shape=a.shape, scale=std, dtype=dtype) return initializer diff --git a/python/mlx/nn/layers/embedding.py b/python/mlx/nn/layers/embedding.py index c62b1206f..18482eddc 100644 --- a/python/mlx/nn/layers/embedding.py +++ b/python/mlx/nn/layers/embedding.py @@ -21,7 +21,7 @@ class Embedding(Module): def __init__(self, num_embeddings: int, dims: int): super().__init__() scale = math.sqrt(1 / dims) - self.weight = mx.random.normal((num_embeddings, dims)) * scale + self.weight = mx.random.normal(shape=(num_embeddings, dims), scale=scale) def _extra_repr(self): return f"{self.weight.shape[0]}, {self.weight.shape[1]}"