Skip to content

Commit c481fc6

Browse files
alexbarrosaquemy
authored andcommitted
fix: remove computation from missing data plots (#1294)
1 parent 421ee51 commit c481fc6

File tree

4 files changed

+109
-61
lines changed

4 files changed

+109
-61
lines changed

src/ydata_profiling/model/pandas/missing_pandas.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import pandas as pd
23

34
from ydata_profiling.config import Settings
@@ -11,14 +12,35 @@
1112

1213
@missing_bar.register
1314
def pandas_missing_bar(config: Settings, df: pd.DataFrame) -> str:
14-
return plot_missing_bar(config, df)
15+
notnull_counts = len(df) - df.isnull().sum()
16+
return plot_missing_bar(
17+
config,
18+
notnull_counts=notnull_counts,
19+
nrows=len(df),
20+
columns=list(df.columns),
21+
)
1522

1623

1724
@missing_matrix.register
1825
def pandas_missing_matrix(config: Settings, df: pd.DataFrame) -> str:
19-
return plot_missing_matrix(config, df)
26+
return plot_missing_matrix(
27+
config,
28+
columns=list(df.columns),
29+
notnull=df.notnull().values,
30+
nrows=len(df),
31+
)
2032

2133

2234
@missing_heatmap.register
2335
def pandas_missing_heatmap(config: Settings, df: pd.DataFrame) -> str:
24-
return plot_missing_heatmap(config, df)
36+
# Remove completely filled or completely empty variables.
37+
columns = [i for i, n in enumerate(np.var(df.isnull(), axis="rows")) if n > 0]
38+
df = df.iloc[:, columns]
39+
40+
# Create and mask the correlation matrix. Construct the base heatmap.
41+
corr_mat = df.isnull().corr()
42+
mask = np.zeros_like(corr_mat)
43+
mask[np.triu_indices_from(mask)] = True
44+
return plot_missing_heatmap(
45+
config, corr_mat=corr_mat, mask=mask, columns=list(df.columns)
46+
)

src/ydata_profiling/model/spark/missing_spark.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any, List, Optional
22

3+
import numpy as np
34
from pyspark.sql import DataFrame
45

56
from ydata_profiling.config import Settings
@@ -67,18 +68,33 @@ def spark_missing_bar(config: Settings, df: DataFrame) -> str:
6768
)
6869

6970
return plot_missing_bar(
70-
config,
71-
MissingnoBarSparkPatch(
72-
df=data_nan_counts, columns=df.columns, original_df_size=df.count()
73-
),
71+
config, notnull_counts=data_nan_counts, columns=df.columns, nrows=df.count()
7472
)
7573

7674

7775
@missing_matrix.register
7876
def spark_missing_matrix(config: Settings, df: DataFrame) -> str:
79-
return plot_missing_matrix(config, MissingnoBarSparkPatch(df))
77+
df = MissingnoBarSparkPatch(df, columns=df.columns, original_df_size=df.count())
78+
return plot_missing_matrix(
79+
config,
80+
columns=df.columns,
81+
notnull=df.notnull().values,
82+
nrows=len(df),
83+
)
8084

8185

8286
@missing_heatmap.register
8387
def spark_missing_heatmap(config: Settings, df: DataFrame) -> str:
84-
return plot_missing_heatmap(config, MissingnoBarSparkPatch(df))
88+
df = MissingnoBarSparkPatch(df, columns=df.columns, original_df_size=df.count())
89+
90+
# Remove completely filled or completely empty variables.
91+
columns = [i for i, n in enumerate(np.var(df.isnull(), axis="rows")) if n > 0]
92+
df = df.iloc[:, columns]
93+
94+
# Create and mask the correlation matrix. Construct the base heatmap.
95+
corr_mat = df.isnull().corr()
96+
mask = np.zeros_like(corr_mat)
97+
mask[np.triu_indices_from(mask)] = True
98+
return plot_missing_heatmap(
99+
config, corr_mat=corr_mat, mask=mask, columns=list(df.columns)
100+
)

