File size: 655 Bytes
2e5bb45 0be6bfd 2e5bb45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import torch
import torch.nn as nn
from .activation import GELU
class FeedForward(nn.Module):
'''
Feed-forward neural network with GELU activation function.
- Multi-Head Self-Attention → Captures relationships between tokens.
- Feedforward Neural Network (FFN) → Processes each token independently after attention.
'''
def __init__(self, cfg):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
GELU(),
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
)
def forward(self, x):
return self.layers(x) |