diff --git a/.ipynb_checkpoints/trytoubload-checkpoint.py b/.ipynb_checkpoints/trytoubload-checkpoint.py
index ccb4fbc12bd34e9b6f58635736a637f3b2e782f0..23e61a5329316c52133e38eb1a49eaa365ca33a8 100644
--- a/.ipynb_checkpoints/trytoubload-checkpoint.py
+++ b/.ipynb_checkpoints/trytoubload-checkpoint.py
@@ -1,7 +1,33 @@
 from transformers import BertTokenizer, BertForSequenceClassification
 from peft import PeftModel
 import torch
-from micro_no_cross import MultiTaskBert
+
+
+class MultiTaskBert(BertPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.bert = BertModel(config)
+        self.classifier_safety = nn.Linear(config.hidden_size, 2)
+        self.classifier_attack = nn.Linear(config.hidden_size, 4)
+
+    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
+        # Переводим тензоры на устройство
+        input_ids, attention_mask, labels = map(lambda x: x.to(device) if x is not None else None, [input_ids, attention_mask, labels])
+        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
+        pooled_output = outputs.last_hidden_state[:, 0, :]
+
+        logits_safety = self.classifier_safety(pooled_output)
+        logits_attack = self.classifier_attack(pooled_output)
+
+        loss = None
+        if labels is not None:
+            labels_safety, labels_attack = labels[:, 0], labels[:, 1]
+            loss_safety = nn.CrossEntropyLoss(weight=class_weights_task1_tensor)(logits_safety, labels_safety)
+            loss_attack = nn.CrossEntropyLoss(weight=class_weights_task2_tensor)(logits_attack, labels_attack)
+            loss = loss_safety + loss_attack
+
+        return {'logits_safety': logits_safety, 'logits_attack': logits_attack, 'loss': loss}
+
 # Пути к сохранённой модели
 # BASE_MODEL_PATH = "./micro_no_cross_fine_tuned/base"
 # LORA_PATH = "./micro_no_cross_fine_tuned/lora"
diff --git a/micro_no_cross.py b/micro_no_cross.py
index 6d6fde16008d90574bb8471f8caebffc806691d8..5d31cdb45056e3c9818f68fb4a9034543fbe7e3e 100644
--- a/micro_no_cross.py
+++ b/micro_no_cross.py
@@ -159,6 +159,7 @@ training_args = TrainingArguments(
 
 # Обучение
 trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics)
+
 trainer.train()
 
 # Оценка
diff --git a/trytoubload.py b/trytoubload.py
index ccb4fbc12bd34e9b6f58635736a637f3b2e782f0..23e61a5329316c52133e38eb1a49eaa365ca33a8 100644
--- a/trytoubload.py
+++ b/trytoubload.py
@@ -1,7 +1,33 @@
 from transformers import BertTokenizer, BertForSequenceClassification
 from peft import PeftModel
 import torch
-from micro_no_cross import MultiTaskBert
+
+
+class MultiTaskBert(BertPreTrainedModel):
+    def __init__(self, config):
+        super().__init__(config)
+        self.bert = BertModel(config)
+        self.classifier_safety = nn.Linear(config.hidden_size, 2)
+        self.classifier_attack = nn.Linear(config.hidden_size, 4)
+
+    def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
+        # Переводим тензоры на устройство
+        input_ids, attention_mask, labels = map(lambda x: x.to(device) if x is not None else None, [input_ids, attention_mask, labels])
+        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
+        pooled_output = outputs.last_hidden_state[:, 0, :]
+
+        logits_safety = self.classifier_safety(pooled_output)
+        logits_attack = self.classifier_attack(pooled_output)
+
+        loss = None
+        if labels is not None:
+            labels_safety, labels_attack = labels[:, 0], labels[:, 1]
+            loss_safety = nn.CrossEntropyLoss(weight=class_weights_task1_tensor)(logits_safety, labels_safety)
+            loss_attack = nn.CrossEntropyLoss(weight=class_weights_task2_tensor)(logits_attack, labels_attack)
+            loss = loss_safety + loss_attack
+
+        return {'logits_safety': logits_safety, 'logits_attack': logits_attack, 'loss': loss}
+
 # Пути к сохранённой модели
 # BASE_MODEL_PATH = "./micro_no_cross_fine_tuned/base"
 # LORA_PATH = "./micro_no_cross_fine_tuned/lora"