src/ydata_profiling/visualisation/missing.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Plotting functions for the missing values diagrams"""
2-
import pandas as pd
2+
from typing import Any, List
3+
34
from matplotlib import pyplot as plt
45

56
from ydata_profiling.config import Settings
@@ -12,22 +13,22 @@
1213
from ydata_profiling.visualisation.utils import hex_to_rgb, plot_360_n0sc0pe
1314

1415

15-
def get_font_size(data: pd.DataFrame) -> float:
16+
def get_font_size(columns: List[str]) -> float:
1617
"""Calculate font size based on number of columns
1718
1819
Args:
19-
data: DataFrame
20+
columns: List of column names.
2021
2122
Returns:
2223
Font size for missing values plots.
2324
"""
24-
max_label_length = max(len(label) for label in data.columns)
25+
max_label_length = max(len(label) for label in columns)
2526

26-
if len(data.columns) < 20:
27+
if len(columns) < 20:
2728
font_size = 13.0
28-
elif 20 <= len(data.columns) < 40:
29+
elif 20 <= len(columns) < 40:
2930
font_size = 12.0
30-
elif 40 <= len(data.columns) < 60:
31+
elif 40 <= len(columns) < 60:
3132
font_size = 10.0
3233
else:
3334
font_size = 8.0
@@ -37,21 +38,27 @@ def get_font_size(data: pd.DataFrame) -> float:
3738

3839

3940
@manage_matplotlib_context()
40-
def plot_missing_matrix(config: Settings, data: pd.DataFrame) -> str:
41+
def plot_missing_matrix(
42+
config: Settings, notnull: Any, columns: List[str], nrows: int
43+
) -> str:
4144
"""Generate missing values matrix plot
4245
4346
Args:
4447
config: report Settings object
45-
data: Pandas DataFrame to generate missing values matrix from.
48+
notnull: Missing data indicator matrix.
49+
columns: List of column names.
50+
nrows: Number of rows in the dataframe.
4651
4752
Returns:
4853
The resulting missing values matrix encoded as a string.
4954
"""
5055

5156
missing_matrix(
52-
data,
57+
notnull=notnull,
58+
height=nrows,
59+
columns=columns,
5360
figsize=(10, 4),
54-
fontsize=get_font_size(data) / 20 * 16,
61+
fontsize=get_font_size(columns) / 20 * 16,
5562
color=hex_to_rgb(config.html.style.primary_colors[0]),
5663
labels=config.plot.missing.force_labels,
5764
)
@@ -60,20 +67,25 @@ def plot_missing_matrix(config: Settings, data: pd.DataFrame) -> str:
6067

6168

6269
@manage_matplotlib_context()
63-
def plot_missing_bar(config: Settings, data: pd.DataFrame) -> str:
70+
def plot_missing_bar(
71+
config: Settings, notnull_counts: list, nrows: int, columns: List[str]
72+
) -> str:
6473
"""Generate missing values bar plot.
6574
6675
Args:
6776
config: report Settings object
68-
data: Pandas DataFrame to generate missing values bar plot from.
77+
notnull_counts: Number of nonnull values per column.
78+
nrows: Number of rows in the dataframe.
79+
columns: List of column names.
6980
7081
Returns:
7182
The resulting missing values bar plot encoded as a string.
7283
"""
7384
missing_bar(
74-
data,
85+
notnull_counts=notnull_counts,
86+
nrows=nrows,
7587
figsize=(10, 5),
76-
fontsize=get_font_size(data),
88+
fontsize=get_font_size(columns),
7789
color=hex_to_rgb(config.html.style.primary_colors[0]),
7890
labels=config.plot.missing.force_labels,
7991
)
@@ -85,35 +97,40 @@ def plot_missing_bar(config: Settings, data: pd.DataFrame) -> str:
8597

8698

8799
@manage_matplotlib_context()
88-
def plot_missing_heatmap(config: Settings, data: pd.DataFrame) -> str:
100+
def plot_missing_heatmap(
101+
config: Settings, corr_mat: Any, mask: Any, columns: List[str]
102+
) -> str:
89103
"""Generate missing values heatmap plot.
90104
91105
Args:
92106
config: report Settings object
93-
data: Pandas DataFrame to generate missing values heatmap plot from.
107+
corr_mat: Correlation matrix.
108+
maks: Upper-triangle mask.
109+
columns: List of column names.
94110
95111
Returns:
96112
The resulting missing values heatmap plot encoded as a string.
97113
"""
98114

99115
height = 4
100-
if len(data.columns) > 10:
101-
height += int((len(data.columns) - 10) / 5)
116+
if len(columns) > 10:
117+
height += int((len(columns) - 10) / 5)
102118
height = min(height, 10)
103119

104-
font_size = get_font_size(data)
105-
if len(data.columns) > 40:
120+
font_size = get_font_size(columns)
121+
if len(columns) > 40:
106122
font_size /= 1.4
107123

108124
missing_heatmap(
109-
data,
125+
corr_mat=corr_mat,
126+
mask=mask,
110127
figsize=(10, height),
111128
fontsize=font_size,
112129
cmap=config.plot.missing.cmap,
113130
labels=config.plot.missing.force_labels,
114131
)
115132

116-
if len(data.columns) > 40:
133+
if len(columns) > 40:
117134
plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.3)
118135
else:
119136
plt.subplots_adjust(left=0.2, right=0.9, top=0.8, bottom=0.3)

src/ydata_profiling/visualisation/plot.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,8 @@ def _set_visibility(
761761

762762

763763
def missing_bar(
764-
data: pd.DataFrame,
764+
notnull_counts: pd.Series,
765+
nrows: int,
765766
figsize: Tuple[float, float] = (25, 10),
766767
fontsize: float = 16,
767768
labels: bool = True,
@@ -774,7 +775,8 @@ def missing_bar(
774775
Inspired by https://github.com/ResidentMario/missingno
775776
776777
Args:
777-
data: The input DataFrame.
778+
notnull_counts: Number of nonnull values per column.
779+
nrows: Number of rows in the dataframe.
778780
figsize: The size of the figure to display.
779781
fontsize: The figure's font size. This default to 16.
780782
labels: Whether or not to display the column names. Would need to be turned off on particularly large
@@ -784,12 +786,10 @@ def missing_bar(
784786
Returns:
785787
The plot axis.
786788
"""
787-
null_counts = len(data) - data.isnull().sum()
788-
values = null_counts.values
789-
null_counts = null_counts / len(data)
789+
percentage = notnull_counts / nrows
790790

