Cifar-10 is a common dataset for object classification. The problem is to classify 32x32 RGB (thus 32x32x3=1024 dimensions) image into 10 classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, and truck.
This problem is more complex than MNIST classification task. This means network architecture for Cifar-10 tends to be larger (or/and deeper) than MNIST. (If you are a machine learning beginner, I recommend you to visit MNIST example before this page.)
Prerequisites for this example
- Download and locate Cifar-10 binary version dataset
Constructing Model
network<cross_entropy, adam> nn;
typedef convolutional_layer<activation::identity> conv;
typedef max_pooling_layer<relu> pool;
const int n_fmaps = 32;
const int n_fmaps2 = 64;
const int n_fc = 64;
nn << conv(32, 32, 5, 3, n_fmaps, padding::same)
<< pool(32, 32, n_fmaps, 2)
<< conv(16, 16, 5, n_fmaps, n_fmaps, padding::same)
<< pool(16, 16, n_fmaps, 2)
<< conv(8, 8, 5, n_fmaps, n_fmaps2, padding::same)
<< pool(8, 8, n_fmaps2, 2)
<< fully_connected_layer<activation::identity>(4 * 4 * n_fmaps2, n_fc)
<< fully_connected_layer<softmax>(n_fc, 10);
Loading Dataset
vector<label_t> train_labels, test_labels;
vector<vec_t> train_images, test_images;
for (int i = 1; i <= 5; i++) {
parse_cifar10(data_dir_path + "/data_batch_" + to_string(i) + ".bin",
&train_images, &train_labels, -1.0, 1.0, 0, 0);
}
parse_cifar10(data_dir_path + "/test_batch.bin",
&test_images, &test_labels, -1.0, 1.0, 0, 0);
Grid Search
One of the most important hyperparameter in deep learning is learning rate. To get stable and better result, let's try grid search for learning rate. The entire code for training cifar-10 is following:
#include <iostream>
#include "tiny_dnn/tiny_dnn.h"
using namespace tiny_dnn;
using namespace tiny_dnn::activation;
template <typename N>
nn << conv(32, 32, 5, 3,
n_fmaps, padding::same)
}
for (
int i = 1;
i <= 5;
i++) {
}
};
};
}
<<
"arg[0]: path_to_data (example:../data)" <<
endl;
<<
"arg[1]: learning rate (example:0.01)" <<
endl;
return -1;
}
}
Simple image utility class.
Definition image.h:94
compile this file and try various learning rate:
./train your-cifar-10-data-directory 10.0
./train your-cifar-10-data-directory 1.0
./train your-cifar-10-data-directory 0.1
./train your-cifar-10-data-directory 0.01
>Note: >If training is too slow, change n_training_epochs
, n_fmaps
and n_fmaps2
variables to smaller value.
If you see the following message, some network weights become infinite while training. Usually it implies too large learning rate.
[Warning]Detected infinite value in weight. stop learning.
You will get about 70% accuracy in learning rate=0.01.