Last month I wrote about how you can use the cuda-convnet wrappers in pylearn2 to get up to 3x faster GPU convolutions in Theano. Since then I’ve been working on an FFT-based convolution implementation for Theano. Preliminary tests indicate that this approach is again 2-4x faster than the cuda-convnet wrappers.

I wrote the code in pure Python, using scikits.cuda and PyCUDA to do the heavy lifting. The Theano team is currently working on integrating this code into Theano. They also plan to create a proper C/CUDA implementation to guarantee the best performance.

I put everything up on GitHub, you can find the code there, or clone it and try it yourself:

FFT-based convolution

The Fourier transform of a convolution of two functions is the product of the Fourier transforms of those functions. This is the convolution theorem. This result can be used to quickly compute convolutions in the Fourier domain, since an elementwise product is much less computationally intensive than a convolution.

However, there is a price to be paid: the inputs need to be transformed using the Fast Fourier Transform (FFT), and the product of these transformed inputs needs to be transformed again using the inverse FFT. Depending on the sizes of the inputs, these costs can be pretty significant, so sometimes it is a better idea to just compute the convolution in the original domain.

I was somewhat surprised to learn that all popular implementations of convolutional neural networks (CNNs) use the latter approach, including that of Theano and cuda-convnet. The reason is that typically, convolutions in CNNs involve relatively small filters, so I think people just assumed it wasn’t worth it.

However, a paper published at ICLR 2014 recently caught my eye: Fast Training of Convolutional Networks through FFTs by Mathieu, Henaff and LeCun. They implemented the FFT-based approach in the Torch7 framework and compared its performance to Torch7’s own ‘classical’ implementation. They concluded that it is actually advantageous to use FFT-based convolutions in CNNs in many cases.

The reason is actually quite straightforward: compared to the general case, the overhead of computing the FFTs of the inputs is drastically reduced. We need to compute the convolution of each input example in a given minibatch with each filter. If there are m examples in the minibatch with k input channels, and n filters, this means we need to compute m * n * k convolutions. In the Fourier domain, this turns into m * n * k elementwise products. However, we only need to compute the FFT of each input example and each filter once. So the total number of FFTs to compute is not 2 * m * n * k, but (m + n) * k.

But that’s not everything: the output of a convolutional layer in a CNN is actually a sum of convolutions across all k input channels. Because the FFT is a linear operator, we can compute this sum in the Fourier domain, and then take the IFFT of this sum (instead of the other way around). This means we only need to compute m * n IFFTs, instead of m * n * k. It turns out that these savings can be very significant.

A CUDA/C-less Theano implementation

So this got me thinking that it should be possible to do the same thing in Theano. Theano already intelligently replaces convolution operators in computational graphs with their GPU-based counterparts in the optimization phase. If an FFT-based implementation was added, it could do the same with that version instead.

I set out to implement this, but unfortunately my knowledge of CUDA is nonexistent, and my knowledge of C can be called rusty at best. So I sought to avoid both. Enter scikits.cuda, which offers all the necessary primitives: forward and inverse FFTs, and complex products (the FFT of a real signal is complex and symmetric).

Luckily, scikits.cuda is built on top of PyCUDA, and the Theano docs have some examples of how to implement PyCUDA-based operators. Essentially I just had to glue everything together.

Implementation details

As mentioned earlier, an FFT-based convolution can be broken up into 3 parts: an FFT of the input images and the filters, a bunch of elementwise products followed by a sum across input channels, and then an IFFT of the outputs. I decided to implement each of these as a separate Theano operator. That way, the optimizer could detect if the same inputs or filters are used in multiple convolutions, and only compute them once. At the moment I’m still unsure whether this is beneficial - perhaps some additional performance could be gained by combining everything into a single, monolithic FFT-convolution operator. But that’s a discussion for another time.

The FFT and IFFT operators were the easiest. scikits.cuda exposes a nice API to perform batched FFTs. This allows for GPU-parallelism to be exploited when many FFTs of the same size have to be computed. This is precisely our use case. The API uses the cuFFT implementation internally, which is a part of CUDA.

Interestingly, the authors of the paper I mentioned earlier claim that using cuFFT is not an option because it does not allow to exploit this type of parallelism, so they made their own CUDA FFT implementation instead. However, I got pretty good results using cuFFT, so I don’t know what lead them to make this claim. Perhaps the batched FFT is a recent addition to cuFFT. The same batched approach can be used for the IFFT.

The tough part was performing the actual convolution in the Fourier domain, by computing the complex elementwise products and summing across the input channels. Theano does not have support for complex numbers, so some trickery was required to convert complex arrays into real arrays with an extra trailing dimension of size 2, to contain the real and imaginary parts of the numbers.

I tried a number of different approaches, but what worked best in the end is interpreting the operation as a dot product. A dot product is precisely that: an elementwise product with some broadcasting, followed by summing out a particular dimension. So by reshaping the Fourier-transformed inputs and filters, the multiply-and-sum operation could be translated into a set of dot products. This is great, because GPUs are really good at computing dot products quickly.

It turns out that recent versions of cuBLAS also support batched dot products, which offer the same performance advantages as batched FFTs. Since we need to perform a large number of dot products with the same shapes, this was again a perfect match for our use case. The particular function I needed to compute a batched complex-valued dot product is cublasCgemmBatched. Unfortunately this wasn’t available through scikits.cuda yet, but it wasn’t hard to add the necessary wrappers. I sent a pull request and it is now included (so make sure to get the latest version of scikits.cuda from git if you want to try this).

