CNN-Example(R)

Loading...

Deep Learning Using R with keras (CNN)

In this notebook, we will walk through how to use the keras R package for a toy example in deep learning with the hand written digits image dataset (i.e. MNIST). The purpose of the notebook is to have hands-on experience and get familar with the Converlutional Neural Network part of the training course. Please check the keras R package website for most recent development: https://keras.rstudio.com/

We are using databrick community edition with R as the interface for this deep learning training for audience with statistics background with the following consideration:

  • Minimum language barrier in coding for most statisticians
  • Zero setup to save time using cloud envriment
  • Get familar with current trend of cloud computing in coorprate setup

Last updated and tested: 2021-04-24 running on 8.1 ML cluster instance at community edition account.

1 Packages Download and Installation

In this example notebook, the keras R package is needed. As it has a many dependent packages to be download and install, it takes a few minutes to finish. Be patient! In a production cloud enviroment such as the paid version of Databricks, you can save what you have and resume from where you left. But in this free community edition of Databrics, we have to download and install needed packages everytime a new cluster is created or after every log in.

1.1 Download keras

The keras package is published to CRAN at 2018-04-29 and we can get it through CRAN by calling install.packages("keras"). As it is still in fast development stage, we can also get it directly from github for most recently update that might not be pushed to CRAN yet. Again, be patient, the following cell may take a few minutes to finish installing all dependencies.

## download and install keras package
 
packages_installed <- rownames(installed.packages())
if (c('keras') %in% packages_installed) {
    print("keras package is already installed, skip the installation step.")
}else{ 
  devtools::install_github("rstudio/keras") 
}
[1] "keras package is already installed, skip the installation step."

1.2 Load keras package and the requried tensorflow backend

As keras is just an interface to popular deep learning frameworks, we have to install a specfic deep learning backend. The default and recommended backend is TensorFlow. The following cell takes around one minute to run. At Keras Rstudio home page, it mentioned to call install_keras() after library(keras). But with the runtime version and time of running, there will be error message if we include install_keras(). So we commented that line out and it runs fine as 05/29/2020.

