Mathieu5454 commited on
Commit
72a647c
·
1 Parent(s): d615950

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +97 -0
utils.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import numpy as np
4
+
5
+ import torch
6
+
7
+ import torchvision.transforms as T
8
+
9
+
10
+
11
+ totensor = T.ToTensor()
12
+
13
+ topil = T.ToPILImage()
14
+
15
+
16
+
17
+ def recover_image(image, init_image, mask, background=False):
18
+
19
+ image = totensor(image)
20
+
21
+ mask = totensor(mask)
22
+
23
+ init_image = totensor(init_image)
24
+
25
+ if background:
26
+
27
+ result = mask * init_image + (1 - mask) * image
28
+
29
+ else:
30
+
31
+ result = mask * image + (1 - mask) * init_image
32
+
33
+ return topil(result)
34
+
35
+
36
+
37
+ def preprocess(image):
38
+
39
+ w, h = image.size
40
+
41
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
42
+
43
+ image = image.resize((w, h), resample=Image.LANCZOS)
44
+
45
+ image = np.array(image).astype(np.float32) / 255.0
46
+
47
+ image = image[None].transpose(0, 3, 1, 2)
48
+
49
+ image = torch.from_numpy(image)
50
+
51
+ return 2.0 * image - 1.0
52
+
53
+
54
+
55
+ def prepare_mask_and_masked_image(image, mask):
56
+
57
+ image = np.array(image.convert("RGB"))
58
+
59
+ image = image[None].transpose(0, 3, 1, 2)
60
+
61
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
62
+
63
+
64
+
65
+ mask = np.array(mask.convert("L"))
66
+
67
+ mask = mask.astype(np.float32) / 255.0
68
+
69
+ mask = mask[None, None]
70
+
71
+ mask[mask < 0.5] = 0
72
+
73
+ mask[mask >= 0.5] = 1
74
+
75
+ mask = torch.from_numpy(mask)
76
+
77
+
78
+
79
+ masked_image = image * (mask < 0.5)
80
+
81
+
82
+
83
+ return mask, masked_image
84
+
85
+
86
+
87
+ def prepare_image(image):
88
+
89
+ image = np.array(image.convert("RGB"))
90
+
91
+ image = image[None].transpose(0, 3, 1, 2)
92
+
93
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
94
+
95
+
96
+
97
+ return image[0]