Proof of concept

So far I’ve only implemented the valid convolution. Using the implementation in the context of a CNN will also require support for full convolutions - but this is easy to mimic by padding the input with zeros. I have not implemented an optimization that swaps out Theano’s own convolution operator with the FFT-based version, but that is something the Theano team is currently working on.

Preliminary benchmarks show that this implementation is typically faster than cuda-convnet. The table below shows the duration of a single valid convolution computation with the given input and filter shapes, measured on a GeForce GTX 680, averaged across 10 runs, and not taking into account the warmup that the FFT-based implementation requires (the first run will be a bit slower because the FFT plans need to be created).

Following Theano conventions, the input shape is given as (batch size, number of input channels, width, height) and the filter shape is given as (number of filters, number of input channels, width, height). Durations are given for Theano’s own conv2d implementation, the cuda-convnet wrappers from pylearn2, and the FFT-based implementation. The speedup of the FFT-based implementation over the cuda-convnet wrappers is also given.

input shape filter shape Theano’s own cuda-convnet FFT-based speedup
(64, 3, 96, 96) (128, 3, 16, 16) 388.9 ms 156.9 ms 117.3 ms 1.34x
(64, 128, 32, 32) (64, 128, 8, 8) 233.9 ms 87.4 ms 27.1 ms 3.23x
(128, 32, 54, 54) (64, 32, 6, 6) 457.5 ms 107.6 ms 52.2 ms 2.06x
(128, 128, 16, 16) (128, 128, 8, 8) 133.4 ms 43.5 ms 18.6 ms 2.34x
(128, 1024, 32, 32) (128, 1024, 4, 4) 6246.2 ms 1283.5 ms 357.8 ms 3.59x

In all cases we get a nice speedup. This approach seems to be the most beneficial when the number of input channels is large - this makes sense, as this is the dimension that is summed over in the batched dot product. But even when this number is small (e.g. 3) it’s still faster.

Try it out

As mentioned in the introduction, you can grab the code for this at:

All the relevant code is in the file fftconv.py. The file cufftop.py was mainly used for experimentation, and contains some alternative implementations of the multiply-and-sum step.

Note that the latest revision of scikits.cuda is required, to ensure that the cublasCgemmBatched function is available. You’ll also need a working installation of PyCUDA, as this is a dependency of scikits.cuda. And of course, you’ll need Theano and a working CUDA installation.

If you’re patient, you can also wait until the code is available in Theano. Chances are you’ll be able to use it without modifying your existing code, as they are also building an optimization that will replace Theano’s own convolutions with the FFT-based implementation. And if you’re very patient, you can wait until they build the CUDA/C version, which will eliminate the scikits.cuda and PyCUDA dependencies, and hopefully it will be a bit faster as well due to the reduced overhead.

The code to compute the numbers in the table above is in the file speedtest.py. This script also checks whether the output of all three implementations is the same (up to a given tolerance). More numbers for different input/filter shapes and different GPUs are welcome, so if you run this script on your own machine(s), feel free to send me the results.

Feedback is welcome, and if you’d like to help with integrating this into Theano, join the conversation at the theano-users group!

Galaxy Zoo Challenge: code published

Some two weeks ago I posted about my solution for the Galaxy Zoo challenge on Kaggle. Today I’ve published the code with documentation on GitHub.

Get it here:

git clone git://github.com/benanne/kaggle-galaxies.git

Have a look at the README file for instructions on how to generate the winning solution, or check out doc/documentation.pdf for more information.

The code is available under a BSD 3-clause licence. If you’ve found it useful, dropping me a line to let me know how you used it is appreciated. Have fun with it!

My solution for the Galaxy Zoo challenge

The Galaxy Zoo challenge on Kaggle has just finished. The goal of the competition was to predict how Galaxy Zoo users (zooites) would classify images of galaxies from the Sloan Digital Sky Survey. I finished in 1st place and in this post I’m going to explain how my solution works.

Introduction

The problem

Galaxy Zoo is a crowdsourcing project, where users are asked to describe the morphology of galaxies based on images. They are asked questions such as “How rounded is the galaxy” and “Does it have a central bulge”, and the users’ answers determine which question will be asked next. The questions form a decision tree which is shown in the figure below, taken from Willett et al. 2013.

The Galaxy Zoo decision tree, taken from Willett et al. 2013.

When many users have classified the same image, their answers can be aggregated into a set of probabilities for each answer. Often, not all users will agree on all of their answers, so it’s useful to quantify this uncertainty.

The goal of the Galaxy Zoo challenge is to predict these probabilities from the galaxy images that are shown to the users. In other words, build a model of how “the crowd” perceive and classify these images.

This means that we’re looking at a regression problem, not a classification problem: we don’t have to determine which classes the galaxies belong to, but rather the fraction of people who would classify them as such.

My solution: convnets

I suppose this won’t surprise anyone: my solution is based around convolutional neural networks (convnets). I believe they’re an excellent match for this problem: it’s image data, but it is different enough from typical image data (i.e. “natural” images such as those used in object recognition, scene parsing, etc.) for traditional features from computer vision to be suboptimal. Learning the features just seems to make sense.

