JAX does offer some general matrix math that can be more useful/fast than torch alone. I often do deep learning with torch and then use JAX on the top to train statistical models (i.e. fuse features from multiple models, raw features, etc. into a single regression/inference)
TubasAreFun t1_jbapwmb wrote
Reply to comment by hcarlens in [R] Analysis of 200+ ML competitions in 2022 by hcarlens
JAX does offer some general matrix math that can be more useful/fast than torch alone. I often do deep learning with torch and then use JAX on the top to train statistical models (i.e. fuse features from multiple models, raw features, etc. into a single regression/inference)