In [1]:
# code for loading the format for the notebook
import os

# path : store the current path to convert back to it later
path = os.getcwd()
os.chdir(os.path.join('..', '..', 'notebook_format'))

from formats import load_style
load_style(plot_style=False)
Out[1]:
In [2]:
os.chdir(path)

import cv2
import numpy as np
import matplotlib.pyplot as plt
import keras.layers as layers
from keras.models import Model
from keras.preprocessing import image

# 1. magic so that the notebook will reload external python modules
# 2. magic for inline plot
# 3. magic to enable retina (high resolution) plots
# https://gist.github.com/minrk/3301035
# 4. magic to print version
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format='retina'
%load_ext watermark

%watermark -a 'Ethen' -d -t -v -p cv2,keras,numpy,matplotlib,tensorflow
Using TensorFlow backend.
Ethen 2018-09-04 20:57:58 

CPython 3.6.4
IPython 6.4.0

cv2 3.4.0
keras 2.2.2
numpy 1.14.1
matplotlib 2.2.2
tensorflow 1.7.0

ResNet (Residual Network)

BackGround

In the field of deep learning, often times we would like to add more nodes or more layers to our model to increase the its learning capability (introduce other methods to prevent overfitting as well of course). But the question now arises, is better network as easy as stacking more layers? Or is there a limit in how deep should we go.

As you the imagine, the reason why we're raising this question is because in practice, when we start pushing the "deep-ness" of our network to its extreme, we do observe that making the network too deep can actually be detrimental. The image below shows the training and testing error of a 20-layer and 56-layer convnet. As we can see from the chart, a deeper network actually resulted in a higher training error.

This indicates not all networks can be optimized in the same fashion.

In this notebook, we will be introducing ResNet (Residual Network). When it was unleashed in 2015, this type of network won all image classification, detection, and localization challenges of ImageNet. As depicted in the chart below, this type of network can benefit from increasing the depth of the network and enjoy the accuracy gain, while the counterpart "plain" network degrades in performance when we increases its depth.

Introduction

Before jumping straight into ResNets let's do a quick recap on the concept of a residual.

Residual is the error in a result. Let's say, we were asked to predict the age of a person, just by looking at the person. If his/her actual age is 20, and we predicted 18, we are off by 2, and our residual in this case would be 2. If we had predicted 21, we would have been off by -1. In essence, residual is what we should have added to our prediction to match the actual/expected value.

What is important to understand here is that, if the residual is 0, we should not be performing any action since the prediction already matched the actual output. This idea can be depicted with the following diagram.

In the diagram above, x is our prediction and we want it to be equal to Actual. When is it off by a margin, our residual function residual() will kick in and tries to correct our prediction to match the actual. On the other hand, when x == Actual, residual(x) will be 0. And the Identity function just copies x as is.

Now back to ResNets. This architecture's main highlight is the use of residual blocks. a.k.a skip-connections, shortcuts.

Let us consider $H(x)$ as an underlying mapping to be fit by a few stacked layers (not necessarily the entire net), with $x$ denoting the inputs to the first of these layers. If one hypothesizes that multiple nonlinear layers can asymptotically approximate complicated functions, then it is equivalent to hypothesize that they can asymptotically approximate the residual functions, i.e., $H(x) − x$ (assuming the input and output are of the same dimensions). So rather than expecting our stacked layers to approximate $H(x)$, we explicitly let these layers approximate a residual function $F(x) := H(x) − x$. The original function thus becomes $F(x) + x$.

The paragraph from the original paper can be summarized into the following diagram:

In the diagram above, we have our main path on the left, which consists of multiple nonlinear layers. And on the right, we have our shortcut connections, which skips one or more layers and perform identity mapping (i.e. copying the original input), and their outputs are added to the outputs of the stacked layers.

The idea is that behind this design is that during training, if the residual network learns the identity mapping were optimal weights, all the solver can push the weights of the multiple nonlinear layers to 0. This means no corrections need to be made, our $F(x)$ is essentially set to 0, and the network can use the shortcut path to perform an identity mapping. In real cases, it is unlikely that identity mappings are optimal, but this formulation may help precondition the problem. Meaning it should be easier for the solver to find a more suitable weight with reference to an identity mapping, than to learn a completely new one.

Identity Block

Let's take a look at how we can implement this skip-connection with keras.

The identity block is one of the standard building blocks of ResNets.

The idea is we have a "main path" (the lower path) and the "shortcut path", the identity block here skips over three layers (each layer is composed of a convolutional, a batchnorm layer followed by a relu activation). And right before the end of third layer, i.e. before its relu activation, we add the shortcut path back to the main path.

In [3]:
from keras.initializers import glorot_uniform


