Daniellesry commited on
Commit
19cc0f7
·
1 Parent(s): 4c7d24d
Files changed (1) hide show
  1. dkt/pipelines/pipeline.py +2 -19
dkt/pipelines/pipeline.py CHANGED
@@ -956,26 +956,9 @@ class DKTPipeline:
956
 
957
 
958
 
959
- @spaces.GPU(duration=30)
960
  @torch.inference_mode()
961
- def moge_infer(self, input_image):
962
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
963
-
964
- # Ensure model and input are on the same device
965
- model_device = next(self.moge_pipe.parameters()).device
966
- logger.info(f'moge_infer: model device={model_device}, target device={device}')
967
- print(f'moge_infer: model device={model_device}, target device={device}')
968
-
969
- if model_device != device:
970
- logger.info(f'Moving MoGe model from {model_device} to {device}')
971
- print(f'Moving MoGe model from {model_device} to {device}')
972
- self.moge_pipe = self.moge_pipe.to(device)
973
-
974
- input_image = input_image.to(device)
975
- logger.info(f'moge_infer: input device={input_image.device}')
976
- print(f'moge_infer: input device={input_image.device}')
977
-
978
-
979
  return self.moge_pipe.infer(input_image)
980
 
981
 
 
956
 
957
 
958
 
959
+ @spaces.GPU()
960
  @torch.inference_mode()
961
+ def moge_infer(self, input_image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
962
  return self.moge_pipe.infer(input_image)
963
 
964