@@ -415,10 +415,7 @@ def medfilt1(x, L=3):
415415 >>> L = 103
416416 >>> xout = medfilt1(x=x, L=L)
417417 >>> ax = plt.subplot(212)
418- >>> (
419- ... l1,
420- ... l2,
421- ... ) = ax.plot(
418+ >>> (l1, l2,) = ax.plot(
422419 ... x
423420 ... ), ax.plot(xout)
424421 >>> ax.grid(True)
@@ -570,7 +567,7 @@ def md_trenberth(x):
570567 return y
571568
572569
573- def pl33tn (x , dt = 1.0 , T = 33.0 , mode = "valid" ):
570+ def pl33tn (x , dt = 1.0 , T = 33.0 , mode = "valid" , t = None ):
574571 """
575572 Computes low-passed series from `x` using pl33 filter, with optional
576573 sample interval `dt` (hours) and filter half-amplitude period T (hours)
@@ -608,14 +605,25 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
608605 """
609606
610607 import cf_xarray # noqa: F401
608+ import pandas as pd
611609 import xarray as xr
612610
613- if isinstance (x , xr .Dataset ):
614- raise TypeError ("Input a DataArray not a Dataset." )
611+ if isinstance (x , ( xr .Dataset , pd . DataFrame ) ):
612+ raise TypeError ("Input a DataArray not a Dataset, or a Series not a DataFrame ." )
615613
614+ if isinstance (x , pd .Series ) and not isinstance (
615+ x .index ,
616+ pd .core .indexes .datetimes .DatetimeIndex ,
617+ ):
618+ raise TypeError ("Input Series needs to have parsed datetime indices." )
619+
620+ # find dt in units of hours
616621 if isinstance (x , xr .DataArray ):
617- # find dt in units of hours
618- dt = (x .cf ["T" ][1 ] - x .cf ["T" ][0 ]) * 1e-9 / 3600
622+ dt = (x .cf ["T" ][1 ] - x .cf ["T" ][0 ]) / np .timedelta64 (
623+ 360_000_000_000 ,
624+ )
625+ elif isinstance (x , pd .Series ):
626+ dt = (x .index [1 ] - x .index [0 ]) / pd .Timedelta ("1H" )
619627
620628 pl33 = np .array (
621629 [
@@ -694,18 +702,20 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
694702 dt = float (dt ) * (33.0 / T )
695703
696704 filter_time = np .arange (0.0 , 33.0 , dt , dtype = "d" )
697- # N = len(filter_time)
705+ Nt = len (filter_time )
698706 filter_time = np .hstack ((- filter_time [- 1 :0 :- 1 ], filter_time ))
699707
700708 pl33 = np .interp (filter_time , _dt , pl33 )
701709 pl33 /= pl33 .sum ()
702710
703711 if isinstance (x , xr .DataArray ):
712+ x = x .interpolate_na (dim = x .cf ["T" ].name )
713+
704714 weight = xr .DataArray (pl33 , dims = ["window" ])
705715 xf = (
706716 x .rolling ({x .cf ["T" ].name : len (pl33 )}, center = True )
707717 .construct ({x .cf ["T" ].name : "window" })
708- .dot (weight )
718+ .dot (weight , dims = "window" )
709719 )
710720 # update attrs
711721 attrs = {
@@ -715,7 +725,26 @@ def pl33tn(x, dt=1.0, T=33.0, mode="valid"):
715725 }
716726 xf .attrs = attrs
717727
728+ elif isinstance (x , pd .Series ):
729+ xf = x .to_frame ().apply (np .convolve , v = pl33 , mode = mode )
730+
731+ # nan out edges which are not good values anyway
732+ if mode == "same" :
733+ xf [: Nt - 1 ] = np .nan
734+ xf [- Nt + 2 :] = np .nan
735+
718736 else : # use numpy
719737 xf = np .convolve (x , pl33 , mode = mode )
720738
739+ # times to match xf
740+ if t is not None :
741+ # Nt = len(filter_time)
742+ tf = t [Nt - 1 : - Nt + 1 ]
743+ return xf , tf
744+
745+ # nan out edges which are not good values anyway
746+ if mode == "same" :
747+ xf [: Nt - 1 ] = np .nan
748+ xf [- Nt + 2 :] = np .nan
749+
721750 return xf
0 commit comments