diff --git a/.ipynb_checkpoints/trytoubload-checkpoint.py b/.ipynb_checkpoints/trytoubload-checkpoint.py index ccb4fbc12bd34e9b6f58635736a637f3b2e782f0..23e61a5329316c52133e38eb1a49eaa365ca33a8 100644 --- a/.ipynb_checkpoints/trytoubload-checkpoint.py +++ b/.ipynb_checkpoints/trytoubload-checkpoint.py @@ -1,7 +1,33 @@ from transformers import BertTokenizer, BertForSequenceClassification from peft import PeftModel import torch -from micro_no_cross import MultiTaskBert + + +class MultiTaskBert(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.bert = BertModel(config) + self.classifier_safety = nn.Linear(config.hidden_size, 2) + self.classifier_attack = nn.Linear(config.hidden_size, 4) + + def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): + # Переводим тензоры на устройство + input_ids, attention_mask, labels = map(lambda x: x.to(device) if x is not None else None, [input_ids, attention_mask, labels]) + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) + pooled_output = outputs.last_hidden_state[:, 0, :] + + logits_safety = self.classifier_safety(pooled_output) + logits_attack = self.classifier_attack(pooled_output) + + loss = None + if labels is not None: + labels_safety, labels_attack = labels[:, 0], labels[:, 1] + loss_safety = nn.CrossEntropyLoss(weight=class_weights_task1_tensor)(logits_safety, labels_safety) + loss_attack = nn.CrossEntropyLoss(weight=class_weights_task2_tensor)(logits_attack, labels_attack) + loss = loss_safety + loss_attack + + return {'logits_safety': logits_safety, 'logits_attack': logits_attack, 'loss': loss} + # Пути к сохранённой модели # BASE_MODEL_PATH = "./micro_no_cross_fine_tuned/base" # LORA_PATH = "./micro_no_cross_fine_tuned/lora" diff --git a/micro_no_cross.py b/micro_no_cross.py index 6d6fde16008d90574bb8471f8caebffc806691d8..5d31cdb45056e3c9818f68fb4a9034543fbe7e3e 100644 --- a/micro_no_cross.py +++ b/micro_no_cross.py @@ -159,6 +159,7 @@ training_args = TrainingArguments( # Обучение trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics) + trainer.train() # Оценка diff --git a/trytoubload.py b/trytoubload.py index ccb4fbc12bd34e9b6f58635736a637f3b2e782f0..23e61a5329316c52133e38eb1a49eaa365ca33a8 100644 --- a/trytoubload.py +++ b/trytoubload.py @@ -1,7 +1,33 @@ from transformers import BertTokenizer, BertForSequenceClassification from peft import PeftModel import torch -from micro_no_cross import MultiTaskBert + + +class MultiTaskBert(BertPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.bert = BertModel(config) + self.classifier_safety = nn.Linear(config.hidden_size, 2) + self.classifier_attack = nn.Linear(config.hidden_size, 4) + + def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): + # Переводим тензоры на устройство + input_ids, attention_mask, labels = map(lambda x: x.to(device) if x is not None else None, [input_ids, attention_mask, labels]) + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True) + pooled_output = outputs.last_hidden_state[:, 0, :] + + logits_safety = self.classifier_safety(pooled_output) + logits_attack = self.classifier_attack(pooled_output) + + loss = None + if labels is not None: + labels_safety, labels_attack = labels[:, 0], labels[:, 1] + loss_safety = nn.CrossEntropyLoss(weight=class_weights_task1_tensor)(logits_safety, labels_safety) + loss_attack = nn.CrossEntropyLoss(weight=class_weights_task2_tensor)(logits_attack, labels_attack) + loss = loss_safety + loss_attack + + return {'logits_safety': logits_safety, 'logits_attack': logits_attack, 'loss': loss} + # Пути к сохранённой модели # BASE_MODEL_PATH = "./micro_no_cross_fine_tuned/base" # LORA_PATH = "./micro_no_cross_fine_tuned/lora"