{ "cells": [ { "cell_type": "markdown", "id": "9371cf89", "metadata": {}, "source": [ "# Loading Script" ] }, { "cell_type": "code", "execution_count": null, "id": "f67fdbad", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import sys\n", "import os\n", "from pathlib import Path\n", "import importlib.util\n", "import huggingface_hub\n", "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", "\n", "def download_and_setup_model(model_name=\"gbyuvd/ChemMiniQ3-SAbRLo\", local_dir=\"./chemq3_model\"):\n", " \"\"\"Download model files and set up the custom modules\"\"\"\n", " \n", " print(f\"๐Ÿ“ฅ Downloading model files from {model_name}...\")\n", " \n", " try:\n", " # Download all files to a local directory\n", " model_path = huggingface_hub.snapshot_download(\n", " repo_id=model_name,\n", " local_dir=local_dir,\n", " local_files_only=False,\n", " resume_download=True\n", " )\n", " \n", " print(f\"โœ… Model downloaded to: {model_path}\")\n", " \n", " # List downloaded files\n", " print(\"๐Ÿ“ Downloaded files:\")\n", " for file in Path(model_path).iterdir():\n", " if file.is_file():\n", " print(f\" {file.name} ({file.stat().st_size} bytes)\")\n", " \n", " return Path(model_path)\n", " \n", " except Exception as e:\n", " print(f\"โŒ Download failed: {e}\")\n", " return None\n", "\n", "def load_custom_modules(model_path):\n", " \"\"\"Load all the custom modules required by the model\"\"\"\n", " \n", " model_path = Path(model_path)\n", " \n", " # Add the model directory to Python path\n", " if str(model_path) not in sys.path:\n", " sys.path.insert(0, str(model_path))\n", " \n", " print(f\"๐Ÿ”ง Loading custom modules from {model_path}...\")\n", " \n", " # Required module files\n", " required_files = {\n", " 'configuration_chemq3mtp.py': 'configuration_chemq3mtp',\n", " 'modeling_chemq3mtp.py': 'modeling_chemq3mtp', \n", " 'FastChemTokenizerHF.py': 'FastChemTokenizerHF'\n", " }\n", " \n", " loaded_modules = {}\n", " \n", " # Load each required module\n", " for filename, module_name in required_files.items():\n", " file_path = model_path / filename\n", " \n", " if not file_path.exists():\n", " print(f\"โŒ Required file not found: {filename}\")\n", " return None\n", " \n", " try:\n", " spec = importlib.util.spec_from_file_location(module_name, file_path)\n", " module = importlib.util.module_from_spec(spec)\n", " \n", " # Execute the module\n", " spec.loader.exec_module(module)\n", " loaded_modules[module_name] = module\n", " \n", " print(f\" โœ… Loaded {filename}\")\n", " \n", " except Exception as e:\n", " print(f\" โŒ Failed to load {filename}: {e}\")\n", " return None\n", " \n", " return loaded_modules\n", "\n", "def register_model_components(loaded_modules):\n", " \"\"\"Register the model components with transformers\"\"\"\n", " \n", " print(\"๐Ÿ”— Registering model components...\")\n", " \n", " try:\n", " # Get the classes from loaded modules\n", " ChemQ3MTPConfig = loaded_modules['configuration_chemq3mtp'].ChemQ3MTPConfig\n", " ChemQ3MTPForCausalLM = loaded_modules['modeling_chemq3mtp'].ChemQ3MTPForCausalLM\n", " FastChemTokenizerSelfies = loaded_modules['FastChemTokenizerHF'].FastChemTokenizerSelfies\n", " \n", " # Register with transformers\n", " AutoConfig.register(\"chemq3_mtp\", ChemQ3MTPConfig)\n", " AutoModelForCausalLM.register(ChemQ3MTPConfig, ChemQ3MTPForCausalLM)\n", " AutoTokenizer.register(ChemQ3MTPConfig, FastChemTokenizerSelfies)\n", " \n", " print(\"โœ… Model components registered successfully\")\n", " \n", " return ChemQ3MTPConfig, ChemQ3MTPForCausalLM, FastChemTokenizerSelfies\n", " \n", " except Exception as e:\n", " print(f\"โŒ Registration failed: {e}\")\n", " return None, None, None\n", "\n", "def load_model(model_path):\n", " \"\"\"Load the model using the registered components\"\"\"\n", " \n", " print(\"๐Ÿš€ Loading model...\")\n", " \n", " try:\n", " # Load config\n", " config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=False)\n", " print(f\"โœ… Config loaded: {config.__class__.__name__}\")\n", " \n", " # Load model\n", " model = AutoModelForCausalLM.from_pretrained(\n", " str(model_path),\n", " config=config,\n", " torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,\n", " trust_remote_code=False # We've already registered everything\n", " )\n", " print(f\"โœ… Model loaded: {model.__class__.__name__}\")\n", " \n", " # Load tokenizer\n", " tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=False)\n", " print(f\"โœ… Tokenizer loaded: {tokenizer.__class__.__name__}\")\n", " \n", " return model, tokenizer, config\n", " \n", " except Exception as e:\n", " print(f\"โŒ Model loading failed: {e}\")\n", " import traceback\n", " traceback.print_exc()\n", " return None, None, None\n", "\n", "def test_model(model, tokenizer, config):\n", " \"\"\"Test the loaded model\"\"\"\n", " \n", " print(\"\\n๐Ÿงช Testing model...\")\n", " \n", " # Setup device\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " print(f\"๐Ÿ–ฅ๏ธ Using device: {device}\")\n", " \n", " model = model.to(device)\n", " model.eval()\n", " \n", " # Model info\n", " print(f\"\\n๐Ÿ“Š Model Information:\")\n", " print(f\" Model class: {model.__class__.__name__}\")\n", " print(f\" Config class: {config.__class__.__name__}\")\n", " print(f\" Tokenizer class: {tokenizer.__class__.__name__}\")\n", " print(f\" Model type: {config.model_type}\")\n", " print(f\" Vocab size: {config.vocab_size}\")\n", " \n", " # Set pad token if needed\n", " if not hasattr(tokenizer, 'pad_token') or tokenizer.pad_token is None:\n", " if hasattr(tokenizer, 'eos_token'):\n", " tokenizer.pad_token = tokenizer.eos_token\n", " print(\"โœ… Set pad_token to eos_token\")\n", " \n", " # Test tokenization\n", " print(\"\\n๐Ÿ”ค Testing tokenization...\")\n", " test_inputs = [\"[C][C][O]\", \"[C]\", \"[O]\"]\n", " \n", " for test_input in test_inputs:\n", " try:\n", " tokens = tokenizer(test_input, return_tensors=\"pt\")\n", " print(f\" '{test_input}' -> {tokens.input_ids.tolist()}\")\n", " except Exception as e:\n", " print(f\" โŒ Tokenization failed for '{test_input}': {e}\")\n", " continue\n", " \n", " # Test generation\n", " print(\"\\n๐ŸŽฏ Testing generation...\")\n", " test_prompts = [\"[C]\", \"[C][C]\"]\n", " \n", " for prompt in test_prompts:\n", " try:\n", " input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n", " \n", " with torch.no_grad():\n", " outputs = model.generate(\n", " input_ids,\n", " max_length=input_ids.shape[1] + 20,\n", " temperature=0.8,\n", " top_p=0.9,\n", " top_k=50,\n", " do_sample=True,\n", " pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0,\n", " num_return_sequences=3\n", " )\n", " \n", " print(f\"\\n Prompt: '{prompt}'\")\n", " for i, output in enumerate(outputs):\n", " generated = tokenizer.decode(output, skip_special_tokens=True)\n", " print(f\" {i+1}: {generated}\")\n", " \n", " except Exception as e:\n", " print(f\" โŒ Generation failed for '{prompt}': {e}\")\n", " \n", " # Test MTP functionality if available\n", " print(\"\\n๐Ÿ”ฌ Testing MTP functionality...\")\n", " try:\n", " if hasattr(model, 'set_mtp_training'):\n", " print(\" โœ… MTP training methods available\")\n", " if hasattr(model, 'generate_with_logprobs'):\n", " print(\" โœ… MTP generation methods available\")\n", " else:\n", " print(\" โ„น๏ธ Standard model - no MTP methods detected\")\n", " except Exception as e:\n", " print(f\" โš ๏ธ MTP test error: {e}\")\n", "\n", "def main():\n", " print(\"๐Ÿš€ ChemQ3-MTP Model Loader Starting...\\n\")\n", " \n", " model_name = \"gbyuvd/ChemMiniQ3-SAbRLo\"\n", " local_dir = \"./chemq3_model\"\n", " \n", " # Step 1: Download model files\n", " model_path = download_and_setup_model(model_name, local_dir)\n", " if model_path is None:\n", " return None, None, None\n", " \n", " print()\n", " \n", " # Step 2: Load custom modules\n", " loaded_modules = load_custom_modules(model_path)\n", " if loaded_modules is None:\n", " return None, None, None\n", " \n", " print()\n", " \n", " # Step 3: Register components\n", " config_class, model_class, tokenizer_class = register_model_components(loaded_modules)\n", " if config_class is None:\n", " return None, None, None\n", " \n", " print()\n", " \n", " # Step 4: Load the model\n", " model, tokenizer, config = load_model(model_path)\n", " if model is None:\n", " return None, None, None\n", " \n", " # Step 5: Test the model\n", " test_model(model, tokenizer, config)\n", " \n", " print(\"\\n๐ŸŽ‰ Model loading and testing completed successfully!\")\n", " \n", " return model, tokenizer, config\n", "\n", "if __name__ == \"__main__\":\n", " model, tokenizer, config = main()" ] }, { "cell_type": "code", "execution_count": null, "id": "b2ea169c", "metadata": {}, "outputs": [], "source": [ "# Generate SELFIES\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.to(device)\n", "input_ids = tokenizer(\"\", return_tensors=\"pt\").input_ids.to(device)\n", "gen = model.generate(input_ids, max_length=256, top_k=50, temperature=1, do_sample=True, pad_token_id=tokenizer.pad_token_id)\n", "print(tokenizer.decode(gen[0], skip_special_tokens=True))" ] }, { "cell_type": "code", "execution_count": null, "id": "bcd4f1fa", "metadata": {}, "outputs": [], "source": [ "# Manually convert it to SMILES\n", "import selfies as sf\n", "\n", "test = tokenizer.decode(gen[0], skip_special_tokens=True)\n", "test = test.replace(' ', '')\n", "print(sf.decoder(test))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "eaf273a0", "metadata": {}, "outputs": [], "source": [ "# Generate Mol Viz\n", "from rdkit import Chem\n", "from rdkit.Chem import Draw\n", "\n", "input_ids = tokenizer(\"\", return_tensors=\"pt\").input_ids.to(device)\n", "gen = model.generate(input_ids, max_length=25, top_k=50, temperature=1, do_sample=True, pad_token_id=tokenizer.pad_token_id)\n", "generatedmol = tokenizer.decode(gen[0], skip_special_tokens=True)\n", "\n", "test = generatedmol.replace(' ', '')\n", "csmi_gen = sf.decoder(test)\n", "print(csmi_gen)\n", "mol = Chem.MolFromSmiles(csmi_gen)\n", "\n", "# Draw the molecule\n", "Draw.MolToImage(mol)" ] }, { "cell_type": "code", "execution_count": 12, "id": "3f68f519", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using MTP-specific generation...\n", "Generated SELFIES: .[C][N][C][Branch1][#C][S][C][C][=Branch1][C][=O][C][=C][C][=C][C][=C][Ring1][=Branch1][=N][N][=C][Ring1][#C][C][C][=C][O][C][=Ring1][Branch1][C]\n", "Decoded SMILES: CN1C(SCC(=O)C2=CC=CC=C2)=NN=C1C=3C=COC=3C\n" ] }, { "data": { "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/2wBDAQkJCQwLDBgNDRgyIRwhMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjIyMjL/wAARCAEsASwDASIAAhEBAxEB/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/8QAHwEAAwEBAQEBAQEBAQAAAAAAAAECAwQFBgcICQoL/8QAtREAAgECBAQDBAcFBAQAAQJ3AAECAxEEBSExBhJBUQdhcRMiMoEIFEKRobHBCSMzUvAVYnLRChYkNOEl8RcYGRomJygpKjU2Nzg5OkNERUZHSElKU1RVVldYWVpjZGVmZ2hpanN0dXZ3eHl6goOEhYaHiImKkpOUlZaXmJmaoqOkpaanqKmqsrO0tba3uLm6wsPExcbHyMnK0tPU1dbX2Nna4uPk5ebn6Onq8vP09fb3+Pn6/9oADAMBAAIRAxEAPwD3+iiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKzPEepDRvDOqannBtLSWYfVVJH6igCj4P8AE58WabeX4tBbwxX01tCRJv8AORGwH6DGeeOenWuhrkfhfpp0r4Z6BbsMO9qJ2z1zITIc/wDfVddQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRTJporeF5p5UiijUs7uwVVA6kk9BQA+ivP0+MvhF9aNibi5W03+WuqNDi0Z/wC6JM/rjHvjmu9iljnhSaGRJInAZHRgQwPQgjqKAH0UUUAFFFFABRRRQAUUUUAFcF8YZ5B8Pp9PgOLjVLmCxi+ryDI/INXe15343mi1T4i+B9AWRGMd5JqM6BhlPKjJjJHud2PoaAPQLeCO1toreIbY4kCIPQAYFSUUUAFFFFABRRRQBS1jUV0fRb7UnjaRLSB5yinBYKpOB+VU9P8AEdtqN1YQQxODd20053cGIxPGjxsP7waTB/3TVPx7f2ln4RvLa6lZG1FTYwBE3s0koKjCjk4zk+wNUb7wXZax4jvLldQdbOW3eG4t7aXY8c7PC+8MPukiJMj8e5r0KFGi6SlWur319Ev1f4W6kNu+h2dFcL/wget2BzovjrWIcdEvwt2v0+bGBQW+Jmm9Y9A1iIf3S9vK35/KKn6nTl/DqxfreL/FW/EfM+qO6orgofFOvQ608urWi2NkLaFLi03rKbWWR5VSTzAOVJRQew3g9jXVeG7ue/8AC+kXly/mXFxZQyyvgDczICTgcDk9qzr4SpRjzSatps776+n9LoCkmadFFFcpQUUUUAFFFFABRRRQAUUUUAFFFeY+IdQ8ReJvHOqeFbPUG0jR9OtI7i6uLNC13cBxnYh/h7jgZ475xQBveJPiHpmiXv8AZNhDNrOvPwmm2I3OD6yN0Qeuee+K878QXMt9qEcXji7k1XUnIa18H6IxKKeo89x1x1OTx1GRxWh4c8P6xqtkbPwvp8ng/wAOy/63UJ13alfD1GeUB9SeM5HpXo/hnwfonhK1aHSrQJJJzNcyHfNMfV3PJ55x09qAMvwppOtXGj3Vv4q0/SILCdVS20a2gDR20YzlWPRieOgwMcdcVlTeA9Z8KTPeeANTEEJJeTRL5i9rJ67D1jP+cgV6NRQBxOhfEiwvdQXRtetZtA1zp9kvSAkp9YpPuuPToT2zXbVkeJNF0TW9Gnh1+zguLKNGkYyrzGAOWUjlTjuOa898LXd7oHjjQ9F0vXptX8K6vZS3NqLz5pIAgyAr4BK9AAexPHGaAPWaKKKACiiqerLcto18tmSLo28ghI679p2/rik3ZXKhHmko3tctLIjsyq6ll+8AeR9awdX8ZaRpM/2QSPe35OFs7NfNkJ9CBwPxrJig0W8k0i20iKMTfMl2IV2yJCYmDiXHIJYr97ndzXS6ToOl6FB5Om2UVuD95lGWb6seT+NZKU5/DY75UcPQd6vM30W3VrV69ui+aOQudP8AHniydlnv4/DGjNjEdp+8vZB3Bfon1HNc14h8HeE9AvtO0S1S80rWbs/aLDxHLIXzdKeI3cnq2eRgA5GOa9lrL8Q+H9O8UaLPpWqQCW2mHbhkbsynsw9a1WiOCTTk2lbyMPwV4wm1lrjRNcgWy8S6f8t3bdpV7Sx+qHj6Z+mewrw6807U49Ys/D2rX32XxXYZfw54gPC38Y/5YSnuexBz175+f0jwX4xTxPaz2t5bmx12wbyr+wfrG395fVD1B/8A1lknU0UUUAFFFcX8SNfu9O0e30bRznXdbl+x2QB5jB+/L7BVOc9iQaAOZ1nxBFe67qHiyZPO0nw9m00qLGRd3r8Fh6gcAH8R3rtPA3h6bQNA3XzeZq19Ibu/lPVpW5I+g6fn61xHhbQrbVPF1npVhk+G/Bo8pW7XWoEfOx9duSfYn0Neu131sc6lFUYxstPuWy+9uT7t+SIUbO4UUUVwFnl3iDxS0Nj4wuW0y0aSzurewuhKXZZrVjgEjIwf3jdPXvXpVnaQWFjb2dsmy3t41iiXJO1VGAMnk8CvI/FVujar8SrDeubnTbW/CHjAiHJyeOo/SvVNDuvt2gabd7g3n2sUuQcg7lB/rXrY50nQg6T666/3Ifk+ZfgZxvfUv0UUV5JoFFFFABRRRQAUUUUAFFFFABXn+jcfHDxMP72mWp/U16BXn+l8fHTXh/e0e3P/AI+aAPQKKKKACiiigDO8QDd4b1UetnKP/HDXkXg4/wDE6+Fjf3tHu1/JRXsGtDdoWoD1tpB/46a8b8Gn/iZ/CVv72nX6/klAHuVFFFABRRRQBy8FzPF8TbuzeeRoJtMSeOMuSqkPtOB0FdRXF6zdx2PxX8MowbfqNpdQKQOP3YEhBrtKiCavfudWKnCfI4fypP1Wn+QUUUVZymJ4q8L6f4t0WTTb9WXkPBPHxJBIPuuh7Ef/AFq8x8T6frnh7w3a+KdUuoIfF+mXK2Vrc22X/taIsAqSJxywycdRjPHb2mvNrT/ivfibJfn59B8MOYbfus96fvP7hBgD3wR1oAnW5+LN8ilNP8LacCBkXEs0rr/3xxmlHh34lXh/0rxzY2KnqtlpSP8Aq5yK9BooA4D/AIVvqd3/AMhP4geJps9RazLbA/goNZWq/DS48NT2/ifwld3t7rVhktBqNwZvtcRGGjBP3WIzgivVKKAOe8Ga7o3iLQ/7Q0aGO3WSRjc24jCPFP8AxhwP4s9+/Wuhrg/Eeg33h7WJfF/ha3D3Dc6ppyjAvUH8Q/6aDr7/AFyG6rQtdsPEekQanps3mW8o78Mjd1YdiPSuiWHkqSrJ3T0fk+z/ADXf5MV9bGlRRWNr/ivRPDFsZtWv44DglYhlpH/3UGSfyxWCi5aJDOK8TafJN8VLmJY2MOq+FbizYgcbw5PPvg4rf+Fl59u+F/h6XOdtoIv++CU/9lrI/t/xz4x+Xw7pS+HtMf8A5iWrJunYescPb/gXBqjpmjeMfhfYR22mKnifw9FlmtUQQ3UGSSxj6hxkk4684GKQHqlFYHhfxlovi61kl0u5JlhO2e2mXZNA3o6np0PPI4PNb9ABRRTZJEhjaSRgiKMlmOABQ3bcB1FecT+ONa8XyvZeALENbhikuu3yFbeP18tSMyN+GPbHNRr4G8X+HSdT0DxdcalqMnz3dpq3zW9y3+xjmL0GPbkCgD0uiuI0T4k2NzqC6N4htJfD+udPs14R5cvvFJ91h+Xtmu3oAKKKiubq3sraS5up44IIl3PLKwVVHqSeBQBLXn1mwj+O2sMxwv8AYULE/SQ1FP471jxbO9h8P7ESwglJdcvUK20fr5YIzI36exHNRr8JXj/4mkXivV18Uk7m1YvkP/sGL7vl8fd/XtQB2Hh3xZofiq1a40bUIrkIcSR8rJGfRkOCPyrarwrX7Z7HUUuPGuny6JqanbB4u0IHy3PQecg5Ge+Rz0GBXTWPjvXPDVtFL4phi1fQ3H7rxFpA8yPb6yxr933I47DNAHp9FU9L1bT9asI77TLyG7tZPuywuGH09j7Hmp7m5gs7aS4uZo4YI13PJIwVVHqSeAKAGX8LXGnXMKDLyRMij3IIryDR9JvPD/iX4T6VqMaxXtvb6kksYcNt/d5HI46V0Nx481bxVcSaf8P7BbiNSUl1u8UpaxHvsGMyN9OOnUVq+Gvh7ZaLqY1zUry41nxCVIbUbtjlARgiNM4QYJHrgkZwcUAdjRRRQAUUUUAed/EqaLTPEHgfWZ5Uhht9X+zySuwVUWWMgkk8AYWr2qeKvFNnqUz6Z4Tj1rRvl8m6s9RjEhG0Zyh6856dsV1OqaRput2f2TVLC3vbfcG8u4jDqGHQgHoeTz71yE/wg8ImUz6fbXmk3B/5a6deSREfQZI/SgCIfF7RrTC69pOu6Gw4LX1g+z8GXOR+Fb+l+PPCes4Fh4h06V26RmcI5/4C2D+lc+fAvi3Tgf7F+IWoMn/PLVbdLvcPQucEfgKwtQ8M+Lt6yar4I8H+IdhyJLQ/ZZz9WcYz9OKAOv8AiL4iutF0KOw0n5tc1eUWWnoDyHb70nsFBznscVr+FfDtr4U8NWWjWnzJbph5COZHPLOfckk1yfhHTta8Q+Mrvxh4m0qTTHto/sel2MrhzCp/1kme5YnGcdMj0r0WgAooooAKKKKACuG1aA+EtdhutBtQq38kt3qVugz56oq52D+FgGZuMbiMHrXc1ynjHU20rUfDkltbxTX1zqC2cfmFsJHIP3jYBHICjrXZgXL2vIldNO66PR7+m/luiZbXMxX8UeLVm/snX7Wx0k3MqG6ih3zsgPyhD93BUj5uv1rY0HwHoXh+c3cVu93qLcvfXr+bMx9dx6fgBWxpekWOi20ltp8AgheV5igJwGY5OPQZ7dBV6qrY2TTp0fdh5JJv1tv+QKPVhRRRXCUULnR7O4macRCG5bBM8Q2uSOmT3/GmRS6hayrFcR/aYjwJ4xgj/eFaVBzg461zTw0XP2kG4y626+q2f5+Y79CgL6eSz+0Q2ok3Y2KkgJ59fT9aiTTZbtxLqcgkxysCcIv+NWre1eO5kncx7nUKRGu0H3PPWrVZRw7qpPEXdumlvVpfk72He2w2ONIY1jiRURRhVUYAHsKdRRXcSZut6BpPiTT2sdYsIby3b+GReVPqp6qfcYNcR/YXjHwJ+88N3T+IdETrpV/Ji4iX0il7j0U/gCa9JooA8tHxy0LULW2g0Oxvb/Xbp/Ki0xk8sq/+25+UD3BP4VdtfAOp+J7mPUviBfrd7W3w6NaEraQntu7yN9eOvUVwXheKOPS/h3KsaCT+37pGcKMkbpMAn8TX0FQBHBBDbQJBbxJFDGoVI41CqoHYAdBUlFFADJYo54nilRZI3BVkcZDA9iO9ef33w3n0e5l1LwJqP9j3Eh3S6dKC9jcexT+D6r07AV6HRQB4rpXiCLwpr11DH8PNRtPFt9GE+x2Lj7FcYOfMVs7VHqcHGPrXS2/gLVfFNzHqHxAv1uUVt8Oi2bFbWI9t56yN9eOvUV6LRQBHb28Nrbx29vDHDDGoVI41Cqo9ABwBUlFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAVw2t41L4ueG7HcpXTrO4v5E3DPzfu1JHXg967mvPvC3/Ez+LXjPVDzHZpbabC30XfIP8AvrFbUKzoyco7tNfemv1E1c9BooorEYUUUUAFFFFABRRRQAUUUUAFFFFAHjmjafaxaZ4RVYsCHX5inzH5WLNk9a9jrzHT9K1BLHQo2sbgPBrzyyKYyCkZLfOR2X3r06sKDk173l+R62a06MJr2SW8tv8AE7fgFFFFbnkhRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAUUUUAFFFFABRRRQAVnaVoenaK9+9hAYmv7p7u4Jdm3ytjceScdOg4HatGigAooooAKKKKACiiigAooooAKKKKACiiigDgLcaf8A2VYrGIf+EoF1EZcY+0+Z5g80t/Fs27uvG3HtXf0d80VEIcp1YnE+3e3d7336ei6IKKKKs5QooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooAKKKKACiiigAooooA/9k=", "image/png": "", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Generate Mol Viz with MTP-specific generation\n", "from rdkit import Chem\n", "from rdkit.Chem import Draw\n", "import selfies as sf\n", "import torch\n", "\n", "# Setup device first\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Check if MTP-specific generation is available\n", "if hasattr(model, 'generate_with_logprobs'):\n", " print(\"Using MTP-specific generation...\")\n", " input_ids = tokenizer(\"\", return_tensors=\"pt\").input_ids.to(device)\n", " \n", " # Try MTP-specific generation with log probabilities\n", " try:\n", " outputs = model.generate_with_logprobs(\n", " input_ids,\n", " max_new_tokens=25, # Correct parameter name\n", " temperature=1,\n", " top_k=50,\n", " do_sample=True,\n", " return_probs=True, # This returns action probabilities\n", " tokenizer=tokenizer # Pass tokenizer for decoding\n", " )\n", " \n", " # Handle the output (returns: decoded_list, logprobs, tokens, probs)\n", " gen = outputs[2] # Get the generated token IDs (index 2)\n", " except Exception as e:\n", " print(f\"MTP generation failed: {e}, falling back to standard generation\")\n", " gen = model.generate(input_ids, max_length=25, top_k=50, temperature=1, do_sample=True, pad_token_id=tokenizer.pad_token_id)\n", "else:\n", " print(\"Using standard generation...\")\n", " input_ids = tokenizer(\"\", return_tensors=\"pt\").input_ids.to(device)\n", " gen = model.generate(input_ids, max_length=25, top_k=50, temperature=1, do_sample=True, pad_token_id=tokenizer.pad_token_id)\n", "\n", "# Decode and process the generated molecule\n", "generatedmol = tokenizer.decode(gen[0], skip_special_tokens=True)\n", "test = generatedmol.replace(' ', '')\n", "csmi_gen = sf.decoder(test)\n", "print(f\"Generated SELFIES: {test}\")\n", "print(f\"Decoded SMILES: {csmi_gen}\")\n", "\n", "mol = Chem.MolFromSmiles(csmi_gen)\n", "\n", "# Draw the molecule\n", "if mol is not None:\n", " img = Draw.MolToImage(mol)\n", " display(img) # Use display() in Jupyter notebooks\n", "else:\n", " print(\"โŒ Could not create molecule from generated SMILES\")\n" ] }, { "cell_type": "markdown", "id": "b16c5461", "metadata": {}, "source": [ "# Testing MTP Head Generation (Local)" ] }, { "cell_type": "code", "execution_count": 2, "id": "cefc1a68", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "`torch_dtype` is deprecated! Use `dtype` instead!\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "๐Ÿš€ ChemQ3-MTP Model Loader Starting...\n", "\n", "๐Ÿ“ Loading library from: ./ChemQ3MTP\n", "๐Ÿ”ง Loading custom modules from ChemQ3MTP...\n", " โœ… Loaded configuration_chemq3mtp.py\n", " โœ… Loaded modeling_chemq3mtp.py\n", " โœ… Loaded FastChemTokenizerHF.py\n", "\n", "๐Ÿ”— Registering model components...\n", "โœ… Model components registered successfully\n", "\n", "๐Ÿ“ Loading model weights from checkpoint: ./chunk-4/\n", "๐Ÿ“ Checkpoint files:\n", " config.json (1161 bytes)\n", " generation_config.json (174 bytes)\n", " model.safetensors (39437252 bytes)\n", " tokenizer_config.json (302 bytes)\n", " training_args.bin (5368 bytes)\n", " training_config.json (248 bytes)\n", " vocab.json (21574 bytes)\n", "\n", "๐Ÿš€ Loading model...\n", "โœ… Config loaded: ChemQ3MTPConfig\n", "โœ… Model loaded: ChemQ3MTPForCausalLM\n", "โœ… Tokenizer loaded: FastChemTokenizerSelfies\n", "\n", "๐Ÿงช Testing model...\n", "๐Ÿ–ฅ๏ธ Using device: cuda\n", "\n", "๐Ÿ“Š Model Information:\n", " Model class: ChemQ3MTPForCausalLM\n", " Config class: ChemQ3MTPConfig\n", " Tokenizer class: FastChemTokenizerSelfies\n", " Model type: chemq3_mtp\n", " Vocab size: 782\n", "\n", "๐Ÿ”ค Testing tokenization...\n", " '[C][C][O]' -> [[0, 379, 379, 377, 1]]\n", " '[C]' -> [[0, 379, 1]]\n", " '[O]' -> [[0, 377, 1]]\n", "\n", "๐ŸŽฏ Testing generation...\n", "\n", " Prompt: '[C]'\n", " 1: [C] [Ring1] [Ring1] [Branch1] [C] [Cl] [Cl]\n", " 2: [C] .[Cl]\n", " 3: [C] .[O] [C] [C] [N] [C] [C] [C] [C] [C] [C] [C] [C] [N] [C] [C] [=C] [C] [=C] [Branch1] [=Branch2] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=C] [Ring1] [N]\n", "\n", " Prompt: '[C][C]'\n", " 1: [C] [C] .[C] [C] [C] [N] [Branch1] [Ring2] [C] [C] [C] [C] [C] [C] [N] [C] [=Branch1] [C] [=O] [C] [=C] [C] [C] [=Branch1] [C] [=O] [N] [Branch1] [C] [C] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [C] [=Ring1] [N] [N] [Ring1] [#C] [C]\n", " 2: [C] [C]\n", " 3: [C] [C]\n", "\n", "๐Ÿ”ฌ Testing MTP functionality...\n", " โœ… MTP training methods available\n", " โœ… MTP generation methods available\n", "\n", "๐ŸŽ‰ Model loading and testing completed successfully!\n" ] } ], "source": [ "import torch\n", "import sys\n", "import os\n", "from pathlib import Path\n", "import importlib.util\n", "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n", "\n", "def load_custom_modules(library_path):\n", " \"\"\"Load all the custom modules required by the model from library directory\"\"\"\n", " \n", " library_path = Path(library_path)\n", " \n", " # Add the library directory to Python path\n", " if str(library_path) not in sys.path:\n", " sys.path.insert(0, str(library_path))\n", " \n", " print(f\"๐Ÿ”ง Loading custom modules from {library_path}...\")\n", " \n", " # Required module files\n", " required_files = {\n", " 'configuration_chemq3mtp.py': 'configuration_chemq3mtp',\n", " 'modeling_chemq3mtp.py': 'modeling_chemq3mtp', \n", " 'FastChemTokenizerHF.py': 'FastChemTokenizerHF'\n", " }\n", " \n", " loaded_modules = {}\n", " \n", " # Load each required module\n", " for filename, module_name in required_files.items():\n", " file_path = library_path / filename\n", " \n", " if not file_path.exists():\n", " print(f\"โŒ Required file not found: {filename}\")\n", " return None\n", " \n", " try:\n", " spec = importlib.util.spec_from_file_location(module_name, file_path)\n", " module = importlib.util.module_from_spec(spec)\n", " \n", " # Execute the module\n", " spec.loader.exec_module(module)\n", " loaded_modules[module_name] = module\n", " \n", " print(f\" โœ… Loaded {filename}\")\n", " \n", " except Exception as e:\n", " print(f\" โŒ Failed to load {filename}: {e}\")\n", " return None\n", " \n", " return loaded_modules\n", "\n", "def register_model_components(loaded_modules):\n", " \"\"\"Register the model components with transformers\"\"\"\n", " \n", " print(\"๐Ÿ”— Registering model components...\")\n", " \n", " try:\n", " # Get the classes from loaded modules\n", " ChemQ3MTPConfig = loaded_modules['configuration_chemq3mtp'].ChemQ3MTPConfig\n", " ChemQ3MTPForCausalLM = loaded_modules['modeling_chemq3mtp'].ChemQ3MTPForCausalLM\n", " FastChemTokenizerSelfies = loaded_modules['FastChemTokenizerHF'].FastChemTokenizerSelfies\n", " \n", " # Register with transformers\n", " AutoConfig.register(\"chemq3_mtp\", ChemQ3MTPConfig)\n", " AutoModelForCausalLM.register(ChemQ3MTPConfig, ChemQ3MTPForCausalLM)\n", " AutoTokenizer.register(ChemQ3MTPConfig, FastChemTokenizerSelfies)\n", " \n", " print(\"โœ… Model components registered successfully\")\n", " \n", " return ChemQ3MTPConfig, ChemQ3MTPForCausalLM, FastChemTokenizerSelfies\n", " \n", " except Exception as e:\n", " print(f\"โŒ Registration failed: {e}\")\n", " return None, None, None\n", "\n", "def load_model(model_path):\n", " \"\"\"Load the model using the registered components\"\"\"\n", " \n", " print(\"๐Ÿš€ Loading model...\")\n", " \n", " try:\n", " # Load config\n", " config = AutoConfig.from_pretrained(str(model_path), trust_remote_code=False)\n", " print(f\"โœ… Config loaded: {config.__class__.__name__}\")\n", " \n", " # Load model\n", " model = AutoModelForCausalLM.from_pretrained(\n", " str(model_path),\n", " config=config,\n", " torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,\n", " trust_remote_code=False # We've already registered everything\n", " )\n", " print(f\"โœ… Model loaded: {model.__class__.__name__}\")\n", " \n", " # Load tokenizer\n", " tokenizer = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=False)\n", " print(f\"โœ… Tokenizer loaded: {tokenizer.__class__.__name__}\")\n", " \n", " return model, tokenizer, config\n", " \n", " except Exception as e:\n", " print(f\"โŒ Model loading failed: {e}\")\n", " import traceback\n", " traceback.print_exc()\n", " return None, None, None\n", "\n", "def test_model(model, tokenizer, config):\n", " \"\"\"Test the loaded model\"\"\"\n", " \n", " print(\"\\n๐Ÿงช Testing model...\")\n", " \n", " # Setup device\n", " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " print(f\"๐Ÿ–ฅ๏ธ Using device: {device}\")\n", " \n", " model = model.to(device)\n", " model.eval()\n", " \n", " # Model info\n", " print(f\"\\n๐Ÿ“Š Model Information:\")\n", " print(f\" Model class: {model.__class__.__name__}\")\n", " print(f\" Config class: {config.__class__.__name__}\")\n", " print(f\" Tokenizer class: {tokenizer.__class__.__name__}\")\n", " print(f\" Model type: {config.model_type}\")\n", " print(f\" Vocab size: {config.vocab_size}\")\n", " \n", " # Set pad token if needed\n", " if not hasattr(tokenizer, 'pad_token') or tokenizer.pad_token is None:\n", " if hasattr(tokenizer, 'eos_token'):\n", " tokenizer.pad_token = tokenizer.eos_token\n", " print(\"โœ… Set pad_token to eos_token\")\n", " \n", " # Test tokenization\n", " print(\"\\n๐Ÿ”ค Testing tokenization...\")\n", " test_inputs = [\"[C][C][O]\", \"[C]\", \"[O]\"]\n", " \n", " for test_input in test_inputs:\n", " try:\n", " tokens = tokenizer(test_input, return_tensors=\"pt\")\n", " print(f\" '{test_input}' -> {tokens.input_ids.tolist()}\")\n", " except Exception as e:\n", " print(f\" โŒ Tokenization failed for '{test_input}': {e}\")\n", " continue\n", " \n", " # Test generation\n", " print(\"\\n๐ŸŽฏ Testing generation...\")\n", " test_prompts = [\"[C]\", \"[C][C]\"]\n", " \n", " for prompt in test_prompts:\n", " try:\n", " input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n", " \n", " with torch.no_grad():\n", " outputs = model.generate(\n", " input_ids,\n", " max_length=input_ids.shape[1] + 20,\n", " temperature=0.8,\n", " top_p=0.9,\n", " top_k=50,\n", " do_sample=True,\n", " pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0,\n", " num_return_sequences=3\n", " )\n", " \n", " print(f\"\\n Prompt: '{prompt}'\")\n", " for i, output in enumerate(outputs):\n", " generated = tokenizer.decode(output, skip_special_tokens=True)\n", " print(f\" {i+1}: {generated}\")\n", " \n", " except Exception as e:\n", " print(f\" โŒ Generation failed for '{prompt}': {e}\")\n", " \n", " # Test MTP functionality if available\n", " print(\"\\n๐Ÿ”ฌ Testing MTP functionality...\")\n", " try:\n", " if hasattr(model, 'set_mtp_training'):\n", " print(\" โœ… MTP training methods available\")\n", " if hasattr(model, 'generate_with_logprobs'):\n", " print(\" โœ… MTP generation methods available\")\n", " else:\n", " print(\" โ„น๏ธ Standard model - no MTP methods detected\")\n", " except Exception as e:\n", " print(f\" โš ๏ธ MTP test error: {e}\")\n", "\n", "def main():\n", " print(\"๐Ÿš€ ChemQ3-MTP Model Loader Starting...\\n\")\n", " \n", " # Library directory (contains the .py files)\n", " library_dir = \"./ChemQ3MTP\"\n", " \n", " # Check if library directory exists\n", " if not Path(library_dir).exists():\n", " print(f\"โŒ Library directory does not exist: {library_dir}\")\n", " return None, None, None\n", " \n", " print(f\"๐Ÿ“ Loading library from: {library_dir}\")\n", " \n", " # Load custom modules from library directory\n", " loaded_modules = load_custom_modules(Path(library_dir))\n", " if loaded_modules is None:\n", " return None, None, None\n", " \n", " print()\n", " \n", " # Register components\n", " config_class, model_class, tokenizer_class = register_model_components(loaded_modules)\n", " if config_class is None:\n", " return None, None, None\n", " \n", " print()\n", " \n", " # Load model from checkpoint directory\n", " checkpoint_dir = \"./chunk-4/\"\n", " \n", " # Check if checkpoint directory exists\n", " if not Path(checkpoint_dir).exists():\n", " print(f\"โŒ Checkpoint directory does not exist: {checkpoint_dir}\")\n", " return None, None, None\n", " \n", " print(f\"๐Ÿ“ Loading model weights from checkpoint: {checkpoint_dir}\")\n", " \n", " # List checkpoint files\n", " print(\"๐Ÿ“ Checkpoint files:\")\n", " for file in Path(checkpoint_dir).iterdir():\n", " if file.is_file():\n", " print(f\" {file.name} ({file.stat().st_size} bytes)\")\n", " \n", " print()\n", " \n", " # Load the model from checkpoint\n", " model, tokenizer, config = load_model(Path(checkpoint_dir))\n", " if model is None:\n", " return None, None, None\n", " \n", " # Test the model\n", " test_model(model, tokenizer, config)\n", " \n", " print(\"\\n๐ŸŽ‰ Model loading and testing completed successfully!\")\n", " \n", " return model, tokenizer, config\n", "\n", "if __name__ == \"__main__\":\n", " model, tokenizer, config = main()" ] }, { "cell_type": "code", "execution_count": 10, "id": "56628930", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using MTP-specific generation...\n", "Generated SELFIES: .[F][C][=C][C][=C][C][=C][Ring1][=Branch1][C][N][C][C][N][Branch1][#Branch2][C][C][=C][C][=C][C][=C][Ring1][=Branch1][C][C][Ring1][=N]\n", "Decoded SMILES: FC1=CC=CC=C1CN2CCN(CC3=CC=CC=C3)CC2\n" ] }, { "data": { "image/jpeg": "", "image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEsCAIAAAD2HxkiAAAegUlEQVR4nO3deXxM9/oH8GcmiURWCRJqb5AKt5ZIUaWWVkuV2NXSUlsp5dJcCW0ttVZRVNPaLlW0KLdorOVX4VKNpddyE9cSQUkqYSayTDIz5/fHl9PI1mTmnPNMks/75Y/MmeR8n7T5zJk553uer06SJAIAPnruAgDKO4QQgBlCCMAMIQRghhACMEMIAZghhADMEEIAZgghADOEEIAZQgjADCEEYIYQAjBDCAGYIYQAzBBCAGYIIQAzhBCAGUIIwAwhBGCGEAIwQwgBmCGEAMwQQgBmCCEAM4QQgBlCCMAMIQRghhACMEMIAZghhADMEEIAZgghADOEEIAZQgjADCEEYIYQAjBDCAGYIYQAzBBCAGYIIQAzhBCAGUIIwAwhBGCGEAIwQwgBmCGEAMwQQgBmCCEAM4QQgBlCCMAMIQRghhACMEMIAZghhADMEEIAZgghADOEEIAZQgjADCEEYIYQAjBDCAGYIYQAzBBCAGYIIQAzhBCAGUIIwAwhBGCGEAIwQwgBmCGEAMwQQgBmCCEAM4QQgBlCCMAMIQRghhACMEMIAZghhADMEEIAZgghADOEEIAZQgjADCEEYIYQAjBDCAGYIYQAzBBCAGYIIQAzZ+4C4E/Hjx/v16/fvXv3/Pz8PDw88n+Dj4+PXl/o66a7u7urq2thz7q4uHh6ehqNxosXL9atWzcqKqpJkybK1A320UmSxF0DPNKpU6ejR49aLBa1B9Lr9S1atPj111/VHgiKAyF0FAcPHuzSpYuHh8eGDRsaNmxY4JHQYDBYrdbC9pCRkWEymQp7Nicn5+HDhwaDITY2duPGjRkZGfv37+/SpYsy1YM9JHAAJpMpKCiIiBYvXqzBcPPmzSOi4ODgnJwcDYaDouHEjENYtmxZfHx8o0aNJkyYoMFwkydPDgwMvHTp0rp16zQYDoqGt6P8kpKSgoKCDAbD3r17X331VW0G3bp164ABA/z9/S9fvuzj46PNoFAgHAn5TZ061WAwhIWFaZZAIurfv3+7du2Sk5PFW1NghCMhs9jY2FatWrm4uJw/f75BgwZaDn327NmWLVs6OztfuHBB46EhNxwJOVmt1nfffddqtYaHh2sfg+bNmw8ZMiQ7OzsyMlLjoSE3HAk5rV27duTIkTVr1oyLiyvwmoTabt++HRQUlJ6efvDgwZdeekn7AoBwJGRkNBo//PBDIlq0aBFLAomoRo0a//jHP4goPDy8iCuQoCqEkM3MmTPv3LnTtm3bAQMGMJYRHh5ep06dc+fOrV+/nrGM8gwh5HHpEv32W29PT/+VK1fqdDrGSipWrDh37lwiWrdu28OHjIWUXwghj0mT6PDhF0aNSmzatCl3LTRo0KChQ7edPLln/nzuUsolnJhhsHMn9e5Nvr50+TJVqcJdDRER/fILtWlDrq703/9S3brc1ZQzOBJqzWSiqVOJiObMcZQEElGrVvTGG5SVRRER3KWUPzgSau3jj+mjj6hxYzp3jpwd6XbOW7coKIgyMujoUWrXjrua8gRHQk3dukULFxIRff65YyWQiGrWpClTiIgmTSJcrdASQqip99+n9HQaOJA6dOAupSAREVSrFp05Q5s2cZdSnuDtqHaOHaP27cnNjf77X6pTh7uaQmzYQMOGUY0aFB9PTDMIyh0cCTVisdD48SRJNG2a4yaQiIYOpdBQun2bFi3iLqXcwJFQIytX0vjx9PTTdPEiublxV1OkEyeobVtyc6O4OKpdm7uacgBHQi2kptLMmUREixc7egKJqE0b6tuXMjNp+nTuUsoHhFALc+bQvXvUpQuFhXGXUjwLF5KbG23aRMePc5dSDjjYafIy6oMPyGymsWO56yi2evVo0iTatInS07lLKQfwmVB5kZHUqBG9+eajh1YrTZtGo0ZRYCBrWSWUnk4GA2VlUd26JDccTk2lrCx66inWysocvB1V3sKFNGIEnT376KHVSgsXUmIia00l5+FBq1dTYCCtXfvnxqVLadAgvprKKIRQFSEhNG5cWZh34u1NERGUnMxdR5mGEKoiMpKuXqXVq7nrsFtICIWE0Pvvc9dRppXHEzMJCQlNmzY1Go3ioZOTk7e3dxHf37798fPnGxX2bMWKf151qFqV9u4lIqpUiebOpalTKSyMKldWqnAeS5dS8+Y0fDh17PhoS3o6RUeTtzc5ORX6U25uVLEiEZFen2C1puR5tsD/5snJyUFBQb6+vsrVXjqUxxD26tVLTiARWSyW+/fvF/H9KSnO164Va8+1av359YgRtHYthYdTaW9y3bgxjR9PY8fSb7892pKURP37F/fHX3hh5rFjG4rznS4uLtWrV7969aqzo81tV1n5+m2J6NKlSxcuXNDpdNu2bevTpw8RWSyW3JnMLy3NzWwu9NnMTMrKevR1hQp/btfr6YsvqFUrGjVKibpZzZpFW7fSZ589eujhQf36kdFIRawflZVFmZlERLVq1QkJCcnzbIH/zRMSEhITE8eNG7dq1SrFSi8V+JbB4CHWIZowYYJ6Q+h00v/936Ovx4+XQkMlIunwYfUGVMvMmVLHjo++3rJF8vKSRo2SXnxRreEmTpxIRP7+/gaDQa0xHFL5OjGzffv2AwcO+Pn5zZgxQ5sR586lW7e0GUpdAwfSc8/R11+rOMTSpUvbtm2bnJy8UNxzWW6UoxBmZmaGh4cT0bx58yprdbbE25vKzF9UVJS6+9fpdIsXL9bpdEuWLElISFB3MIfCfSjWjjj6NW/e3Gw2az/6pUuS1ar9sHY5dkzavPmJLTt2SJs2qTvooEGDiGjAgAHqDuNIyksIExMTPTw8dDrdzz//rP3okZGSk1PeP2gHt2yZNGWKpP2ns5s3b4p+5EePHtV6bCblJYR9+/YlosGDB7OMvmaNRCTVrCmlp7OMX2J370re3hKRtG8fw+hidYAWLVpYLBaG4TVXLkL4008/EZG7u/uNGzdYCrBYpJAQiUiaPZtl/BIbNkwiknr25Bk9PT29Vq1aRPT111/zVKCtsh9Cs9n87LPPEtG8efMYyzhyRCKS3N2lxETGKoolNlbS66UKFaTLl9lqEAtj1KhR4+HDh2xFaKXsh3DZsmVEFBgYmJmZyVtJr14SkTRsGG8Vf8FikVq1koik6dM5y7BaraGhoUQ0Y8YMzjo0UcZDmJKSIq5G7Nq1i7sW6epVydVV0uulU6e4Sync2rUSkVSjhpSWxlzJv//9b51OV7FiRa4PEZop4yEcPXo0Eb388svchTwSHi4RSW3aOOjlCqNRql5dInKUE7n9+vUjoiFDhnAXoq6yHMIzZ844OTlVqFAhLi6Ou5ZHjEapWjWJSNq6lbuUgkyeLBFJbds6ymvEtWvX3NzcdDrdsWPHuGtRUZkNodVqbd++PRGFh4dz1/KEqCiJSKpXT+L+iJrXpUuSi4uk10u//spdSi4RERFE1Lp1a6uDvDCooMyGcOPGjUQUEBDw4MED7lqeYDZLzz4rEUnz53OX8qRXX5WIpLFjuet4ktForF69OhFt2bKFuxa1lM0QpqWl1ahRg4jWr1/PXUsBfvpJIpK8vKTff+cu5bGdOyUiyddX+uMP7lLyWb16NRHVrFkzvbTMdSihshlC8R6mZcuWDjvl4vXXpaeeSoyMXMNdiCRJUmZm5nPPmYikFSu4SymIxWIRdyR+/PHH3LWoogyG8MqVK66urnq9/pdffuGupVDx8Q+8vX2cnJzOnj3LXYs0Z84cT8/qAwfG5uRwl1KII0eOEJGnp+ft27e5a1FeGQzha6+9RkQjRozgLuQvTJo0iYg6dOjAW4Y8YfrIkSO8lRQtLCyMiIYPH85diPLKWgj3799PRN7e3r87zuetQqSmplapUoWIdu7cyVjGwIEDS8WtQ1evXhVvcH51qLO3SihTITSZTEFBQUS0ZMkS7lqKZcWKFUT09NNPZ2VlsRRw7NgxMSslISGBpYASmTJlChE9//zzZexyRZkK4YIFC4ioUaNG2dnZ3LUUS05OTpMmTYjo008/1X50s9nctGlTIppdSm7uMBgMAQEBRLR9+3buWpRUdkJ49+5d0cpy79693LWUQHR0tHj/fPfuXY2HXrlyJRHVrl27FJ36/+KLL4ioXr167NPxFVR2Qjh06FAi6tWrF3chJfbqq68S0VhtL5OnpKSIT6Q7duzQclw7mc3mv/3tb0S0YMEC7loUU0ZCKGbcu7q6Xma8B85Wly5dcnFxcXJy+s9//qPZoGPHjiWizp07azaiUsQt2l5eXnfu3OGuRRllIYQWi+W5554jog8//JC7Fhu9++67RNSpUydthjt//ryzs7Ozs7OWsVeQuAo1evRo7kKUoW4IMzIyrl+/npiYmJqP0WhUahTRsLlmzZql9y7slJQUPz8/ItqzZ48Gw7344otE9Pe//12DsdQQFxfn4uKi1+tPnz7NXYsCVFwkNCMjo0GDBr///rttP16pUiWdTlfYsx4eHhUeN52/d++e0Wj89ttvBwwYYNtYjmDJkiVTpkzx9fXt2rVr+/btC/vdvby8iliqwc3NraJYh6UgTk5O6enp0dHRUVFR/v7+8fHxlSpVsr9yFhMnTly+fHnHjh0PHz7MXYu9VAxhnz59duzYQUQuLi6enp55njWbzWlpaYoM5OPjk5aWdvr06WbNmimyQxbZ2dk1atS4d++eBmM5OTktXrxYtJ0vpVJTU+vXr3///v1nnnnG29u76HW1fHx89PpC+1y7u7u7uroW+FRSUtK1a9dat269dOlSMa9IDWqFMDk5uWHDhgaDYdasWR999JENexC3IBX2bHp6enZ2tvh6/vz5a9as6dChg5hhWHrFxMTMnz8/ISGhXbt2hX1PWlqaufDlabKysjLFOiwFsVgscXFxBoOhXbt2e/bsKdWLH0mSFBwcnJKS8scff2gwnKp/XWqFcMSIEevWrevevfvu3bvV2H9u9+/fb9iw4b1793bu3ClmGEKZt2XLlkGDBlWpUmXbtm3u7u5Fr6tlMBishS+bnPsFPY+kpKQDBw4cO3bMycnp2rVrtWvXtrfuAqnxQTM2Nlav11eoUCE+Pl6N/efHPv8LtJSenl6nTh0iWrt2rQbDvfzyy6Rmqxvlj4SSJD3//PMnT56MiIiYP3++2Jienp7/Y2Hukyv5FfhJMrcmTZr88MMP4muz2dy8efMLFy4sWrTofSzuXNZ98MEHc+fODQkJOXXqVBEf9pRy/fr14OBgk8kUExPTtm1b5QdQPNb//Oc/iahatWq5V5kzGAyKV96gQYPc4x48eJCIvLy8tJ//BVq6evWq6P504sQJsSVN/faMqra6UfhImJaWFhQUdOfOnY0bNw4ZMqToby7ivTgR5eTkPHz4sIgfz8zMbNy4ce4tXbt23bdv39ixY8UMQyiTwsLCfvjhh2HDhomXe4vFEhoaGhgYGBUVJSbiKSIlJWX27Nk+Pj6zZ8+mXH/YW7ZsETd/KUnZTIu3glw3m7DM/wItye935PtFxQuu4tPQT548qdPp3Nzcrl+/Lrao1+pGyRD+73//E7ddnuJrMT1+/HjScP4XaEm+82vRokVii3xj9Pfff6/4cIMHDyai/v37i4fqtbpRMoRdu3YlojFjxii4z5JKTU0Vfe93797NWAao4dNPPyWi+vXry+fAVZ1ze+vWrTwrJarU6kaxEIoTlb6+vsnJyUrt0zZLliwhoqCgoNJyay8UR1JSko+PDxFFR0eLLRcuXFB7GrqYZ9K8eXO5bZ8arW6UCaHJZGrYsCERLV++XJEd2iM7O1sU89lnn3HXAooZPnw4EfXo0UPe0qFDByKaNGmSeoNmZGSIC/QbNmwQW9RodaNMCOfOnUtEwcHBDnLwkQ/LfzhgL1soufzTP7799lsi8vf3v3//vqpDb9iwgZ5cKVHxs48KhPDWrVviqvr+/fvt35tSunTpQkQTJkzgLgTsZbVaW7VqRUTTpk0TWzIyMsSMmdWrV2swurhb9aOPPhJbjEajaHWzbds2RYZQIISDBg0ion79+tm/KwVdvHhRfGC4cOECdy1gl3Xr1oljkXxRXuNF7eWVEuWedMq2urE3hMePHxf1yZdTHMeYMWOIqFu3btyFgO3kBWE2bdoktty4ccPd3V2n08knLTXQv39/Iho8eLB4qGyrG7tCaLFYWrZsSUQzZ860vxTFJScni/Nppav/GuQ2efLkPB/AevXqRURvvvmmlmUkJiaK5MsrJcqtbuxvM21XCKOiooioVq1aDtsz75NPPqFS1YkUchNToHKfijx06JD409d+UYrIyEh6cvpo9+7diWjUqFF27tn2EMqTFRy5E6vJZGrQoAERrVy5krsWKDHRDPKdd94RD3NycsSbwIULF2pfTFpamnhjvPnxYuJyq5vY2Fh79mx7CEvLBLHvv/+eiPz8/FJSUrhrgRLYuXNnnutMYhpG7hkzGluzZk2e6aOiRUjHjh3t2a2NIdRgsoKCXnrpJSKaPHkydyFQXFlZWeItzIrHayYmJSWJtlQ//vgjV1Xy9FF54QD5/eC//vUvm3drYwg7duxIRBMnTrR5YC2dPXvWycnJxcVFszv9wU4ff/wxETVu3Djn8ZqJb7/9NhF1796dtzAxfdTd3T0xMVFsWb58OREFBgbafHy2JYTfffcdEVWuXPnevXu2jao98b+wZ8+e3IXAX5OXTDx8+LDYon3DlCKI07PDhg0TD3NycsR9rYsXL7ZthyUOYUZGRt26dYlo1apVtg3JQl4u5sCBA9y1wF8Qd83K9xBZrdbWrVsTUUREBG9hgjx9VL5l78cffySiSpUq2TZNssQhlOeVm81mG8ZjJCa45n6Ho5KHZvMf2dm5/2kxraOskJdMlKd/FNgwhVd4eDgRtWnTRr5cIU7kjhs3zoa9lSyELJMVlJKZmSmO4V999ZWqA829cSMkNjb3vxRcpSweecnEWbNmiS3yjJmNGzfy1pab0WisVq0aEW3dulVssaerQ8lC2Lt3byIaOnRoSYdxEPLUe9FZWCVzb9z4+5Ur6u2/DMu/ZCJvw5QiiJkquaePjhs3jmxa6KoEIVy0aBEReXh4aD9ZQUGiufXUqVPVGwIhtE3+XhWO0DClMGaz+dlnnyWi+fPniy3yqj4lvYhS3G5rZrM5ICAgNTX1lVde2bdvn7z91KlT4kYPh2W1WnO3pjxz5kxoaKher3/77bd9fX3FRr1eL2aZFqj6Cy94NmxIha9O46LTVXw8RFNPz9V37vxuMs2oW1dsqajXezo52f+LlHldunQ5ePBgp06dxLRMIurWrdvevXvHjBnz5Zdf8tZWoMOHD3fu3NnLyys+Pl68Zxar+gQFBZ09e7aIlXnyKG4Is7OzAwICHjx40KdPn+3bt4uNPXv23LVr188//9y+fXvbfg0NhIeH37hxY/HixbVq1RJbGjRokJ2dnZiYWMw9hM2bd7NLl2J+82f16x81GHbkWiAhrEqVD+rUKVHN5dDu3bt79OhBRKdOnQoNDSWiffv2de3a1dfXNz4+vmrVqtwFFqxHjx67d+8eOXKk6MWWk5Pz9NNP37p1q2/fvtu2bSvmToq7JEiFChXeeeedBQsWHDp0KCUlRTRTCgkJ2bVr18SJE0+fPq1BI2QbXLlyZcWKFTk5OeHh4SKEP/3005UrV9zd3adNmyYv5WO1WotoT1ynVav7VatS4a9WOZKU+Xipg6ouLkTUvlKlTwMDxZZCD6CQS0xMDBFVqFBBJJCIOnbsuGDBAj8/P4dNIBEtWrToyJEjcoUuLi5hYWGff/55dHR0RkaGu7t7sfZSojevoim/vLikfIPzunXrSrQfzeRpAKfNDGB8JrTB9evXxeu4PD26tMjTX0O8iFSrVq34eyhZCM+dOyfmf8XFxYkt33zzDREFBAQ4zjUcWf4GcNrMAEYIbTNnzhwq5av6iHXI9Xr9zp07i/9TJb5YP3LkSCJ6/fXXxUOr1SqWyJg+fXpJd6Wq/A3gNJsBjBDaxmw25+ntW+rYtg55iUOYlJQkPkrJbZ1OnDiRp2G4I8jfAG7EiBGkyQzgEwbDYZVbgJVVpXpVn82bN5NNDeBsmcA9b9488fctz/964403iGjgwIE27E0N+RvAiVNHDjIDGIog5n+NHTuWu5CSkZdMXLNmTUl/1pYQmkym+vXrE9GXX34ptty8eVOcCIqJibFhh4rL0wDOarW+8MIL5DAzgKEIpXRVn+nTp5OtDeBsvJ9w69atRFS1alV5/pfGXeiKkL8B3Pr168nBZgBDEUpL0waZvGSibQch29tbiAv04eHh4mF6erq4EPf111/bvE/75W8AZzQan3rqKXKwGcBQhFK3qk/Pnj2J6K233rLtx20P4ZkzZ8SnrMuXL4st4paT3A3DtZe/AVz+u07A8ckXk0wmE3ctf0E+mWTznGq7Wh6+9dZbRNS7d2/xUKyZSkQzZsywZ7c2y98AzpFnAEMRSsuqPvKSiZ988onNO7ErhLdv3xYnIQ8ePCi2yJ/Hbty4Yc+ebZP/s0S3bt2IaPTo0doXA3YqFav65F8y0Qb2tsEXK3o3a9ZMvtG+X79+xHHPYf4GcLt27SLHWDIRbOPgq/rISybaOf3D3hDK00fXrl0rtly7dk2cKfrll1/s3HmJ5GkAJ8+YWbZsmZZlgIIcfFUf0T1Mnj1mMwVWZdq0aVOe6aMRERH0ZMNwteVvACfPKEAD/FLNYVf1UbABnAIhlKePysvHyX1BtmzZYv/+/1L+BnB37tzJM7cOSinHXNVHbgAXGRlp/96UWan35MmTeaaPrlq1ip5sGK6e/A3gBg8eTER9+/ZVe2jQgAOu6qNsAzhlQig9nik2YMAA8VBuGD5nzhylhihQ/gZwjrxkItjA0Vb1kd/offPNN4rsULEQyl2T5TCIhuGenp6qNobK0wCO/VolqMGhVvVxxDXrZfmnj4aFhRHR8OHDFRwlN9ERKHfORUegWrVqMc7aATU4yKo+akz/UDKE8vTRDRs2iC1yw3B5kUcFyb0q5CWLU1NTRbePbdu2KT4c8BKr+rDfjJanYYoilAyh9Ph+hdzTR6dMmaLssVu2dOlSenI1nAkTJpSu2fdQIuyr+uRvmKIIhUNotVrzfCS7f/++ODop21RC7lWxZ88eseXixYviPrTffvtNwYHAcfCu6pO/YYpSitt3tPhOnDjRtm1bNze3uLi42rVrE9HGjRszMjJGjhzppFwPXJPJtGTJknPnzonL9ET0yiuvHDhw4L333lu2bJlSo4CjmTdv3vTp0xs3bnzu3Dln5z8bdkZHR6enpxf2U66urkV0H9TpdOIFPY9mzZrl/osVQwcHB587d87FxcXGX6BAymZaENNHhwwZosbOCyRuMvbz8ytFSyaCDeRVffKszCdaPSgrLS1N3n/+hikKUv5ISEQ3b9585plnMjMzY2JixGQaVWVmZgYHByckJHz11VejR49Wezjg9d133w0cONDf3//y5cvy4gXvvffe3bt3C/sRk8mUkZFR2LOSJD148CD/9piYGLmV/eDBgzdv3tyvXz/xcq8wxWMtREZGElFgYKAG3S7effddKp1LJoJtNFjVJze1p3+ociQkorS0tMqVK+fk5Pj7+7dp0yYgICD/9xS9DAsReXt7F/ExsmLFiiaT6fjx4+Kc1aFDhzp37mx/5eD4xKo+zs7Ox48fF91M1GOxWFq3bh0bGztz5swZM2aoMYRaISSiUaNGrVmzRqWd51GvXr1r165pMxY4gtdeey06OtrPz2/06NEtWrQo7OSKzMvLK/eJnDzc3NwKXEQpOTl51qxZ+/btq1at2tWrV4u7tkQJqRhCq9W6YsWK2NjYhg0bFngkLHoZFiIyGo0Wi6WwZzMzM+/evXv+/PnAwMAvvvhCXnQJyoPTp0+Hhoaq99ebx5QpU8RN9GpQMYQAqoqKilq2bJnFYmnevLlUyMkV2V++oGdlZeXfnpycnJOT065dO9HNSSUIIQAzR1xUEKBcQQgBmCGEAMwQQgBmCCEAM4QQgBlCCMAMIQRghhACMEMIAZghhADMEEIAZgghADOEEIAZQgjADCEEYIYQAjBDCAGYIYQAzBBCAGYIIQAzhBCAGUIIwAwhBGCGEAIwQwgBmCGEAMwQQgBmCCEAM4QQgBlCCMAMIQRghhACMEMIAZghhADMEEIAZgghADOEEIAZQgjADCEEYIYQAjBDCAGYIYQAzBBCAGYIIQAzhBCAGUIIwAwhBGCGEAIwQwgBmCGEAMwQQgBmCCEAM4QQgBlCCMAMIQRghhACMEMIAZghhADMEEIAZgghADOEEIAZQgjADCEEYIYQAjBDCAGYIYQAzBBCAGYIIQAzhBCAGUIIwAwhBGCGEAIwQwgBmCGEAMwQQgBmCCEAM4QQgBlCCMAMIQRghhACMEMIAZghhADMEEIAZgghADOEEIAZQgjADCEEYIYQAjD7f6ajdl+MJgciAAAAAElFTkSuQmCC", "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Generate Mol Viz with MTP-specific generation\n", "from rdkit import Chem\n", "from rdkit.Chem import Draw\n", "import selfies as sf\n", "import torch\n", "\n", "# Setup device first\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "# Check if MTP-specific generation is available\n", "if hasattr(model, 'generate_with_logprobs'):\n", " print(\"Using MTP-specific generation...\")\n", " input_ids = tokenizer(\"\", return_tensors=\"pt\").input_ids.to(device)\n", " \n", " # Try MTP-specific generation with log probabilities\n", " try:\n", " outputs = model.generate_with_logprobs(\n", " input_ids,\n", " max_new_tokens=25, # Correct parameter name\n", " temperature=1,\n", " top_k=50,\n", " do_sample=True,\n", " return_probs=True, # This returns action probabilities\n", " tokenizer=tokenizer # Pass tokenizer for decoding\n", " )\n", " \n", " # Handle the output (returns: decoded_list, logprobs, tokens, probs)\n", " gen = outputs[2] # Get the generated token IDs (index 2)\n", " except Exception as e:\n", " print(f\"MTP generation failed: {e}, falling back to standard generation\")\n", " gen = model.generate(input_ids, max_length=25, top_k=50, temperature=1, do_sample=True, pad_token_id=tokenizer.pad_token_id)\n", "else:\n", " print(\"Using standard generation...\")\n", " input_ids = tokenizer(\"\", return_tensors=\"pt\").input_ids.to(device)\n", " gen = model.generate(input_ids, max_length=25, top_k=50, temperature=1, do_sample=True, pad_token_id=tokenizer.pad_token_id)\n", "\n", "# Decode and process the generated molecule\n", "generatedmol = tokenizer.decode(gen[0], skip_special_tokens=True)\n", "test = generatedmol.replace(' ', '')\n", "csmi_gen = sf.decoder(test)\n", "print(f\"Generated SELFIES: {test}\")\n", "print(f\"Decoded SMILES: {csmi_gen}\")\n", "\n", "mol = Chem.MolFromSmiles(csmi_gen)\n", "\n", "# Draw the molecule\n", "if mol is not None:\n", " img = Draw.MolToImage(mol)\n", " display(img) # Use display() in Jupyter notebooks\n", "else:\n", " print(\"โŒ Could not create molecule from generated SMILES\")\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "0dc9e278", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "--- Standard Generation Test ---\n", "Generated SELFIES 1: [C]\n", "Generated SELFIES 2: [C]\n", "Generated SELFIES 3: [C] .[C] [=N] [C] [Branch1] [Ring1] [O] [C] [=C] [C] [Branch1] [C] [F] [=C] [Ring1] [=Branch2] [O]\n" ] } ], "source": [ "print(\"\\n--- Standard Generation Test ---\")\n", "input_ids = tokenizer(\" [C]\", return_tensors=\"pt\").input_ids.to(device)\n", "with torch.no_grad():\n", " model.set_mtp_training(False)\n", " gen = model.generate(\n", " input_ids,\n", " max_length=25,\n", " top_k=50,\n", " top_p=0.9,\n", " temperature=1.0,\n", " do_sample=True,\n", " pad_token_id=tokenizer.pad_token_id,\n", " eos_token_id=tokenizer.eos_token_id,\n", " num_return_sequences=3,\n", " )\n", " for i, sequence in enumerate(gen):\n", " result = tokenizer.decode(sequence, skip_special_tokens=True)\n", " print(f\"Generated SELFIES {i+1}: {result}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "edf549c4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.0" } }, "nbformat": 4, "nbformat_minor": 5 }