44
55from typing import Literal
66
7+ from .julia_helpers import KNOWN_CLUSTERMANAGER_BACKENDS
78from .julia_import import Pkg , jl
89from .julia_registry_helpers import try_with_registry_fallback
910from .logger_specs import AbstractLoggerSpec , TensorBoardLoggerSpec
1011
12+ PACKAGE_UUIDS = {
13+ "LoopVectorization" : "bdcacae8-1622-11e9-2a5c-532679323890" ,
14+ "Bumper" : "8ce10254-0962-460f-a3d8-1f77fea1446e" ,
15+ "Zygote" : "e88e6eb3-aa80-5325-afca-941959d7151f" ,
16+ "SlurmClusterManager" : "c82cd089-7bf7-41d7-976b-6b5d413cbe0a" ,
17+ "ClusterManagers" : "34f1f09b-3a8b-5176-ab39-66d58a4d544e" ,
18+ "TensorBoardLogger" : "899adc3e-224a-11e9-021f-63837185c80f" ,
19+ }
20+
1121
1222def load_required_packages (
1323 * ,
@@ -18,26 +28,24 @@ def load_required_packages(
1828 logger_spec : AbstractLoggerSpec | None = None ,
1929):
2030 if turbo :
21- load_package ("LoopVectorization" , "bdcacae8-1622-11e9-2a5c-532679323890" )
31+ load_package ("LoopVectorization" )
2232 if bumper :
23- load_package ("Bumper" , "8ce10254-0962-460f-a3d8-1f77fea1446e" )
33+ load_package ("Bumper" )
2434 if autodiff_backend is not None :
25- load_package ("Zygote" , "e88e6eb3-aa80-5325-afca-941959d7151f" )
35+ load_package ("Zygote" )
2636 if cluster_manager is not None :
27- load_package ("ClusterManagers" , "34f1f09b-3a8b-5176-ab39-66d58a4d544e" )
37+ if cluster_manager == "slurm_native" :
38+ load_package ("SlurmClusterManager" )
39+ elif cluster_manager in KNOWN_CLUSTERMANAGER_BACKENDS :
40+ load_package ("ClusterManagers" )
2841 if isinstance (logger_spec , TensorBoardLoggerSpec ):
29- load_package ("TensorBoardLogger" , "899adc3e-224a-11e9-021f-63837185c80f" )
42+ load_package ("TensorBoardLogger" )
3043
3144
3245def load_all_packages ():
3346 """Install and load all Julia extensions available to PySR."""
34- load_required_packages (
35- turbo = True ,
36- bumper = True ,
37- autodiff_backend = "Zygote" ,
38- cluster_manager = "slurm" ,
39- logger_spec = TensorBoardLoggerSpec (log_dir = "logs" ),
40- )
47+ for package_name , uuid_s in PACKAGE_UUIDS .items ():
48+ load_package (package_name , uuid_s )
4149
4250
4351# TODO: Refactor this file so we can install all packages at once using `juliapkg`,
@@ -48,7 +56,8 @@ def isinstalled(uuid_s: str):
4856 return jl .haskey (Pkg .dependencies (), jl .Base .UUID (uuid_s ))
4957
5058
51- def load_package (package_name : str , uuid_s : str ) -> None :
59+ def load_package (package_name : str , uuid_s : str | None = None ) -> None :
60+ uuid_s = uuid_s or PACKAGE_UUIDS [package_name ]
5261 if not isinstalled (uuid_s ):
5362
5463 def _add_package ():
0 commit comments