Initial upload with README & teaser
Browse files- .gitattributes +1 -0
 - README.md +14 -3
 - assets/NovoMolGen.png +3 -0
 - modeling_novomolgen.py +6 -6
 
    	
        .gitattributes
    CHANGED
    
    | 
         @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text 
     | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 
         | 
| 
         | 
|
| 33 | 
         
             
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
             
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
             
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
            +
            assets/NovoMolGen.png filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        README.md
    CHANGED
    
    | 
         @@ -14,7 +14,12 @@ pipeline_tag: text-generation 
     | 
|
| 14 | 
         | 
| 15 | 
         
             
            # NovoMolGen
         
     | 
| 16 | 
         | 
| 17 | 
         
            -
            NovoMolGen is a family of molecular foundation models trained on  
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 18 | 
         | 
| 19 | 
         
             
            ## How to load
         
     | 
| 20 | 
         | 
| 
         @@ -24,9 +29,14 @@ tokenizer = AutoTokenizer.from_pretrained("chandar-lab/NovoMolGen_157M_SMILES_At 
     | 
|
| 24 | 
         
             
            model = AutoModelForCausalLM.from_pretrained("chandar-lab/NovoMolGen_157M_SMILES_AtomWise", trust_remote_code=True)
         
     | 
| 25 | 
         
             
            ```
         
     | 
| 26 | 
         | 
| 27 | 
         
            -
            ##  
     | 
| 28 | 
         | 
| 29 | 
         
             
            ```python
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 30 | 
         
             
            outputs = model.sample(tokenizer=tokenizer, batch_size=4)
         
     | 
| 31 | 
         
             
            print(outputs['SMILES'])
         
     | 
| 32 | 
         
             
            ```
         
     | 
| 
         @@ -36,7 +46,8 @@ print(outputs['SMILES']) 
     | 
|
| 36 | 
         
             
            ```bibtex
         
     | 
| 37 | 
         
             
            @article{chitsaz2024novomolgen,
         
     | 
| 38 | 
         
             
              title={NovoMolGen: Rethinking Molecular Language Model Pretraining},
         
     | 
| 39 | 
         
            -
              author={Chitsaz, Kamran and Balaji, Roshan and Fournier, Quentin and  
     | 
| 
         | 
|
| 40 | 
         
             
              journal={arXiv preprint},
         
     | 
| 41 | 
         
             
              year={2025},
         
     | 
| 42 | 
         
             
            }
         
     | 
| 
         | 
|
| 14 | 
         | 
| 15 | 
         
             
            # NovoMolGen
         
     | 
| 16 | 
         | 
| 17 | 
         
            +
            NovoMolGen is a family of molecular foundation models trained on 
         
     | 
| 18 | 
         
            +
            1.5 billion ZINC-22 molecules with Llama architectures and FlashAttention. 
         
     | 
| 19 | 
         
            +
            It achieves state-of-the-art performance on both unconstrained and 
         
     | 
| 20 | 
         
            +
            goal-directed molecule generation tasks.
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            <img src="assets/NovoMolGen.png" width="900"/>
         
     | 
| 23 | 
         | 
| 24 | 
         
             
            ## How to load
         
     | 
| 25 | 
         | 
| 
         | 
|
| 29 | 
         
             
            model = AutoModelForCausalLM.from_pretrained("chandar-lab/NovoMolGen_157M_SMILES_AtomWise", trust_remote_code=True)
         
     | 
| 30 | 
         
             
            ```
         
     | 
| 31 | 
         | 
| 32 | 
         
            +
            ## Quick-start (FlashAttention + bf16)
         
     | 
| 33 | 
         | 
| 34 | 
         
             
            ```python
         
     | 
| 35 | 
         
            +
            from accelerate import Accelerator
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
            acc = Accelerator(mixed_precision='bf16')
         
     | 
| 38 | 
         
            +
            model = acc.prepare(model)
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
             
            outputs = model.sample(tokenizer=tokenizer, batch_size=4)
         
     | 
| 41 | 
         
             
            print(outputs['SMILES'])
         
     | 
| 42 | 
         
             
            ```
         
     | 
| 
         | 
|
| 46 | 
         
             
            ```bibtex
         
     | 
| 47 | 
         
             
            @article{chitsaz2024novomolgen,
         
     | 
| 48 | 
         
             
              title={NovoMolGen: Rethinking Molecular Language Model Pretraining},
         
     | 
| 49 | 
         
            +
              author={Chitsaz, Kamran and Balaji, Roshan and Fournier, Quentin and 
         
     | 
| 50 | 
         
            +
                      Bhatt, Nirav Pravinbhai and Chandar, Sarath},
         
     | 
| 51 | 
         
             
              journal={arXiv preprint},
         
     | 
| 52 | 
         
             
              year={2025},
         
     | 
| 53 | 
         
             
            }
         
     | 
    	
        assets/NovoMolGen.png
    ADDED
    
    
											 
									 | 
									
								
											Git LFS Details
  | 
									
    	
        modeling_novomolgen.py
    CHANGED
    
    | 
         @@ -33,7 +33,7 @@ except ImportError: 
     | 
|
| 33 | 
         
             
                inv_remap_state_dict_hf_llama = None
         
     | 
| 34 | 
         | 
| 35 | 
         | 
| 36 | 
         
            -
            def state_dict_from_pretrained(model_name, checkpoint_path: str = "", device=None, dtype=None):
         
     | 
| 37 | 
         
             
                """
         
     | 
| 38 | 
         
             
                code modified from: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/pretrained.py
         
     | 
| 39 | 
         
             
                """
         
     | 
| 
         @@ -45,10 +45,10 @@ def state_dict_from_pretrained(model_name, checkpoint_path: str = "", device=Non 
     | 
|
| 45 | 
         | 
| 46 | 
         
             
                # Try loading from HF hub instead of from local files
         
     | 
| 47 | 
         
             
                resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_NAME),
         
     | 
| 48 | 
         
            -
                                                    _raise_exceptions_for_missing_entries=False)
         
     | 
| 49 | 
         
             
                if resolved_archive_file is None:
         
     | 
| 50 | 
         
             
                    resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_INDEX_NAME),
         
     | 
