Parallel nested graph generation attempt #28341
-
|
Hello, I'm trying to speed up this simple function (with parallel scalar injection): import jax; from jax.tree_util import Partial
@Partial(jax.jit, static_argnums=0)
def FVconnect(n: int = 7): # N being the number of Vertices
""" fully connect vertices between each others in a fully connected manner (densest option) """
Edges = [ (i,j) for i in range(n) for j in range(i+1,n) ]
return Edges
Output = FVconnect() # this will give a 2nd degree tensor (but what matters is only j in the end, i is repeated) def generate_edges1(n: int = 2): # N being the number of Vertices
Edges = jax.pmap(lambda a: jnp.arange(a+1,n), 0, None)(jnp.arange(n-1))
return Edges
def generate_edges2(n: int = 2): # N being the number of Vertices
Edges = jax.lax.scan(lambda b,a: jnp.arange(a+1,n),init=0,xs=jnp.arange(n-1) )
return EdgesIn idea the array will ressemble this with jax.vmap
However as for numpy Arrays, jax.Array type does not work with non uniform shapes, so I'm ok with flattening / concatenating this. Anyone in the forum knows how to make it work? Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
Update: |
Beta Was this translation helpful? Give feedback.
Update:
see the topic #32993 (comment) where a solution is proposed along this line (for those interested)