Source code for epbd_bert.datasets.data_collators
from typing import Dict, Sequence
from dataclasses import dataclass
import torch
[docs]
@dataclass
class SeqLabelEPBDDataCollator:
def __init__(self, pad_token_id=0):
self.pad_token_id = pad_token_id
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels, epbd_features = tuple(
[instance[key] for instance in instances]
for key in ("input_ids", "labels", "epbd_features")
)
# padding tokens in a mini-batch as the length of the maximum seq_len
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.pad_token_id
)
# print(input_ids.shape)
epbd_features = torch.stack(epbd_features)
# stacking labels
labels = torch.stack(labels)
labels = labels.to(dtype=torch.float32)
# setting up attention mask
attention_mask = input_ids.ne(self.pad_token_id).int()
return dict(
input_ids=input_ids,
epbd_features=epbd_features,
labels=labels,
attention_mask=attention_mask,
)
# dc = SeqLabelEPBDDataCollator(pad_token_id=100)
# x = [
# dict(
# input_ids=torch.ones(10), labels=torch.ones(3), epbd_features=torch.rand(1200)
# ),
# dict(input_ids=torch.ones(7), labels=torch.ones(3), epbd_features=torch.rand(1200)),
# dict(input_ids=torch.ones(3), labels=torch.ones(3), epbd_features=torch.rand(1200)),
# ]
# print(dc(x))
[docs]
@dataclass
class SeqLabelDataCollator:
def __init__(self, pad_token_id=0):
self.pad_token_id = pad_token_id
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple(
[instance[key] for instance in instances] for key in ("input_ids", "labels")
)
# padding tokens in a mini-batch as the length of the maximum seq_len
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.pad_token_id
)
# print(input_ids.shape)
# stacking labels
labels = torch.stack(labels)
labels = labels.to(dtype=torch.float32)
# setting up attention mask
attention_mask = input_ids.ne(self.pad_token_id).int()
return dict(input_ids=input_ids, labels=labels, attention_mask=attention_mask)
# dc = SeqLabelDataCollator()
# x = [
# dict(input_ids=torch.ones(10), labels=torch.ones(3)),
# dict(input_ids=torch.ones(7), labels=torch.ones(3)),
# dict(input_ids=torch.ones(3), labels=torch.ones(3)),
# ]
# print(dc(x))