diff --git a/.ipynb_checkpoints/trytoubload-checkpoint.py b/.ipynb_checkpoints/trytoubload-checkpoint.py index b4836269575d9e06fb159d8e9254c45927bc86e8..e04f14a5b07a28de154fc115792097198417ab9f 100644 --- a/.ipynb_checkpoints/trytoubload-checkpoint.py +++ b/.ipynb_checkpoints/trytoubload-checkpoint.py @@ -14,6 +14,8 @@ from transformers import BertTokenizer, BertPreTrainedModel, BertModel, Trainer, from torch import nn from peft import get_peft_model, LoraConfig, TaskType +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + class MultiTaskBert(BertPreTrainedModel): def __init__(self, config): diff --git a/trytoubload.py b/trytoubload.py index b4836269575d9e06fb159d8e9254c45927bc86e8..e04f14a5b07a28de154fc115792097198417ab9f 100644 --- a/trytoubload.py +++ b/trytoubload.py @@ -14,6 +14,8 @@ from transformers import BertTokenizer, BertPreTrainedModel, BertModel, Trainer, from torch import nn from peft import get_peft_model, LoraConfig, TaskType +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + class MultiTaskBert(BertPreTrainedModel): def __init__(self, config):