martenlienen
martenlienen OP t1_iziq1xp wrote
Reply to comment by fhchl in [R] torchode: A Parallel ODE Solver for PyTorch by martenlienen
Yes, it is the same thing. Unfortunately, functorch is not advanced enough yet to just translate diffrax to PyTorch directly. Instead, we had to take care of batching everywhere explicitly to decide how long to loop etc.
martenlienen OP t1_izic55x wrote
Reply to comment by cheecheepong in [R] torchode: A Parallel ODE Solver for PyTorch by martenlienen
ಠ_ಠ
martenlienen OP t1_izi8k3g wrote
Reply to comment by MathChief in [R] torchode: A Parallel ODE Solver for PyTorch by martenlienen
First, the difference in steps is probably due to different tolerances in the step size controller.
The loop times is measured in milliseconds. Of course, that is much slower than what you got in matlab. The difference is that we did all benchmarks on GPU, because that is the usual mode for deep learning even though it is certainly inappropriate for the VdP equation if you were interested in it for anything else but benchmarking the inner loop of an ODE solver on a GPU. I think, you can get similar numbers to your matlab code in diffrax with JIT compilation on a CPU. However, you won't get it with torchode because PyTorch's JIT is not as good as JAX's and specifically this line is really slow on CPUs. Nonetheless, after comparing several alternatives we chose this because, as I said, in practice in most of deep learning only GPU performance matters.
martenlienen OP t1_izi7n9z wrote
Reply to comment by fhchl in [R] torchode: A Parallel ODE Solver for PyTorch by martenlienen
Diffrax is an excellent project and a superset of torchode. torchode solves only ODEs, while diffrax combines ODEs, CDEs and SDEs (maybe more?) in the same framework. The reason why we created torchode is that we wanted to bring its structure and flexibility into the PyTorch ecosystem that is still more popular than JAX. In addition, we were also looking to create an optimized implementation in which we succeeded as far as I am concerned. Even though tracking multiple ODE solvers at once is inherently more complex than solving a batch of ODEs jointly, torchode is as fast as or faster than the other PyTorch ODE solvers in our experiments.
martenlienen OP t1_izi71ti wrote
Reply to comment by dopadelic in [R] torchode: A Parallel ODE Solver for PyTorch by martenlienen
You are correct except for parallel-in-time integration methods that we also mention in the paper. But the "parallel" in the title refers to solving multiple ODEs in parallel independently which is contrary to what is currently done in ML. At the moment, training on a batch of ODEs means that you treat the batch as one large ODE that is solved jointly. torchode solves them independently from each other but still in parallel by tracking a separate current state, step size, etc. for each sample.
martenlienen OP t1_izejs85 wrote
Reply to comment by MathChief in [R] torchode: A Parallel ODE Solver for PyTorch by martenlienen
In these benchmarks we compare the same Runge-Kutta solver (5th order Dormand-Prince) implemented in all of these libraries. None of these libraries actually propose any new stepping methods. The point is to make ODE solvers available in popular deep learning methods to enable deep continuous models such as neural ODEs and continuous normalizing flows. The particular appeal of torchode is its optimized implementation and that it runs multiple independent instances of an ODE solver in parallel when you train on batches, i.e. each instance is solved with its own step size and step accept/reject decisions. This avoids a performance pitfall where the usual batching approach can lead to many unnecessary solver steps in batched training of models with varying stiffness, as we show in the Van der Pol experiment.
martenlienen OP t1_izk7bk3 wrote
Reply to comment by fhchl in [R] torchode: A Parallel ODE Solver for PyTorch by martenlienen
I am not aware of any but would be very interested if you find one