Fix POM structure for vFW API
[demo.git] / vnfs / DAaaS / applications / sample-horovod-app / keras_mnist_advanced_modified.py
1 from __future__ import print_function
2 import keras
3 from keras.datasets import mnist
4 from keras.models import Sequential
5 from keras.layers import Dense, Dropout, Flatten
6 from keras.layers import Conv2D, MaxPooling2D
7 from keras.preprocessing.image import ImageDataGenerator
8 from keras import backend as K
9 import tensorflow as tf
10 import horovod.keras as hvd
11
12 # Horovod: initialize Horovod.
13 hvd.init()
14
15 # Horovod: pin GPU to be used to process local rank (one GPU per process)
16 config = tf.ConfigProto()
17 #config.gpu_options.allow_growth = True
18 #config.gpu_options.visible_device_list = str(hvd.local_rank())
19 K.set_session(tf.Session(config=config))
20
21 batch_size = 128
22 num_classes = 10
23
24 # Enough epochs to demonstrate learning rate warmup and the reduction of
25 # learning rate when training plateaues.
26 epochs = 24
27
28 # Input image dimensions
29 img_rows, img_cols = 28, 28
30
31 # The data, shuffled and split between train and test sets
32 (x_train, y_train), (x_test, y_test) = mnist.load_data()
33
34 # Determine how many batches are there in train and test sets
35 train_batches = len(x_train) // batch_size
36 test_batches = len(x_test) // batch_size
37
38 if K.image_data_format() == 'channels_first':
39     x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
40     x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
41     input_shape = (1, img_rows, img_cols)
42 else:
43     x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
44     x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
45     input_shape = (img_rows, img_cols, 1)
46
47 x_train = x_train.astype('float32')
48 x_test = x_test.astype('float32')
49 x_train /= 255
50 x_test /= 255
51 print('x_train shape:', x_train.shape)
52 print(x_train.shape[0], 'train samples')
53 print(x_test.shape[0], 'test samples')
54
55 # Convert class vectors to binary class matrices
56 y_train = keras.utils.to_categorical(y_train, num_classes)
57 y_test = keras.utils.to_categorical(y_test, num_classes)
58
59 model = Sequential()
60 model.add(Conv2D(32, kernel_size=(3, 3),
61                  activation='relu',
62                  input_shape=input_shape))
63 model.add(Conv2D(64, (3, 3), activation='relu'))
64 model.add(MaxPooling2D(pool_size=(2, 2)))
65 model.add(Dropout(0.25))
66 model.add(Flatten())
67 model.add(Dense(128, activation='relu'))
68 model.add(Dropout(0.5))
69 model.add(Dense(num_classes, activation='softmax'))
70
71 # Horovod: adjust learning rate based on number of GPUs.
72 opt = keras.optimizers.Adadelta(lr=1.0 * hvd.size())
73
74 # Horovod: add Horovod Distributed Optimizer.
75 opt = hvd.DistributedOptimizer(opt)
76
77 model.compile(loss=keras.losses.categorical_crossentropy,
78               optimizer=opt,
79               metrics=['accuracy'])
80
81 callbacks = [
82     # Horovod: broadcast initial variable states from rank 0 to all other processes.
83     # This is necessary to ensure consistent initialization of all workers when
84     # training is started with random weights or restored from a checkpoint.
85     hvd.callbacks.BroadcastGlobalVariablesCallback(0),
86
87     # Horovod: average metrics among workers at the end of every epoch.
88     #
89     # Note: This callback must be in the list before the ReduceLROnPlateau,
90     # TensorBoard or other metrics-based callbacks.
91     hvd.callbacks.MetricAverageCallback(),
92
93     # Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
94     # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
95     # the first five epochs. See https://arxiv.org/abs/1706.02677 for details.
96     hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, verbose=1),
97
98     # Reduce the learning rate if training plateaues.
99     keras.callbacks.ReduceLROnPlateau(patience=10, verbose=1),
100 ]
101
102 # Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
103 if hvd.rank() == 0:
104     callbacks.append(keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))
105
106 # Set up ImageDataGenerators to do data augmentation for the training images.
107 train_gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,
108                                height_shift_range=0.08, zoom_range=0.08)
109 test_gen = ImageDataGenerator()
110
111 # Train the model.
112 # Horovod: the training will randomly sample 1 / N batches of training data and
113 # 3 / N batches of validation data on every worker, where N is the number of workers.
114 # Over-sampling of validation data helps to increase probability that every validation
115 # example will be evaluated.
116 model.fit_generator(train_gen.flow(x_train, y_train, batch_size=batch_size),
117                     steps_per_epoch=train_batches // hvd.size(),
118                     callbacks=callbacks,
119                     epochs=epochs,
120                     verbose=1,
121                     validation_data=test_gen.flow(x_test, y_test, batch_size=batch_size),
122                     validation_steps=3 * test_batches // hvd.size())
123
124 # Evaluate the model on the full data set.
125 score = model.evaluate(x_test, y_test, verbose=0)
126 print('Test loss:', score[0])
127 print('Test accuracy:', score[1])