| 
							 | 
						from typing import Callable | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						import torch | 
					
					
						
						| 
							 | 
						import torch.nn as nn | 
					
					
						
						| 
							 | 
						import torch.nn.functional as F | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						def softmax(x: torch.Tensor, dim: int | tuple[int, ...]) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    Compute the softmax along the specified dimensions. | 
					
					
						
						| 
							 | 
						    This function adds the option to specify multiple dimensions | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Args: | 
					
					
						
						| 
							 | 
						        x (torch.Tensor): Input tensor. | 
					
					
						
						| 
							 | 
						        dims (int or tuple[int]): The dimension or list of dimensions along which the softmax probabilities are computed. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    Returns: | 
					
					
						
						| 
							 | 
						        torch.Tensor: Output tensor containing softmax probabilities along the specified dimensions. | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    dtype = x.dtype | 
					
					
						
						| 
							 | 
						    x = x.to(torch.float32) | 
					
					
						
						| 
							 | 
						    max_vals = torch.amax(x, dim=dim, keepdim=True) | 
					
					
						
						| 
							 | 
						    e_x = torch.exp(x - max_vals) | 
					
					
						
						| 
							 | 
						    sum_exp = e_x.sum(dim=dim, keepdim=True) | 
					
					
						
						| 
							 | 
						    return (e_x / sum_exp).to(dtype) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						class SteelSoftMoEV3(nn.Module): | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						    A wrapper class to create a Soft Mixture of Experts layer. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						    From "From Sparse to Soft Mixtures of Experts" | 
					
					
						
						| 
							 | 
						    https://arxiv.org/pdf/2308.00951.pdf | 
					
					
						
						| 
							 | 
						    """ | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def __init__( | 
					
					
						
						| 
							 | 
						        self, | 
					
					
						
						| 
							 | 
						        config, | 
					
					
						
						| 
							 | 
						        layer: Callable, | 
					
					
						
						| 
							 | 
						    ) -> None: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            dim (int): Dimensionality of input features. | 
					
					
						
						| 
							 | 
						            num_experts (int): Number of experts. | 
					
					
						
						| 
							 | 
						            slots_per_expert (int): Number of token slots per expert. | 
					
					
						
						| 
							 | 
						            layer (Callable): Network layer of the experts. | 
					
					
						
						| 
							 | 
						            normalize (bool): Normalize input and phi (sec. 2.3 from paper) | 
					
					
						
						| 
							 | 
						            **layer_kwargs: Additional keyword arguments for the layer class. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        super().__init__() | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        self.dim = config.hidden_size | 
					
					
						
						| 
							 | 
						        self.num_experts = config.n_experts | 
					
					
						
						| 
							 | 
						        self.slots_per_expert = config.slots_per_expert if hasattr(config, "slots_per_expert") else 1 | 
					
					
						
						| 
							 | 
						        self.normalize = True | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.phi = nn.Parameter(torch.zeros(self.dim, self.num_experts, self.slots_per_expert)) | 
					
					
						
						| 
							 | 
						        if self.normalize: | 
					
					
						
						| 
							 | 
						            self.scale = nn.Parameter(torch.ones(1)) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        nn.init.normal_(self.phi, mean=0, std=1 / self.dim**0.5) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        self.experts = nn.ModuleList( | 
					
					
						
						| 
							 | 
						            [layer(config) for _ in range(self.num_experts)] | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        Forward pass through the Soft-MoE layer (algorithm 1 from paper). | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Args: | 
					
					
						
						| 
							 | 
						            x (torch.Tensor): Input tensor of shape [batch_size, seq_len, input_dim]. | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						        Returns: | 
					
					
						
						| 
							 | 
						            torch.Tensor: Output tensor of shape [batch_size, seq_len, input_dim]. | 
					
					
						
						| 
							 | 
						        """ | 
					
					
						
						| 
							 | 
						        assert ( | 
					
					
						
						| 
							 | 
						            x.shape[-1] == self.dim | 
					
					
						
						| 
							 | 
						        ), f"Input feature dim of {x.shape[-1]} does not match layer dim of {self.dim}" | 
					
					
						
						| 
							 | 
						        assert ( | 
					
					
						
						| 
							 | 
						            len(x.shape) == 3 | 
					
					
						
						| 
							 | 
						        ), f"Input expected to have 3 dimensions but has {len(x.shape)}" | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        phi = self.phi | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        if self.normalize: | 
					
					
						
						| 
							 | 
						            x = F.normalize(x, dim=2)   | 
					
					
						
						| 
							 | 
						            phi = self.scale * F.normalize(phi, dim=0)   | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        logits = torch.einsum("bmd,dnp->bmnp", x, phi) | 
					
					
						
						| 
							 | 
						        d = softmax(logits, dim=1) | 
					
					
						
						| 
							 | 
						        c = softmax(logits, dim=(2, 3)) | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        xs = torch.einsum("bmd,bmnp->bnpd", x, d) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        ys = torch.stack( | 
					
					
						
						| 
							 | 
						            [f_i(xs[:, i, :, :]) for i, f_i in enumerate(self.experts)], dim=1 | 
					
					
						
						| 
							 | 
						        ) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						         | 
					
					
						
						| 
							 | 
						        y = torch.einsum("bnpd,bmnp->bmd", ys, c) | 
					
					
						
						| 
							 | 
						
 | 
					
					
						
						| 
							 | 
						        return y |