@@ -1638,6 +1638,8 @@ def _linspace_core(
16381638
16391639
16401640def _broadcast_inputs (* args ):
1641+ """Helper function to preprocess inputs to *space Ops"""
1642+
16411643 args = map (ptb .as_tensor_variable , args )
16421644 args = broadcast_arrays (* args )
16431645
@@ -1651,14 +1653,23 @@ def _broadcast_base_with_inputs(start, stop, base, axis):
16511653
16521654 Parameters
16531655 ----------
1654- start
1655- stop
1656- base
1657- axis
1656+ start: TensorVariable
1657+ The start value(s) of the sequence(s).
1658+ stop: TensorVariable
1659+ The end value(s) of the sequence(s)
1660+ base: TensorVariable
1661+ The log base value(s) of the sequence(s)
1662+ axis: int
1663+ The axis along which to generate samples.
16581664
16591665 Returns
16601666 -------
1661-
1667+ start: TensorVariable
1668+ The start value(s) of the sequence(s), broadcast with the base tensor if necessary.
1669+ stop: TensorVariable
1670+ The end value(s) of the sequence(s), broadcast with the base tensor if necessary.
1671+ base: TensorVariable
1672+ The log base value(s) of the sequence(s), broadcast with the start and stop tensors if necessary.
16621673 """
16631674 base = ptb .as_tensor_variable (base )
16641675 if base .ndim > 0 :
@@ -1839,10 +1850,9 @@ def geomspace(
18391850 )
18401851 result = base ** result
18411852
1842- if num > 0 :
1843- result = set_subtensor (result [0 , ...], start )
1844- if num > 1 and endpoint :
1845- result = set_subtensor (result [- 1 , ...], stop )
1853+ result = switch (gt (num , 0 ), set_subtensor (result [0 , ...], start ), result )
1854+ if endpoint :
1855+ result = switch (gt (num , 1 ), set_subtensor (result [- 1 , ...], stop ), result )
18461856
18471857 result = result * out_sign
18481858
0 commit comments