From 6bcc875cd6b157c794a4bd6e129e39077b0db723 Mon Sep 17 00:00:00 2001 From: Karthikeyan Singaravelan Date: Tue, 22 Mar 2022 10:39:21 +0000 Subject: [PATCH] [refactor] Import ABC from collections.abc for Python 3.10 compatibility --- mmf/common/report.py | 2 +- mmf/common/sample.py | 2 +- mmf/models/transformers/heads/utils.py | 4 ++-- mmf/utils/logger.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mmf/common/report.py b/mmf/common/report.py index 179fb6b3c..28526cf11 100644 --- a/mmf/common/report.py +++ b/mmf/common/report.py @@ -98,7 +98,7 @@ def apply_fn(self, fn: Callable, fields: Optional[List[str]] = None): if key not in fields: continue self[key] = fn(self[key]) - if isinstance(self[key], collections.MutableSequence): + if isinstance(self[key], collections.abc.MutableSequence): for idx, item in enumerate(self[key]): self[key][idx] = fn(item) elif isinstance(self[key], dict): diff --git a/mmf/common/sample.py b/mmf/common/sample.py index bafffa460..7b8f96966 100644 --- a/mmf/common/sample.py +++ b/mmf/common/sample.py @@ -424,7 +424,7 @@ def convert_batch_to_sample_list( def to_device( sample_list: Union[SampleList, Dict[str, Any]], device: device_type = "cuda" ) -> SampleList: - if isinstance(sample_list, collections.Mapping): + if isinstance(sample_list, collections.abc.Mapping): sample_list = convert_batch_to_sample_list(sample_list) # to_device is specifically for SampleList # if user is passing something custom built diff --git a/mmf/models/transformers/heads/utils.py b/mmf/models/transformers/heads/utils.py index 20ba23dce..314ad4162 100644 --- a/mmf/models/transformers/heads/utils.py +++ b/mmf/models/transformers/heads/utils.py @@ -147,10 +147,10 @@ def _process_head_output( head_name: str, sample_list: Dict[str, Tensor], ) -> Dict[str, Tensor]: - if isinstance(outputs, collections.MutableMapping) and "losses" in outputs: + if isinstance(outputs, collections.abc.MutableMapping) and "losses" in outputs: return outputs - if isinstance(outputs, collections.MutableMapping) and "scores" in outputs: + if isinstance(outputs, collections.abc.MutableMapping) and "scores" in outputs: logits = outputs["scores"] else: logits = outputs diff --git a/mmf/utils/logger.py b/mmf/utils/logger.py index 9fe03a508..2153f4c4a 100644 --- a/mmf/utils/logger.py +++ b/mmf/utils/logger.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. -import collections +import collections.abc import functools import json import logging @@ -288,7 +288,7 @@ def log_progress(info: Union[Dict, Any], log_format="simple"): caller, key = _find_caller() logger = logging.getLogger(caller) - if not isinstance(info, collections.Mapping): + if not isinstance(info, collections.abc.Mapping): logger.info(info) if log_format == "simple":