-
Notifications
You must be signed in to change notification settings - Fork 71
Description
Hi,
I am a robotics researcher, and working on model-based and learning-based locomotion. I am combining classical model-based method (optimization-based whole body control) and reinforcement learning methods, and using MJX to parallize the learning framework. However, I experience some performance issue with the jaxopt to solve the whole-body optimization. When I run with a single environment (jitted), it takes almost 20 ms to solve one step. However, when I scale up to 1000 environments, the solution time is up to 2 seconds, which is 100 times slower, and I believe it is not scaled properly. I used jax.jit and jax.vamp. I also experienced that it almost maintains the same efficiency when running 100 envs. I thought it might be a memory issue, but I have a 16 GB GPU, and Jax used only 12 GB. In general, I would like to know what the possible ways are to speed up the QP solvers. I am using BoxOSQP now. Also, my QP can usually be solved in less than 1ms on the CPU. Are the solvers in JAXOPT in general slower than existing CPU solvers (such as osqp, proxsuit, piqp, ...)? Thank you very much.