In [1]:
import torch
In [2]:
class VanillaRNN:
    def __init__(self):
        self._initial_state = torch.zeros(3)
        self.Wxh = torch.nn.Linear(in_features=2, out_features=3)
        self.Whh = torch.nn.Linear(in_features=3, out_features=3)
        self.embedding = torch.nn.Embedding(num_embeddings=4, embedding_dim=2)

    def forward(self, seq):
        h_t = self._initial_state
        outs = []
        for x_t in seq:
            x_t = self.embedding(torch.tensor(x_t))
            h_t = self._f(h_t, x_t)
            outs.append(h_t)
        outs = torch.stack(outs)
        return outs, h_t

    def _f(self, h_t: torch.tensor, x_t: torch.tensor):
        return torch.tanh(self.Wxh(x_t) + self.Whh(h_t))
In [3]:
torch.random.manual_seed(1234)

with torch.no_grad():
    rnn = VanillaRNN()
    vocab = {'I': 0, 'have': 1, 'socks': 2, '.': 3}
    outs, h_t = rnn.forward(seq=[vocab[x] for x in ['I', 'have', 'socks', '.']])
In [4]:
outs
Out[4]:
tensor([[-0.6570,  0.2775,  0.4972],
        [-0.9238, -0.3656, -0.7586],
        [-0.8673, -0.3682,  0.3496],
        [-0.2447,  0.0817,  0.6288]])
In [5]:
h_t
Out[5]:
tensor([-0.2447,  0.0817,  0.6288])