Spaces:
Sleeping
Sleeping
Update saicinpainting/training/trainers/default.py
Browse files
saicinpainting/training/trainers/default.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
import logging
|
| 2 |
-
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
from omegaconf import OmegaConf
|
|
@@ -13,30 +13,6 @@ from saicinpainting.utils import add_prefix_to_keys, get_ramp
|
|
| 13 |
|
| 14 |
LOGGER = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
-
def resize_to_square(image, target_size):
|
| 17 |
-
h, w = image.shape[:2]
|
| 18 |
-
if h == w:
|
| 19 |
-
return cv2.resize(image, (target_size, target_size))
|
| 20 |
-
|
| 21 |
-
dif = h if h > w else w
|
| 22 |
-
interpolation = cv2.INTER_AREA if dif > target_size else cv2.INTER_CUBIC
|
| 23 |
-
|
| 24 |
-
x_pos = (dif - w) // 2
|
| 25 |
-
y_pos = (dif - h) // 2
|
| 26 |
-
|
| 27 |
-
if len(image.shape) == 2:
|
| 28 |
-
mask = np.zeros((dif, dif), dtype=image.dtype)
|
| 29 |
-
mask[y_pos:y_pos+h, x_pos:x_pos+w] = image
|
| 30 |
-
else:
|
| 31 |
-
mask = np.zeros((dif, dif, image.shape[2]), dtype=image.dtype)
|
| 32 |
-
mask[y_pos:y_pos+h, x_pos:x_pos+w, :] = image
|
| 33 |
-
|
| 34 |
-
return cv2.resize(mask, (target_size, target_size), interpolation=interpolation)
|
| 35 |
-
|
| 36 |
-
# Sử dụng
|
| 37 |
-
target_size = 256
|
| 38 |
-
resized_frame = resize_to_square(frame, target_size)
|
| 39 |
-
|
| 40 |
|
| 41 |
def make_constant_area_crop_batch(batch, **kwargs):
|
| 42 |
crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
|
|
@@ -48,9 +24,25 @@ def make_constant_area_crop_batch(batch, **kwargs):
|
|
| 48 |
|
| 49 |
|
| 50 |
class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
|
| 51 |
-
def __init__(self, *args,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
super().__init__(*args, **kwargs)
|
| 53 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
def forward(self, batch):
|
| 56 |
if self.training and self.rescale_size_getter is not None:
|
|
@@ -58,29 +50,6 @@ class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
|
|
| 58 |
batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
|
| 59 |
batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
|
| 60 |
|
| 61 |
-
# Thêm đoạn code resize ở đây
|
| 62 |
-
resized_images = []
|
| 63 |
-
resized_masks = []
|
| 64 |
-
for img, mask in zip(batch['image'], batch['mask']):
|
| 65 |
-
# Chuyển từ tensor sang numpy array
|
| 66 |
-
img_np = img.permute(1, 2, 0).cpu().numpy()
|
| 67 |
-
mask_np = mask.squeeze().cpu().numpy()
|
| 68 |
-
|
| 69 |
-
# Resize
|
| 70 |
-
img_resized = resize_to_square(img_np, self.target_size)
|
| 71 |
-
mask_resized = resize_to_square(mask_np, self.target_size)
|
| 72 |
-
|
| 73 |
-
# Chuyển lại thành tensor
|
| 74 |
-
img_resized = torch.from_numpy(img_resized).permute(2, 0, 1).float().to(img.device)
|
| 75 |
-
mask_resized = torch.from_numpy(mask_resized).unsqueeze(0).float().to(mask.device)
|
| 76 |
-
|
| 77 |
-
resized_images.append(img_resized)
|
| 78 |
-
resized_masks.append(mask_resized)
|
| 79 |
-
|
| 80 |
-
batch['image'] = torch.stack(resized_images)
|
| 81 |
-
batch['mask'] = torch.stack(resized_masks)
|
| 82 |
-
|
| 83 |
-
# Tiếp tục với phần còn lại của phương thức forward
|
| 84 |
if self.training and self.const_area_crop_kwargs is not None:
|
| 85 |
batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
|
| 86 |
|
|
@@ -203,4 +172,4 @@ class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
|
|
| 203 |
metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss
|
| 204 |
metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_'))
|
| 205 |
|
| 206 |
-
return total_loss, metrics
|
|
|
|
| 1 |
import logging
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
| 5 |
from omegaconf import OmegaConf
|
|
|
|
| 13 |
|
| 14 |
LOGGER = logging.getLogger(__name__)
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
def make_constant_area_crop_batch(batch, **kwargs):
|
| 18 |
crop_y, crop_x, crop_height, crop_width = make_constant_area_crop_params(img_height=batch['image'].shape[2],
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
class DefaultInpaintingTrainingModule(BaseInpaintingTrainingModule):
|
| 27 |
+
def __init__(self, *args, concat_mask=True, rescale_scheduler_kwargs=None, image_to_discriminator='predicted_image',
|
| 28 |
+
add_noise_kwargs=None, noise_fill_hole=False, const_area_crop_kwargs=None,
|
| 29 |
+
distance_weighter_kwargs=None, distance_weighted_mask_for_discr=False,
|
| 30 |
+
fake_fakes_proba=0, fake_fakes_generator_kwargs=None,
|
| 31 |
+
**kwargs):
|
| 32 |
super().__init__(*args, **kwargs)
|
| 33 |
+
self.concat_mask = concat_mask
|
| 34 |
+
self.rescale_size_getter = get_ramp(**rescale_scheduler_kwargs) if rescale_scheduler_kwargs is not None else None
|
| 35 |
+
self.image_to_discriminator = image_to_discriminator
|
| 36 |
+
self.add_noise_kwargs = add_noise_kwargs
|
| 37 |
+
self.noise_fill_hole = noise_fill_hole
|
| 38 |
+
self.const_area_crop_kwargs = const_area_crop_kwargs
|
| 39 |
+
self.refine_mask_for_losses = make_mask_distance_weighter(**distance_weighter_kwargs) \
|
| 40 |
+
if distance_weighter_kwargs is not None else None
|
| 41 |
+
self.distance_weighted_mask_for_discr = distance_weighted_mask_for_discr
|
| 42 |
+
|
| 43 |
+
self.fake_fakes_proba = fake_fakes_proba
|
| 44 |
+
if self.fake_fakes_proba > 1e-3:
|
| 45 |
+
self.fake_fakes_gen = FakeFakesGenerator(**(fake_fakes_generator_kwargs or {}))
|
| 46 |
|
| 47 |
def forward(self, batch):
|
| 48 |
if self.training and self.rescale_size_getter is not None:
|
|
|
|
| 50 |
batch['image'] = F.interpolate(batch['image'], size=cur_size, mode='bilinear', align_corners=False)
|
| 51 |
batch['mask'] = F.interpolate(batch['mask'], size=cur_size, mode='nearest')
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
if self.training and self.const_area_crop_kwargs is not None:
|
| 54 |
batch = make_constant_area_crop_batch(batch, **self.const_area_crop_kwargs)
|
| 55 |
|
|
|
|
| 172 |
metrics['discr_adv_fake_fakes'] = fake_fakes_adv_discr_loss
|
| 173 |
metrics.update(add_prefix_to_keys(fake_fakes_adv_metrics, 'adv_'))
|
| 174 |
|
| 175 |
+
return total_loss, metrics
|