codys12 commited on
Commit
bab2a77
·
verified ·
1 Parent(s): b161686

Upload modeling_hunyuan.py

Browse files
Files changed (1) hide show
  1. modeling_hunyuan.py +63 -63
modeling_hunyuan.py CHANGED
@@ -1605,6 +1605,68 @@ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
1605
  )
1606
  return reordered_past
1607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1608
 
1609
  @add_start_docstrings(
1610
  """
@@ -1845,66 +1907,4 @@ class HunYuanDenseMoE(nn.Module):
1845
  sparse_out = self._sparse_path(x, probs)
1846
 
1847
  out = dense_out + (sparse_out - dense_out).detach() # STE
1848
- return out.view(bsz, seq_len, self.hidden_size)
1849
-
1850
- # -----------------------------------------------------------------------
1851
- # Helper for module replacement
1852
- # -----------------------------------------------------------------------
1853
- def _replace_submodule(root: nn.Module, target: str, new_module: nn.Module):
1854
- """Replace a (possibly nested) sub‑module.
1855
-
1856
- ``target`` is the dotted path returned by ``model.named_modules()``.
1857
- """
1858
- parts = target.split('.')
1859
- parent = root
1860
- for p in parts[:-1]:
1861
- parent = getattr(parent, p)
1862
- setattr(parent, parts[-1], new_module)
1863
-
1864
- # -----------------------------------------------------------------------
1865
- # Public APIs
1866
- # -----------------------------------------------------------------------
1867
- def densify(model: nn.Module):
1868
- """Convert all :class:`HunYuanMoE` modules under *model* to
1869
- :class:`HunYuanDenseMoE`. Operates **in‑place**."""
1870
- replacements = []
1871
- for name, module in model.named_modules():
1872
- if isinstance(module, HunYuanMoE):
1873
- replacements.append((name, module))
1874
- for name, sparse_moe in replacements:
1875
- dense_moe = HunYuanDenseMoE(sparse_moe).to(next(sparse_moe.parameters()).device)
1876
- _replace_submodule(model, name, dense_moe)
1877
- return model
1878
-
1879
-
1880
- def sparsify(model: nn.Module):
1881
- """Rebuild standard sparse :class:`HunYuanMoE` modules from their
1882
- fused :class:`HunYuanDenseMoE` form. Operates **in‑place**."""
1883
- replacements = []
1884
- for name, module in model.named_modules():
1885
- if isinstance(module, HunYuanDenseMoE):
1886
- replacements.append((name, module))
1887
- for name, dense_moe in replacements:
1888
- cfg = dense_moe.config
1889
- sparse_moe = HunYuanMoE(cfg, layer_idx=dense_moe.layer_idx).to(next(dense_moe.parameters()).device)
1890
-
1891
- # Copy router
1892
- sparse_moe.gate.load_state_dict(dense_moe.gate.state_dict())
1893
-
1894
- # Slice fused weights back to per‑expert
1895
- for idx, expert in enumerate(sparse_moe.experts):
1896
- start = idx * dense_moe.intermediate_size
1897
- end = (idx + 1) * dense_moe.intermediate_size
1898
-
1899
- expert.gate_proj.weight.data.copy_(
1900
- dense_moe.fused_gate_proj.weight.data[start:end]
1901
- )
1902
- expert.up_proj.weight.data.copy_(
1903
- dense_moe.fused_up_proj.weight.data[start:end]
1904
- )
1905
- expert.down_proj.weight.data.copy_(
1906
- dense_moe.fused_down_proj.weight.data[:, start:end]
1907
- )
1908
-
1909
- _replace_submodule(model, name, sparse_moe)
1910
- return model
 
1605
  )
1606
  return reordered_past
1607
 
1608
+ # -----------------------------------------------------------------------
1609
+ # Helper for module replacement
1610
+ # -----------------------------------------------------------------------
1611
+ def _replace_submodule(root: nn.Module, target: str, new_module: nn.Module):
1612
+ """Replace a (possibly nested) sub‑module.
1613
+
1614
+ ``target`` is the dotted path returned by ``model.named_modules()``.
1615
+ """
1616
+ parts = target.split('.')
1617
+ parent = root
1618
+ for p in parts[:-1]:
1619
+ parent = getattr(parent, p)
1620
+ setattr(parent, parts[-1], new_module)
1621
+
1622
+ # -----------------------------------------------------------------------
1623
+ # Public APIs
1624
+ # -----------------------------------------------------------------------
1625
+ def densify(model: nn.Module):
1626
+ """Convert all :class:`HunYuanMoE` modules under *model* to
1627
+ :class:`HunYuanDenseMoE`. Operates **in‑place**."""
1628
+ replacements = []
1629
+ for name, module in model.named_modules():
1630
+ if isinstance(module, HunYuanMoE):
1631
+ replacements.append((name, module))
1632
+ for name, sparse_moe in replacements:
1633
+ dense_moe = HunYuanDenseMoE(sparse_moe).to(next(sparse_moe.parameters()).device)
1634
+ _replace_submodule(model, name, dense_moe)
1635
+ return model
1636
+
1637
+
1638
+ def sparsify(model: nn.Module):
1639
+ """Rebuild standard sparse :class:`HunYuanMoE` modules from their
1640
+ fused :class:`HunYuanDenseMoE` form. Operates **in‑place**."""
1641
+ replacements = []
1642
+ for name, module in model.named_modules():
1643
+ if isinstance(module, HunYuanDenseMoE):
1644
+ replacements.append((name, module))
1645
+ for name, dense_moe in replacements:
1646
+ cfg = dense_moe.config
1647
+ sparse_moe = HunYuanMoE(cfg, layer_idx=dense_moe.layer_idx).to(next(dense_moe.parameters()).device)
1648
+
1649
+ # Copy router
1650
+ sparse_moe.gate.load_state_dict(dense_moe.gate.state_dict())
1651
+
1652
+ # Slice fused weights back to per‑expert
1653
+ for idx, expert in enumerate(sparse_moe.experts):
1654
+ start = idx * dense_moe.intermediate_size
1655
+ end = (idx + 1) * dense_moe.intermediate_size
1656
+
1657
+ expert.gate_proj.weight.data.copy_(
1658
+ dense_moe.fused_gate_proj.weight.data[start:end]
1659
+ )
1660
+ expert.up_proj.weight.data.copy_(
1661
+ dense_moe.fused_up_proj.weight.data[start:end]
1662
+ )
1663
+ expert.down_proj.weight.data.copy_(
1664
+ dense_moe.fused_down_proj.weight.data[:, start:end]
1665
+ )
1666
+
1667
+ _replace_submodule(model, name, sparse_moe)
1668
+ return model
1669
+
1670
 
1671
  @add_start_docstrings(
1672
  """
 
1907
  sparse_out = self._sparse_path(x, probs)
1908
 
1909
  out = dense_out + (sparse_out - dense_out).detach() # STE
1910
+ return out.view(bsz, seq_len, self.hidden_size)