library(keras)
install_keras()
Requirement already up-to-date: tensorflow==2.6.* in /databricks/python3/lib/python3.8/site-packages (2.6.0)
Requirement already up-to-date: tensorflow-hub in /databricks/python3/lib/python3.8/site-packages (0.12.0)
Requirement already up-to-date: scipy in /databricks/python3/lib/python3.8/site-packages (1.7.1)
Requirement already up-to-date: requests in /databricks/python3/lib/python3.8/site-packages (2.26.0)
Requirement already up-to-date: Pillow in /databricks/python3/lib/python3.8/site-packages (8.3.1)
Requirement already up-to-date: h5py in /databricks/python3/lib/python3.8/site-packages (3.3.0)
Requirement already up-to-date: pandas in /databricks/python3/lib/python3.8/site-packages (1.3.2)
Requirement already satisfied, skipping upgrade: flatbuffers~=1.12.0 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (1.12)
Requirement already satisfied, skipping upgrade: gast==0.4.0 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (0.4.0)
Requirement already satisfied, skipping upgrade: opt-einsum~=3.3.0 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (3.3.0)
Requirement already satisfied, skipping upgrade: termcolor~=1.1.0 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (1.1.0)
Requirement already satisfied, skipping upgrade: clang~=5.0 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (5.0)
Requirement already satisfied, skipping upgrade: protobuf>=3.9.2 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (3.17.3)
Requirement already satisfied, skipping upgrade: grpcio<2.0,>=1.37.0 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (1.39.0)
Requirement already satisfied, skipping upgrade: typing-extensions~=3.7.4 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (3.7.4.3)
Requirement already satisfied, skipping upgrade: tensorboard~=2.6 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (2.6.0)
Requirement already satisfied, skipping upgrade: keras-preprocessing~=1.1.2 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (1.1.2)
Requirement already satisfied, skipping upgrade: absl-py~=0.10 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (0.13.0)
Requirement already satisfied, skipping upgrade: wheel~=0.35 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (0.35.1)
Requirement already satisfied, skipping upgrade: tensorflow-estimator~=2.6 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (2.6.0)
Requirement already satisfied, skipping upgrade: google-pasta~=0.2 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (0.2.0)
Requirement already satisfied, skipping upgrade: keras~=2.6 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (2.6.0)
Requirement already satisfied, skipping upgrade: astunparse~=1.6.3 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (1.6.3)
Requirement already satisfied, skipping upgrade: numpy~=1.19.2 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (1.19.2)
Requirement already satisfied, skipping upgrade: six~=1.15.0 in /usr/local/lib/python3.8/site-packages (from tensorflow==2.6.*) (1.15.0)
Requirement already satisfied, skipping upgrade: wrapt~=1.12.1 in /databricks/python3/lib/python3.8/site-packages (from tensorflow==2.6.*) (1.12.1)
Requirement already satisfied, skipping upgrade: certifi>=2017.4.17 in /databricks/python3/lib/python3.8/site-packages (from requests) (2020.12.5)
Requirement already satisfied, skipping upgrade: charset-normalizer~=2.0.0; python_version >= "3" in /databricks/python3/lib/python3.8/site-packages (from requests) (2.0.4)
Requirement already satisfied, skipping upgrade: urllib3<1.27,>=1.21.1 in /databricks/python3/lib/python3.8/site-packages (from requests) (1.25.11)
Requirement already satisfied, skipping upgrade: idna<4,>=2.5; python_version >= "3" in /databricks/python3/lib/python3.8/site-packages (from requests) (2.10)
Requirement already satisfied, skipping upgrade: python-dateutil>=2.7.3 in /databricks/python3/lib/python3.8/site-packages (from pandas) (2.8.1)
Requirement already satisfied, skipping upgrade: pytz>=2017.3 in /databricks/python3/lib/python3.8/site-packages (from pandas) (2020.5)
Requirement already satisfied, skipping upgrade: setuptools>=41.0.0 in /databricks/python3/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow==2.6.*) (50.3.1)
Requirement already satisfied, skipping upgrade: tensorboard-data-server<0.7.0,>=0.6.0 in /databricks/python3/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow==2.6.*) (0.6.1)
Requirement already satisfied, skipping upgrade: tensorboard-plugin-wit>=1.6.0 in /databricks/python3/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow==2.6.*) (1.8.0)
Requirement already satisfied, skipping upgrade: markdown>=2.6.8 in /databricks/python3/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow==2.6.*) (3.3.4)
Requirement already satisfied, skipping upgrade: werkzeug>=0.11.15 in /databricks/python3/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow==2.6.*) (2.0.1)
Requirement already satisfied, skipping upgrade: google-auth<2,>=1.6.3 in /databricks/python3/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow==2.6.*) (1.35.0)
Requirement already satisfied, skipping upgrade: google-auth-oauthlib<0.5,>=0.4.1 in /databricks/python3/lib/python3.8/site-packages (from tensorboard~=2.6->tensorflow==2.6.*) (0.4.5)
Requirement already satisfied, skipping upgrade: pyasn1-modules>=0.2.1 in /databricks/python3/lib/python3.8/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.6->tensorflow==2.6.*) (0.2.8)
Requirement already satisfied, skipping upgrade: rsa<5,>=3.1.4; python_version >= "3.6" in /databricks/python3/lib/python3.8/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.6->tensorflow==2.6.*) (4.7.2)
Requirement already satisfied, skipping upgrade: cachetools<5.0,>=2.0.0 in /databricks/python3/lib/python3.8/site-packages (from google-auth<2,>=1.6.3->tensorboard~=2.6->tensorflow==2.6.*) (4.2.2)
Requirement already satisfied, skipping upgrade: requests-oauthlib>=0.7.0 in /databricks/python3/lib/python3.8/site-packages (from google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.6->tensorflow==2.6.*) (1.3.0)
Requirement already satisfied, skipping upgrade: pyasn1<0.5.0,>=0.4.6 in /databricks/python3/lib/python3.8/site-packages (from pyasn1-modules>=0.2.1->google-auth<2,>=1.6.3->tensorboard~=2.6->tensorflow==2.6.*) (0.4.8)
Requirement already satisfied, skipping upgrade: oauthlib>=3.0.0 in /databricks/python3/lib/python3.8/site-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tensorboard~=2.6->tensorflow==2.6.*) (3.1.1)
WARNING: You are using pip version 20.2.4; however, version 21.2.4 is available.
You should consider upgrading via the '/databricks/python3/bin/python -m pip install --upgrade pip' command.
Using virtual environment '/databricks/python3' ...

Installation complete.

Now we are all set to explore deep learning! As simple as three lines of R code, but there are quite a lot going on behind the scene. One of the advantage of cloud enviroment is that we do not need to worry about these behind scene setup and maintenance.

2 Overview for MNIST Dataset

In deep learning, one of the first sucessfuly application that is better than traditonal machine learning areas is image recognition. We will use the widely used MNIST handwritten digit image dataset for this tutorial. More information about the dataset and benchmark results from various machine learning methods can be found at: http://yann.lecun.com/exdb/mnist/ and https://en.wikipedia.org/wiki/MNIST_database

2.1 Load MNIST dataset

This dataset is already included in the keras/tensorflow installation and we can simply load the dataset as described in the following cell. It takes less than a minute to load the dataset.

mnist <- dataset_mnist()
2021-08-20 20:29:32.304643: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /databricks/python3/lib:/databricks/python3/lib:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/zulu8-ca-amd64/jre//lib/server:/opt/simba/sparkodbc/lib/64/
2021-08-20 20:29:32.304705: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.

2.2 Training and testing datasets

The data structure of the MNIST dataset is straight forward and well prepared for R, which has two pieces: (1) training set: x (i.e. features): 60000x28x28 tensor which corresponds to 60000 28x28 pixel images with grey scale representation (i.e. all the values are integer between 0 and 255 in each 28x28 matrix), and y (i.e. responses): a length 60000 vector which contains the corresponding digits with integer values between 0 and 9. (2) testing set: same as the training set, but with only 10000 images and responses. Detailed structure for the dataset can be seen with str(mnist) below.

str(mnist)
List of 2
 $ train:List of 2
  ..$ x: int [1:60000, 1:28, 1:28] 0 0 0 0 0 0 0 0 0 0 ...
  ..$ y: int [1:60000(1d)] 5 0 4 1 9 2 1 3 1 4 ...
 $ test :List of 2
  ..$ x: int [1:10000, 1:28, 1:28] 0 0 0 0 0 0 0 0 0 0 ...
  ..$ y: int [1:10000(1d)] 7 2 1 0 4 1 4 9 5 9 ...

