File size: 655 Bytes
2e5bb45
 
 
0be6bfd
2e5bb45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch.nn as nn

from .activation import GELU

class FeedForward(nn.Module):
    '''
    Feed-forward neural network with GELU activation function.
    - Multi-Head Self-Attention → Captures relationships between tokens.
    - Feedforward Neural Network (FFN) → Processes each token independently after attention.
    '''
    
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
        )
        
    def forward(self, x):
        return self.layers(x)