Jingjing Zhai
Remove dependency on external yairschiff/caduceus_base dependency; switch to self-contained config and local model files
5ac1d5e
| """Caduceus config for Hugging Face. | |
| """ | |
| from typing import Optional, Union | |
| from transformers import PretrainedConfig | |
| class CaduceusConfig(PretrainedConfig): | |
| """Config that extends the original MambaConfig with params relevant to bi-directionality and RC equivariance.""" | |
| model_type = "caduceus" | |
| def __init__( | |
| self, | |
| # From original MambaConfig | |
| d_model: int = 2560, | |
| n_layer: int = 64, | |
| vocab_size: int = 50277, | |
| ssm_cfg: Optional[dict] = None, | |
| rms_norm: bool = True, | |
| residual_in_fp32: bool = True, | |
| fused_add_norm: bool = True, | |
| pad_vocab_size_multiple: int = 8, | |
| # Not in original MambaConfig, but default arg in create_block in mamba_ssm repo; used in layer norm | |
| norm_epsilon: float = 1e-5, | |
| # Used in init_weights | |
| initializer_cfg: Optional[dict] = None, | |
| # Caduceus-specific params | |
| bidirectional: bool = True, | |
| bidirectional_strategy: Union[str, None] = "add", | |
| bidirectional_weight_tie: bool = True, | |
| rcps: bool = False, | |
| complement_map: Optional[dict] = None, # used for RCPSEmbedding / RCPSLMHead | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.d_model = d_model | |
| self.n_layer = n_layer | |
| self.vocab_size = vocab_size | |
| self.ssm_cfg = ssm_cfg | |
| self.rms_norm = rms_norm | |
| self.residual_in_fp32 = residual_in_fp32 | |
| self.fused_add_norm = fused_add_norm | |
| self.pad_vocab_size_multiple = pad_vocab_size_multiple | |
| self.norm_epsilon = norm_epsilon | |
| self.initializer_cfg = initializer_cfg | |
| self.bidirectional = bidirectional | |
| self.bidirectional_strategy = bidirectional_strategy | |
| self.bidirectional_weight_tie = bidirectional_weight_tie | |
| self.rcps = rcps | |
| self.complement_map = complement_map | |