The Architecture of Focus: Attention in Transformers and Discrete Choice
How Transformer Attention and Econometric ‘Consideration Sets’ share a mathematical skeleton
NLP
discrete choice
econometrics
deep learning
PyMC
Author
Nathaniel Forde
Published
March 13, 2026
Introduction: Attention is Bias
Every act of attention is an act of exclusion. To focus on one thing is to ignore another. When the mechanism operates invisibly in systems that make consequential decisions about people, it becomes unaccountable bias.
Large language models now screen résumés and shortlist applicants before a human sees them. Given candidates and a job description, they decide which features deserve processing resources. That is selective relevance weighting. Econometricians have studied the same structure for decades under a different name: consideration set formation.
These systems encode bias. They learn it from data that biased humans generated. The real problem is detection. Can we locate and quantify what they encode? Consideration set models separate “who gets noticed” from “who gets chosen.” They provide the audit infrastructure that transformer attention lacks.
Road map
When AI systems screen candidates, recommend content, or triage decisions, their internal logic is opaque. Consideration set models offer an explicit audit: they separate gatekeeping bias (who gets noticed) from conditional-on-screening evaluation (who gets chosen, given notice). If you can identify which variable blocks the consideration funnel, you can fix the intake filter. This post makes the case in three movements.
Part I builds a toy transformer that replicates the conjunction fallacy. A single attention head learns to over-weight vivid but non-diagnostic words. We can see the bias in the attention map. Then we add heads: the same bias persists, but now it is distributed across anonymous subspaces. No single head is “the biased one.” The architecture becomes what Dan Davies calls an accountability sink: a structure that absorbs blame without assigning it. Wrap it in a vendor API and the sink deepens. The firm points to the model, the vendor points to the data, the data points to history.
Part II introduces the econometric alternative. Consideration set models separate screening from evaluation, with named parameters, posteriors, and an exclusion restriction that forces instrument validity into the open. We fit the model to synthetic hiring data (where we know the ground truth) and then to the Swiss Metro transport dataset (where we do not), adding per-individual random coefficients on travel time to capture heterogeneity in time sensitivity. The hiring case has clean instruments: the model recovers which firm screens on which variable, and how strongly. The Swiss Metro case is harder. The GA pass is pulled between consideration and utility. The model names the tension and earns its decomposition.
Part III draws the comparison. Internal attribution methods (SHAP, probing, causal tracing) are confounded by the architecture they try to explain. The consideration set model sidesteps this: it wraps a behavioural model around the black box and asks whether the system’s screening pattern exhibits gatekeeping bias. The streetlight effect and the accountability sink are two sides of the same problem. We search where the light is; the architecture ensures there is no light where the bias lives. Consideration set models bring their own lamp.
A companion essay tracks the development of this argument through 21 commits over 15 days.
The Streetlight Effect: we look where the light is, not where the keys fell.
The old joke about searching under the streetlight applies here. Interpretability tools examine attention weights because they are visible, not because they are where the bias lives. Single-head attention is the lamppost: you can see the weights clearly, so you look there. Multi-head attention scatters the keys into the dark. Consideration set models do something different. They ask not where does the mechanism put the bias? but does the system’s behaviour exhibit bias, and on which variables? They bring their own light.
The Linda Problem: Where Attention Bias Begins
Linda is 31 years old, single, outspoken, and very bright. She majored in philosophy. As a student, she was deeply concerned with issues of discrimination and social justice, and also participated in anti-nuclear demonstrations.
Kahneman and Tversky’s famous vignette asks which is more probable: that Linda is (a) a bank teller, or (b) a bank teller and active in the feminist movement. Most people choose (b), the conjunction, violating basic probability. The error is instructive. The words outspoken, philosophy, discrimination, social justice create an attentional field that makes “feminist” overwhelmingly salient. The base rate drowns. Context selectively weights descriptors by relevance to each hypothesis.
This selective relevance weighting is the same mathematical mechanism at work in two paradigms:
Transformer attention (we use a simplified cross-attention variant for clarity), and
Both score, normalise, and weight. The difference that matters for fairness: one framework names its biases, the other distributes them across anonymous subspaces where they resist audit.
The Core Analogy
Transformer Attention
For a sequence of tokens, attention computes:
\[
\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right) V
\]
where:
\(Q \in \mathbb{R}^{1 \times d}\) is the Query, the target word’s projection of “what am I looking for?”
\(K \in \mathbb{R}^{n \times d}\) is the Key, each context word’s projection of “what do I signal?”
\(V \in \mathbb{R}^{n \times d}\) is the Value, each context word’s “information payload”
The output is a weighted sum of value payloads, with weights set by query-key similarity. The crucial point: \(K\) and \(V\) are separate projections. What determines what to attend to is decoupled from what gets retrieved.
A note on terminology: cross-attention vs self-attention
In self-attention, \(Q\), \(K\), and \(V\) are all projections of the same sequence. In cross-attention, \(Q\) comes from one source and \(K\)/\(V\) from another. Our implementation is closer to cross-attention: the query is “Linda” and the keys/values are the context words. This isolates the scoring mechanism without the complexity of full self-attention. Demo 4 adds multi-head structure to show how interpretability degrades with scale.
What cross-attention cannot represent
Our cross-attention architecture computes each context word’s relevance to “Linda” independently. Words do not attend to each other, so the model cannot learn that “organized” + “concern” together signal activist more than either alone. The conjunction effect lives entirely in the training labels. The model absorbs word-label correlations from biased data. It does not discover conjunction structure through architectural composition.
Full self-attention would let context words interact before classification. Production transformers handle conjunction-like reasoning this way, through multi-layer self-attention across the sequence.
This limitation is conservative. Our conjunction fallacy results are a lower bound on the interpretability problem. Pure single-word correlations from biased training data suffice to reproduce the bias. Self-attention with word-word interactions would make the bias harder to locate: more interaction terms, more compositional subspaces, more places for discrimination to hide.
Consideration Sets in Discrete Choice
A consumer does not evaluate every available alternative. She first forms a consideration set\(\mathcal{C}_n \subseteq \mathcal{J}\), then chooses from within it.
\(\pi_{nj} = \sigma\!\left(\gamma_{0j} + \mathbf{z}_n^\top \boldsymbol{\gamma}_j\right)\) is the consideration probability for alternative \(j\), i.e. the propensity to attend, modelled with a sigmoid (independent per alternative, not softmax)
\(P(\mathcal{C}) = \prod_k \pi_{nk}^{\mathbf{1}[k \in \mathcal{C}]} (1-\pi_{nk})^{\mathbf{1}[k \notin \mathcal{C}]}\) is the probability of a particular set forming
\(P(j \mid \mathcal{C}) = \text{softmax}(V_{n\mathcal{C}})_j\) is the standard logit choice within the consideration set
Marginalising over all \(2^J\) possible consideration sets is combinatorially expensive. The log-consideration-adjusted utility approximation collapses both stages into a single softmax:
This is a surrogate likelihood, not the structural model. The hard-gating story motivates the decomposition: each alternative is in or out. The log-adjusted form makes it estimable. We use it throughout because it is tractable, identified under the exclusion restriction, and sufficient for the audit question: which variables gate consideration, and how strongly?
The Structural Map
The Attention–Consideration Analogy
Transformer Component
Discrete Choice Analogue
Query \(Q\)
Consumer’s current need state (context \(z_n\))
Key \(K_j\)
Alternative \(j\)’s signalled attributes
\(Q \cdot K_j^\top / \sqrt{d}\)
Alternative \(j\)’s relevance score\(\gamma_{0j} + z_n'\gamma_j\)
\(\text{softmax}(QK^\top / \sqrt{d})\)
Consideration probability \(\pi_{nj} = \sigma(\cdot)\); note: sigmoid (independent), not softmax (coupled)
How scores and values combine differs. Transformers aggregate values by weights: \(\sum_j w_j V_j\). Consideration models add scores and values inside a single softmax: \(\text{softmax}_j(\log \pi_j + V_j)\). The parallel is structural, not algebraic. Both separate a relevance signal from a payload.
The other critical difference is the gate type. Transformers use soft continuous weighting: every token always contributes. Consideration sets use hard stochastic gating: alternatives are in the set or not. The transformer is a differentiable relaxation of the consideration set mechanism. Sending the softmax temperature \(\tau \to 0\) recovers hard one-hot gating from the soft distribution, i.e. the same limit used in Gumbel-Softmax reparameterisation.
Part I: The Transformer Side
We build a toy transformer that classifies “Linda” by attending to context. A single head learns to over-weight vivid but non-diagnostic words. The signal then fragments across multiple heads until it becomes unlocatable. The interpretability problem in miniature.
Demo 1 — Static Soft Attention
Before training anything, we can see the \(Q, K, V\) mechanism at work. The word "Linda" computes a scaled dot-product against context word keys. The \(Q \cdot K^\top\) score determines what to attend to, but \(V\) determines what information is retrieved.
Code
import torchimport torch.nn.functional as Fimport matplotlib.pyplot as pltimport numpy as npdef static_qkv_attention(Q, K, V):""" Full QKV scaled dot-product attention. Q : (1, d) — query (the target word: "what am I looking for?") K : (n, d) — keys (context words: "what do I signal?") V : (n, d) — values (context words: "what information do I deliver?") Returns ------- context : (1, d) — relevance-weighted information payload weights : (1, n) — attention distribution """ d_k = Q.shape[-1] scores = torch.matmul(Q, K.T) / (d_k **0.5) weights = F.softmax(scores, dim=-1) context = torch.matmul(weights, V) # weighted sum of VALUE payloadsreturn context, weightstorch.manual_seed(1)d =8# Hand-crafted key embeddings — activism words project similarly to the feminist queryoutspoken_K = torch.tensor([1.2, 0.1, 2.1, 0.5, 0.2, 0.0, 1.1, 0.8])justice_K = torch.tensor([1.1, 0.0, 2.3, 0.4, 0.1, 0.1, 1.0, 0.9])deposit_K = torch.tensor([0.1, 2.5, 0.2, 1.1, 1.5, 2.0, 0.1, 0.2]) # very different# Feminist-sense query for "Linda"linda_Q = torch.tensor([1.0, 0.2, 2.0, 0.5, 0.3, 0.1, 1.0, 0.8]).unsqueeze(0)# Value payloads are separate from keys — they carry semantic contentV_static = torch.randn(3, d) *0.5# arbitrary payloads for illustrationK_static = torch.stack([outspoken_K, justice_K, deposit_K])_, attn_w = static_qkv_attention(linda_Q, K_static, V_static)words = ["outspoken", "justice", "deposit"]w = attn_w[0].detach().numpy()fig, ax = plt.subplots(figsize=(6, 3.5))colors = [PAL["focal2"], PAL["focal2"], PAL["focal1"]]bars = ax.bar(words, w, color=colors, alpha=0.85, edgecolor="white", linewidth=1.5)ax.set_ylabel("Attention weight $w_j$")ax.set_title("Static QKV attention: feminist-sense 'Linda' query")ax.set_ylim(0, 1.05)for bar, weight inzip(bars, w): ax.text(bar.get_x() + bar.get_width() /2, bar.get_height() +0.02,f"{weight:.3f}", ha="center", va="bottom", fontsize=11)plt.tight_layout()plt.show()
Static attention: a feminist-sense query for ‘Linda’ attends selectively to activism context. The V matrix decouples the signal (K) from the payload (V).
The feminist-sense query attends heavily to outspoken and justice, and nearly ignores deposit. The output is the weighted sum of the value payloads of those relevant words, not just the knowledge that they were relevant.
Demo 2 — Learning to Attend: The Linda Classification Task
Now the \(Q\), \(K\), \(V\) matrices are learned from data, not hand-crafted. We construct a toy corpus centred on "Linda", inspired by the conjunction fallacy experiment.
Data-Generating Process
Each training example places "Linda" in a context corresponding to either (0) a feminist activist reading or (1) a bank teller reading. Three features of the original T&K experiment are built directly into the DGP:
Base rate imbalance. In the real world, bank tellers vastly outnumber feminist activists. We set P(teller sense) = 0.70, P(activist sense) = 0.30. A calibrated model should respect this prior for ambiguous inputs.
Vignette injection. Every participant in the original experiment reads the same activist-flavoured description of Linda before judging. We mirror this by injecting one word from a vignette pool (“outspoken”, “philosophy”, “justice”, “discrimination”) into every context, regardless of the true generating sense. This creates systematic activist-flavoured noise that the transformer must either discount or be biased by.
Conjunction boosts. Words like “independent” and “organized” are genuinely ambiguous. Individually they signal neither sense clearly. But when a vignette word co-occurs with an ambiguous word, the combination shifts the label toward activist on a log-odds scale. Vivid descriptors plus ambiguous traits create a stronger activist impression than either alone.
Each word \(w\) is assigned signal weights \((a_w, t_w)\) representing how likely it is to appear when “Linda” carries its activist versus teller sense. Context words are sampled proportional to these weights. The label is determined stochastically: base sense provides a strong prior, but conjunction effects can override it.
Code
import randomimport pandas as pdrandom.seed(42)np.random.seed(42)# =============================================================================# Vocabulary with graded signal strengths (Linda task).## Format: word -> (activist_weight, teller_weight)## activist_weight = proportional probability of appearing when Linda=feminist# teller_weight = proportional probability of appearing when Linda=bank teller## Ambiguous words carry moderate weight in BOTH senses.# Noise words carry low weight in either direction.# =============================================================================VOCAB_SIGNALS: dict[str, tuple[float, float]] = {# --- Strong ACTIVIST signals ---"protest": (0.90, 0.05),"outspoken": (0.85, 0.08),"justice": (0.88, 0.06),"philosophy": (0.82, 0.07),"discrimination":(0.80, 0.05),"rally": (0.78, 0.06),"petition": (0.75, 0.08),"equality": (0.72, 0.05),# --- Genuinely AMBIGUOUS words ---"independent": (0.50, 0.42), # "independent thinker" vs "independent contractor""organized": (0.38, 0.48), # "organized protests" vs "organized ledgers""concern": (0.45, 0.35), # "social concern" vs "financial concern"# --- Strong TELLER signals ---"deposit": (0.06, 0.92),"vault": (0.04, 0.95), # "vault of knowledge" leaks a mild activist signal"ledger": (0.05, 0.90),"account": (0.08, 0.88), # "account of injustice" leaks activist"teller": (0.03, 0.94),"transaction": (0.06, 0.85),"cashier": (0.05, 0.90),"receipt": (0.07, 0.82),# --- Low-signal NOISE words ---"Tuesday": (0.08, 0.07),"nearby": (0.09, 0.10),"usually": (0.07, 0.06),"recently": (0.06, 0.08),"often": (0.07, 0.07),"sometimes": (0.06, 0.06),"rather": (0.05, 0.05),"quite": (0.08, 0.09),}VOCAB =list(VOCAB_SIGNALS.keys())TARGET ="Linda"# --- Vignette injection (the T&K design) --------------------------------------# In the original experiment every participant reads the same activist-flavoured# description of Linda before judging. We mirror this by injecting one word# from this pool into EVERY context, regardless of the true generating sense.# This creates a systematic activist signal in every training example —# analogous to the "volume of text" that biases the transformer's mapping.VIGNETTE_POOL = ["outspoken", "philosophy", "justice", "discrimination"]# --- Conjunction effects (the Linda insight) -----------------------------------# Certain word PAIRS boost P(activist) when they co-occur. Values are additive# shifts on the log-odds scale of P(label = activist).## Vignette × ambiguous pairs fire often (a vignette word is always present);# ambiguous × ambiguous pairs fire rarely but carry stronger boosts.CONJUNCTION_BOOST: dict[frozenset, float] = {# Vignette × ambiguous (moderate boost, fires often)frozenset({"philosophy", "organized"}): 1.5, # philosophical organiser → activistfrozenset({"justice", "concern"}): 1.5, # justice + concern → activistfrozenset({"outspoken", "independent"}): 1.5, # outspoken independent → activistfrozenset({"discrimination", "concern"}): 1.2, # concern about discrimination → activist# Ambiguous × ambiguous (strong boost, fires rarely)frozenset({"organized", "concern"}): 3.0, # organized concern → social activismfrozenset({"independent", "organized"}): 2.5, # two ambiguous traits → activist profilefrozenset({"independent", "concern"}): 2.0, # principled concern → activist mindset}def generate_context(sense: int, ctx_size: int=4) ->list[str]:""" Sample a context window for 'Linda' given its true sense. Every context contains one VIGNETTE word (always activist-flavoured, drawn uniformly from VIGNETTE_POOL) plus (ctx_size - 1) words sampled by signal strength. This mirrors the T&K design: every participant reads the same vivid activist description, regardless of Linda's actual occupation. """ vignette = random.choice(VIGNETTE_POOL) remaining_vocab = [w for w in VOCAB if w != vignette] weights = np.array([VOCAB_SIGNALS[w][sense] for w in remaining_vocab]) weights = weights / weights.sum() others =list(np.random.choice( remaining_vocab, size=ctx_size -1, replace=False, p=weights ))return [vignette] + othersdef build_corpus(n: int=2500, ctx_size: int=4) ->list[tuple]:""" Generate a corpus of (context, target, label) triples. Base rate: P(teller sense) = 0.70, P(activist sense) = 0.30 — mirroring the real-world prior that bank tellers vastly outnumber feminist activists. The vignette injection and conjunction boosts then shift labels toward activist, creating a training distribution that over- represents activist labels relative to the base rate — the conjunction fallacy baked into the data. """ corpus = []for _ inrange(n): base_sense =0if random.random() <0.30else1# 30 % activist ctx = generate_context(base_sense, ctx_size)# Log-odds of activist label (base ≈ 90 % match with generating sense) base_logodds =2.2if base_sense ==0else-2.2# Pair interaction: co-occurring words boost activist odds ctx_set =set(ctx) boost =sum( delta for pair, delta in CONJUNCTION_BOOST.items()if pair.issubset(ctx_set) ) p_activist =1/ (1+ np.exp(-(base_logodds + boost))) label =0if random.random() < p_activist else1 corpus.append((ctx, TARGET, label))return corpuscorpus = build_corpus(n=2500)df = pd.DataFrame(corpus, columns=["context", "target", "label"])# --- Label distribution: does the conjunction fallacy inflate activist labels? ---n_act = (df["label"] ==0).sum()n_tel = (df["label"] ==1).sum()print(f"Label distribution: activist={n_act} ({n_act/len(df):.1%}) "f"teller={n_tel} ({n_tel/len(df):.1%})")print(f" (base-rate prior was 30 % activist / 70 % teller)\n")# Vignette frequency checkfor v in VIGNETTE_POOL: frac = df["context"].apply(lambda c: v in c).mean()print(f" Vignette '{v}' appears in {frac:.0%} of contexts")# Sanity checks: do signal words appear in the right contexts?print("\nWord co-occurrence by label (0=activist, 1=teller):\n")for word in ["vault", "protest", "independent", "Tuesday"]: df[f"has_{word}"] = df["context"].apply(lambda x: word in x) ct = pd.crosstab(df[f"has_{word}"], df["label"], normalize="columns") ct.index = [f"{word}=False", f"{word}=True"]print(ct.round(3));print()# Show conjunction effect: how do pair co-occurrences shift the label?print("--- Conjunction effects on P(activist label) ---\n")for pair, boost_val insorted(CONJUNCTION_BOOST.items(), key=lambda x: -x[1]): w1, w2 =sorted(pair) has_pair = df["context"].apply(lambda ctx: pair.issubset(set(ctx))) n_pair = has_pair.sum()if n_pair >0: p_act = (df.loc[has_pair, "label"] ==0).mean()print(f" {w1:>14s} + {w2:<14s} (boost={boost_val:+.1f}) "f"n={n_pair:3d} P(activist)={p_act:.2f}")else:print(f" {w1:>14s} + {w2:<14s} (boost={boost_val:+.1f}) "f"n= 0 [pair never sampled]")
Although only 30% of examples were generated from the activist-sense distribution, the effective P(activist label) is higher. Conjunction boosts systematically flip teller-sense examples toward activist when vignette words co-occur with ambiguous descriptors. Every context contains at least one vignette word, so the model trains on text that is always partially activist-flavoured.
Code
log_odds = { w: np.log((VOCAB_SIGNALS[w][0] +1e-3) / (VOCAB_SIGNALS[w][1] +1e-3))for w in VOCAB}sorted_items =sorted(log_odds.items(), key=lambda x: x[1])words_sorted = [x[0] for x in sorted_items]lo_sorted = [x[1] for x in sorted_items]fig, ax = plt.subplots(figsize=(8, 6))colors = [PAL["focal2"] if lo >0else PAL["focal1"] for lo in lo_sorted]ax.barh(words_sorted, lo_sorted, color=colors, alpha=0.85, edgecolor="white")ax.axvline(0, color=PAL["ref_line"], linewidth=0.9, linestyle="--")ax.set_xlabel(r"Log-odds $\log(a_w / t_w)$")ax.set_title("Vocabulary signal structure: activist (blue) vs teller (orange)")plt.tight_layout()plt.show()
Log-odds of each word’s activist vs teller signal strength. Words near zero are genuine ambiguity carriers; words at the extremes are strong sense-specific signals.
The log-odds chart shows each word’s marginal diagnostic value. Words at the extremes are unambiguous signals; words near zero carry little or conflicting information. The four vignette words (outspoken, philosophy, justice, discrimination) all sit at the activist extreme, but because they appear in every context regardless of sense, their marginal diagnostic value is actually lower than this chart suggests. The model must learn to discount them or it will systematically over-predict activist.
Full QKV Attention Head
Code
import torch.nn as nnimport torch.optim as optimclass AttentionHead(nn.Module):""" Single attention head (cross-attention style: Q from target, K/V from context). The analogy with discrete choice: Q <-> Consumer's current need state (what am I looking for?) K <-> What each word SIGNALS / advertises V <-> The information PAYLOAD each word delivers The attended context = softmax(QK^T / sqrt(d)) . V is a relevance-weighted extraction from the value space — not the raw embedding space. Without V, the model conflates "what I attend to" with "what I get", exactly as a consideration model that treats consideration probability and utility as the same quantity would. """def__init__(self, vocab_size: int, d_model: int=32):super().__init__()self.d_model = d_modelself.embed = nn.Embedding(vocab_size, d_model)self.q_proj = nn.Linear(d_model, d_model, bias=False)self.k_proj = nn.Linear(d_model, d_model, bias=False)self.v_proj = nn.Linear(d_model, d_model, bias=False) # the crucial additionself.clf = nn.Linear(d_model, 2)def _attention_weights(self, ctx_idxs: torch.Tensor, tgt_idx: torch.Tensor ) ->tuple[torch.Tensor, torch.Tensor]:"""Compute soft attention weights and context embeddings.""" c =self.embed(ctx_idxs) # (ctx, d) t =self.embed(tgt_idx).unsqueeze(0) # (1, d) Q =self.q_proj(t) # (1, d) — need-state query K =self.k_proj(c) # (ctx, d) — signal projection scores = torch.matmul(Q, K.T) / (self.d_model **0.5) weights = F.softmax(scores, dim=-1)return weights, cdef forward(self, ctx_idxs: torch.Tensor, tgt_idx: torch.Tensor ) ->tuple[torch.Tensor, torch.Tensor]: weights, c_embs =self._attention_weights(ctx_idxs, tgt_idx) V =self.v_proj(c_embs) # (ctx, d) — value/payload projection context_vec = torch.matmul(weights, V) # (1, d) — aggregate information logits =self.clf(context_vec) # (1, 2)return logits, weights
Hard-Gated Attention (the Consideration Set Variant)
Training
We train with cross-entropy loss only, no entropy penalty. The task has three sources of structure: individual word signals, pair co-occurrences, and a base rate prior (70% teller). Every context contains an activist-flavoured vignette word. Whether the model learns to discount this omnipresent signal or is captured by it is the empirical question.
Code
torch.manual_seed(42)# Build vocabulary (sorted for reproducibility)all_words =sorted(set(w for ctx, _, _ in corpus for w in ctx) | {TARGET})w2i = {w: i for i, w inenumerate(all_words)}i2w = {i: w for w, i in w2i.items()}train_corpus = corpus[:2000]test_corpus = corpus[2000:]model = AttentionHead(len(w2i), d_model=32)optimizer = optim.Adam(model.parameters(), lr=1e-3)criterion = nn.CrossEntropyLoss()losses = []for epoch inrange(100): model.train() epoch_loss =0.0for ctx, tgt, lbl in train_corpus: optimizer.zero_grad() c_idx = torch.tensor([w2i[w] for w in ctx]) t_idx = torch.tensor(w2i[tgt]) logits, _ = model(c_idx, t_idx) loss = criterion(logits, torch.tensor([lbl])) loss.backward() optimizer.step() epoch_loss += loss.item() losses.append(epoch_loss /len(train_corpus))# Evaluatemodel.eval()with torch.no_grad(): correct =sum( torch.argmax( model( torch.tensor([w2i[w] for w in ctx]), torch.tensor(w2i[tgt]) )[0] ).item() == lblfor ctx, tgt, lbl in test_corpus )print(f"Test accuracy : {correct /len(test_corpus):.3f}")print(f"Final loss : {losses[-1]:.4f}")
Test accuracy : 0.812
Final loss : 0.3654
What Has the Model Learned to Attend To?
Do the learned attention weights track the true signal strengths, or are they biased by the vignette injection? Vignette words appear in every context, so their marginal diagnostic value is zero. Yet their strong activist associations might cause the model to over-attend to them.
Code
attention_records = {0: {}, 1: {}}model.eval()with torch.no_grad():for ctx, tgt, lbl in corpus: c_idx = torch.tensor([w2i[w] for w in ctx]) t_idx = torch.tensor(w2i[tgt]) _, weights = model(c_idx, t_idx)for w, a inzip(ctx, weights[0].numpy()): attention_records[lbl].setdefault(w, []).append(float(a))# Words appearing in both label groups — exclude vignette words from the# Spearman analysis because their appearance is forced (not governed by# VOCAB_SIGNALS), so their "true signal weight" is not a meaningful# comparison point.common_words = [ w for w in VOCABif w notin VIGNETTE_POOLand attention_records[0].get(w) and attention_records[1].get(w)]vignette_words_in_data = [ w for w in VIGNETTE_POOLif attention_records[0].get(w) and attention_records[1].get(w)]activist_attn = [np.mean(attention_records[0][w]) for w in common_words]teller_attn = [np.mean(attention_records[1][w]) for w in common_words]fig, axes = plt.subplots(1, 2, figsize=(13, 6))for ax, (attn, sense_name, color) inzip( axes, [(activist_attn, "Activist sense (0)", PAL["focal2"]), (teller_attn, "Teller sense (1)", PAL["focal1"])],): order = np.argsort(attn)[::-1] ax.barh( [common_words[i] for i in order], [attn[i] for i in order], color=color, alpha=0.85, ) baseline =1/len(common_words) ax.axvline( baseline, color=PAL["ref_line"], linestyle="--", alpha=0.5, label=f"Uniform baseline ({baseline:.2f})" ) ax.set_xlabel("Average attention weight") ax.set_title(f"Learned attention: {sense_name}") ax.legend(fontsize=8)plt.suptitle("Learned attention by label — do vignette words get discounted or over-attended?", y=1.02, fontsize=12)plt.tight_layout()plt.show()
Average attention weight per word, by label. Vignette words (outspoken, philosophy, justice, discrimination) appear in every context — does the model learn to discount them?
Code
from scipy.stats import spearmanrfig, axes = plt.subplots(1, 2, figsize=(12, 5))for ax, (sense_idx, sense_label, attn_list, color) inzip( axes, [(0, "Activist", activist_attn, PAL["focal2"]), (1, "Teller", teller_attn, PAL["focal1"])],): true_s = [VOCAB_SIGNALS[w][sense_idx] for w in common_words] rho, p = spearmanr(true_s, attn_list)# Non-vignette words (governed by VOCAB_SIGNALS) ax.scatter(true_s, attn_list, color=color, alpha=0.8, s=55)for w, ts, la inzip(common_words, true_s, attn_list): ax.annotate(w, (ts, la), fontsize=7, alpha=0.75, xytext=(3, 2), textcoords="offset points")# Vignette words (always present — annotated separately as red triangles)for vw in vignette_words_in_data: vts = VOCAB_SIGNALS[vw][sense_idx] vla = np.mean(attention_records[sense_idx][vw]) ax.scatter([vts], [vla], color=PAL["focal1"], marker="^", s=80, zorder=5, alpha=0.9) ax.annotate(vw, (vts, vla), fontsize=7, alpha=0.9, color=PAL["focal1"], xytext=(3, 2), textcoords="offset points")# Regression line m, b = np.polyfit(true_s, attn_list, 1) xs = np.linspace(min(true_s), max(true_s), 100) ax.plot(xs, m * xs + b, color=color, linestyle="--", alpha=0.6) ax.set_xlabel(f"True DGP signal weight ({sense_label.lower()} sense)") ax.set_ylabel("Learned attention weight") ax.set_title(f"{sense_label} sense (Spearman ρ = {rho:.2f}, p = {p:.3f})")plt.suptitle("Does transformer attention recover the DGP signal weights?", y=1.02, fontsize=12)plt.tight_layout()plt.show()
Scatter of true DGP signal strength vs learned average attention weight, by label. Circles: non-vignette words (Spearman computed over these). Red triangles: vignette words (always present, excluded from correlation — their forced presence means VOCAB_SIGNALS does not govern their frequency).
How to read these plots. Each point is a word. The x-axis is the true DGP signal weight; the y-axis is the learned attention weight. Perfect recovery would place points on an increasing line. Red triangles are vignette words, excluded from the correlation because their appearance is forced.
Two features stand out. First, asymmetry between senses: the teller panel shows significant positive \(\rho\). The majority class (70%) gives the model enough data to learn which words are teller-diagnostic. The activist panel shows weak or negative \(\rho\), with fewer examples and vignette contamination.
Second, vignette words have the highest designed signal weights but zero diagnostic value. Elevated attention to them is the conjunction fallacy in learned weights. The model has been captured by the volume of activist-flavoured text, exactly as human respondents are captured by the vivid Linda description.
A weak \(\rho\) for non-vignette words teaches a further lesson. The V-projection and classifier can absorb classification signal without attention weights faithfully mirroring word importance. As Jain & Wallace (2019) and Wiegreffe & Pinter (2019) argued, attention provides a mechanism but not necessarily an explanation.
Demo 3 — The Conjunction Fallacy in Attention
Human respondents systematically judge P(feminist bank teller) > P(bank teller). Vivid descriptions overwhelm the base rate. Does our transformer exhibit the same bias? Seven diagnostic contexts separate three effects: the base rate prior, the conjunction boost, and the vignette injection.
Code
probes = {"Pure Teller": ["deposit", "vault", "ledger", "account"],"Teller + Vignette": ["outspoken", "deposit", "vault", "ledger"],"Ambiguous (no pairs)": ["independent", "Tuesday", "nearby", "usually"],"Ambiguous (with pairs)": ["independent", "organized", "concern", "Tuesday"],"Vignette + Amb pairs": ["philosophy", "independent", "organized", "concern"],"Vignette + Amb + Teller": ["justice", "organized", "deposit", "usually"],"Pure Activist": ["protest", "rally", "petition", "equality"],}probe_results = {}model.eval()with torch.no_grad():for name, ctx in probes.items(): c_idx = torch.tensor([w2i[w] for w in ctx]) t_idx = torch.tensor(w2i[TARGET]) logits, attn_w = model(c_idx, t_idx) probs = F.softmax(logits, dim=-1)[0].numpy() pred ="activist"if probs[0] > probs[1] else"teller" probe_results[name] = {"P(activist)": probs[0],"prediction": pred,"attention": dict(zip(ctx, attn_w[0].numpy())), }# --- Table ---print(f"{'Probe':<26s} P(activist) Prediction")print("-"*58)for name, res in probe_results.items(): marker =" ***"if name =="Vignette + Amb pairs"else""print(f"{name:<26s}{res['P(activist)']:>9.3f}{res['prediction']}{marker}")print(f"\n Base rate prior = 0.30")# --- Bar chart ---fig, ax = plt.subplots(figsize=(11, 5))names =list(probe_results.keys())p_act = [probe_results[n]["P(activist)"] for n in names]colors = [PAL["focal1"] if p >0.5else PAL["focal2"] for p in p_act]bars = ax.bar(range(len(names)), p_act, color=colors, alpha=0.85, edgecolor="white")ax.axhline(0.30, color=PAL["ref_line"], linestyle="--", alpha=0.6, label="Base rate (30 % activist)")ax.axhline(0.50, color=PAL["ref_dash"], linestyle=":", alpha=0.4, label="Decision boundary")ax.set_xticks(range(len(names)))ax.set_xticklabels(names, rotation=30, ha="right", fontsize=9)ax.set_ylabel("P(model predicts activist)")ax.set_ylim(0, 1.05)ax.legend(fontsize=8)ax.set_title("The conjunction fallacy in attention: ""isolating base rate, conjunction pairs, and vignette effects")plt.tight_layout()plt.show()
Probe P(activist) Prediction
----------------------------------------------------------
Pure Teller 0.028 teller
Teller + Vignette 0.115 teller
Ambiguous (no pairs) 0.012 teller
Ambiguous (with pairs) 0.855 activist
Vignette + Amb pairs 0.989 activist ***
Vignette + Amb + Teller 0.268 teller
Pure Activist 0.972 activist
Base rate prior = 0.30
Conjunction fallacy probe: P(model predicts activist) for seven diagnostic contexts. The dashed line marks the 30% base rate. Moving left to right isolates the additive effects of conjunction pairs and vignette words.
The probes isolate three cumulative effects:
Base rate: Pure Teller and Pure Activist establish the endpoints.
Conjunction pairs: Compare Ambiguous (no pairs) to Ambiguous (with pairs). The first should classify near the 30% base rate. The second packs three ambiguous words whose pairwise boosts total +7.5 on the log-odds scale. The jump measures the conjunction effect alone.
Vignette injection: Compare Ambiguous (with pairs) to Vignette + Amb pairs. Same conjunction words, but a noise word is swapped for philosophy. Any additional increase in P(activist) is the marginal vignette effect.
The transformer classifies by pattern similarity, not by Bayesian updating. The volume of activist-flavoured text shapes the learned mapping, exactly as the vivid Linda description biases human respondents away from the base rate.
Demo 4 — Multi-Head Attention: Where Does the Bias Live?
The single-head model gives us one attention map and a direct link between what the model attends to and what it predicts. Real transformers use multiple heads, each operating in a low-dimensional subspace, recombined through nonlinear projections. Two questions. First, when bias is distributed across heads, does it become harder to detect? Second, does recombination amplify or attenuate the conjunction fallacy?
Why multi-head attention matters for interpretability
Four heads compute Q, K, V projections in 8-dimensional subspaces (\(d_{\text{head}} = 32 / 4 = 8\)). Head outputs are concatenated and passed through a two-layer MLP with GELU activation before classification. The nonlinear recombination destroys decomposability. You cannot express the final logit as a sum of per-head contributions.
We use a single-layer, four-head architecture. Production transformers stack many such layers. The distribution effect is already stark at one layer. The interpretability challenge is architectural, not a matter of scale.
The Multi-Head Architecture
class MultiHeadAttention(nn.Module):""" Multi-head attention with nonlinear recombination. Four heads, each in a d_head-dimensional subspace, recombined through a two-layer MLP (GELU activation) before classification. The nonlinear recombination is the key: it makes the mapping from per-head attention to output non-decomposable — you cannot attribute the final prediction to any single head's attention pattern. """def__init__(self, vocab_size: int, d_model: int=32, n_heads: int=4):super().__init__()assert d_model % n_heads ==0, "d_model must be divisible by n_heads"self.d_model = d_modelself.n_heads = n_headsself.d_head = d_model // n_headsself.embed = nn.Embedding(vocab_size, d_model)# Per-head Q, K, V projections (d_model -> d_head each)self.q_projs = nn.ModuleList([ nn.Linear(d_model, self.d_head, bias=False) for _ inrange(n_heads) ])self.k_projs = nn.ModuleList([ nn.Linear(d_model, self.d_head, bias=False) for _ inrange(n_heads) ])self.v_projs = nn.ModuleList([ nn.Linear(d_model, self.d_head, bias=False) for _ inrange(n_heads) ])# Nonlinear output MLP: concat -> GELU -> projection -> classifierself.out_mlp = nn.Sequential( nn.Linear(d_model, d_model), nn.GELU(), nn.Linear(d_model, d_model), )self.clf = nn.Linear(d_model, 2)def _head_attention(self, head_idx, c_embs, t_emb):"""Compute attention weights and output for a single head.""" Q =self.q_projs[head_idx](t_emb.unsqueeze(0)) # (1, d_head) K =self.k_projs[head_idx](c_embs) # (ctx, d_head) V =self.v_projs[head_idx](c_embs) # (ctx, d_head) scores = torch.matmul(Q, K.T) / (self.d_head **0.5) weights = F.softmax(scores, dim=-1) # (1, ctx) out = torch.matmul(weights, V) # (1, d_head)return weights, outdef forward(self, ctx_idxs, tgt_idx): c_embs =self.embed(ctx_idxs) # (ctx, d_model) t_emb =self.embed(tgt_idx) # (d_model,) all_weights = [] head_outputs = []for h inrange(self.n_heads): w_h, out_h =self._head_attention(h, c_embs, t_emb) all_weights.append(w_h) head_outputs.append(out_h)# Nonlinear recombination: concat -> MLP -> classifier concat = torch.cat(head_outputs, dim=-1) # (1, d_model) projected =self.out_mlp(concat) # (1, d_model) logits =self.clf(projected) # (1, 2)return logits, all_weights # all_weights: list of n_heads tensors
Same embedding, same Q/K/V pattern, same classifier. Two additions: four heads in 8-dimensional subspaces instead of one head in 32, and a GELU MLP that recombines head outputs nonlinearly before classification. A linear output projection would preserve additive decomposability. The GELU layer destroys it.
Training
Same hyperparameters and data as the single-head model. The only difference is the architecture.
torch.manual_seed(42)n_heads =4mh_model = MultiHeadAttention(len(w2i), d_model=32, n_heads=n_heads)mh_opt = optim.Adam(mh_model.parameters(), lr=1e-3)mh_losses = []for epoch inrange(100): mh_model.train() epoch_loss =0.0for ctx, tgt, lbl in train_corpus: mh_opt.zero_grad() c_idx = torch.tensor([w2i[w] for w in ctx]) t_idx = torch.tensor(w2i[tgt]) logits, _ = mh_model(c_idx, t_idx) loss = criterion(logits, torch.tensor([lbl])) loss.backward() mh_opt.step() epoch_loss += loss.item() mh_losses.append(epoch_loss /len(train_corpus))# Evaluatemh_model.eval()with torch.no_grad(): mh_correct =sum( torch.argmax( mh_model( torch.tensor([w2i[w] for w in ctx]), torch.tensor(w2i[tgt]) )[0] ).item() == lblfor ctx, tgt, lbl in test_corpus )print(f"Single-head test accuracy : {correct /len(test_corpus):.3f}")print(f"Multi-head test accuracy : {mh_correct /len(test_corpus):.3f}")print(f"Single-head final loss : {losses[-1]:.4f}")print(f"Multi-head final loss : {mh_losses[-1]:.4f}")
Single-head test accuracy : 0.812
Multi-head test accuracy : 0.772
Single-head final loss : 0.3654
Multi-head final loss : 0.2230
More parameters, lower training loss, no better generalisation. Four heads and a nonlinear MLP give the bias more places to hide.
Head Specialisation: Who Attends to What?
Do the four heads specialise, each attending to a distinct word class? Or do they learn diffuse, overlapping patterns? Specialisation would preserve some interpretability; diffusion would mean the bias is truly distributed.
Code
WORD_CATEGORIES = {"activist": [w for w, (a, t) in VOCAB_SIGNALS.items()if a >0.70and w notin VIGNETTE_POOL],"vignette": VIGNETTE_POOL,"teller": [w for w, (a, t) in VOCAB_SIGNALS.items() if t >0.80],"ambiguous": ["independent", "organized", "concern"],"noise": [w for w, (a, t) in VOCAB_SIGNALS.items()if a <0.10and t <0.15],}# Collect per-head attention by word categoryhead_cat_attn = { h: {cat: [] for cat in WORD_CATEGORIES}for h inrange(n_heads)}# Also collect per-word for later usehead_word_attn = {h: {} for h inrange(n_heads)}mh_model.eval()with torch.no_grad():for ctx, tgt, lbl in corpus: c_idx = torch.tensor([w2i[w] for w in ctx]) t_idx = torch.tensor(w2i[tgt]) _, all_weights = mh_model(c_idx, t_idx)for h inrange(n_heads):for w, a inzip(ctx, all_weights[h][0].numpy()):# Per-categoryfor cat, words in WORD_CATEGORIES.items():if w in words: head_cat_attn[h][cat].append(float(a))break# Per-word head_word_attn[h].setdefault(w, []).append(float(a))# Build heatmapcategories =list(WORD_CATEGORIES.keys())heatmap_data = np.array([ [np.mean(head_cat_attn[h][cat]) if head_cat_attn[h][cat] else0.0for cat in categories]for h inrange(n_heads)])fig, ax = plt.subplots(figsize=(8, 3))im = ax.imshow(heatmap_data, cmap="bone_r", aspect="auto")ax.set_xticks(range(len(categories)))ax.set_xticklabels(categories, fontsize=10)ax.set_yticks(range(n_heads))ax.set_yticklabels([f"Head {h}"for h inrange(n_heads)], fontsize=10)for i inrange(n_heads):for j inrange(len(categories)): ax.text(j, i, f"{heatmap_data[i, j]:.3f}", ha="center", va="center", fontsize=9, color="white"if heatmap_data[i, j] > heatmap_data.max() *0.6else PAL["ref_line"])plt.colorbar(im, ax=ax, label="Mean attention weight")ax.set_title("Head specialisation: which heads attend to which word types?")plt.tight_layout()plt.show()
Mean attention by head and word category, averaged over the full corpus. If heads specialise cleanly, each row should have a distinct column peak. If bias is distributed, the pattern will be diffuse — and no single head will cleanly capture the vignette effect.
Partial specialisation, but not along the axis that matters. Head 0 tracks strong diagnostic signals and suppresses vignette. Head 1 picks up ambiguous and teller words. Head 2 is diffuse. Head 3 concentrates on vignette and activist words but conflates the two. No single head gives you “the vignette effect” in isolation. The bias has no address.
Vignette Bias: Concentrated or Distributed?
Vignette words appear in every context with zero diagnostic value. Does multi-head attention concentrate the over-attention in one head, or spread it across all four?
Code
vignette_words = VIGNETTE_POOL# Collect single-head attention for vignette wordssh_vignette_attn = {}for lbl in [0, 1]:for w in vignette_words: vals = attention_records[lbl].get(w, []) sh_vignette_attn.setdefault(w, []).extend(vals)n_groups =len(vignette_words)bar_width =0.25x = np.arange(n_groups)fig, ax = plt.subplots(figsize=(10, 5))# Single-head barssh_means = [np.mean(sh_vignette_attn.get(w, [0.0])) for w in vignette_words]ax.bar(x - bar_width, sh_means, bar_width, label="Single head", color=PAL["focal2"], alpha=0.85)# Multi-head per-head barscolors_mh = [PAL["focal1"], PAL["focal3"], PAL["taupe"], PAL["brown"]]for h inrange(n_heads): mh_means = [np.mean(head_word_attn[h].get(w, [0.0]))for w in vignette_words] ax.bar(x + h * bar_width, mh_means, bar_width, label=f"Head {h}", color=colors_mh[h], alpha=0.85)ax.axhline(1.0/4, color=PAL["ref_line"], linestyle="--", alpha=0.5, label="Uniform baseline (1/ctx_size)")ax.set_xticks(x + bar_width * (n_heads -1) /2)ax.set_xticklabels(vignette_words, fontsize=10)ax.set_ylabel("Mean attention weight")ax.set_title("Vignette word attention: single head vs per-head decomposition")ax.legend(fontsize=9)plt.tight_layout()plt.show()
Vignette word attention: single-head vs multi-head decomposition. The single-head bar shows total learned attention to each vignette word. The multi-head bars show per-head contributions. The dashed line marks uniform attention (1/4 for context size 4).
No single head’s attention map tells the full story. Even if three heads suppress vignette attention, the fourth head’s value representation may carry the signal through the GELU MLP to the output. The bias is not in any one set of weights. It is in the interaction.
Does Multi-Head Attention Amplify the Conjunction Fallacy?
We run the same seven diagnostic probes from Demo 3 through the multi-head model and compare P(activist) side by side, with calibration curves on the test set.
Left: Conjunction fallacy probe comparison — single-head vs multi-head P(activist) for seven diagnostic contexts. Right: Calibration curves on the test set — does multi-head attention produce better-calibrated probabilities, or does it simply learn a more confident version of the same bias?
The conjunction fallacy persists. Multi-head attention does not eliminate the bias. Calibration does not improve.
The finding is not about amplification. It is about locatability. In Demo 3 we traced the bias to specific weights on specific words. Here the same discriminatory outcome resists inspection.
The interpretability cost of distribution
Single head: inspect one attention map, ask “is the model over-attending to vignette words?”, get a legible answer. This is the streetlight. The weights are visible, so we look there.
Four heads + GELU MLP: no clean answer. The bias lives in the interaction of concatenated value vectors, nonlinear recombination, and classifier. It cannot be decomposed into additive per-head contributions. Even linear probing of individual heads misses the interaction terms that drive the output. The keys are in the dark.
Dan Davies would recognise the structure. An accountability sink is an institutional arrangement that absorbs blame without assigning it: the call centre that cannot escalate, the committee where no individual voted for the outcome. Multi-head attention is an accountability sink for bias. The discrimination is real. No single head is responsible.
Part II: From Cognitive Bias to Institutional Bias
The Linda task is a thought experiment. The real world is worse. In hiring, the priming signals are protected characteristics: race, age, gender. When a firm replaces its screening committee with an API call to a vendor LLM, it outsources the decision and the accountability. The vendor says: “we provide a tool, not a decision.” The firm says: “we used the best available technology.” The candidate has no one to appeal to. The architecture distributes blame internally across anonymous heads; the procurement structure distributes it externally across organisations. The opacity is a side effect that happens to be convenient.
The consideration set model is the counter-architecture. Every screening variable gets a name, a coefficient, and a posterior. Who gets considered? is separate from who gets hired?
Implementation
We use the surrogate likelihood from Part I: \(\text{softmax}_j(\log \pi_{nj} + V_{nj})\). Two identification choices, both applied in the hiring and Swiss Metro models below:
Identification requires K ≠ V: two aliasing problems
Problem 1 — Level aliasing. If both \(\gamma_{0j}\) and \(\alpha_j\) are free, they compete. Both shift the level of \(U_{\text{avail},nj}\) in the same direction, so a mode that is “often considered” can look identical to a mode with “high intrinsic utility.” The two strategies are: (a) drop \(\gamma_{0j}\) entirely, letting \(\alpha_j\) absorb both baselines, simpler but forcing the utility constant to do double duty; or (b) include \(\gamma_{0j}\) and accept the soft identification from priors and the nonlinearity of the sigmoid. The hiring model below uses strategy (a). The Swiss Metro model uses strategy (b): the consideration intercepts absorb baseline consideration, freeing \(\alpha_j\) to capture only utility preference, and the \(\gamma_z\) slopes to capture only the marginal instrument effects. The aliasing is not eliminated. It is managed, and the intercepts make the decomposition more interpretable at the cost of weaker identification.
Problem 2 — Misspecified heterogeneity. If the DGP generates individual heterogeneity in utility that the model does not capture, the unexplained variance bleeds into the consideration mechanism, attenuating the utility coefficients. Fix: match the model complexity to the DGP. The Swiss Metro model includes per-individual random effects on travel-time sensitivity to absorb this heterogeneity in the utility stage where it belongs.
With mean-centred instruments, the consideration model becomes:
where \(\tilde{z}_{nj}\) is mean-centred so that the default consideration (\(z_{nj} = \bar{z}_j\)) is captured by \(\gamma_{0j}\). Only the within-alternative, cross-person variation in \(z\) identifies \(\gamma_{zj}\). When \(\gamma_{0j}\) is dropped (hiring model), this variation identifies against \(\alpha_j\).
This is the discrete choice analogue of K ≠ V in transformers: the signal that determines what to consider (\(\tilde{z}\), the Key) must be structurally separate from what determines utility (\(x\), the Value), and both must be different projections of the available information.
Data-Generating Process: Hiring Decisions
The consideration stage uses screening instruments: variables that shift whether a firm considers a candidate but do not determine job utility. Finding good instruments is the central modelling challenge. In hiring, two candidates are obvious:
grad_degree: a legitimate screening criterion (has a degree or not)
is_white: an illegitimate but empirically documented screening criterion
These are the Keys: they determine which candidates pass the initial filter. Age and experience are the Values: they determine how attractive a candidate is once considered.
The interesting wrinkle: different firms screen on different things. We set up the DGP so that Firm A screens on both variables moderately, Firm B screens heavily on education but weakly on race, and Firm C screens heavily on race but weakly on education. The model must recover not just the existence of bias but the per-firm pattern. That is the actionability point. If Firm C’s racial screening is the problem, you fix Firm C’s intake filter. A standard logit would bury these per-firm differences in aggregate constants.
Code
SEED =42np.random.seed(SEED)# ---- True parameters ----TRUE_ALPHA = np.array([0.5, 1.0, 0.0]) # firm intercepts (Firm C = reference)TRUE_B_AGE =-0.8# older candidates less preferredTRUE_B_EXP =0.5# more experience preferred# Consideration slopes: (3 firms) × (2 instruments: grad_degree, is_white)# Each firm screens on a different mix — this is the pattern the model must recover.TRUE_GAMMA = np.array([ [1.5, 2.0], # Firm A: screens on both moderately [2.5, 0.5], # Firm B: screens heavily on grad, weakly on race [0.5, 3.0], # Firm C: screens heavily on race, weakly on grad])# ---- Simulate data ----N =4000alts = ["Firm A", "Firm B", "Firm C"]# Consideration instruments (screening-relevant only)grad_degree = np.random.binomial(1, 0.35, N).astype(float)white = np.random.binomial(1, 0.65, N).astype(float)Z_person = np.column_stack([grad_degree, white])Z_MEANS = Z_person.mean(axis=0)Z_tilde_sim = Z_person - Z_MEANS# Stack to (N, J=3, K_z=2) — same person-level Z for all modesZ_tilde_3d = np.stack([Z_tilde_sim] *3, axis=1)rows = []for i inrange(N):# Consideration probabilities — per-firm gamma × person instruments log_odds = TRUE_GAMMA @ Z_tilde_sim[i] # (3,) pi =1.0/ (1.0+ np.exp(-log_odds))# Alternative-specific utility covariates (scaled) age_A = np.random.uniform(22, 60) /40 age_B = np.random.uniform(22, 60) /40 age_C = np.random.uniform(22, 60) /40 exp_A = np.random.uniform(0, 30) /20 exp_B = np.random.uniform(0, 30) /20 exp_C = np.random.uniform(0, 30) /20 V = np.array([ TRUE_ALPHA[0] + TRUE_B_AGE * age_A + TRUE_B_EXP * exp_A, TRUE_ALPHA[1] + TRUE_B_AGE * age_B + TRUE_B_EXP * exp_B, TRUE_ALPHA[2] + TRUE_B_AGE * age_C + TRUE_B_EXP * exp_C, ])# Log-consideration-adjusted utility (the surrogate likelihood) U_adj = np.log(pi +1e-12) + V U_adj -= U_adj.max() p = np.exp(U_adj) / np.exp(U_adj).sum() choice = np.random.choice(alts, p=p) rows.append({"choice": choice,"age_A": age_A, "age_B": age_B, "age_C": age_C,"exp_A": exp_A, "exp_B": exp_B, "exp_C": exp_C,"grad_degree": grad_degree[i],"is_white": white[i], })choice_df = pd.DataFrame(rows)# ---- Diagnostics ----print("True γ matrix (firms × instruments):")print(f" {'':12s} grad_degree is_white")for j, alt inenumerate(alts):print(f" {alt:12s}{TRUE_GAMMA[j, 0]:>8.1f}{TRUE_GAMMA[j, 1]:>8.1f}")print(f"\nObserved hiring frequencies:")print(choice_df["choice"].value_counts(normalize=True).round(3).to_string())
True γ matrix (firms × instruments):
grad_degree is_white
Firm A 1.5 2.0
Firm B 2.5 0.5
Firm C 0.5 3.0
Observed hiring frequencies:
choice
Firm B 0.502
Firm A 0.306
Firm C 0.192
The \(\gamma\) matrix is the key object. Firm B screens heavily on education (\(\gamma = 2.5\)) but weakly on race (\(\gamma = 0.5\)). Firm C does the opposite (\(\gamma = 0.5\) for education, \(3.0\) for race). A standard multinomial logit collapses both screening dimensions into a single firm-level constant. The consideration set model separates them because the Z instruments create individual-level variation in \(\pi_{nj}\). The posterior must recover not just that bias exists, but which firm screens on which variable.
Note
DGP and model likelihood are aligned. The simulation generates choices directly from the log-consideration-adjusted softmax, i.e. the same reduced-form surrogate we estimate. This is deliberate: it ensures clean parameter recovery. The hard binary gate \(C_{nj} \sim \text{Bernoulli}(\pi_{nj})\) motivates the decomposition, but neither the DGP nor the estimator instantiates it. Both work with the marginalised form. The structural claim is not “firms literally flip coins.” It is “the system’s screening behaviour is as if gated, and we can estimate the gate parameters.”
PyMC Implementation
Code
import pymc as pmimport pytensor.tensor as ptimport arviz as azalt_enc = {"Firm A": 0, "Firm B": 1, "Firm C": 2}choice_df["y"] = choice_df["choice"].map(alt_enc)# X : (N, J, K_x) — utility covariates (age, experience per firm)X = np.stack([ np.column_stack([choice_df["age_A"], choice_df["exp_A"]]), np.column_stack([choice_df["age_B"], choice_df["exp_B"]]), np.column_stack([choice_df["age_C"], choice_df["exp_C"]]),], axis=1) # (N, 3, 2)# Z_tilde : (N, J, K_z) — mean-centred consideration instruments# Same person-level Z for all firms; gamma varies per firm.Z_tilde = Z_tilde_3d # shape: (N, 3, 2), already mean-centredy = choice_df["y"].values
def build_consideration_model( X: np.ndarray, Z: np.ndarray, y: np.ndarray) -> pm.Model:""" Consideration Set Logit — per-firm screening on multiple instruments. gamma_z is now (J, K_z): each firm has its own sensitivity to each screening variable. The model must recover the pattern, not just the existence of bias. No consideration intercept: alpha_j absorbs both baseline utility and average consideration. Only the within-alternative, cross-person variation in z_tilde identifies gamma_z. """ n_obs, n_alts, n_z = Z.shape _, _, n_x = X.shape z_names = ["grad_degree", "is_white"] alt_names_ref = alts[:n_alts -1]with pm.Model( coords={"obs": range(n_obs),"alts": alts,"alts_ref": alt_names_ref,"z_instruments": z_names,"covs": ["age", "experience"], } ) as model:# STAGE 1: Consideration — per-firm, per-instrument gamma_z = pm.Normal("gamma_z", mu=0, sigma=2, dims=("alts", "z_instruments")) Z_data = pm.Data("Z", Z, dims=("obs", "alts", "z_instruments"))# log_odds: (N, J) = sum_k gamma_z[j,k] * Z[n,j,k] log_odds = pt.sum(gamma_z[None, :, :] * Z_data, axis=2) pi = pm.Deterministic("pi", pm.math.sigmoid(log_odds), dims=("obs", "alts"))# log(sigmoid(x)) = x - softplus(x), numerically stable log_pi = log_odds - pt.softplus(log_odds)# STAGE 2: Utility alpha_raw = pm.Normal("alpha_raw", mu=0, sigma=2, dims="alts_ref") alpha = pt.concatenate([alpha_raw, pt.zeros(1)]) # Firm C = 0 beta_age = pm.Normal("beta_age", mu=0, sigma=2) beta_exp = pm.Normal("beta_exp", mu=0, sigma=2) beta = pt.stack([beta_age, beta_exp]) X_data = pm.Data("X", X, dims=("obs", "alts", "covs")) V = alpha[None, :] + pt.sum(X_data * beta[None, None, :], axis=2)# Log-consideration-adjusted utility U_avail = V + log_pi U_c = U_avail - U_avail.max(axis=1, keepdims=True) p = pm.Deterministic("p", pm.math.softmax(U_c, axis=1), dims=("obs", "alts")) y_obs = pm.Data("y_obs", y, dims="obs") _ = pm.Categorical("likelihood", p=p, observed=y_obs)return modelconsideration_model = build_consideration_model(X, Z_tilde, y)print(consideration_model)
<pymc.model.core.Model object at 0x16c2d0e20>
Prior Predictive Check
Code
with consideration_model: prior_pred = pm.sample_prior_predictive(500, random_seed=42)ppc = prior_pred["prior_predictive"]["likelihood"].values # (1, 500, N)prior_freq = np.array([np.mean(ppc == k) for k inrange(3)])fig, ax = plt.subplots(figsize=(6, 3.5))ax.bar(alts, prior_freq, color=[PAL["focal2"], PAL["focal3"], PAL["focal1"]], alpha=0.85)ax.axhline(1/3, color=PAL["ref_line"], linestyle="--", alpha=0.5, label="Uniform baseline (1/3)")ax.set_ylabel("Marginal hiring probability (prior predictive)")ax.set_title("Prior predictive check: consideration set model with screening instruments")ax.set_ylim(0, 0.5)ax.legend()plt.tight_layout()plt.show()
Prior predictive marginal hiring frequencies. The weakly informative prior puts roughly equal weight on all three firms before data is seen.
Fitting (MCMC)
Code
# NOTE: Run with eval: true to sample. Requires PyMC + NumPyro (for NUTS).with consideration_model: idata = pm.sample( draws=1000, tune=1000, target_accept=0.9, random_seed=42, progressbar=True, )
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [gamma_z, alpha_raw, beta_age, beta_exp]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 67 seconds.
Results
The key diagnostic: the \(\gamma_z\) posterior must recover the per-firm pattern. Firm B should show a large slope on grad_degree and a small one on is_white. Firm C should show the opposite. If the model recovers this pattern, we have located bias not just in the system but at a specific firm, on a specific variable.
The model recovers the per-firm screening pattern. Firm B’s consideration jumps sharply with grad_degree but barely moves with is_white. Firm C shows the opposite: race dominates, education barely registers. This is the actionability point. A standard multinomial logit would bury these per-firm differences in aggregate constants. The consideration set model tells you which firm screens on which variable, giving you a target for intervention.
Code
# Side-by-side: transformer attention weights vs posterior consideration probs.# The key payoff visualisation — both tell the same story in their language.## Left: Transformer — soft attention weights across context words# (activist sense: high on protest/justice, low on vault/deposit)# Right: Consideration model — marginal pi per firm, by is_white subgroup# (structural K≠V separation: z instruments drive pi, age/experience drive V)fig, axes = plt.subplots(1, 2, figsize=(13, 4.5))# ---- LEFT: transformer attention (activist sense, top words) ----ax = axes[0]attn_activist_means = [ np.mean(attention_records[0].get(w, [0])) for w in common_words]order = np.argsort(attn_activist_means)[::-1][:12]colors_left = [ PAL["focal2"] if VOCAB_SIGNALS[common_words[i]][0] >0.5# activist signalelse PAL["focal1"] if VOCAB_SIGNALS[common_words[i]][1] >0.5# teller signalelse PAL["grey1"]for i in order]ax.barh( [common_words[i] for i in order], [attn_activist_means[i] for i in order], color=colors_left, alpha=0.85,)ax.axvline(1/len(common_words), color=PAL["ref_line"], linestyle="--", alpha=0.4)ax.set_xlabel("Average attention weight (activist sense)")ax.set_title("Transformer: learned attention\nBlue=activist signal, orange=teller signal")# ---- RIGHT: consideration probabilities — race effect by firm ----ax = axes[1]pi_posterior = idata.posterior["pi"].mean(("chain", "draw")) # (N, J) via xr# Stratify by is_white across all firms to show per-firm patternwhite_mask = choice_df["is_white"].values.astype(bool)labels_r, vals_r, colors_r = [], [], []for j, (firm, col) inenumerate(zip(alts, [PAL["focal2"], PAL["focal3"], PAL["focal1"]])): pi_no = pi_posterior.values[~white_mask, j].mean() pi_yes = pi_posterior.values[white_mask, j].mean() labels_r.extend([f"{firm}\n(non-wh)", f"{firm}\n(white)"]) vals_r.extend([pi_no, pi_yes]) colors_r.extend([col, col])bars = ax.bar(labels_r, vals_r, color=colors_r, alpha=0.85)for i, (bar, v) inenumerate(zip(bars, vals_r)): bar.set_alpha(0.45if i %2==0else0.9) ax.text(bar.get_x() + bar.get_width() /2, v +0.01, f"{v:.2f}", ha="center", fontsize=9)ax.set_ylabel("Posterior mean consideration probability π")ax.set_title("Consideration model: per-firm race effect\n(Firm C screens heavily on race; Firm B barely)")ax.set_ylim(0, 1.1)plt.suptitle("Two paradigms, one mechanism: selective relevance weighting\n""Left = soft attention (transformer); Right = hard-gate consideration (hiring model)", y=1.03, fontsize=11,)plt.tight_layout()plt.show()
Consideration Sets on Real Data: Transport Mode Choice
The hiring demo had clean instruments and known ground truth. Now the harder case: real data where nobody hands you the true \(\gamma_z\) and the instruments are impure.
The data comes from a 1998 stated preference survey in Switzerland. Respondents chose between three transport modes: Train, Swiss Metro (a hypothetical high-speed maglev), and Car, across scenarios that varied travel time, cost, and headway. Each respondent also reported person-level characteristics: age, income, whether the trip was for business, and whether they held a General Abonnement (GA), an annual rail pass covering most Swiss public transport. The question: which person-level characteristics gate consideration of a mode, as distinct from the mode-level attributes that drive utility conditional on consideration?
The GA pass is the key instrument and the key problem. Does holding a GA shift whether you consider rail, or does it shift how attractive rail is conditional on consideration? Plausibly both. We proceed with the assumption that person-level characteristics primarily gate consideration while mode-level attributes (time, cost) drive utility. The framework names this assumption as \(Z \neq X\) and invites scrutiny.
Code
import pandas as pd# ---- Load Swiss Metro data (local copy) ----sm_df = pd.read_csv("swissmetro.dat", sep='\t')# Standard filters (Bierlaire 2003): SP survey, known choice, exclude group=2 with purpose != 1,3sm_df = sm_df[sm_df["SP"] !=0]sm_df = sm_df[sm_df["CHOICE"] !=0]# Availability: keep only rows where at least 2 modes are availablesm_df = sm_df[ (sm_df["TRAIN_AV"] + sm_df["SM_AV"] + sm_df["CAR_AV"]) >=2]# Recode CHOICE to 0-indexed: 1=Train, 2=SM, 3=Car → 0, 1, 2sm_df["y"] = sm_df["CHOICE"] -1print(f"Observations: {len(sm_df)}")print(f"\nChoice frequencies:")choice_labels = {0: "Train", 1: "Swiss Metro", 2: "Car"}for k, v in sm_df["y"].value_counts().sort_index().items():print(f" {choice_labels[k]:12s}: {v:5d} ({v/len(sm_df):.1%})")print(f"\nPerson-level variables (consideration instruments):")print(f" GA (annual pass): {sm_df['GA'].mean():.2%} hold one")print(f" PURPOSE (business): coded 1={sm_df[sm_df['PURPOSE']==1].shape[0]} obs")print(f" AGE distribution: {sm_df['AGE'].value_counts().sort_index().to_dict()}")print(f" INCOME distribution: {sm_df['INCOME'].value_counts().sort_index().to_dict()}")
Observations: 10719
Choice frequencies:
Train : 1423 (13.3%)
Swiss Metro : 6216 (58.0%)
Car : 3080 (28.7%)
Person-level variables (consideration instruments):
GA (annual pass): 14.11% hold one
PURPOSE (business): coded 1=1575 obs
AGE distribution: {1: 711, 2: 3339, 3: 3825, 4: 2025, 5: 810, 6: 9}
INCOME distribution: {0: 306, 1: 1719, 2: 3744, 3: 4041, 4: 909}
Code
# ---- Prepare arrays for the consideration set model ----## Z (consideration instruments): person-level characteristics# - GA: annual rail pass (binary) — gates rail/SM consideration# - PURPOSE==1: business trip (binary) — business travellers may ignore car# - INCOME > 2: higher income (binary) — may screen out slow/cheap modes# - AGE > 3: older traveller (binary) — may screen out novel modes like SM## X (utility covariates): mode-level attributes# - Travel time (TT) per mode, scaled to hours# - Cost (CO) per mode, scaled to CHF/100sm_alts = ["Train", "Swiss Metro", "Car"]n_sm =len(sm_df)# ---- Person-level consideration instruments ----ga = sm_df["GA"].values.astype(float)business = (sm_df["PURPOSE"] ==1).values.astype(float)high_inc = (sm_df["INCOME"] >2).values.astype(float)older = (sm_df["AGE"] >3).values.astype(float)# Composite Z: same instruments for each mode, but gamma varies by mode# This lets e.g. GA shift train consideration differently from car considerationZ_person = np.column_stack([ga, business, high_inc, older]) # (N, 4)Z_means_sm = Z_person.mean(axis=0)Z_tilde_sm = Z_person - Z_means_sm # mean-centred# Stack to (N, J, K_z) — same person-level Z for all modes, gamma_z differs per modeZ_sm = np.stack([Z_tilde_sm, Z_tilde_sm, Z_tilde_sm], axis=1) # (N, 3, 4)# ---- Mode-level utility covariates ----# Travel time in hours, cost in CHF/100X_sm = np.stack([ np.column_stack([sm_df["TRAIN_TT"].values /60, sm_df["TRAIN_CO"].values /100]), np.column_stack([sm_df["SM_TT"].values /60, sm_df["SM_CO"].values /100]), np.column_stack([sm_df["CAR_TT"].values /60, sm_df["CAR_CO"].values /100]),], axis=1) # (N, 3, 2)# Availability mask: set unavailable modes' utility to -infavail = np.column_stack([ sm_df["TRAIN_AV"].values, sm_df["SM_AV"].values, sm_df["CAR_AV"].values]) # (N, 3)y_sm = sm_df["y"].values# ---- Individual-level index for random effects ----# Each respondent (ID) appears in multiple choice scenarios.# The random effect on travel time is per-individual, not per-observation.unique_ids = sm_df["ID"].unique()id_to_idx = {uid: i for i, uid inenumerate(unique_ids)}individual_idx = sm_df["ID"].map(id_to_idx).values # (N,) — maps each obs to its individualn_individuals =len(unique_ids)print(f"Z_sm shape: {Z_sm.shape} (N, J, K_z)")print(f"X_sm shape: {X_sm.shape} (N, J, K_x)")print(f"Availability: {avail.mean(axis=0).round(2)} (Train, SM, Car)")print(f"Individuals: {n_individuals} (observations per individual: "f"{len(sm_df) / n_individuals:.1f} avg)")print(f"\nZ instrument means (pre-centring):")for name, m inzip(["GA", "Business", "High income", "Older"], Z_means_sm):print(f" {name:12s}: {m:.3f}")
Z_sm shape: (10719, 3, 4) (N, J, K_z)
X_sm shape: (10719, 3, 2) (N, J, K_x)
Availability: [1. 1. 0.84] (Train, SM, Car)
Individuals: 1191 (observations per individual: 9.0 avg)
Z instrument means (pre-centring):
GA : 0.141
Business : 0.147
High income : 0.462
Older : 0.265
The model adds one refinement over the hiring demo: a per-individual random coefficient on travel time. Not everyone values speed equally. Business commuters on tight schedules are acutely time-sensitive; retirees on leisure trips much less so. A fixed \(\beta_{tt}\) forces a single time-sensitivity onto the entire population. Unmodelled heterogeneity in utility does not vanish — it bleeds into the consideration stage, attenuating the \(\gamma_z\) slopes (this is aliasing problem 2 from the identification callout). The random effect \(\beta_{tt,i} \sim \mathcal{N}(\mu_{tt}, \sigma_{tt}^2)\) absorbs that heterogeneity where it belongs: in the utility stage. Since each respondent appears in multiple choice scenarios, the random effect is defined per individual, not per observation — an index lookup maps each observation to its respondent’s time sensitivity. This is the mixed logit component, the same extension that motivates combining consideration sets with random-coefficient models in the pymc-marketing implementation.
import pymc as pmimport pytensor.tensor as ptdef build_swissmetro_consideration( X: np.ndarray, Z: np.ndarray, avail: np.ndarray, y: np.ndarray, individual_idx: np.ndarray =None, n_individuals: int=None,) -> pm.Model:""" Consideration set model for Swiss Metro with random coefficients: - Person-level Z instruments (GA, business, income, age) driving consideration - Mode-level X covariates (travel time, cost) driving utility - Per-individual random effect on travel time sensitivity (mixed logit) - Availability constraints (unavailable modes get -inf utility) Each mode gets its own gamma_z vector: a GA pass may strongly increase train consideration but have no effect on car consideration. Consideration intercepts (gamma_0) absorb baseline consideration for each mode, freeing alpha_j to capture only utility preference and gamma_z slopes to capture marginal instrument effects. The aliasing between gamma_0 and alpha_j is managed by priors and the sigmoid nonlinearity, not eliminated. The random effect on beta_tt is per-individual, not per-observation. Each respondent appears in multiple choice scenarios; the same person has the same time sensitivity across scenarios. The index lookup maps each observation to its individual: beta_tt_n[obs] = beta_tt_i[individual_idx[obs]]. """ n_obs, n_alts, n_z = Z.shape _, _, n_x = X.shape z_names = ["GA", "Business", "High_income", "Older"] x_names = ["travel_time", "cost"] alt_names_ref = sm_alts[:n_alts -1] # reference-coded alternativeswith pm.Model( coords={"obs": range(n_obs),"individual": range(n_individuals),"alts": sm_alts,"alts_ref": alt_names_ref,"z_instruments": z_names,"x_covs": x_names, } ) as model:# ---- STAGE 1: Consideration ----# Consideration intercepts: baseline log-odds of considering each mode# These absorb mode-specific baseline consideration, freeing gamma_z# slopes to capture the marginal effect of each instrument. gamma_intercept = pm.Normal("gamma_intercept", mu=0, sigma=2, dims="alts")# gamma_z: (J, K_z) — each mode has its own sensitivity to each instrument gamma_z = pm.Normal("gamma_z", mu=0, sigma=2, dims=("alts", "z_instruments")) Z_data = pm.Data("Z", Z, dims=("obs", "alts", "z_instruments"))# log_odds_consideration: (N, J) = gamma_0[j] + sum_k gamma_z[j,k] * Z[n,j,k] log_odds_c = gamma_intercept[None, :] + pt.sum(gamma_z[None, :, :] * Z_data, axis=2) pi = pm.Deterministic("pi", pm.math.sigmoid(log_odds_c), dims=("obs", "alts"))# log(sigmoid(x)) = x - softplus(x), numerically stable log_pi = log_odds_c - pt.softplus(log_odds_c)# ---- STAGE 2: Utility (mixed logit) ---- alpha_raw = pm.Normal("alpha_raw", mu=0, sigma=2, dims="alts_ref") alpha = pt.concatenate([alpha_raw, pt.zeros(1)]) # Car = reference# Travel time: per-individual random coefficient (non-centred)# beta_tt_i = beta_tt_mu + beta_tt_sigma * offset_i beta_tt_mu = pm.Normal("beta_tt_mu", mu=0, sigma=2) beta_tt_sigma = pm.HalfNormal("beta_tt_sigma", sigma=1) beta_tt_offset = pm.Normal("beta_tt_offset", mu=0, sigma=1, dims="individual") beta_tt_i = pm.Deterministic("beta_tt_i", beta_tt_mu + beta_tt_sigma * beta_tt_offset, dims="individual", ) # (n_individuals,)# Index lookup: map each observation to its individual's beta_tt idx = pm.Data("individual_idx", individual_idx, dims="obs") beta_tt_n = beta_tt_i[idx] # (N,)# Cost: fixed coefficient (less heterogeneity expected) beta_cost = pm.Normal("beta_cost", mu=0, sigma=2) X_data = pm.Data("X", X, dims=("obs", "alts", "x_covs"))# V_nj = alpha_j + beta_tt_n * tt_nj + beta_cost * cost_nj V = (alpha[None, :]+ beta_tt_n[:, None] * X_data[:, :, 0]+ beta_cost * X_data[:, :, 1])# ---- Availability + log-consideration-adjusted utility ---- avail_data = pm.Data("avail", avail.astype(float), dims=("obs", "alts"))# Unavailable modes: set to large negative number U_avail = log_pi + V + pt.log(avail_data +1e-8) U_c = U_avail - U_avail.max(axis=1, keepdims=True) p = pm.Deterministic("p", pm.math.softmax(U_c, axis=1), dims=("obs", "alts")) y_obs = pm.Data("y_obs", y, dims="obs") _ = pm.Categorical("likelihood", p=p, observed=y_obs)return modelsm_model = build_swissmetro_consideration( X_sm, Z_sm, avail, y_sm, individual_idx=individual_idx, n_individuals=n_individuals,)print(sm_model)
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [gamma_intercept, gamma_z, alpha_raw, beta_tt_mu, beta_tt_sigma, beta_tt_offset, beta_cost]
Code
# ---- Consideration parameters: gamma_z by mode ----# Each row is a mode, each column is a person-level instrument.# A positive gamma means that characteristic INCREASES consideration of that mode.fig, axes = plt.subplots(1, 3, figsize=(14, 4), sharey=True)z_names = ["GA", "Business", "High_income", "Older"]z_colors = [PAL["focal2"], PAL["focal1"], PAL["focal3"], PAL["brown"]]for j, (ax, mode) inenumerate(zip(axes, sm_alts)): gamma_post = sm_idata.posterior["gamma_z"].sel(alts=mode)for k, z_name inenumerate(z_names): samples = gamma_post.sel(z_instruments=z_name).values.flatten() ax.hist(samples, bins=40, alpha=0.6, label=z_name, density=True, color=z_colors[k]) ax.axvline(0, color=PAL["ref_line"], linestyle="--", alpha=0.5) ax.set_title(f"{mode}") ax.set_xlabel("γ_z")if j ==0: ax.set_ylabel("Posterior density") ax.legend(fontsize=7)plt.suptitle("Swiss Metro: person-level screening parameters γ_z by mode\n""Positive = characteristic increases consideration; zero line = no effect", y=1.03)plt.tight_layout()plt.show()# ---- Utility parameters (population-level) ----az.plot_posterior( sm_idata, var_names=["beta_tt_mu", "beta_tt_sigma", "beta_cost", "alpha_raw"], figsize=(14, 3),)plt.suptitle("Utility stage: travel time (random), cost (fixed), mode constants\n""β_tt_mu = population mean time sensitivity; β_tt_sigma = person-level spread", y=1.02,)plt.tight_layout()plt.show()# ---- Distribution of individual time sensitivities ----# beta_tt_i has shape (chains, draws, n_individuals) — flatten across chains/draws,# then take the mean per individual to show the distribution of person-level estimates.beta_tt_i_post = sm_idata.posterior["beta_tt_i"] # (chain, draw, individual)beta_tt_i_mean = beta_tt_i_post.mean(("chain", "draw")).values # (n_individuals,)fig, ax = plt.subplots(figsize=(8, 3.5))ax.hist(beta_tt_i_mean, bins=50, density=True, color=PAL["focal2"], alpha=0.7, edgecolor="white", linewidth=0.3)ax.axvline(0, color=PAL["ref_line"], linestyle="--", alpha=0.5)mu_hat = sm_idata.posterior["beta_tt_mu"].values.mean()ax.axvline(mu_hat, color=PAL["focal1"], linestyle="-", alpha=0.8, label=f"Population mean ≈ {mu_hat:.2f}")ax.set_xlabel("β_tt (posterior mean per individual)")ax.set_ylabel("Density")ax.set_title(f"Person-level heterogeneity in travel time sensitivity "f"({len(beta_tt_i_mean)} individuals)")ax.legend(fontsize=9)plt.tight_layout()plt.show()
Code
# ---- Stratified consideration: GA holders vs non-holders ----# The GA pass is the most interesting instrument: it should strongly shift# consideration of rail modes (Train, SM) but have little effect on Car.pi_post_sm = sm_idata.posterior["pi"].mean(("chain", "draw")) # (N, J)ga_mask = sm_df["GA"].values.astype(bool)fig, ax = plt.subplots(figsize=(9, 4))labels_sm = []vals_sm = []colors_sm = []alphas_sm = []mode_colors = {"Train": PAL["focal2"], "Swiss Metro": PAL["focal3"], "Car": PAL["focal1"]}for j, mode inenumerate(sm_alts): pi_no_ga = pi_post_sm.values[~ga_mask, j].mean() pi_yes_ga = pi_post_sm.values[ga_mask, j].mean() labels_sm.extend([f"{mode}\n(no GA)", f"{mode}\n(GA)"]) vals_sm.extend([pi_no_ga, pi_yes_ga]) colors_sm.extend([mode_colors[mode], mode_colors[mode]]) alphas_sm.extend([0.45, 0.9])bars = ax.bar(labels_sm, vals_sm, color=colors_sm)for bar, v, a inzip(bars, vals_sm, alphas_sm): bar.set_alpha(a) ax.text(bar.get_x() + bar.get_width() /2, v +0.01,f"{v:.2f}", ha="center", fontsize=10)ax.set_ylabel("Posterior mean consideration probability π")ax.set_title("Consideration probability by GA status\n""Dark = GA holder; light = no GA")ax.set_ylim(0, 1.1)plt.tight_layout()plt.show()
Before interpreting any posterior, check convergence. A Bayesian model that has not converged is a random number generator with delusions of structure. The trace plots show marginal posterior densities (left) and raw MCMC chains (right). We want smooth, unimodal densities and chains that look like hairy caterpillars: rapidly mixing, stationary, overlapping. Chains that drift, stick, or separate indicate the sampler has not explored the posterior.
Trace plots: left column shows posterior densities, right column shows MCMC chains. Well-mixed chains and smooth densities indicate convergence.
The chains mix well. The posteriors are unimodal. The summary table provides the numerical diagnostics. \(\hat{R}\) measures between-chain versus within-chain variance: values at or near 1.00 confirm convergence. Effective sample size (ESS) tells us how many independent draws the chains are worth after accounting for autocorrelation. Values in the hundreds or thousands mean the posteriors are well-estimated. \(\hat{R}\) above 1.01 or ESS below 100 would warrant longer chains or a reparameterisation.
The summary table and trace plots give the full picture. Two results deserve comment: the GA pass and the random effects on travel time.
The GA pass as negative gate. The GA effect on Car is catastrophic (\(\gamma_z \approx -9.5\), HDI: \([-10.8, -8.1]\)): holders of the annual rail pass do not consider driving. The HDI excludes zero by a wide margin. The GA effect on Train, by contrast, is essentially zero (\(\gamma_z \approx 0.0\), HDI: \([-0.8, 0.7]\)). Once the consideration intercept absorbs the baseline, there is no additional boost from holding the pass. The commitment is already total. The pass does its strongest work not by promoting rail but by eliminating the car alternative. It is a negative gate, not a positive one. The remaining instruments (business trip, income, age) show a consistent pattern: negative slopes indicating narrower consideration sets, not wider ones. Habitual travellers have already decided.
Random effects on travel time. The per-individual random coefficient reveals massive heterogeneity: \(\beta_{tt,\mu} \approx -1.81\), \(\beta_{tt,\sigma} \approx 1.89\). The standard deviation is nearly as large as the mean. Individuals one standard deviation above the mean are barely time-sensitive (\(\mu + \sigma \approx 0.1\)); those one standard deviation below are extremely so (\(\mu - \sigma \approx -3.7\)). Without this random effect, the unmodelled heterogeneity bleeds into the consideration stage, attenuating the \(\gamma_z\) slopes. This is the aliasing problem described earlier. The per-individual parameterisation absorbs the heterogeneity where it belongs: in the utility stage.
The specification arrived through iteration, not foresight. The first model dropped consideration intercepts and had no random effects. The GA pass showed a flipped sign on Train. Dramatic but shallow: an artefact of uncentred instruments and missing intercepts. Adding per-individual random effects on travel time absorbed the heterogeneity that was bleeding into the consideration slopes. Restoring the consideration intercepts freed \(\alpha_j\) from doing double duty. Each perturbation changed the posteriors. What survived across all three specifications was the GA effect on Car: always catastrophically negative, always with an HDI excluding zero. What didn’t survive was the GA effect on Train. Its sign and magnitude shifted with the intercept specification. That asymmetry is the finding. The pass acts as a negative gate on alternatives, not a positive promoter of rail. Only the sensitivity analysis made this visible.
This is the core advantage over transformer attention. The consideration set model makes every modelling choice visible and arguable: the intercept specification, the exclusion restriction, the choice of which variables enter \(Z\) versus \(X\). When a multi-head model distributes mode preferences across anonymous subspaces, there is nothing to argue about.
Part III: What the Analogy Buys You
The most productive difference is the hard/soft divide. Transformer attention is a differentiable relaxation of discrete consideration. Cross-pollination runs both ways: attention weights read better as “how likely is this token to be decision-relevant” than “how much does this token contribute,” motivating sparsity-inducing variants (Sparsemax, \(\alpha\)-entmax). In the other direction, the \(Q/K/V\) factorisation suggests consideration models could separate the signal that drives consideration from the utility that drives choice.
But the deeper difference is about what each framework makes arguable.
Behavioural Audit, Not Mechanistic Explanation
Mechanistic interpretability (probing classifiers, causal tracing, activation patching) asks: where inside the model does the bias live? This is the streetlight question. It searches where the representations are visible. The consideration set approach asks something different: does the system’s screening behaviour exhibit bias, and on which variables? It does not need to find the keys under the lamppost. It applies to any screener, human or algorithmic, without access to weights or architecture.
Internal attribution methods face a related problem. SHAP, LIME, and integrated gradients assume feature contributions are additive or locally linear. Multi-head attention violates both assumptions. The bias lives in the interaction of concatenated value vectors through the GELU MLP. A SHAP value for “applicant name” tells you the marginal contribution of that token holding other features fixed. It does not tell you whether the name gated consideration or shifted utility conditional on screening. The distinction matters: a blocked consideration funnel and a low utility score produce the same outcome (rejection) but require different interventions.
More sophisticated methods get closer. Concept erasure (LEACE) projects out the linear subspace encoding a protected attribute and checks whether performance drops. Causal abstraction tests whether a structural decomposition — “this component screens on race” — faithfully describes the network. Both ask the right question. Neither can answer it without the same assumption the consideration model makes explicit: an exclusion restriction that separates the screening channel from the utility channel. The identification problem is not a limitation of any particular method. It is the problem itself.
The consideration set model does not try to open the black box. It wraps a separate behavioural model around it. The \(\gamma_z\) posteriors say the system’s decisions are consistent with a process where a given variable gates consideration with a given magnitude and uncertainty. This is a functional claim. It describes behaviour, not mechanism, with the same epistemic status as résumé audit studies that surface discrimination without opening the hiring pipeline. The consideration model decomposes screening from conditional-on-screening evaluation, and the exclusion restriction makes that decomposition identified.
Finding Instruments is the Hard Part
Everything rests on the exclusion restriction\(Z \neq X\). The two demos illustrate the range. In hiring, protected characteristics satisfy it almost by definition. “Should not affect job performance” is not a pious hope but the content of anti-discrimination law. When the exclusion restriction and the legal standard say the same thing, the modelling assumption is as defensible as it gets. The Swiss Metro case is harder. The GA pass is an impure instrument, pulled between consideration and utility. The framework names this tension rather than hiding it.
Other domains are harder still. In content recommendation, what feature shifts whether an item is considered without shifting its quality? In credit scoring? In triage? The framework does not pretend instruments are easy to find. It makes the difficulty visible as a modelling assumption you must defend. The same identification problem applies inside the transformer. Without exogenous variation that isolates the bias channel, internal attribution is underdetermined, for the same reason that a consideration model without valid instruments is unidentified.
Conclusion
Two metaphors, one problem. The streetlight effect: we search where the light is, not where the keys fell. The accountability sink: the architecture ensures there is no one to blame when the keys go missing. Multi-head attention is both at once. The bias is real, distributed across anonymous subspaces where it resists inspection, and no single head is responsible. Wrap the architecture in a vendor API and the sink deepens. The firm points to the model. The vendor points to the data. The data records past human decisions that nobody alive made. Accountability does not vanish. It is delegated into a void.
Consideration set models are the counter-architecture. They name the variables that gate access, estimate how strongly each one operates, and report their uncertainty. \(\gamma_z[\text{Car, GA}] \approx -9.5\) is not a diffuse activation pattern. It is a claim, with a sign, a magnitude, and a credible interval. You can show it to a regulator. You can argue about whether the instrument is valid. You can be wrong — and that is the point. A framework that can be wrong can also be held accountable.
Naming what you model is not a stylistic preference. It is audit infrastructure.