Skip to content

Commit

Permalink
Fix pos_list api for force atlas
Browse files Browse the repository at this point in the history
  • Loading branch information
hlinsen committed Mar 6, 2024
1 parent 37d5d48 commit 85cc611
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions python/cugraph/cugraph/layout/force_atlas2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from cugraph.layout import force_atlas2_wrapper
from cugraph.utilities import ensure_cugraph_obj_for_nx
import cudf


def force_atlas2(
Expand Down Expand Up @@ -55,8 +56,8 @@ def force_atlas2(
Above 1000 iterations is discouraged.
pos_list: cudf.DataFrame, optional (default=None)
Data frame with initial vertex positions containing two columns:
'x' and 'y' positions.
Data frame with initial vertex positions containing three columns:
'vertex', 'x' and 'y' positions.
outbound_attraction_distribution: bool, optional (default=True)
Distributes attraction along outbound edges.
Expand Down Expand Up @@ -131,6 +132,10 @@ def on_train_end(self, positions):
input_graph, isNx = ensure_cugraph_obj_for_nx(input_graph)

if pos_list is not None:
if not isinstance(pos_list, cudf.DataFrame):
raise TypeError('pos_list should be a cudf.DataFrame')
if set(pos_list.columns) != set(['x', 'y', 'vertex']):
raise ValueError('pos_list has wrong column names')
if input_graph.renumbered is True:
if input_graph.vertex_column_size() > 1:
cols = pos_list.columns[:-2].to_list()
Expand Down

0 comments on commit 85cc611

Please sign in to comment.