Skip to content

Commit 0ce560d

Browse files
committed
feat: add SlurmClusterManager support
1 parent a23b847 commit 0ce560d

File tree

3 files changed

+45
-26
lines changed

3 files changed

+45
-26
lines changed

pysr/julia_extensions.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,20 @@
44

55
from typing import Literal
66

7+
from .julia_helpers import KNOWN_CLUSTERMANAGER_BACKENDS
78
from .julia_import import Pkg, jl
89
from .julia_registry_helpers import try_with_registry_fallback
910
from .logger_specs import AbstractLoggerSpec, TensorBoardLoggerSpec
1011

12+
PACKAGE_UUIDS = {
13+
"LoopVectorization": "bdcacae8-1622-11e9-2a5c-532679323890",
14+
"Bumper": "8ce10254-0962-460f-a3d8-1f77fea1446e",
15+
"Zygote": "e88e6eb3-aa80-5325-afca-941959d7151f",
16+
"SlurmClusterManager": "c82cd089-7bf7-41d7-976b-6b5d413cbe0a",
17+
"ClusterManagers": "34f1f09b-3a8b-5176-ab39-66d58a4d544e",
18+
"TensorBoardLogger": "899adc3e-224a-11e9-021f-63837185c80f",
19+
}
20+
1121

1222
def load_required_packages(
1323
*,
@@ -18,26 +28,24 @@ def load_required_packages(
1828
logger_spec: AbstractLoggerSpec | None = None,
1929
):
2030
if turbo:
21-
load_package("LoopVectorization", "bdcacae8-1622-11e9-2a5c-532679323890")
31+
load_package("LoopVectorization")
2232
if bumper:
23-
load_package("Bumper", "8ce10254-0962-460f-a3d8-1f77fea1446e")
33+
load_package("Bumper")
2434
if autodiff_backend is not None:
25-
load_package("Zygote", "e88e6eb3-aa80-5325-afca-941959d7151f")
35+
load_package("Zygote")
2636
if cluster_manager is not None:
27-
load_package("ClusterManagers", "34f1f09b-3a8b-5176-ab39-66d58a4d544e")
37+
if cluster_manager == "slurm_native":
38+
load_package("SlurmClusterManager")
39+
elif cluster_manager in KNOWN_CLUSTERMANAGER_BACKENDS:
40+
load_package("ClusterManagers")
2841
if isinstance(logger_spec, TensorBoardLoggerSpec):
29-
load_package("TensorBoardLogger", "899adc3e-224a-11e9-021f-63837185c80f")
42+
load_package("TensorBoardLogger")
3043

3144

3245
def load_all_packages():
3346
"""Install and load all Julia extensions available to PySR."""
34-
load_required_packages(
35-
turbo=True,
36-
bumper=True,
37-
autodiff_backend="Zygote",
38-
cluster_manager="slurm",
39-
logger_spec=TensorBoardLoggerSpec(log_dir="logs"),
40-
)
47+
for package_name, uuid_s in PACKAGE_UUIDS.items():
48+
load_package(package_name, uuid_s)
4149

4250

4351
# TODO: Refactor this file so we can install all packages at once using `juliapkg`,
@@ -48,7 +56,8 @@ def isinstalled(uuid_s: str):
4856
return jl.haskey(Pkg.dependencies(), jl.Base.UUID(uuid_s))
4957

5058

51-
def load_package(package_name: str, uuid_s: str) -> None:
59+
def load_package(package_name: str, uuid_s: str | None = None) -> None:
60+
uuid_s = uuid_s or PACKAGE_UUIDS[package_name]
5261
if not isinstalled(uuid_s):
5362

5463
def _add_package():

pysr/julia_helpers.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,21 @@ def _escape_filename(filename):
2929
return str_repr
3030

3131

32-
def _load_cluster_manager(cluster_manager: str):
33-
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
34-
return jl.seval(f"addprocs_{cluster_manager}")
32+
KNOWN_CLUSTERMANAGER_BACKENDS = ["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"]
33+
34+
35+
def load_cluster_manager(cluster_manager: str) -> AnyValue:
36+
if cluster_manager == "slurm_native":
37+
jl.seval("using SlurmClusterManager: SlurmManager")
38+
# TODO: Is this the right way to do this?
39+
jl.seval("addprocs_slurm_native(; _...) = addprocs(SlurmManager())")
40+
return jl.addprocs_slurm_native
41+
elif cluster_manager in KNOWN_CLUSTERMANAGER_BACKENDS:
42+
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
43+
return jl.seval(f"addprocs_{cluster_manager}")
44+
else:
45+
# Assume it's a function
46+
return jl.seval(cluster_manager)
3547

3648

3749
def jl_array(x, dtype=None):

pysr/sr.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@
4545
from .julia_extensions import load_required_packages
4646
from .julia_helpers import (
4747
_escape_filename,
48-
_load_cluster_manager,
4948
jl_array,
5049
jl_deserialize,
5150
jl_is_function,
5251
jl_serialize,
52+
load_cluster_manager,
5353
)
5454
from .julia_import import AnyValue, SymbolicRegression, VectorValue, jl
5555
from .logger_specs import AbstractLoggerSpec
@@ -574,8 +574,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
574574
Default is `None`.
575575
cluster_manager : str
576576
For distributed computing, this sets the job queue system. Set
577-
to one of "slurm", "pbs", "lsf", "sge", "qrsh", "scyld", or
578-
"htc". If set to one of these, PySR will run in distributed
577+
to one of "slurm_native", "slurm", "pbs", "lsf", "sge", "qrsh", "scyld",
578+
or "htc". If set to one of these, PySR will run in distributed
579579
mode, and use `procs` to figure out how many processes to launch.
580580
Default is `None`.
581581
heap_size_hint_in_bytes : int
@@ -876,13 +876,11 @@ def __init__(
876876
probability_negate_constant: float = 0.00743,
877877
tournament_selection_n: int = 15,
878878
tournament_selection_p: float = 0.982,
879-
parallelism: (
880-
Literal["serial", "multithreading", "multiprocessing"] | None
881-
) = None,
879+
# fmt: off
880+
parallelism: Literal["serial", "multithreading", "multiprocessing"] | None = None,
882881
procs: int | None = None,
883-
cluster_manager: (
884-
Literal["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"] | None
885-
) = None,
882+
cluster_manager: Literal["slurm_native", "slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"] | str | None = None,
883+
# fmt: on
886884
heap_size_hint_in_bytes: int | None = None,
887885
batching: bool = False,
888886
batch_size: int = 50,
@@ -1880,7 +1878,7 @@ def _run(
18801878
raise ValueError(
18811879
"To use cluster managers, you must set `parallelism='multiprocessing'`."
18821880
)
1883-
cluster_manager = _load_cluster_manager(cluster_manager)
1881+
cluster_manager = load_cluster_manager(cluster_manager)
18841882

18851883
# TODO(mcranmer): These functions should be part of this class.
18861884
binary_operators, unary_operators = _maybe_create_inline_operators(

0 commit comments

Comments
 (0)