In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

from tqdm import tqdm_notebook
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
torch.manual_seed(42)

Out[1]:
<torch._C.Generator at 0x1acb5e36bf0>
In [2]:
# CIFAR images are 32x32
# We consider a simple 2x2 jigsaw puzzle.
# 0 1
# 3 2
perm_inds = [(0, 0), (16, 0), (16, 16), (0, 16)]
# Simply maps each pixel to [-1, 1]
img_mean = 0.5
img_std = 0.5

def permute2x2(images):
"""
Splits the images into 2x2 = 4 pieces and randomly permutes the pieces.
"""
p_images = torch.FloatTensor(images.size())
perms = torch.LongTensor(images.size()[0], 4)
for i in range(images.size()[0]):
p = torch.randperm(4)
for j in range(4):
sr, sc = perm_inds[j]
tr, tc = perm_inds[p[j]]
p_images[i, :, tr:tr+16, tc:tc+16] = images[i, :, sr:sr+16, sc:sc+16]
perms[i,:] = p
return(p_images, perms)

def restore2x2(p_images, perms):
"""
Restores the original image from the pieces and the given permutation.
"""
images = torch.FloatTensor(p_images.size())
for i in range(images.size()[0]):
for j in range(4):
sr, sc = perm_inds[j]
tr, tc = perm_inds[perms[i, j]]
images[i, :, sr:sr+16, sc:sc+16] = p_images[i, :, tr:tr+16, tc:tc+16]
return images

def perm2vecmat2x2(perms):
"""
Converts permutation vectors to vectorized assignment matrices.
"""
n = perms.size()[0]
mat = torch.zeros(n, 4, 4)
# m[i][j] : i is assigned to j
for i in range(n):
for k in range(4):
mat[i, k, perms[i, k]] = 1.
return mat.view(n, -1)

def vecmat2perm2x2(x):
"""
Converts vectorized assignment matrices back to permutation vectors.
Note: this function is compatible with GPU tensors.
"""
n = x.size()[0]
x = x.view(n, 4, 4)
_, ind = x.max(2)
return ind

def imshow(img, title=None):
"""
Displays a torch image.
"""
img = img * img_std + img_mean
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
if title != None:
plt.title(title)

In [3]:
batch_size = 32
dataset_dir = './data'

transform = transforms.Compose(
[transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize((img_mean, img_mean, img_mean), (img_std, img_std, img_std))])

# Plot some training samples.
shuffle=True, num_workers=0)
images, labels = next(dataiter)

p_images, perms = permute2x2(images)

# Check the implementation of per2vecmat and vecmat2perm.
assert(vecmat2perm2x2(perm2vecmat2x2(perms)).equal(perms))

# Show permuted images.
plt.figure()
imshow(torchvision.utils.make_grid(p_images))
# Show restored images.
plt.figure()
imshow(torchvision.utils.make_grid(restore2x2(p_images, perms)))

Files already downloaded and verified

In [4]:
# Prepare training, validation, and test samples.
validation_ratio = 0.1
total = len(train_set)
ind = list(range(total))
n_train = int(np.floor((1. - validation_ratio) * total))
train_ind, validation_ind = ind[:n_train], ind[n_train:]
train_subsampler = torch.utils.data.sampler.SubsetRandomSampler(train_ind)
validation_subsampler = torch.utils.data.sampler.SubsetRandomSampler(validation_ind)

sampler=train_subsampler, num_workers=0)
sampler=validation_subsampler, num_workers=0)


Number of training batches: 1407
Number of validation batches: 157

In [5]:
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

def sinkhorn(A, n_iter=4):
"""
Sinkhorn iterations.
"""
for i in range(n_iter):
A /= A.sum(dim=1, keepdim=True)
A /= A.sum(dim=2, keepdim=True)
return A

