Update modelling_cxrrg.py
Browse files- modelling_cxrrg.py +1 -1
modelling_cxrrg.py
CHANGED
|
@@ -541,4 +541,4 @@ class CXRRGModel(VisionEncoderDecoderModel):
|
|
| 541 |
causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
|
| 542 |
|
| 543 |
mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
|
| 544 |
-
return mixed_causality_4d_attention_mask
|
|
|
|
| 541 |
causal_2d_attention_mask = causal_2d_attention_mask[:, None, None, :]
|
| 542 |
|
| 543 |
mixed_causality_4d_attention_mask = torch.cat((non_causal_2d_attention_mask, causal_2d_attention_mask), dim=-1)
|
| 544 |
+
return mixed_causality_4d_attention_mask.to(dtype=torch.float)
|