In [1]:
import torch
In [2]:
class VanillaRNN:
def __init__(self):
self._initial_state = torch.zeros(3)
self.Wxh = torch.nn.Linear(in_features=2, out_features=3)
self.Whh = torch.nn.Linear(in_features=3, out_features=3)
self.embedding = torch.nn.Embedding(num_embeddings=4, embedding_dim=2)
def forward(self, seq):
h_t = self._initial_state
outs = []
for x_t in seq:
x_t = self.embedding(torch.tensor(x_t))
h_t = self._f(h_t, x_t)
outs.append(h_t)
outs = torch.stack(outs)
return outs, h_t
def _f(self, h_t: torch.tensor, x_t: torch.tensor):
return torch.tanh(self.Wxh(x_t) + self.Whh(h_t))
In [3]:
torch.random.manual_seed(1234)
with torch.no_grad():
rnn = VanillaRNN()
vocab = {'I': 0, 'have': 1, 'socks': 2, '.': 3}
outs, h_t = rnn.forward(seq=[vocab[x] for x in ['I', 'have', 'socks', '.']])
In [4]:
outs
Out[4]:
tensor([[-0.6570, 0.2775, 0.4972], [-0.9238, -0.3656, -0.7586], [-0.8673, -0.3682, 0.3496], [-0.2447, 0.0817, 0.6288]])
In [5]:
h_t
Out[5]:
tensor([-0.2447, 0.0817, 0.6288])