Spaces:
Runtime error
Runtime error
Commit
·
8081e40
1
Parent(s):
ce1de29
update
Browse files- diffueraser/diffueraser.py +6 -5
- diffueraser/pipeline_diffueraser.py +2 -1
- gradio_app.py +3 -2
- propainter/inference.py +21 -20
- propainter/model/misc.py +9 -3
- run_diffueraser.py +3 -2
diffueraser/diffueraser.py
CHANGED
|
@@ -22,6 +22,7 @@ from libs.unet_motion_model import MotionAdapter, UNetMotionModel
|
|
| 22 |
from libs.brushnet_CA import BrushNetModel
|
| 23 |
from libs.unet_2d_condition import UNet2DConditionModel
|
| 24 |
from diffueraser.pipeline_diffueraser import StableDiffusionDiffuEraserPipeline
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
checkpoints = {
|
|
@@ -318,7 +319,7 @@ class DiffuEraser:
|
|
| 318 |
latents.append(self.vae.encode(pixel_values[i : i + num]).latent_dist.sample())
|
| 319 |
latents = torch.cat(latents, dim=0)
|
| 320 |
latents = latents * self.vae.config.scaling_factor #[(b f), c1, h, w], c1=4
|
| 321 |
-
|
| 322 |
timesteps = torch.tensor([0], device=self.device)
|
| 323 |
timesteps = timesteps.long()
|
| 324 |
|
|
@@ -349,7 +350,7 @@ class DiffuEraser:
|
|
| 349 |
guidance_scale=guidance_scale_final,
|
| 350 |
latents=latents_pre,
|
| 351 |
).latents
|
| 352 |
-
|
| 353 |
|
| 354 |
def decode_latents(latents, weight_dtype):
|
| 355 |
latents = 1 / self.vae.config.scaling_factor * latents
|
|
@@ -363,7 +364,7 @@ class DiffuEraser:
|
|
| 363 |
with torch.no_grad():
|
| 364 |
video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16)
|
| 365 |
images_pre_out = self.image_processor.postprocess(video_tensor_temp, output_type="pil")
|
| 366 |
-
|
| 367 |
|
| 368 |
## replace input frames with updated frames
|
| 369 |
black_image = Image.new('L', validation_masks_input[0].size, color=0)
|
|
@@ -376,7 +377,7 @@ class DiffuEraser:
|
|
| 376 |
latents_pre_out=None
|
| 377 |
sample_index=None
|
| 378 |
gc.collect()
|
| 379 |
-
|
| 380 |
|
| 381 |
################ Frame-by-frame inference ################
|
| 382 |
## add priori
|
|
@@ -396,7 +397,7 @@ class DiffuEraser:
|
|
| 396 |
images = images[:real_video_length]
|
| 397 |
|
| 398 |
gc.collect()
|
| 399 |
-
|
| 400 |
|
| 401 |
################ Compose ################
|
| 402 |
binary_masks = validation_masks_input_ori
|
|
|
|
| 22 |
from libs.brushnet_CA import BrushNetModel
|
| 23 |
from libs.unet_2d_condition import UNet2DConditionModel
|
| 24 |
from diffueraser.pipeline_diffueraser import StableDiffusionDiffuEraserPipeline
|
| 25 |
+
import devicetorch
|
| 26 |
|
| 27 |
|
| 28 |
checkpoints = {
|
|
|
|
| 319 |
latents.append(self.vae.encode(pixel_values[i : i + num]).latent_dist.sample())
|
| 320 |
latents = torch.cat(latents, dim=0)
|
| 321 |
latents = latents * self.vae.config.scaling_factor #[(b f), c1, h, w], c1=4
|
| 322 |
+
devicetorch.empty_cache(torch)
|
| 323 |
timesteps = torch.tensor([0], device=self.device)
|
| 324 |
timesteps = timesteps.long()
|
| 325 |
|
|
|
|
| 350 |
guidance_scale=guidance_scale_final,
|
| 351 |
latents=latents_pre,
|
| 352 |
).latents
|
| 353 |
+
devicetorch.empty_cache(torch)
|
| 354 |
|
| 355 |
def decode_latents(latents, weight_dtype):
|
| 356 |
latents = 1 / self.vae.config.scaling_factor * latents
|
|
|
|
| 364 |
with torch.no_grad():
|
| 365 |
video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16)
|
| 366 |
images_pre_out = self.image_processor.postprocess(video_tensor_temp, output_type="pil")
|
| 367 |
+
devicetorch.empty_cache(torch)
|
| 368 |
|
| 369 |
## replace input frames with updated frames
|
| 370 |
black_image = Image.new('L', validation_masks_input[0].size, color=0)
|
|
|
|
| 377 |
latents_pre_out=None
|
| 378 |
sample_index=None
|
| 379 |
gc.collect()
|
| 380 |
+
devicetorch.empty_cache(torch)
|
| 381 |
|
| 382 |
################ Frame-by-frame inference ################
|
| 383 |
## add priori
|
|
|
|
| 397 |
images = images[:real_video_length]
|
| 398 |
|
| 399 |
gc.collect()
|
| 400 |
+
devicetorch.empty_cache(torch)
|
| 401 |
|
| 402 |
################ Compose ################
|
| 403 |
binary_masks = validation_masks_input_ori
|
diffueraser/pipeline_diffueraser.py
CHANGED
|
@@ -36,6 +36,7 @@ from diffusers import (
|
|
| 36 |
from libs.unet_2d_condition import UNet2DConditionModel
|
| 37 |
from libs.brushnet_CA import BrushNetModel
|
| 38 |
|
|
|
|
| 39 |
|
| 40 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 41 |
|
|
@@ -1326,7 +1327,7 @@ class StableDiffusionDiffuEraserPipeline(
|
|
| 1326 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 1327 |
self.unet.to("cpu")
|
| 1328 |
self.brushnet.to("cpu")
|
| 1329 |
-
|
| 1330 |
|
| 1331 |
if output_type == "latent":
|
| 1332 |
image = latents
|
|
|
|
| 36 |
from libs.unet_2d_condition import UNet2DConditionModel
|
| 37 |
from libs.brushnet_CA import BrushNetModel
|
| 38 |
|
| 39 |
+
import devicetorch
|
| 40 |
|
| 41 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 42 |
|
|
|
|
| 1327 |
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 1328 |
self.unet.to("cpu")
|
| 1329 |
self.brushnet.to("cpu")
|
| 1330 |
+
devicetorch.empty_cache(torch)
|
| 1331 |
|
| 1332 |
if output_type == "latent":
|
| 1333 |
image = latents
|
gradio_app.py
CHANGED
|
@@ -8,6 +8,7 @@ import gradio as gr
|
|
| 8 |
|
| 9 |
# Download Weights
|
| 10 |
from huggingface_hub import snapshot_download
|
|
|
|
| 11 |
|
| 12 |
# List of subdirectories to create inside "checkpoints"
|
| 13 |
subfolders = [
|
|
@@ -93,7 +94,7 @@ def infer(input_video, input_mask):
|
|
| 93 |
inference_time = end_time - start_time
|
| 94 |
print(f"DiffuEraser inference time: {inference_time:.4f} s")
|
| 95 |
|
| 96 |
-
|
| 97 |
|
| 98 |
return output_path
|
| 99 |
|
|
@@ -150,4 +151,4 @@ demo.queue().launch(show_api=False, show_error=True)
|
|
| 150 |
|
| 151 |
|
| 152 |
|
| 153 |
-
|
|
|
|
| 8 |
|
| 9 |
# Download Weights
|
| 10 |
from huggingface_hub import snapshot_download
|
| 11 |
+
import devicetorch
|
| 12 |
|
| 13 |
# List of subdirectories to create inside "checkpoints"
|
| 14 |
subfolders = [
|
|
|
|
| 94 |
inference_time = end_time - start_time
|
| 95 |
print(f"DiffuEraser inference time: {inference_time:.4f} s")
|
| 96 |
|
| 97 |
+
devicetorch.empty_cache(torch)
|
| 98 |
|
| 99 |
return output_path
|
| 100 |
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
|
| 154 |
+
|
propainter/inference.py
CHANGED
|
@@ -24,6 +24,7 @@ except:
|
|
| 24 |
from propainter.core.utils import to_tensors
|
| 25 |
from propainter.model.misc import get_device
|
| 26 |
|
|
|
|
| 27 |
import warnings
|
| 28 |
warnings.filterwarnings("ignore")
|
| 29 |
|
|
@@ -247,15 +248,15 @@ class Propainter:
|
|
| 247 |
|
| 248 |
gt_flows_f_list.append(flows_f)
|
| 249 |
gt_flows_b_list.append(flows_b)
|
| 250 |
-
|
| 251 |
|
| 252 |
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
| 253 |
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
| 254 |
gt_flows_bi = (gt_flows_f, gt_flows_b)
|
| 255 |
else:
|
| 256 |
gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
|
| 257 |
-
|
| 258 |
-
|
| 259 |
gc.collect()
|
| 260 |
|
| 261 |
if use_half:
|
|
@@ -284,7 +285,7 @@ class Propainter:
|
|
| 284 |
|
| 285 |
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
| 286 |
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
| 287 |
-
|
| 288 |
|
| 289 |
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
| 290 |
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
|
@@ -292,8 +293,8 @@ class Propainter:
|
|
| 292 |
else:
|
| 293 |
pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
|
| 294 |
pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
|
| 295 |
-
|
| 296 |
-
|
| 297 |
gc.collect()
|
| 298 |
|
| 299 |
|
|
@@ -321,15 +322,15 @@ class Propainter:
|
|
| 321 |
|
| 322 |
gt_flows_f_list.append(flows_f)
|
| 323 |
gt_flows_b_list.append(flows_b)
|
| 324 |
-
|
| 325 |
|
| 326 |
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
| 327 |
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
| 328 |
sample_gt_flows_bi = (gt_flows_f, gt_flows_b)
|
| 329 |
else:
|
| 330 |
sample_gt_flows_bi = self.fix_raft(sample_frames, iters=raft_iter)
|
| 331 |
-
|
| 332 |
-
|
| 333 |
gc.collect()
|
| 334 |
|
| 335 |
if use_half:
|
|
@@ -356,7 +357,7 @@ class Propainter:
|
|
| 356 |
|
| 357 |
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
| 358 |
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
| 359 |
-
|
| 360 |
|
| 361 |
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
| 362 |
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
|
@@ -364,8 +365,8 @@ class Propainter:
|
|
| 364 |
else:
|
| 365 |
sample_pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(sample_gt_flows_bi, sample_flow_masks)
|
| 366 |
sample_pred_flows_bi = self.fix_flow_complete.combine_flow(sample_gt_flows_bi, sample_pred_flows_bi, sample_flow_masks)
|
| 367 |
-
|
| 368 |
-
|
| 369 |
gc.collect()
|
| 370 |
|
| 371 |
masked_frames = sample_frames * (1 - sample_masks_dilated)
|
|
@@ -391,7 +392,7 @@ class Propainter:
|
|
| 391 |
|
| 392 |
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
| 393 |
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
| 394 |
-
|
| 395 |
|
| 396 |
updated_frames = torch.cat(updated_frames, dim=1)
|
| 397 |
updated_masks = torch.cat(updated_masks, dim=1)
|
|
@@ -400,7 +401,7 @@ class Propainter:
|
|
| 400 |
prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, sample_pred_flows_bi, sample_masks_dilated, 'nearest')
|
| 401 |
updated_frames = sample_frames * (1 - sample_masks_dilated) + prop_imgs.view(b, t, 3, h, w) * sample_masks_dilated
|
| 402 |
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
| 403 |
-
|
| 404 |
|
| 405 |
## replace input frames/masks with updated frames/masks
|
| 406 |
for i,index in enumerate(index_sample):
|
|
@@ -432,7 +433,7 @@ class Propainter:
|
|
| 432 |
|
| 433 |
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
| 434 |
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
| 435 |
-
|
| 436 |
|
| 437 |
updated_frames = torch.cat(updated_frames, dim=1)
|
| 438 |
updated_masks = torch.cat(updated_masks, dim=1)
|
|
@@ -441,7 +442,7 @@ class Propainter:
|
|
| 441 |
prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
|
| 442 |
updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
|
| 443 |
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
| 444 |
-
|
| 445 |
|
| 446 |
comp_frames = [None] * video_length
|
| 447 |
|
|
@@ -451,7 +452,7 @@ class Propainter:
|
|
| 451 |
else:
|
| 452 |
ref_num = -1
|
| 453 |
|
| 454 |
-
|
| 455 |
# ---- feature propagation + transformer ----
|
| 456 |
for f in tqdm(range(0, video_length, neighbor_stride)):
|
| 457 |
neighbor_ids = [
|
|
@@ -488,7 +489,7 @@ class Propainter:
|
|
| 488 |
|
| 489 |
comp_frames[idx] = comp_frames[idx].astype(np.uint8)
|
| 490 |
|
| 491 |
-
|
| 492 |
|
| 493 |
##save composed video##
|
| 494 |
comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
|
|
@@ -499,7 +500,7 @@ class Propainter:
|
|
| 499 |
writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 500 |
writer.release()
|
| 501 |
|
| 502 |
-
|
| 503 |
|
| 504 |
return output_path
|
| 505 |
|
|
@@ -517,4 +518,4 @@ if __name__ == '__main__':
|
|
| 517 |
res = propainter.forward(video, mask, output)
|
| 518 |
|
| 519 |
|
| 520 |
-
|
|
|
|
| 24 |
from propainter.core.utils import to_tensors
|
| 25 |
from propainter.model.misc import get_device
|
| 26 |
|
| 27 |
+
import devicetorch
|
| 28 |
import warnings
|
| 29 |
warnings.filterwarnings("ignore")
|
| 30 |
|
|
|
|
| 248 |
|
| 249 |
gt_flows_f_list.append(flows_f)
|
| 250 |
gt_flows_b_list.append(flows_b)
|
| 251 |
+
devicetorch.empty_cache(torch)
|
| 252 |
|
| 253 |
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
| 254 |
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
| 255 |
gt_flows_bi = (gt_flows_f, gt_flows_b)
|
| 256 |
else:
|
| 257 |
gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
|
| 258 |
+
devicetorch.empty_cache(torch)
|
| 259 |
+
devicetorch.empty_cache(torch)
|
| 260 |
gc.collect()
|
| 261 |
|
| 262 |
if use_half:
|
|
|
|
| 285 |
|
| 286 |
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
| 287 |
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
| 288 |
+
devicetorch.empty_cache(torch)
|
| 289 |
|
| 290 |
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
| 291 |
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
|
|
|
| 293 |
else:
|
| 294 |
pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
|
| 295 |
pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
|
| 296 |
+
devicetorch.empty_cache(torch)
|
| 297 |
+
devicetorch.empty_cache(torch)
|
| 298 |
gc.collect()
|
| 299 |
|
| 300 |
|
|
|
|
| 322 |
|
| 323 |
gt_flows_f_list.append(flows_f)
|
| 324 |
gt_flows_b_list.append(flows_b)
|
| 325 |
+
devicetorch.empty_cache(torch)
|
| 326 |
|
| 327 |
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
| 328 |
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
| 329 |
sample_gt_flows_bi = (gt_flows_f, gt_flows_b)
|
| 330 |
else:
|
| 331 |
sample_gt_flows_bi = self.fix_raft(sample_frames, iters=raft_iter)
|
| 332 |
+
devicetorch.empty_cache(torch)
|
| 333 |
+
devicetorch.empty_cache(torch)
|
| 334 |
gc.collect()
|
| 335 |
|
| 336 |
if use_half:
|
|
|
|
| 357 |
|
| 358 |
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
| 359 |
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
| 360 |
+
devicetorch.empty_cache(torch)
|
| 361 |
|
| 362 |
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
| 363 |
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
|
|
|
| 365 |
else:
|
| 366 |
sample_pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(sample_gt_flows_bi, sample_flow_masks)
|
| 367 |
sample_pred_flows_bi = self.fix_flow_complete.combine_flow(sample_gt_flows_bi, sample_pred_flows_bi, sample_flow_masks)
|
| 368 |
+
devicetorch.empty_cache(torch)
|
| 369 |
+
devicetorch.empty_cache(torch)
|
| 370 |
gc.collect()
|
| 371 |
|
| 372 |
masked_frames = sample_frames * (1 - sample_masks_dilated)
|
|
|
|
| 392 |
|
| 393 |
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
| 394 |
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
| 395 |
+
devicetorch.empty_cache(torch)
|
| 396 |
|
| 397 |
updated_frames = torch.cat(updated_frames, dim=1)
|
| 398 |
updated_masks = torch.cat(updated_masks, dim=1)
|
|
|
|
| 401 |
prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, sample_pred_flows_bi, sample_masks_dilated, 'nearest')
|
| 402 |
updated_frames = sample_frames * (1 - sample_masks_dilated) + prop_imgs.view(b, t, 3, h, w) * sample_masks_dilated
|
| 403 |
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
| 404 |
+
devicetorch.empty_cache(torch)
|
| 405 |
|
| 406 |
## replace input frames/masks with updated frames/masks
|
| 407 |
for i,index in enumerate(index_sample):
|
|
|
|
| 433 |
|
| 434 |
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
| 435 |
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
| 436 |
+
devicetorch.empty_cache(torch)
|
| 437 |
|
| 438 |
updated_frames = torch.cat(updated_frames, dim=1)
|
| 439 |
updated_masks = torch.cat(updated_masks, dim=1)
|
|
|
|
| 442 |
prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
|
| 443 |
updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
|
| 444 |
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
| 445 |
+
devicetorch.empty_cache(torch)
|
| 446 |
|
| 447 |
comp_frames = [None] * video_length
|
| 448 |
|
|
|
|
| 452 |
else:
|
| 453 |
ref_num = -1
|
| 454 |
|
| 455 |
+
devicetorch.empty_cache(torch)
|
| 456 |
# ---- feature propagation + transformer ----
|
| 457 |
for f in tqdm(range(0, video_length, neighbor_stride)):
|
| 458 |
neighbor_ids = [
|
|
|
|
| 489 |
|
| 490 |
comp_frames[idx] = comp_frames[idx].astype(np.uint8)
|
| 491 |
|
| 492 |
+
devicetorch.empty_cache(torch)
|
| 493 |
|
| 494 |
##save composed video##
|
| 495 |
comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
|
|
|
|
| 500 |
writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
| 501 |
writer.release()
|
| 502 |
|
| 503 |
+
devicetorch.empty_cache(torch)
|
| 504 |
|
| 505 |
return output_path
|
| 506 |
|
|
|
|
| 518 |
res = propainter.forward(video, mask, output)
|
| 519 |
|
| 520 |
|
| 521 |
+
|
propainter/model/misc.py
CHANGED
|
@@ -7,6 +7,7 @@ import torch.nn as nn
|
|
| 7 |
import logging
|
| 8 |
import numpy as np
|
| 9 |
from os import path as osp
|
|
|
|
| 10 |
|
| 11 |
def constant_init(module, val, bias=0):
|
| 12 |
if hasattr(module, 'weight') and module.weight is not None:
|
|
@@ -81,8 +82,13 @@ def set_random_seed(seed):
|
|
| 81 |
random.seed(seed)
|
| 82 |
np.random.seed(seed)
|
| 83 |
torch.manual_seed(seed)
|
| 84 |
-
|
| 85 |
-
torch.cuda.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
def get_time_str():
|
|
@@ -128,4 +134,4 @@ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
|
| 128 |
else:
|
| 129 |
continue
|
| 130 |
|
| 131 |
-
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
|
|
|
| 7 |
import logging
|
| 8 |
import numpy as np
|
| 9 |
from os import path as osp
|
| 10 |
+
import devicetorch
|
| 11 |
|
| 12 |
def constant_init(module, val, bias=0):
|
| 13 |
if hasattr(module, 'weight') and module.weight is not None:
|
|
|
|
| 82 |
random.seed(seed)
|
| 83 |
np.random.seed(seed)
|
| 84 |
torch.manual_seed(seed)
|
| 85 |
+
|
| 86 |
+
if torch.cuda.is_available():
|
| 87 |
+
torch.cuda.manual_seed(seed)
|
| 88 |
+
torch.cuda.manual_seed_all(seed)
|
| 89 |
+
|
| 90 |
+
if torch.backends.mps.is_available():
|
| 91 |
+
torch.mps.manual_seed(seed)
|
| 92 |
|
| 93 |
|
| 94 |
def get_time_str():
|
|
|
|
| 134 |
else:
|
| 135 |
continue
|
| 136 |
|
| 137 |
+
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
run_diffueraser.py
CHANGED
|
@@ -4,6 +4,7 @@ import time
|
|
| 4 |
import argparse
|
| 5 |
from diffueraser.diffueraser import DiffuEraser
|
| 6 |
from propainter.inference import Propainter, get_device
|
|
|
|
| 7 |
|
| 8 |
def main():
|
| 9 |
|
|
@@ -53,10 +54,10 @@ def main():
|
|
| 53 |
inference_time = end_time - start_time
|
| 54 |
print(f"DiffuEraser inference time: {inference_time:.4f} s")
|
| 55 |
|
| 56 |
-
|
| 57 |
|
| 58 |
if __name__ == '__main__':
|
| 59 |
main()
|
| 60 |
|
| 61 |
|
| 62 |
-
|
|
|
|
| 4 |
import argparse
|
| 5 |
from diffueraser.diffueraser import DiffuEraser
|
| 6 |
from propainter.inference import Propainter, get_device
|
| 7 |
+
import devicetorch
|
| 8 |
|
| 9 |
def main():
|
| 10 |
|
|
|
|
| 54 |
inference_time = end_time - start_time
|
| 55 |
print(f"DiffuEraser inference time: {inference_time:.4f} s")
|
| 56 |
|
| 57 |
+
devicetorch.empty_cache(torch)
|
| 58 |
|
| 59 |
if __name__ == '__main__':
|
| 60 |
main()
|
| 61 |
|
| 62 |
|
| 63 |
+
|