Causal Mediation Analysis in Language Models
Description
Causal mediation analysis in language models is a mechanistic interpretability technique that systematically investigates how specific internal components (neurons, attention heads, or layers) causally contribute to model outputs. By performing controlled interventions—such as activating, deactivating, or modifying specific components—researchers can trace the causal pathways through which information flows and transforms within the model. This approach goes beyond correlation to establish causal relationships, enabling researchers to understand not just what features influence outputs, but how and why they do so through specific computational pathways.
Example Use Cases
Safety
Investigating causal pathways in content moderation models to understand how specific attention mechanisms contribute to flagging potentially harmful content, enabling verification that safety decisions rely on appropriate features rather than spurious correlations and ensuring robust content filtering.
Reliability
Identifying specific neurons or attention heads that causally contribute to biased outputs in hiring or lending language models, enabling targeted interventions to reduce discriminatory behaviour whilst preserving model performance on legitimate tasks and ensuring fair treatment across demographics.
Explainability
Tracing causal pathways in large language models performing mathematical reasoning tasks to understand how intermediate steps are computed and stored, revealing which components are responsible for different aspects of logical inference and enabling validation of reasoning processes.
Limitations
- Requires sophisticated understanding of model architecture to design meaningful interventions, as poorly chosen intervention points may yield misleading causal conclusions or fail to capture relevant computational pathways.
- Results are highly dependent on the validity of underlying causal assumptions, which can be difficult to verify in complex, high-dimensional neural network spaces where multiple causal pathways may interact.
- Comprehensive causal analysis requires extensive computational resources, particularly for large models, as each intervention requires separate forward passes and multiple intervention combinations for robust conclusions.
- Distinguishing between direct causal effects and indirect effects mediated through other components can be challenging, potentially leading to oversimplified causal narratives that miss important intermediate processes.
- Causal relationships identified in specific contexts or datasets may not generalise to different domains, tasks, or model versions, requiring careful validation across diverse scenarios to ensure robust findings.
Resources
Research Papers
Transformer Interpretability Beyond Attention Visualization
Self-attention techniques, and specifically Transformers, are dominating the field of text processing and are becoming increasingly popular in computer vision classification tasks. In order to visualize the parts of the image that led to a certain classification, existing methods either rely on the obtained attention maps or employ heuristic propagation along the attention graph. In this work, we propose a novel way to compute relevancy for Transformer networks. The method assigns local relevance based on the Deep Taylor Decomposition principle and then propagates these relevancy scores through the layers. This propagation involves attention layers and skip connections, which challenge existing methods. Our solution is based on a specific formulation that is shown to maintain the total relevancy across layers. We benchmark our method on very recent visual Transformer networks, as well as on a text classification problem, and demonstrate a clear advantage over the existing explainability methods.
Efficient Automated Circuit Discovery in Transformers using Contextual Decomposition
Automated mechanistic interpretation research has attracted great interest due to its potential to scale explanations of neural network internals to large models. Existing automated circuit discovery work relies on activation patching or its approximations to identify subgraphs in models for specific tasks (circuits). They often suffer from slow runtime, approximation errors, and specific requirements of metrics, such as non-zero gradients. In this work, we introduce contextual decomposition for transformers (CD-T) to build interpretable circuits in large language models. CD-T can produce circuits of arbitrary level of abstraction, and is the first able to produce circuits as fine-grained as attention heads at specific sequence positions efficiently. CD-T consists of a set of mathematical equations to isolate contribution of model features. Through recursively computing contribution of all nodes in a computational graph of a model using CD-T followed by pruning, we are able to reduce circuit discovery runtime from hours to seconds compared to state-of-the-art baselines. On three standard circuit evaluation datasets (indirect object identification, greater-than comparisons, and docstring completion), we demonstrate that CD-T outperforms ACDC and EAP by better recovering the manual circuits with an average of 97% ROC AUC under low runtimes. In addition, we provide evidence that faithfulness of CD-T circuits is not due to random chance by showing our circuits are 80% more faithful than random circuits of up to 60% of the original model size. Finally, we show CD-T circuits are able to perfectly replicate original models' behavior (faithfulness $ = 1$) using fewer nodes than the baselines for all tasks. Our results underscore the great promise of CD-T for efficient automated mechanistic interpretability, paving the way for new insights into the workings of large language models.
Causal Mediation Analysis for Interpreting Neural NLP: The Case of Gender Bias
Common methods for interpreting neural models in natural language processing typically examine either their structure or their behavior, but not both. We propose a methodology grounded in the theory of causal mediation analysis for interpreting which parts of a model are causally implicated in its behavior. It enables us to analyze the mechanisms by which information flows from input to output through various model components, known as mediators. We apply this methodology to analyze gender bias in pre-trained Transformer language models. We study the role of individual neurons and attention heads in mediating gender bias across three datasets designed to gauge a model's sensitivity to gender bias. Our mediation analysis reveals that gender bias effects are (i) sparse, concentrated in a small part of the network; (ii) synergistic, amplified or repressed by different components; and (iii) decomposable into effects flowing directly from the input and indirectly through the mediators.