diff --git a/torchrec/metrics/cpu_offloaded_metric_module.py b/torchrec/metrics/cpu_offloaded_metric_module.py index 83a9c6e48..f25afa859 100644 --- a/torchrec/metrics/cpu_offloaded_metric_module.py +++ b/torchrec/metrics/cpu_offloaded_metric_module.py @@ -22,7 +22,7 @@ MetricUpdateJob, SynchronizationMarker, ) -from torchrec.metrics.metric_module import MetricValue, RecMetricModule +from torchrec.metrics.metric_module import MetricsFuture, MetricValue, RecMetricModule from torchrec.metrics.metric_state_snapshot import MetricStateSnapshot from torchrec.metrics.model_utils import parse_task_model_outputs from torchrec.metrics.rec_metric import RecMetricException @@ -254,9 +254,7 @@ def compute(self) -> Dict[str, MetricValue]: ) @override - def async_compute( - self, future: concurrent.futures.Future[Dict[str, MetricValue]] - ) -> None: + def async_compute(self) -> MetricsFuture: """ Entry point for asynchronous metric compute. It enqueues a synchronization marker to the update queue. @@ -264,14 +262,16 @@ def async_compute( Args: future: Pre-created future where the computed metrics will be set. """ + metrics_future = concurrent.futures.Future() if self._shutdown_event.is_set(): - future.set_exception( + metrics_future.set_exception( RecMetricException("metric processor thread is shut down.") ) - return + return metrics_future - self.update_queue.put_nowait(SynchronizationMarker(future)) + self.update_queue.put_nowait(SynchronizationMarker(metrics_future)) self.update_queue_size_logger.add(self.update_queue.qsize()) + return metrics_future def _process_synchronization_marker( self, synchronization_marker: SynchronizationMarker diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index 4ab3bb162..759af94ea 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -117,6 +117,7 @@ MetricValue = Union[torch.Tensor, float] +MetricsFuture = concurrent.futures.Future[Dict[str, MetricValue]] class StateMetric(abc.ABC): @@ -492,9 +493,7 @@ def load_pre_compute_states( def shutdown(self) -> None: logger.info("Initiating graceful shutdown...") - def async_compute( - self, future: concurrent.futures.Future[Dict[str, MetricValue]] - ) -> None: + def async_compute(self) -> MetricsFuture: raise RecMetricException("async_compute is not supported in RecMetricModule") diff --git a/torchrec/metrics/tests/test_cpu_offloaded_metric_module.py b/torchrec/metrics/tests/test_cpu_offloaded_metric_module.py index 3c3d557e3..dcb6accaf 100644 --- a/torchrec/metrics/tests/test_cpu_offloaded_metric_module.py +++ b/torchrec/metrics/tests/test_cpu_offloaded_metric_module.py @@ -207,10 +207,6 @@ def test_async_compute_synchronization_marker(self) -> None: Note that the comms module's metrics are actually the ones that are computed. """ - future: concurrent.futures.Future[Dict[str, MetricValue]] = ( - concurrent.futures.Future() - ) - model_out = { "task1-prediction": torch.tensor([0.5]), "task1-label": torch.tensor([0.7]), @@ -220,7 +216,7 @@ def test_async_compute_synchronization_marker(self) -> None: for _ in range(10): self.cpu_module.update(model_out) - self.cpu_module.async_compute(future) + self.cpu_module.async_compute() comms_mock_metric = cast( MockRecMetric, self.cpu_module.comms_module.rec_metrics.rec_metrics[0] @@ -234,10 +230,7 @@ def test_async_compute_synchronization_marker(self) -> None: def test_async_compute_after_shutdown(self) -> None: self.cpu_module.shutdown() - future: concurrent.futures.Future[Dict[str, MetricValue]] = ( - concurrent.futures.Future() - ) - self.cpu_module.async_compute(future) + future = self.cpu_module.async_compute() self.assertRaisesRegex( RecMetricException, "metric processor thread is shut down.", future.result @@ -275,7 +268,7 @@ def test_wait_until_queue_is_empty(self) -> None: "task1-weight": torch.tensor([1.0]), } self.cpu_module.update(model_out) - self.cpu_module.async_compute(concurrent.futures.Future()) + self.cpu_module.async_compute() self.cpu_module.wait_until_queue_is_empty(self.cpu_module.update_queue) self.cpu_module.wait_until_queue_is_empty(self.cpu_module.compute_queue) @@ -576,10 +569,7 @@ def _compare_metric_results_worker( standard_results = standard_module.compute() - future: concurrent.futures.Future[Dict[str, MetricValue]] = ( - concurrent.futures.Future() - ) - cpu_offloaded_module.async_compute(future) + future = cpu_offloaded_module.async_compute() # Wait for async compute to finish. Compare the input to each update() offloaded_results = future.result(timeout=10.0) diff --git a/torchrec/metrics/tests/test_metric_module.py b/torchrec/metrics/tests/test_metric_module.py index b71cae543..d5d96b945 100644 --- a/torchrec/metrics/tests/test_metric_module.py +++ b/torchrec/metrics/tests/test_metric_module.py @@ -662,7 +662,7 @@ def test_async_compute_raises_exception(self) -> None: RecMetricException, "async_compute is not supported in RecMetricModule", ): - metric_module.async_compute(concurrent.futures.Future()) + metric_module.async_compute() def metric_module_gather_state(