Source code for epbd_bert.models_factory
import torch
from torch.utils.data import DataLoader
import transformers
from epbd_bert.datasets.data_collators import (
SeqLabelEPBDDataCollator,
SeqLabelDataCollator,
)
from epbd_bert.path_configs import (
dnabert2_classifier_ckptpath,
epbd_dnabert2_ckptpath,
epbd_dnabert2_crossattn_ckptpath,
epbd_dnabert2_crossattn_best_ckptpath,
)
[docs]
def get_model_and_dataloader(
model_name,
data_path,
tokenizer: transformers.PreTrainedTokenizer,
batch_size=64,
num_workers=8,
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_name == "dnabert2_classifier":
from epbd_bert.dnabert2_classifier.model import DNABERT2Classifier
from epbd_bert.datasets.sequence_dataset import SequenceDataset
# ckpt_path = "dnabert2_classifier/backups/version_2/checkpoints/epoch=20-step=50547-val_loss=0.052-val_aucroc=0.939.ckpt"
ckpt_path = dnabert2_classifier_ckptpath
model = DNABERT2Classifier.load_from_checkpoint(ckpt_path)
ds = SequenceDataset(data_path, tokenizer)
data_collator = SeqLabelDataCollator(pad_token_id=tokenizer.pad_token_id)
elif model_name == "epbd_dnabert2":
from epbd_bert.dnabert2_epbd.model import Dnabert2EPBDModel
from epbd_bert.datasets.sequence_epbd_dataset import SequenceEPBDDataset
# ckpt_path = "dnabert2_epbd/backups/version_0/checkpoints/epoch=17-step=43326-val_loss=0.053-val_aucroc=0.938.ckpt"
ckpt_path = epbd_dnabert2_ckptpath
model = Dnabert2EPBDModel.load_from_checkpoint(ckpt_path)
ds = SequenceEPBDDataset(data_path, tokenizer)
data_collator = SeqLabelEPBDDataCollator(pad_token_id=tokenizer.pad_token_id)
elif model_name == "epbd_dnabert2_crossattn_best":
from epbd_bert.dnabert2_epbd_crossattn.model import EPBDDnabert2Model
from epbd_bert.datasets.sequence_epbd_multimodal_dataset import (
SequenceEPBDMultiModalDataset,
)
# ckpt_path = "analysis/best_model/epoch=9-step=255700.ckpt" # best model
ckpt_path = epbd_dnabert2_crossattn_best_ckptpath
model = EPBDDnabert2Model.load_from_checkpoint(ckpt_path)
ds = SequenceEPBDMultiModalDataset(data_path, tokenizer)
data_collator = SeqLabelEPBDDataCollator(pad_token_id=tokenizer.pad_token_id)
model.to(device)
model.eval()
dl = DataLoader(
ds,
collate_fn=data_collator,
shuffle=False,
pin_memory=False,
batch_size=batch_size,
num_workers=num_workers,
)
print("DS, DL:", ds.__len__(), len(dl))
return model, dl
# test all models with corresponding test dataloader
# from epbd_bert.utility.dnabert2 import get_dnabert2_tokenizer
# from epbd_bert.path_configs import test_data_filepath
# tokenizer = get_dnabert2_tokenizer(max_num_tokens=512)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model, dl = get_model_and_dataloader(
# model_name="epbd_dnabert2_crossattn_best",
# data_path=test_data_filepath,
# tokenizer=tokenizer,
# )
# for i, batch in enumerate(dl):
# x = {key: batch[key].to(device) for key in batch.keys()}
# logits, targets = model(x)
# print(i, logits.shape, targets.shape)
# break
# checkpoint = torch.load(
# dnabert2_classifier_ckptpath, map_location=lambda storage, loc: storage
# )
# print(checkpoint.keys())
# print(checkpoint["hyper_parameters"])
# print(checkpoint["state_dict"])
# {"learning_rate": the_value, "another_parameter": the_other_value}