Now we prepare the features (x) and response variable (y) for both the training and testing dataset, and we can check the structure of the x_train and y_train using str() function.

x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y
 
str(x_train)
str(y_train)
 int [1:60000, 1:28, 1:28] 0 0 0 0 0 0 0 0 0 0 ...
 int [1:60000(1d)] 5 0 4 1 9 2 1 3 1 4 ...

2.3 Plot an image

Now let's plot a chosen 28x28 matrix as an images using R's image() function. In R's image() function, the way of showing an image is rotated 90 degree from the matrix representation. So there is additonal steps to rearrange the matrix such that we can use image() function to show it in the actual orientation.

index_image = 28 ## change this index to see different image.
input_matrix <- x_train[index_image,1:28,1:28]
output_matrix <- apply(input_matrix, 2, rev)
output_matrix <- t(output_matrix)
image(1:28, 1:28, output_matrix, col=gray.colors(256), xlab=paste('Image for digit of: ', y_train[index_image]), ylab="")

Here is the original 28x28 matrix for the above image:

input_matrix
      [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10] [,11] [,12] [,13]
 [1,]    0    0    0    0    0    0    0    0    0     0     0     0     0
 [2,]    0    0    0    0    0    0    0    0    0     0     0     0     0
 [3,]    0    0    0    0    0    0    0    0    0     0     0     0     0
 [4,]    0    0    0    0    0    0    0    0    0     0     0     0     0
 [5,]    0    0    0    0    0    0    0    0    0     0     0     0     0
 [6,]    0    0    0    0    0    0    0    0    0     0     9    80   207
 [7,]    0    0    0    0    0    0   39  158  158   158   168   253   253
 [8,]    0    0    0    0    0    0  226  253  253   253   253   253   253
 [9,]    0    0    0    0    0    0  139  253  253   253   238   113   215
[10,]    0    0    0    0    0    0   39   34   34    34    30     0    31
[11,]    0    0    0    0    0    0   91    0    0     0     0     0     0
[12,]    0    0    0    0    0    0    0    0    0     0     0    11    33
[13,]    0    0    0    0    0    0    0    0    0     0    11   167   253
[14,]    0    0    0    0    0    0    0    0    0     0    27   253   253
[15,]    0    0    0    0    0    0    0    0    0     0    18   201   253
[16,]    0    0    0    0    0    0    0    0    0     0     0    36    87
[17,]    0    0    0    0    0    0    0    0    0     0     0     0     0
[18,]    0    0    0    0    0    0    0    0    0     0     0     0     0
[19,]    0    0    0    0    0    0    0    0    0     0     0     0     0
[20,]    0    0    0    0    0    0    0    0    0     0     0     0     0
[21,]    0    0    0    0   66  211  211  211   59    36    36    21    26
[22,]    0    0    0    0   80  253  253  253  253   253   253   195   215
[23,]    0    0    0    0   80  253  253  253  253   253   253   253   253
[24,]    0    0    0    0   49  156  247  253  253   253   253   253   253
[25,]    0    0    0    0    0    0  116  253  253   253   253   253   126
[26,]    0    0    0    0    0    0    0    0    0     0     0     0     0
[27,]    0    0    0    0    0    0    0    0    0     0     0     0     0
[28,]    0    0    0    0    0    0    0    0    0     0     0     0     0
      [,14] [,15] [,16] [,17] [,18] [,19] [,20] [,21] [,22] [,23] [,24] [,25]
 [1,]     0     0     0     0     0     0     0     0     0     0     0     0
 [2,]     0     0     0     0     0     0     0     0     0     0     0     0
 [3,]     0     0     0     0     0     0     0     0     0     0     0     0
 [4,]     0     0     0     0     0     0     0     0     0     0     0     0
 [5,]     0     0     0     0     0     0     0     0     0     0     0     0
 [6,]   255   254   254   254    97    80    80    44     0     0     0     0
 [7,]   253   253   253   253   253   253   253   210    38     0     0     0
 [8,]   253   253   253   253   253   253   253   253   241   146     0     0
 [9,]   253   253   253   253   253   253   253   253   253   210    43     0
[10,]   148    34   204   235   253   253   253   253   253   236    64     0
[11,]     0     0    35   199   253   253   253   253   244    81     0     0
[12,]   202   202   216   253   253   253   253   241    89     0     0     0
[13,]   253   253   253   253   253   253   238    82     0     0     0     0
[14,]   253   253   253   253   253   253    96     0     0     0     0     0
[15,]   253   253   253   253   253   253   230    49     0     0     0     0
[16,]    87    87   248   253   253   253   253   138     0     0     0     0
[17,]     0     0     7   152   253   253   253   250    59     0     0     0
[18,]     0     0     0    62   238   253   253   253    60     0     0     0
[19,]     0     0     0    32   233   253   253   150     6     0     0     0
[20,]     0     0    37   203   253   253   253   138     0     0     0     0
[21,]    36   151   222   253   253   253   253   138     0     0     0     0
[22,]   253   253   253   253   253   253   157    77     0     0     0     0
[23,]   253   253   253   253   237   235    40     0     0     0     0     0
[24,]   253   253   159   156    16     0     0     0     0     0     0     0
[25,]    78    78     3     0     0     0     0     0     0     0     0     0
[26,]     0     0     0     0     0     0     0     0     0     0     0     0
[27,]     0     0     0     0     0     0     0     0     0     0     0     0
[28,]     0     0     0     0     0     0     0     0     0     0     0     0
      [,26] [,27] [,28]
 [1,]     0     0     0
 [2,]     0     0     0
 [3,]     0     0     0
 [4,]     0     0     0
 [5,]     0     0     0
 [6,]     0     0     0
 [7,]     0     0     0
 [8,]     0     0     0
 [9,]     0     0     0
