Skip to content

Commit a172b7a

Browse files
committed
feat: added md5sum check in case of existing config file for a given experiment
1 parent 06d6ccc commit a172b7a

File tree

3 files changed

+45
-1
lines changed

3 files changed

+45
-1
lines changed

src/modalities/main.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from modalities.config.config import load_app_config_dict
1313
from modalities.config.instantiation_models import TrainingComponentsInstantiationModel, TrainingReportGenerator
1414
from modalities.evaluator import Evaluator
15+
from modalities.exceptions import RunningEnvError
1516
from modalities.gym import Gym
1617
from modalities.logging_broker.message_broker import MessageBroker
1718
from modalities.logging_broker.messages import MessageTypes, ProgressUpdate
@@ -21,6 +22,7 @@
2122
from modalities.registry.registry import Registry
2223
from modalities.trainer import Trainer
2324
from modalities.util import get_synced_experiment_id_of_run, get_total_number_of_trainable_parameters, print_rank_0
25+
from modalities.utils.file_ops import get_file_md5sum
2426
from modalities.utils.logger_utils import get_logger
2527

2628
logger = get_logger(name="main")
@@ -102,7 +104,15 @@ def run(self, components: TrainingComponentsInstantiationModel):
102104
if not (experiment_path / self.config_path.name).exists():
103105
shutil.copy(self.config_path, experiment_path / self.config_path.name)
104106
else:
105-
logger.warning(f"Config file {self.config_path.name} already exists in {experiment_path}. Overwriting.")
107+
logger.warning(f"Config file {self.config_path.name} already exists in {experiment_path}.")
108+
# compare md5 hashes of the two files
109+
existing_config_path = experiment_path / self.config_path.name
110+
if get_file_md5sum(existing_config_path) != get_file_md5sum(self.config_path):
111+
raise RunningEnvError(
112+
f"Config file {self.config_path.name} already exists in {experiment_path}, "
113+
"but the content is different. Please remove the existing config file or "
114+
"create a new experiment ID."
115+
)
106116

107117
resolved_config_path = (experiment_path / self.config_path.name).with_suffix(".yaml.resolved")
108118
with open(resolved_config_path, "w", encoding="utf-8") as f:

src/modalities/utils/file_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import hashlib
2+
from pathlib import Path
3+
4+
5+
def get_file_md5sum(path: Path, chunk_size: int = 8192) -> str:
6+
hash_md5 = hashlib.md5()
7+
with path.open("rb") as f:
8+
for chunk in iter(lambda: f.read(chunk_size), b""):
9+
hash_md5.update(chunk)
10+
return hash_md5.hexdigest()

tests/utils/test_file_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from pathlib import Path
2+
3+
from modalities.utils.file_ops import get_file_md5sum
4+
5+
6+
def test_md5sum_identical(tmp_path: Path):
7+
file1 = tmp_path / "file1.txt"
8+
file2 = tmp_path / "file2.txt"
9+
10+
content = b"Hello, world!\n"
11+
file1.write_bytes(content)
12+
file2.write_bytes(content)
13+
14+
assert get_file_md5sum(file1) == get_file_md5sum(file2)
15+
16+
17+
def test_md5sum_different(tmp_path: Path):
18+
file1 = tmp_path / "file1.txt"
19+
file2 = tmp_path / "file2.txt"
20+
21+
file1.write_bytes(b"Hello, world!\n")
22+
file2.write_bytes(b"Goodbye, world!\n")
23+
24+
assert get_file_md5sum(file1) != get_file_md5sum(file2)

0 commit comments

Comments
 (0)