FastML

Machine learning made easy

Introduction to pointer networks

Pointer networks are a variation of the sequence-to-sequence model with attention. Instead of translating one sequence into another, they yield a succession of pointers to the elements of the input series. The most basic use of this is ordering the elements of a variable-length sequence or set.

Basic seq2seq is an LSTM encoder coupled with an LSTM decoder. It’s most often heard of in the context of machine translation: given a sentence in one language, the encoder turns it into a fixed-size representation. Decoder transforms this into a sentence again, possibly of different length than the source. For example, “como estas?” - two words - would be translated to “how are you?” - three words.

The model gives better results when augmented with attention. Practically it means that the decoder can look back and forth over input. Specifically, it has access to encoder states from each step, not just the last one. Consider how it may help with Spanish, in which adjectives go before nouns: “neural network” becomes “red neuronal”.

In technical terms, attention (at least this particular kind, content-based attention) boils down to weighted averages. In short, a weighted average of encoder states becomes the decoder state. Attention is just the distribution of weights.

Here’s more on seq2seq and attention in Keras.

In pointer networks, attention is even simpler: instead of weighing input elements, it points at them probabilistically. In effect, you get a permutation of inputs. Refer to the paper for details and equations.

Note that one doesn’t need to use all the pointers. For example, given a piece of text, a network could mark an excerpt by pointing at two elements: where it starts, and where it ends.

Experiments

Where do we start? Well, how about ordering numbers. In other words, a deep argsort:

In [3]: np.argsort([ 10, 30, 20 ])
Out[3]: array([0, 2, 1], dtype=int64)

In [4]: np.argsort([ 40, 10, 30, 20 ])
Out[4]: array([1, 3, 2, 0], dtype=int64)

Surprisingly, the authors don’t pursue the task in the paper. Instead, they use two fancy problems: traveling salesman and convex hull (see READMEs), admittedly with very good results. Why not sort numbers, though?


Let us dive right in

It turns out that numbers are hard. They address it in the follow-up paper, Order Matters: Sequence to sequence for sets. The main point is, make no mistake, that order matters. Specifically, we’re talking about the order of the input elements. The authors found out that it influences results very much, which is not what we want. That’s because in essence we’re dealing with sets, not sequences, as input. Sets don’t have inherent order, so how elements are permuted ideally shouldn’t affect the outcome.

Hence the paper introduces an improved architecture, where they replace the LSTM encoder by a feed-forward network connected to another LSTM. That LSTM is said to run repeateadly in order to produce an embedding which is permutation invariant to the inputs. The decoder is the same, a pointer network.

Back to sorting numbers. The longer the sequence, the harder it is to sort. For five numbers, they report an accuracy ranging from 81% to 94%, depending on the model (accuracy here refers to the percentage of correctly sorted sequences). When dealing with 15 numbers, the scores range from 0% to 10%.

In our experiments, we achieved nearly 100% accuracy with 5 numbers. Note that this is “categorical accuracy” as reported by Keras, meaning a percentage of elements in their right places. For example, this example would be 50% accurate - the first two elements are in place, but the last two are swapped:

4 3 2 1 -> 3 2 0 1

For sequences with eight elements, the categorical accuracy drops to around 33%. We also tried a more challenging task, sorting a set of arrays by their sums:

[1 2] [3 4] [2 3] -> 0 2 1

The network handles this just as (un)easily as scalar numbers.

One unexpected thing we’ve noticed is that the network tends to duplicate pointers, especially early in training. This is disappointing: apparently it cannot remember what it predicted just a moment ago. “Oh yes, this element is going to be the second, and this next element is going to be the second. The next element, let’s see… It’s going to be the second, and the next…”

y_test: [2 0 1 4 3]
p:      [2 2 2 2 2]


Men gathered to visualize outputs of a pointer network in the early stage of training. No smiles at this point.

Later:

y_test: [2 0 1 4 3]
p:      [2 0 2 4 3]

Training sometimes gets stuck at some level of accuracy. And a network trained on small numbers doesn’t generalize to bigger ones, like these:

981,66,673
856,10,438
884,808,241