[10,]     0     0     0
[11,]     0     0     0
[12,]     0     0     0
[13,]     0     0     0
[14,]     0     0     0
[15,]     0     0     0
[16,]     0     0     0
[17,]     0     0     0
[18,]     0     0     0
[19,]     0     0     0
[20,]     0     0     0
[21,]     0     0     0
[22,]     0     0     0
[23,]     0     0     0
[24,]     0     0     0
[25,]     0     0     0
[26,]     0     0     0
[27,]     0     0     0
[28,]     0     0     0

3 Convolutional Neural Network Model

In this section, we will show how to use Convolutional Neural Network (CNN) for the MNIST handwritten dataset to classcify images into digits. It is exactly the same problem as what we just learnt before, but CNN is a much better deep learning methods for image recognition than a generic deep neural network. CNN leverages the relatinoship among neighbor pixcels in the 2D image for better performance. It also avoids generating thousands or millions of features for high resolution images with full color. This example is described at: https://keras.rstudio.com/articles/examples/mnist_cnn.html

3.1 Dataset import and parameter setup

Now let's import the MNIST dataset from scratch again as we have done some preprocessing specifically for a deep neural network model at last section. For CNN, there are different preprocessing steps invovled. We also define a few parameters to be used later.

# Load the mnist data's training and testing dataset
mnist <- dataset_mnist()
x_train <- mnist$train$x
y_train <- mnist$train$y
x_test <- mnist$test$x
y_test <- mnist$test$y
# Define a few parameters to be used in the CNN model
batch_size <- 128
num_classes <- 10
epochs <- 10
 
# Input image dimensions
img_rows <- 28
img_cols <- 28

3.2 Data preprocessing

For CNN method in general, the input of a MxN image is a MxNxK 3D arrays with K specific channels. For example, a greyscale MxN image has only one channel, and the input is MxNx1 tensor. An MXN 8-bit per channel RGB image has three channels with 3 MxN array with values between 0 and 255, so the input is MxNx3 tensor. For the problem that we have now, the image is greyscale, but we need to specifically define there is one channel by reshape the 2D array into 3D tensor using array_reshape(). The input_shape variable will be used in the CNN model later. For RGB color image, the nmber of channels is 3 and we need to replace "1" with "3" for the code cell below if the input image is RGB format.

3.2.1 Add channel into the dimension

x_train <- array_reshape(x_train, c(nrow(x_train), img_rows, img_cols, 1))
x_test <- array_reshape(x_test, c(nrow(x_test), img_rows, img_cols, 1))
input_shape <- c(img_rows, img_cols, 1)

Here is the structure of the reshaped image, the first dimension is the image index, the 2-4 dimension is a 3D tensor even though there is only one channel.

str(x_train)
 int [1:60000, 1:28, 1:28, 1] 0 0 0 0 0 0 0 0 0 0 ...

3.2.2 Scaling

Same as the DNN model, we scale the input values to be between 0 and 1 for the same numerical stability consideration in the optimization process.

x_train <- x_train / 255
x_test <- x_test / 255

3.2.3 Convert response to categorical variable

Same as DNN model, the response variable is converted into categorical.

# Convert class vectors to binary class matrices
y_train <- to_categorical(y_train, num_classes)
y_test <- to_categorical(y_test, num_classes)

3.3 Fit a CNN model

As we discussed, CNN model contains a series 2D convolutional layers which contains a few parameters: (1) the kernal_size which is typically 3x3 or 5x5; (2) the number of filters, which corresponding to the number of channels (i.e. the 3rd dimension) in the output tensor; (3) activation funtion. For the first layer, there is also an input_shape parameter which is the input image size and channel. To prevent overfitting and speed up computation, a pooling layer is usually applied after one or a few 2D convolutional layers. A typial pooling of return the maximum of a 2x2 pool_size as the new value in the output which essentially reduce the size to half. Dropout can be used as well in addtion to pooling neighbor values. After a few 2D convolutional layers, we also need 'flattern' the 3D tensor output into 1D tensor, and then add one or a couple of dense layers to connect the output from 2D convolutional layers to the target response classes.

3.3.1 Define a CNN model structure

Now we define a CNN model with two 2D convolutional layers with max pooling, and the 2nd layer with additonal dropout to prevent overfitting. Then flatten the output and use two dense layers to connect to the categoires of the image.

