Feature Attribution with Integrated Gradients in NLP
Description
Applies Integrated Gradients to natural language processing models to attribute prediction importance to individual input tokens, words, or subword units. This technique computes gradients along a straight-line path from a baseline input (typically all-zeros, padding tokens, or neutral text) to the actual input, integrating these gradients to obtain attribution scores. Unlike vanilla gradient methods, Integrated Gradients satisfies axioms of sensitivity and implementation invariance, making it particularly valuable for understanding transformer-based language models where token interactions are complex.
Example Use Cases
Safety
In a clinical decision support system processing doctor's notes to predict patient risk, Integrated Gradients identifies which medical terms, symptoms, or phrases most strongly influence risk predictions, enabling clinicians to verify that the model focuses on clinically relevant information rather than spurious correlations and supporting regulatory compliance in healthcare AI.
Fairness
For automated loan approval systems processing free-text application descriptions, Integrated Gradients reveals which words or phrases drive acceptance decisions, supporting fairness audits by highlighting whether protected characteristics inadvertently influence decisions and enabling transparent explanations to customers about application outcomes.
Explainability
In content moderation systems flagging potentially harmful posts, Integrated Gradients identifies which specific words or linguistic patterns trigger safety classifications, enabling platform teams to debug false positives and validate that models focus on genuinely problematic language rather than demographic markers.
Limitations
- Computational overhead scales significantly with document length as processing requires computing gradients across many integration steps (typically 20-300), making real-time applications or large-scale document processing challenging.
- Choice of baseline input (zero embeddings, padding tokens, neutral text, or average embeddings) substantially affects attribution results, but optimal baseline selection remains domain-specific and often requires extensive experimentation.
- In transformer models with attention mechanisms, importance often spreads across many tokens, making it difficult to identify clear, actionable insights, especially for complex reasoning tasks where multiple tokens contribute collectively.
- Modern NLP models use subword tokenisation (BPE, WordPiece), making attribution results difficult to interpret at the word level, as single words may split across multiple tokens with varying attribution scores.
- While Integrated Gradients identifies correlative relationships between tokens and predictions, it cannot establish causal relationships or distinguish between spurious correlations and meaningful semantic dependencies in the input text.
Resources
Captum: Model Interpretability for PyTorch
Open-source PyTorch library implementing Integrated Gradients with multi-modal support including text, featuring easy integration with transformer models and comprehensive NLP tutorials (BERT SQUAD, IMDB classification, Llama2 attribution).
Axiomatic Attribution for Deep Networks
Original paper introducing Integrated Gradients method with fundamental axioms of sensitivity and implementation invariance, demonstrating applications across text models and providing theoretical foundations for attribution methods.