The first step in training our own DBN is to construct our dataset. This section will show you how to transform the MNIST data into a convenient format that allows you to train a neural network, using some of TensorFlow 2's built-in functions for simplicity.
Let's start by loading the MNIST dataset in TensorFlow. As the MNIST data has been used for many deep learning benchmarks, TensorFlow2 already has convenient utilities for loading and formatting this data. To doo so, we need to first install the tensorflow-dataset library;
pip install tensorflow-datasets
After installing the package, we need to import it along with the required dependencies:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import matplotlib.phlot as plt
import numpy as np
import tesorflow.compat.v2 as tf
import tensorflow_datasets as tfds
Now we can download the MNIST data locally from Google Cloud Storage(GCS) using the builder functionality:
mnist_builder = tfds.builder("mnist")
mnist_builder.download_and_prepare()
The dataset will now be available on disk on our machine. As noted earlier, this data is divided into a training and test dataset, which you can verify by taking a look at the info command:
info = mnist_builder.info
print(info)
This gives the following output:
tfds.core.DatasetInfo(
name='mnist',
version=3.0.1
description='The MNIST database of handwitten digits.',
homepage='http://yann.lecun.com/exdb/mnist/',
features=FeaturesDict({
'image': Image(shape=(28, 28, 1), dtype=tf.unit8),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
}),
total_num_examples=70000,
splits={
'test':10000,
'train':60000,
},
supervised_keys=('image', 'label'),
citation="""@article{lecun2010mnist,
title={MNIST handwritten digit database},
author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
journal={ATT Labs [Online], Available: http://yann.lecun.com/exdb/mnist},
volume={2},
year={2010}
}""",
redistribution_info=,
)
As you can see, the test dataset has 10,000 examples, the training dataset has 60,000 examples, and the images are 28*28 pixels with a label from one of 10 classes (0 to 9).
Let's start by taking at the training dataset:
mnist_train = mnist_builder.as_dataset(split="train")
We can visually plot some examples using the show_examples function:
fig = tds.show_examples(info, mnist_train)
This gives the following figure:
You can also see more clearly here the grayscale edges on the numbers where the anti-aliasing was applied to the original dataset to make the edges seem less jagged (the colors have also been flipped from the original example in Figure 4.1).
We can also plot an individual image by taking one element from the dataset, reshaping it to a 28*28 array, casting it as a 32-bit float, and plotting it in grayscale:
flatten_image = partial(flatten_image, label=True)
for image, label in mnist_train.map(flatten_image).take(1):
plt.imshow(image.numpy().reshape(28,28).astype(np.float32),cmap=plt.get_cmap("gray"))
print("Label: %d" % label.numpy())
This gives the following figure:
This in nice for visual inspection, but for our experiments in this chapter, we will actually need to flatten these images into a vector. To do so, we can use the map() function, and verify that the dataset is now flattened; note that we also need to cast to a float for use in the RBM later. The RBM also assumes binary (0 or1) inputs, so we need to rescale the pixels, which range from 0 to 256 to the range 0 to 1:
def flatten_image(x, label=True):
if label:
return (tf.divide(tf.dtypes.cast(tf.reshape(x["image"],(1,28*28)), tf.float32), 256.0), x["label"])
else:
return (tf.divide(tf.dtypes.cast(tf.reshape(x["image"],(1,28*28)), tf.float32), 256.0))
for image, label in mnist_train.map(flatten_image).take(1):
plt.imshow(image.numpy().astype(np.float32), cmap=plt.get_cmap("gray"))
print("Label: %d" % label.numpy())
This gives a 784*1 vector, which is the "flattened" version of the pixels of the digit "4":
Now that we have the MNIST data as a series of vectors, we are ready to start implementing an RBM to process this data and ultimately create a model capable of generating new images.