# define model structure 
cnn_model <- keras_model_sequential() %>%
  layer_conv_2d(filters = 32, kernel_size = c(3,3), activation = 'relu', input_shape = input_shape) %>% 
  layer_max_pooling_2d(pool_size = c(2, 2)) %>% 
  layer_conv_2d(filters = 64, kernel_size = c(3,3), activation = 'relu') %>% 
  layer_max_pooling_2d(pool_size = c(2, 2)) %>% 
  layer_dropout(rate = 0.25) %>% 
  layer_flatten() %>% 
  layer_dense(units = 128, activation = 'relu') %>% 
  layer_dropout(rate = 0.5) %>% 
  layer_dense(units = num_classes, activation = 'softmax')
2021-08-20 20:29:40.357696: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /databricks/python3/lib:/databricks/python3/lib:/usr/lib/R/lib:/usr/lib/x86_64-linux-gnu:/usr/lib/jvm/zulu8-ca-amd64/jre//lib/server:/opt/simba/sparkodbc/lib/64/
2021-08-20 20:29:40.357756: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)
2021-08-20 20:29:40.357799: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (0820-162012-all288-10-172-173-42): /proc/driver/nvidia/version does not exist
2021-08-20 20:29:40.358213: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
summary(cnn_model)
Model: "sequential"
________________________________________________________________________________
Layer (type)                        Output Shape                    Param #     
================================================================================
conv2d_1 (Conv2D)                   (None, 26, 26, 32)              320         
________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)      (None, 13, 13, 32)              0           
________________________________________________________________________________
conv2d (Conv2D)                     (None, 11, 11, 64)              18496       
________________________________________________________________________________
max_pooling2d (MaxPooling2D)        (None, 5, 5, 64)                0           
________________________________________________________________________________
dropout_1 (Dropout)                 (None, 5, 5, 64)                0           
________________________________________________________________________________
flatten (Flatten)                   (None, 1600)                    0           
________________________________________________________________________________
dense_1 (Dense)                     (None, 128)                     204928      
________________________________________________________________________________
dropout (Dropout)                   (None, 128)                     0           
________________________________________________________________________________
dense (Dense)                       (None, 10)                      1290        
================================================================================
Total params: 225,034
Trainable params: 225,034
Non-trainable params: 0
________________________________________________________________________________

3.3.2 Compile the model

Similar to DNN model, we need to compile the defined CNN model.

# Compile model
cnn_model %>% compile(
  loss = loss_categorical_crossentropy,
  optimizer = optimizer_adadelta(),
  metrics = c('accuracy')
)

Train the model and save each epochs's history. Please note, as we are not using GPU, it takes a few minutes to finish. Please be patient while waiting for the results. The training time can be significantly reduced if running on GPU.

3.3.3 Train the model

Now, we can train the model with our processed data. Each epochs's history can be saved to track the progress. Please note, as we are not using GPU, it takes a few minutes to finish. Please be patient while waiting for the results. The training time can be significantly reduced if running on GPU.