791-
if len(values) <= 50:
792-
ax0 = null_counts.plot.bar(figsize=figsize, fontsize=fontsize, color=color)
791+
if len(notnull_counts) <= 50:
792+
ax0 = percentage.plot.bar(figsize=figsize, fontsize=fontsize, color=color)
793793
ax0.set_xticklabels(
794794
ax0.get_xticklabels(),
795795
ha="right",
@@ -801,17 +801,17 @@ def missing_bar(
801801
ax1.set_xticks(ax0.get_xticks())
802802
ax1.set_xlim(ax0.get_xlim())
803803
ax1.set_xticklabels(
804-
values, ha="left", fontsize=fontsize, rotation=label_rotation
804+
notnull_counts, ha="left", fontsize=fontsize, rotation=label_rotation
805805
)
806806
else:
807-
ax0 = null_counts.plot.barh(figsize=figsize, fontsize=fontsize, color=color)
807+
ax0 = percentage.plot.barh(figsize=figsize, fontsize=fontsize, color=color)
808808
ylabels = ax0.get_yticklabels() if labels else []
809809
ax0.set_yticklabels(ylabels, fontsize=fontsize)
810810

811811
ax1 = ax0.twinx()
812812
ax1.set_yticks(ax0.get_yticks())
813813
ax1.set_ylim(ax0.get_ylim())
814-
ax1.set_yticklabels(values, fontsize=fontsize)
814+
ax1.set_yticklabels(notnull_counts, fontsize=fontsize)
815815

816816
for ax in [ax0, ax1]:
817817
ax = _set_visibility(ax)
@@ -820,7 +820,9 @@ def missing_bar(
820820

821821

822822
def missing_matrix(
823-
data: pd.DataFrame,
823+
notnull: Any,
824+
columns: List[str],
825+
height: int,
824826
figsize: Tuple[float, float] = (25, 10),
825827
color: Tuple[float, ...] = (0.41, 0.41, 0.41),
826828
fontsize: float = 16,
@@ -833,7 +835,9 @@ def missing_matrix(
833835
Inspired by https://github.com/ResidentMario/missingno
834836
835837
Args:
836-
data: The input DataFrame.
838+
notnull: Missing data indicator matrix.
839+
columns: List of column names.
840+
height: Number of rows in the dataframe.
837841
figsize: The size of the figure to display.
838842
fontsize: The figure's font size. Default to 16.
839843
labels: Whether or not to display the column names when there is more than 50 columns.
@@ -842,9 +846,7 @@ def missing_matrix(
842846
Returns:
843847
The plot axis.
844848
"""
845-
height, width = data.shape
846-
847-
notnull = data.notnull().values
849+
width = len(columns)
848850
missing_grid = np.zeros((height, width, 3), dtype=np.float32)
849851

850852
missing_grid[notnull] = color
@@ -860,9 +862,7 @@ def missing_matrix(
860862

861863
ha = "left"
862864
ax.set_xticks(list(range(0, width)))
863-
ax.set_xticklabels(
864-
list(data.columns), rotation=label_rotation, ha=ha, fontsize=fontsize
865-
)
865+
ax.set_xticklabels(columns, rotation=label_rotation, ha=ha, fontsize=fontsize)
866866
ax.set_yticks([0, height - 1])
867867
ax.set_yticklabels([1, height], fontsize=fontsize)
868868

@@ -878,7 +878,8 @@ def missing_matrix(
878878

879879

880880
def missing_heatmap(
881-
data: pd.DataFrame,
881+
corr_mat: Any,
882+
mask: Any,
882883
figsize: Tuple[float, float] = (20, 12),
883884
fontsize: float = 16,
884885
labels: bool = True,
@@ -895,7 +896,8 @@ def missing_heatmap(
895896
Inspired by https://github.com/ResidentMario/missingno
896897
897898
Args:
898-
data: The input DataFrame.
899+
corr_mat: correlation matrix.
900+
mask: Upper-triangle mask.
899901
figsize: The size of the figure to display. Defaults to (20, 12).
900902
fontsize: The figure's font size.
901903
labels: Whether or not to label each matrix entry with its correlation (default is True).
@@ -906,15 +908,6 @@ def missing_heatmap(
906908
The plot axis.
907909
"""
908910
_, ax = plt.subplots(1, 1, figsize=figsize)
909-
910-
# Remove completely filled or completely empty variables.
911-
columns = [i for i, n in enumerate(np.var(data.isnull(), axis="rows")) if n > 0]
912-
data = data.iloc[:, columns]
913-
914-
# Create and mask the correlation matrix. Construct the base heatmap.
915-
corr_mat = data.isnull().corr()
916-
mask = np.zeros_like(corr_mat)
917-
mask[np.triu_indices_from(mask)] = True
918911
norm_args = {"vmin": -1, "vmax": 1} if normalized_cmap else {}
919912

920913
if labels:

0 commit comments

Comments
 (0)