Upload modeling_hunyuan.py
Browse files- modeling_hunyuan.py +63 -63
modeling_hunyuan.py
CHANGED
|
@@ -1605,6 +1605,68 @@ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
|
|
| 1605 |
)
|
| 1606 |
return reordered_past
|
| 1607 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1608 |
|
| 1609 |
@add_start_docstrings(
|
| 1610 |
"""
|
|
@@ -1845,66 +1907,4 @@ class HunYuanDenseMoE(nn.Module):
|
|
| 1845 |
sparse_out = self._sparse_path(x, probs)
|
| 1846 |
|
| 1847 |
out = dense_out + (sparse_out - dense_out).detach() # STE
|
| 1848 |
-
return out.view(bsz, seq_len, self.hidden_size)
|
| 1849 |
-
|
| 1850 |
-
# -----------------------------------------------------------------------
|
| 1851 |
-
# Helper for module replacement
|
| 1852 |
-
# -----------------------------------------------------------------------
|
| 1853 |
-
def _replace_submodule(root: nn.Module, target: str, new_module: nn.Module):
|
| 1854 |
-
"""Replace a (possibly nested) sub‑module.
|
| 1855 |
-
|
| 1856 |
-
``target`` is the dotted path returned by ``model.named_modules()``.
|
| 1857 |
-
"""
|
| 1858 |
-
parts = target.split('.')
|
| 1859 |
-
parent = root
|
| 1860 |
-
for p in parts[:-1]:
|
| 1861 |
-
parent = getattr(parent, p)
|
| 1862 |
-
setattr(parent, parts[-1], new_module)
|
| 1863 |
-
|
| 1864 |
-
# -----------------------------------------------------------------------
|
| 1865 |
-
# Public APIs
|
| 1866 |
-
# -----------------------------------------------------------------------
|
| 1867 |
-
def densify(model: nn.Module):
|
| 1868 |
-
"""Convert all :class:`HunYuanMoE` modules under *model* to
|
| 1869 |
-
:class:`HunYuanDenseMoE`. Operates **in‑place**."""
|
| 1870 |
-
replacements = []
|
| 1871 |
-
for name, module in model.named_modules():
|
| 1872 |
-
if isinstance(module, HunYuanMoE):
|
| 1873 |
-
replacements.append((name, module))
|
| 1874 |
-
for name, sparse_moe in replacements:
|
| 1875 |
-
dense_moe = HunYuanDenseMoE(sparse_moe).to(next(sparse_moe.parameters()).device)
|
| 1876 |
-
_replace_submodule(model, name, dense_moe)
|
| 1877 |
-
return model
|
| 1878 |
-
|
| 1879 |
-
|
| 1880 |
-
def sparsify(model: nn.Module):
|
| 1881 |
-
"""Rebuild standard sparse :class:`HunYuanMoE` modules from their
|
| 1882 |
-
fused :class:`HunYuanDenseMoE` form. Operates **in‑place**."""
|
| 1883 |
-
replacements = []
|
| 1884 |
-
for name, module in model.named_modules():
|
| 1885 |
-
if isinstance(module, HunYuanDenseMoE):
|
| 1886 |
-
replacements.append((name, module))
|
| 1887 |
-
for name, dense_moe in replacements:
|
| 1888 |
-
cfg = dense_moe.config
|
| 1889 |
-
sparse_moe = HunYuanMoE(cfg, layer_idx=dense_moe.layer_idx).to(next(dense_moe.parameters()).device)
|
| 1890 |
-
|
| 1891 |
-
# Copy router
|
| 1892 |
-
sparse_moe.gate.load_state_dict(dense_moe.gate.state_dict())
|
| 1893 |
-
|
| 1894 |
-
# Slice fused weights back to per‑expert
|
| 1895 |
-
for idx, expert in enumerate(sparse_moe.experts):
|
| 1896 |
-
start = idx * dense_moe.intermediate_size
|
| 1897 |
-
end = (idx + 1) * dense_moe.intermediate_size
|
| 1898 |
-
|
| 1899 |
-
expert.gate_proj.weight.data.copy_(
|
| 1900 |
-
dense_moe.fused_gate_proj.weight.data[start:end]
|
| 1901 |
-
)
|
| 1902 |
-
expert.up_proj.weight.data.copy_(
|
| 1903 |
-
dense_moe.fused_up_proj.weight.data[start:end]
|
| 1904 |
-
)
|
| 1905 |
-
expert.down_proj.weight.data.copy_(
|
| 1906 |
-
dense_moe.fused_down_proj.weight.data[:, start:end]
|
| 1907 |
-
)
|
| 1908 |
-
|
| 1909 |
-
_replace_submodule(model, name, sparse_moe)
|
| 1910 |
-
return model
|
|
|
|
| 1605 |
)
|
| 1606 |
return reordered_past
|
| 1607 |
|
| 1608 |
+
# -----------------------------------------------------------------------
|
| 1609 |
+
# Helper for module replacement
|
| 1610 |
+
# -----------------------------------------------------------------------
|
| 1611 |
+
def _replace_submodule(root: nn.Module, target: str, new_module: nn.Module):
|
| 1612 |
+
"""Replace a (possibly nested) sub‑module.
|
| 1613 |
+
|
| 1614 |
+
``target`` is the dotted path returned by ``model.named_modules()``.
|
| 1615 |
+
"""
|
| 1616 |
+
parts = target.split('.')
|
| 1617 |
+
parent = root
|
| 1618 |
+
for p in parts[:-1]:
|
| 1619 |
+
parent = getattr(parent, p)
|
| 1620 |
+
setattr(parent, parts[-1], new_module)
|
| 1621 |
+
|
| 1622 |
+
# -----------------------------------------------------------------------
|
| 1623 |
+
# Public APIs
|
| 1624 |
+
# -----------------------------------------------------------------------
|
| 1625 |
+
def densify(model: nn.Module):
|
| 1626 |
+
"""Convert all :class:`HunYuanMoE` modules under *model* to
|
| 1627 |
+
:class:`HunYuanDenseMoE`. Operates **in‑place**."""
|
| 1628 |
+
replacements = []
|
| 1629 |
+
for name, module in model.named_modules():
|
| 1630 |
+
if isinstance(module, HunYuanMoE):
|
| 1631 |
+
replacements.append((name, module))
|
| 1632 |
+
for name, sparse_moe in replacements:
|
| 1633 |
+
dense_moe = HunYuanDenseMoE(sparse_moe).to(next(sparse_moe.parameters()).device)
|
| 1634 |
+
_replace_submodule(model, name, dense_moe)
|
| 1635 |
+
return model
|
| 1636 |
+
|
| 1637 |
+
|
| 1638 |
+
def sparsify(model: nn.Module):
|
| 1639 |
+
"""Rebuild standard sparse :class:`HunYuanMoE` modules from their
|
| 1640 |
+
fused :class:`HunYuanDenseMoE` form. Operates **in‑place**."""
|
| 1641 |
+
replacements = []
|
| 1642 |
+
for name, module in model.named_modules():
|
| 1643 |
+
if isinstance(module, HunYuanDenseMoE):
|
| 1644 |
+
replacements.append((name, module))
|
| 1645 |
+
for name, dense_moe in replacements:
|
| 1646 |
+
cfg = dense_moe.config
|
| 1647 |
+
sparse_moe = HunYuanMoE(cfg, layer_idx=dense_moe.layer_idx).to(next(dense_moe.parameters()).device)
|
| 1648 |
+
|
| 1649 |
+
# Copy router
|
| 1650 |
+
sparse_moe.gate.load_state_dict(dense_moe.gate.state_dict())
|
| 1651 |
+
|
| 1652 |
+
# Slice fused weights back to per‑expert
|
| 1653 |
+
for idx, expert in enumerate(sparse_moe.experts):
|
| 1654 |
+
start = idx * dense_moe.intermediate_size
|
| 1655 |
+
end = (idx + 1) * dense_moe.intermediate_size
|
| 1656 |
+
|
| 1657 |
+
expert.gate_proj.weight.data.copy_(
|
| 1658 |
+
dense_moe.fused_gate_proj.weight.data[start:end]
|
| 1659 |
+
)
|
| 1660 |
+
expert.up_proj.weight.data.copy_(
|
| 1661 |
+
dense_moe.fused_up_proj.weight.data[start:end]
|
| 1662 |
+
)
|
| 1663 |
+
expert.down_proj.weight.data.copy_(
|
| 1664 |
+
dense_moe.fused_down_proj.weight.data[:, start:end]
|
| 1665 |
+
)
|
| 1666 |
+
|
| 1667 |
+
_replace_submodule(model, name, sparse_moe)
|
| 1668 |
+
return model
|
| 1669 |
+
|
| 1670 |
|
| 1671 |
@add_start_docstrings(
|
| 1672 |
"""
|
|
|
|
| 1907 |
sparse_out = self._sparse_path(x, probs)
|
| 1908 |
|
| 1909 |
out = dense_out + (sparse_out - dense_out).detach() # STE
|
| 1910 |
+
return out.view(bsz, seq_len, self.hidden_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|