From 011d9369ff73c390ac53457138e067a18a5909c9 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sun, 30 Jun 2024 23:31:30 +0100 Subject: [PATCH] feat: improving Lre __repr__() --- linear_relational/Lre.py | 9 +++++++++ tests/test_Lre.py | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/linear_relational/Lre.py b/linear_relational/Lre.py index 8a5941f..252f017 100644 --- a/linear_relational/Lre.py +++ b/linear_relational/Lre.py @@ -72,6 +72,9 @@ def calculate_subject_activation( vec = vec / vec.norm() return vec + def __repr__(self) -> str: + return f"InvertedLre({self.relation}, rank {self.rank}, layers {self.subject_layer} -> {self.object_layer}, {self.object_aggregation})" + class LowRankLre(nn.Module): """Low-rank approximation of a LRE""" @@ -140,6 +143,9 @@ def calculate_object_activation( vec = vec / vec.norm() return vec + def __repr__(self) -> str: + return f"LowRankLre({self.relation}, rank {self.rank}, layers {self.subject_layer} -> {self.object_layer}, {self.object_aggregation})" + class Lre(nn.Module): """Linear Relational Embedding""" @@ -211,3 +217,6 @@ def _low_rank_svd( low_rank_v: torch.Tensor = v[:, :rank].to(self.weight.dtype) low_rank_s: torch.Tensor = s[:rank].to(self.weight.dtype) return low_rank_u, low_rank_s, low_rank_v + + def __repr__(self) -> str: + return f"Lre({self.relation}, layers {self.subject_layer} -> {self.object_layer}, {self.object_aggregation})" diff --git a/tests/test_Lre.py b/tests/test_Lre.py index 44982d3..8dd43b9 100644 --- a/tests/test_Lre.py +++ b/tests/test_Lre.py @@ -13,6 +13,7 @@ def test_Lre_invert() -> None: bias=bias, weight=torch.eye(3), ) + assert lre.__repr__() == "Lre(test, layers 5 -> 10, mean)" inv_lre = lre.invert(rank=2) assert inv_lre.relation == "test" assert inv_lre.subject_layer == 5 @@ -23,6 +24,7 @@ def test_Lre_invert() -> None: assert inv_lre.s.shape == (2,) assert inv_lre.v.shape == (3, 2) assert inv_lre.rank == 2 + assert inv_lre.__repr__() == "InvertedLre(test, rank 2, layers 5 -> 10, mean)" def test_Lre_to_low_rank() -> None: @@ -36,6 +38,7 @@ def test_Lre_to_low_rank() -> None: weight=torch.eye(3), ) low_rank_lre = lre.to_low_rank(rank=2) + assert lre.__repr__() == "Lre(test, layers 5 -> 10, mean)" assert low_rank_lre.relation == "test" assert low_rank_lre.subject_layer == 5 assert low_rank_lre.object_layer == 10 @@ -45,6 +48,7 @@ def test_Lre_to_low_rank() -> None: assert low_rank_lre.s.shape == (2,) assert low_rank_lre.v.shape == (3, 2) assert low_rank_lre.rank == 2 + assert low_rank_lre.__repr__() == "LowRankLre(test, rank 2, layers 5 -> 10, mean)" def test_LowRankLre_calculate_object_activation_unnormalized() -> None: