Spaces:
Running
Running
| import torch | |
| import random | |
| import string | |
| from transformers import AutoTokenizer, T5EncoderModel | |
| from models.pretrained_models import Plonk | |
| from models.samplers.riemannian_flow_sampler import riemannian_flow_sampler | |
| from models.postprocessing import CartesiantoGPS | |
| from models.schedulers import ( | |
| SigmoidScheduler, | |
| LinearScheduler, | |
| CosineScheduler, | |
| ) | |
| from models.preconditioning import DDPMPrecond | |
| from torchvision import transforms | |
| from transformers import CLIPProcessor, CLIPVisionModel | |
| from utils.image_processing import CenterCrop | |
| import numpy as np | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| MODELS = { | |
| "nicolas-dufour/PLONK_YFCC": {"emb_name": "dinov2"}, | |
| "nicolas-dufour/PLONK_OSV_5M": { | |
| "emb_name": "street_clip", | |
| }, | |
| "nicolas-dufour/PLONK_iNaturalist": { | |
| "emb_name": "dinov2", | |
| }, | |
| } | |
| def scheduler_fn( | |
| scheduler_type: str, start: float, end: float, tau: float, clip_min: float = 1e-9 | |
| ): | |
| if scheduler_type == "sigmoid": | |
| return SigmoidScheduler(start, end, tau, clip_min) | |
| elif scheduler_type == "cosine": | |
| return CosineScheduler(start, end, tau, clip_min) | |
| elif scheduler_type == "linear": | |
| return LinearScheduler(clip_min=clip_min) | |
| else: | |
| raise ValueError(f"Scheduler type {scheduler_type} not supported") | |
| class DinoV2FeatureExtractor: | |
| def __init__(self, device=device): | |
| super().__init__() | |
| self.device = device | |
| self.emb_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg") | |
| self.emb_model.eval() | |
| self.emb_model.to(self.device) | |
| self.augmentation = transforms.Compose( | |
| [ | |
| CenterCrop(ratio="1:1"), | |
| transforms.Resize( | |
| 336, interpolation=transforms.InterpolationMode.BICUBIC | |
| ), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) | |
| ), | |
| ] | |
| ) | |
| def __call__(self, batch): | |
| embs = [] | |
| with torch.no_grad(): | |
| for img in batch["img"]: | |
| emb = self.emb_model( | |
| self.augmentation(img).unsqueeze(0).to(self.device) | |
| ).squeeze(0) | |
| embs.append(emb) | |
| batch["emb"] = torch.stack(embs) | |
| return batch | |
| class StreetClipFeatureExtractor: | |
| def __init__(self, device=device): | |
| self.device = device | |
| self.emb_model = CLIPVisionModel.from_pretrained("geolocal/StreetCLIP").to( | |
| device | |
| ) | |
| self.processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") | |
| def __call__(self, batch): | |
| inputs = self.processor(images=batch["img"], return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.emb_model(**inputs) | |
| embeddings = outputs.last_hidden_state[:, 0] | |
| batch["emb"] = embeddings | |
| return batch | |
| def load_prepocessing(model_name, dtype=torch.float32): | |
| if MODELS[model_name]["emb_name"] == "dinov2": | |
| return DinoV2FeatureExtractor() | |
| elif MODELS[model_name]["emb_name"] == "street_clip": | |
| return StreetClipFeatureExtractor() | |
| else: | |
| raise ValueError(f"Embedding model {MODELS[model_name]['emb_name']} not found") | |
| class PlonkPipeline: | |
| """ | |
| The CADT2IPipeline class is designed to facilitate the generation of images from text prompts using a pre-trained CAD model. | |
| It integrates various components such as samplers, schedulers, and post-processing techniques to produce high-quality images. | |
| Initialization: | |
| CADT2IPipeline( | |
| model_path, | |
| sampler="ddim", | |
| scheduler="sigmoid", | |
| postprocessing="sd_1_5_vae", | |
| scheduler_start=-3, | |
| scheduler_end=3, | |
| scheduler_tau=1.1, | |
| device="cuda", | |
| ) | |
| Parameters: | |
| model_path (str): Path to the pre-trained CAD model. | |
| sampler (str): The sampling method to use. Options are "ddim", "ddpm", "dpm", "dpm_2S", "dpm_2M". Default is "ddim". | |
| scheduler (str): The scheduler type to use. Options are "sigmoid", "cosine", "linear". Default is "sigmoid". | |
| postprocessing (str): The post-processing method to use. Options are "consistency-decoder", "sd_1_5_vae". Default is "sd_1_5_vae". | |
| scheduler_start (float): Start value for the scheduler. Default is -3. | |
| scheduler_end (float): End value for the scheduler. Default is 3. | |
| scheduler_tau (float): Tau value for the scheduler. Default is 1.1. | |
| device (str): Device to run the model on. Default is "cuda". | |
| Methods: | |
| model(*args, **kwargs): | |
| Runs the preconditioning on the network with the provided arguments. | |
| __call__(...): | |
| Generates images based on the provided conditions and parameters. | |
| Parameters: | |
| cond (str or list of str): The conditioning text or list of texts. | |
| num_samples (int, optional): Number of samples to generate. If not provided, it is inferred from cond. | |
| x_N (torch.Tensor, optional): Initial noise tensor. If not provided, it is generated. | |
| latents (torch.Tensor, optional): Previous latents. | |
| num_steps (int, optional): Number of steps for the sampler. If not provided, the default is used. | |
| sampler (callable, optional): Custom sampler function. If not provided, the default sampler is used. | |
| scheduler (callable, optional): Custom scheduler function. If not provided, the default scheduler is used. | |
| cfg (float): Classifier-free guidance scale. Default is 15. | |
| guidance_type (str): Type of guidance. Default is "constant". | |
| guidance_start_step (int): Step to start guidance. Default is 0. | |
| generator (torch.Generator, optional): Random number generator. | |
| coherence_value (float): Doherence value for sampling. Default is 1.0. | |
| uncoherence_value (float): Uncoherence value for sampling. Default is 0.0. | |
| unconfident_prompt (str, optional): Unconfident prompt text. | |
| thresholding_type (str): Type of thresholding. Default is "clamp". | |
| clamp_value (float): Clamp value for thresholding. Default is 1.0. | |
| thresholding_percentile (float): Percentile for thresholding. Default is 0.995. | |
| Returns: | |
| torch.Tensor: The generated image tensor after post-processing. | |
| to(device): | |
| Moves the model and its components to the specified device. | |
| Parameters: | |
| device (str): The device to move the model to (e.g., "cuda", "cpu"). | |
| Returns: | |
| CADT2IPipeline: The pipeline instance with updated device. | |
| Example Usage: | |
| pipe = CADT2IPipeline( | |
| "nicolas-dufour/", | |
| ) | |
| pipe.to("cuda") | |
| image = pipe( | |
| "a beautiful landscape with a river and mountains", | |
| num_samples=4, | |
| ) | |
| """ | |
| def __init__( | |
| self, | |
| model_path, | |
| scheduler="sigmoid", | |
| scheduler_start=-7, | |
| scheduler_end=3, | |
| scheduler_tau=1.0, | |
| device=device, | |
| ): | |
| self.network = Plonk.from_pretrained(model_path).to(device) | |
| self.network.requires_grad_(False).eval() | |
| assert scheduler in [ | |
| "sigmoid", | |
| "cosine", | |
| "linear", | |
| ], f"Scheduler {scheduler} not supported" | |
| self.scheduler = scheduler_fn( | |
| scheduler, scheduler_start, scheduler_end, scheduler_tau | |
| ) | |
| self.cond_preprocessing = load_prepocessing(model_name=model_path) | |
| self.postprocessing = CartesiantoGPS() | |
| self.sampler = riemannian_flow_sampler | |
| self.model_path = model_path | |
| self.preconditioning = DDPMPrecond() | |
| self.device = device | |
| def model(self, *args, **kwargs): | |
| return self.preconditioning(self.network, *args, **kwargs) | |
| def __call__( | |
| self, | |
| images, | |
| batch_size=None, | |
| x_N=None, | |
| num_steps=None, | |
| scheduler=None, | |
| cfg=0, | |
| generator=None, | |
| callback=None, | |
| ): | |
| """Sample from the model given conditioning. | |
| Args: | |
| cond: Conditioning input (image or list of images) | |
| batch_size: Number of samples to generate (inferred from cond if not provided) | |
| x_N: Initial noise tensor (generated if not provided) | |
| num_steps: Number of sampling steps (uses default if not provided) | |
| sampler: Custom sampler function (uses default if not provided) | |
| scheduler: Custom scheduler function (uses default if not provided) | |
| cfg: Classifier-free guidance scale (default 15) | |
| generator: Random number generator | |
| callback: Optional callback function to report progress (step, total_steps) | |
| Returns: | |
| Sampled GPS coordinates after postprocessing | |
| """ | |
| # Set up batch size and initial noise | |
| shape = [3] | |
| if not isinstance(images, list): | |
| images = [images] | |
| if x_N is None: | |
| if batch_size is None: | |
| if isinstance(images, list): | |
| batch_size = len(images) | |
| else: | |
| batch_size = 1 | |
| x_N = torch.randn( | |
| batch_size, *shape, device=self.device, generator=generator | |
| ) | |
| else: | |
| x_N = x_N.to(self.device) | |
| if x_N.ndim == 3: | |
| x_N = x_N.unsqueeze(0) | |
| batch_size = x_N.shape[0] | |
| # Set up batch with conditioning | |
| batch = {"y": x_N} | |
| batch["img"] = images | |
| batch = self.cond_preprocessing(batch) | |
| if len(images) > 1: | |
| assert len(images) == batch_size | |
| else: | |
| batch["emb"] = batch["emb"].repeat(batch_size, 1) | |
| # Use default sampler/scheduler if not provided | |
| sampler = self.sampler | |
| if scheduler is None: | |
| scheduler = self.scheduler | |
| # Sample from model | |
| if num_steps is None: | |
| num_steps = 16 # Default number of steps | |
| # Create a wrapper for the model that updates progress | |
| def model_with_progress(*args, **kwargs): | |
| step = kwargs.pop('current_step', 0) | |
| if callback: | |
| callback(step, num_steps) | |
| return self.model(*args, **kwargs) | |
| output = sampler( | |
| model_with_progress, | |
| batch, | |
| conditioning_keys="emb", | |
| scheduler=scheduler, | |
| num_steps=num_steps, | |
| cfg_rate=cfg, | |
| generator=generator, | |
| callback=callback, | |
| ) | |
| # Apply postprocessing and return | |
| output = self.postprocessing(output) | |
| # To degrees | |
| output = np.degrees(output.detach().cpu().numpy()) | |
| return output | |
| def to(self, device): | |
| self.network.to(device) | |
| self.postprocessing.to(device) | |
| self.device = torch.device(device) | |
| return self | |