Spaces:
Runtime error
Runtime error
Make device dependable of the machine capacity
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ import random
|
|
| 3 |
from datasets import load_dataset
|
| 4 |
from sentence_transformers import SentenceTransformer, util
|
| 5 |
import logging
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
# Create a custom logger
|
| 8 |
logger = logging.getLogger(__name__)
|
|
@@ -22,9 +23,10 @@ c_handler.setFormatter(c_format)
|
|
| 22 |
logger.addHandler(c_handler)
|
| 23 |
|
| 24 |
class SearchEngine:
|
| 25 |
-
def __init__(self):
|
|
|
|
| 26 |
self.model = SentenceTransformer('clip-ViT-B-32')
|
| 27 |
-
self.embedding_dataset = load_dataset("JLD/unsplash25k-image-embeddings", trust_remote_code=True, split="train").with_format("torch", device=
|
| 28 |
image_dataset = load_dataset("jamescalam/unsplash-25k-photos", trust_remote_code=True, revision="refs/pr/3")
|
| 29 |
self.image_dataset = {image["photo_id"]: image["photo_image_url"] for image in image_dataset["train"]}
|
| 30 |
|
|
@@ -35,12 +37,12 @@ class SearchEngine:
|
|
| 35 |
|
| 36 |
def search_images_from_text(self, text):
|
| 37 |
logger.info("Searching images from text")
|
| 38 |
-
emb = self.model.encode(text, convert_to_tensor=True, device=
|
| 39 |
return self.get_candidates(query_embedding=emb)
|
| 40 |
|
| 41 |
def search_images_from_image(self, image):
|
| 42 |
logger.info("Searching images from image")
|
| 43 |
-
emb = self.model.encode(Image.fromarray(image), convert_to_tensor=True, device=
|
| 44 |
return self.get_candidates(query_embedding=emb)
|
| 45 |
|
| 46 |
def main():
|
|
|
|
| 3 |
from datasets import load_dataset
|
| 4 |
from sentence_transformers import SentenceTransformer, util
|
| 5 |
import logging
|
| 6 |
+
import torch
|
| 7 |
from PIL import Image
|
| 8 |
# Create a custom logger
|
| 9 |
logger = logging.getLogger(__name__)
|
|
|
|
| 23 |
logger.addHandler(c_handler)
|
| 24 |
|
| 25 |
class SearchEngine:
|
| 26 |
+
def __init__(self, device="cpu"):
|
| 27 |
+
self.device = device if torch.cuda.is_available() else "cpu"
|
| 28 |
self.model = SentenceTransformer('clip-ViT-B-32')
|
| 29 |
+
self.embedding_dataset = load_dataset("JLD/unsplash25k-image-embeddings", trust_remote_code=True, split="train").with_format("torch", device=self.device)
|
| 30 |
image_dataset = load_dataset("jamescalam/unsplash-25k-photos", trust_remote_code=True, revision="refs/pr/3")
|
| 31 |
self.image_dataset = {image["photo_id"]: image["photo_image_url"] for image in image_dataset["train"]}
|
| 32 |
|
|
|
|
| 37 |
|
| 38 |
def search_images_from_text(self, text):
|
| 39 |
logger.info("Searching images from text")
|
| 40 |
+
emb = self.model.encode(text, convert_to_tensor=True, device=self.device)
|
| 41 |
return self.get_candidates(query_embedding=emb)
|
| 42 |
|
| 43 |
def search_images_from_image(self, image):
|
| 44 |
logger.info("Searching images from image")
|
| 45 |
+
emb = self.model.encode(Image.fromarray(image), convert_to_tensor=True, device=self.device)
|
| 46 |
return self.get_candidates(query_embedding=emb)
|
| 47 |
|
| 48 |
def main():
|