Transfer Learning in Keras

Image credit: Raw&Rendered

Transfer Learning in Keras

Using Transfer Learning and Bottlenecking to Capitalize on State of the Art DNNs

As a data scientist, I’m really interested in how deep learning networks can be deployed in industry. Usually the problems are too niche, there isn’t enough training data, or there is a lack of computing power. Transfer learning is a way to make deep neural networks more accessible by overcoming some of these challenges.

Transfer learning is modification of pre-trained neural networks to perform on new datasets that they were not trained on. In this project specifically, the pre-trained networks include InceptionV3, VGG16, and ResNet50. These are all convolutional neural networks that were trained on ImageNet and extremely successful in the competition at one time or another. The new dataset is CIFAR-10.

If you are interested in a more general description of transfer learning, check out my article on Medium about this project. I write more about what situations transfer learning is effective, how to implement it, and how to optimize it using bottlenecking.

If you are interested in the actual implementation and workflow of transfer learning, definetely take a look at the project repository. There is a detailed Jupyter Notebook that walks through each step of the process and also contains everything found in this article and the one on Medium.

As a base case, this project includes a simple CNN that was trained for just a few epochs to show how much quicker pre-trained networks converge. The base case network was written in Keras and is fashioned after the original LeNet architecture, with the addition of dropout as a regularization technique.

Below is a d3.js visualization of training epoch vs validation loss/accuracy.

Click on the labels beneath the graph to toggle the lines.

As the graph shows, all three pre-trained networks start off better than the base case. The rate at which model accuracy increases seems to be faily consistent through all training epochs, across all models.

This is the most surprising result from the project. I think it may be in part due to the simple nature of the CIFAR-10 dataset. Each image is 28x28x3 for a total of 2352 pixels. The ImageNet dataset consists of 200x200x3 or 120,000 pixels for each image. The complexity of the data is an order of magnitude larger. It is most likely the case that the difference in network architecture becomes more apparent with higher dimensional data. Likewise, the rate of training between networks may have more variation if trained on data more complex than CIFAR-10.

However, in industry, these networks (especially ResNet50 in this case) required extremely little training time and were relatively easy to implement. Once there is a proof of concept, it is a lot easier to write an optimized network that suits your needs (and maybe mimics the network you transfer learned from) than it is to both write and train from scratch.

In a similar vein, DeepMind recently published a paper that shows the concept of transfering the learning of reinforcement agents from environment to environment. Although the implementation is very different, the idea of taking advantage of previous learning is fundamental and I suspect it will become a pillar of deep learning as we move forward.