def identity_block(input_tensor, kernel_size, filters, stage, block):
    """ 
    An identity block.

    Parameters
    ----------
    input_tensor:

    kernel_size: int
        The kernel size of middle conv layer at main path.

    filters: list[int]
        The filters of 3 conv layer at main path.

    stage: int
        Current stage label, used for generating layer names.

    block: : str
        'a','b'..., current block label, used for generating layer names.

    Returns
    -------
    Output tensor for the block.
    """
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    
    # for batch normalization layer, we assume
    # the input data is in channel last format
    bn_axis = 3

    filters1, filters2, filters3 = filters
  
    # main path, note that setting the kernel_initializer seed here is only used
    # for reproducibility, we techniqually don't need it
    x = layers.Conv2D(filters1, kernel_size=(1, 1), strides=(1, 1),
                      kernel_initializer=glorot_uniform(seed=0),
                      padding='valid', name=conv_name_base + '2a')(input_tensor)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters2, kernel_size, strides=(1, 1),
                      kernel_initializer=glorot_uniform(seed=0),
                      padding='same', name=conv_name_base + '2b')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters3, kernel_size=(1, 1), strides=(1, 1),
                      kernel_initializer=glorot_uniform(seed=0),
                      padding='valid', name=conv_name_base + '2c')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)

    # this line is the core component of resnet, the skip connection, i.e.
    # having a shortcut to the main path before the activation, when addition
    # is performed on convolutional layers, the element-wise addition is performed
    # on their feature maps, i.e. channel by channel
    x = layers.add([x, input_tensor])
    x = layers.Activation('relu')(x)
    return x
In [4]:
# generate some fake data to work with
np.random.seed(0)
X = np.random.randn(3, 4, 4, 6)
print('original data shape:', X.shape)
  

stage = 1
block = 'a'
inputs = layers.Input(shape=X.shape[1:])

outputs = identity_block(inputs, kernel_size=2, filters=[2, 4, 6], stage=stage, block=block)

model = Model(inputs=inputs, outputs=outputs)
prediction = model.predict(X)
print('identity block output shape:', prediction.shape)
prediction[1, 1, 0]
original data shape: (3, 4, 4, 6)
identity block output shape: (3, 4, 4, 6)
Out[4]:
array([0.7089733 , 0.        , 1.1173227 , 1.5701073 , 0.8489223 ,
       0.05919211], dtype=float32)

Convolutional Block

The convolutional block is another type of block for ResNet block, and is used when the input and output dimension doesn't match up when we change the channel size. For example, to reduce the activation dimensions's height and width by a factor of 2, we can use a $1 \times 1$ convolution with a stride of 2.

The CONV2D layer on the shortcut path does not use any non-linear activation function. Its main role is to apply a (learned) linear function that reduces the dimension of the input, so that the dimensions match up for the later addition step.

In [5]:
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
    """ 
    A block that has a conv layer at shortcut.

    Parameters
    ----------
    input_tensor:
    
    kernel_size: int
        The kernel size of middle conv layer at main path.

    filters: list[int]
        The filters of 3 conv layer at main path.

    stage: int
        Current stage label, used for generating layer names.

    block: : str
        'a','b'..., current block label, used for generating layer names.
        
    strides : tuple, default (2, 2)
        Strides for the first conv layer in the block.

    Returns
    -------
    Output tensor for the block.
    """
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    
    # for batch normalization layer, we assume
    # the input data is in channel last format,
    # which is the case if we are using the default
    # keras' backend tensorflow
    bn_axis = 3

    filters1, filters2, filters3 = filters
  
    # main path, note that setting the kernel_initializer set here is only used
    # for reproducibility, we techniqually don't need it
    x = layers.Conv2D(filters1, kernel_size=(1, 1), strides=strides,
                      kernel_initializer=glorot_uniform(seed=0),
                      padding='valid', name=conv_name_base + '2a')(input_tensor)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters2, kernel_size, strides=(1, 1),
                      kernel_initializer=glorot_uniform(seed=0),
                      padding='same', name=conv_name_base + '2b')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(filters3, kernel_size=(1, 1), strides=(1, 1),
                      kernel_initializer=glorot_uniform(seed=0),
                      padding='valid', name=conv_name_base + '2c')(x)
    x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
    
    # we resize the input so its dimension will match the output dimension
    # of the main path
    shortcut = layers.Conv2D(filters3, kernel_size=(1, 1), strides=strides,
                             kernel_initializer=glorot_uniform(seed=0),
                             padding='valid', name=conv_name_base + '1')(input_tensor)
    shortcut = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '1')(shortcut) 

    # this line is the core component of resnet, the skip connection, i.e.
    # having a shortcut to the main path before the activation
    x = layers.add([x, shortcut])
    x = layers.Activation('relu')(x)
    return x
