Skip to content

Commit 03d8132

Browse files
fix recompile in distrubuted mode with given path
1 parent 31757ba commit 03d8132

File tree

4 files changed

+18
-16
lines changed

4 files changed

+18
-16
lines changed

examples/multi_host/multicontroller_vqe_with_path.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,17 @@ def run_vqe_main(coordinator_address: str, num_processes: int, process_id: int):
9898
# The contractor will use this concrete array to run its (now internal)
9999
# "find path on 0 and broadcast" logic.},
100100

101+
# Shard the parameters onto devices for the actual GPU/TPU computation.
102+
params_sharding = NamedSharding(global_mesh, P(*([None] * len(params_shape))))
103+
params = jax.device_put(params_cpu, params_sharding)
104+
101105
DC = DistributedContractor.from_path(
102106
filepath="tree.pkl",
103107
nodes_fn=nodes_fn,
108+
params=params,
104109
mesh=global_mesh,
105110
)
106111

107-
# Shard the parameters onto devices for the actual GPU/TPU computation.
108-
params_sharding = NamedSharding(global_mesh, P(*([None] * len(params_shape))))
109-
params = jax.device_put(params_cpu, params_sharding)
110-
111112
# Initialize the optimizer and its state.
112113
optimizer = optax.adam(2e-2)
113114
opt_state = optimizer.init(params) # Can init directly with sharded params

examples/multi_host/slurm_vqe_with_path.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,17 @@ def run_vqe_main():
8383
# Broadcast the CPU array. Now all processes have a concrete `params_cpu`.
8484
# This is CRITICAL to prevent the NoneType error upon contractor initialization.
8585
params_cpu = broadcast_py_object(params_cpu)
86+
# Shard the parameters onto devices for the actual GPU/TPU computation.
87+
params_sharding = NamedSharding(global_mesh, P(*([None] * len(params_shape))))
88+
params = jax.device_put(params_cpu, params_sharding)
8689

8790
DC = DistributedContractor.from_path(
8891
filepath="tree.pkl",
8992
nodes_fn=nodes_fn,
9093
mesh=global_mesh,
94+
params=params,
9195
)
9296

93-
# Shard the parameters onto devices for the actual GPU/TPU computation.
94-
params_sharding = NamedSharding(global_mesh, P(*([None] * len(params_shape))))
95-
params = jax.device_put(params_cpu, params_sharding)
96-
9797
# Initialize the optimizer and its state.
9898
optimizer = optax.adam(2e-2)
9999
opt_state = optimizer.init(params) # Can init directly with sharded params

llm.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,6 @@ pip install -r requirements/requirements-types.txt
142142

143143
### Branch Strategy
144144

145-
- main/master branch for stable releases
145+
- master branch for stable releases
146146
- beta branch for nightly builds (as seen in nightly_release.yml)
147147
- pull requests for feature development

tensorcircuit/experimental.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,10 @@ def __init__(
780780
logger.info("DistributedContractor is running on a single device.")
781781

782782
self._params_template = params
783+
self.params_sharding = jaxlib.tree_util.tree_map(
784+
lambda x: NamedSharding(self.mesh, P(*((None,) * x.ndim))),
785+
self._params_template,
786+
)
783787
self._backend = "jax"
784788
self._compiled_v_fns: Dict[
785789
Tuple[Callable[[Tensor], Tensor], str],
@@ -932,6 +936,7 @@ def from_path(
932936
nodes_fn: Callable[[Tensor], List[Gate]],
933937
devices: Optional[List[Any]] = None, # backward compatibility
934938
mesh: Optional[Any] = None,
939+
params: Any = None,
935940
) -> "DistributedContractor":
936941
with open(filepath, "rb") as f:
937942
tree_data = pickle.load(f)
@@ -940,7 +945,7 @@ def from_path(
940945
# We pass the loaded `tree_data` directly to __init__ to trigger the second workflow.
941946
return cls(
942947
nodes_fn=nodes_fn,
943-
params=None,
948+
params=params,
944949
mesh=mesh,
945950
devices=devices,
946951
tree_data=tree_data,
@@ -1107,19 +1112,15 @@ def global_aggregated_fn(
11071112

11081113
# Compile the global function with jax.jit and specify shardings.
11091114
# `params` are replicated (available everywhere).
1110-
params_sharding = jaxlib.tree_util.tree_map(
1111-
lambda x: NamedSharding(self.mesh, P(*((None,) * x.ndim))),
1112-
self._params_template,
1113-
)
11141115

1115-
in_shardings = (params_sharding, self.sharding)
1116+
in_shardings = (self.params_sharding, self.sharding)
11161117

11171118
if is_grad_fn:
11181119
# Returns (value, grad), so out_sharding must be a 2-tuple.
11191120
# `value` is a replicated scalar -> P()
11201121
sharding_for_value = NamedSharding(self.mesh, P())
11211122
# `grad` is a replicated PyTree with the same structure as params.
1122-
sharding_for_grad = params_sharding
1123+
sharding_for_grad = self.params_sharding
11231124
out_shardings = (sharding_for_value, sharding_for_grad)
11241125
else:
11251126
# Returns a single scalar value -> P()

0 commit comments

Comments
 (0)