From 284e16c563034a4d5247b74e0391ae4bc8fa9598 Mon Sep 17 00:00:00 2001 From: Tingyu Wang Date: Fri, 8 Mar 2024 10:03:05 -0500 Subject: [PATCH] pass in explicit args --- .../cugraph_dgl/nn/conv/gatconv.py | 30 +++++++++++++++---- .../cugraph_dgl/nn/conv/gatv2conv.py | 18 +++++++---- .../cugraph_pyg/nn/conv/gat_conv.py | 30 +++++++++++++++---- .../cugraph_pyg/nn/conv/gatv2_conv.py | 18 +++++++---- 4 files changed, 72 insertions(+), 24 deletions(-) diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py index 88bed90e198..e8813271fd8 100644 --- a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatconv.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Optional, Union from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph from cugraph.utilities.utils import import_optional @@ -186,7 +186,10 @@ def forward( nfeat: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], efeat: Optional[torch.Tensor] = None, max_in_degree: Optional[int] = None, - **kwargs: Any, + deterministic_dgrad: bool = False, + deterministic_wgrad: bool = False, + high_precision_dgrad: bool = False, + high_precision_wgrad: bool = False, ) -> torch.Tensor: r"""Forward computation. @@ -205,8 +208,20 @@ def forward( from a neighbor sampler, the value should be set to the corresponding :attr:`fanout`. This option is used to invoke the MFG-variant of cugraph-ops kernel. - **kwargs : Any - Additional arguments of `pylibcugraphops.pytorch.operators.mha_gat_n2n`. + deterministic_dgrad : bool, default=False + Optional flag indicating whether the feature gradients + are computed deterministically using a dedicated workspace buffer. + deterministic_wgrad: bool, default=False + Optional flag indicating whether the weight gradients + are computed deterministically using a dedicated workspace buffer. + high_precision_dgrad: bool, default=False + Optional flag indicating whether gradients for inputs in half precision + are kept in single precision as long as possible and only casted to + the corresponding input type at the very end. + high_precision_wgrad: bool, default=False + Optional flag indicating whether gradients for weights in half precision + are kept in single precision as long as possible and only casted to + the corresponding input type at the very end. Returns ------- @@ -235,7 +250,7 @@ def forward( _graph = self.get_cugraph_ops_CSC( g, is_bipartite=bipartite, max_in_degree=max_in_degree ) - if kwargs.get("deterministic_dgrad", False): + if deterministic_dgrad: _graph.add_reverse_graph() if bipartite: @@ -278,7 +293,10 @@ def forward( negative_slope=self.negative_slope, concat_heads=self.concat, edge_feat=efeat, - **kwargs, + deterministic_dgrad=deterministic_dgrad, + deterministic_wgrad=deterministic_wgrad, + high_precision_dgrad=high_precision_dgrad, + high_precision_wgrad=high_precision_wgrad, )[: g.num_dst_nodes()] if self.concat: diff --git a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py index 6e7e3d172d2..4f47005f8ee 100644 --- a/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py +++ b/python/cugraph-dgl/cugraph_dgl/nn/conv/gatv2conv.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Union +from typing import Optional, Union from cugraph_dgl.nn.conv.base import BaseConv, SparseGraph from cugraph.utilities.utils import import_optional @@ -150,7 +150,8 @@ def forward( nfeat: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], efeat: Optional[torch.Tensor] = None, max_in_degree: Optional[int] = None, - **kwargs: Any, + deterministic_dgrad: bool = False, + deterministic_wgrad: bool = False, ) -> torch.Tensor: r"""Forward computation. @@ -167,8 +168,12 @@ def forward( from a neighbor sampler, the value should be set to the corresponding :attr:`fanout`. This option is used to invoke the MFG-variant of cugraph-ops kernel. - **kwargs : Any - Additional arguments of `pylibcugraphops.pytorch.operators.mha_gat_v2_n2n`. + deterministic_dgrad : bool, default=False + Optional flag indicating whether the feature gradients + are computed deterministically using a dedicated workspace buffer. + deterministic_wgrad: bool, default=False + Optional flag indicating whether the weight gradients + are computed deterministically using a dedicated workspace buffer. Returns ------- @@ -199,7 +204,7 @@ def forward( _graph = self.get_cugraph_ops_CSC( g, is_bipartite=graph_bipartite, max_in_degree=max_in_degree ) - if kwargs.get("deterministic_dgrad", False): + if deterministic_dgrad: _graph.add_reverse_graph() if nfeat_bipartite: @@ -233,7 +238,8 @@ def forward( negative_slope=self.negative_slope, concat_heads=self.concat, edge_feat=efeat, - **kwargs, + deterministic_dgrad=deterministic_dgrad, + deterministic_wgrad=deterministic_wgrad, )[: g.num_dst_nodes()] if self.concat: diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py index 9c0f507f9e7..d1785f2bef8 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gat_conv.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Tuple, Union +from typing import Optional, Tuple, Union from cugraph.utilities.utils import import_optional from pylibcugraphops.pytorch.operators import mha_gat_n2n @@ -162,7 +162,10 @@ def forward( csc: Tuple[torch.Tensor, torch.Tensor, int], edge_attr: Optional[torch.Tensor] = None, max_num_neighbors: Optional[int] = None, - **kwargs: Any, + deterministic_dgrad: bool = False, + deterministic_wgrad: bool = False, + high_precision_dgrad: bool = False, + high_precision_wgrad: bool = False, ) -> torch.Tensor: r"""Runs the forward pass of the module. @@ -179,14 +182,26 @@ def forward( of a destination node. When enabled, it allows models to use the message-flow-graph primitives in cugraph-ops. (default: :obj:`None`) - **kwargs : Additional arguments of - `pylibcugraphops.pytorch.operators.mha_gat_n2n`. + deterministic_dgrad : bool, default=False + Optional flag indicating whether the feature gradients + are computed deterministically using a dedicated workspace buffer. + deterministic_wgrad: bool, default=False + Optional flag indicating whether the weight gradients + are computed deterministically using a dedicated workspace buffer. + high_precision_dgrad: bool, default=False + Optional flag indicating whether gradients for inputs in half precision + are kept in single precision as long as possible and only casted to + the corresponding input type at the very end. + high_precision_wgrad: bool, default=False + Optional flag indicating whether gradients for weights in half precision + are kept in single precision as long as possible and only casted to + the corresponding input type at the very end. """ bipartite = not isinstance(x, torch.Tensor) graph = self.get_cugraph( csc, bipartite=bipartite, max_num_neighbors=max_num_neighbors ) - if kwargs.get("deterministic_dgrad", False): + if deterministic_dgrad: graph.add_reverse_graph() if edge_attr is not None: @@ -225,7 +240,10 @@ def forward( negative_slope=self.negative_slope, concat_heads=self.concat, edge_feat=edge_attr, - **kwargs, + deterministic_dgrad=deterministic_dgrad, + deterministic_wgrad=deterministic_wgrad, + high_precision_dgrad=high_precision_dgrad, + high_precision_wgrad=high_precision_wgrad, ) if self.bias is not None: diff --git a/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py b/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py index ffe1d846cd5..33865898816 100644 --- a/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py +++ b/python/cugraph-pyg/cugraph_pyg/nn/conv/gatv2_conv.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Tuple, Union +from typing import Optional, Tuple, Union from cugraph.utilities.utils import import_optional from pylibcugraphops.pytorch.operators import mha_gat_v2_n2n @@ -174,7 +174,8 @@ def forward( x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], csc: Tuple[torch.Tensor, torch.Tensor, int], edge_attr: Optional[torch.Tensor] = None, - **kwargs: Any, + deterministic_dgrad: bool = False, + deterministic_wgrad: bool = False, ) -> torch.Tensor: r"""Runs the forward pass of the module. @@ -187,12 +188,16 @@ def forward( :meth:`to_csc` method to convert an :obj:`edge_index` representation to the desired format. edge_attr: (torch.Tensor, optional) The edge features. - **kwargs : Additional arguments of - `pylibcugraphops.pytorch.operators.mha_gat_v2_n2n`. + deterministic_dgrad : bool, default=False + Optional flag indicating whether the feature gradients + are computed deterministically using a dedicated workspace buffer. + deterministic_wgrad: bool, default=False + Optional flag indicating whether the weight gradients + are computed deterministically using a dedicated workspace buffer. """ bipartite = not isinstance(x, torch.Tensor) or not self.share_weights graph = self.get_cugraph(csc, bipartite=bipartite) - if kwargs.get("deterministic_dgrad", False): + if deterministic_dgrad: graph.add_reverse_graph() if edge_attr is not None: @@ -222,7 +227,8 @@ def forward( negative_slope=self.negative_slope, concat_heads=self.concat, edge_feat=edge_attr, - **kwargs, + deterministic_dgrad=deterministic_dgrad, + deterministic_wgrad=deterministic_wgrad, ) if self.bias is not None: