@@ -1640,6 +1640,8 @@ def _linspace_core(
16401640
16411641
16421642def _broadcast_inputs (* args ):
1643+ """Helper function to preprocess inputs to *space Ops"""
1644+
16431645 args = map (ptb .as_tensor_variable , args )
16441646 args = broadcast_arrays (* args )
16451647
@@ -1653,14 +1655,23 @@ def _broadcast_base_with_inputs(start, stop, base, axis):
16531655
16541656 Parameters
16551657 ----------
1656- start
1657- stop
1658- base
1659- axis
1658+ start: TensorVariable
1659+ The start value(s) of the sequence(s).
1660+ stop: TensorVariable
1661+ The end value(s) of the sequence(s)
1662+ base: TensorVariable
1663+ The log base value(s) of the sequence(s)
1664+ axis: int
1665+ The axis along which to generate samples.
16601666
16611667 Returns
16621668 -------
1663-
1669+ start: TensorVariable
1670+ The start value(s) of the sequence(s), broadcast with the base tensor if necessary.
1671+ stop: TensorVariable
1672+ The end value(s) of the sequence(s), broadcast with the base tensor if necessary.
1673+ base: TensorVariable
1674+ The log base value(s) of the sequence(s), broadcast with the start and stop tensors if necessary.
16641675 """
16651676 base = ptb .as_tensor_variable (base )
16661677 if base .ndim > 0 :
@@ -1841,10 +1852,9 @@ def geomspace(
18411852 )
18421853 result = base ** result
18431854
1844- if num > 0 :
1845- result = set_subtensor (result [0 , ...], start )
1846- if num > 1 and endpoint :
1847- result = set_subtensor (result [- 1 , ...], stop )
1855+ result = switch (gt (num , 0 ), set_subtensor (result [0 , ...], start ), result )
1856+ if endpoint :
1857+ result = switch (gt (num , 1 ), set_subtensor (result [- 1 , ...], stop ), result )
18481858
18491859 result = result * out_sign
18501860
0 commit comments