| 51 | 
         
            -
                                                        _raise_exceptions_for_missing_entries=False)
         
     | 
| 52 | 
         
             
                    if resolved_archive_file is not None:
         
     | 
| 53 | 
         
             
                        is_sharded = True
         
     | 
| 54 | 
         | 
| 
         @@ -115,7 +115,7 @@ class NovoMolGenConfig(LlamaConfig): 
     | 
|
| 115 | 
         | 
| 116 | 
         
             
                    resolved_archive_config_file = cached_file(pretrained_model_name_or_path,
         
     | 
| 117 | 
         
             
                                                               os.path.join(checkpoint_path, "config.json"),
         
     | 
| 118 | 
         
            -
                                                               _raise_exceptions_for_missing_entries=False)
         
     | 
| 119 | 
         | 
| 120 | 
         
             
                    if resolved_archive_config_file is not None:
         
     | 
| 121 | 
         
             
                        with open(resolved_archive_config_file, "r", encoding="utf-8") as reader:
         
     | 
| 
         @@ -266,13 +266,13 @@ class NovoMolGen(GPTLMHeadModel): 
     | 
|
| 266 | 
         
             
                    **kwargs,
         
     | 
| 267 | 
         
             
                    ):
         
     | 
| 268 | 
         
             
                    if config is None:
         
     | 
| 269 | 
         
            -
                        config = NovoMolGenConfig.from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path)
         
     | 
| 270 | 
         
             
                    model = cls(config)
         
     | 
| 271 | 
         | 
| 272 | 
         
             
                    if os.path.exists(pretrained_model_name_or_path):
         
     | 
| 273 | 
         
             
                        state_dict = torch.load(os.path.join(pretrained_model_name_or_path, checkpoint_path, WEIGHTS_NAME))
         
     | 
| 274 | 
         
             
                    else:
         
     | 
| 275 | 
         
            -
                        state_dict = state_dict_from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path)
         
     | 
| 276 | 
         
             
                    model.load_state_dict(state_dict)
         
     | 
| 277 | 
         
             
                    return model
         
     | 
| 278 | 
         | 
| 
         | 
|
| 33 | 
         
             
                inv_remap_state_dict_hf_llama = None
         
     | 
| 34 | 
         | 
| 35 | 
         | 
| 36 | 
         
            +
            def state_dict_from_pretrained(model_name, checkpoint_path: str = "", device=None, dtype=None, **kwargs):
         
     | 
| 37 | 
         
             
                """
         
     | 
| 38 | 
         
             
                code modified from: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/pretrained.py
         
     | 
| 39 | 
         
             
                """
         
     | 
| 
         | 
|
| 45 | 
         | 
| 46 | 
         
             
                # Try loading from HF hub instead of from local files
         
     | 
| 47 | 
         
             
                resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_NAME),
         
     | 
| 48 | 
         
            +
                                                    _raise_exceptions_for_missing_entries=False, **kwargs)
         
     | 
| 49 | 
         
             
                if resolved_archive_file is None:
         
     | 
| 50 | 
         
             
                    resolved_archive_file = cached_file(model_name, os.path.join(checkpoint_path, WEIGHTS_INDEX_NAME),
         
     | 
| 51 | 
         
            +
                                                        _raise_exceptions_for_missing_entries=False, **kwargs)
         
     | 
| 52 | 
         
             
                    if resolved_archive_file is not None:
         
     | 
| 53 | 
         
             
                        is_sharded = True
         
     | 
| 54 | 
         | 
| 
         | 
|
| 115 | 
         | 
| 116 | 
         
             
                    resolved_archive_config_file = cached_file(pretrained_model_name_or_path,
         
     | 
| 117 | 
         
             
                                                               os.path.join(checkpoint_path, "config.json"),
         
     | 
| 118 | 
         
            +
                                                               _raise_exceptions_for_missing_entries=False, force_download=force_download)
         
     | 
| 119 | 
         | 
| 120 | 
         
             
                    if resolved_archive_config_file is not None:
         
     | 
| 121 | 
         
             
                        with open(resolved_archive_config_file, "r", encoding="utf-8") as reader:
         
     | 
| 
         | 
|
| 266 | 
         
             
                    **kwargs,
         
     | 
| 267 | 
         
             
                    ):
         
     | 
| 268 | 
         
             
                    if config is None:
         
     | 
| 269 | 
         
            +
                        config = NovoMolGenConfig.from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path, **kwargs)
         
     | 
| 270 | 
         
             
                    model = cls(config)
         
     | 
| 271 | 
         | 
| 272 | 
         
             
                    if os.path.exists(pretrained_model_name_or_path):
         
     | 
| 273 | 
         
             
                        state_dict = torch.load(os.path.join(pretrained_model_name_or_path, checkpoint_path, WEIGHTS_NAME))
         
     | 
| 274 | 
         
             
                    else:
         
     | 
| 275 | 
         
            +
                        state_dict = state_dict_from_pretrained(pretrained_model_name_or_path, checkpoint_path=checkpoint_path, **kwargs)
         
     | 
| 276 | 
         
             
                    model.load_state_dict(state_dict)
         
     | 
| 277 | 
         
             
                    return model
         
     | 
| 278 | 
         |