Skip to content

Commit 156598e

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 8f61c2a commit 156598e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+114
-121
lines changed

ac_dc/anonymization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def apply_regex_anonymization(
3030
tag_type=tag_type,
3131
)
3232
if anonymize_condition:
33-
for (ent, start, end, tag) in ner:
33+
for ent, start, end, tag in ner:
3434
# we need to actually walk through and replace by start, end span.
3535
sentence = sentence.replace(ent, f" <{tag}> ")
3636
return sentence, ner

ac_dc/deduplicate/self_deduplicate.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# @Date : 2022-01-08 22:39:29
43
# @Author : Chenghao Mou ([email protected])
54
# @Description: Self-deduplication with `datasets`
@@ -28,7 +27,7 @@
2827

2928
def main(conf: str) -> None:
3029

31-
with open(conf, "r") as f:
30+
with open(conf) as f:
3231
conf = yaml.safe_load(f.read())
3332

3433
if conf["load_from_disk"]["path"]:

ac_dc/visualization/get_data_for_visualization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ def compute_stats(self):
9090
)
9191
for n in range(2, 16)
9292
}
93-
stats_document[
94-
"character_repetition_ratio"
95-
] = character_repetition_ratios
93+
stats_document["character_repetition_ratio"] = (
94+
character_repetition_ratios
95+
)
9696

