Submitted by martenlienen t3_zfvb8h in MachineLearning
fhchl t1_izi5ecp wrote
Nice work! Though diffrax is mentioned, it would be interesting to see a direct comparison between diffrax and torchode. Can you give some more details in how they differ in features and performance, apart from the library in which they are implemented?
martenlienen OP t1_izi7n9z wrote
Diffrax is an excellent project and a superset of torchode. torchode solves only ODEs, while diffrax combines ODEs, CDEs and SDEs (maybe more?) in the same framework. The reason why we created torchode is that we wanted to bring its structure and flexibility into the PyTorch ecosystem that is still more popular than JAX. In addition, we were also looking to create an optimized implementation in which we succeeded as far as I am concerned. Even though tracking multiple ODE solvers at once is inherently more complex than solving a batch of ODEs jointly, torchode is as fast as or faster than the other PyTorch ODE solvers in our experiments.
fhchl t1_izif55h wrote
Is this feature of torchode of solving multiple ODEs at once over some batch dimension comparable to jax.vmapping over that dimension in diffrax?
martenlienen OP t1_iziq1xp wrote
Yes, it is the same thing. Unfortunately, functorch is not advanced enough yet to just translate diffrax to PyTorch directly. Instead, we had to take care of batching everywhere explicitly to decide how long to loop etc.
fhchl t1_izjb0tf wrote
Aight! Thanks for the nice answers! I wish a good conference :)
Viewing a single comment thread. View all comments