In [6]:
stage = 1
block = 'a'
inputs = layers.Input(shape=X.shape[1:])

outputs = conv_block(inputs, kernel_size=2, filters=[2, 4, 6], stage=stage, block=block)

model = Model(inputs=inputs, outputs=outputs)
prediction = model.predict(X)
print('identity block output shape:', prediction.shape)
prediction[1, 1, 0]
identity block output shape: (3, 2, 2, 6)
Out[6]:
array([0.        , 1.0165374 , 0.        , 0.        , 0.        ,
       0.48622143], dtype=float32)

ResNet In Action

Now that we have a basic understanding of the definition of a ResNet, we will build one a train it on the MNIST dataset.

In [7]:
from keras.datasets import mnist
from keras.utils import np_utils


(X_train, y_train), (X_test, y_test) = mnist.load_data()
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0] , 'test samples')
X_train shape: (60000, 28, 28)
60000 train samples
10000 test samples
In [8]:
n_classes = 10
img_rows, img_cols = 28, 28

# mnist is grey-scaled image, thus the last dimension, channel size will be 1
X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = img_rows, img_cols, 1

X_train = X_train.astype('float32')
X_test  = X_test.astype('float32')

# images takes values between 0 - 255, we can normalize it
# by dividing every number by 255
X_train /= 255
X_test /= 255
print('train shape:', X_train.shape)

# one-hot encode the class (target) vectors
Y_train = np_utils.to_categorical(y_train, n_classes)
Y_test = np_utils.to_categorical(y_test , n_classes)
print('y_train shape:', Y_train.shape)
train shape: (60000, 28, 28, 1)
y_train shape: (60000, 10)
In [9]:
def ResNet(input_shape, n_classes):
    """
    Definition of ResNet
    
    References
    ----------
    https://github.com/keras-team/keras-applications/blob/master/keras_applications/resnet50.py
    """
    img_input = layers.Input(shape=input_shape)
    
    bn_axis = 3
    
    x = layers.ZeroPadding2D(padding=(3, 3), name='conv1_pad')(img_input)
    x = layers.Conv2D(64, (7, 7),
                      strides=(2, 2),
                      padding='valid',
                      kernel_initializer='he_normal',
                      name='conv1')(x)
    x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
    x = layers.Activation('relu')(x)
    x = layers.ZeroPadding2D(padding=(1, 1), name='pool1_pad')(x)
    x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)

    x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
    x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')

    x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
    x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
    
    # the commented out blocks are what's needed to build out the
    # full ResNet50 (a ResNet with 50 layers), we won't be needing
    # the complexity here
    # x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
    # x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
    # x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
    # x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
    # x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
    # x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')

    # x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
    # x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
    # x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
    
    x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    img_output = layers.Dense(n_classes, activation='softmax', name='fc' + str(n_classes))(x)

    model = Model(inputs=img_input, outputs=img_output, name='resnet')
    return model


model = ResNet(input_shape, n_classes)
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 34, 34, 1)    0           input_3[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 14, 14, 64)   3200        conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 14, 14, 64)   256         conv1[0][0]                      
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 14, 14, 64)   0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D)       (None, 16, 16, 64)   0           activation_7[0][0]               
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 7, 7, 64)     0           pool1_pad[0][0]                  
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, 7, 7, 64)     4160        max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 7, 7, 64)     256         res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 7, 7, 64)     0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, 7, 7, 64)     36928       activation_8[0][0]               
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 7, 7, 64)     256         res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 7, 7, 64)     0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, 7, 7, 256)    16640       activation_9[0][0]               
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, 7, 7, 256)    16640       max_pooling2d_1[0][0]            
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 7, 7, 256)    1024        res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 7, 7, 256)    1024        res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_3 (Add)                     (None, 7, 7, 256)    0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_10 (Activation)      (None, 7, 7, 256)    0           add_3[0][0]                      
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, 7, 7, 64)     16448       activation_10[0][0]              
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 7, 7, 64)     256         res2b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_11 (Activation)      (None, 7, 7, 64)     0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, 7, 7, 64)     36928       activation_11[0][0]              
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 7, 7, 64)     256         res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_12 (Activation)      (None, 7, 7, 64)     0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, 7, 7, 256)    16640       activation_12[0][0]              
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 7, 7, 256)    1024        res2b_branch2c[0][0]             
__________________________________________________________________________________________________
add_4 (Add)                     (None, 7, 7, 256)    0           bn2b_branch2c[0][0]              
                                                                 activation_10[0][0]              
