LSTMCell (torch.nn.LSTMCell
)
A torch.nn.LSTMCell
correspond to a single cell of a Long Short-Term Memory Layer (torch.nn.LSTM
). A torch.nn.LSTMCell
takes in an input \(x\), a hidden state \(h\) and a cell state \(c\) . Internally, it has an input gate \(i\), a forget gate \(f\), a cell gate \(g\) and an output gate \(o\) that help to propagate information between sequence steps. These are combined to generate the torch.nn.LSTMCell
outputs. The relationship between these tensors is defined as
Where
- \(x\) is the input tensor of size \(\left(N, H_{in}\right)\) or \(\left(H_{in}\right)\).
- \(h\) is the hidden state tensor of size \(\left(N, H_{out}\right)\) or \(\left(H_{out}\right)\).
- \(c\) is the cell state tensor of size \(\left(N, H_{out}\right)\) or \(\left(H_{out}\right)\).
- \(W_{ii}\), \(W_{if}\), \(W_{ig}\) and \(W_{io}\) are weight tensors of size \(\left(H_{out}, H_{in}\right)\).
- \(W_{hi}\), \(W_{hf}\), \(W_{hg}\) and \(W_{ho}\) are weight tensors of size \(\left(H_{out}, H_{out}\right)\).
- \(\sigma\) is the sigmoid function and can be defined as \(\sigma\left(x\right)=\frac{1}{1+e^{-x}}\).
- \(\text{tanh}\) is the hyperbolic tangent function and can be defined as \(\text{tanh}\left(x\right)=\frac{e^x-e^{-x}}{e^x+e^{-x}}\).
- \(\odot\) is the Hadamard product or element-wise product.
- \(b_{ii}\), \(b_{hi}\), \(b_{if}\), \(b_{hf}\), \(b_{ig}\), \(b_{hg}\), \(b_{io}\) and \(b_{ho}\) are bias tensors of size \(\left(H_{out}\right)\).
Complexity
In order to compute the complexity of a single torch.nn.LSTMCell
, it is necessary to estimate the number of operations of all six aforementioned equations. For the sake of simplicity, for operations involving sigmoid and hyperbolic tangent, the aforementioned equations will be used and exponentials will be counted as a single operation.
Note
During the following operations, some tensors have to be transposed in order to have compatible dimensions to perform matrix multiplication, even thought this is not explicitly mentioned in PyTorch torch.nn.LSTMCell’s
documentation. Additionally, some weight tensors are stacked. For instance, \(W_{ii}\), \(W_{if}\) \(W_{ig}\) and \(W_{io}\) are implemented as a single tensor of size \(\left(4\times H_{out},H_{in} \right)\), and \(W_{hi}\), \(W_{hf}\), \(W_{hg}\) and \(W_{ho}\) are implemented as a single tensor of size \(\left(4\times H_{out},H_{out} \right)\), possibly due to efficiency reasons.
Input gate
The tensor sizes involved in the operations performed to calculate the input gate \(i\) are
In this case, \(x\) (with shape \(\left(N, H_{in}\right)\)) and \(h\) (with shape \(\left(N, H_{out}\right)\)) have to be transposed. Additionally, \(b_{ii}\) and \(b_{hi}\) will be implicitly broadcasted to be able to be summed with the tensor multiplication results. Then, the unwrapped and transposed shapes involved in the operations are
This will result in a tensor of shape \(\left(H_{out}, N\right)\). To estimate the complexity of this operation, it is possible to reuse the results from torch.nn.Linear
for both matrix multiplications and add the sigmoid operations \(\sigma\). \(i_{ops}\) (the operations of the input gate \(i\)) can be then broken down into four parts:
- The operations to needed compute \(W_{ii}x+b_{ii}\).
- The operations needed to compute \(W_{hi}h+b_{hi}\).
- The operations needed to sum both results.
- The operations needed to compute the sigmoid function \(\sigma\) of this result.
For simplicity sake, the following definitions will be used:
Then, in terms of operations (\(ops\)) when bias=True
and when bias=False
Forget and output gates
Since the dimensions of these gates are the same as the input gate \(i\), it is trivial to observe that
Cell gate
The argument of the \(\text{tanh}\) function has the same shape as the previously computed gates, yet the complexity of this function itself is the only difference between this gate and the others, then
Replacing by the previously calculated results
\(c\prime\)
The complexity of \(c\prime\) corresponds to three element-wise operations between elements with shape \(\left(H_{out}, N\right)\). Therefore its complexity is
\(h\prime\)
The complexity of \(h\prime\) corresponds to one element-wise operation and a \(\text{tanh}\) operation
Total complexity
Finally, the total complexity is the sum of all individual contributions
In the case of bias=True
, the total number of operations is
and for bias=False
Summary
The number of operations \(\phi\) operformed by a torch.nn.LSTMCell
module can be estimated as
\(\text{LSTMCell}_{ops} = 8\times N\times H_{out}\times\left( H_{in}+H_{out}+3.875\right)\)
\(\text{LSTMCell}_{ops} = 8\times N\times H_{out}\times\left(H_{in}+H_{out}+2.875\right)\)
Where
- \(N\) is the batch size.
- \(H_\text{in}\) is the number of input features.
- \(H_\text{out}\) is the number of output features.