diff --git a/flair/datasets/__init__.py b/flair/datasets/__init__.py index f1c1279355..31300ee326 100644 --- a/flair/datasets/__init__.py +++ b/flair/datasets/__init__.py @@ -297,6 +297,11 @@ UD_KAZAKH, UD_KOREAN, UD_LATIN, + UD_LATIN_CIRCSE, + UD_LATIN_ITTB, + UD_LATIN_UDANTE, + UD_LATIN_PERSEUS, + UD_LATIN_PROIEL, UD_LATVIAN, UD_LITHUANIAN, UD_LIVVI, @@ -536,6 +541,11 @@ "UD_KAZAKH", "UD_KOREAN", "UD_LATIN", + "UD_LATIN_CIRCSE", + "UD_LATIN_ITTB", + "UD_LATIN_UDANTE", + "UD_LATIN_PERSEUS", + "UD_LATIN_PROIEL", "UD_LATVIAN", "UD_LITHUANIAN", "UD_LIVVI", diff --git a/flair/datasets/treebanks.py b/flair/datasets/treebanks.py index 21ae327691..37c6147652 100644 --- a/flair/datasets/treebanks.py +++ b/flair/datasets/treebanks.py @@ -19,6 +19,7 @@ def __init__( test_file=None, dev_file=None, in_memory: bool = True, + sample_missing_splits=True, split_multiwords: bool = True, ) -> None: """Instantiates a Corpus from CoNLL-U column-formatted task data such as the UD corpora. @@ -28,6 +29,7 @@ def __init__( :param test_file: the name of the test file :param dev_file: the name of the dev file, if None, dev data is sampled from train :param in_memory: If set to True, keeps full dataset in memory, otherwise does disk reads + :param sample_missing_splits: If set to True, missing splits will be randomly sampled from the training split :param split_multiwords: If set to True, multiwords are split (default), otherwise kept as single tokens :return: a Corpus with annotated train, dev and test data """ @@ -55,7 +57,7 @@ def __init__( else None ) - super().__init__(train, dev, test, name=str(data_folder)) + super().__init__(train, dev, test, name=str(data_folder), sample_missing_splits=sample_missing_splits) class UniversalDependenciesDataset(FlairDataset): @@ -581,6 +583,128 @@ def __init__( super().__init__(data_folder, in_memory=in_memory, split_multiwords=split_multiwords) +class UD_LATIN_CIRCSE(UniversalDependenciesCorpus): + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + in_memory: bool = True, + split_multiwords: bool = True, + revision: str = "master", + ) -> None: + base_path = Path(flair.cache_root) / "datasets" if not base_path else Path(base_path) + + # this dataset name + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + # download data if necessary + web_path = f"https://raw.githubusercontent.com/UniversalDependencies/UD_Latin-CIRCSE/{revision}/" + cached_path(f"{web_path}/la_circse-ud-test.conllu", Path("datasets") / dataset_name) + + super().__init__(data_folder, in_memory=in_memory, split_multiwords=split_multiwords) + + +class UD_LATIN_ITTB(UniversalDependenciesCorpus): + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + in_memory: bool = True, + split_multiwords: bool = True, + revision: str = "master", + ) -> None: + base_path = Path(flair.cache_root) / "datasets" if not base_path else Path(base_path) + + # this dataset name + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + # download data if necessary + web_path = f"https://raw.githubusercontent.com/UniversalDependencies/UD_Latin-ITTB/{revision}/" + + for split in ["train", "dev", "test"]: + cached_path(f"{web_path}/la_ittb-ud-{split}.conllu", Path("datasets") / dataset_name) + + super().__init__(data_folder, in_memory=in_memory, split_multiwords=split_multiwords) + + +class UD_LATIN_UDANTE(UniversalDependenciesCorpus): + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + in_memory: bool = True, + split_multiwords: bool = True, + revision: str = "master", + ) -> None: + base_path = Path(flair.cache_root) / "datasets" if not base_path else Path(base_path) + + # this dataset name + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + # download data if necessary + web_path = f"https://raw.githubusercontent.com/UniversalDependencies/UD_Latin-UDANTE/{revision}/" + + for split in ["train", "dev", "test"]: + cached_path(f"{web_path}/la_udante-ud-{split}.conllu", Path("datasets") / dataset_name) + + super().__init__(data_folder, in_memory=in_memory, split_multiwords=split_multiwords) + + +class UD_LATIN_PERSEUS(UniversalDependenciesCorpus): + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + in_memory: bool = True, + split_multiwords: bool = True, + revision: str = "master", + ) -> None: + base_path = Path(flair.cache_root) / "datasets" if not base_path else Path(base_path) + + # this dataset name + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + # download data if necessary + web_path = f"https://raw.githubusercontent.com/UniversalDependencies/UD_Latin-Perseus/{revision}/" + + for split in ["train", "test"]: + cached_path(f"{web_path}/la_perseus-ud-{split}.conllu", Path("datasets") / dataset_name) + + super().__init__( + data_folder, in_memory=in_memory, sample_missing_splits=False, split_multiwords=split_multiwords + ) + + +class UD_LATIN_PROIEL(UniversalDependenciesCorpus): + def __init__( + self, + base_path: Optional[Union[str, Path]] = None, + in_memory: bool = True, + split_multiwords: bool = True, + revision: str = "master", + ) -> None: + base_path = Path(flair.cache_root) / "datasets" if not base_path else Path(base_path) + + # this dataset name + dataset_name = self.__class__.__name__.lower() + + data_folder = base_path / dataset_name + + # download data if necessary + web_path = f"https://raw.githubusercontent.com/UniversalDependencies/UD_Latin-PROIEL/{revision}/" + + for split in ["train", "dev", "test"]: + cached_path(f"{web_path}/la_proiel-ud-{split}.conllu", Path("datasets") / dataset_name) + + super().__init__( + data_folder, in_memory=in_memory, sample_missing_splits=False, split_multiwords=split_multiwords + ) + + class UD_SPANISH(UniversalDependenciesCorpus): def __init__( self, diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 121a8521bb..f78f3dfd13 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -1018,6 +1018,90 @@ def test_bavarian_wiki(tasks_base_path): ), f"Number of parsed tokens ({actual_tokens}) does not match with reported number of tokens ({ref_tokens})!" +@pytest.mark.skip() +def test_ud_latin(): + revision = "f16caaa" + corpus = flair.datasets.UD_LATIN(revision=revision) + + # Taken from: https://github.com/UniversalDependencies/UD_Latin-LLCT/blob/f16caaa3b0c57e3319396a1af74ee2bc7c9b4323/stats.xml#L8 + ref_sentences = 9023 + actual_sentences = len(corpus.train) + len(corpus.dev) + len(corpus.test) + + assert ( + ref_sentences == actual_sentences + ), f"Number of parsed token ({actual_sentences}) does not match with reported number of sentences ({ref_sentences})!" + + +@pytest.mark.skip() +def test_ud_latin_circse(): + revision = "13cc204" + corpus = flair.datasets.UD_LATIN_CIRCSE(revision=revision) + + # Taken from: https://github.com/UniversalDependencies/UD_Latin-CIRCSE/blob/13cc204a1d8910d7f95fd78b23aec93ccb64be5c/stats.xml#L8 + ref_sentences = 1263 + actual_sentences = len(corpus.test) + + assert ( + ref_sentences == actual_sentences + ), f"Number of parsed token ({actual_sentences}) does not match with reported number of sentences ({ref_sentences})!" + + +@pytest.mark.skip() +def test_ud_latin_ittb(): + revision = "9991421" + corpus = flair.datasets.UD_LATIN_ITTB(revision=revision) + + # Taken from: https://github.com/UniversalDependencies/UD_Latin-ITTB/blob/9991421cd858f6603b4f27b26c1f11d4619fc8cc/stats.xml#L8 + ref_sentences = 26977 + actual_sentences = len(corpus.train) + len(corpus.dev) + len(corpus.test) + + assert ( + ref_sentences == actual_sentences + ), f"Number of parsed token ({actual_sentences}) does not match with reported number of sentences ({ref_sentences})!" + + +@pytest.mark.skip() +def test_ud_latin_udante(): + revision = "f817abd" + corpus = flair.datasets.UD_LATIN_UDANTE(revision=revision) + + # Taken from: https://github.com/UniversalDependencies/UD_Latin-UDante/blob/f817abdaeaf3b40250b65d1a6bbbd5c7dcee7836/stats.xml#L8 + ref_sentences = 1723 + actual_sentences = len(corpus.train) + len(corpus.dev) + len(corpus.test) + + assert ( + ref_sentences == actual_sentences + ), f"Number of parsed token ({actual_sentences}) does not match with reported number of sentences ({ref_sentences})!" + + +@pytest.mark.skip() +def test_ud_latin_perseus(): + revision = "b3c7f9b" + corpus = flair.datasets.UD_LATIN_PERSEUS(revision=revision) + + # Taken from: https://github.com/UniversalDependencies/UD_Latin-Perseus/blob/b3c7f9b6751c404db3b1f9e436ba4557d8b945c5/stats.xml#L8 + ref_sentences = 2273 + actual_sentences = len(corpus.train) + len(corpus.test) + + assert ( + ref_sentences == actual_sentences + ), f"Number of parsed token ({actual_sentences}) does not match with reported number of sentences ({ref_sentences})!" + + +@pytest.mark.skip() +def test_ud_latin_proiel(): + revision = "6d7c717" + corpus = flair.datasets.UD_LATIN_PROIEL(revision=revision) + + # Taken from: https://github.com/UniversalDependencies/UD_Latin-PROIEL/blob/6d7c717f6c9fa971c312fa2071016cbd5f2e6a41/stats.xml#L8 + ref_sentences = 18689 + actual_sentences = len(corpus.train) + len(corpus.dev) + len(corpus.test) + + assert ( + ref_sentences == actual_sentences + ), f"Number of parsed token ({actual_sentences}) does not match with reported number of sentences ({ref_sentences})!" + + def test_multi_file_jsonl_corpus_should_use_label_type(tasks_base_path): corpus = MultiFileJsonlCorpus( train_files=[tasks_base_path / "jsonl/train.jsonl"],