| from rdkit import Chem | |
| import torch | |
| from torch import nn | |
| from pytorch_lightning import LightningModule | |
| import torchmetrics | |
| from fsr_fg_model import FsrFgModel | |
| from data import FsrFgDataModule | |
| from pytorch_lightning.cli import LightningCLI | |
| class FsrFgLightning(LightningModule): | |
| def __init__(self, fg_input_dim=2786, mfg_input_dim=2586, num_input_dim=208, | |
| enc_dec_dims=(500, 100), output_dims=(200, 100, 50), num_tasks=2, dropout=0.8, | |
| method='FGR', lr=1e-4, **kwargs): | |
| super(FsrFgLightning, self).__init__() | |
| self.save_hyperparameters('fg_input_dim', 'mfg_input_dim', 'num_input_dim', 'enc_dec_dims', | |
| 'output_dims', 'num_tasks', 'dropout', 'method', 'lr') | |
| self.net = FsrFgModel(fg_input_dim, mfg_input_dim, num_input_dim, enc_dec_dims, output_dims, num_tasks, dropout, | |
| method) | |
| self.lr = lr | |
| self.method = method | |
| self.criterion = nn.CrossEntropyLoss() | |
| self.recon_loss = nn.BCEWithLogitsLoss() | |
| self.softmax = nn.Softmax(dim=1) | |
| self.train_auc = torchmetrics.AUROC(num_classes=num_tasks) | |
| self.valid_auc = torchmetrics.AUROC(num_classes=num_tasks) | |
| self.test_auc = torchmetrics.AUROC(num_classes=num_tasks) | |
| def forward(self, fg, mfg, num_features): | |
| if self.method == 'FG': | |
| y_pred = self.net(fg=fg) | |
| elif self.method == 'MFG': | |
| y_pred = self.net(mfg=mfg) | |
| elif self.method == 'FGR': | |
| y_pred = self.net(fg=fg, mfg=mfg) | |
| else: | |
| y_pred = self.net(fg=fg, mfg=mfg, num_features=num_features) | |
| return y_pred | |
| def configure_optimizers(self): | |
| optimizer = torch.optim.AdamW(self.net.parameters(), lr=self.lr, weight_decay=0.3) | |
| scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
| optimizer, max_lr=1e-2, total_steps=self.trainer.estimated_stepping_batches) | |
| return [optimizer], [scheduler] | |
| def training_step(self, batch, batch_idx): | |
| fg, mfg, num_features, y = batch | |
| y_pred, recon = self(fg, mfg, num_features) | |
| if self.method == 'FG': | |
| loss_r_pre = 1e-4 * self.recon_loss(recon, fg) | |
| elif self.method == 'MFG': | |
| loss_r_pre = 1e-4 * self.recon_loss(recon, mfg) | |
| else: | |
| loss_r_pre = 1e-4 * self.recon_loss(recon, torch.cat([fg, mfg], dim=1)) | |
| loss = self.criterion(y_pred, y) + loss_r_pre | |
| self.train_auc(self.softmax(y_pred), y) | |
| self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=False, logger=True) | |
| self.log('train_auc', self.train_auc, on_epoch=True, on_step=False, prog_bar=True, logger=True) | |
| return loss | |
| def validation_step(self, batch, batch_idx): | |
| fg, mfg, num_features, y = batch | |
| y_pred, recon = self(fg, mfg, num_features) | |
| loss = self.criterion(y_pred, y) | |
| self.valid_auc(self.softmax(y_pred), y) | |
| self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) | |
| self.log('val_auc', self.valid_auc, on_step=False, on_epoch=True, prog_bar=True, logger=True) | |
| def test_step(self, batch, batch_idx): | |
| fg, mfg, num_features, y = batch | |
| y_pred, recon = self(fg, mfg, num_features) | |
| loss = self.criterion(y_pred, y) | |
| self.test_auc(self.softmax(y_pred), y) | |
| self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True) | |
| self.log('test_auc', self.test_auc, on_step=False, on_epoch=True, prog_bar=True, logger=True) | |
| if __name__ == '__main__': | |
| cli = LightningCLI(model_class=FsrFgLightning, datamodule_class=FsrFgDataModule, | |
| save_config_callback=None, run=False) | |
| cli.trainer.fit(cli.model, cli.datamodule) | |
| cli.trainer.test(cli.model, cli.datamodule, ckpt_path='best') | |