class SimpleConvNet(nn.Module):
"""
A simple convolutional neural network shared among all pieces.
"""
def __init__(self):
super().__init__()
# 3 x 16 x 16 input
self.conv1 = nn.Conv2d(3, 8, 3)
# 8 x 14 x 14
self.conv2 = nn.Conv2d(8, 8, 3)
self.conv2_bn = nn.BatchNorm2d(8)
# 8 x 12 x 12
self.pool1 = nn.MaxPool2d(2, 2)
# 8 x 6 x 6
self.conv3 = nn.Conv2d(8, 16, 3)
self.conv3_bn = nn.BatchNorm2d(16)
# 16 x 4 x 4
self.fc1 = nn.Linear(16 * 4 * 4, 128)
self.fc1_bn = nn.BatchNorm1d(128)
# 128-d features
self.fc2 = nn.Linear(128, 128)
self.fc2_bn = nn.BatchNorm1d(128)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2_bn(self.conv2(x)))
x = self.pool1(x)
x = F.relu(self.conv3_bn(self.conv3(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1_bn(self.fc1(x)))
x = F.relu(self.fc2_bn(self.fc2(x)))
return x

class JigsawNet(nn.Module):
"""
A neural network that solves 2x2 jigsaw puzzles.
"""
def __init__(self, sinkhorn_iter=0):
super().__init__()
self.conv_net = SimpleConvNet()
self.fc1 = nn.Linear(128 * 4, 256)
self.fc1_bn = nn.BatchNorm1d(256)
# 4 x 4 assigment matrix
self.fc2 = nn.Linear(256, 16)
self.sinkhorn_iter = sinkhorn_iter

def forward(self, x):
# Split the input into four pieces and pass them into the
# same convolutional neural network.
x0 = self.conv_net(x[:, :, 0:16, 0:16])
x1 = self.conv_net(x[:, :, 16:32, 0:16])
x2 = self.conv_net(x[:, :, 16:32, 16:32])
x3 = self.conv_net(x[:, :, 0:16, 16:32])
# Cat
x = torch.cat([x0, x1, x2, x3], dim=1)
# Dense layer
x = F.dropout(x, p=0.1, training=self.training)
x = F.relu(self.fc1_bn(self.fc1(x)))
x = F.sigmoid(self.fc2(x))
if self.sinkhorn_iter > 0:
x = x.view(-1, 4, 4)
x = sinkhorn(x, self.sinkhorn_iter)
x = x.view(-1, 16)
return x

In [6]:
# Test helper
def compute_acc(p_pred, p_true, average=True):
"""
We require that the location of all four pieces are correctly predicted.
Note: this function is compatible with GPU tensors.
"""
# Remember to cast to float.
n = torch.sum((torch.sum(p_pred == p_true, 1) == 4).float())
if average:
return n / p_pred.size()[0]
else:
return n

# Training process
n_epochs=40, save_file_name=None):
loss_history = []
val_loss_history = []
acc_history = []
val_acc_history = []
for epoch in range(n_epochs):
with tqdm_notebook(total=len(train_loader), desc="Epoch {}".format(epoch + 1), unit='b', leave=False) as pbar:
# Training phase
model.train()
running_loss = 0.
n_correct_pred = 0
n_samples = 0
for i, data in enumerate(train_loader, 0):
inputs, _ = data
x_in, perms = permute2x2(inputs)
y_in = perm2vecmat2x2(perms)
n_samples += inputs.size()[0]
if is_cuda_available:
x_in, y_in = Variable(x_in.cuda()), Variable(y_in.cuda())
perms = Variable(perms.cuda())
else:
x_in, y_in = Variable(x_in), Variable(y_in)
perms = Variable(perms)
outputs = model(x_in)
n_correct_pred += compute_acc(vecmat2perm2x2(outputs), perms, False).data[0]
loss = criterion(outputs, y_in)
loss.backward()
optimizer.step()
running_loss += loss.data[0] * x_in.size()[0]
pbar.update(1)
loss_history.append(running_loss / n_samples)
acc_history.append(n_correct_pred / n_samples)

# Validation phase
model.eval()
running_loss = 0.
n_correct_pred = 0
n_samples = 0
for i, data in enumerate(validation_loader, 0):
inputs, _ = data
x_in, perms = permute2x2(inputs)
y_in = perm2vecmat2x2(perms)
n_samples += inputs.size()[0]
if is_cuda_available:
x_in, y_in = Variable(x_in.cuda()), Variable(y_in.cuda())
perms = Variable(perms.cuda())
else:
x_in, y_in = Variable(x_in), Variable(y_in)
perms = Variable(perms)
outputs = model(x_in)
n_correct_pred += compute_acc(vecmat2perm2x2(outputs), perms, False).data[0]
loss = criterion(outputs, y_in)
running_loss += loss.data[0] * x_in.size()[0]
val_loss_history.append(running_loss / n_samples)
val_acc_history.append(n_correct_pred / n_samples)

# Update the progress bar.
print("Epoch {0:03d}: loss={1:.4f}, val_loss={2:.4f}, acc={3:.2%}, val_acc={4:.2%}".format(
epoch + 1, loss_history[-1], val_loss_history[-1], acc_history[-1], val_acc_history[-1]))
print('Training completed')
history = {
'loss': loss_history,
'val_loss': val_loss_history,
'acc': acc_history,
'val_acc': val_acc_history
}
# Save the model when requested.
if save_file_name is not None:
torch.save({
'history': history,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()
}, save_file_name)
return history

# Test process
# Compute the accuracy
running_acc = 0.
n = 0
model.eval()
for i, data in enumerate(test_loader, 0):
inputs, _ = data
x_in, perms = permute2x2(inputs)
y_in = perm2vecmat2x2(perms)
if is_cuda_available:
x_in, y_in = Variable(x_in.cuda()), Variable(y_in.cuda())
else:
x_in, y_in = Variable(x_in), Variable(y_in)
pred = model(x_in)
perms_pred = vecmat2perm2x2(pred.cpu().data)
running_acc += compute_acc(perms_pred, perms, False)
n += x_in.size()[0]
acc = running_acc / n
return acc

In [ ]:
n_epochs = 100
sinkhorn_iter = 5

# Create the neural network.
model = JigsawNet(sinkhorn_iter=sinkhorn_iter)
is_cuda_available = torch.cuda.is_available();
if is_cuda_available:
model.cuda()

n_params = 0
for p in model.parameters():
n_params += np.prod(p.size())
print('# of parameters: {}'.format(n_params))

# We use binary cross-entropy loss here.
criterion = nn.BCELoss()

# Train
save_file_name = 'jigsaw_cifar100_e{}_s{}.pk'.format(n_epochs, sinkhorn_iter)
n_epochs=n_epochs, save_file_name=save_file_name)

In [8]:
plt.figure()
plt.plot(history['loss'])
plt.plot(history['val_loss'])
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(['Train', 'Validation'])
plt.show()
plt.figure()
plt.plot(history['acc'])
plt.plot(history['val_acc'])
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(['Train', 'Validation'])
plt.show()

In [9]:
# Calculate accuracy
# Here training accuracy will be higher because dropout is disabled

Training accuracy: 0.8245777777777777
Validation accuracy: 0.7924
Test accuracy: 0.8001

In [10]:
# Let us try some test images.
test_images, _ = test_data_iter.next()
p_images, perms = permute2x2(test_images)

# Show permuted images.
plt.figure()
imshow(torchvision.utils.make_grid(p_images))
plt.title('Inputs')
plt.show()

model.eval()
if is_cuda_available:
pred = model(Variable(p_images.cuda()))
else:
pred = model(Variable(p_images))
perms_pred = vecmat2perm2x2(pred.cpu().data)

# Show restored images.
plt.figure()
imshow(torchvision.utils.make_grid(restore2x2(p_images, perms_pred)))
plt.title('Restored')
plt.show()