Skip to content

Commit f0240d4

Browse files
authored
Allow multiple targets for rustworkx.all_simple_paths (#1488)
* Enable multiple targets passed as an iterable to rustworkx.all_simple_paths * Fix formatting issues * Use enum to represent single/multiple target nodes, add a release note
1 parent 949d30b commit f0240d4

File tree

7 files changed

+212
-41
lines changed

7 files changed

+212
-41
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
---
2+
features:
3+
- |
4+
The :func:`~rustworkx.all_simple_paths` function now supports multiple
5+
target nodes passed as an iterable. You can now pass either a single target
6+
node (int) or multiple target nodes (iterable of ints) to find all simple
7+
paths from a source node to any of the specified targets. For example::
8+
9+
import rustworkx as rx
10+
11+
graph = rx.generators.path_graph(4)
12+
# Multiple targets - new functionality
13+
paths = rx.all_simple_paths(graph, 0, [2, 3])
14+
# paths: [[0, 1, 2], [0, 1, 2, 3]]
15+
16+
This enhancement maintains backward compatibility while providing more
17+
flexibility for pathfinding operations in both :class:`~.PyGraph` and
18+
:class:`~.PyDiGraph` objects.

rustworkx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def all_simple_paths(graph, from_, to, min_depth=None, cutoff=None):
263263
:param graph: The graph to find the path in. Can either be a
264264
class:`~rustworkx.PyGraph` or :class:`~rustworkx.PyDiGraph`
265265
:param int from_: The node index to find the paths from
266-
:param int to: The node index to find the paths to
266+
:param int | Iterable[int] to: The node index(es) to find the paths to
267267
:param int min_depth: The minimum depth of the path to include in the
268268
output list of paths. By default all paths are included regardless of
269269
depth, setting to 0 will behave like the default.

rustworkx/__init__.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ import numpy as np
1414
import numpy.typing as npt
1515

1616
from typing import Generic, Any, Callable, overload
17-
from collections.abc import Iterator, Sequence
17+
from collections.abc import Iterable, Iterator, Sequence
1818

1919
if sys.version_info >= (3, 13):
2020
from typing import TypeVar
@@ -325,7 +325,7 @@ def adjacency_matrix(
325325
def all_simple_paths(
326326
graph: PyGraph | PyDiGraph,
327327
from_: int,
328-
to: int,
328+
to: int | Iterable[int],
329329
min_depth: int | None = ...,
330330
cutoff: int | None = ...,
331331
) -> list[list[int]]: ...

rustworkx/rustworkx.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,15 +285,15 @@ def local_complement(
285285
def digraph_all_simple_paths(
286286
graph: PyDiGraph,
287287
origin: int,
288-
to: int,
288+
to: int | Iterable[int],
289289
/,
290290
min_depth: int | None = ...,
291291
cutoff: int | None = ...,
292292
) -> list[list[int]]: ...
293293
def graph_all_simple_paths(
294294
graph: PyGraph,
295295
origin: int,
296-
to: int,
296+
to: int | Iterable[int],
297297
/,
298298
min_depth: int | None = ...,
299299
cutoff: int | None = ...,

src/connectivity/mod.rs

Lines changed: 102 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -950,13 +950,37 @@ pub fn local_complement(
950950
Ok(complement_graph)
951951
}
952952

953+
/// Represents target nodes for path-finding operations.
954+
///
955+
/// This enum allows functions to accept either a single target node
956+
/// or multiple target nodes, providing flexibility in path-finding algorithms.
957+
pub enum TargetNodes {
958+
Single(NodeIndex),
959+
Multiple(HashSet<NodeIndex>),
960+
}
961+
962+
impl<'py> FromPyObject<'py> for TargetNodes {
963+
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
964+
if let Ok(int) = ob.extract::<usize>() {
965+
Ok(Self::Single(NodeIndex::new(int)))
966+
} else {
967+
let mut target_set: HashSet<NodeIndex> = HashSet::new();
968+
for target in ob.try_iter()? {
969+
let target_index = NodeIndex::new(target?.extract::<usize>()?);
970+
target_set.insert(target_index);
971+
}
972+
Ok(Self::Multiple(target_set))
973+
}
974+
}
975+
}
976+
953977
/// Return all simple paths between 2 nodes in a PyGraph object
954978
///
955979
/// A simple path is a path with no repeated nodes.
956980
///
957981
/// :param PyGraph graph: The graph to find the path in
958982
/// :param int origin: The node index to find the paths from
959-
/// :param int to: The node index to find the paths to
983+
/// :param int | iterable[int] to: The node index(es) to find the paths to
960984
/// :param int min_depth: The minimum depth of the path to include in the output
961985
/// list of paths. By default all paths are included regardless of depth,
962986
/// setting to 0 will behave like the default.
@@ -971,7 +995,7 @@ pub fn local_complement(
971995
pub fn graph_all_simple_paths(
972996
graph: &graph::PyGraph,
973997
origin: usize,
974-
to: usize,
998+
to: TargetNodes,
975999
min_depth: Option<usize>,
9761000
cutoff: Option<usize>,
9771001
) -> PyResult<Vec<Vec<usize>>> {
@@ -981,27 +1005,48 @@ pub fn graph_all_simple_paths(
9811005
"The input index for 'from' is not a valid node index",
9821006
));
9831007
}
984-
let to_index = NodeIndex::new(to);
985-
if !graph.graph.contains_node(to_index) {
986-
return Err(InvalidNode::new_err(
987-
"The input index for 'to' is not a valid node index",
988-
));
989-
}
9901008
let min_intermediate_nodes: usize = match min_depth {
9911009
Some(0) | None => 0,
9921010
Some(depth) => depth - 2,
9931011
};
9941012
let cutoff_petgraph: Option<usize> = cutoff.map(|depth| depth - 2);
995-
let result: Vec<Vec<usize>> = algo::all_simple_paths::<Vec<_>, _, foldhash::fast::RandomState>(
996-
&graph.graph,
997-
from_index,
998-
to_index,
999-
min_intermediate_nodes,
1000-
cutoff_petgraph,
1001-
)
1002-
.map(|v: Vec<NodeIndex>| v.into_iter().map(|i| i.index()).collect())
1003-
.collect();
1004-
Ok(result)
1013+
1014+
match to {
1015+
TargetNodes::Single(to_index) => {
1016+
if !graph.graph.contains_node(to_index) {
1017+
return Err(InvalidNode::new_err(
1018+
"The input index for 'to' is not a valid node index",
1019+
));
1020+
}
1021+
1022+
let result: Vec<Vec<usize>> =
1023+
algo::all_simple_paths::<Vec<_>, _, foldhash::fast::RandomState>(
1024+
&graph.graph,
1025+
from_index,
1026+
to_index,
1027+
min_intermediate_nodes,
1028+
cutoff_petgraph,
1029+
)
1030+
.map(|v: Vec<NodeIndex>| v.into_iter().map(|i| i.index()).collect())
1031+
.collect();
1032+
Ok(result)
1033+
}
1034+
TargetNodes::Multiple(target_set) => {
1035+
let result = connectivity::all_simple_paths_multiple_targets(
1036+
&graph.graph,
1037+
from_index,
1038+
&target_set,
1039+
min_intermediate_nodes,
1040+
cutoff_petgraph,
1041+
);
1042+
1043+
Ok(result
1044+
.into_values()
1045+
.flatten()
1046+
.map(|path| path.into_iter().map(|node| node.index()).collect())
1047+
.collect())
1048+
}
1049+
}
10051050
}
10061051

10071052
/// Return all simple paths between 2 nodes in a PyDiGraph object
@@ -1010,7 +1055,7 @@ pub fn graph_all_simple_paths(
10101055
///
10111056
/// :param PyDiGraph graph: The graph to find the path in
10121057
/// :param int origin: The node index to find the paths from
1013-
/// :param int to: The node index to find the paths to
1058+
/// :param int | iterable[int] to: The node index(es) to find the paths to
10141059
/// :param int min_depth: The minimum depth of the path to include in the output
10151060
/// list of paths. By default all paths are included regardless of depth,
10161061
/// setting to 0 will behave like the default.
@@ -1025,7 +1070,7 @@ pub fn graph_all_simple_paths(
10251070
pub fn digraph_all_simple_paths(
10261071
graph: &digraph::PyDiGraph,
10271072
origin: usize,
1028-
to: usize,
1073+
to: TargetNodes,
10291074
min_depth: Option<usize>,
10301075
cutoff: Option<usize>,
10311076
) -> PyResult<Vec<Vec<usize>>> {
@@ -1035,27 +1080,48 @@ pub fn digraph_all_simple_paths(
10351080
"The input index for 'from' is not a valid node index",
10361081
));
10371082
}
1038-
let to_index = NodeIndex::new(to);
1039-
if !graph.graph.contains_node(to_index) {
1040-
return Err(InvalidNode::new_err(
1041-
"The input index for 'to' is not a valid node index",
1042-
));
1043-
}
10441083
let min_intermediate_nodes: usize = match min_depth {
10451084
Some(0) | None => 0,
10461085
Some(depth) => depth - 2,
10471086
};
10481087
let cutoff_petgraph: Option<usize> = cutoff.map(|depth| depth - 2);
1049-
let result: Vec<Vec<usize>> = algo::all_simple_paths::<Vec<_>, _, foldhash::fast::RandomState>(
1050-
&graph.graph,
1051-
from_index,
1052-
to_index,
1053-
min_intermediate_nodes,
1054-
cutoff_petgraph,
1055-
)
1056-
.map(|v: Vec<NodeIndex>| v.into_iter().map(|i| i.index()).collect())
1057-
.collect();
1058-
Ok(result)
1088+
1089+
match to {
1090+
TargetNodes::Single(to_index) => {
1091+
if !graph.graph.contains_node(to_index) {
1092+
return Err(InvalidNode::new_err(
1093+
"The input index for 'to' is not a valid node index",
1094+
));
1095+
}
1096+
1097+
let result: Vec<Vec<usize>> =
1098+
algo::all_simple_paths::<Vec<_>, _, foldhash::fast::RandomState>(
1099+
&graph.graph,
1100+
from_index,
1101+
to_index,
1102+
min_intermediate_nodes,
1103+
cutoff_petgraph,
1104+
)
1105+
.map(|v: Vec<NodeIndex>| v.into_iter().map(|i| i.index()).collect())
1106+
.collect();
1107+
Ok(result)
1108+
}
1109+
TargetNodes::Multiple(target_set) => {
1110+
let result = connectivity::all_simple_paths_multiple_targets(
1111+
&graph.graph,
1112+
from_index,
1113+
&target_set,
1114+
min_intermediate_nodes,
1115+
cutoff_petgraph,
1116+
);
1117+
1118+
Ok(result
1119+
.into_values()
1120+
.flatten()
1121+
.map(|path| path.into_iter().map(|node| node.index()).collect())
1122+
.collect())
1123+
}
1124+
}
10591125
}
10601126

10611127
/// Return all the simple paths between all pairs of nodes in the graph

tests/digraph/test_all_simple_paths.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,45 @@ def test_graph_digraph_all_simple_paths(self):
131131
dag.add_node(1)
132132
self.assertRaises(TypeError, rustworkx.digraph_all_simple_paths, (dag, 0, 1))
133133

134+
def test_all_simple_paths_multiple_targets(self):
135+
graph = rustworkx.generators.directed_path_graph(4)
136+
graph.add_edge(1, 3, None)
137+
paths = rustworkx.digraph_all_simple_paths(graph, 0, [2, 3])
138+
expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 3]]
139+
self.assertEqual(len(expected), len(paths))
140+
for i in expected:
141+
self.assertIn(i, paths)
142+
143+
def test_all_simple_paths_multiple_targets_iterables(self):
144+
graph = rustworkx.generators.directed_path_graph(4)
145+
graph.add_edge(1, 3, None)
146+
paths = rustworkx.digraph_all_simple_paths(graph, 0, iter([2, 3]))
147+
expected = [[0, 1, 2], [0, 1, 2, 3], [0, 1, 3]]
148+
self.assertEqual(len(expected), len(paths))
149+
for i in expected:
150+
self.assertIn(i, paths)
151+
152+
def test_all_simple_paths_multiple_targets_invalid_type(self):
153+
graph = rustworkx.generators.directed_path_graph(4)
154+
with self.assertRaises(TypeError):
155+
rustworkx.digraph_all_simple_paths(graph, 0, [2, "a"])
156+
157+
def test_all_simple_paths_multiple_targets_invalid_index(self):
158+
graph = rustworkx.generators.directed_path_graph(4)
159+
paths = rustworkx.digraph_all_simple_paths(graph, 0, [3, 100])
160+
expected = [[0, 1, 2, 3]]
161+
self.assertEqual(expected, paths)
162+
163+
def test_all_simple_paths_on_nontrivial_graph(self):
164+
graph = rustworkx.PyDiGraph()
165+
graph.add_nodes_from(range(6))
166+
graph.add_edges_from_no_data([(0, 1), (0, 5), (1, 2), (1, 5), (2, 3), (3, 4), (4, 5)])
167+
paths = rustworkx.digraph_all_simple_paths(graph, 0, [2, 3])
168+
expected = [[0, 1, 2], [0, 1, 2, 3]]
169+
self.assertEqual(len(expected), len(paths))
170+
for i in expected:
171+
self.assertIn(i, paths)
172+
134173

