GRUCell (torch.nn.GRUCell
)
A torch.nn.GRUCell
corresponds to a single cell of a Grated Recurrent Unit (torch.nn.GRU
). A torch.nn.GRUCell
takes an input \(x\), a hidden state \(h\). Internally, it
has a reset gate \(r\) and an update gate \(z\) that help to propagate information between time steps. These are combined to generate \(n\), that is then used to create a new hidden state \(h\prime\). The relationship between these tensors is defines as
Where
- \(x\) is the input tensor of size \(\left(N, H_{in}\right)\) or \(\left(H_{in}\right)\).
- \(h\) is the hidden state tensor of size \(\left(N, H_{out}\right)\) or \(\left(H_{out}\right)\).
- \(H_{in}\) and \(H_{out}\) are the number of input and output features, respectively.
- \(N\) is the batch size.
- \(W_{ir}\), \(W_{iz}\) and \(W_{in}\) are weight tensors of size \(\left(H_{out}, H_{in}\right)\).
- \(W_{hr}\), \(W_{hz}\) and \(W_{hn}\) are weight tensors of size \(\left(H_{out}, H_{out}\right)\).
- \(\sigma\) is the sigmoid function and can be defined as \(\sigma\left(x\right)=\frac{1}{1+e^{-x}}\).
- \(\text{tanh}\) is the hyperbolic tangent function and can be defined as \(\text{tanh}\left(x\right)=\frac{e^x-e^{-x}}{e^x+e^{-x}}\).
- \(\odot\) is the Hadamard product or element-wise product.
- \(b_{ir}\), \(b_{iz}\), \(b_{in}\), \(b_{hr}\), \(b_{hz}\) and \(b_{hn}\) are bias tensors of size \(\left(H_{out}\right)\).
Complexity
In order to compute the complexity of a single torch.nn.GRUCell
, it is necessary to estimate the number of operations of all four aforementioned equations. For the sake of simplicity, for operations involving sigmoid and hyperbolic tangent, the listed equations will be used and exponentials will be counted as a single operation.
Note
During the following operations, some tensors have to be transposed in order to have compatible dimensions to perform matrix multiplication, even thought this is not explicitly mentioned in PyTorch torch.nn.GRUCel
’s documentation. Additionally, some weight tensors are stacked. For instance, \(W_{ir}\), \(W_{iz}\) and \(W_{in}\) are implemented as a single tensor of size \(\left(3\times H_{out}, H_{in}\right)\), and \(W_{hr}\), \(W_{hz}\) and \(W_{hn}\) are implemented as a single tensor of size \(\left(3\times H_{out}, H_{out}\right)\), possibly due to efficiency reasons.
Reset gate
The tensor sizes involved in the operations performed to calculate the reset gate \(r\) are
In this case, \(x\) (with shape \(\left(N, H_{in}\right)\)) and \(h\) (with shape \(\left(N, H_{out}\right)\)) have to be transposed. Additionally, \(b_{ir}\) and \(b_{hr}\) will be implicitly broadcasted to be able to be summed with the tensor multiplication results. Then, the unwrapped and transposed shapes involved in the operations are
This will result in a tensor of shape \(\left(H_{out}, N\right)\). To estimate the complexity of this operation, it is possible to reuse the results from torch.nn.Linear
for both matrix multiplications and add the sigmoid operations \(\sigma\). \(r_{ops}\) (the operations of the reset gate \(r\)) can be then broken down into four parts:
- The operations needed to compute \(W_{ir}x+b_{ir}\).
- The operations needed to compute \(W_{hi}h+b_{hi}\).
- The operations needed to sum both results.
- The operations needed to compute the sigmoid function \(\sigma\) of this result.
For simplicity sake, the following definitions will be used:
Then, in terms of operations (\(ops\)) when bias=True
and when bias=False
Update gate
Since the dimensions of this gate are the same as the reset gate \(r\), it is trivial to observe that
\(n\)
\(n\) has a slightly different configuration. Besides the matrix multiplications, there is Hadamard product \(\odot\) and an hyperbolic tangent \(\text{tanh}\) function. The involved tensor sizes are
Again, it becomes necessary to broadcast and transpose some tensors to be able to perform all operations, resulting in
Now \(n_{ops}\) (the operations performed by \(n\)) can be divided into five parts:
- The operations needed to compute \(W_{in}x+b_{in}\).
- The operations needed to compute \(W_{hn}h+b_{hn}\).
- The operations needed to compute the Hadamard product between \(r\) and \(W_{hn}h+b_{hn}\).
- The operations needed to sum the terms that result into the \(\text{tanh}\) function argument.
- The operations needed to compute the hyperbolic tangent \(\text{tanh}\) function of this result.
Then, the different parts that contribute to \(n_{ops}\) can be defined as
Then, when bias=True
and when bias=False
Note
Please note that there are many possible formulations for the amount of operations carried out by \(\text{tanh}\). In this calculation, the formula mentioned at the beginning is what is being used. In such a case, there are 7 operations per element are: x4 exponentials, x1 sum, x1 difference and x1 division. Please also note that we are ignoring sign inversion operations, assuming these will usually have a negligible computational cost.
\(h\prime\)
The operations computed to obtain \(h\prime\) can be divided into four parts:
- The operations needed to subtract \(\left(1-z\right)\).
- The operations needed to compute the Hadamard product between \(\left(1-z\right)\) and \(n\).
- The operations needed to compute the Hadamard product between \(z\) and \(h\).
- The operations needed to sum both Hadamard product results.
In this case, all operations are element wise operations, therefore it is trivial to see that every part will contribute with \(N\times H_{out}\) operations, therefore
Total complexity
Finally, the total complexity is the sum of all individual contributions
In the case of bias=True
, the total number of operations is
and for bias=False
Summary
The number of operations \(\phi\) performed by a torch.nn.GRUCell
module can be estimated as
\(\text{GRUCell}_{ops} = 6\times N \times H_{out}\times\left(H_{in}+H_{out}+3.5\right)\)
\(\text{GRUCell}_{ops} = 6\times N \times H_{out}\times\left(H_{in}+H_{out}+2.5\right)\)
Where
- \(N\) is the batch size.
- \(H_\text{in}\) is the number of input features.
- \(H_\text{out}\) is the number of output features.