Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions gt_extras/_utils_column2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable

from great_tables._gt_data import GTData
from great_tables._styles import FromColumn
from great_tables._tbl_data import PlExpr, _get_cell, eval_transform, n_rows

if TYPE_CHECKING:
import polars as pl

PlExpr = pl.Expr


@dataclass
class ColumnExtractor:
data: GTData
attr: str | FromColumn | PlExpr | Callable

def _check_attr(self):
attr = self.attr
if not isinstance(attr, (str, FromColumn, PlExpr)) and not callable(attr):
raise TypeError(
f"{attr=} must be one of: str, FromColumn, PlExpr, or a callable"
)

def _eval_exprs_to_get_values(self) -> list[Any]:
attr, data = self.attr, self.data
n_row = n_rows(data)

if isinstance(attr, str):
vals = [attr for _ in range(n_row)]
elif isinstance(attr, FromColumn):
vals = []
for row in range(n_row):
val = _get_cell(data, row, attr.column)
if attr.fn is not None:
vals.append(attr.fn(val))
else:
vals.append(val)
else:
vals = eval_transform(data, attr)
return vals

def resolve(self) -> list[Any]:
self._check_attr()
return self._eval_exprs_to_get_values()
14 changes: 11 additions & 3 deletions gt_extras/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_scale_numeric_column,
_validate_and_get_single_column,
)
from gt_extras._utils_column2 import ColumnExtractor

__all__ = [
"gt_plt_bar",
Expand Down Expand Up @@ -149,7 +150,9 @@ def gt_plt_bar(
if stroke_color is None:
stroke_color = "transparent"

def _make_bar(scaled_val: float, original_val: int | float) -> str:
def _make_bar(
scaled_val: float, original_val: int | float, label_color: str
) -> str:
svg = _make_bar_svg(
scaled_val=scaled_val,
original_val=original_val,
Expand All @@ -167,6 +170,8 @@ def _make_bar(scaled_val: float, original_val: int | float) -> str:
# Get names of columns
columns_resolved = resolve_cols_c(data=gt, expr=columns)

_label_colors = ColumnExtractor(gt._tbl_data, label_color).resolve()

res = gt
for column in columns_resolved:
# Validate this is a single column and get values
Expand All @@ -193,11 +198,14 @@ def _make_bar(scaled_val: float, original_val: int | float) -> str:
col_name = col_name + " plot"

# Apply the scaled value for each row, so the bar is proportional
for i, scaled_val in enumerate(scaled_vals):
for i, (scaled_val, _label_color) in enumerate(zip(scaled_vals, _label_colors)):
res = res.fmt(
lambda original_val, scaled_val=scaled_val: _make_bar(
lambda original_val,
scaled_val=scaled_val,
label_color=_label_color: _make_bar(
original_val=original_val,
scaled_val=scaled_val,
label_color=label_color,
),
columns=col_name,
rows=[i],
Expand Down
Loading