GRU (torch.nn.GRU
)
A torch.nn.GRU
corresponds to a Gated Recurrent Unit. That is - in essence - an arrangement or torch.nn.GRUCell
that can process an input tensor containing a sequence of step and can use cell arrangements of configurable depth. The equations that rule a torch.nn.GRU
are the same as torch.nn.GRUCell
, except for the sequence length and the number of layers. Differently from a single torch.nn.GRUCell
, a torch.nn.GRU
has a hidden state per sequence step (\(h_t\)), a reset gate per time step (\(r_t\)) and an update gate per time step (\(z_t\)), thus, there is also a \(n\) tensor per time step (\(n_t\)). Please note that the current step hidden state \(h_t\) depends on the previous step hidden state \(h_{t-1}\).
Where
- \(x\) is the input tensor of size \(\left(L, H_{in}\right)\) or \(\left(L, N, H_{in}\right)\) when
batch_first=False
, or \(\left(N, L, H_{in}\right)\) whenbatch_first=True
. - \(h_t\) is the hidden state tensor at sequence step \(t\) of size \(\left(N, H_{out}\right)\) or \(\left(H_{out}\right)\).
- \(H_{in}\) and \(H_{out}\) are the number of input and output features, respectively.
- \(L\) is the sequence length.
- \(N\) is the batch size.
- \(W_{ir}\), \(W_{iz}\) and \(W_{in}\) are weight tensors of size \(\left(H_{out}, H_{in}\right)\) in the first layer and \(\left(H_{out}, D\times H_{out}\right)\) in subsequent layers.
- \(W_{hr}\), \(W_{hz}\) and \(W_{hn}\) are weight tensors of size \(\left(H_{out}, H_{out}\right)\)
- \(D\) is \(2\) if
bidirectional=True
and \(1\) ifbidirectional=False
. - \(\sigma\) is the sigmoid function and can be defined as \(\sigma\left(x\right)=\frac{1}{1+e^{-x}}\).
- \(\text{tanh}\) is the hyperbolic tangent function and can be defined as \(\text{tanh}\left(x\right)=\frac{e^x-e^{-x}}{e^x+e^{-x}}\).
- \(\odot\) is the Hadamard product or element-wise product.
- \(b_{ir}\), \(b_{iz}\), \(b_{in}\), \(b_{hr}\), \(b_{hz}\) and \(b_{hn}\) are bias tensors of size \(\left(H_{out}\right)\).
Note
Please note that some weight tensor sizes may differ from Pytorch torch.nn.GRU
's documentation due to the fact that some tensors are stacked. For instance, \(W_{ir}\), \(W_{iz}\) and \(W_{in}\) tensors of each layer are implemented as a single tensor of size \(\left(3\times H_{out}, H_{in}\right)\) for the first layer, and \(\left(3\times H_{out}, D\times H_{out}\right)\) for subsequent layers. Similarly \(W_{hr}\), \(W_{hz}\) and \(W_{hn}\) are implemented as a single tensor of size \(\left(3\times H_{out}, H_{out}\right)\). The number of layers is controlled by the num_layers
parameter, and the number of directions \(D\) is controlled by the birectional
parameter.
Note
The complexity of the dropout
parameter is not considered in the following calculations, since it is usually temporarily used during training and then disabled during inference.
Complexity
It is possible to reuse the calculation for torch.nn.GRUCell
to estimate the complexity of torch.nn.GRU
. However, there are a couple of additional considerations. First, when num_layers > 1
, the second layer takes the output(s) of the first layer as input. This means that \(W_{ir}\), \(W_{iz}\) and \(W_{in}\) will have size \(\left(H_{out}, H_{out}\right)\) if bidirectional=False
and size \(\left(H_{out}, 2\times H_{out}\right)\) if bidirectional=True
. Secondly and differently from torch.nn.GRUCell
, torch.nn.GRU
can process an input containing bigger sequence lenghts, therefore the same calculations estimated before will repeat \(L\) times where \(L\) sequence length.
Warning
Please review the torch.nn.GRUCell
complexity documentation before continuing, as the subsequent sections will reference formulas from that layer without re-deriving them.
Unidirectional
The complexity of the first layer is the same as as torch.nn.GRUCell
that if bias=True
can be simplified to
and when bias=False
For subsequent layers it is necessary to replace \(H_{in}\) by \(H_{out}\), then when bias=True
and when bias=False
Total complexity
Now it is necessary to include the sequence length \(L\) in the input tensor \(x\) to obtain the total complexity, since the previous calculation will be repeatead \(L\) times. The total complexity for bidirectional=False
is
When bias=True
this expression becomes
and when bias=False
Bidirectional
For the case of bidirectional=True
the same considerations explained at the beginning of this section should be taken into account. Additionally, each cell will approximately duplicate its calculations because one subset of the output is calculated using the forward direction of the input sequence \(x\), and the remaining one uses the reverse input sequence \(x\). Please note that each direction of the input sequence will have its own set of weights, even though this is not documented at the moment of writing this documentation. Finally, both outputs will be concatenated to produce a tensor of size \(\left(L, N, D\times H_{out}\right)\) with \(D=2\) in this case. When num_layers > 1
, this is also the size of the input size for layers after the first one.
The complexity of the first layer when bidirectional=True
and bias=True
is
and when bias=False
For subsequent layers it is necessary to replace \(H_{in}\) by \(2\times H_{out}\). Then when bias=True
and when bias=False
Total complexity
Now it is necessary to include the sequence length \(L\) in the input tensor \(x\) to obtain the total complexity, since the previous calculation will be repeatead \(L\) times. The total complexity for bidirectional=True
is
When bias=True
this expression becomes
and when bias=False
Summary
The number of operations performed by a torch.nn.GRU
module can be estimated as
\(\text{GRU}_{ops} = 6\times L\times N \times H_{out}\times \left(H_{in}+\left(2\times\text{num\_layers}-1\right)\times H_{out}+3.5\times\text{num\_layers}\right)\)
\(\text{GRU}_{ops} = 6\times L\times N \times H_{out}\times \left(H_{in}+\left(2\times\text{num\_layers}-1\right)\times H_{out}+2.5\times\text{num\_layers}\right)\)
\(\text{GRU}_{ops} = 12\times L\times N \times H_{out}\times \left(H_{in}+\left(3\times\text{num\_layers}-2\right)\times H_{out}+3.5\times\text{num\_layers}\right)\)
\(\text{GRU}_{ops} = 12\times L\times N \times H_{out}\times \left(H_{in}+\left(3\times\text{num\_layers}-2\right)\times H_{out}+2.5\times\text{num\_layers}\right)\)
Where
- \(N\) is the batch size.
- \(H_\text{in}\) is the number of input features.
- \(H_\text{out}\) is the number of output features.
- \(L\) is the sequence length.
- \(\text{num\_layers}\) is the number of layers. When
num_layers > 1
, the output of the first layer is fed into the second one.