# Train model
cnn_history <- cnn_model %>% fit(
  x_train, y_train,
  batch_size = batch_size,
  epochs = epochs,
  validation_split = 0.2
)
2021-08-20 20:29:41.309432: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 150528000 exceeds 10% of free system memory.
2021-08-20 20:29:41.490871: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/10

  1/375 [..............................] - ETA: 4:49 - loss: 2.2940 - accuracy: 0.1094
  2/375 [..............................] - ETA: 38s - loss: 2.3015 - accuracy: 0.1133 
  3/375 [..............................] - ETA: 48s - loss: 2.2820 - accuracy: 0.1484
  4/375 [..............................] - ETA: 44s - loss: 2.2747 - accuracy: 0.1602
  5/375 [..............................] - ETA: 42s - loss: 2.2536 - accuracy: 0.1859
  6/375 [..............................] - ETA: 41s - loss: 2.2419 - accuracy: 0.2018
  7/375 [..............................] - ETA: 40s - loss: 2.2218 - accuracy: 0.2266
  8/375 [..............................] - ETA: 40s - loss: 2.2055 - accuracy: 0.2275
  9/375 [..............................] - ETA: 39s - loss: 2.1821 - accuracy: 0.2439
 10/375 [..............................] - ETA: 39s - loss: 2.1597 - accuracy: 0.2516
 11/375 [..............................] - ETA: 38s - loss: 2.1402 - accuracy: 0.2614
 12/375 [..............................] - ETA: 38s - loss: 2.1149 - accuracy: 0.2695
 13/375 [>.............................] - ETA: 38s - loss: 2.0924 - accuracy: 0.2758
 14/375 [>.............................] - ETA: 38s - loss: 2.0649 - accuracy: 0.2896
 15/375 [>.............................] - ETA: 37s - loss: 2.0272 - accuracy: 0.3063
 16/375 [>.............................] - ETA: 37s - loss: 1.9894 - accuracy: 0.3203
 17/375 [>.............................] - ETA: 37s - loss: 1.9497 - accuracy: 0.3373
 18/375 [>.............................] - ETA: 37s - loss: 1.9168 - accuracy: 0.3503
 19/375 [>.............................] - ETA: 37s - loss: 1.8905 - accuracy: 0.3577
 20/375 [>.............................] - ETA: 37s - loss: 1.8617 - accuracy: 0.3711
 21/375 [>.............................] - ETA: 36s - loss: 1.8335 - accuracy: 0.3813
 22/375 [>.............................] - ETA: 36s - loss: 1.7978 - accuracy: 0.3938
 23/375 [>.............................] - ETA: 36s - loss: 1.7652 - accuracy: 0.4066
 24/375 [>.............................] - ETA: 36s - loss: 1.7337 - accuracy: 0.4186
 25/375 [=>............................] - ETA: 36s - loss: 1.7041 - accuracy: 0.4294
 26/375 [=>............................] - ETA: 36s - loss: 1.6840 - accuracy: 0.4366
 27/375 [=>............................] - ETA: 35s - loss: 1.6631 - accuracy: 0.4439
 28/375 [=>............................] - ETA: 35s - loss: 1.6370 - accuracy: 0.4523
 29/375 [=>............................] - ETA: 35s - loss: 1.6054 - accuracy: 0.4642
 30/375 [=>............................] - ETA: 35s - loss: 1.5831 - accuracy: 0.4719
 31/375 [=>............................] - ETA: 35s - loss: 1.5575 - accuracy: 0.4798
 32/375 [=>............................] - ETA: 35s - loss: 1.5366 - accuracy: 0.4868
 33/375 [=>............................] - ETA: 35s - loss: 1.5129 - accuracy: 0.4941
 34/375 [=>............................] - ETA: 35s - loss: 1.4897 - accuracy: 0.5018
 35/375 [=>............................] - ETA: 34s - loss: 1.4691 - accuracy: 0.5092
 36/375 [=>............................] - ETA: 34s - loss: 1.4448 - accuracy: 0.5180
 37/375 [=>............................] - ETA: 34s - loss: 1.4240 - accuracy: 0.5253
 38/375 [==>...........................] - ETA: 34s - loss: 1.4065 - accuracy: 0.5327
 39/375 [==>...........................] - ETA: 34s - loss: 1.3879 - accuracy: 0.5393
 40/375 [==>...........................] - ETA: 34s - loss: 1.3654 - accuracy: 0.5475
 41/375 [==>...........................] - ETA: 34s - loss: 1.3457 - accuracy: 0.5537
 42/375 [==>...........................] - ETA: 33s - loss: 1.3316 - accuracy: 0.5595
 43/375 [==>...........................] - ETA: 34s - loss: 1.3151 - accuracy: 0.5654
 44/375 [==>...........................] - ETA: 34s - loss: 1.2977 - accuracy: 0.5716
 45/375 [==>...........................] - ETA: 34s - loss: 1.2825 - accuracy: 0.5773
 46/375 [==>...........................] - ETA: 33s - loss: 1.2632 - accuracy: 0.5842
 47/375 [==>...........................] - ETA: 33s - loss: 1.2466 - accuracy: 0.5898
 48/375 [==>...........................] - ETA: 33s - loss: 1.2310 - accuracy: 0.5942
 49/375 [==>...........................] - ETA: 33s - loss: 1.2122 - accuracy: 0.6009
 50/375 [===>..........................] - ETA: 33s - loss: 1.1978 - accuracy: 0.6062
 51/375 [===>..........................] - ETA: 33s - loss: 1.1823 - accuracy: 0.6114
 52/375 [===>..........................] - ETA: 33s - loss: 1.1708 - accuracy: 0.6152
 53/375 [===>..........................] - ETA: 33s - loss: 1.1570 
*** WARNING: skipped 324324 bytes of output ***

