I've been a huge fan of PyTorch since the last year, especially when it quickly got all necessary functionality for sophisticated computer vision models - without added complexity of TF. But today PyTorch team announced the production-ready release of PyTorch - here are just a few things to be excited about.
This is by far the biggest one. In TF ecosystem there are a few tools that make the deployment to production very simple, be it a cluster of servers or a smartphone. Even NVIDIA with their rapid development of TensorRT library that allows to perform a whole bunch of optimizations out of the box and compilation to a native binary, is mostly oriented towards TF/Caffe.
The biggest issue with running PyTorch in production was the fact that it's still in Python, so no real HPC for you. All "tweaks" were limited to "eval" mode and disabled gradients during the inference:
net.eval() v = Variable(input_image, volatile=True).cuda(async=True) # and that's about it
With the new JIT compiler this issue is going to be removed - since you don't need to compute gradients anymore, you can have significant speedups from the core features:
- Compiling to a native binary.
- Fusing some operations together, most notable Convolution + BatchNorm + ReLU becomes one operation, same as they did in TensorRT.
- Ditching all additional operations required for computing gradients - no additional memory allocation is a massive improvement, especially in systems with high throughput requirements like a self-driving car.
- Built-in weight quantization - a very big win for smartphones and embedded systems.
One of the difficulties with a dynamic computational graphs, the computational model that serves as a foundation for PyTorch and Chainer, was the question about tracing the operations written inside your model in Python and compiling them correctly (preferably, with optimizations):
... def forward(self, input_imgs, additional_features): # how the hell do you compile that? feat_1, feat_2 = additional_features feats = self.encoder(input_imgs) fusion = torch.cat([feats, feat_1], 1) + feat_2 return self.decoder(fusion)
The announced JIT compiler includes a trace module that's already partially implemented, it takes a constructed network and converts it into a function that can be compiled further:
# This will run your nn.Module or regular Python function with the example # input that you provided. The returned callable can be used to re-execute # all operations that happened during the example run, but it will no longer # use the Python interpreter. from torch.jit import trace traced_model = trace(model, example_input=input) traced_fn = trace(fn, example_input=input) # The training loop doesn't change. Traced model behaves exactly like an # nn.Module, except that you can't edit what it does or change its attributes. # Think of it as a "frozen module". for input, target in data_loader: loss = loss_fn(traced_model(input), target)
The code from official announce
We might expect minor issues early on with some implementations of seq2seq networks that have some non-trivial operations inside them, but here comes another great thing:
To make sure that slightly more complicated parts of your code will be executed correctly after compilation,
@script annotation provides you with an explicit way to control your workflow. Most of that code will depend not on tensor operations in Torch backend, but on Python code which is the source of slowdowns. When these sections are compiled, this issue goes away, and I'm particularly excited about this one.
It's taking the flexibility of ONNX import to Caffe2 that's available right now to a new level where you can hack a network, then convert it into a production-ready version almost instantly, avoiding all possible conflicts between two libraries that still happen from time to time.
From what I've learned profiling different parts of rather heavy networks (object detection, segmentation, depth perception), weight-heavy convolution layers are not the main source of latency as one might think. Fast cuDNN implementation makes it actually the fastest operation after some really lightweight things like elementwise multiplication or addition. The main culprits were complex control flow modules, like Region Proposal Networks in Faster R-CNN, which can still be too slow, and even without additional optimization getting them as C++ runtimes would be a significant boost.