Spaces:
Build error
Build error
| import unittest | |
| import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist | |
| import cv2 | |
| import numpy | |
| class TestSegmentAnything2Assist(unittest.TestCase): | |
| def setUp(self) -> None: | |
| return super().setUp() | |
| def tearDown(self) -> None: | |
| return super().tearDown() | |
| def _loading_all_sam_model_types(self): | |
| # Test loading all types of SAM2 models. | |
| all_sam_models_type = [ | |
| "sam2_hiera_tiny", | |
| "sam2_hiera_small", | |
| "sam2_hiera_base_plus", | |
| "sam2_hiera_large", | |
| ] | |
| for sam_model_type in all_sam_models_type: | |
| sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
| sam_model_name=sam_model_type, download=True, device="cpu" | |
| ) | |
| self.assertEqual(sam_model.is_model_available(), True) | |
| sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
| sam_model_name=sam_model_type, | |
| download=False, | |
| model_path=f".tmp/checkpoints/{sam_model_type}.pth", | |
| device="cpu", | |
| ) | |
| with self.assertRaises(Exception): | |
| sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
| sam_model_name=sam_model_type, | |
| download=False, | |
| model_path=".", | |
| device="cpu", | |
| ) | |
| def _generate_automatic_mask(self): | |
| image = cv2.imread("test/assets/liberty.jpg") | |
| sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
| sam_model_name="sam2_hiera_tiny", download=True, device="cpu" | |
| ) | |
| segmentation_masks, bboxes, predicted_iou, stability_score = ( | |
| sam_model.generate_automatic_masks(image) | |
| ) | |
| self.assertEqual(len(segmentation_masks.shape), 4) | |
| self.assertEqual(segmentation_masks[0].shape, image.shape) | |
| self.assertEqual(segmentation_masks.shape[3], 3) | |
| self.assertEqual(type(segmentation_masks[0][0][0][0]), numpy.uint8) | |
| self.assertEqual(len(bboxes.shape), 2) | |
| self.assertEqual(bboxes[0].shape, (4,)) | |
| self.assertEqual(type(bboxes[0][0]), numpy.uint32) | |
| self.assertEqual(len(predicted_iou.shape), 1) | |
| self.assertEqual(type(predicted_iou[0]), numpy.float32) | |
| self.assertEqual(len(stability_score.shape), 1) | |
| self.assertEqual(type(stability_score[0]), numpy.float32) | |
| for segmentation_mask in segmentation_masks: | |
| self.assertEqual(segmentation_mask.shape, image.shape) | |
| def test_generate_masks_from_image(self): | |
| image = cv2.imread("test/assets/liberty.jpg") | |
| sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
| sam_model_name="sam2_hiera_tiny", download=True, device="cpu" | |
| ) | |
| mask_chw, mask_iou = sam_model.generate_masks_from_image( | |
| image, None, None, None | |
| ) | |
| self.assertEqual(len(mask_chw.shape), 3) | |
| self.assertEqual(mask_chw[0].shape, image.shape) | |
| self.assertEqual(mask_chw.shape[0], 1) | |
| self.assertEqual(len(mask_iou.shape), 1) | |
| self.assertEqual(mask_iou.shape[0], 1) | |