Transfer learning by pre-training a deep neural network on another dataset (say, ImageNet), chopping off the top layer and then training a new classifier, a popular approach for the recently finished Dogs vs. Cats competition, is not really viable either. There were no requests to use external data in the competition forums (a requirement to be allowed to use it), so I guess nobody tried this approach.

During the contest, I frequently referred to Krizhevsky et al.’s seminal 2012 paper on ImageNet classification for guidance. Asking myself “What would Krizhevsky do?” usually resulted in improved performance.

Overfitting

As Geoffrey Hinton has been known to say, if you’re not overfitting, your network isn’t big enough. My main objective during the competition was avoiding overfitting. My models were significantly overfitting throughout the entire competition, and most of the progress I attained came from finding new ways to mitigate that problem.

I tackled this problem with three orthogonal approaches:

  • data augmentation
  • dropout and weight norm constraints
  • modifying the network architecture to increase parameter sharing

The best model I found has about 42 million parameters. It overfits significantly, but it’s still the best despite that. There seems to be a lot of room for improvement there.

As is customary in Kaggle competitions, I also improved my score quite a bit by averaging the predictions of a number of different models. Please refer to the “Model averaging” section below for more details.

Software and hardware

I used Python, NumPy and Theano to implement my solution. I also used the Theano wrappers for the cuda-convnet convolution implementation that are part of pylearn2. They provided me with a speed boost of almost 3x over Theano’s own implementation. I wrote a guide on how to use them, because their documentation is limited.

I used scikit-image for preprocessing and augmentation. I also used sextractor and pysex to extract some parameters of the galaxies from the images.

The networks were trained on workstations with a hexacore CPU, 32GB RAM and two NVIDIA GeForce GTX 680 GPUs each.

Preprocessing and data augmentation

Cropping and downsampling

The data consisted of 424x424 colour JPEG images, along with 37 weighted probabilities that have to be predicted for each image (for details on the weighting scheme, please refer to this page).

For almost all of the images, the interesting part was in the center. The void of space around the galaxies was not very discriminative, so I cropped all images to 207x207. I then downsampled them 3x to 69x69, to keep the input size of the network manageable.

Exploiting spatial invariances

Images of galaxies are rotation invariant: there is no up or down in space. They are also scale invariant and translation invariant to a limited extent. All of these invariances could be exploited to do data augmentation: creating new training data by perturbing the existing data points.

Each training example was perturbed before presenting it to the network by randomly scaling it, rotating it, translating it and optionally flipping it. I used the following parameter ranges:

  • rotation: random with angle between 0° and 360° (uniform)
  • translation: random with shift between -4 and 4 pixels (relative to the original image size of 424x424) in the x and y direction (uniform)
  • zoom: random with scale factor between 1/1.3 and 1.3 (log-uniform)
  • flip: yes or no (bernoulli)

Because both the initial downsampling to 69x69 and the random perturbation are affine transforms, they could be combined into one affine transformation step (I used scikit-image for this). This sped up things significantly and reduced information loss.

Colour perturbation

After this, the colour of the images was changed as described in Krizhevsky et al. 2012, with two differences: the first component had a much larger eigenvalue than the other two, so only this one was used, and the standard deviation for the scale factor alpha was set to 0.5.

“Realtime” augmentation

Combining downsampling and perturbation into a single affine transform made it possible to do data augmentation in realtime, i.e. during training. This significantly reduced overfitting because the network would never see the exact same image twice. While the network was being trained on a chunk of data on the GPU, the next chunk would be generated on the CPU in multiple processes, to ensure that all the available cores were used.

Centering and rescaling

I experimented with centering and rescaling the galaxy images based on parameters extracted with sextractor. Although this didn’t improve performance, including a few models that used it in the final ensemble helped to increase variance (see “Model averaging” for more information).

I extracted the center of the galaxies, as well as the Petrosian radius. A number of different radii can be extracted, but the Petrosian radius seemed to give the best size estimate. I then centered each image by shifting the estimated center pixel to (212, 212), and rescaled it so that its Petrosian radius would be equal to 160 pixels. The scale factor was limited to the range (1/1.5, 1.5), because there were some outliers.

This rescaling and centering could also be collapsed into the affine transform doing downsampling and perturbation, so it did not slow things down at all.

Input = raw pixels

With these pre-processing and augmentation steps, the network input still consisted of raw pixels. No features were extracted apart from those learned by the network itself.

Network architecture

Exploiting rotation invariance to increase parameter sharing

I increased parameter sharing in the network by cutting the galaxy images into multiple parts that could be treated in the same fashion, i.e. processed by the same convolutional architecture. For this I exploited the rotation invariance of the images.

As mentioned before, the images were cropped to 207x207 and downsampled by a factor of 3. This was done with two different orientations: a regular crop, as well as one that is rotated 45°. Both of these crops were also flipped horizontally, resulting in four 69x69 “views” of the image. This is visualised below.

Four different views were extracted from each image: a regular view (red), a 45° rotated view (blue), and mirrored versions of both.

Each of the four views was again split into four partially overlapping “parts” of size 45x45. Each part was rotated so that they are all aligned, with the galaxy in the bottom right corner. This is visualised below. In total, 16 parts were extracted from the original image.

Each view was then split into four partially overlapping parts. Each part was rotated so that they are all aligned, with the galaxy in the bottom right corner. In total, 16 parts were extracted from the original image.

This results in 16 smaller 45x45 images which appear very similar. They can be expected to have the same topological structure due to rotation invariance, so they can be processed by the same convolutional architecture, which results in a 16x increase in parameter sharing, and thus less overfitting. At the top of the network, the features extracted from these 16 parts are concatenated and connected to one or more dense layers, so the information can be aggregated.

A nice side-effect of this approach is that the effective minibatch size for the convolutional part of the network increases 16-fold because the 16 parts are stacked on top of each other, which makes it easier to exploit GPU parallelism.

Due to the overlap of the parts, a lot of information is available about the center of the galaxy, because it is processed by the convnet in 16 different orientations. This is useful because a few important properties of the galaxies are expected to be in the center of the image (the presence of a bar or a bulge, for example). Reducing this overlap typically resulted in reduced performance. I chose not to make the parts fully overlap, because it would slow down training too much.

Incorporating output constraints

As described on this page, the 37 outputs to be predicted are weighted probabilities, adhering to a number of constraints. Incorporating these constraints into the model turned out to be quite useful.

In essence, the answers to each question should form a categorical distribution. Additionally, they are scaled by the probability of the question being asked, i.e. the total probability of answers given that would lead to this question being asked.

My initial reflex was to use a softmax output for each question, and then apply the scaling factors. This didn’t make much of a difference. I believe this is because the softmax function has difficulty predicting hard zeros and ones, of which there were quite a few in the training data (its input would have to be very large in magnitude).

If cross-entropy is the error metric, this is not a big issue, but for this competition, the metric by which submissions were judged was the root mean squared error (RMSE). As a result, being able to predict very low and very high probabilities was quite useful.

In the end I normalised the distribution for each question by adding a rectification nonlinearity in the top layer instead of the softmax functions, and then just using divisive normalisation. For example, if the raw, linear outputs of the top layer of the network for question one were z1, z2, z3, then the actual output for question one was given by max(z1, 0) / (max(z1, 0) + max(z2, 0) + max(z3, 0) + epsilon). The epsilon is a very small constant that prevented division by zero errors, I set it to 1e-12. This approach allowed the network to predict hard zeros more easily.

This is really where Theano shines: I could incorporate these constraints into the model simply by writing out what they were - no need to manually compute all the changed gradients. A big time saver! If I had to compute the gradients manually, I probably wouldn’t even have bothered to try incorporating the constraints in the first place.

Architecture of the best model

The best model I found is shown below in the form of a Krizhevsky-style diagram. All other models included in the final ensemble I submitted are slight variations of this model.

Krizhevsky-style diagram of the architecture of the best performing network.

The input is presented to the model in the form of RGB coloured 45x45 image parts.

The model has 7 layers: 4 convolutional layers and 3 dense layers. All convolutional layers include a ReLU nonlinearity (i.e. f(x) = max(x, 0)). The first, second and fourth convolutional layers are followed by 2x2 max-pooling. The sizes of the layers, as well as the sizes of the filters, are indicated in the figure.

As mentioned before, the convolutional part of the network is applied to 16 different parts of the input image. The extracted features for all these parts are then aggregated and connected to the dense part of the network.

The dense part consists of two maxout layers with 2048 units (Goodfellow et al. 2013), both of which take the maximum over pairs of linear filters (so 4096 linear filters in total). Using maxout here instead of regular dense layers with ReLUs helped to reduce overfitting a lot, compared to dense layers with 4096 linear filters. Using maxout in the convolutional part of the network as well proved too computationally intensive.

Training this model took 67 hours.

Variants

Variants of the best model were included in the final ensemble I submitted, to increase variance (see “Model averaging”). They include:

  • a network with two dense layers instead of three (just one maxout layer)
  • a network with one of the dense layers reduced in size and applied individually to each part (resulting in 16-way parameter sharing for this layer as well)
  • a network with a different filter size configuration: 8/4/3/3 instead of 6/5/3/3 (from bottom to top)
  • a network with centered and rescaled input images
  • a network with a ReLU dense layer instead of maxout
  • a network with 192 filters instead of 128 for the topmost convolutional layer
  • a network with 256 filters instead of 128 for the topmost convolutional layer
  • a network with norm constraint regularisation applied to the two maxout layers (as in Hinton et al. 2012)
  • combinations of the above variations

Training

Validation

For validation purposes, I split the training set in two parts. I used the first 90% for training, and the remainder for validation. I noticed quite early on that the estimates on my validation set matched the public leaderboard pretty well. This implied that submitting frequently was unnecessary - but nevertheless I couldn’t resist :)

Near the end of the competition I tried retraining a model on the entire training set, including the validation data I split off, but I noticed no increase in performance on the public leaderboard, so I left it at that. The separate validation set came in handy for model averaging anyway.

Training algorithm

I trained the networks with stochastic gradient descent (SGD) and Nesterov momentum (fixed at 0.9). I used a minibatch size of 16 examples. This meant that the effective minibatch size for the convolutional part was 256 (see above). This worked well because the cuda-convnet convolution implementation is optimised for minibatch sizes that are multiples of 128.

I trained the networks for about 1.5 million gradient steps. I used a learning rate schedule with two discrete decreases. Initially it was set to 0.04. It was decreased tenfold to 0.004 after about 1.1 million steps, and again to 0.0004 after about 1.4 million steps.

For the first ~600 gradient steps, the divisive normalisation in the output layer was disabled. This was necessary to ensure convergence (otherwise it would get stuck at the start sometimes).

Initialisation

Some fiddling with the parameter initialisation was required to get the network to train properly. Most of the layer weights were initialised from a Gaussian distribution with mean zero and a standard deviation of 0.01, with biases initialised to 0.1. For the topmost convolutional layer, I increased the standard deviation to 0.1. For the dense layers, I reduced it to 0.001 and the biases were initialised to 0.01. These modifications were necessary presumably because these layers are much smaller resp. bigger than the others.

Regularisation

Dropout was used in all three dense layers, with a dropout probability of 0.5. This was absolutely essential to be able to train the network at all.

Near the very end of the competition I also experimented with norm constraint regularisation for the maxout layers. I chose the maximal norm for each layer based on a histogram of the norms of a network trained without norm constraint regularisation (I chose it so the tail of the histogram would be chopped off). I’m not entirely sure if this helped or not, since I was only able to do two runs with this setup.

Model averaging

Averaging across transformed images

For each individual model, I computed predictions for 60 affine transformations of the test set images: a combination of 10 rotations, spaced by 36°, 3 rescalings (with scale factors 1/1.2, 1 and 1.2) and flipping / no flipping. These were uniformly averaged. Even though the model architecture already incorporated a lot of invariances, this still helped quite a bit.

Computing these averaged test set predictions for a single model took just over 4 hours.

Averaging across architectures

The averaged predictions for each model were then uniformly blended again, across a number of different models (variants of the model described under “Architecture of the best model”). I also experimented with a weighted blend, optimised on the validation set I split off, but this turned out not to make a significant difference. However, I did use the learned weights to identify sets of predictions that were not contributing at all, and I removed those from the uniform blend as well.

My final submission was a blend of predictions from 17 different models, each of which were themselves blended across 60 transformations of the input. So in the end, I blended 1020 predictions for each test set image.

For comparison: my best single model achieved a score of 0.07671 on the public leaderboard. After averaging, I achieved a final score of 0.7467. This resulted in a score of 0.07492 on the private leaderboard.

Miscellany

Below are a bunch of things that I tried but ended up not using - either because they didn’t help my score, or because they slowed down training too much.

  • Adding Gaussian noise to the input images during training to reduce overfitting. This didn’t help.
  • Extracting crops from the input images at different scales, and training a multiscale convnet on them. It turned out that only the part of the network for the most detailed scale was actually learning anything. The other parts received no gradient and weren’t really learning.
  • Overlapping pooling. This seemed to help a little bit, but it slowed things down too much.
  • Downsampling the input images less (1.5x instead of 3x) and using a strided convolution in the first layer (with stride 2). This did not improve results, and dramatically increased memory usage.
  • Adding shearing to the data augmentation step. This didn’t help, but it didn’t hurt performance either. I assumed that it would hurt performance because question 7 pertains to the ellipticity of the galaxy (shearing would of course change this), but this didn’t seem to be the case.

Near the end of the competition I also toyed with a polar coordinate representation of the images. I suppose this could work well because rotations turn into translations in polar space, so the convnet’s inherent translation invariance would amount to rotation invariance in the original input space. Unfortunately I didn’t have enough time left to properly explore this approach, so I decided to focus on my initial approach instead.

I would also have liked to find a way to incorporate the test set images somehow (i.e. a transduction setup). Unsupervised pre-training seemed pointless, because it tends not to be beneficial when rectified linear units and dropout are used, except maybe when labeled training data is very scarce. I really like the pseudo-label approach of Dong-Hyun Lee (2013), but it could not be applied here because it only makes sense for classification problems, not regression problems. If anyone has ideas for this, I’m still interested!

Conclusion

This was a really cool competition, and even though I had some prior experience training convnets, I learnt a lot of new things about them.

If this problem interests you, be sure to check out the competition forum. Many of the participants will be posting overviews of their approaches in the coming days.

I would like to thank the organisers of the competition, as well as the authors of Theano, cuda-convnet and pylearn2 for providing me with the necessary tools.

I will clean up my code and I’ll put it on GitHub soon. If you have any questions or feedback about this post, feel free to leave a comment.

Update (April 6th):

Convolutional neural networks (convnets) are all the rage right now. Training a convnet on any reasonably sized dataset is very computationally intensive, so GPU acceleration is indispensible. In this post I’ll show how you can use the blazing fast convolution implementation from Alex Krizhevsky’s cuda-convnet in Theano.

As an example, I’ll also show how the LeNet deep learning tutorial on convolutional neural networks can be modified to use this convolution implementation instead of Theano’s own, resulting in a 3x speedup.

Introduction

Quite a few libraries offering GPU-accelerated convnet training have sprung up in recent years: cuda-convnet, Caffe and Torch7 are just a few. I’ve been using Theano for anything deep learning-related, because it offers a few advantages:

  • it allows you to specify your models symbolically, and compiles this representation to optimised code for both CPU and GPU. It (almost) eliminates the need to deal with implementation details;
  • it does symbolic differentiation: given an expression, it can compute gradients for you.

The combination of these two advantages makes Theano ideal for rapid prototyping of machine learning models trained with gradient descent. For example, if you want to try a different objective function, just change that one line of code where you define it, and Theano takes care of the rest. No need to recompute the gradients, and no tedious optimisation to get it to run fast enough. This is a huge time saver.

Performance vs. flexibility

Theano comes with a 2D convolution operator out of the box, but its GPU implementation hasn’t been the most efficient for a while now, and other libraries have surpassed it in performance. Unfortunately, they don’t typically offer the flexibility that Theano offers.

Luckily, we no longer need to choose between performance and flexibility: the team behind pylearn2, a machine learning research library built on top of Theano, has wrapped the blazing fast convolution implementation from Alex Krizhevsky’s cuda-convnet library so that it can be used in Theano.

This wrapper can be used directly from Theano, without any dependencies on other pylearn2 components. It is not a drop-in replacement for Theano’s own conv2d, and unfortunately its documentation is limited, so in this post I’m going to try and describe how to use it. I’ve seen speedups of 2x-3x after replacing Theano’s own implementation with this one in some of my own code, so if you’re doing convolutions in Theano this is definitely worth trying out.

Why not just use cuda-convnet? cuda-convnet is an impressive piece of software, and while it does implement a lot of state-of-the-art techniques, it does not offer the same degree of flexibility that Theano offers.

Why not just use pylearn2 then? Although pylearn2 is specifically aimed at researchers, it has a fairly steep learning curve due to its emphasis on modularity and code reuse. Of course these are desirable qualities, but when writing research code, I personally prefer to keep the cognitive overhead minimal, and using Theano affords me that luxury.

Requirements

For this post I will assume that Python, numpy and Theano are installed and working, and that you have access to a CUDA-enabled GPU.

Make sure to configure Theano to use the GPU: set device=gpu and floatX=float32 in your .theanorc file, or in the THEANO_FLAGS environment variable (more info in the Theano documentation).

You will also need to get pylearn2:

git clone git://github.com/lisa-lab/pylearn2.git

Add the resulting directory to your PYTHONPATH, so the pylearn2 module can be imported:

export PYTHONPATH=$PYTHONPATH:/path/to/pylearn2

Detailed installation instructions can be found in the pylearn2 documentation. However, note that some dependencies (PIL, PyYAML) will not be necessary if you are only going to use the cuda-convnet wrappers.

The convnet from Krizhevsky et al.'s NIPS 2012 ImageNet classification paper.

Usage

Overview

Assume the following imports and definitions:

import theano.tensor as T
input = T.tensor4('input')
filters = T.tensor4('filters')

We have defined two 4-tensors: one for the input data, and one for the filters that will be convolved with it. A 2D convolution in Theano is normally implemented as follows:

from theano.tensor.nnet import conv

out = conv.conv2d(input, filters, filter_shape=filter_shape,
                  image_shape=image_shape)

To use the cuda-convnet wrappers from pylearn2 instead, use the following code:

from pylearn2.sandbox.cuda_convnet.filter_acts import FilterActs
from theano.sandbox.cuda.basic_ops import gpu_contiguous

conv_op = FilterActs()
contiguous_input = gpu_contiguous(input)
contiguous_filters = gpu_contiguous(filters)
out = conv_op(contiguous_input, contiguous_filters)

This is a little wordier. An interesting peculiarity is that the FilterActs wrapper needs to be instantiated before it can be used in a Theano expression (line 4).

Next, we need to make sure that the inputs are laid out correctly in memory (they must be C-contiguous arrays). This is what the gpu_contiguous helper function achieves (lines 5 and 6). It will make a copy of its input if the layout is not correct. Otherwise it does nothing. Wrapping the inputs in this way is not always necessary, but even if it isn’t, the performance overhead seems to be minimal anyway, so I recommend always doing it just to be sure.

The convolution can then be applied to the contiguous inputs (line 7).

Different input arrangement: bc01 vs. c01b

An important difference with Theano’s own implementation is that FilterActs expects a different arrangement of the input. Theano’s conv2d expects its input to have the following shapes:

  • input: (batch size, channels, rows, columns)
  • filters: (number of filters, channels, rows, columns)

In pylearn2, this input arrangement is referred to as bc01. In cuda-convnet, the following shapes are expected instead:

  • input: (channels, rows, columns, batch_size)
  • filters: (channels, rows, columns, number of filters)

This is referred to as c01b.

If you have an existing codebase which assumes the bc01 arrangement everywhere, the simplest way to deal with this is to use Theano’s dimshuffle method to change the order of the dimensions as appropriate:

conv_op = FilterActs()
input_shuffled = input.dimshuffle(1, 2, 3, 0) # bc01 to c01b
filters_shuffled = filters.dimshuffle(1, 2, 3, 0) # bc01 to c01b
contiguous_input = gpu_contiguous(input_shuffled)
contiguous_filters = gpu_contiguous(filters_shuffled)
out_shuffled = conv_op(contiguous_input, contiguous_filters)
out = out_shuffled.dimshuffle(3, 0, 1, 2) # c01b to bc01

However, this may incur a performance penalty because it requires making a copy of the data. This negates some of the performance gained by using the cuda-convnet implementation in the first place.

Contrary to what the Theano documentation says, the negative effect on performance of adding these dimshuffle calls is not necessarily that profound, in my experience. Nevertheless, using the c01b arrangement everywhere will result in faster execution.

Convolution vs. correlation

The code fragment above still isn’t a drop-in replacement for Theano’s conv2d, because of another subtle difference: FilterActs technically implements a correlation, not a convolution. In a convolution, the filters are flipped before they are slided across the input. In correlation, they aren’t. So to perform an operation that is equivalent to Theano’s conv2d, we have to flip the filters manually (line 4):

conv_op = FilterActs()
input_shuffled = input.dimshuffle(1, 2, 3, 0) # bc01 to c01b
filters_shuffled = filters.dimshuffle(1, 2, 3, 0) # bc01 to c01b
filters_flipped = filters_shuffled[:, ::-1, ::-1, :] # flip rows and columns
contiguous_input = gpu_contiguous(input_shuffled)
contiguous_filters = gpu_contiguous(filters_flipped)
out_shuffled = conv_op(contiguous_input, contiguous_filters)
out = out_shuffled.dimshuffle(3, 0, 1, 2) # c01b to bc01

However, when the filters are being learned from data, it doesn’t really matter how they are oriented, as long as they are always oriented in the same way. So in practice, it is rarely necessary to flip the filters.

Limitations

FilterActs has several limitations compared to conv2d:

  • The number of channels must be even, or less than or equal to 3. If you want to compute the gradient, it should be divisible by 4. If you’re training a convnet, that means valid numbers of input channels are 1, 2, 3, 4, 8, 12, 16, …
  • Filters must be square, the number of rows and columns should be equal. For images, square filters are usually what you want anyway, but this can be a serious limitation when working with non-image data.
  • The number of filters must be a multiple of 16.
  • All minibatch sizes are supported, but the best performance is achieved when the minibatch size is a multiple of 128.
  • Only “valid” convolutions are supported. If you want to perform a “full” convolution, you will need to use zero-padding (more on this later).
  • FilterActs only works on the GPU. You cannot run your Theano code on the CPU if you use it.

Tuning the time-memory trade-off with partial_sum

When instantiating FilterActs, we can specify the partial_sum argument to control the trade-off between memory usage and performance. From the cuda-convnet documentation:

partialSum is a parameter that affects the performance of the weight gradient computation. It’s a bit hard to predict what value will result in the best performance (it’s problem-specific), but it’s worth trying a few. Valid values are ones that divide the area of the output grid in this convolutional layer. For example if this layer produces 32-channel 20x20 output grid, valid values of partialSum are ones which divide 20*20 = 400.

By default, partial_sum is set to None, which is the most conservative setting in terms of memory usage. To speed things up, the value can be tuned as described above, at the expense of higher memory usage. In practice, setting it to 1 tends to work very well (and it’s always a valid value, regardless of the size of the output grid):

conv_op = FilterActs(partial_sum=1)
contiguous_input = gpu_contiguous(input)
contiguous_filters = gpu_contiguous(filters)
out = conv_op(contiguous_input, contiguous_filters)

I recommend setting partial_sum to 1 and leaving it at that. In most cases this will work well enough, and it saves you the trouble of having to recompute the divisors of the output grid area every time you change the filter size. I have observed only very minimal performance gains from optimising this setting.

If you don’t have a lot of GPU memory to spare, leaving this setting at None will reduce performance, but it will likely still be quite a bit faster than Theano’s implementation.

Strided convolutions

Although Theano’s conv2d allows for convolutions with strides different from 1 through the subsample parameter, the performance tends to be a bit disappointing, in my experience. FilterActs has much better support for strided convolutions. The stride argument can be specified when instantiating FilterActs (it defaults to 1):

conv_op = FilterActs(partial_sum=1, stride=2)
contiguous_input = gpu_contiguous(input)
contiguous_filters = gpu_contiguous(filters)
out = conv_op(contiguous_input, contiguous_filters)

stride should be an integer, not a tuple, so this implies that the stride has to be the same in both dimensions, just like the filter size.

This is very useful for large input images, since it is a lot cheaper than computing a full convolution and then pooling the result. In the ImageNet classification paper, Krizhevsky et al. used a convolution with stride 4 in the first layer of their convnet.

Zero-padding

FilterActs supports another optional argument pad, which defaults to 0. Setting this to another value p will implicitly pad the input with a border of p zeros on all sides. This does not use extra memory, so it is much cheaper than adding the padding yourself.

This argument can be used to implement a “full” convolution instead of a “valid” one, by padding the input with filter_size - 1 zeros:

conv_op = FilterActs(partial_sum=1, pad=filter_size - 1)
contiguous_input = gpu_contiguous(input)
contiguous_filters = gpu_contiguous(filters)
out = conv_op(contiguous_input, contiguous_filters)

Let n be the input size and f the filter size, then padding the input with f - 1 zeros on all sides changes the input size to n + 2f - 2. Applying the convolution then results in an output size of (n + 2f - 2) - (f - 1) = n + f - 1, which corresponds to a “full” convolution.

Max-pooling

In addition to FilterActs, there is also a MaxPool wrapper. In Theano, you would implement 2D max-pooling as follows:

from theano.tensor.signal import downsample

out = downsample.max_pool_2d(input, ds=(2, 2))

To use the wrappers instead:

from pylearn2.sandbox.cuda_convnet.pool import MaxPool
from theano.sandbox.cuda.basic_ops import gpu_contiguous

