Linear Recurrent Unit in TensorFlow 2.10

I wanted to check if my RNN theory worked on state-space models. So I implemented the Linear Recurrent Unit in TensorFlow 2.10, and since I have it, why not to share it? I tried to make it a clean code that could be easy to use and understand. In the coming days I’ll turn it into a pip package. The LRU was introduced in Resurrecting Recurrent Neural Networks for Long Sequences at ICML, and belongs to the state-space models family, which are models able to handle extremely long sequences more gracefully than attention based architectures. You can find here the JAX implementation that I took as a reference, as recommended by one of the authors of the LRU.

I’d like to complete the job with a JAX and PyTorch implementations. However, parallel scans are not implemented native in PyTorch, as noted here. However custom implementations exist, such as this one.

I implement the LRU unit and also the final LRU residual block used in the paper. For both I provide a recurrent form and a scan form. In my tests, the scan form was up to 300x faster than the recurrent form on a GPU, giving the same output. You can import them as follows:

from lru_unofficial.tf.linear_recurrent_unit import LinearRecurrentUnitCell, LinearRecurrentUnitFFN
from lru_unofficial.tf.linear_recurrent_unit import ResLRUCell, ResLRUFFN

Enjoy!

Written on November 14, 2023