__________________________________________________________________________________________________
activation_13 (Activation)      (None, 7, 7, 256)    0           add_4[0][0]                      
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, 7, 7, 64)     16448       activation_13[0][0]              
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 7, 7, 64)     256         res2c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_14 (Activation)      (None, 7, 7, 64)     0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, 7, 7, 64)     36928       activation_14[0][0]              
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 7, 7, 64)     256         res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_15 (Activation)      (None, 7, 7, 64)     0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, 7, 7, 256)    16640       activation_15[0][0]              
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 7, 7, 256)    1024        res2c_branch2c[0][0]             
__________________________________________________________________________________________________
add_5 (Add)                     (None, 7, 7, 256)    0           bn2c_branch2c[0][0]              
                                                                 activation_13[0][0]              
__________________________________________________________________________________________________
activation_16 (Activation)      (None, 7, 7, 256)    0           add_5[0][0]                      
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 4, 4, 128)    32896       activation_16[0][0]              
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 4, 4, 128)    512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_17 (Activation)      (None, 4, 4, 128)    0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 4, 4, 128)    147584      activation_17[0][0]              
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 4, 4, 128)    512         res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_18 (Activation)      (None, 4, 4, 128)    0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, 4, 4, 512)    66048       activation_18[0][0]              
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, 4, 4, 512)    131584      activation_16[0][0]              
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 4, 4, 512)    2048        res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 4, 4, 512)    2048        res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_6 (Add)                     (None, 4, 4, 512)    0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_19 (Activation)      (None, 4, 4, 512)    0           add_6[0][0]                      
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, 4, 4, 128)    65664       activation_19[0][0]              
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 4, 4, 128)    512         res3b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_20 (Activation)      (None, 4, 4, 128)    0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, 4, 4, 128)    147584      activation_20[0][0]              
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 4, 4, 128)    512         res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_21 (Activation)      (None, 4, 4, 128)    0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, 4, 4, 512)    66048       activation_21[0][0]              
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 4, 4, 512)    2048        res3b_branch2c[0][0]             
__________________________________________________________________________________________________
add_7 (Add)                     (None, 4, 4, 512)    0           bn3b_branch2c[0][0]              
                                                                 activation_19[0][0]              
__________________________________________________________________________________________________
activation_22 (Activation)      (None, 4, 4, 512)    0           add_7[0][0]                      
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, 4, 4, 128)    65664       activation_22[0][0]              
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 4, 4, 128)    512         res3c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_23 (Activation)      (None, 4, 4, 128)    0           bn3c_branch2a[0][0]              
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, 4, 4, 128)    147584      activation_23[0][0]              
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 4, 4, 128)    512         res3c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_24 (Activation)      (None, 4, 4, 128)    0           bn3c_branch2b[0][0]              
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, 4, 4, 512)    66048       activation_24[0][0]              
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 4, 4, 512)    2048        res3c_branch2c[0][0]             
__________________________________________________________________________________________________
add_8 (Add)                     (None, 4, 4, 512)    0           bn3c_branch2c[0][0]              
                                                                 activation_22[0][0]              
__________________________________________________________________________________________________
activation_25 (Activation)      (None, 4, 4, 512)    0           add_8[0][0]                      
__________________________________________________________________________________________________
res3d_branch2a (Conv2D)         (None, 4, 4, 128)    65664       activation_25[0][0]              
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 4, 4, 128)    512         res3d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_26 (Activation)      (None, 4, 4, 128)    0           bn3d_branch2a[0][0]              
__________________________________________________________________________________________________
res3d_branch2b (Conv2D)         (None, 4, 4, 128)    147584      activation_26[0][0]              
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 4, 4, 128)    512         res3d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_27 (Activation)      (None, 4, 4, 128)    0           bn3d_branch2b[0][0]              
__________________________________________________________________________________________________
res3d_branch2c (Conv2D)         (None, 4, 4, 512)    66048       activation_27[0][0]              
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 4, 4, 512)    2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_9 (Add)                     (None, 4, 4, 512)    0           bn3d_branch2c[0][0]              
                                                                 activation_25[0][0]              
