Designed to process sequential data, by processing elements one at a time, while maintaining a memory (hidden state) of previous inputs.
Used in language, time-series data, audio, etc
Key
Input vector xt: the data point at the current time step.
Hidden state ht: a vector that represents the RNN’s memory, updated at each time step
Output vector yt: the output at each time step, which may depend on ht
Updating
At each time step t, the RNN takes in the current input xt, combines it with the previous hidden state ht−1, produces a new hidden state ht, and optionally produces an output yt.
ht=tanh(Wxxt+Whht−1+bh)
Wx = weight matrix that connects the input xt to the hidden layer
Wh = weight matrix that connects the previous hidden state ht+1 to the current hidden state
bh = bias term
Generating
yt=Wyht+by
Wy = weight matrix for output layer
by= bias for output layer
Backpropagation through time
In training a RNN, you need to minimize the loss over all previous time steps
Unroll the RNN through time (turning it into a deep feedforward network where each layer corresponds to a different time step)
Apply standard backprop through unrolled structure
Accumulate gradients for the shared weights at each time step
Update parameters
Forward pass T, inputs x1,...,xT generate hidden states h1,...,hT and outputs y1,...,yT.
At each time step:
ht=tanhWxhxt+Whhht−1+bh,
yt=Whyht+by,
lt=LOSS(yt,y^t)
For a sequence of length
Backward pass
During the backward pass, we compute gradients of the total loss L=∑t=1Tℓt with respect to all trainable parameters: Wxh, Whh, and Why.
First, compute the output gradient at each time step:
δty=∂yt∂ℓt
∂Why∂L=∑t=1Tδtyht⊤
Then compute gradients flowing into the hidden states:
δth=δtyWhy⊤+δt+1h⋅∂ht∂ht+1
where ∂ht∂ht+1=diag(1−ht+12)⋅Whh⊤ (from tanh derivative)
For each time step t, accumulate parameter gradients:
∂Whh∂L+=δth⊙(1−ht2)⋅ht−1⊤
∂Wxh∂L+=δth⊙(1−ht2)⋅xt⊤
The gradients are summed across all time steps due to weight sharing across t=1,...,T.
Optionally, apply gradient clipping to avoid exploding gradients: