# Mianzhi Wang

Ph.D. in Electrical Engineering

# Teaching a Neural Network to Solve Jigsaw Puzzles

While working with graph structured data, I have been thinking about permutation invariance recently. Permutation invariance naturally arises from graph related problems. Two graphs are still considered the same even if we index their nodes in a different order. Many functions are not permutation invariant. For instance, the affine transform $f(\mathbf{x}) = \mathbf{A}\mathbf{x} + \mathbf{b}$ is not. On the other hand, $f(x,y,z) = xyz$ is permutation invariant as the ordering of the input does not affect the result.

I decided to conduct some experiments on neural networks to check if they can learn about permutation invariance. The experiment is extremely simple: split an image into puzzle pieces and feed them into the neural network, and let the neural network learn to output the correct ordering of these pieces that can be used to restore the original image. The neural network must be able to handle all possible permutations of the puzzle pieces and restore the order correctly. For simplicity, I use the CIFAR-100 dataset in this article and only consider 2x2 jigsaw puzzles.

You can view the HTML version of the notebook here, or download the notebook file here.

## The Neural Network

The neural network structure is very simple. Each puzzle piece is a 16x16 colored image, which will be first feed into the same convolutional neural network and become a feature vector. The four feature vectors are then concatenated into a single column vector. Finally, this single column vector is fed into fully connected layers to obtain the final output.

The tricky part is the encoding of the output. The output should encode a four-element permutation vector which assigns each piece to its correct location. Did the word "assign" give you any idea? Yes, a permutation vector can be encoded as an assignment matrix $\mathbf{A}$, where $A_{ij} = 1$ if and only if the $i$-th piece is assigned to the $j$-th location. For instance, if the permutation vector is $[0, 3, 2, 1]$, its assignment matrix representation is then given by

$\mathbf{A} = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & 0 & 1 & 0 \\ 0 & 1 & 0 & 0 \end{bmatrix}.$

Another nice property of the assignment matrix is that its elements are either zero or one. This enables us to use the sigmoid function as the activation function of the output layer and binary cross entropy loss as the loss function. Now, the output layer of the neural network will be a 16x1 vector representing the vectorized assignment matrix.

There is another issue. While the sigmoid function ensures that the outputs are within zero and one, it cannot enforce the constrains over rows and columns of assignment matrices. Unfortunately, it is not easy to force each element to be either 0 or 1. Nevertheless, it is possible to relax the constraints and obtain a "soft" assignment matrix satisfying that each column/row sums to one. Even with this relaxation, our neural network's output cannot guarantee this. Thankfully, because I was working on optimal transport related problems before, I happen to know a fix: Sinkhorn iterations[1] (if you are interested, definitely check out the lightspeed optimal transport distance computation paper[2]). The idea behind the Sinkhorn iteration is very simple: alternatingly normalize the rows and columns until convergence. Under certain conditions, a few iterations would be sufficient. Therefore, a fixed number (e.g., 5) of Sinkhorn iterations are applied after the sigmoid activation of the neural network. We then recover the permutations by taking argmax for each row.

## Implementation Details

The neural network is implemented using PyTorch.

The implementation of the Sinkhorn iterations is straight forward:

def sinkhorn(A, n_iter=4):
"""
Sinkhorn iterations.

:param A: (n_batches, d, d) tensor
:param n_iter: Number of iterations.
"""
for i in range(n_iter):
A /= A.sum(dim=1, keepdim=True)
A /= A.sum(dim=2, keepdim=True)
return A


The code for the neural network is given below:

class SimpleConvNet(nn.Module):
"""
A simple convolutional neural network shared among all pieces.
"""
def __init__(self):
super().__init__()
# 3 x 16 x 16 input
self.conv1 = nn.Conv2d(3, 8, 3)
# 8 x 14 x 14
self.conv2 = nn.Conv2d(8, 8, 3)
self.conv2_bn = nn.BatchNorm2d(8)
# 8 x 12 x 12
self.pool1 = nn.MaxPool2d(2, 2)
# 8 x 6 x 6
self.conv3 = nn.Conv2d(8, 16, 3)
self.conv3_bn = nn.BatchNorm2d(16)
# 16 x 4 x 4
self.fc1 = nn.Linear(16 * 4 * 4, 128)
self.fc1_bn = nn.BatchNorm1d(128)
# 128-d features
self.fc2 = nn.Linear(128, 128)
self.fc2_bn = nn.BatchNorm1d(128)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2_bn(self.conv2(x)))
x = self.pool1(x)
x = F.relu(self.conv3_bn(self.conv3(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1_bn(self.fc1(x)))
x = F.relu(self.fc2_bn(self.fc2(x)))
return x

class JigsawNet(nn.Module):
"""
A neural network that solves 2x2 jigsaw puzzles.
"""
def __init__(self, sinkhorn_iter=0):
super().__init__()
self.conv_net = SimpleConvNet()
self.fc1 = nn.Linear(128 * 4, 256)
self.fc1_bn = nn.BatchNorm1d(256)
# 4 x 4 assigment matrix
self.fc2 = nn.Linear(256, 16)
self.sinkhorn_iter = sinkhorn_iter

def forward(self, x):
# Split input into four pieces and pass them into the
# same convolutional neural network.
x0 = self.conv_net(x[:, :, 0:16, 0:16])
x1 = self.conv_net(x[:, :, 16:32, 0:16])
x2 = self.conv_net(x[:, :, 16:32, 16:32])
x3 = self.conv_net(x[:, :, 0:16, 16:32])
# Cat
x = torch.cat([x0, x1, x2, x3], dim=1)
# Dense layer
x = F.dropout(x, p=0.1, training=self.training)
x = F.relu(self.fc1_bn(self.fc1(x)))
x = F.sigmoid(self.fc2(x))
if self.sinkhorn_iter > 0:
x = x.view(-1, 4, 4)
x = sinkhorn(x, self.sinkhorn_iter)
x = x.view(-1, 16)
return x


## The Results

After training for 100 epochs using the Adam optimizer with batch size set to 32, the test accuracy is around 80% (the prediction of the ordering of all four pieces must be correct), which is not that bad. Here are the results sampled from the test set:

It can be observed that most of the puzzles are solved.

I also tried to train the neural network to solve 3x3 jigsaw puzzles on the CelebA dataset (the output of the network is a 9x9 assignment matrix). Here are some samples from the test set:

Overall the neural network did learn to handle the permutations to some extent.