In [1]:
import torch
from torch_geometric.data import Data
from ogb.utils import smiles2graph
import os
import json
from rdkit import RDLogger
from rdkit import Chem
RDLogger.DisableLog('rdApp.*')
from tqdm import tqdm
import multiprocessing

def write_json(data, filename):
    with open(filename, 'w') as f:
        json.dump(data, f, indent=4, ensure_ascii=False)

def read_json(filename):
    with open(filename, 'r') as f:
        data = json.load(f)
    return data

def smiles2data(smiles):
    graph = smiles2graph(smiles)
    x = torch.from_numpy(graph['node_feat'])
    edge_index = torch.from_numpy(graph['edge_index'], )
    edge_attr = torch.from_numpy(graph['edge_feat'])
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return data


In [None]:
# make pretrain graphs
root = 'data/pretrain_data/'
mol_property_list = read_json(f'{root}/Abstract_property.json')
target_file = f'{root}/mol_graph_map.pt'

if not os.path.exists(target_file):
    mol_graph_map = {}
    for mol_dict in tqdm(mol_property_list):
        smiles = mol_dict['canon_smiles']
        graph = smiles2data(smiles)
        mol_graph_map[smiles] = graph
    torch.save(mol_graph_map, target_file)

In [None]:
# make downstrem (action prediction) graphs
root = 'data/action_data'
target_file = f'{root}/mol_graph_map.pt'

if not os.path.exists(target_file):
    all_mols = set()
    reaction_list = read_json(f'{root}/processed.json')
    rxn_keys = ['REACTANT', 'PRODUCT', 'CATALYST', 'SOLVENT']

    for rxn in reaction_list:
        for key in rxn_keys:
            for mol in rxn[key]:
                if mol in all_mols:
                    continue
                all_mols.add(mol)
    mol_graph_map={}

    for smiles in all_mols:
        graph = smiles2data(smiles)
        mol_graph_map[smiles] = graph
    torch.save(mol_graph_map, target_file)

In [None]:
# make downstream (retrosynthesis) graphs
root = 'data/synthesis_data'

for folder in [
    'USPTO_50K_PtoR',
    'USPTO_50K_PtoR_aug20',
    'USPTO-MIT_PtoR_aug5',
    'USPTO-MIT_RtoP_aug5_mixed',
    'USPTO-MIT_RtoP_aug5_separated',
    'USPTO_full_pretrain_aug5_masked_token',
    ]:
    mol_graphid_file = f'{root}/{folder}/mol_graphid_map.json'
    target_file = f'{root}/{folder}/idx_graph_map.pt'
    if not os.path.exists(mol_graphid_file):
        canon_idx_map = {}
        mol_idx_map = {}
        mol_set = set()
        for mode in ['train', 'val', 'test']:
            for file in ['src', 'tgt']:
                if 'pretrain' in folder:
                    if file=='src':
                        continue
                else:
                    if file=='tgt':
                        continue
                file_path = f'{root}/{folder}/{mode}/{file}-{mode}.txt'
                with open(file_path) as f:
                    lines = f.readlines()
                for line in lines:
                    line = line.strip().replace(' ', '')
                    line = line.replace('<separated>', '.')
                    for smi in line.split('.'):
                        mol_set.add(smi)
        smi_list = list(mol_set)
        pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
        canon_list = pool.map(func=Chem.CanonSmiles,iterable=smi_list)
        for smi, canon in zip(smi_list, canon_list):
            if canon not in canon_idx_map:
                canon_idx_map[canon] = len(canon_idx_map)
            mol_idx_map[smi] = canon_idx_map[canon]
        write_json(mol_idx_map, mol_graphid_file)
    else:
        mol_idx_map = read_json(mol_graphid_file)

    cid_graph_map = {}
    for smiles, graph_id in mol_idx_map.items():
        if graph_id in cid_graph_map:
            continue
        graph = smiles2data(smiles)
        cid_graph_map[graph_id] = graph
    torch.save(cid_graph_map, target_file)

In [3]:
# make downstream (retrosynthesis) graphs
root = 'data/ChEBI-20_data'
target_file = f'{root}/cid_graph_map.pt'

cid_graph_map = {}
if not os.path.exists(target_file):
    for mode in ['train', 'validation', 'test']:
        with open(f'{root}/{mode}.txt') as f:
            lines = f.readlines()
        for line in lines[1:]:
            cid, smiles, _ = line.strip().split('\t', maxsplit=2)
            graph = smiles2data(smiles)
            cid_graph_map[cid] = graph
    torch.save(cid_graph_map, target_file)