|
|
import torch |
|
|
import pandas as pd |
|
|
import typing |
|
|
import math |
|
|
import utils |
|
|
import numpy as np |
|
|
import os |
|
|
|
|
|
base_path = "" |
|
|
LOGGER = utils.get_logger(__name__) |
|
|
DNA_ALPHABET = {'A': 0, 'C': 1, 'G': 2, 'T': 3} |
|
|
INDEX_TO_DNA = {v: k for k, v in DNA_ALPHABET.items()} |
|
|
lookup_array = np.array([INDEX_TO_DNA[i] for i in range(len(INDEX_TO_DNA))]) |
|
|
|
|
|
|
|
|
def dna_detokenize(seq): |
|
|
return ''.join([list(DNA_ALPHABET.keys())[int(i)] for i in seq]) |
|
|
|
|
|
def batch_dna_detokenize(batch_seq): |
|
|
""" |
|
|
batch_seq: numpy array of shape [batch_size, seq_len] |
|
|
return: list of strings |
|
|
""" |
|
|
detokenized_batch = lookup_array[batch_seq] |
|
|
detokenized_batch = [''.join(seq) for seq in detokenized_batch] |
|
|
return detokenized_batch |
|
|
|
|
|
def dna_tokenize(seq): |
|
|
return [DNA_ALPHABET[c] for c in seq] |
|
|
|
|
|
def batch_dna_tokenize(batch_seq): |
|
|
""" |
|
|
batch_seq: list of strings |
|
|
return: numpy array of shape [batch_size, seq_len] |
|
|
""" |
|
|
tokenized_batch = np.array([[DNA_ALPHABET[c] for c in seq] for seq in batch_seq]) |
|
|
return tokenized_batch |
|
|
|
|
|
class GosaiDataset(torch.utils.data.Dataset): |
|
|
def __init__(self): |
|
|
data_df = pd.read_csv(os.path.join(base_path, f'mdlm/gosai_data/processed_data/gosai_all.csv')) |
|
|
self.seqs = torch.tensor(data_df['seq'].apply(lambda x: [DNA_ALPHABET[c] for c in x]).tolist()) |
|
|
self.clss = torch.tensor(data_df[['hepg2', 'k562', 'sknsh']].to_numpy()) |
|
|
LOGGER.info(f'Loaded data: seqs shape: {self.seqs.shape}, clss shape: {self.clss.shape}') |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.seqs) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return {'seqs': self.seqs[idx], 'clss': self.clss[idx], 'attention_mask': torch.ones(len(self.seqs[idx]))} |
|
|
|
|
|
|
|
|
def get_datasets_gosai(): |
|
|
return GosaiDataset() |
|
|
|
|
|
|
|
|
def get_dataloaders_gosai(config, skip_valid=False, valid_seed=None): |
|
|
num_gpus = torch.cuda.device_count() |
|
|
if config.loader.global_batch_size % ( |
|
|
num_gpus * config.trainer.accumulate_grad_batches) != 0: |
|
|
raise ValueError( |
|
|
f'Train Batch Size {config.training.batch_size}' |
|
|
f'not divisible by {num_gpus} gpus with accumulation ' |
|
|
f'{config.trainer.accumulate_grad_batches}.') |
|
|
if config.loader.eval_global_batch_size % num_gpus != 0: |
|
|
raise ValueError( |
|
|
f'Eval Batch Size for {config.eval.batch_size} ' |
|
|
f'not divisible by {num_gpus}.') |
|
|
train_set = GosaiDataset() |
|
|
|
|
|
valid_set = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), 40000, replace=False)) |
|
|
test_set = torch.utils.data.Subset(train_set, np.random.choice(len(train_set), 40000, replace=False)) |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
|
train_set, |
|
|
batch_size=config.loader.batch_size, |
|
|
num_workers=config.loader.num_workers, |
|
|
pin_memory=config.loader.pin_memory, |
|
|
shuffle=not config.data.streaming, |
|
|
persistent_workers=True) |
|
|
if skip_valid: |
|
|
valid_loader = None |
|
|
test_loader = None |
|
|
else: |
|
|
if valid_seed is None: |
|
|
shuffle_valid = False |
|
|
generator = None |
|
|
else: |
|
|
shuffle_valid = True |
|
|
generator = torch.Generator().manual_seed(valid_seed) |
|
|
valid_loader = torch.utils.data.DataLoader( |
|
|
valid_set, |
|
|
batch_size=config.loader.eval_batch_size, |
|
|
num_workers=config.loader.num_workers, |
|
|
pin_memory=config.loader.pin_memory, |
|
|
shuffle=shuffle_valid, |
|
|
generator=generator) |
|
|
test_loader = torch.utils.data.DataLoader( |
|
|
test_set, |
|
|
batch_size=config.loader.eval_batch_size, |
|
|
num_workers=config.loader.num_workers, |
|
|
pin_memory=config.loader.pin_memory, |
|
|
shuffle=shuffle_valid, |
|
|
generator=generator) |
|
|
|
|
|
return train_loader, valid_loader, test_loader |
|
|
|
|
|
|
|
|
|
|
|
class RandomFaultTolerantSampler(torch.utils.data.RandomSampler): |
|
|
|
|
|
def __init__(self, *args, generator=None, **kwargs): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if generator is None: |
|
|
seed = int(torch.empty((), dtype=torch.int64).random_().item()) |
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
kwargs.pop('shuffle', None) |
|
|
super().__init__(*args, generator=generator, **kwargs) |
|
|
self.counter = 0 |
|
|
self.restarting = False |
|
|
|
|
|
def state_dict(self): |
|
|
return {'random_state': self.generator.get_state(), |
|
|
'counter': self.counter} |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
self.generator.set_state(state_dict.get('random_state')) |
|
|
self.counter = state_dict['counter'] |
|
|
|
|
|
self.restarting = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self) -> typing.Iterator[int]: |
|
|
n = len(self.data_source) |
|
|
|
|
|
self.state = self.generator.get_state() |
|
|
indices = torch.randperm(n, generator=self.generator).tolist() |
|
|
|
|
|
if not self.restarting: |
|
|
self.counter = 0 |
|
|
else: |
|
|
indices = indices[self.counter:] |
|
|
self.restarting = False |
|
|
|
|
|
for index in indices: |
|
|
self.counter += 1 |
|
|
yield index |
|
|
|
|
|
self.counter = 0 |
|
|
|
|
|
|
|
|
class FaultTolerantDistributedSampler(torch.utils.data.DistributedSampler): |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.counter = 0 |
|
|
self.restarting = False |
|
|
|
|
|
def state_dict(self): |
|
|
return {'epoch': self.epoch, 'counter': self.counter} |
|
|
|
|
|
def load_state_dict(self, state_dict): |
|
|
self.epoch = state_dict['epoch'] |
|
|
self.counter = state_dict['counter'] |
|
|
self.restarting = True |
|
|
|
|
|
|
|
|
|
|
|
def __iter__(self): |
|
|
if self.shuffle: |
|
|
|
|
|
g = torch.Generator() |
|
|
g.manual_seed(self.seed + self.epoch) |
|
|
indices = torch.randperm(len(self.dataset), generator=g).tolist() |
|
|
else: |
|
|
indices = list(range(len(self.dataset))) |
|
|
|
|
|
if not self.drop_last: |
|
|
|
|
|
padding_size = self.total_size - len(indices) |
|
|
if padding_size <= len(indices): |
|
|
indices += indices[:padding_size] |
|
|
else: |
|
|
indices += (indices * math.ceil( |
|
|
padding_size / len(indices)))[:padding_size] |
|
|
else: |
|
|
|
|
|
indices = indices[:self.total_size] |
|
|
assert len(indices) == self.total_size |
|
|
|
|
|
|
|
|
indices = indices[self.rank:self.total_size:self.num_replicas] |
|
|
assert len(indices) == self.num_samples |
|
|
|
|
|
if not self.restarting: |
|
|
self.counter = 0 |
|
|
else: |
|
|
indices = indices[self.counter:] |
|
|
self.restarting = False |
|
|
|
|
|
for index in indices: |
|
|
self.counter += 1 |
|
|
yield index |
|
|
|
|
|
self.counter = 0 |
|
|
|
|
|
|