xingpng commited on
Commit
9e81046
·
verified ·
1 Parent(s): 5429671

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +30 -3
README.md CHANGED
@@ -1,3 +1,30 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+ ```python
5
+ # You can use the following code to call our trained style encoder. Hope it helps.
6
+ import torchvision.transforms.functional as F
7
+ from torchvision import transforms
8
+ from transformers import (AutoModel, AutoProcessor, AutoTokenizer, AutoConfig,
9
+ CLIPImageProcessor, CLIPVisionModelWithProjection)
10
+ class SEStyleEmbedding:
11
+ def __init__(self, pretrained_path: str = "xingpng/OneIG-StyleEncoder", device: str = "cuda", dtype=torch.bfloat16):
12
+ self.device = torch.device(device)
13
+ self.dtype = dtype
14
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_path)
15
+ self.image_encoder.to(self.device, dtype=self.dtype)
16
+ self.image_encoder.eval()
17
+ self.processor = CLIPImageProcessor()
18
+
19
+ def _l2_normalize(self, x):
20
+ return torch.nn.functional.normalize(x, p=2, dim=-1)
21
+
22
+ def get_style_embedding(self, image_path: str):
23
+ image = Image.open(image_path).convert('RGB')
24
+ inputs = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, dtype=self.dtype)
25
+
26
+ with torch.no_grad():
27
+ outputs = self.image_encoder(inputs)
28
+ image_embeds = outputs.image_embeds
29
+ image_embeds_norm = self._l2_normalize(image_embeds)
30
+ return image_embeds_norm