__________________________________________________________________________________________________
activation_28 (Activation)      (None, 4, 4, 512)    0           add_9[0][0]                      
__________________________________________________________________________________________________
avg_pool (GlobalAveragePooling2 (None, 512)          0           activation_28[0][0]              
__________________________________________________________________________________________________
fc10 (Dense)                    (None, 10)           5130        avg_pool[0][0]                   
==================================================================================================
Total params: 1,458,954
Trainable params: 1,448,842
Non-trainable params: 10,112
__________________________________________________________________________________________________
In [10]:
history = model.fit(X_train, Y_train, epochs=3, batch_size=32)
Epoch 1/3
60000/60000 [==============================] - 168s 3ms/step - loss: 0.1413 - acc: 0.9580
Epoch 2/3
60000/60000 [==============================] - 149s 2ms/step - loss: 0.0635 - acc: 0.9816
Epoch 3/3
60000/60000 [==============================] - 149s 2ms/step - loss: 0.0489 - acc: 0.9859
In [11]:
loss, accuracy = model.evaluate(X_test, Y_test)
print('Loss = ' + str(loss))
print('Test Accuracy = ' + str(accuracy))
10000/10000 [==============================] - 5s 536us/step
Loss = 0.044752538148895835
Test Accuracy = 0.9876

Takeaways:

  • Very deep "plain" networks don't work in practice because they are hard to train due to vanishing gradients.
  • Residual blocks, or so called skip-connections aims to address this vanishing gradient issue by making it easier for a network to learn an identity function.
  • There are two main type of blocks: The identity block and the convolutional block and very deep Residual Networks are built by stacking these blocks together. The latter is used to match the shortcut's dimension and the main path's dimension so the two paths can perform the element-wise addition.

CAM (Class Activation Map)

When working with machine learning models, two common questions that we would love to address is 1. prevent overfitting, 2. being able to provide an explanation of why my model generated the prediction.

In the context of convnets, one way to minimizes the chance of overfitting is the use of Global Average Pooling (GAP). Similar to max pooling layers, GAP layers are used to reduce the spatial dimensions. However, GAP layers perform a more extreme type of dimensionality reduction, where a tensor with dimensions $h \times w \times d$ is reduced to take the size of $1 \times 1 \times d$. i.e. GAP layers reduce each $h \times w$ feature map to a single number by taking the average of all $hw$ values. This operation can be depicted as the following diagram.

The other interesting property about GAP is that its advantages extend beyond acting as a regularization control over our model. With a little bit of tweaking, it allows us to identify exactly which regions of an image are being used for making the prediction towards the predicted class. That is, it will not only tells us what object is contained in the image, but also tells us where the object is in the image. The localization is presented as a heat map (referred to as class activation map from the original paper), where the color-coding scheme identifies regions that are relatively important for the network to perform the object identification task. The following snapshot shows this localization on some sample images:

This approach of highlighting which regions of an image are important to make the classification provides another view of interpreting the inner workings of our convnets.

We will leverage the pre-trained ResNet50 model from Keras to see CAM in action.

In [12]:
from keras.applications.resnet50 import ResNet50, preprocess_input


model = ResNet50(weights='imagenet')
model.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_4 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_4[0][0]                    
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 112, 112, 64) 256         conv1[0][0]                      
__________________________________________________________________________________________________
activation_29 (Activation)      (None, 112, 112, 64) 0           bn_conv1[0][0]                   
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 55, 55, 64)   0           activation_29[0][0]              
__________________________________________________________________________________________________
res2a_branch2a (Conv2D)         (None, 55, 55, 64)   4160        max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_30 (Activation)      (None, 55, 55, 64)   0           bn2a_branch2a[0][0]              
__________________________________________________________________________________________________
res2a_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_30[0][0]              
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_31 (Activation)      (None, 55, 55, 64)   0           bn2a_branch2b[0][0]              
__________________________________________________________________________________________________
res2a_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_31[0][0]              
__________________________________________________________________________________________________
res2a_branch1 (Conv2D)          (None, 55, 55, 256)  16640       max_pooling2d_2[0][0]            
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2a_branch2c[0][0]             
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 55, 55, 256)  1024        res2a_branch1[0][0]              
__________________________________________________________________________________________________
add_10 (Add)                    (None, 55, 55, 256)  0           bn2a_branch2c[0][0]              
                                                                 bn2a_branch1[0][0]               
__________________________________________________________________________________________________
activation_32 (Activation)      (None, 55, 55, 256)  0           add_10[0][0]                     
__________________________________________________________________________________________________
res2b_branch2a (Conv2D)         (None, 55, 55, 64)   16448       activation_32[0][0]              
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_33 (Activation)      (None, 55, 55, 64)   0           bn2b_branch2a[0][0]              
__________________________________________________________________________________________________
res2b_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_33[0][0]              
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_34 (Activation)      (None, 55, 55, 64)   0           bn2b_branch2b[0][0]              
__________________________________________________________________________________________________
res2b_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_34[0][0]              
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2b_branch2c[0][0]             
__________________________________________________________________________________________________
add_11 (Add)                    (None, 55, 55, 256)  0           bn2b_branch2c[0][0]              
                                                                 activation_32[0][0]              
__________________________________________________________________________________________________
activation_35 (Activation)      (None, 55, 55, 256)  0           add_11[0][0]                     
__________________________________________________________________________________________________
res2c_branch2a (Conv2D)         (None, 55, 55, 64)   16448       activation_35[0][0]              
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_36 (Activation)      (None, 55, 55, 64)   0           bn2c_branch2a[0][0]              
__________________________________________________________________________________________________
res2c_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_36[0][0]              
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_37 (Activation)      (None, 55, 55, 64)   0           bn2c_branch2b[0][0]              
__________________________________________________________________________________________________
res2c_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_37[0][0]              
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2c_branch2c[0][0]             
__________________________________________________________________________________________________
add_12 (Add)                    (None, 55, 55, 256)  0           bn2c_branch2c[0][0]              
                                                                 activation_35[0][0]              
