Skip to content

Recurrent Neural Networks and Attention: Mathematical Notes and Intuition

1. Sequence Modeling Problem Setup

Many deep learning tasks use ordered data:

  • Language modeling
  • Machine translation
  • Speech recognition
  • Time-series forecasting

For a sequence input $$ X = (x_1, x_2, \dots, x_T) $$ a model must preserve order and context over time.


2. Vanilla RNN: Core Equations

A recurrent neural network updates its hidden state step-by-step:

\[ \begin{aligned} h_t &= \phi(W_h h_{t-1} + W_x x_t + b_h) \\ y_t &= W_y h_t + b_y \end{aligned} \]

Where:

  • \(x_t\): input at time step \(t\)
  • \(h_t\): hidden state (memory)
  • \(y_t\): output at time step \(t\)
  • \(\phi\): activation function (typically \(\tanh\) or ReLU)

The key idea is parameter sharing across time: the same \(W_h, W_x, W_y\) are reused at every step.

flowchart LR
  X1["x1"] --> H1["h1"]
  H0["h0"] --> H1
  X2["x2"] --> H2["h2"]
  H1 --> H2
  X3["x3"] --> H3["h3"]
  H2 --> H3
  H1 --> Y1["y1"]
  H2 --> Y2["y2"]
  H3 --> Y3["y3"]

3. Training RNNs with BPTT

RNNs are trained using Backpropagation Through Time (BPTT):

  • Unroll the computation graph across time steps.
  • Compute loss over sequence outputs.
  • Backpropagate gradients from later time steps to earlier ones.

3.1 Vanishing Gradient Challenge

For long sequences, repeated multiplication by Jacobians causes gradients to shrink:

  • Early tokens receive weak learning signal.
  • Long-range dependencies become hard to learn.

This is the main reason gated architectures (LSTM/GRU) and attention mechanisms became essential.


4. Practical RNN Variants

4.1 Bidirectional RNN (BiRNN)

BiRNN uses two passes:

  • Forward pass: \(x_1 \to x_T\)
  • Backward pass: \(x_T \to x_1\)

Combined representation: $$ h_t^{(\text{bi})} = [\overrightarrow{h_t}; \overleftarrow{h_t}] $$

Useful when future context is available (offline NLP), less suitable for strict real-time streaming.

4.2 Stacked (Deep) RNN

Multiple recurrent layers are stacked: $$ h_t^{(l)} = \phi\left(W_h^{(l)} h_{t-1}^{(l)} + W_x^{(l)} h_t^{(l-1)} + b^{(l)}\right) $$

Benefits:

  • Lower layers capture local patterns.
  • Upper layers capture more abstract temporal structure.

5. Sequence-to-Sequence (Encoder-Decoder) Bottleneck

For variable-length input/output tasks (e.g., translation), encoder-decoder RNNs were introduced.

5.1 Fixed Context Vector Limitation

Classical seq2seq compresses the whole input into one fixed vector \(c\): $$ c = h_T^{(\text{enc})} $$

Decoder then generates outputs from \(c\).

Limitation:

  • Information bottleneck for long inputs
  • Early-token details may be lost
  • All source tokens are indirectly forced into one summary
flowchart LR
  A["Input Sequence"] --> B["Encoder RNN"]
  B --> C["Fixed Context c"]
  C --> D["Decoder RNN"]
  D --> E["Output Sequence"]

6. Attention Mechanism: Why It Works

Attention removes the single-vector bottleneck by creating a dynamic context per decoder step.

At decoder step \(t\), compute alignment scores between decoder state \(s_{t-1}\) and each encoder state \(h_i\): $$ e_{t,i} = a(s_{t-1}, h_i) $$

Convert to normalized weights: $$ \alpha_{t,i} = \frac{\exp(e_{t,i})}{\sum_j \exp(e_{t,j})} $$

Build context as weighted sum: $$ c_t = \sum_{i=1}^{T_x} \alpha_{t,i} h_i $$

Interpretation:

  • \(\alpha_{t,i}\) tells how much source position \(i\) contributes to target step \(t\).
  • Different output words can attend to different source words.
flowchart TD
  H["Encoder States h1...hT"] --> S["Score with decoder state"]
  S --> A["Softmax -> attention weights alpha"]
  A --> C["Weighted sum -> context c_t"]
  C --> O["Decoder output at step t"]

7. Query-Key-Value View

Attention can be framed as retrieval:

  • Query \(Q\): what current step is looking for
  • Keys \(K\): descriptors of available memory slots
  • Values \(V\): content to retrieve

7.1 Scaled Dot-Product Attention

