Update README.md
Browse files
README.md
CHANGED
|
@@ -1,3 +1,30 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
---
|
| 4 |
+
```python
|
| 5 |
+
# You can use the following code to call our trained style encoder. Hope it helps.
|
| 6 |
+
import torchvision.transforms.functional as F
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from transformers import (AutoModel, AutoProcessor, AutoTokenizer, AutoConfig,
|
| 9 |
+
CLIPImageProcessor, CLIPVisionModelWithProjection)
|
| 10 |
+
class SEStyleEmbedding:
|
| 11 |
+
def __init__(self, pretrained_path: str = "xingpng/OneIG-StyleEncoder", device: str = "cuda", dtype=torch.bfloat16):
|
| 12 |
+
self.device = torch.device(device)
|
| 13 |
+
self.dtype = dtype
|
| 14 |
+
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(pretrained_path)
|
| 15 |
+
self.image_encoder.to(self.device, dtype=self.dtype)
|
| 16 |
+
self.image_encoder.eval()
|
| 17 |
+
self.processor = CLIPImageProcessor()
|
| 18 |
+
|
| 19 |
+
def _l2_normalize(self, x):
|
| 20 |
+
return torch.nn.functional.normalize(x, p=2, dim=-1)
|
| 21 |
+
|
| 22 |
+
def get_style_embedding(self, image_path: str):
|
| 23 |
+
image = Image.open(image_path).convert('RGB')
|
| 24 |
+
inputs = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, dtype=self.dtype)
|
| 25 |
+
|
| 26 |
+
with torch.no_grad():
|
| 27 |
+
outputs = self.image_encoder(inputs)
|
| 28 |
+
image_embeds = outputs.image_embeds
|
| 29 |
+
image_embeds_norm = self._l2_normalize(image_embeds)
|
| 30 |
+
return image_embeds_norm
|