TorchOpt

TorchOpt is an efficient library for differentiable optimization built upon PyTorch. TorchOpt is:

  • Comprehensive: TorchOpt provides three differentiation modes - explicit differentiation, implicit differentiation, and zero-order differentiation for handling different differentiable optimization situations.

  • Flexible: TorchOpt provides both functional and objective-oriented API for users’ different preferences. Users can implement differentiable optimization in JAX-like or PyTorch-like style.

  • Efficient: TorchOpt provides (1) CPU/GPU acceleration differentiable optimizer (2) RPC-based distributed training framework (3) Fast Tree Operations, to largely increase the training efficiency for bi-level optimization problems.

Beyond differentiable optimization, TorchOpt can also be regarded as a functional optimizer that enables JAX-like composable functional optimizer for PyTorch. With TorchOpt, users can easily conduct neural network optimization in PyTorch with a functional style optimizer, similar to Optax in JAX.

Luo Mai
Luo Mai
Assistant Professor

My research interests include computer systems, machine learning and data management.