@@ -433,6 +433,11 @@ def rng_fn(cls, rng, r, alpha, time_covariate_vector, size):
433433 lam_covar = lam * exp_time_covar
434434
435435 p = 1 - np .exp (- lam_covar )
436+ # TODO: This is a hack to ensure valid probability in (0, 1]
437+ # We should find a better way to do this.
438+ # Ensure valid probability in (0, 1]
439+ tiny = np .finfo (p .dtype ).tiny
440+ p = np .clip (p , tiny , 1.0 )
436441 samples = rng .geometric (p )
437442 # samples = np.ceil(np.log(1 - rng.uniform(size=size)) / (-lam_covar))
438443
@@ -576,12 +581,11 @@ def support_point(rv, size, r, alpha, time_covariate_vector):
576581 1.0 / (1.0 - pt .exp (- base_lambda )), # Full expression for larger lambda
577582 )
578583
579- # Apply time covariates if provided
584+ # Apply time covariates if provided: multiply by exp(sum over axis=0)
585+ # This yields a scalar for 1D covariates and a time-length vector for 2D (features x time)
580586 tcv = pt .as_tensor_variable (time_covariate_vector )
581587 if tcv .ndim != 0 :
582- # If 1D, treat as per-time vector; if 2D+, sum features while preserving time axis
583- cov_time = tcv if tcv .ndim == 1 else tcv .sum (axis = 0 )
584- mean = mean * pt .exp (cov_time )
588+ mean = mean * pt .exp (tcv .sum (axis = 0 ))
585589
586590 # Round up to nearest integer and ensure >= 1
587591 mean = pt .maximum (pt .ceil (mean ), 1.0 )
@@ -603,8 +607,8 @@ def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.Te
603607 if time_covariate_vector .ndim == 1 :
604608 per_time_sum = pt .exp (time_covariate_vector )
605609 else :
606- feature_axes = tuple ( range ( time_covariate_vector . ndim - 1 ) )
607- per_time_sum = pt .sum (pt .exp (time_covariate_vector ), axis = feature_axes )
610+ # If axis=0 is time and axis>0 are features, sum over features (axis>0 )
611+ per_time_sum = pt .sum (pt .exp (time_covariate_vector ), axis = 0 )
608612
609613 # Build cumulative sum up to each t without advanced indexing
610614 time_length = pt .shape (per_time_sum )[0 ]
@@ -617,9 +621,5 @@ def C_t(t: pt.TensorVariable, time_covariate_vector: pt.TensorVariable) -> pt.Te
617621 mask = pt .lt (time_idx , pt .shape_padright (t_vec , 1 ))
618622 # Sum per-time contributions over time axis
619623 base_sum = pt .sum (pt .shape_padleft (per_time_sum ) * mask , axis = - 1 )
620- # Carry-forward last per-time value for t beyond time_length
621- last_value = per_time_sum [- 1 ]
622- excess_steps = pt .maximum (t_vec - time_length , 0 )
623- carried = base_sum + excess_steps * last_value
624- # If original t was scalar, return scalar
625- return pt .squeeze (carried )
624+ # If original t was scalar, return scalar (saturate at last time step)
625+ return pt .squeeze (base_sum )
0 commit comments