__________________________________________________________________________________________________
activation_38 (Activation)      (None, 55, 55, 256)  0           add_12[0][0]                     
__________________________________________________________________________________________________
res3a_branch2a (Conv2D)         (None, 28, 28, 128)  32896       activation_38[0][0]              
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_39 (Activation)      (None, 28, 28, 128)  0           bn3a_branch2a[0][0]              
__________________________________________________________________________________________________
res3a_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_39[0][0]              
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_40 (Activation)      (None, 28, 28, 128)  0           bn3a_branch2b[0][0]              
__________________________________________________________________________________________________
res3a_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_40[0][0]              
__________________________________________________________________________________________________
res3a_branch1 (Conv2D)          (None, 28, 28, 512)  131584      activation_38[0][0]              
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3a_branch2c[0][0]             
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512)  2048        res3a_branch1[0][0]              
__________________________________________________________________________________________________
add_13 (Add)                    (None, 28, 28, 512)  0           bn3a_branch2c[0][0]              
                                                                 bn3a_branch1[0][0]               
__________________________________________________________________________________________________
activation_41 (Activation)      (None, 28, 28, 512)  0           add_13[0][0]                     
__________________________________________________________________________________________________
res3b_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_41[0][0]              
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_42 (Activation)      (None, 28, 28, 128)  0           bn3b_branch2a[0][0]              
__________________________________________________________________________________________________
res3b_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_42[0][0]              
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_43 (Activation)      (None, 28, 28, 128)  0           bn3b_branch2b[0][0]              
__________________________________________________________________________________________________
res3b_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_43[0][0]              
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3b_branch2c[0][0]             
__________________________________________________________________________________________________
add_14 (Add)                    (None, 28, 28, 512)  0           bn3b_branch2c[0][0]              
                                                                 activation_41[0][0]              
__________________________________________________________________________________________________
activation_44 (Activation)      (None, 28, 28, 512)  0           add_14[0][0]                     
__________________________________________________________________________________________________
res3c_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_44[0][0]              
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_45 (Activation)      (None, 28, 28, 128)  0           bn3c_branch2a[0][0]              
__________________________________________________________________________________________________
res3c_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_45[0][0]              
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_46 (Activation)      (None, 28, 28, 128)  0           bn3c_branch2b[0][0]              
__________________________________________________________________________________________________
res3c_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_46[0][0]              
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3c_branch2c[0][0]             
__________________________________________________________________________________________________
add_15 (Add)                    (None, 28, 28, 512)  0           bn3c_branch2c[0][0]              
                                                                 activation_44[0][0]              
__________________________________________________________________________________________________
activation_47 (Activation)      (None, 28, 28, 512)  0           add_15[0][0]                     
__________________________________________________________________________________________________
res3d_branch2a (Conv2D)         (None, 28, 28, 128)  65664       activation_47[0][0]              
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 28, 28, 128)  512         res3d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_48 (Activation)      (None, 28, 28, 128)  0           bn3d_branch2a[0][0]              
__________________________________________________________________________________________________
res3d_branch2b (Conv2D)         (None, 28, 28, 128)  147584      activation_48[0][0]              
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 28, 28, 128)  512         res3d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_49 (Activation)      (None, 28, 28, 128)  0           bn3d_branch2b[0][0]              
__________________________________________________________________________________________________
res3d_branch2c (Conv2D)         (None, 28, 28, 512)  66048       activation_49[0][0]              
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 28, 28, 512)  2048        res3d_branch2c[0][0]             
__________________________________________________________________________________________________
add_16 (Add)                    (None, 28, 28, 512)  0           bn3d_branch2c[0][0]              
                                                                 activation_47[0][0]              
__________________________________________________________________________________________________
activation_50 (Activation)      (None, 28, 28, 512)  0           add_16[0][0]                     
__________________________________________________________________________________________________
res4a_branch2a (Conv2D)         (None, 14, 14, 256)  131328      activation_50[0][0]              
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_51 (Activation)      (None, 14, 14, 256)  0           bn4a_branch2a[0][0]              
__________________________________________________________________________________________________
res4a_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_51[0][0]              
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_52 (Activation)      (None, 14, 14, 256)  0           bn4a_branch2b[0][0]              
__________________________________________________________________________________________________
res4a_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_52[0][0]              
__________________________________________________________________________________________________
res4a_branch1 (Conv2D)          (None, 14, 14, 1024) 525312      activation_50[0][0]              
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4a_branch2c[0][0]             
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096        res4a_branch1[0][0]              
__________________________________________________________________________________________________
add_17 (Add)                    (None, 14, 14, 1024) 0           bn4a_branch2c[0][0]              
                                                                 bn4a_branch1[0][0]               
