| import segmentation_models_pytorch as smp | |
| from .hf_config import UnetConfig | |
| from transformers import PreTrainedModel | |
| class HFUnetPlusPlus(PreTrainedModel): | |
| config_class = UnetConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = smp.UnetPlusPlus( | |
| encoder_name=config.encoder_name, | |
| encoder_weights="imagenet", | |
| decoder_channels=config.decoder_channels, | |
| in_channels=config.input_channels, | |
| classes=config.num_classes, | |
| decoder_attention_type="scse") | |
| def forward(self, tensor): | |
| return self.model(tensor) | |