From 9b14fa7bf0a1ceaf223d00c095197f639099ea57 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Wed, 13 Aug 2025 21:15:47 +0800 Subject: [PATCH] Add draft implementation of `ColumnExtractor` --- gt_extras/_utils_column2.py | 46 +++++++++++++++++++++++++++++++++++++ gt_extras/plotting.py | 14 ++++++++--- 2 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 gt_extras/_utils_column2.py diff --git a/gt_extras/_utils_column2.py b/gt_extras/_utils_column2.py new file mode 100644 index 00000000..b0113ba4 --- /dev/null +++ b/gt_extras/_utils_column2.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable + +from great_tables._gt_data import GTData +from great_tables._styles import FromColumn +from great_tables._tbl_data import PlExpr, _get_cell, eval_transform, n_rows + +if TYPE_CHECKING: + import polars as pl + + PlExpr = pl.Expr + + +@dataclass +class ColumnExtractor: + data: GTData + attr: str | FromColumn | PlExpr | Callable + + def _check_attr(self): + attr = self.attr + if not isinstance(attr, (str, FromColumn, PlExpr)) and not callable(attr): + raise TypeError( + f"{attr=} must be one of: str, FromColumn, PlExpr, or a callable" + ) + + def _eval_exprs_to_get_values(self) -> list[Any]: + attr, data = self.attr, self.data + n_row = n_rows(data) + + if isinstance(attr, str): + vals = [attr for _ in range(n_row)] + elif isinstance(attr, FromColumn): + vals = [] + for row in range(n_row): + val = _get_cell(data, row, attr.column) + if attr.fn is not None: + vals.append(attr.fn(val)) + else: + vals.append(val) + else: + vals = eval_transform(data, attr) + return vals + + def resolve(self) -> list[Any]: + self._check_attr() + return self._eval_exprs_to_get_values() diff --git a/gt_extras/plotting.py b/gt_extras/plotting.py index 0dfe228b..7fcb3ccc 100644 --- a/gt_extras/plotting.py +++ b/gt_extras/plotting.py @@ -20,6 +20,7 @@ _scale_numeric_column, _validate_and_get_single_column, ) +from gt_extras._utils_column2 import ColumnExtractor __all__ = [ "gt_plt_bar", @@ -149,7 +150,9 @@ def gt_plt_bar( if stroke_color is None: stroke_color = "transparent" - def _make_bar(scaled_val: float, original_val: int | float) -> str: + def _make_bar( + scaled_val: float, original_val: int | float, label_color: str + ) -> str: svg = _make_bar_svg( scaled_val=scaled_val, original_val=original_val, @@ -167,6 +170,8 @@ def _make_bar(scaled_val: float, original_val: int | float) -> str: # Get names of columns columns_resolved = resolve_cols_c(data=gt, expr=columns) + _label_colors = ColumnExtractor(gt._tbl_data, label_color).resolve() + res = gt for column in columns_resolved: # Validate this is a single column and get values @@ -193,11 +198,14 @@ def _make_bar(scaled_val: float, original_val: int | float) -> str: col_name = col_name + " plot" # Apply the scaled value for each row, so the bar is proportional - for i, scaled_val in enumerate(scaled_vals): + for i, (scaled_val, _label_color) in enumerate(zip(scaled_vals, _label_colors)): res = res.fmt( - lambda original_val, scaled_val=scaled_val: _make_bar( + lambda original_val, + scaled_val=scaled_val, + label_color=_label_color: _make_bar( original_val=original_val, scaled_val=scaled_val, + label_color=label_color, ), columns=col_name, rows=[i],