__________________________________________________________________________________________________
activation_53 (Activation)      (None, 14, 14, 1024) 0           add_17[0][0]                     
__________________________________________________________________________________________________
res4b_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_53[0][0]              
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_54 (Activation)      (None, 14, 14, 256)  0           bn4b_branch2a[0][0]              
__________________________________________________________________________________________________
res4b_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_54[0][0]              
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_55 (Activation)      (None, 14, 14, 256)  0           bn4b_branch2b[0][0]              
__________________________________________________________________________________________________
res4b_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_55[0][0]              
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4b_branch2c[0][0]             
__________________________________________________________________________________________________
add_18 (Add)                    (None, 14, 14, 1024) 0           bn4b_branch2c[0][0]              
                                                                 activation_53[0][0]              
__________________________________________________________________________________________________
activation_56 (Activation)      (None, 14, 14, 1024) 0           add_18[0][0]                     
__________________________________________________________________________________________________
res4c_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_56[0][0]              
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_57 (Activation)      (None, 14, 14, 256)  0           bn4c_branch2a[0][0]              
__________________________________________________________________________________________________
res4c_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_57[0][0]              
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_58 (Activation)      (None, 14, 14, 256)  0           bn4c_branch2b[0][0]              
__________________________________________________________________________________________________
res4c_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_58[0][0]              
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4c_branch2c[0][0]             
__________________________________________________________________________________________________
add_19 (Add)                    (None, 14, 14, 1024) 0           bn4c_branch2c[0][0]              
                                                                 activation_56[0][0]              
__________________________________________________________________________________________________
activation_59 (Activation)      (None, 14, 14, 1024) 0           add_19[0][0]                     
__________________________________________________________________________________________________
res4d_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_59[0][0]              
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4d_branch2a[0][0]             
__________________________________________________________________________________________________
activation_60 (Activation)      (None, 14, 14, 256)  0           bn4d_branch2a[0][0]              
__________________________________________________________________________________________________
res4d_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_60[0][0]              
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4d_branch2b[0][0]             
__________________________________________________________________________________________________
activation_61 (Activation)      (None, 14, 14, 256)  0           bn4d_branch2b[0][0]              
__________________________________________________________________________________________________
res4d_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_61[0][0]              
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4d_branch2c[0][0]             
__________________________________________________________________________________________________
add_20 (Add)                    (None, 14, 14, 1024) 0           bn4d_branch2c[0][0]              
                                                                 activation_59[0][0]              
__________________________________________________________________________________________________
activation_62 (Activation)      (None, 14, 14, 1024) 0           add_20[0][0]                     
__________________________________________________________________________________________________
res4e_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_62[0][0]              
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4e_branch2a[0][0]             
__________________________________________________________________________________________________
activation_63 (Activation)      (None, 14, 14, 256)  0           bn4e_branch2a[0][0]              
__________________________________________________________________________________________________
res4e_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_63[0][0]              
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4e_branch2b[0][0]             
__________________________________________________________________________________________________
activation_64 (Activation)      (None, 14, 14, 256)  0           bn4e_branch2b[0][0]              
__________________________________________________________________________________________________
res4e_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_64[0][0]              
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4e_branch2c[0][0]             
__________________________________________________________________________________________________
add_21 (Add)                    (None, 14, 14, 1024) 0           bn4e_branch2c[0][0]              
                                                                 activation_62[0][0]              
__________________________________________________________________________________________________
activation_65 (Activation)      (None, 14, 14, 1024) 0           add_21[0][0]                     
__________________________________________________________________________________________________
res4f_branch2a (Conv2D)         (None, 14, 14, 256)  262400      activation_65[0][0]              
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2a[0][0]             
__________________________________________________________________________________________________
activation_66 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2a[0][0]              
__________________________________________________________________________________________________
res4f_branch2b (Conv2D)         (None, 14, 14, 256)  590080      activation_66[0][0]              
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256)  1024        res4f_branch2b[0][0]             
__________________________________________________________________________________________________
activation_67 (Activation)      (None, 14, 14, 256)  0           bn4f_branch2b[0][0]              
__________________________________________________________________________________________________
res4f_branch2c (Conv2D)         (None, 14, 14, 1024) 263168      activation_67[0][0]              
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096        res4f_branch2c[0][0]             
__________________________________________________________________________________________________
add_22 (Add)                    (None, 14, 14, 1024) 0           bn4f_branch2c[0][0]              
                                                                 activation_65[0][0]              