\[ \text{Attention}(Q,K,V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V \]

Why divide by \(\sqrt{d_k}\):

  • Dot products grow with dimension.
  • Scaling stabilizes softmax and gradients.

7.2 Common Scoring Functions

  • Dot-product score: \(q^\top k\)
  • Scaled dot-product score: \(q^\top k / \sqrt{d_k}\)
  • Additive score (Bahdanau): small feed-forward alignment network

8. Self-Attention and Multi-Head Attention

8.1 Self-Attention

When \(Q,K,V\) are all derived from the same sequence representation, it is self-attention.

8.2 Multi-Head Attention

A single attention map may miss different relation types. Multi-head attention learns multiple relation subspaces in parallel:

\[ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$ $$ \text{MHA}(Q,K,V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O \]

Each head can focus on different structure: syntax, positional linkage, semantic dependency, or coreference.


9. Teacher Forcing vs Inference Decoding

9.1 Teacher Forcing (Training)

At step \(t\), decoder receives true previous token \(y_{t-1}^{\text{true}}\).

Advantage:

  • Faster and more stable training.

9.2 Free Running (Inference)

At step \(t\), decoder receives its own previous prediction \(\hat{y}_{t-1}\).

Challenge:

  • Errors can accumulate across steps (exposure bias).

10. Geometric Interpretation of Attention Weights

In dot-product attention, score \(q^\top k\) is a similarity measure in representation space.

  • Larger alignment -> larger softmax weight.
  • Context vector is a convex combination of value vectors.

So attention computes a task-dependent projection of memory onto the query direction.


11. End-to-End Workflow (Exam-Friendly)

  1. Encode source sequence into hidden states \((h_1,\dots,h_T)\).
  2. Initialize decoder state.
  3. For each target step \(t\):
  4. Compute scores \(e_{t,i}\).
  5. Compute weights \(\alpha_{t,i}\) using softmax.
  6. Form context \(c_t\).
  7. Produce output token distribution.
  8. Train with teacher forcing + cross-entropy.
  9. Infer with autoregressive decoding.

12. What RNN + Attention Solves

RNNs provide temporal memory, but struggle with very long dependencies. Attention adds direct access to all encoder states, reducing bottlenecks and improving alignment in long or complex sequences.

This progression explains the practical evolution:

  • Vanilla RNN -> gated RNNs (LSTM/GRU) -> attention-enhanced seq2seq -> transformer-style full attention models.

13. RNN vs GRU vs LSTM (Architecture and Practical Trade-offs)

This comparison is important for model selection in sequence tasks.

13.1 Vanilla RNN

\[ h_t = \phi(W_h h_{t-1} + W_x x_t + b_h) \]

Strengths: - Simple architecture - Fewer parameters - Fastest per-step computation

Limitations: - Vanishing gradients on long sequences - Weak long-term memory retention

Best fit: - Short sequences - Lightweight baselines

13.2 GRU (Gated Recurrent Unit)

GRU introduces update and reset gates:

\[ z_t = \sigma(W_z x_t + U_z h_{t-1} + b_z), \quad r_t = \sigma(W_r x_t + U_r h_{t-1} + b_r) \]
\[ \tilde{h}_t = \tanh(W_h x_t + U_h(r_t \odot h_{t-1}) + b_h) \]
\[ h_t = (1-z_t)\odot h_{t-1} + z_t \odot \tilde{h}_t \]

Strengths: - Better long-range handling than vanilla RNN - Fewer gates and parameters than LSTM - Often faster training than LSTM

Best fit: - Medium/long sequences when compute is constrained - Strong default when speed-performance balance is needed

13.3 LSTM (Long Short-Term Memory)

LSTM uses a cell state and three gates:

\[ \begin{aligned} f_t &= \sigma(W_f[h_{t-1},x_t] + b_f) \\ i_t &= \sigma(W_i[h_{t-1},x_t] + b_i) \\ \tilde{C}_t &= \tanh(W_c[h_{t-1},x_t] + b_c) \\ C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \\ o_t &= \sigma(W_o[h_{t-1},x_t] + b_o) \\ h_t &= o_t \odot \tanh(C_t) \end{aligned} \]

Strengths: - Strongest explicit memory control - Robust for longer dependencies - Stable in many real sequence problems

Limitations: - More parameters than GRU/RNN - Higher compute and memory cost

Best fit: - Long context tasks - Higher accuracy priority over speed

13.4 Quick Comparison Table

Model Memory Mechanism Parameters Training Speed Long Dependency Handling
RNN Hidden state only Low High Low
GRU Update + reset gates Medium Medium-High Medium-High
LSTM Cell state + 3 gates High Medium High

13.5 Practical Selection Rule

  1. Start with GRU for strong baseline efficiency.
  2. Move to LSTM if long-range dependency errors remain.
  3. Keep vanilla RNN for short-sequence baselines or tiny-resource settings.

14. Reference for Comparison Section