TorchOpt
TorchOpt is an efficient library for differentiable optimization built upon PyTorch. TorchOpt is:
-
Comprehensive: TorchOpt provides three differentiation modes - explicit differentiation, implicit differentiation, and zero-order differentiation for handling different differentiable optimization situations.
-
Flexible: TorchOpt provides both functional and objective-oriented API for users’ different preferences. Users can implement differentiable optimization in JAX-like or PyTorch-like style.
-
Efficient: TorchOpt provides (1) CPU/GPU acceleration differentiable optimizer (2) RPC-based distributed training framework (3) Fast Tree Operations, to largely increase the training efficiency for bi-level optimization problems.
Beyond differentiable optimization, TorchOpt can also be regarded as a functional optimizer that enables JAX-like composable functional optimizer for PyTorch. With TorchOpt, users can easily conduct neural network optimization in PyTorch with a functional style optimizer, similar to Optax in JAX.