To help the network with numbers, we tried adding an ID (1,2,3…) to each element of the sequence. The hypothesis was that since the attention is content-based, maybe it could use positions explicitly encoded in content. This ID is either a number (train_with_positions.py) or a one-hot vector (train_with_positions_categorical.py). It seems to help a little, but doesn’t remove the fundamental difficulty.

Code for the experiments is available at GitHub. Compared with the original repo, we added a data generation script and changed the training script to load data from generated files. We also changed optimization algorithm to RMSProp, as it seems to converge reasonably well while handling the learning rate automatically.

Data structure

Data is in 3D arrays. The first dimension (rows) is examples, as usual. The second, columns, would normally be features (attributes), but with sequences the features go into the third dimension. The second dimension consists of elements of a given sequence. Below are three example sequences, each with three elements (steps), each with two features:

array([[[ 8,  2],
        [ 3,  3],
        [10,  3]],

       [[ 1,  4],
        [19, 12],
        [ 4, 10]],

       [[19,  0],
        [15, 12],
        [ 8,  6]],

The goal would be to sort the elements by the sum of the features, so the corresponding targets would be

array([[1, 0, 2],
       [0, 2, 1],
       [2, 0, 1],

These targets are encoded categorically:

array([[[ 0.,  1.,  0.],
        [ 1.,  0.,  0.],
        [ 0.,  0.,  1.]],

       [[ 1.,  0.,  0.],
        [ 0.,  0.,  1.],
        [ 0.,  1.,  0.]],

       [[ 0.,  0.,  1.],
        [ 1.,  0.,  0.],
        [ 0.,  1.,  0.]],

One hairy thing here is that we’ve been talking all along how recurrent networks can handle variable length sequences, but in practice data is 3D arrays, as seen above. In other words, the sequence length is fixed.


“Fixed?” wonders the cat.

The way to deal with it is to fix the dimensionality at the maximum possible sequence length and pad the unused places with zeros.

“Great”, you say, “but won’t it mess up the cost function?” It might, therefore we better mask those zeros so they are omitted when calculating loss. In Keras the official way to do this seems to be the embdedding layer. The relevant parameter is mask_zero:

mask_zero: Whether or not the input value 0 is a special “padding” value that should be masked out. This is useful when using recurrent layers which may take variable length input. If this is True then all subsequent layers in the model need to support masking or an exception will be raised. If mask_zero is set to True, as a consequence, index 0 cannot be used in the vocabulary (input_dim should equal size of vocabulary + 1).

For more on masking, see Variable Sequence Lengths in TensorFlow, or this notebook-style alternative.

On implementations

We have used a Keras implementation of pointer networks. There are a few others on GitHub, mostly in Tensorflow. Depending on how you look at it, that’s slightly crazy, as people build everything from the ground up, while one just needs a slight modification of a normal seq2seq with attention. On the other hand, the model hasn’t yet found its way into mainstream and Keras the way some others did, so it’s still about blazing trails.

One problem with all of that is you don’t know if an implementation you’re using is correct. Isn’t the network converging because of the task, the optimization method, or a bug? To be sure, you’d need to read and understand the source code line by line, which is just one step removed from writing it yourself. As OpenAI blog puts it:

Results are tricky to reproduce: performance is very noisy, algorithms have many moving parts which allow for subtle bugs, and many papers dont report all the required tricks.

Be wary of non-breaking bugs: when we looked through a sample of ten popular reinforcement learning algorithm reimplementations we noticed that six had subtle bugs found by a community member and confirmed by the author.

They’re talking about reinforcement learning, but the quote is widely applicable. Luckily, the official Order Matters implementation will be made available upon the publication of the paper. They promised. In the meantime, we salute you.


A picture of where we are in the grand scheme of things

Appendix A: implementations of Pointer Networks

But wait, there’s more:

Appendix B: some implementations of seq2seq with attention


P.S. Maybe, just maybe, we need to go DEEPER?

Talk about the article with a chatbot (Llama 3.3 70B from together.ai). The chatbot is rate-limited to six requests per minute. The messages are stored on your device.
code, neural-networks, software

« Project RHUBARB: predicting mortality in England using air quality data It's embarassing, really »

Comments