9797
word_repetition_ratios = {
9898
n: round(

ac_dc/visualization/visualization.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -290,16 +290,16 @@ def get_cond(key, cutoff, max_cutoff):
290290
"stopwords_ratio"
291291
]
292292
for i in range(len(self.docs["stopwords_ratio"])):
293-
self.docs["stopwords_ratio"].iloc[
294-
i
295-
] = Filtering.compute_stopwords_ratio(
296-
self.docs["text"].iloc[i],
297-
self.sentencepiece_model_tok,
298-
self.param["strip_characters"],
299-
self.param["cond_words_augmentation"],
300-
self.param["words_augmentation_group_sizes"],
301-
self.param["words_augmentation_join_char"],
302-
new_stopwords,
293+
self.docs["stopwords_ratio"].iloc[i] = (
294+
Filtering.compute_stopwords_ratio(
295+
self.docs["text"].iloc[i],
296+
self.sentencepiece_model_tok,
297+
self.param["strip_characters"],
298+
self.param["cond_words_augmentation"],
299+
self.param["words_augmentation_group_sizes"],
300+
self.param["words_augmentation_join_char"],
301+
new_stopwords,
302+
)
303303
)
304304
cutoff_def = "If the stop words ratio of a document is lower than this number, the document is removed."
305305
cutoff_stopwords_ratio = st.slider(
@@ -326,16 +326,16 @@ def get_cond(key, cutoff, max_cutoff):
326326
"flagged_words_ratio"
327327
]
328328
for i in range(len(self.docs["flagged_words_ratio"])):
329-
self.docs["flagged_words_ratio"].iloc[
330-
i
331-
] = Filtering.compute_flagged_words_ratio(
332-
self.docs["text"].iloc[i],
333-
self.sentencepiece_model_tok,
334-
self.param["strip_characters"],
335-
self.param["cond_words_augmentation"],
336-
self.param["words_augmentation_group_sizes"],
337-
self.param["words_augmentation_join_char"],
338-
new_flagged_words,
329+
self.docs["flagged_words_ratio"].iloc[i] = (
330+
Filtering.compute_flagged_words_ratio(
331+
self.docs["text"].iloc[i],
332+
self.sentencepiece_model_tok,
333+
self.param["strip_characters"],
334+
self.param["cond_words_augmentation"],
335+
self.param["words_augmentation_group_sizes"],
336+
self.param["words_augmentation_join_char"],
337+
new_flagged_words,
338+
)
339339
)
340340
cutoff_def = "If the flagged words ratio of a document is higher than this number, the document is removed."
341341
max_fwr = np.max(self.docs["flagged_words_ratio"])

bertin/evaluation/run_glue.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# coding=utf-8
32
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
43
#
54
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -384,19 +383,23 @@ def main():
384383
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
385384
# download model & vocab.
386385
config = AutoConfig.from_pretrained(
387-
model_args.config_name
388-
if model_args.config_name
389-
else model_args.model_name_or_path,
386+
(
387+
model_args.config_name
388+
if model_args.config_name
389+
else model_args.model_name_or_path
390+
),
390391
num_labels=num_labels,
391392
finetuning_task=data_args.task_name,
392393
cache_dir=model_args.cache_dir,
393394
revision=model_args.model_revision,
394395
use_auth_token=True if model_args.use_auth_token else None,
395396
)
396397
tokenizer = AutoTokenizer.from_pretrained(
397-
model_args.tokenizer_name
398-
if model_args.tokenizer_name
399-
else model_args.model_name_or_path,
398+
(
399+
model_args.tokenizer_name
400+
if model_args.tokenizer_name
401+
else model_args.model_name_or_path
402+
),
400403
cache_dir=model_args.cache_dir,
401404
use_fast=model_args.use_fast_tokenizer,
402405
revision=model_args.model_revision,

bertin/evaluation/run_ner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# coding=utf-8
32
# Copyright 2020 The HuggingFace Team All rights reserved.
43
#
54
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -364,9 +363,11 @@ def get_label_list(labels):
364363
# The .from_pretrained methods guarantee that only one local process can concurrently
365364
# download model & vocab.
366365
config = AutoConfig.from_pretrained(
367-
model_args.config_name
368-
if model_args.config_name
369-
else model_args.model_name_or_path,
366+
(
367+
model_args.config_name
368+
if model_args.config_name
369+
else model_args.model_name_or_path
370+
),
370371
num_labels=num_labels,
371372
label2id=label_to_id,
372373
id2label={i: l for l, i in label_to_id.items()},
@@ -636,9 +637,9 @@ def compute_metrics(p):
636637
kwargs["dataset_tags"] = data_args.dataset_name
637638
if data_args.dataset_config_name is not None:
638639
kwargs["dataset_args"] = data_args.dataset_config_name
639-
kwargs[
640-
"dataset"
641-
] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
640+
kwargs["dataset"] = (
641+
f"{data_args.dataset_name} {data_args.dataset_config_name}"
642+
)
642643
else:
643644
kwargs["dataset"] = data_args.dataset_name
644645

bertin/mc4/mc4.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Perplexity Sampled mC4 dataset based on Common Crawl."""
22

3-
43
import gzip
54
import json
65

@@ -404,7 +403,7 @@ def _generate_examples(self, filepaths):
404403
for filepath in filepaths:
405404
logger.info("generating examples from = %s", filepath)
406405
if filepath.endswith("jsonl"):
407-
with open(filepath, "r", encoding="utf-8") as f:
406+
with open(filepath, encoding="utf-8") as f:
408407
for line in f:
409408
if line:
410409
example = json.loads(line)

bertin/run_mlm_flax.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# coding=utf-8
32
# Copyright 2021 The HuggingFace Team All rights reserved.
43
#
54
# Licensed under the Apache License, Version 2.0 (the "License");

bertin/run_mlm_flax_stream.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# coding=utf-8
32
# Copyright 2021 The HuggingFace Team All rights reserved.
43
#
54
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -446,7 +445,7 @@ def restore_checkpoint(save_dir, state):
446445
args = joblib.load(os.path.join(save_dir, "training_args.joblib"))
447446
data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
448447

449-
with open(os.path.join(save_dir, "training_state.json"), "r") as f:
448+
with open(os.path.join(save_dir, "training_state.json")) as f:
450449
training_state = json.load(f)
451450
step = training_state["step"]
452451

bertin/utils/dataset_perplexity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_perplexity(doc):
1717

1818

1919
with open("mc4-es-train-50M-stats.csv", "w") as csv:
20-
with open("mc4-es-train-50M-steps.jsonl", "r") as data:
20+
with open("mc4-es-train-50M-steps.jsonl") as data:
2121
for line in tqdm(data):
2222
text = json.loads(line)["text"]
2323
csv.write(f"{len(text.split())},{get_perplexity(text)}\n")

0 commit comments

Comments
 (0)