Paper Review: Set Transformers
A Framework for Attention-based Permutation-Invariant Neural Networks
When working with unordered data, standard neural networks often falter, as they are designed to process structured, sequential, or grid-like inputs such as images or text. But what about problems where the input is a set, an inherently unordered collection of elements? This is where Set Transformers come into play. In this blog, we analyze the paper “Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks” by Lee et al., diving into its motivation, background, technical approach, experiments, and results. It was published in the Proceedings of the 36th International Conference on Machine Learning (ICML) in 2019.
Motivation: Why Do We Need Set Transformers?
Let’s start with a simple clustering problem: Given a set of n points X = {x₁, x₂, x₃,…, xₙ} where these points are generated from k different Gaussian distributions, our task is to identify the means and covariances of these k Gaussians.
A common approach to solving this problem is the Expectation-Maximization (EM) algorithm. However, anyone who has used EM knows how computationally expensive it becomes, especially as the size of the set n increases. Now, wouldn’t it be incredible if we had a neural network function approximator that could take the entire set X as input, analyze the relationships between the points, and directly output the k means and covariances?
Consider another intriguing problem: Given the set of images below, identify the odd one out.
By observing and comparing the features of all the images, a human would eventually spot the odd one in the red box. Why? Because the person in that image isn’t wearing glasses and doesn’t have a mustache. To achieve this, the human brain compares features across the entire set of images before making a decision. For a machine learning model to perform such tasks, it must also process the elements of the entire set as input simultaneously, enabling it to compare features across all elements and identify the outlier.
These are just two examples of many real-world problems involving data represented as sets. Training neural networks to excel at such tasks requires feeding the entire set as input while preserving the inherent relationships within the set.
The challenge, however, lies in traditional neural architectures — like CNNs and RNNs — which are sensitive to input order. Feeding a shuffled version of the same set often produces drastically different outputs, breaking the fundamental symmetry of sets. This is where Set Transformers shine. They are designed to be permutation-invariant, meaning the model’s output remains consistent regardless of the input order.
Background: The Set Processing Landscape
Before Set Transformers, research primarily focused on permutation-invariant models like DeepSets, introduced by Zaheer et al. This architecture offered a simple yet effective solution for handling sets. For a given set of n elements, each element in the set is processed individually through a shared neural network to generate n distinct encodings. These encodings are then aggregated using a symmetric function, such as summation or averaging, ensuring that the model remains invariant to the order of the elements in the set.
While effective, Deep Sets process each element of the set independently, which limits their ability to capture interactions between elements. This makes it challenging to model complex relationships within the set. To address this limitation, Set Transformers step in. They harness the power of attention mechanisms — the same concept that drives Transformers in NLP — to model intricate dependencies and interactions among elements within a set, enabling much richer representations.
Technical Approach: How Do Set Transformers Work?
The architecture of Set Transformers builds upon key innovations in attention mechanisms:
1. Encoder Block: Multihead Attention Meets Sets
Set Transformers replace traditional feedforward layers with Multihead Attention blocks to capture the relationship between different elements of the set.
Given two sets X and Y, a Multihead Attention Block (MAB) takes in the two sets and computes the cross attention between elements of the two sets by using the set X for obtaining queries and the set Y for obtaining keys and values.
Using the above definition of an MAB block, they define a Set Attention Block or SAB as SAB(X) = MAB(X, X).
SAB uses the set X to generate queries, keys, and values and performs self-attention to compute pairwise interactions between all elements in the set X, capturing relationships among the elements in X. The attention mechanism ensures that each element’s contribution is weighted dynamically, depending on the context provided by the rest of the set.
2. Inducing Points for Efficiency
While Multihead Attention (MHA) is highly effective at capturing relationships within sets, its computational cost grows quadratically with the size of the set, i.e. O(n²) when the set has n elements.
This quickly becomes a bottleneck for large sets. To tackle this challenge, Set Transformers introduces an innovative approach called Induced Set Attention Blocks (ISABs), significantly reducing computational complexity while maintaining performance.
Here’s how ISABs work:
- Inducing Points as a Bottleneck Representation Instead of computing pairwise attention across all elements in the input set X = {x₁, x₂, x₃,…, xₙ}, ISABs introduce a smaller, learnable set of inducing points, denoted as I = {i₁, i₂, i₃,…, iₘ}, where m<<<n. These inducing points act as a bottleneck, condensing information from the input set into a smaller, manageable representation.
- Two-Step Attention Mechanism ISABs compute attention in two stages:
- First, attention is computed from the input set X to the inducing points I. This step condenses the information from X into I, where each inducing point aggregates features from all elements of the set.
- Next, attention is computed back from the inducing points I to the input set X. This allows the inducing points to propagate the condensed information back to each element of the set, enriching their representations.
Together, these two steps ensure that interactions between all elements in X are captured indirectly via the inducing points, significantly reducing computational overhead. By using a fixed number of inducing points m, ISABs reduce the computational cost of MHA from O(n²) to O(nm). Since m is much smaller than n, this approach makes it feasible to scale Set Transformers to larger sets without sacrificing efficiency or performance.
As you might have realized so far, if the input set is of size n, the output from both the SAB and the ISAB is also of size n. Furthermore, neither SAB nor ISAB guarantee order invariance in the output, in fact, the output is permutation equivariant. A function is permutation equivariant if applying any permutation to its input and then passing it through the function gives the same result as passing the input through the function first and then permuting the output. (We refer the readers to the Set Transformer paper for further details.)
3. Pooling by Multihead Attention (PMA)
A key challenge in processing sets of variable sizes is creating a fixed-size representation suitable for tasks like classification or regression which is order invariant. Set Transformers address this with Pooling by Multihead Attention (PMA), which acts as the decoder block of the architecture. PMA operates on the intermediate representation Z output by the encoder blocks (SAB or ISAB), not directly on the input set X. This step is crucial to ensure permutation invariance in the final representation.
1)How PMA Works PMA uses a learnable set of query vectors Q = {q₁, q₂, q₃,…, qₖ} to attend to all the n elements of the set Z = {z₁, z₂, z₃,…, zₙ}. Each query qᵢ interacts with the entries in Z through attention, effectively summarizing the entire set into k fixed-size representations/embeddings.
This pooling process can be mathematically represented as:
where pᵢ is the output representation corresponding to the iᵗʰ query. By varying the number of queries k, PMA can produce summaries of different sizes, making it adaptable to diverse tasks.
2) Ensuring Permutation Invariance PMA ensures permutation invariance — one of the core requirements for set processing. For PMA, the query variables are learned and the elements of Z are used to obtain the keys and the values. The attention mechanism computes the encoding for each query by taking the dot product of it with every key and using that weight to aggregate the values. Since the queries are learned, and hence fixed, the dot product and aggregation step for every query is the same for all possible permutations of the elements in Z (and also X), i.e. the pooled representation remains invariant to permutations of the input set. This guarantees the permutation invariance critical for set processing.
3) Decoder Role in the Architecture PMA can be seen as the decoder block of Set Transformers. While the encoder block (e.g., SAB or ISAB) processes the set to model interactions between elements, PMA condenses the rich, contextual information captured in Z into a task-specific, permutation-invariant output suitable for downstream applications, such as classification, clustering, or regression.
4. Efficient and Adaptable The learnable queries Q allow the model to dynamically adapt to the task at hand. For example, in the task of point-cloud classification, each query might focus on specific features of the set, such as shape or texture. By condensing a variable-sized input into a fixed k × d representation, PMA seamlessly integrates with downstream layers (d is the size of each of the k embeddings).
Experiments and Results: Validating the Power of Set Transformers
The authors of the Set Transformer paper conducted experiments to demonstrate the model’s ability to handle set-structured data efficiently and effectively. These experiments span a variety of tasks designed to highlight the model’s key strengths: permutation invariance, computational efficiency, and the ability to model complex interactions. We will discuss two of those experiments here.
1. Counting Unique Digits
Objective: To evaluate the model’s ability to count the number of unique digits in a set of images.
Setup:
- Dataset: Synthetic datasets derived from MNIST, where each set contains 5–10 images of digits. The task is to predict the number of unique digits in the set.
Results:
- The Set Transformer outperformed Deep Sets in counting unique digits due to its ability to capture pairwise interactions via attention mechanisms.
- The ISAB variant demonstrated similar performance to SAB but with reduced computational complexity (linear vs. quadratic in set size). The accuracy of their approach surpasses that of other baselines, and the performance tends to increase as the set size increases.
Takeaway: The attention-based architecture of the Set Transformer makes it especially well-suited for tasks requiring recognition of unique elements in a set.
2. ModelNet40 Classification
Objective: To classify 3D objects represented as unordered point clouds.
Setup:
- Dataset: The ModelNet40 dataset, consists of 3D point cloud representations of objects from 40 categories. Each object is represented as a set of unordered points.
- Task: Classify each object into one of the 40 categories based on its point cloud representation.
Results:
- Set Transformers were superior to other baselines when given smaller-sized sets, but were outperformed by some baselines on larger sets. The authors suspect that their approach was outperformed in the problems with large sets because such sets already had sufficient information for classification, diminishing the need to model complex interactions among points.
Takeaway: The Set Transformer is highly effective for tasks involving unordered point cloud data, offering a competitive alternative to specialized architectures like PointNet.
Compared to Deep Sets and other baselines, Set Transformers consistently perform better across benchmarks, bridging the gap between efficiency and representational power.
Conclusion: A New Era for Set-Based Problems
The Set Transformer is a game-changer for handling set-structured data. By using attention mechanisms for both encoding and aggregation, it captures complex interactions within sets, something earlier models struggled with. It also introduces Induced Set Attention Blocks (ISABs) to tackle the computational challenges of self-attention, making it scalable to large sets without compromising performance.
One of the most exciting aspects of Set Transformers is their theoretical robustness — they’re proven to be universal approximators for permutation-invariant functions, ensuring their adaptability to diverse tasks.
Looking ahead, there’s immense potential for Set Transformers in areas like meta-learning, where they could enable efficient posterior inference in Bayesian models. Another intriguing direction would be exploring how uncertainty in set functions can be modeled by injecting noise into the architecture.
In essence, the Set Transformer is not just a novel solution but a stepping stone for future innovations in set-based machine learning problems.
In case you wish to try Set Transformers for your projects, you can find the code from the authors here: GitHub Link