@@ -99,13 +99,13 @@ def __init__(
9999 # causal impact pre (ie the residuals of the model fit to observed)
100100 pre_data = xr .DataArray (self .pre_y [:, 0 ], dims = ["obs_ind" ])
101101 self .pre_impact = (
102- pre_data - self .pre_pred ["posterior_predictive" ].y_hat
102+ pre_data - self .pre_pred ["posterior_predictive" ].mu
103103 ).transpose (..., "obs_ind" )
104104
105105 # causal impact post (ie the residuals of the model fit to observed)
106106 post_data = xr .DataArray (self .post_y [:, 0 ], dims = ["obs_ind" ])
107107 self .post_impact = (
108- post_data - self .post_pred ["posterior_predictive" ].y_hat
108+ post_data - self .post_pred ["posterior_predictive" ].mu
109109 ).transpose (..., "obs_ind" )
110110
111111 # cumulative impact post
@@ -118,31 +118,43 @@ def plot(self):
118118
119119 # TOP PLOT --------------------------------------------------
120120 # pre-intervention period
121- plot_xY (
121+ h_line , h_patch = plot_xY (
122122 self .datapre .index ,
123- self .pre_pred ["posterior_predictive" ].y_hat ,
123+ self .pre_pred ["posterior_predictive" ].mu ,
124124 ax = ax [0 ],
125+ plot_hdi_kwargs = {"color" : "C0" },
125126 )
126- ax [0 ].plot (self .datapre .index , self .pre_y , "k." , label = "Observations" )
127+ handles = [(h_line , h_patch )]
128+ labels = ["Pre-intervention period" ]
129+
130+ (h ,) = ax [0 ].plot (self .datapre .index , self .pre_y , "k." , label = "Observations" )
131+ handles .append (h )
132+ labels .append ("Observations" )
133+
127134 # post intervention period
128- plot_xY (
135+ h_line , h_patch = plot_xY (
129136 self .datapost .index ,
130- self .post_pred ["posterior_predictive" ].y_hat ,
137+ self .post_pred ["posterior_predictive" ].mu ,
131138 ax = ax [0 ],
132- include_label = False ,
139+ plot_hdi_kwargs = { "color" : "C1" } ,
133140 )
141+ handles .append ((h_line , h_patch ))
142+ labels .append ("Synthetic control" )
143+
134144 ax [0 ].plot (self .datapost .index , self .post_y , "k." )
135145 # Shaded causal effect
136- ax [0 ].fill_between (
146+ h = ax [0 ].fill_between (
137147 self .datapost .index ,
138148 y1 = az .extract (
139- self .post_pred , group = "posterior_predictive" , var_names = "y_hat "
149+ self .post_pred , group = "posterior_predictive" , var_names = "mu "
140150 ).mean ("sample" ),
141151 y2 = np .squeeze (self .post_y ),
142152 color = "C0" ,
143153 alpha = 0.25 ,
144- label = "Causal impact" ,
145154 )
155+ handles .append (h )
156+ labels .append ("Causal impact" )
157+
146158 ax [0 ].set (
147159 title = f"""
148160 Pre-intervention Bayesian $R^2$: { self .score .r2 :.3f}
@@ -155,12 +167,13 @@ def plot(self):
155167 self .datapre .index ,
156168 self .pre_impact ,
157169 ax = ax [1 ],
170+ plot_hdi_kwargs = {"color" : "C0" },
158171 )
159172 plot_xY (
160173 self .datapost .index ,
161174 self .post_impact ,
162175 ax = ax [1 ],
163- include_label = False ,
176+ plot_hdi_kwargs = { "color" : "C1" } ,
164177 )
165178 ax [1 ].axhline (y = 0 , c = "k" )
166179 ax [1 ].fill_between (
@@ -173,12 +186,12 @@ def plot(self):
173186 ax [1 ].set (title = "Causal Impact" )
174187
175188 # BOTTOM PLOT -----------------------------------------------
176-
177189 ax [2 ].set (title = "Cumulative Causal Impact" )
178190 plot_xY (
179191 self .datapost .index ,
180192 self .post_impact_cumulative ,
181193 ax = ax [2 ],
194+ plot_hdi_kwargs = {"color" : "C1" },
182195 )
183196 ax [2 ].axhline (y = 0 , c = "k" )
184197
@@ -189,10 +202,13 @@ def plot(self):
189202 ls = "-" ,
190203 lw = 3 ,
191204 color = "r" ,
192- label = "Treatment time" ,
193205 )
194206
195- ax [0 ].legend (fontsize = LEGEND_FONT_SIZE )
207+ ax [0 ].legend (
208+ handles = (h_tuple for h_tuple in handles ),
209+ labels = labels ,
210+ fontsize = LEGEND_FONT_SIZE ,
211+ )
196212
197213 return (fig , ax )
198214
@@ -353,39 +369,46 @@ def __init__(
353369 )
354370
355371 def plot (self ):
356- """Plot the results"""
372+ """Plot the results.
373+ Creating the combined mean + HDI legend entries is a bit involved.
374+ """
357375 fig , ax = plt .subplots ()
358376
359377 # Plot raw data
360- # NOTE: This will not work when there is just ONE unit in each group
361- sns .lineplot (
378+ sns .scatterplot (
362379 self .data ,
363380 x = self .time_variable_name ,
364381 y = self .outcome_variable_name ,
365382 hue = self .group_variable_name ,
366- units = "unit" , # NOTE: assumes we have a `unit` predictor variable
367- estimator = None ,
368- alpha = 0.5 ,
383+ alpha = 1 ,
384+ legend = False ,
385+ markers = True ,
369386 ax = ax ,
370387 )
371388
372389 # Plot model fit to control group
373390 time_points = self .x_pred_control [self .time_variable_name ].values
374- plot_xY (
391+ h_line , h_patch = plot_xY (
375392 time_points ,
376- self .y_pred_control .posterior_predictive .y_hat ,
393+ self .y_pred_control .posterior_predictive .mu ,
377394 ax = ax ,
378395 plot_hdi_kwargs = {"color" : "C0" },
396+ label = "Control group" ,
379397 )
398+ handles = [(h_line , h_patch )]
399+ labels = ["Control group" ]
380400
381401 # Plot model fit to treatment group
382402 time_points = self .x_pred_control [self .time_variable_name ].values
383- plot_xY (
403+ h_line , h_patch = plot_xY (
384404 time_points ,
385- self .y_pred_treatment .posterior_predictive .y_hat ,
405+ self .y_pred_treatment .posterior_predictive .mu ,
386406 ax = ax ,
387407 plot_hdi_kwargs = {"color" : "C1" },
408+ label = "Treatment group" ,
388409 )
410+ handles .append ((h_line , h_patch ))
411+ labels .append ("Treatment group" )
389412
390413 # Plot counterfactual - post-test for treatment group IF no treatment
391414 # had occurred.
@@ -403,26 +426,34 @@ def plot(self):
403426 widths = 0.2 ,
404427 )
405428 for pc in parts ["bodies" ]:
406- pc .set_facecolor ("C2 " )
429+ pc .set_facecolor ("C0 " )
407430 pc .set_edgecolor ("None" )
408431 pc .set_alpha (0.5 )
409432 else :
410- plot_xY (
433+ h_line , h_patch = plot_xY (
411434 time_points ,
412- self .y_pred_counterfactual .posterior_predictive .y_hat ,
435+ self .y_pred_counterfactual .posterior_predictive .mu ,
413436 ax = ax ,
414437 plot_hdi_kwargs = {"color" : "C2" },
438+ label = "Counterfactual" ,
415439 )
440+ handles .append ((h_line , h_patch ))
441+ labels .append ("Counterfactual" )
416442
417443 # arrow to label the causal impact
418444 self ._plot_causal_impact_arrow (ax )
445+
419446 # formatting
420447 ax .set (
421448 xticks = self .x_pred_treatment [self .time_variable_name ].values ,
422449 title = self ._causal_impact_summary_stat (),
423450 )
424- ax .legend (fontsize = LEGEND_FONT_SIZE )
425- return (fig , ax )
451+ ax .legend (
452+ handles = (h_tuple for h_tuple in handles ),
453+ labels = labels ,
454+ fontsize = LEGEND_FONT_SIZE ,
455+ )
456+ return fig , ax
426457
427458 def _plot_causal_impact_arrow (self , ax ):
428459 """
@@ -582,12 +613,17 @@ def plot(self):
582613 c = "k" , # hue="treated",
583614 ax = ax ,
584615 )
616+
585617 # Plot model fit to data
586- plot_xY (
618+ h_line , h_patch = plot_xY (
587619 self .x_pred [self .running_variable_name ],
588620 self .pred ["posterior_predictive" ].mu ,
589621 ax = ax ,
622+ plot_hdi_kwargs = {"color" : "C1" },
590623 )
624+ handles = [(h_line , h_patch )]
625+ labels = ["Posterior mean" ]
626+
591627 # create strings to compose title
592628 title_info = f"{ self .score .r2 :.3f} (std = { self .score .r2_std :.3f} )"
593629 r2 = f"Bayesian $R^2$ on all data = { title_info } "
@@ -605,7 +641,11 @@ def plot(self):
605641 color = "r" ,
606642 label = "treatment threshold" ,
607643 )
608- ax .legend (fontsize = LEGEND_FONT_SIZE )
644+ ax .legend (
645+ handles = (h_tuple for h_tuple in handles ),
646+ labels = labels ,
647+ fontsize = LEGEND_FONT_SIZE ,
648+ )
609649 return (fig , ax )
610650
611651 def summary (self ):
@@ -710,27 +750,38 @@ def plot(self):
710750 hue = "group" ,
711751 alpha = 0.5 ,
712752 data = self .data ,
753+ legend = True ,
713754 ax = ax [0 ],
714755 )
715756 ax [0 ].set (xlabel = "Pretest" , ylabel = "Posttest" )
716757
717758 # plot posterior predictive of untreated
718- plot_xY (
759+ h_line , h_patch = plot_xY (
719760 self .pred_xi ,
720- self .pred_untreated ["posterior_predictive" ].y_hat ,
761+ self .pred_untreated ["posterior_predictive" ].mu ,
721762 ax = ax [0 ],
722763 plot_hdi_kwargs = {"color" : "C0" },
764+ label = "Control group" ,
723765 )
766+ handles = [(h_line , h_patch )]
767+ labels = ["Control group" ]
724768
725769 # plot posterior predictive of treated
726- plot_xY (
770+ h_line , h_patch = plot_xY (
727771 self .pred_xi ,
728- self .pred_treated ["posterior_predictive" ].y_hat ,
772+ self .pred_treated ["posterior_predictive" ].mu ,
729773 ax = ax [0 ],
730774 plot_hdi_kwargs = {"color" : "C1" },
775+ label = "Treatment group" ,
731776 )
777+ handles .append ((h_line , h_patch ))
778+ labels .append ("Treatment group" )
732779
733- ax [0 ].legend (fontsize = LEGEND_FONT_SIZE )
780+ ax [0 ].legend (
781+ handles = (h_tuple for h_tuple in handles ),
782+ labels = labels ,
783+ fontsize = LEGEND_FONT_SIZE ,
784+ )
734785
735786 # Plot estimated caual impact / treatment effect
736787 az .plot_posterior (self .causal_impact , ref_val = 0 , ax = ax [1 ])
0 commit comments