@@ -1192,7 +1192,7 @@ def __init__(
1192
1192
max_feature_lengths : Optional [Dict [str , int ]] = None ,
1193
1193
feature_processor_modules : Optional [Dict [str , torch .nn .Module ]] = None ,
1194
1194
over_arch_clazz : Type [nn .Module ] = TestOverArch ,
1195
- preproc_module : Optional [nn .Module ] = None ,
1195
+ postproc_module : Optional [nn .Module ] = None ,
1196
1196
) -> None :
1197
1197
super ().__init__ (
1198
1198
tables = cast (List [BaseEmbeddingConfig ], tables ),
@@ -1229,7 +1229,7 @@ def __init__(
1229
1229
"dummy_ones" ,
1230
1230
torch .ones (1 , device = dense_device ),
1231
1231
)
1232
- self .preproc_module = preproc_module
1232
+ self .postproc_module = postproc_module
1233
1233
1234
1234
def sparse_forward (self , input : ModelInput ) -> KeyedTensor :
1235
1235
return self .sparse (
@@ -1256,8 +1256,8 @@ def forward(
1256
1256
self ,
1257
1257
input : ModelInput ,
1258
1258
) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
1259
- if self .preproc_module :
1260
- input = self .preproc_module (input )
1259
+ if self .postproc_module :
1260
+ input = self .postproc_module (input )
1261
1261
return self .dense_forward (input , self .sparse_forward (input ))
1262
1262
1263
1263
@@ -1749,18 +1749,18 @@ def forward(self, kjt: KeyedJaggedTensor) -> List[KeyedJaggedTensor]:
1749
1749
1750
1750
class TestModelWithPreproc (nn .Module ):
1751
1751
"""
1752
- Basic module with up to 3 preproc modules:
1753
- - preproc on idlist_features for non-weighted EBC
1754
- - preproc on idscore_features for weighted EBC
1755
- - optional preproc on model input shared by both EBCs
1752
+ Basic module with up to 3 postproc modules:
1753
+ - postproc on idlist_features for non-weighted EBC
1754
+ - postproc on idscore_features for weighted EBC
1755
+ - optional postproc on model input shared by both EBCs
1756
1756
1757
1757
Args:
1758
1758
tables,
1759
1759
weighted_tables,
1760
1760
device,
1761
- preproc_module ,
1761
+ postproc_module ,
1762
1762
num_float_features,
1763
- run_preproc_inline ,
1763
+ run_postproc_inline ,
1764
1764
1765
1765
Example:
1766
1766
>>> TestModelWithPreproc(tables, weighted_tables, device)
@@ -1774,9 +1774,9 @@ def __init__(
1774
1774
tables : List [EmbeddingBagConfig ],
1775
1775
weighted_tables : List [EmbeddingBagConfig ],
1776
1776
device : torch .device ,
1777
- preproc_module : Optional [nn .Module ] = None ,
1777
+ postproc_module : Optional [nn .Module ] = None ,
1778
1778
num_float_features : int = 10 ,
1779
- run_preproc_inline : bool = False ,
1779
+ run_postproc_inline : bool = False ,
1780
1780
) -> None :
1781
1781
super ().__init__ ()
1782
1782
self .dense = TestDenseArch (num_float_features , device )
@@ -1790,17 +1790,17 @@ def __init__(
1790
1790
is_weighted = True ,
1791
1791
device = device ,
1792
1792
)
1793
- self .preproc_nonweighted = TestPreprocNonWeighted ()
1794
- self .preproc_weighted = TestPreprocWeighted ()
1795
- self ._preproc_module = preproc_module
1796
- self ._run_preproc_inline = run_preproc_inline
1793
+ self .postproc_nonweighted = TestPreprocNonWeighted ()
1794
+ self .postproc_weighted = TestPreprocWeighted ()
1795
+ self ._postproc_module = postproc_module
1796
+ self ._run_postproc_inline = run_postproc_inline
1797
1797
1798
1798
def forward (
1799
1799
self ,
1800
1800
input : ModelInput ,
1801
1801
) -> Tuple [torch .Tensor , torch .Tensor ]:
1802
1802
"""
1803
- Runs preprco for EBC and weighted EBC, optionally runs preproc for input
1803
+ Runs preprco for EBC and weighted EBC, optionally runs postproc for input
1804
1804
1805
1805
Args:
1806
1806
input
@@ -1809,20 +1809,20 @@ def forward(
1809
1809
"""
1810
1810
modified_input = input
1811
1811
1812
- if self ._preproc_module is not None :
1813
- modified_input = self ._preproc_module (modified_input )
1814
- elif self ._run_preproc_inline :
1812
+ if self ._postproc_module is not None :
1813
+ modified_input = self ._postproc_module (modified_input )
1814
+ elif self ._run_postproc_inline :
1815
1815
idlist_features = modified_input .idlist_features
1816
1816
modified_input .idlist_features = KeyedJaggedTensor .from_lengths_sync (
1817
1817
idlist_features .keys (), # pyre-ignore [6]
1818
1818
idlist_features .values (), # pyre-ignore [6]
1819
1819
idlist_features .lengths (), # pyre-ignore [16]
1820
1820
)
1821
1821
1822
- modified_idlist_features = self .preproc_nonweighted (
1822
+ modified_idlist_features = self .postproc_nonweighted (
1823
1823
modified_input .idlist_features
1824
1824
)
1825
- modified_idscore_features = self .preproc_weighted (
1825
+ modified_idscore_features = self .postproc_weighted (
1826
1826
modified_input .idscore_features
1827
1827
)
1828
1828
ebc_out = self .ebc (modified_idlist_features [0 ])
@@ -1834,15 +1834,15 @@ def forward(
1834
1834
1835
1835
class TestNegSamplingModule (torch .nn .Module ):
1836
1836
"""
1837
- Basic module to simulate feature augmentation preproc (e.g. neg sampling) for testing
1837
+ Basic module to simulate feature augmentation postproc (e.g. neg sampling) for testing
1838
1838
1839
1839
Args:
1840
1840
extra_input
1841
1841
has_params
1842
1842
1843
1843
Example:
1844
- >>> preproc = TestNegSamplingModule(extra_input)
1845
- >>> out = preproc (in)
1844
+ >>> postproc = TestNegSamplingModule(extra_input)
1845
+ >>> out = postproc (in)
1846
1846
1847
1847
Returns:
1848
1848
ModelInput
@@ -1906,8 +1906,8 @@ class TestPositionWeightedPreprocModule(torch.nn.Module):
1906
1906
1907
1907
Args: None
1908
1908
Example:
1909
- >>> preproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
1910
- >>> out = preproc (in)
1909
+ >>> postproc = TestPositionWeightedPreprocModule(max_feature_lengths, device)
1910
+ >>> out = postproc (in)
1911
1911
Returns:
1912
1912
ModelInput
1913
1913
"""
0 commit comments