From 2c400f293b6cb193148552dc4a4a245efdee7fe2 Mon Sep 17 00:00:00 2001 From: Ryo Kitagawa Date: Tue, 16 Dec 2025 00:06:09 +0900 Subject: [PATCH] feat: add automatic processor selection based on DataFrame type parameter --- gokart/task.py | 38 +++++++++++++++- gokart/utils.py | 47 ++++++++++++++++++- test/test_utils.py | 110 ++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 192 insertions(+), 3 deletions(-) diff --git a/gokart/task.py b/gokart/task.py index a9db1b69..4728d595 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -25,7 +25,7 @@ from gokart.required_task_output import RequiredTaskOutput from gokart.target import TargetOnKart from gokart.task_complete_check import task_complete_check_wrapper -from gokart.utils import FlattenableItems, flatten, map_flattenable_items +from gokart.utils import FlattenableItems, flatten, get_dataframe_type_from_task, map_flattenable_items logger = getLogger(__name__) @@ -219,6 +219,10 @@ def make_target(self, relative_file_path: str | None = None, use_unique_id: bool file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None + # Auto-select processor based on type parameter if not provided + if processor is None and relative_file_path is not None: + processor = self._create_processor_for_dataframe_type(file_path) + task_lock_params = make_task_lock_params( file_path=file_path, unique_id=unique_id, @@ -232,6 +236,38 @@ def make_target(self, relative_file_path: str | None = None, use_unique_id: bool file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather ) + def _create_processor_for_dataframe_type(self, file_path: str) -> FileProcessor | None: + """ + Create a file processor with appropriate return_type based on task's type parameter. + + Args: + file_path: Path to the file + + Returns: + FileProcessor with return_type set, or None to use default processor + """ + from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, JsonFileProcessor, ParquetFileProcessor + + extension = os.path.splitext(file_path)[1] + df_type = get_dataframe_type_from_task(self) + + # Create custom processor for DataFrame-supporting file types with type parameter + if extension == '.csv': + return CsvFileProcessor(sep=',', dataframe_type=df_type) + elif extension == '.tsv': + return CsvFileProcessor(sep='\t', dataframe_type=df_type) + elif extension == '.json': + return JsonFileProcessor(orient=None, dataframe_type=df_type) + elif extension == '.ndjson': + return JsonFileProcessor(orient='records', dataframe_type=df_type) + elif extension == '.parquet': + return ParquetFileProcessor(dataframe_type=df_type) + elif extension == '.feather': + return FeatherFileProcessor(store_index_in_feather=self.store_index_in_feather, dataframe_type=df_type) + + # For other file types, use default processor selection + return None + def make_large_data_frame_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip') diff --git a/gokart/utils.py b/gokart/utils.py index 510db5c9..5296502b 100644 --- a/gokart/utils.py +++ b/gokart/utils.py @@ -3,7 +3,7 @@ import os from collections.abc import Callable, Iterable from io import BytesIO -from typing import Any, Protocol, TypeAlias, TypeVar +from typing import Any, Literal, Protocol, TypeAlias, TypeVar, get_args, get_origin import dill import luigi @@ -92,3 +92,48 @@ def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> An assert file.seekable(), f'{file} is not seekable.' file.seek(0) return pd.read_pickle(file) + + +def get_dataframe_type_from_task(task: Any) -> Literal['pandas', 'polars']: + """ + Extract DataFrame type from TaskOnKart[T] type parameter. + + Examines the type parameter T of a TaskOnKart subclass to determine + whether it uses pandas or polars DataFrames. + + Args: + task: A TaskOnKart instance or class + + Returns: + 'pandas' or 'polars' (defaults to 'pandas' if type cannot be determined) + + Examples: + >>> class MyTask(TaskOnKart[pd.DataFrame]): pass + >>> get_dataframe_type_from_task(MyTask()) + 'pandas' + + >>> class MyPolarsTask(TaskOnKart[pl.DataFrame]): pass + >>> get_dataframe_type_from_task(MyPolarsTask()) + 'polars' + """ + task_class = task if isinstance(task, type) else task.__class__ + + if not hasattr(task_class, '__orig_bases__'): + return 'pandas' + + for base in task_class.__orig_bases__: + origin = get_origin(base) + # Check if this is a TaskOnKart subclass + if origin and hasattr(origin, '__name__') and origin.__name__ == 'TaskOnKart': + args = get_args(base) + if args: + df_type = args[0] + module = getattr(df_type, '__module__', '') + + # Check module name to determine DataFrame type + if 'polars' in module: + return 'polars' + elif 'pandas' in module: + return 'pandas' + + return 'pandas' # Default to pandas for backward compatibility diff --git a/test/test_utils.py b/test/test_utils.py index 9b49d330..20e525d8 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,6 +1,17 @@ import unittest -from gokart.utils import flatten, map_flattenable_items +import pandas as pd +import pytest + +from gokart.task import TaskOnKart +from gokart.utils import flatten, get_dataframe_type_from_task, map_flattenable_items + +try: + import polars as pl + + HAS_POLARS = True +except ImportError: + HAS_POLARS = False class TestFlatten(unittest.TestCase): @@ -34,3 +45,100 @@ def test_map_flattenable_items(self): ), {'a': ['1', '2', '3', '4'], 'b': {'c': 'True', 'd': {'e': '5'}}}, ) + + +class TestGetDataFrameTypeFromTask(unittest.TestCase): + """Tests for get_dataframe_type_from_task function.""" + + def test_pandas_dataframe_from_instance(self): + """Test detecting pandas DataFrame from task instance.""" + + class PandasTask(TaskOnKart[pd.DataFrame]): + pass + + task = PandasTask() + self.assertEqual(get_dataframe_type_from_task(task), 'pandas') + + def test_pandas_dataframe_from_class(self): + """Test detecting pandas DataFrame from task class.""" + + class PandasTask(TaskOnKart[pd.DataFrame]): + pass + + self.assertEqual(get_dataframe_type_from_task(PandasTask), 'pandas') + + @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') + def test_polars_dataframe_from_instance(self): + """Test detecting polars DataFrame from task instance.""" + + class PolarsTask(TaskOnKart[pl.DataFrame]): + pass + + task = PolarsTask() + self.assertEqual(get_dataframe_type_from_task(task), 'polars') + + @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') + def test_polars_dataframe_from_class(self): + """Test detecting polars DataFrame from task class.""" + + class PolarsTask(TaskOnKart[pl.DataFrame]): + pass + + self.assertEqual(get_dataframe_type_from_task(PolarsTask), 'polars') + + def test_no_type_parameter_defaults_to_pandas(self): + """Test that tasks without type parameter default to pandas.""" + + # Create a class without __orig_bases__ by not using type parameters + class PlainTask: + pass + + task = PlainTask() + self.assertEqual(get_dataframe_type_from_task(task), 'pandas') + + def test_non_taskonkart_class_defaults_to_pandas(self): + """Test that non-TaskOnKart classes default to pandas.""" + + class RegularClass: + pass + + task = RegularClass() + self.assertEqual(get_dataframe_type_from_task(task), 'pandas') + + def test_taskonkart_with_non_dataframe_type(self): + """Test TaskOnKart with non-DataFrame type parameter defaults to pandas.""" + + class StringTask(TaskOnKart[str]): + pass + + task = StringTask() + # Should default to pandas since str module is not 'pandas' or 'polars' + self.assertEqual(get_dataframe_type_from_task(task), 'pandas') + + def test_nested_inheritance_pandas(self): + """Test that nested inheritance without direct type parameter defaults to pandas.""" + + class BasePandasTask(TaskOnKart[pd.DataFrame]): + pass + + class DerivedPandasTask(BasePandasTask): + pass + + task = DerivedPandasTask() + # DerivedPandasTask doesn't have its own __orig_bases__ with type parameter, + # so it defaults to 'pandas' + self.assertEqual(get_dataframe_type_from_task(task), 'pandas') + + @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') + def test_nested_inheritance_polars(self): + """Test detecting polars DataFrame type through nested inheritance.""" + + class BasePolarsTask(TaskOnKart[pl.DataFrame]): + pass + + class DerivedPolarsTask(BasePolarsTask): + pass + + task = DerivedPolarsTask() + # Function should detect 'polars' through the inheritance chain + self.assertEqual(get_dataframe_type_from_task(task), 'polars')