| from typing import * | |
| import torch | |
| from .span import Span | |
| def _tensor2span_batch( | |
| span_boundary: torch.Tensor, | |
| span_labels: torch.Tensor, | |
| parent_indices: torch.Tensor, | |
| num_spans: torch.Tensor, | |
| label_confidence: torch.Tensor, | |
| idx2label: Dict[int, str], | |
| label_ignore: List[int], | |
| ) -> Span: | |
| spans = list() | |
| for (start_idx, end_idx), parent_idx, label, label_conf in \ | |
| list(zip(span_boundary, parent_indices, span_labels, label_confidence))[:int(num_spans)]: | |
| if label not in label_ignore: | |
| span = Span(int(start_idx), int(end_idx), idx2label[int(label)], True, confidence=float(label_conf)) | |
| if int(parent_idx) < len(spans): | |
| spans[int(parent_idx)].add_child(span) | |
| spans.append(span) | |
| return spans[0] | |
| def tensor2span( | |
| span_boundary: torch.Tensor, | |
| span_labels: torch.Tensor, | |
| parent_indices: torch.Tensor, | |
| num_spans: torch.Tensor, | |
| label_confidence: torch.Tensor, | |
| idx2label: Dict[int, str], | |
| label_ignore: Optional[List[int]] = None, | |
| ) -> List[Span]: | |
| """ | |
| Generate spans in dict from vectors. Refer to the model part for the meaning of these variables. | |
| If idx_ignore is provided, some labels will be ignored. | |
| :return: | |
| """ | |
| label_ignore = label_ignore or [] | |
| if span_boundary.device.type != 'cpu': | |
| span_boundary = span_boundary.to(device='cpu') | |
| parent_indices = parent_indices.to(device='cpu') | |
| span_labels = span_labels.to(device='cpu') | |
| num_spans = num_spans.to(device='cpu') | |
| label_confidence = label_confidence.to(device='cpu') | |
| ret = list() | |
| for args in zip( | |
| span_boundary.unbind(0), span_labels.unbind(0), parent_indices.unbind(0), num_spans.unbind(0), | |
| label_confidence.unbind(0), | |
| ): | |
| ret.append(_tensor2span_batch(*args, label_ignore=label_ignore, idx2label=idx2label)) | |
| return ret | |