Commit
·
ce4721e
1
Parent(s):
7192f11
fix: non sparse models do not require deepspeed anymore
Browse files- modeling_bert.py +6 -124
modeling_bert.py
CHANGED
|
@@ -23,8 +23,6 @@ import warnings
|
|
| 23 |
from dataclasses import dataclass
|
| 24 |
from typing import Optional, Tuple
|
| 25 |
|
| 26 |
-
import numpy as np
|
| 27 |
-
|
| 28 |
import torch
|
| 29 |
import torch.utils.checkpoint
|
| 30 |
from packaging import version
|
|
@@ -306,7 +304,11 @@ class BertSelfAttention(nn.Module):
|
|
| 306 |
self.rotary_emb = RotaryEmbedding(self.rotary_dim, base=self.rotary_base)
|
| 307 |
|
| 308 |
if self.is_sparse:
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
self.sparse_self_attention = SparseSelfAttention(self.sparse_config, max_seq_length=self.max_seq_len)
|
| 311 |
|
| 312 |
def transpose_for_scores(self, x):
|
|
@@ -1871,126 +1873,6 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
| 1871 |
hidden_states=outputs.hidden_states,
|
| 1872 |
attentions=outputs.attentions,
|
| 1873 |
)
|
| 1874 |
-
|
| 1875 |
-
|
| 1876 |
-
class APARENTLoss(nn.Module):
|
| 1877 |
-
def __init__(self):
|
| 1878 |
-
super(APARENTLoss, self).__init__()
|
| 1879 |
-
|
| 1880 |
-
def forward(self, p, y):
|
| 1881 |
-
for i, n in enumerate(y):
|
| 1882 |
-
if n == 0.:
|
| 1883 |
-
y[i] += 1e-3
|
| 1884 |
-
elif n == 1.:
|
| 1885 |
-
y[i] -= 1e-3
|
| 1886 |
-
|
| 1887 |
-
loss = p * torch.log(p / y) + (1 - p) * torch.log((1 - p) / (1 - y))
|
| 1888 |
-
|
| 1889 |
-
return loss.mean()
|
| 1890 |
-
|
| 1891 |
-
|
| 1892 |
-
|
| 1893 |
-
@add_start_docstrings(
|
| 1894 |
-
"""
|
| 1895 |
-
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
| 1896 |
-
output) e.g. for GLUE tasks.
|
| 1897 |
-
""",
|
| 1898 |
-
BERT_START_DOCSTRING,
|
| 1899 |
-
)
|
| 1900 |
-
class BertForAPARENTSequenceRegression(BertPreTrainedModel):
|
| 1901 |
-
def __init__(self, config):
|
| 1902 |
-
super().__init__(config)
|
| 1903 |
-
self.num_labels = config.num_labels
|
| 1904 |
-
self.config = config
|
| 1905 |
-
|
| 1906 |
-
self.bert = BertModel(config)
|
| 1907 |
-
classifier_dropout = (
|
| 1908 |
-
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
| 1909 |
-
)
|
| 1910 |
-
self.dropout = nn.Dropout(classifier_dropout)
|
| 1911 |
-
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1912 |
-
|
| 1913 |
-
# Initialize weights and apply final processing
|
| 1914 |
-
self.post_init()
|
| 1915 |
-
|
| 1916 |
-
|
| 1917 |
-
def forward(
|
| 1918 |
-
self,
|
| 1919 |
-
input_ids=None,
|
| 1920 |
-
attention_mask=None,
|
| 1921 |
-
token_type_ids=None,
|
| 1922 |
-
position_ids=None,
|
| 1923 |
-
head_mask=None,
|
| 1924 |
-
inputs_embeds=None,
|
| 1925 |
-
labels=None,
|
| 1926 |
-
pos_weight=None,
|
| 1927 |
-
output_attentions=None,
|
| 1928 |
-
output_hidden_states=None,
|
| 1929 |
-
return_dict=None,
|
| 1930 |
-
):
|
| 1931 |
-
r"""
|
| 1932 |
-
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 1933 |
-
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
| 1934 |
-
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
| 1935 |
-
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
| 1936 |
-
"""
|
| 1937 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1938 |
-
|
| 1939 |
-
if np.all(input_ids[:, -1].detach().cpu().numpy() == np.array([3 for i in range(len(input_ids))])):
|
| 1940 |
-
pass
|
| 1941 |
-
else:
|
| 1942 |
-
print("#########################################NOT ENOUGH TOKENS#######################################")
|
| 1943 |
-
|
| 1944 |
-
outputs = self.bert(
|
| 1945 |
-
input_ids,
|
| 1946 |
-
attention_mask=attention_mask,
|
| 1947 |
-
token_type_ids=token_type_ids,
|
| 1948 |
-
position_ids=position_ids,
|
| 1949 |
-
head_mask=head_mask,
|
| 1950 |
-
inputs_embeds=inputs_embeds,
|
| 1951 |
-
output_attentions=output_attentions,
|
| 1952 |
-
output_hidden_states=output_hidden_states,
|
| 1953 |
-
return_dict=return_dict,
|
| 1954 |
-
)
|
| 1955 |
-
|
| 1956 |
-
pooled_output = outputs[1]
|
| 1957 |
-
|
| 1958 |
-
pooled_output = self.dropout(pooled_output)
|
| 1959 |
-
logits = self.classifier(pooled_output)
|
| 1960 |
-
logits = torch.sigmoid(logits)
|
| 1961 |
-
|
| 1962 |
-
loss = None
|
| 1963 |
-
if labels is not None:
|
| 1964 |
-
if self.config.problem_type is None:
|
| 1965 |
-
if self.num_labels == 1:
|
| 1966 |
-
self.config.problem_type = "regression"
|
| 1967 |
-
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
| 1968 |
-
self.config.problem_type = "single_label_classification"
|
| 1969 |
-
else:
|
| 1970 |
-
self.config.problem_type = "multi_label_classification"
|
| 1971 |
-
|
| 1972 |
-
if self.config.problem_type == "regression":
|
| 1973 |
-
loss_fct = MSELoss() #APARENTLoss()
|
| 1974 |
-
if self.num_labels == 1:
|
| 1975 |
-
loss = loss_fct(logits.squeeze().float(), labels.squeeze().float()) # if it is not a sparse model then --- labels.squeeze().float(), else --- labels.squeeze().half()
|
| 1976 |
-
else:
|
| 1977 |
-
loss = loss_fct(logits, labels)
|
| 1978 |
-
elif self.config.problem_type == "single_label_classification":
|
| 1979 |
-
loss_fct = CrossEntropyLoss()
|
| 1980 |
-
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1981 |
-
elif self.config.problem_type == "multi_label_classification":
|
| 1982 |
-
loss_fct = BCEWithLogitsLoss(pos_weight=pos_weight)
|
| 1983 |
-
loss = loss_fct(logits, labels)
|
| 1984 |
-
if not return_dict:
|
| 1985 |
-
output = (logits,) + outputs[2:]
|
| 1986 |
-
return ((loss,) + output) if loss is not None else output
|
| 1987 |
-
|
| 1988 |
-
return SequenceClassifierOutput(
|
| 1989 |
-
loss=loss,
|
| 1990 |
-
logits=logits,
|
| 1991 |
-
hidden_states=outputs.hidden_states,
|
| 1992 |
-
attentions=outputs.attentions,
|
| 1993 |
-
)
|
| 1994 |
|
| 1995 |
|
| 1996 |
@add_start_docstrings(
|
|
@@ -2174,7 +2056,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|
| 2174 |
loss_fct = BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight)
|
| 2175 |
loss = loss_fct(logits, labels)
|
| 2176 |
loss = loss * labels_mask.unsqueeze(-1)
|
| 2177 |
-
loss = loss.sum() / labels_mask.sum() if labels_mask.sum() != 0.0 else 0.0
|
| 2178 |
|
| 2179 |
if not return_dict:
|
| 2180 |
output = (logits,) + outputs[2:]
|
|
|
|
| 23 |
from dataclasses import dataclass
|
| 24 |
from typing import Optional, Tuple
|
| 25 |
|
|
|
|
|
|
|
| 26 |
import torch
|
| 27 |
import torch.utils.checkpoint
|
| 28 |
from packaging import version
|
|
|
|
| 304 |
self.rotary_emb = RotaryEmbedding(self.rotary_dim, base=self.rotary_base)
|
| 305 |
|
| 306 |
if self.is_sparse:
|
| 307 |
+
try:
|
| 308 |
+
from deepspeed.ops.sparse_attention import SparseSelfAttention
|
| 309 |
+
except ImportError as e:
|
| 310 |
+
logger.error(f'DeepSpeed is required for Sparse Ops: {e}')
|
| 311 |
+
raise
|
| 312 |
self.sparse_self_attention = SparseSelfAttention(self.sparse_config, max_seq_length=self.max_seq_len)
|
| 313 |
|
| 314 |
def transpose_for_scores(self, x):
|
|
|
|
| 1873 |
hidden_states=outputs.hidden_states,
|
| 1874 |
attentions=outputs.attentions,
|
| 1875 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1876 |
|
| 1877 |
|
| 1878 |
@add_start_docstrings(
|
|
|
|
| 2056 |
loss_fct = BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight)
|
| 2057 |
loss = loss_fct(logits, labels)
|
| 2058 |
loss = loss * labels_mask.unsqueeze(-1)
|
| 2059 |
+
loss = loss.sum() / labels_mask.sum() if labels_mask.sum() != 0.0 else torch.tensor(0.0, device=logits.device)
|
| 2060 |
|
| 2061 |
if not return_dict:
|
| 2062 |
output = (logits,) + outputs[2:]
|