LSTM (torch.nn.LSTM
)
A torch.nn.LSTM
corresponds to a Long-short Term Memory module. That is - in essence - an arrangement or torch.nn.LSTMCell
that can process an input tensor containing a sequence of steps and can use cell arrangements of configurable depth. The equations that rule a torch.nn.LSTM
are the same as torch.nn.LSTMCell
, except for the sequence length and the number of layers. Differently from a single torch.nn.LSTMCell
, a torch.nn.LSTM
has a hidden state per sequence step (\(h_t\)), an input gate per sequence step (\(i_t\)), a forget gate per sequence step (\(f_t\)), a cell gate per sequence step (\(g_t\)), and an output gate per sequence step (\(o_t\)), thus, there is also a \(c\) tensor per time step (\(c_t\)), and an \(h\) tensor per time step (\(h_t\)). Please note that the current step hidden state \(h_t\) depends on the previous step hidden state \(h_{t-1}\).
Where
- \(x\) is the input tensor of size \(\left(L, H_{in}\right)\) or \(\left(L, N, H_{in}\right)\) when
batch_first=True
, or \(\left(N, L, H_{in}\right)\) whenbatch_first=True
. - \(h_t\) is the hidden state tensor at sequence step \(t\) of size \(\left(N, H_{out}\right)\) or \(\left(H_{out}\right)\).
- \(H_{in}\) and \(H_{out}\) are the number of input and output features, respectively.
- \(L\) is the sequence length.
- \(N\) is the batch size.
- \(c\) is the cell state tensor of size \(\left(N, H_{out}\right)\) or \(\left(H_{out}\right)\).
- \(W_{ii}\), \(W_{if}\), \(W_{ig}\) and \(W_{io}\) are weight tensors of size \(\left(H_{out}, H_{in}\right)\) in the first layer and \(\left(H_{out}, D\times H_{out}\right)\) in subsequent layers.
- \(W_{hi}\), \(W_{hf}\), \(W_{hg}\) and \(W_{ho}\) are weight tensors of size \(\left(H_{out}, H_{out}\right)\).
- \(\sigma\) is the sigmoid function and can be defined as \(\sigma\left(x\right)=\frac{1}{1+e^{-x}}\).
- \(\text{tanh}\) is the hyperbolic tangent function and can be defined as \(\text{tanh}\left(x\right)=\frac{e^x-e^{-x}}{e^x+e^{-x}}\).
- \(\odot\) is the Hadamard product or element-wise product.
- \(b_{ii}\), \(b_{hi}\), \(b_{if}\), \(b_{hf}\), \(b_{ig}\), \(b_{hg}\), \(b_{io}\) and \(b_{ho}\) are bias tensors of size \(\left(H_{out}\right)\).
Note
Please note that some weight tensor sizes may differ from Pytorch torch.nn.LSTM
's documentation due to the fact that some tensors are stacked. For instance, \(W_{ii}\), \(W_{if}\), \(W_{ig}\) and \(W_{io}\) tensors of each layer are implemented as a single tensor of size \(\left(4\times H_{out}, H_{in}\right)\) for the first layer, and \(\left(4\times H_{out}, D\times H_{out}\right)\) for subsequent layers. Similarly \(W_{hi}\), \(W_{hf}\), \(W_{hg}\) and \(W_{ho}\) are implemented as a single tensor of size \(\left(4\times H_{out}, H_{out}\right)\). The number of layers is controlled by the num_layers
parameter, and the number of directions \(D\) is controlled by the birectional
parameter.
Note
The complexity of the dropout
parameter is not considered in the following calculations, since it is usually temporarily used during training and then disabled during inference.
Complexity
It is possible to reuse the calculation for torch.nn.GRUCell
to estimate the complexity of torch.nn.GRU
. However, there are a couple of additional considerations. First, when num_layers > 1
, the second layer takes the output(s) of the first layer as input. This means that \(W_{ii}\), \(W_{if}\), \(W_{ig}\) and \(W_{io}\) will have size \(\left(H_{out}, H_{out}\right)\) if bidirectional=False
and size \(\left(H_{out}, 2\times H_{out}\right)\) if bidirectional=True
. Secondly and differently from torch.nn.LSTMCell
, torch.nn.LSTM
can process an input containing bigger sequence lenghts, therefore the same calculations estimated before will repeat \(L\) times where \(L\) sequence length.
Warning
Please review the torch.nn.LSTMCell
complexity documentation before continuing, as the subsequent sections will reference formulas from that layer without re-deriving them.
Unidirectional
The complexity of the first layer is the same as as torch.nn.LSTMCell
that if bias=True
can be simplified to
and when bias=False
For subsequent layers it is necessary to replace \(H_{in}\) by \(H_{out}\), then when bias=True
Total complexity
Now it is necessary to include the sequence length \(L\) in the input tensor \(x\) to obtain the total complexity, since the previous calculation will be repeatead \(L\) times. The total complexity for bidirectional=False
is
When bias=True
this expression becomes
and when bias=False
Bidirectional
For the case of bidirectional=True
the same considerations explained at the beginning of this section should be taken into account. Additionally, each cell will approximately duplicate its calculations because one subset of the output is calculated using the forward direction of the input sequence \(x\), and the remaining one uses the reverse input sequence \(x\). Please note that each direction of the input sequence will have its own set of weights. Finally, both outputs will be concatenated to produce a tensor of size \(\left(L, N, D\times H_{out}\right)\) with \(D=2\) in this case. When num_layers > 1
, this is also the size of the input size for layers after the first one.
The complexity of the first layer when bidirectional=True
and bias=True
is
and when bias=False
For subsequent layers it is necessary to replace \(H_{in}\) by \(2\times H_{out}\). Then when bias=True
and when bias=False
Total complexity
Now it is necessary to include the sequence length \(L\) in the input tensor \(x\) to obtain the total complexity, since the previous calculation will be repeatead \(L\) times. The total complexity for bidirectional=True
is
When bias=True
this expression becomes
and when bias=False
Summary
The number of operations performed by a torch.nn.LSTM
module can be estimated as
\(\text{LSTM}_{ops} = 8\times L\times N \times H_{out}\times \left(H_{in}+\left(2\times\text{num\_layers}-1\right)\times H_{out}+3.875\times\text{num\_layers}\right)\)
\(\text{LSTM}_{ops} = 8\times L\times N \times H_{out}\times \left(H_{in}+\left(2\times\text{num\_layers}-1\right)\times H_{out}+2.875\times\text{num\_layers}\right)\)
\(\text{LSTM}_{ops} = 16\times L\times N \times H_{out}\times \left(H_{in}+\left(3\times\text{num\_layers}-2\right)\times H_{out}+3.875\times\text{num\_layers}\right)\)
\(\text{LSTM}_{ops} = 16\times L\times N \times H_{out}\times \left(H_{in}+\left(3\times\text{num\_layers}-2\right)\times H_{out}+2.875\times\text{num\_layers}\right)\)
Where
- \(L\) is the sequence length.
- \(N\) is the batch size.
- \(H_{in}\) and \(H_{out}\) are the number of input and output features, respectively.
- \(\text{num\_layers}\) is the number of layers.