384 - accuracy: 0.9885
321/375 [========================>.....] - ETA: 5s - loss: 0.0383 - accuracy: 0.9885
322/375 [========================>.....] - ETA: 5s - loss: 0.0383 - accuracy: 0.9885
323/375 [========================>.....] - ETA: 5s - loss: 0.0382 - accuracy: 0.9885
324/375 [========================>.....] - ETA: 5s - loss: 0.0382 - accuracy: 0.9885
325/375 [=========================>....] - ETA: 5s - loss: 0.0383 - accuracy: 0.9885
326/375 [=========================>....] - ETA: 5s - loss: 0.0382 - accuracy: 0.9885
327/375 [=========================>....] - ETA: 4s - loss: 0.0382 - accuracy: 0.9885
328/375 [=========================>....] - ETA: 4s - loss: 0.0382 - accuracy: 0.9885
329/375 [=========================>....] - ETA: 4s - loss: 0.0382 - accuracy: 0.9885
330/375 [=========================>....] - ETA: 4s - loss: 0.0382 - accuracy: 0.9885
331/375 [=========================>....] - ETA: 4s - loss: 0.0381 - accuracy: 0.9885
332/375 [=========================>....] - ETA: 4s - loss: 0.0381 - accuracy: 0.9885
333/375 [=========================>....] - ETA: 4s - loss: 0.0381 - accuracy: 0.9886
334/375 [=========================>....] - ETA: 4s - loss: 0.0380 - accuracy: 0.9886
335/375 [=========================>....] - ETA: 4s - loss: 0.0381 - accuracy: 0.9885
336/375 [=========================>....] - ETA: 3s - loss: 0.0384 - accuracy: 0.9885
337/375 [=========================>....] - ETA: 3s - loss: 0.0383 - accuracy: 0.9885
338/375 [==========================>...] - ETA: 3s - loss: 0.0383 - accuracy: 0.9885
339/375 [==========================>...] - ETA: 3s - loss: 0.0384 - accuracy: 0.9885
340/375 [==========================>...] - ETA: 3s - loss: 0.0383 - accuracy: 0.9885
341/375 [==========================>...] - ETA: 3s - loss: 0.0382 - accuracy: 0.9885
342/375 [==========================>...] - ETA: 3s - loss: 0.0383 - accuracy: 0.9886
343/375 [==========================>...] - ETA: 3s - loss: 0.0383 - accuracy: 0.9885
344/375 [==========================>...] - ETA: 3s - loss: 0.0385 - accuracy: 0.9885
345/375 [==========================>...] - ETA: 3s - loss: 0.0385 - accuracy: 0.9885
346/375 [==========================>...] - ETA: 2s - loss: 0.0385 - accuracy: 0.9885
347/375 [==========================>...] - ETA: 2s - loss: 0.0384 - accuracy: 0.9885
348/375 [==========================>...] - ETA: 2s - loss: 0.0383 - accuracy: 0.9885
349/375 [==========================>...] - ETA: 2s - loss: 0.0382 - accuracy: 0.9885
350/375 [===========================>..] - ETA: 2s - loss: 0.0383 - accuracy: 0.9885
351/375 [===========================>..] - ETA: 2s - loss: 0.0382 - accuracy: 0.9885
352/375 [===========================>..] - ETA: 2s - loss: 0.0382 - accuracy: 0.9885
353/375 [===========================>..] - ETA: 2s - loss: 0.0381 - accuracy: 0.9885
354/375 [===========================>..] - ETA: 2s - loss: 0.0380 - accuracy: 0.9885
355/375 [===========================>..] - ETA: 2s - loss: 0.0381 - accuracy: 0.9885
356/375 [===========================>..] - ETA: 1s - loss: 0.0381 - accuracy: 0.9884
357/375 [===========================>..] - ETA: 1s - loss: 0.0385 - accuracy: 0.9884
358/375 [===========================>..] - ETA: 1s - loss: 0.0385 - accuracy: 0.9884
359/375 [===========================>..] - ETA: 1s - loss: 0.0386 - accuracy: 0.9883
360/375 [===========================>..] - ETA: 1s - loss: 0.0387 - accuracy: 0.9883
361/375 [===========================>..] - ETA: 1s - loss: 0.0386 - accuracy: 0.9883
362/375 [===========================>..] - ETA: 1s - loss: 0.0386 - accuracy: 0.9883
363/375 [============================>.] - ETA: 1s - loss: 0.0386 - accuracy: 0.9883
364/375 [============================>.] - ETA: 1s - loss: 0.0387 - accuracy: 0.9883
365/375 [============================>.] - ETA: 1s - loss: 0.0388 - accuracy: 0.9883
366/375 [============================>.] - ETA: 0s - loss: 0.0388 - accuracy: 0.9883
367/375 [============================>.] - ETA: 0s - loss: 0.0388 - accuracy: 0.9883
368/375 [============================>.] - ETA: 0s - loss: 0.0388 - accuracy: 0.9883
369/375 [============================>.] - ETA: 0s - loss: 0.0387 - accuracy: 0.9883
370/375 [============================>.] - ETA: 0s - loss: 0.0388 - accuracy: 0.9882
371/375 [============================>.] - ETA: 0s - loss: 0.0387 - accuracy: 0.9883
372/375 [============================>.] - ETA: 0s - loss: 0.0387 - accuracy: 0.9883
373/375 [============================>.] - ETA: 0s - loss: 0.0386 - accuracy: 0.9883
374/375 [============================>.] - ETA: 0s - loss: 0.0387 - accuracy: 0.9883
375/375 [==============================] - ETA: 0s - loss: 0.0386 - accuracy: 0.9883
375/375 [==============================] - 41s 109ms/step - loss: 0.0386 - accuracy: 0.9883 - val_loss: 0.0347 - val_accuracy: 0.9902
plot(cnn_history)

The trained model accuracy can be evaluated on the testing dataset which is pretty good.

cnn_model %>% evaluate(x_test, y_test)
2021-08-20 20:36:35.463646: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 31360000 exceeds 10% of free system memory.

  1/313 [..............................] - ETA: 12s - loss: 0.0047 - accuracy: 1.0000
  4/313 [..............................] - ETA: 5s - loss: 0.0047 - accuracy: 1.0000 
  7/313 [..............................] - ETA: 5s - loss: 0.0033 - accuracy: 1.0000
 11/313 [>.............................] - ETA: 4s - loss: 0.0125 - accuracy: 0.9972
 15/313 [>.............................] - ETA: 4s - loss: 0.0157 - accuracy: 0.9937
 19/313 [>.............................] - ETA: 4s - loss: 0.0148 - accuracy: 0.9951
 22/313 [=>............................] - ETA: 4s - loss: 0.0169 - accuracy: 0.9943
 27/313 [=>............................] - ETA: 4s - loss: 0.0150 - accuracy: 0.9954
 32/313 [==>...........................] - ETA: 3s - loss: 0.0250 - accuracy: 0.9932
 39/313 [==>...........................] - ETA: 3s - loss: 0.0313 - accuracy: 0.9904
 46/313 [===>..........................] - ETA: 3s - loss: 0.0319 - accuracy: 0.9891
 53/313 [====>.........................] - ETA: 3s - loss: 0.0310 - accuracy: 0.9888
 59/313 [====>.........................] - ETA: 2s - loss: 0.0310 - accuracy: 0.9889
 65/313 [=====>........................] - ETA: 2s - loss: 0.0333 - accuracy: 0.9885
 72/313 [=====>........................] - ETA: 2s - loss: 0.0352 - accuracy: 0.9874
 79/313 [======>.......................] - ETA: 2s - loss: 0.0347 - accuracy: 0.9873
 86/313 [=======>......................] - ETA: 2s - loss: 0.0374 - accuracy: 0.9873
 93/313 [=======>......................] - ETA: 2s - loss: 0.0386 - accuracy: 0.9869
100/313 [========>.....................] - ETA: 2s - loss: 0.0377 - accuracy: 0.9869
107/313 [=========>....................] - ETA: 2s - loss: 0.0365 - accuracy: 0.9874
114/313 [=========>....................] - ETA: 1s - loss: 0.0387 - accuracy: 0.9874
121/313 [==========>...................] - ETA: 1s - loss: 0.0380 - accuracy: 0.9876
128/313 [===========>..................] - ETA: 1s - loss: 0.0372 - accuracy: 0.9875
135/313 [===========>..................] - ETA: 1s - loss: 0.0376 - accuracy: 0.9875
142/313 [============>.................] - ETA: 1s - loss: 0.0370 - accuracy: 0.9875
149/313 [=============>................] - ETA: 1s - loss: 0.0371 - accuracy: 0.9874
155/313 [=============>................] - ETA: 1s - loss: 0.0367 - accuracy: 0.9877
162/313 [==============>...............] - ETA: 1s - loss: 0.0352 - accuracy: 0.9882
169/313 [===============>..............] - ETA: 1s - loss: 0.0337 - accuracy: 0.9887
176/313 [===============>..............] - ETA: 1s - loss: 0.0324 - accuracy: 0.9892
183/313 [================>.............] - ETA: 1s - loss: 0.0313 - accuracy: 0.9896
190/313 [=================>............] - ETA: 1s - loss: 0.0323 - accuracy: 0.9895
197/313 [=================>............] - ETA: 1s - loss: 0.0312 - accuracy: 0.9898
204/313 [==================>...........] - ETA: 0s - loss: 0.0309 - accuracy: 0.9900
211/313 [===================>..........] - ETA: 0s - loss: 0.0317 - accuracy: 0.9898
218/313 [===================>..........] - ETA: 0s - loss: 0.0308 - accuracy: 0.9901
225/313 [====================>.........] - ETA: 0s - loss: 0.0299 - accuracy: 0.9904
232/313 [=====================>........] - ETA: 0s - loss: 0.0290 - accuracy: 0.9907
239/313 [=====================>........] - ETA: 0s - loss: 0.0282 - accuracy: 0.9910
245/313 [======================>.......] - ETA: 0s - loss: 0.0275 - accuracy: 0.9912
250/313 [======================>.......] - ETA: 0s - loss: 0.0270 - accuracy: 0.9914
257/313 [=======================>......] - ETA: 0s - loss: 0.0263 - accuracy: 0.9916
264/313 [========================>.....] - ETA: 0s - loss: 0.0259 - accuracy: 0.9917
271/313 [========================>.....] - ETA: 0s - loss: 0.0252 - accuracy: 0.9919
278/313 [=========================>....] - ETA: 0s - loss: 0.0246 - accuracy: 0.9921
285/313 [==========================>...] - ETA: 0s - loss: 0.0241 - accuracy: 0.9923
292/313 [==========================>...] - ETA: 0s - loss: 0.0235 - accuracy: 0.9925
299/313 [===========================>..] - ETA: 0s - loss: 0.0230 - accuracy: 0.9927
306/313 [============================>.] - ETA: 0s - loss: 0.0244 - accuracy: 0.9924
311/313 [============================>.] - ETA: 0s - loss: 0.0242 - accuracy: 0.9925
313/313 [==============================] - 3s 9ms/step - loss: 0.0242 - accuracy: 0.9925
      loss   accuracy 
0.02415877 0.99250001 

3.4 Model prediction

For any new images, after undergo with the same preprocessing, we can use the trained model to predict which digits the image belongs to.

# model prediction
tf_output = cnn_model %>% predict(x_test) %>% k_argmax()
cnn_pred = as.array(tf_output)
head(cnn_pred, n=50)
2021-08-20 20:37:10.980312: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 31360000 exceeds 10% of free system memory.
 [1] 7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4 9 6 6 5 4 0 7 4 0 1 3 1 3 4 7 2 7 1
[39] 2 1 1 7 4 2 3 5 1 2 4 4

3.5 Check a few mis-classcified images

Now let's check a few mis-classcified images to see whether human can do a better job than this simple CNN model.

## number of mis-classcified images
sum(cnn_pred != mnist$test$y)
[1] 75
missed_image = mnist$test$x[cnn_pred != mnist$test$y,,]
missed_digit = mnist$test$y[cnn_pred != mnist$test$y]
missed_pred = cnn_pred[cnn_pred != mnist$test$y]
index_image = 6 ## change this index to see different image.
input_matrix <- missed_image[index_image,1:28,1:28]
output_matrix <- apply(input_matrix, 2, rev)
output_matrix <- t(output_matrix)
image(1:28, 1:28, output_matrix, col=gray.colors(256), xlab=paste('Image for digit ', missed_digit[index_image], ', wrongly predicted as ', missed_pred[index_image]), ylab="")