#Importing Packages and dataset
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Input,Conv2D,MaxPooling2D,UpSampling2D,Conv2DTranspose
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam,RMSprop
from tensorflow.keras import backend as K
import sklearn
from sklearn.model_selection import train_test_split
#Setting up the dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
#Reshaping the data as per the model requirement
x_train = x_train.reshape(-1, 28,28, 1)/255
x_test = x_test.reshape(-1, 28,28, 1)/255
#Splitting the data for validation
X,X_val,Y,Y_val = train_test_split(x_train,x_train,test_size=0.2)
#Define the autoencoder model
class AutoEncoder(tf.keras.Model):
def __init__(self):
super(AutoEncoder, self).__init__()
self.encoder = tf.keras.Sequential(
[
Conv2D(256, (3, 3), activation='relu', padding='same'),
Conv2D(128, (3, 3), activation='relu', padding='same'),
Conv2D(64, (3, 3), activation='relu', padding='same'),
MaxPooling2D(pool_size=(2, 2)),
Conv2D(32, (3, 3), activation='relu', padding='same'),
Conv2D(16, (3, 3), activation='relu', padding='same'),
Conv2D(8, (1, 1), activation='relu', padding='same'),
MaxPooling2D(pool_size=(2, 2)),
Conv2D(4, (3, 3), activation='relu', padding='same'),
Conv2D(2, (1, 1), activation='relu', padding='same'),
Conv2D(1, (1, 1), activation='relu', padding='same'),
]
)
self.decoder = tf.keras.Sequential(
[
Conv2D(2, (1, 1), activation='relu', padding='same'),
Conv2D(4, (3, 3), activation='relu', padding='same'),
Conv2D(8, (3, 3), activation='relu', padding='same'),
UpSampling2D(size=(2, 2)),
Conv2D(16, (3, 3), activation='relu', padding='same'),
Conv2D(32, (3, 3), activation='relu', padding='same'),
Conv2D(64, (3, 3), activation='relu', padding='same'),
UpSampling2D(size=(2, 2)),
Conv2D(128, (3, 3), activation='relu', padding='same'),
Conv2D(256, (3, 3), activation='relu', padding='same'),
Conv2D(1, (3, 3), activation='tanh', padding='same'),
]
)
self.train()
def call(self, inputs):
if self.e:
inputs = self.encoder(inputs)
if self.d:
inputs = self.decoder(inputs)
return inputs
def encode(self):
self.e = True
self.d = False
def decode(self):
self.e = False
self.d = True
def train(self):
self.e = True
self.d = True
#Create a object of model class
model = AutoEncoder()
#Compile the model with MSE loss and Adam optimizer
model.compile(loss='mean_squared_error', optimizer = Adam())
#Build the model
model.build(input_shape = (None,28,28,1))
#Print Summary
model.encoder.summary()
model.decoder.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 28, 28, 256) 2560
_________________________________________________________________
conv2d_1 (Conv2D) (None, 28, 28, 128) 295040
_________________________________________________________________
conv2d_2 (Conv2D) (None, 28, 28, 64) 73792
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 64) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 14, 14, 32) 18464
_________________________________________________________________
conv2d_4 (Conv2D) (None, 14, 14, 16) 4624
_________________________________________________________________
conv2d_5 (Conv2D) (None, 14, 14, 8) 136
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 8) 0
_________________________________________________________________
conv2d_6 (Conv2D) (None, 7, 7, 4) 292
_________________________________________________________________
conv2d_7 (Conv2D) (None, 7, 7, 2) 10
_________________________________________________________________
conv2d_8 (Conv2D) (None, 7, 7, 1) 3
=================================================================
Total params: 394,921
Trainable params: 394,921
Non-trainable params: 0
_________________________________________________________________
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_9 (Conv2D) (None, 7, 7, 2) 4
_________________________________________________________________
conv2d_10 (Conv2D) (None, 7, 7, 4) 76
_________________________________________________________________
conv2d_11 (Conv2D) (None, 7, 7, 8) 296
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 14, 14, 8) 0
_________________________________________________________________
conv2d_12 (Conv2D) (None, 14, 14, 16) 1168
_________________________________________________________________
conv2d_13 (Conv2D) (None, 14, 14, 32) 4640
_________________________________________________________________
conv2d_14 (Conv2D) (None, 14, 14, 64) 18496
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 28, 28, 64) 0
_________________________________________________________________
conv2d_15 (Conv2D) (None, 28, 28, 128) 73856
_________________________________________________________________
conv2d_16 (Conv2D) (None, 28, 28, 256) 295168
_________________________________________________________________
conv2d_17 (Conv2D) (None, 28, 28, 1) 2305
=================================================================
Total params: 396,009
Trainable params: 396,009
Non-trainable params: 0
_________________________________________________________________
model_saver = tf.keras.callbacks.ModelCheckpoint("model_weights.h5",
monitor='val_loss', verbose=1,
save_best_only=True,
save_weights_only=False,
mode='auto')
#Train the model
history = model.fit(X,Y,epochs = 3,validation_data = (X_val,Y_val), batch_size = 32,callbacks = [model_saver])
Epoch 1/3
2/1500 [..............................] - ETA: 41s - loss: 0.1102WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0149s vs `on_train_batch_end` time: 0.0375s). Check your callbacks.
1500/1500 [==============================] - ETA: 0s - loss: 0.0107WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0032s vs `on_test_batch_end` time: 0.0122s). Check your callbacks.
Epoch 00001: val_loss improved from inf to 0.00492, saving model to model_weights.h5
1500/1500 [==============================] - 83s 56ms/step - loss: 0.0107 - val_loss: 0.0049
Epoch 2/3
1500/1500 [==============================] - ETA: 0s - loss: 0.0038
Epoch 00002: val_loss improved from 0.00492 to 0.00326, saving model to model_weights.h5
1500/1500 [==============================] - 83s 55ms/step - loss: 0.0038 - val_loss: 0.0033
Epoch 3/3
1500/1500 [==============================] - ETA: 0s - loss: 0.0032
Epoch 00003: val_loss did not improve from 0.00326
1500/1500 [==============================] - 83s 55ms/step - loss: 0.0032 - val_loss: 0.0033
#Setup Encoding Mode
model.encode()
#Encode 10 sample images
encodings = model.predict(x_test[:10])
#Setup Decoding Mode
model.decode()
#Decode 10 sample images
decodings = model.predict(encodings)
#Plotting original and extracted images
w=10
h=10
fig=plt.figure(figsize=(10, 20))
columns = 2
rows = 10
j = 0
k = 0
for i in range(1, columns*rows +1):
if i%2 == 0:
img = decodings[j,...,0]
j+=1
else:
img = x_test[k,...,0]
k+=1
fig.add_subplot(rows, columns, i)
plt.imshow(img,cmap = 'gray')
plt.show()