pool_op = MaxPool(ds=2, stride=2)
contiguous_input = gpu_contiguous(input)
out = pool_op(contiguous_input)

Once again we need to ensure C-contiguousness with gpu_contiguous. The input should be in c01b format as before.

Note that the MaxPool op accepts both a ds and a stride argument. If you set both to the same value, you get traditional max-pooling. If you make ds larger than stride, you get overlapping pooling regions. This was also used in the ImageNet classification paper mentioned earlier: they used a pool size of 3 and a stride of 2, so each pool overlaps with the next by 1 pixel.

ds and stride should be integers, not tuples, so this implies that pooling regions should be square, and the strides should be the same in both dimensions.

Another important limitation is that MaxPool only works for square input images. No such limitation applies for FilterActs. If you run into problems with this, you could use FilterActs in combination with Theano’s own max_pool_2d implementation - it’s a bit slower this way, but max-pooling is not the bottleneck in a convnet anyway, the convolutions are.

Other wrappers

There are a few other wrappers for cuda-convnet code in pylearn2: ProbMaxPool (probabilistic max-pooling, Lee et al. 2009), StochasticMaxPool, WeightedMaxPool (stochastic max-pooling, Zeiler et al. 2013) and CrossMapNorm (cross-channel normalisation, Krizhevsky et al., 2012). I will not discuss these in detail, but many of the same remarks and restrictions apply.

More information about these wrappers can be found in the pylearn2 documentation. Any missing information can usually be found in the cuda-convnet documentation.

The stochastic max-pooling implementation is not from cuda-convnet itself, but was built on top of it. As a result, it’s actually pretty slow. If you need this, implementing stochastic max-pooling yourself in Theano may be faster.

Modifying the LeNet deep learning tutorial

To wrap up this post, let’s modify the deep learning tutorial on convolutional neural networks to use these wrappers instead of Theano’s own implementations. The tutorial explains how to train a convolutional neural network on the MNIST dataset with Theano. If you’re not familiar with it, have a look at the tutorial before continuing.

The convnet from the LeNet deep learning tutorial.

You can download the necessary files below (place them in the same directory):

All the code we’ll have to modify is in convolutional_mlp.py. First, let’s add and replace the necessary imports. In what follows, all replaced code is commented. On line 34:

# from theano.tensor.signal import downsample
# from theano.tensor.nnet import conv
from theano.sandbox.cuda.basic_ops import gpu_contiguous
from pylearn2.sandbox.cuda_convnet.filter_acts import FilterActs
from pylearn2.sandbox.cuda_convnet.pool import MaxPool

Next, we’ll need to modify the LeNetConvPoolLayer class to use FilterActs and MaxPool instead. On line 88:

# convolve input feature maps with filters
# conv_out = conv.conv2d(input=input, filters=self.W,
#         filter_shape=filter_shape, image_shape=image_shape)
input_shuffled = input.dimshuffle(1, 2, 3, 0) # bc01 to c01b
filters_shuffled = self.W.dimshuffle(1, 2, 3, 0) # bc01 to c01b
conv_op = FilterActs(stride=1, partial_sum=1)
contiguous_input = gpu_contiguous(input_shuffled)
contiguous_filters = gpu_contiguous(filters_shuffled)
conv_out_shuffled = conv_op(contiguous_input, contiguous_filters)

And on line 92:

# downsample each feature map individually, using maxpooling
# pooled_out = downsample.max_pool_2d(input=conv_out,
#                                     ds=poolsize, ignore_border=True)
pool_op = MaxPool(ds=poolsize[0], stride=poolsize[0])
pooled_out_shuffled = pool_op(conv_out_shuffled)
pooled_out = pooled_out_shuffled.dimshuffle(3, 0, 1, 2) # c01b to bc01

Note that we’re using plenty of dimshuffle calls here, so we can keep using the bc01 input arrangement in the rest of the code and no further changes are necessary. Also note that we did not flip the filters: there is no point because the weights are being learned.

Just one more change is necessary: the tutorial specifies a convnet with two convolutional layers, with 50 and 20 filters respectively. This is not going to work with FilterActs, which expects the number of filters to be a multiple of 16. So we’ll have to change the number of filters on line 106:

# def evaluate_lenet5(learning_rate=0.1, n_epochs=200,
#                     dataset='mnist.pkl.gz',
#                     nkerns=[20, 50], batch_size=500):
def evaluate_lenet5(learning_rate=0.1, n_epochs=200,
                    dataset='mnist.pkl.gz',
                    nkerns=[32, 64], batch_size=500):

Now it should work. You can download the modified file here (place it in the same directory):

Running the unmodified code for 50 epochs with 32 and 64 filters respectively takes 110 minutes on the GeForce GT 540M in my laptop:

Optimization complete.
Best validation score of 1.120000 % obtained at iteration 5000,
with test performance 1.000000 %
The code for file convolutional_mlp.py ran for 109.47m

With FilterActs and MaxPool instead of the Theano implementation, it only takes 34 minutes:

Optimization complete.
Best validation score of 1.120000 % obtained at iteration 5000,
with test performance 0.940000 %
The code for file convolutional_mlp_cc.py ran for 33.76m

A 3.2x speedup!

On a workstation with a GeForce GTX 680, the unmodified code takes 13.15 minutes for 50 epochs. Using FilterActs, it takes 4.75 minutes, which amounts to a 2.7x speedup.