135174
class TestDiGraphAllSimplePathsAllPairs(unittest.TestCase):
136175
def setUp(self):

tests/graph/test_all_simple_paths.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,54 @@ def test_digraph_graph_all_simple_paths(self):
232232
dag.add_node(1)
233233
self.assertRaises(TypeError, rustworkx.graph_all_simple_paths, (dag, 0, 1))
234234

235+
def test_all_simple_paths_multiple_targets(self):
236+
graph = rustworkx.generators.path_graph(4)
237+
graph.add_edge(1, 3, None)
238+
paths = rustworkx.graph_all_simple_paths(graph, 0, [2, 3])
239+
expected = [[0, 1, 2], [0, 1, 3], [0, 1, 2, 3], [0, 1, 3, 2]]
240+
self.assertEqual(len(expected), len(paths))
241+
for i in expected:
242+
self.assertIn(i, paths)
243+
244+
def test_all_simple_paths_multiple_targets_iterables(self):
245+
graph = rustworkx.generators.path_graph(4)
246+
graph.add_edge(1, 3, None)
247+
paths = rustworkx.graph_all_simple_paths(graph, 0, iter([2, 3]))
248+
expected = [[0, 1, 2], [0, 1, 3], [0, 1, 2, 3], [0, 1, 3, 2]]
249+
self.assertEqual(len(expected), len(paths))
250+
for i in expected:
251+
self.assertIn(i, paths)
252+
253+
def test_all_simple_paths_multiple_targets_invalid_type(self):
254+
graph = rustworkx.generators.path_graph(4)
255+
with self.assertRaises(TypeError):
256+
rustworkx.graph_all_simple_paths(graph, 0, [2, "a"])
257+
258+
def test_all_simple_paths_multiple_targets_invalid_index(self):
259+
graph = rustworkx.generators.path_graph(4)
260+
paths = rustworkx.graph_all_simple_paths(graph, 0, [3, 100])
261+
expected = [[0, 1, 2, 3]]
262+
self.assertEqual(expected, paths)
263+
264+
def test_all_simple_paths_on_nontrivial_graph(self):
265+
graph = rustworkx.PyGraph()
266+
graph.add_nodes_from(range(6))
267+
graph.add_edges_from_no_data([(0, 1), (0, 5), (1, 2), (1, 5), (2, 3), (3, 4), (4, 5)])
268+
paths = rustworkx.graph_all_simple_paths(graph, 0, [2, 3])
269+
expected = [
270+
[0, 1, 2],
271+
[0, 1, 2, 3],
272+
[0, 1, 5, 4, 3],
273+
[0, 1, 5, 4, 3, 2],
274+
[0, 5, 1, 2],
275+
[0, 5, 1, 2, 3],
276+
[0, 5, 4, 3],
277+
[0, 5, 4, 3, 2],
278+
]
279+
self.assertEqual(len(expected), len(paths))
280+
for i in expected:
281+
self.assertIn(i, paths)
282+
235283

236284
class TestGraphAllSimplePathsAllPairs(unittest.TestCase):
237285
def setUp(self):

0 commit comments

Comments
 (0)