In [1]:
import torch
from torch.nn.functional import softmax
import math
import numpy as np
import matplotlib.pyplot as plt
In [2]:
torch.random.manual_seed(1234)
np.random.seed(1234)

Let $T=5$ and $D=3$

In [3]:
T = 5
D = 3

Initialize $Q_{proj}$, $K_{proj}$ and $V_{proj}$, the projection tensors using random numbers

In [4]:
Q_proj = torch.randn((D, D))
K_proj = torch.randn((D, D))
V_proj = torch.randn((D, D))
In [5]:
def get_features(x: torch.tensor):
    """
    A function which computes the attention weights A and multiplies by feature tensor V

    :param x: the input tensor
    :return:
    """
    Q = x @ Q_proj
    K = x @ K_proj
    V = x @ V_proj

    A = Q @ K.T * (1.0 / math.sqrt(D))
    A = softmax(A, dim=1)
    y = A @ V
    return y

Without positional encoding¶

First create a $T \times D$ input tensor and get the attention output

In [6]:
x = torch.randn((T, D))
y = get_features(x=x)

Now shuffle the input tensor and get the attention output.

We can see that the output is shuffled the same as the input

In [7]:
shuffled_order = np.random.choice(range(T), size=T, replace=False)
x_perm = x[shuffled_order]
y_perm = get_features(x=x_perm)

_, ax = plt.subplots(ncols=2, figsize=(10, 5))
ax[0].imshow(y, vmin=y.min(), vmax=y.max())
ax[0].set_yticks(range(T))
ax[0].set_xticks([])
ax[0].set_title('y without shuffle')
ax[0].set_xlabel('y')
ax[0].set_ylabel('T')


ax[1].imshow(y_perm, vmin=y.min(), vmax=y.max())
ax[1].set_yticks(range(T), labels=shuffled_order)
ax[1].set_xticks([])
ax[1].set_title('y, inputs shuffled')
ax[1].set_xlabel('y')
ax[1].set_ylabel('T')
plt.show()
No description has been provided for this image

With positional encoding¶

Now let's create a positional embedding tensor and add that to the input tensor.

Now if we shuffle the input, we get 2 totally different outputs based on the order of the input.

In [8]:
positional_embedding = torch.randn((T, D))
y_with_pos = get_features(x=x + positional_embedding)
y_shuffle_with_pos = get_features(x=x[shuffled_order] + positional_embedding)

_, ax = plt.subplots(ncols=2, figsize=(10, 5))
ax[0].imshow(y_with_pos, vmin=y.min(), vmax=y.max())
ax[0].set_yticks(range(T))
ax[0].set_xticks([])
ax[0].set_title('y with positional encoding')
ax[0].set_xlabel('y')
ax[0].set_ylabel('T')

ax[1].imshow(y_shuffle_with_pos, vmin=y.min(), vmax=y.max())
ax[1].set_yticks(range(T), labels=shuffled_order)
ax[1].set_xticks([])
ax[1].set_title('y with positional encoding, inputs shuffled')
ax[1].set_xlabel('y')
ax[1].set_ylabel('T')
plt.show()
No description has been provided for this image