Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIRE Relative Positional Encodings #2325

Open
kaddu341 opened this issue Jan 31, 2025 · 2 comments
Open

FIRE Relative Positional Encodings #2325

kaddu341 opened this issue Jan 31, 2025 · 2 comments

Comments

@kaddu341
Copy link

Hi,

I'm currently working on the length generalization capabilities of transformers. As shown by Zhou et al. (https://arxiv.org/abs/2402.09371), FIRE positional encodings are excellent for this purpose as they can yield generalization results up to 2.5x the input length (in combination with other techniques).

FIRE, which stands for Functional Interpolation for Relative Positional Encodings, was introduced by Li et al. (https://arxiv.org/pdf/2310.04418).

I am planning to implement the algorithm from this paper myself, but I thought it would be useful if I could turn it into a PyTorch module so that others can benefit too. (I posted this originally in the Pytorch Core repo, but they suggested to bring it here). Therefore, I am proposing to add this feature to the torchtune library.

Please let me know what you think!

There are many other positional encoding types (sinusoidal, RoPE, learned, etc.), but for the specific task of length generalization, FIRE seems to be the most suitable based on several papers, which is why I am proposing this feature addition.

Like other relative attention mechanisms, FIRE introduces positional information in the attention layers rather than adding it to the input.

Here is a screenshot of some evaluation results for FIRE from the original paper (Li et al., 2024):

Image
@acisseJZhong
Copy link
Contributor

Hi @kaddu341, thank you for bringing up this new positional embeddings that can well extrapolate length. Looks pretty exciting! We would love to have it in Torchtune to allow more users to benefit from it.

Would you be open to draft an RFC about the design of FIRE so people can review and comment? It would be helpful to include context, motivation, and details on modules/files you plan to add/change. You can find example RFC here #102, #2105

@kaddu341
Copy link
Author

kaddu341 commented Feb 4, 2025

Sure!

Context:

The length generalization of transformers is a problem currently generating much interest in the machine learning community. Generalization performance depends largely on the type of positional encoding used. Over the years, several new positional encoding schemes have been proposed, including FIRE (Li et al., 2023/2024) which has been shown to produce impressive generalization results (Zhou et al., 2024).

Motivation:

For many tasks, the length of the training sequences that we have available is less than the length of the sequences we would like to run inference on. This is especially important for physical simulations, but also for certain natural language processing (NLP) tasks.

Details:

Like most other relative positional encoding schemes, FIRE adds a bias matrix to the standard self-attention computation (before applying softmax) as follows:

Image

where $B$ is calculated by:

Image

(Note: $f_\theta$ here is a standard MLP i.e. Linear->ReLU->Linear->ReLU network

In code, this can be expressed as follows (please note this is not production code and may have typos; it is just to give an idea of the algorithm)

# class to implement a single self-attention head
class FIRE_AttentionHead(nn.Module):
    def __init__(self, dim_model, kdim, hidden_size):
        super().__init__()
        self.kdim = kdim

        # initialize parameter matrices
        self.W_q = nn.Linear(dim_model, kdim, bias=False)
        self.W_k = nn.Linear(dim_model, kdim, bias=False)
        self.W_v = nn.Linear(dim_model, kdim, bias=False)

        # initialize learnable scalars
        self.c = nn.Parameter(torch.tensor(1.0))
        self.L = nn.Parameter(torch.tensor(2.0))

        # initialize learnable continuous function
        self.f_theta = nn.Sequential(
            nn.Linear(1, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, 1)
        )
    
    # concave function to amplify differences among local positions
    def phi(self, c, x):
        return torch.log(c*x + 1)
    
    def forward(self, src):
        # Assuming src has shape (batch_size, seq_length, dim_model)
        batch_size = src.shape[0]
        seq_length = src.shape[1]

        # constrain c to be > 0
        c = torch.nn.functional.softplus(self.c)

        # compute bias matrix
        B = torch.zeros(seq_length, seq_length)
        for i in range(1, seq_length):
            for j in range(0, i):
                B[i, j] = self.phi(c, i - j) / self.phi(c, max(self.L, i + 1))
        # make it antisymmetric:
        B = B - B.T
        # apply MLP to bias matrix
        B = self.f_theta(B.unsqueeze(2)).squeeze(2)
        # repeat B for batch_size
        B = B.repeat(batch_size, 1, 1)

        # get Query, Key, and Value matrices for each sequence
        Q = self.W_q(src)
        K = self.W_k(src)
        V = self.W_v(src)

        # calculate attention scores
        K_T = torch.transpose(K, 1, 2)
        attn_logits  = torch.bmm(Q, K_T) + B    # / np.sqrt(self.kdim)
        attn_weights = torch.nn.functional.softmax(attn_logits, dim=-1)
        attn_outputs = torch.bmm(attn_weights, V)
        return attn_outputs

# class to implement multi-head self-attention with relative positional encodings according to the FIRE paper
class FIRE_SelfAttention(nn.Module):
    def __init__(self, dim_model, num_heads, hidden_size):
        super().__init__()

        # make sure num_heads divides dim_model:
        assert dim_model % num_heads == 0, "Number of heads must divide dimension of model"

        # compute kdim = vdim
        kdim = dim_model // num_heads

        # initialize attention heads
        self.attention_heads = nn.ModuleList([FIRE_AttentionHead(dim_model, kdim, hidden_size) for _ in range(num_heads)])

        # final linear layer
        self.W_o = nn.Linear(dim_model, dim_model, bias=False)

    def forward(self, src):
        # src should have shape (batch_size, seq_length, dim_model)
        # Pass src through the attention heads
        attn_results = [attn_head(src) for attn_head in self.attention_heads]
        # concatenate results
        attn_results = torch.cat(attn_results, dim=-1)
        # pass through final linear layer
        return self.W_o(attn_results)

I am planning to make this class inherit from nn.Module and basically function as a drop-in replacement for nn.MultiHeadAttention.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants