File tree Expand file tree Collapse file tree 1 file changed +7
-6
lines changed
torchprime/torch_xla_models Expand file tree Collapse file tree 1 file changed +7
-6
lines changed Original file line number Diff line number Diff line change @@ -204,13 +204,14 @@ def _log_shapes(self, batch):
204204 )
205205 logger .info (f"[{ self .name } ] data shapes: { shapes } " )
206206
207- # Visualize one tensor.
208- import click
209- from torch_xla .distributed .spmd .debugging import visualize_tensor_sharding
210-
207+ # Visualize one example tensor.
211208 t = next (iter (pytree .tree_iter (batch )))
212- generated_table = visualize_tensor_sharding (t , use_color = False )
213- click .echo (generated_table )
209+ if t .device .type == "xla" :
210+ import click
211+ from torch_xla .distributed .spmd .debugging import visualize_tensor_sharding
212+
213+ generated_table = visualize_tensor_sharding (t , use_color = False )
214+ click .echo (generated_table )
214215
215216 def __len__ (self ):
216217 return len (self .dataloader )
You can’t perform that action at this time.
0 commit comments