vjawa_move_preprocessing_to_device

#8
by VibhuJawa - opened
nemotron_graphic_elements_v1/model.py CHANGED
@@ -141,6 +141,7 @@ class YoloXWrapper(nn.Module):
141
  """
142
  if not isinstance(image, torch.Tensor):
143
  image = torch.from_numpy(image)
 
144
  image = image.permute(2, 0, 1) # [H, W, 3] -> [3, H, W]
145
  image = resize_pad(image, self.img_size)
146
  return image.float()
 
141
  """
142
  if not isinstance(image, torch.Tensor):
143
  image = torch.from_numpy(image)
144
+ image = image.to(self.device)
145
  image = image.permute(2, 0, 1) # [H, W, 3] -> [3, H, W]
146
  image = resize_pad(image, self.img_size)
147
  return image.float()