Summary

LSTM (Long Short-Term Memory) is a type of RNN designed to overcome the vanishing gradient problem and remember information over longer sequences.

  • Input vector : the data point at the current time step (e.g., a word embedding)
  • Hidden state : the output at time , passed to the next time step and possibly the next layer
  • Cell state : internal memory that carries forward important information through time

The LSTM uses gates to control information flow — allowing it to selectively “remember” or “forget” things.

Gates in an LSTM Cell

At each time step , the LSTM processes:

  • the current input
  • the previous hidden state
  • the previous cell state

Forget Gate

What does it do?

The forget gate decides what information to erase from the previous cell state.

  • Outputs values in
  • 0 = forget completely, 1 = retain fully

Input Gate

What does it do?

The input gate decides what new information should be stored in the memory cell.

Compute the candidate memory values:


Update the Cell State

Update the internal memory:

  • Forget some of the past
  • Add some of the new

📤 Output Gate

What does it do?

The output gate decides what part of the memory to pass forward as the hidden state.

Final hidden state output:

PyTorch Implementation

import torch
import torch.nn as nn
 
class CustomLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
 
        # Combine all gates into one matrix for efficiency:
        # [forget_gate | input_gate | candidate_memory | output_gate]
        self.linear = nn.Linear(input_size + hidden_size, 4 * hidden_size)
 
    def forward(self, x_t, h_prev, C_prev):
        # Concatenate previous hidden state and current input
        combined = torch.cat((h_prev, x_t), dim=1)  # shape: [batch_size, input_size + hidden_size]
 
        # Apply linear transformation to get all gate pre-activations
        gates = self.linear(combined)  # shape: [batch_size, 4 * hidden_size]
        f_t, i_t, g_t, o_t = torch.chunk(gates, chunks=4, dim=1)
 
        # Apply activations
        f_t = torch.sigmoid(f_t)       # forget gate
        i_t = torch.sigmoid(i_t)       # input gate
        g_t = torch.tanh(g_t)          # candidate memory
        o_t = torch.sigmoid(o_t)       # output gate
 
        # Update cell state
        C_t = f_t * C_prev + i_t * g_t
 
        # Compute new hidden state
        h_t = o_t * torch.tanh(C_t)
 
        return h_t, C_t
 
class CustomLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.cell = CustomLSTMCell(input_size, hidden_size)
 
    def forward(self, x):
        # x: shape [seq_len, batch_size, input_size]
        seq_len, batch_size, input_size = x.size()
        h_t = torch.zeros(batch_size, self.cell.hidden_size)
        C_t = torch.zeros(batch_size, self.cell.hidden_size)
 
        outputs = []
 
        for t in range(seq_len):
            x_t = x[t]
            h_t, C_t = self.cell(x_t, h_t, C_t)
            outputs.append(h_t)
 
        return torch.stack(outputs)  # shape: [seq_len, batch_size, hidden_size]

References

Understanding LSTM Networks — colah’s blog