For the past 10 years, Computer Vision models have been growing exponentially in size every year... for a very simple reason. If enough data is fed into the model, increasing its size and complexity usually increases its performance.
However, bigger models need more memory and computing power at inference. In almost any computer vision use-case, inference time is very important and can be decisive in improving (or worsening) user experience.
Fortunately, there is a cheat code to address this seemingly inflexible tradeoff between performance and resource intensity: Pruning.
What is Pruning?
Pruning builds on the observation that most of a trained model’s parameters are useless. Yes, you read that right. In your usual Computer Vision CNN, a very small fraction of its weights does all the heavy lifting while the rest are very close to zero.
Pruning algorithms allow you to keep the useful connections in your model and kill the useless ones. As you can see, you can achieve almost the same level of performance with only 10% of the weights. 10x smaller and almost as accurate!
Only the fittest will survive: Computer Vision Darwinism
Pruning works in cycles. Each cycle consists of:
- Removing the least important weights in your Computer Vision model.
- Fine-tuning the rest of the network for a few epochs.
As you may have guessed, the way you select these “least important” connections and the schedule of these cycles constitute the two main degrees of freedom of pruning techniques.
First, you have to choose a criterion to use to determine which weights to set to zero. There are two main families of criteria:
- Structured pruning, where each layer of the model is pruned independently.
- Unstructured pruning, which treats all the weights of the model the same.
For each of these families, the algorithm determines which weights to prune using the L1 norm, L2 norm, or any other norm.
Second, you need to determine a pruning schedule. This describes how the algorithm alternates between pruning and fine-tuning to reach your target pruning rate (or sparsity).
The most common pruning schedules are:
- One-shot pruning, where there is only one pruning step and one fine-tuning step.
- Linear pruning, where every step prunes the same number of weights until the target sparsity is reached.
- Automated Gradual Pruning (AGP) is a more sophisticated pruning schedule. This schedule prunes more weights at the beginning of the pruning process. It then prunes less and less as the network approaches the target sparsity value.
Such a gradual pruning allows the different layers of the Computer Vision model to have more time to adapt to the raising sparsity through fine-tuning.
How easy is it to implement?
Pruning is very easy to implement.
Both PyTorch and TensorFlow, the two major deep learning frameworks have ready-to-use pruning implementations. Here is an example of how to prune a network with one line of code in PyTorch.
Pruning in TensorFlow is a bit less straightforward as it is necessary to recompile the pruned network. However, Google’s framework offers implementations of a few pruning schedules, whereas they would have to be explicitly implemented in PyTorch.
What’s the catch?
As you may have guessed, there had to be a catch. These previous out-of-the-box implementations of pruning sadly achieve no real gain in memory usage or inference time.
In fact, under the hood, the two Python frameworks simply apply a mask on the parameters of the models. The values of the mask are either ones or zeros, depending on whether or not the parameter has been pruned.
This allows you to prune and fine-tune your model, and to experiment with various pruning techniques. But it is not meant to actually prune the parameters.
As the pruned parameters are not removed from the computational graph, the model will still use the same amount of resources, even if it is 100% pruned.
Same performance, 2x faster
There is no free lunch in Computer Vision, it seems! Fortunately, there are a few solutions to redeem the benefits of pruning, which mainly rely on the idea of sparse matrix encoding. In simple terms, it means storing and manipulating them as lists of indices and corresponding values.
These matrices are a lot trickier to manipulate, especially when using common libraries, like NumPy arrays or Torch tensors, which are designed for dense ones. In fact, SciPy’s sparse matrix multiplication is 50 to 200 times slower!
Luckily, researchers are working on optimizing the inference of sparse models (such as pruned ones). By exploiting some properties of sparse networks, this paper from Google and Stanford achieves “1.2–2.1× speedups and up to 12.8× memory savings without sacrificing accuracy”.
This 2x speedup roughly corresponds to 90% sparsity on the paper’s figure. This is also the level of sparsity we found in our earlier experiment with LeNet to retain the model’s accuracy. It all comes together quite nicely!
Computer Vision and pruning: Final Words
Pruning is a very powerful tool as it allows to dramatically speed up and lighten computer vision models, up to 10x smaller, 2x faster and the same performance!
Today, Computer Vision models are getting orders of magnitude bigger, as shown by Google’s latest Vision Transformer which has more than 2 Billion parameters!
This means that in most practical use-cases, these huge models will have to be scaled down. Pruning allows you to train a sizeable and highly complex model, then prune it into an acceptable size, while retaining most of the original performance, which is one heck of a cheat code! If you’re looking for another one, you should check out how to speed up your Python code with JAX.
After reading this, you must be feeling an urge to experiment with pruning in your projects. What better way to do it than using DVC + Makefile
If you’re looking for Computer Vision experts who will know how to deliver tailor-made and efficient models at scale, you’re in luck! Contact us here.