@@ -950,13 +950,37 @@ pub fn local_complement(
950
950
Ok ( complement_graph)
951
951
}
952
952
953
+ /// Represents target nodes for path-finding operations.
954
+ ///
955
+ /// This enum allows functions to accept either a single target node
956
+ /// or multiple target nodes, providing flexibility in path-finding algorithms.
957
+ pub enum TargetNodes {
958
+ Single ( NodeIndex ) ,
959
+ Multiple ( HashSet < NodeIndex > ) ,
960
+ }
961
+
962
+ impl < ' py > FromPyObject < ' py > for TargetNodes {
963
+ fn extract_bound ( ob : & Bound < ' py , PyAny > ) -> PyResult < Self > {
964
+ if let Ok ( int) = ob. extract :: < usize > ( ) {
965
+ Ok ( Self :: Single ( NodeIndex :: new ( int) ) )
966
+ } else {
967
+ let mut target_set: HashSet < NodeIndex > = HashSet :: new ( ) ;
968
+ for target in ob. try_iter ( ) ? {
969
+ let target_index = NodeIndex :: new ( target?. extract :: < usize > ( ) ?) ;
970
+ target_set. insert ( target_index) ;
971
+ }
972
+ Ok ( Self :: Multiple ( target_set) )
973
+ }
974
+ }
975
+ }
976
+
953
977
/// Return all simple paths between 2 nodes in a PyGraph object
954
978
///
955
979
/// A simple path is a path with no repeated nodes.
956
980
///
957
981
/// :param PyGraph graph: The graph to find the path in
958
982
/// :param int origin: The node index to find the paths from
959
- /// :param int to: The node index to find the paths to
983
+ /// :param int | iterable[int] to: The node index(es) to find the paths to
960
984
/// :param int min_depth: The minimum depth of the path to include in the output
961
985
/// list of paths. By default all paths are included regardless of depth,
962
986
/// setting to 0 will behave like the default.
@@ -971,7 +995,7 @@ pub fn local_complement(
971
995
pub fn graph_all_simple_paths (
972
996
graph : & graph:: PyGraph ,
973
997
origin : usize ,
974
- to : usize ,
998
+ to : TargetNodes ,
975
999
min_depth : Option < usize > ,
976
1000
cutoff : Option < usize > ,
977
1001
) -> PyResult < Vec < Vec < usize > > > {
@@ -981,27 +1005,48 @@ pub fn graph_all_simple_paths(
981
1005
"The input index for 'from' is not a valid node index" ,
982
1006
) ) ;
983
1007
}
984
- let to_index = NodeIndex :: new ( to) ;
985
- if !graph. graph . contains_node ( to_index) {
986
- return Err ( InvalidNode :: new_err (
987
- "The input index for 'to' is not a valid node index" ,
988
- ) ) ;
989
- }
990
1008
let min_intermediate_nodes: usize = match min_depth {
991
1009
Some ( 0 ) | None => 0 ,
992
1010
Some ( depth) => depth - 2 ,
993
1011
} ;
994
1012
let cutoff_petgraph: Option < usize > = cutoff. map ( |depth| depth - 2 ) ;
995
- let result: Vec < Vec < usize > > = algo:: all_simple_paths :: < Vec < _ > , _ , foldhash:: fast:: RandomState > (
996
- & graph. graph ,
997
- from_index,
998
- to_index,
999
- min_intermediate_nodes,
1000
- cutoff_petgraph,
1001
- )
1002
- . map ( |v : Vec < NodeIndex > | v. into_iter ( ) . map ( |i| i. index ( ) ) . collect ( ) )
1003
- . collect ( ) ;
1004
- Ok ( result)
1013
+
1014
+ match to {
1015
+ TargetNodes :: Single ( to_index) => {
1016
+ if !graph. graph . contains_node ( to_index) {
1017
+ return Err ( InvalidNode :: new_err (
1018
+ "The input index for 'to' is not a valid node index" ,
1019
+ ) ) ;
1020
+ }
1021
+
1022
+ let result: Vec < Vec < usize > > =
1023
+ algo:: all_simple_paths :: < Vec < _ > , _ , foldhash:: fast:: RandomState > (
1024
+ & graph. graph ,
1025
+ from_index,
1026
+ to_index,
1027
+ min_intermediate_nodes,
1028
+ cutoff_petgraph,
1029
+ )
1030
+ . map ( |v : Vec < NodeIndex > | v. into_iter ( ) . map ( |i| i. index ( ) ) . collect ( ) )
1031
+ . collect ( ) ;
1032
+ Ok ( result)
1033
+ }
1034
+ TargetNodes :: Multiple ( target_set) => {
1035
+ let result = connectivity:: all_simple_paths_multiple_targets (
1036
+ & graph. graph ,
1037
+ from_index,
1038
+ & target_set,
1039
+ min_intermediate_nodes,
1040
+ cutoff_petgraph,
1041
+ ) ;
1042
+
1043
+ Ok ( result
1044
+ . into_values ( )
1045
+ . flatten ( )
1046
+ . map ( |path| path. into_iter ( ) . map ( |node| node. index ( ) ) . collect ( ) )
1047
+ . collect ( ) )
1048
+ }
1049
+ }
1005
1050
}
1006
1051
1007
1052
/// Return all simple paths between 2 nodes in a PyDiGraph object
@@ -1010,7 +1055,7 @@ pub fn graph_all_simple_paths(
1010
1055
///
1011
1056
/// :param PyDiGraph graph: The graph to find the path in
1012
1057
/// :param int origin: The node index to find the paths from
1013
- /// :param int to: The node index to find the paths to
1058
+ /// :param int | iterable[int] to: The node index(es) to find the paths to
1014
1059
/// :param int min_depth: The minimum depth of the path to include in the output
1015
1060
/// list of paths. By default all paths are included regardless of depth,
1016
1061
/// setting to 0 will behave like the default.
@@ -1025,7 +1070,7 @@ pub fn graph_all_simple_paths(
1025
1070
pub fn digraph_all_simple_paths (
1026
1071
graph : & digraph:: PyDiGraph ,
1027
1072
origin : usize ,
1028
- to : usize ,
1073
+ to : TargetNodes ,
1029
1074
min_depth : Option < usize > ,
1030
1075
cutoff : Option < usize > ,
1031
1076
) -> PyResult < Vec < Vec < usize > > > {
@@ -1035,27 +1080,48 @@ pub fn digraph_all_simple_paths(
1035
1080
"The input index for 'from' is not a valid node index" ,
1036
1081
) ) ;
1037
1082
}
1038
- let to_index = NodeIndex :: new ( to) ;
1039
- if !graph. graph . contains_node ( to_index) {
1040
- return Err ( InvalidNode :: new_err (
1041
- "The input index for 'to' is not a valid node index" ,
1042
- ) ) ;
1043
- }
1044
1083
let min_intermediate_nodes: usize = match min_depth {
1045
1084
Some ( 0 ) | None => 0 ,
1046
1085
Some ( depth) => depth - 2 ,
1047
1086
} ;
1048
1087
let cutoff_petgraph: Option < usize > = cutoff. map ( |depth| depth - 2 ) ;
1049
- let result: Vec < Vec < usize > > = algo:: all_simple_paths :: < Vec < _ > , _ , foldhash:: fast:: RandomState > (
1050
- & graph. graph ,
1051
- from_index,
1052
- to_index,
1053
- min_intermediate_nodes,
1054
- cutoff_petgraph,
1055
- )
1056
- . map ( |v : Vec < NodeIndex > | v. into_iter ( ) . map ( |i| i. index ( ) ) . collect ( ) )
1057
- . collect ( ) ;
1058
- Ok ( result)
1088
+
1089
+ match to {
1090
+ TargetNodes :: Single ( to_index) => {
1091
+ if !graph. graph . contains_node ( to_index) {
1092
+ return Err ( InvalidNode :: new_err (
1093
+ "The input index for 'to' is not a valid node index" ,
1094
+ ) ) ;
1095
+ }
1096
+
1097
+ let result: Vec < Vec < usize > > =
1098
+ algo:: all_simple_paths :: < Vec < _ > , _ , foldhash:: fast:: RandomState > (
1099
+ & graph. graph ,
1100
+ from_index,
1101
+ to_index,
1102
+ min_intermediate_nodes,
1103
+ cutoff_petgraph,
1104
+ )
1105
+ . map ( |v : Vec < NodeIndex > | v. into_iter ( ) . map ( |i| i. index ( ) ) . collect ( ) )
1106
+ . collect ( ) ;
1107
+ Ok ( result)
1108
+ }
1109
+ TargetNodes :: Multiple ( target_set) => {
1110
+ let result = connectivity:: all_simple_paths_multiple_targets (
1111
+ & graph. graph ,
1112
+ from_index,
1113
+ & target_set,
1114
+ min_intermediate_nodes,
1115
+ cutoff_petgraph,
1116
+ ) ;
1117
+
1118
+ Ok ( result
1119
+ . into_values ( )
1120
+ . flatten ( )
1121
+ . map ( |path| path. into_iter ( ) . map ( |node| node. index ( ) ) . collect ( ) )
1122
+ . collect ( ) )
1123
+ }
1124
+ }
1059
1125
}
1060
1126
1061
1127
/// Return all the simple paths between all pairs of nodes in the graph
0 commit comments