Microsoft’s Differential Transformer cancels attention noise in LLMs


Join our daily and weekly newsletters for the latest updates and exclusive content on industry-leading AI coverage. Learn More


Improving the capabilities of large language models (LLMs) in retrieving in-prompt information remains an area of active research that can impact important applications such as retrieval-augmented generation (RAG) and in-context learning (ICL).

Microsoft Research and Tsinghua University researchers have introduced Differential Transformer (Diff Transformer), a new LLM architecture that improves performance by amplifying attention to relevant context while filtering out noise. Their findings, published in a research paper, show that Diff Transformer outperforms the classic Transformer architecture in various settings.

Transformers and the “lost-in-the-middle” phenomenon

The Transformer architecture is the foundation of most modern LLMs. It uses an attention mechanism to weigh the importance of different parts of the input sequence when generating output. The attention mechanism employs the softmax function, which normalizes a vector of values into a probability distribution. In Transformers, the softmax function assigns attention scores to different tokens in the input sequence.

However, studies have shown that Transformers struggle to retrieve key information from long contexts.

“We began by investigating the so-called ‘lost-in-the-middle’ phenomenon,” Furu Wei, Partner Research Manager at Microsoft Research, told VentureBeat, referring to previous research findings that showed that LLMs “do not robustly make use of information in long input contexts” and that “performance significantly degrades when models must access relevant information in the middle of long contexts.”

Wei and his colleagues also observed that some LLM hallucinations, where the model produces incorrect outputs despite having relevant context information, correlate with spurious attention patterns.

“For example, large language models are easily distracted by context,” Wei said. “We analyzed the attention patterns and found that the Transformer attention tends to over-attend irrelevant context because of the softmax bottleneck.”

The softmax function used in Transformer’s attention mechanism tends to distribute attention scores across all tokens, even those that are not relevant to the task. This can cause the model to lose focus on the most important parts of the input, especially in long contexts.

“Previous studies indicate that the softmax attention has a bias to learn low-frequency signals because the softmax attention scores are restricted to positive values and have to be summed to 1,” Wei said. “The theoretical bottleneck renders [it] such that the classic Transformer cannot learn sparse attention distributions. In other words, the attention scores tend to flatten rather than focusing on relevant context.”

Differential Transformer

Differential transformer
Differential Transformer (source: arXiv)

To address this limitation, the researchers developed the Diff Transformer, a new foundation architecture for LLMs. The core idea is to use a “differential attention” mechanism that cancels out noise and amplifies the attention given to the most relevant parts of the input.

The Transformer uses three vectors to compute attention: query, key, and value. The classic attention mechanism performs the softmax function on the entire query and key vectors.

The proposed differential attention works by partitioning the query and key vectors into two groups and computing two separate softmax attention maps. The difference between these two maps is then used as the attention score. This process eliminates common noise, encouraging the model to focus on information that is pertinent to the input.

The researchers compare their approach to noise-canceling headphones or differential amplifiers in electrical engineering, where the difference between two signals cancels out common-mode noise.

While Diff Transformer involves an additional subtraction operation compared to the classic Transformer, it maintains efficiency thanks to parallelization and optimization techniques.

“In the experimental setup, we matched the number of parameters and FLOPs with Transformers,” Wei said. “Because the basic operator is still softmax, it can also benefit from the widely used FlashAttention cuda kernels for acceleration.”

In retrospect, the method used in Diff Transformer seems like a simple and intuitive solution. Wei compares it to ResNet, a popular deep learning architecture that introduced “residual connections” to improve the training of very deep neural networks. Residual connections made a very simple change to the traditional architecture yet had a profound impact.

“In research, the key is to figure out ‘what is the right problem?’” Wei said. “Once we can ask the right question, the solution is often intuitive. Similar to ResNet, the residual connection is an addition, compared with the subtraction in Diff Transformer, so it wasn’t immediately apparent for researchers to propose the idea.”

Diff Transformer in action

The researchers evaluated Diff Transformer on various language modeling tasks, scaling it up in terms of model size (from 3 billion to 13 billion parameters), training tokens, and context length (up to 64,000 tokens).

Their experiments showed that Diff Transformer consistently outperforms the classic Transformer architecture across different benchmarks. A 3-billion-parameter Diff Transformer trained on 1 trillion tokens showed consistent improvements of several percentage points compared to similarly sized Transformer models.

Further experiments with different model sizes and training dataset sizes confirmed the scalability of Diff Transformer. Their findings suggest that in general, Diff Transformer requires only around 65% of the model size or training tokens needed by a classic Transformer to achieve comparable performance.

Diff Transformer performance
The Diff Transformer is more efficient than the classic Transformer in terms of both parameters and train tokens (source: arXiv)

The researchers also found that Diff Transformer is particularly effective in using increasing context lengths. It showed significant improvements in key information retrieval, hallucination mitigation, and in-context learning.

While the initial results are promising, there’s still room for improvement. The research team is working on scaling Diff Transformer to larger model sizes and training datasets. They also plan to extend it to other modalities, including image, audio, video, and multimodal data.

The researchers have released the code for Diff Transformer, implemented with different attention and optimization mechanisms. They believe the architecture can help improve performance across various LLM applications.

“As the model can attend to relevant context more accurately, it is expected that these language models can better understand the context information with less in-context hallucinations,” Wei said. “For example, for the retrieval-augmented generation settings (such as Bing Chat, Perplexity, and customized models for specific domains or industries), the models can generate more accurate responses by conditioning on the retrieved documents.”



Source link

About The Author

Scroll to Top