Teaching a Neural Network to Solve Jigsaw Puzzles
While working with graph structured data, I have been thinking about permutation invariance recently. Permutation invariance naturally arises from graph related problems. Two graphs are still considered the same even if we index their nodes in a different order. Many functions are not permutation invariant. For instance, the affine transform is not. On the other hand, is permutation invariant as the ordering of the input does not affect the result.
I decided to conduct some experiments on neural networks to check if they can learn about permutation invariance. The experiment is extremely simple: split an image into puzzle pieces and feed them into the neural network, and let the neural network learn to output the correct ordering of these pieces that can be used to restore the original image. The neural network must be able to handle all possible permutations of the puzzle pieces and restore the order correctly. For simplicity, I use the CIFAR-100 dataset in this article and only consider 2x2 jigsaw puzzles.
The Neural Network
The neural network structure is very simple. Each puzzle piece is a 16x16 colored image, which will be first feed into the same convolutional neural network and become a feature vector. The four feature vectors are then concatenated into a single column vector. Finally, this single column vector is fed into fully connected layers to obtain the final output.
The tricky part is the encoding of the output. The output should encode a four-element permutation vector which assigns each piece to its correct location. Did the word "assign" give you any idea? Yes, a permutation vector can be encoded as an assignment matrix , where if and only if the -th piece is assigned to the -th location. For instance, if the permutation vector is , its assignment matrix representation is then given by
Another nice property of the assignment matrix is that its elements are either zero or one. This enables us to use the sigmoid function as the activation function of the output layer and binary cross entropy loss as the loss function. Now, the output layer of the neural network will be a 16x1 vector representing the vectorized assignment matrix.
There is another issue. While the sigmoid function ensures that the outputs are within zero and one, it cannot enforce the constrains over rows and columns of assignment matrices. Unfortunately, it is not easy to force each element to be either 0 or 1. Nevertheless, it is possible to relax the constraints and obtain a "soft" assignment matrix satisfying that each column/row sums to one. Even with this relaxation, our neural network's output cannot guarantee this. Thankfully, because I was working on optimal transport related problems before, I happen to know a fix: Sinkhorn iterations (if you are interested, definitely check out the lightspeed optimal transport distance computation paper). The idea behind the Sinkhorn iteration is very simple: alternatingly normalize the rows and columns until convergence. Under certain conditions, a few iterations would be sufficient. Therefore, a fixed number (e.g., 5) of Sinkhorn iterations are applied after the sigmoid activation of the neural network. We then recover the permutations by taking
argmax for each row.
The neural network is implemented using PyTorch.
The implementation of the Sinkhorn iterations is straight forward:
def sinkhorn(A, n_iter=4): """ Sinkhorn iterations. :param A: (n_batches, d, d) tensor :param n_iter: Number of iterations. """ for i in range(n_iter): A /= A.sum(dim=1, keepdim=True) A /= A.sum(dim=2, keepdim=True) return A
The code for the neural network is given below:
class SimpleConvNet(nn.Module): """ A simple convolutional neural network shared among all pieces. """ def __init__(self): super().__init__() # 3 x 16 x 16 input self.conv1 = nn.Conv2d(3, 8, 3) # 8 x 14 x 14 self.conv2 = nn.Conv2d(8, 8, 3) self.conv2_bn = nn.BatchNorm2d(8) # 8 x 12 x 12 self.pool1 = nn.MaxPool2d(2, 2) # 8 x 6 x 6 self.conv3 = nn.Conv2d(8, 16, 3) self.conv3_bn = nn.BatchNorm2d(16) # 16 x 4 x 4 self.fc1 = nn.Linear(16 * 4 * 4, 128) self.fc1_bn = nn.BatchNorm1d(128) # 128-d features self.fc2 = nn.Linear(128, 128) self.fc2_bn = nn.BatchNorm1d(128) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2_bn(self.conv2(x))) x = self.pool1(x) x = F.relu(self.conv3_bn(self.conv3(x))) x = x.view(-1, 16 * 4 * 4) x = F.relu(self.fc1_bn(self.fc1(x))) x = F.relu(self.fc2_bn(self.fc2(x))) return x class JigsawNet(nn.Module): """ A neural network that solves 2x2 jigsaw puzzles. """ def __init__(self, sinkhorn_iter=0): super().__init__() self.conv_net = SimpleConvNet() self.fc1 = nn.Linear(128 * 4, 256) self.fc1_bn = nn.BatchNorm1d(256) # 4 x 4 assigment matrix self.fc2 = nn.Linear(256, 16) self.sinkhorn_iter = sinkhorn_iter def forward(self, x): # Split input into four pieces and pass them into the # same convolutional neural network. x0 = self.conv_net(x[:, :, 0:16, 0:16]) x1 = self.conv_net(x[:, :, 16:32, 0:16]) x2 = self.conv_net(x[:, :, 16:32, 16:32]) x3 = self.conv_net(x[:, :, 0:16, 16:32]) # Cat x = torch.cat([x0, x1, x2, x3], dim=1) # Dense layer x = F.dropout(x, p=0.1, training=self.training) x = F.relu(self.fc1_bn(self.fc1(x))) x = F.sigmoid(self.fc2(x)) if self.sinkhorn_iter > 0: x = x.view(-1, 4, 4) x = sinkhorn(x, self.sinkhorn_iter) x = x.view(-1, 16) return x
After training for 100 epochs using the Adam optimizer with batch size set to 32, the test accuracy is around 80% (the prediction of the ordering of all four pieces must be correct), which is not that bad. Here are the results sampled from the test set:
It can be observed that most of the puzzles are solved.
I also tried to train the neural network to solve 3x3 jigsaw puzzles on the CelebA dataset (the output of the network is a 9x9 assignment matrix). Here are some samples from the test set:
Overall the neural network did learn to handle the permutations to some extent.
There are indeed many real world problems that require the neural network to be able to handle inputs with different permutations. As mentioned in the beginning of this article, graphs are such inputs. In addition to graphs, such inputs can also be sets, where the ordering does not matter. For instance, the inputs can be point clouds, which may be produces by LIDARs (can be used to aid the navigation of autonomous vehicles) or other 3D scanning devices. The jigsaw puzzle problem introduced above is just a simple experiment. For further readings on set-like inputs, you can check the following papers:
- O. Vinyals, S. Bengio, and M. Kudlur, "Order matters: Sequence to sequence for sets."
- C. R. Qi, H. Su, K. Mo, and L. J. Guibas, "PointNet: deep learning on point sets for 3D classification and segmentation."
- S. Ravanbakhsh, J. Schneider, and B. Poczos, "Deep learning with sets and point clouds."
As for combining the Sinkhorn operator with deep learning to handling permutations/matchings, I discovered a recent paper by Mena et. al., which will be presented at ICLR 2018:
- G. Mena, D. Belanger, S. Linderman, and J. Snoek, "Learning latent permutations with Gumbel-Sinkhorn networks," arXiv:1802.08665 [cs, stat], Feb. 2018.
R. Sinkhorn and P. Knopp, "Concerning nonnegative matrices and doubly stochastic matrices," Pacific J. Math., vol. 21, no. 2, pp. 343–348, 1967. ↩
M. Cuturi, "Sinkhorn Distances: Lightspeed Computation of Optimal Transport," in Advances in Neural Information Processing Systems 26, C. J. C. Burges, L. Bottou, M. Welling, Z. Ghahramani, and K. Q. Weinberger, Eds. Curran Associates, Inc., 2013, pp. 2292–2300. ↩