diff --git a/textworld/render/graph.py b/textworld/render/graph.py index 75e09ad7..88db1c70 100644 --- a/textworld/render/graph.py +++ b/textworld/render/graph.py @@ -88,7 +88,10 @@ def show_graph(facts: Iterable[Proposition], G = build_graph_from_facts(facts) plt.figure(figsize=(16, 9)) - pos = nx.drawing.nx_pydot.pydot_layout(G, prog="fdp") + # To avoid name issue with nx_pydot, we convert node labels to integers. + H = nx.convert_node_labels_to_integers(G, label_attribute='node_label') + H_layout = nx.drawing.nx_pydot.pydot_layout(H, prog="fdp") + pos = {H.nodes[n]['node_label']: p for n, p in H_layout.items()} edge_labels_pos = {} trace3_list = []