__________________________________________________________________________________________________
activation_68 (Activation)      (None, 14, 14, 1024) 0           add_22[0][0]                     
__________________________________________________________________________________________________
res5a_branch2a (Conv2D)         (None, 7, 7, 512)    524800      activation_68[0][0]              
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5a_branch2a[0][0]             
__________________________________________________________________________________________________
activation_69 (Activation)      (None, 7, 7, 512)    0           bn5a_branch2a[0][0]              
__________________________________________________________________________________________________
res5a_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_69[0][0]              
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5a_branch2b[0][0]             
__________________________________________________________________________________________________
activation_70 (Activation)      (None, 7, 7, 512)    0           bn5a_branch2b[0][0]              
__________________________________________________________________________________________________
res5a_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_70[0][0]              
__________________________________________________________________________________________________
res5a_branch1 (Conv2D)          (None, 7, 7, 2048)   2099200     activation_68[0][0]              
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5a_branch2c[0][0]             
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048)   8192        res5a_branch1[0][0]              
__________________________________________________________________________________________________
add_23 (Add)                    (None, 7, 7, 2048)   0           bn5a_branch2c[0][0]              
                                                                 bn5a_branch1[0][0]               
__________________________________________________________________________________________________
activation_71 (Activation)      (None, 7, 7, 2048)   0           add_23[0][0]                     
__________________________________________________________________________________________________
res5b_branch2a (Conv2D)         (None, 7, 7, 512)    1049088     activation_71[0][0]              
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2a[0][0]             
__________________________________________________________________________________________________
activation_72 (Activation)      (None, 7, 7, 512)    0           bn5b_branch2a[0][0]              
__________________________________________________________________________________________________
res5b_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_72[0][0]              
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5b_branch2b[0][0]             
__________________________________________________________________________________________________
activation_73 (Activation)      (None, 7, 7, 512)    0           bn5b_branch2b[0][0]              
__________________________________________________________________________________________________
res5b_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_73[0][0]              
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5b_branch2c[0][0]             
__________________________________________________________________________________________________
add_24 (Add)                    (None, 7, 7, 2048)   0           bn5b_branch2c[0][0]              
                                                                 activation_71[0][0]              
__________________________________________________________________________________________________
activation_74 (Activation)      (None, 7, 7, 2048)   0           add_24[0][0]                     
__________________________________________________________________________________________________
res5c_branch2a (Conv2D)         (None, 7, 7, 512)    1049088     activation_74[0][0]              
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 7, 7, 512)    2048        res5c_branch2a[0][0]             
__________________________________________________________________________________________________
activation_75 (Activation)      (None, 7, 7, 512)    0           bn5c_branch2a[0][0]              
__________________________________________________________________________________________________
res5c_branch2b (Conv2D)         (None, 7, 7, 512)    2359808     activation_75[0][0]              
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 7, 7, 512)    2048        res5c_branch2b[0][0]             
__________________________________________________________________________________________________
activation_76 (Activation)      (None, 7, 7, 512)    0           bn5c_branch2b[0][0]              
__________________________________________________________________________________________________
res5c_branch2c (Conv2D)         (None, 7, 7, 2048)   1050624     activation_76[0][0]              
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048)   8192        res5c_branch2c[0][0]             
__________________________________________________________________________________________________
add_25 (Add)                    (None, 7, 7, 2048)   0           bn5c_branch2c[0][0]              
                                                                 activation_74[0][0]              
__________________________________________________________________________________________________
activation_77 (Activation)      (None, 7, 7, 2048)   0           add_25[0][0]                     
__________________________________________________________________________________________________
avg_pool (AveragePooling2D)     (None, 1, 1, 2048)   0           activation_77[0][0]              
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 2048)         0           avg_pool[0][0]                   
__________________________________________________________________________________________________
fc1000 (Dense)                  (None, 1000)         2049000     flatten_1[0][0]                  
==================================================================================================
Total params: 25,636,712
Trainable params: 25,583,592
Non-trainable params: 53,120
__________________________________________________________________________________________________

The reason why we chose ResNet50 is because the top layer of this network is a GAP layer, immediately followed by a fully connected layer with a softmax activation function that aims to classify our input images' classes, As we will soon see, this is essentially what CAM requires.

Given this pre-trained model, we will load some images to test it out. Note that for CAM to work, our model should be reasonably strong to begin with. We can't use a model that failed to predict the image was a dog and expect it to give us back the region it used to predict this image was a dog.

In [13]:
# change default figure and font size
plt.rcParams['figure.figsize'] = 8, 6
plt.rcParams['font.size'] = 12

# visualize the image that we'll be working with
img_dir = 'images'
img_path = os.path.join(img_dir, 'bmw.png')
img = image.load_img(img_path)
plt.imshow(img)
plt.show()