Commit 2650a64c authored by Мазур Грета Евгеньевна's avatar Мазур Грета Евгеньевна
Browse files

obuch with cross and graphic SAVING LORA

parent b67bb593
No related merge requests found
Showing with 62 additions and 6 deletions
+62 -6
from transformers import BertTokenizer
import torch
from peft import get_peft_model
from peft import get_peft_model, LoraConfig, TaskType
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
# Загрузка модели с адаптерами LoRA
# Убедитесь, что класс MultiTaskBert определён, как в вашем первоначальном коде
class MultiTaskBert(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.classifier_safety = nn.Linear(config.hidden_size, 2)
self.classifier_attack = nn.Linear(config.hidden_size, 4)
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
# Переводим тензоры на устройство
input_ids, attention_mask, labels = map(lambda x: x.to(device) if x is not None else None, [input_ids, attention_mask, labels])
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
pooled_output = outputs.last_hidden_state[:, 0, :]
logits_safety = self.classifier_safety(pooled_output)
logits_attack = self.classifier_attack(pooled_output)
loss = None
if labels is not None:
labels_safety, labels_attack = labels[:, 0], labels[:, 1]
loss_safety = nn.CrossEntropyLoss()(logits_safety, labels_safety)
loss_attack = nn.CrossEntropyLoss()(logits_attack, labels_attack)
loss = loss_safety + loss_attack
return {'logits_safety': logits_safety, 'logits_attack': logits_attack, 'loss': loss}
# Загрузка модели с LoRA адаптерами
model = MultiTaskBert.from_pretrained('./fine-tuned-bert-lora_new').to(device)
# Восстановление модели с LoRA адаптерами
# Восстановление модели с LoRA
lora_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
r=8,
......
from transformers import BertTokenizer
import torch
from peft import get_peft_model
from peft import get_peft_model, LoraConfig, TaskType
import torch.nn as nn
from transformers import BertModel, BertPreTrainedModel
# Загрузка модели с адаптерами LoRA
# Убедитесь, что класс MultiTaskBert определён, как в вашем первоначальном коде
class MultiTaskBert(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.classifier_safety = nn.Linear(config.hidden_size, 2)
self.classifier_attack = nn.Linear(config.hidden_size, 4)
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
# Переводим тензоры на устройство
input_ids, attention_mask, labels = map(lambda x: x.to(device) if x is not None else None, [input_ids, attention_mask, labels])
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
pooled_output = outputs.last_hidden_state[:, 0, :]
logits_safety = self.classifier_safety(pooled_output)
logits_attack = self.classifier_attack(pooled_output)
loss = None
if labels is not None:
labels_safety, labels_attack = labels[:, 0], labels[:, 1]
loss_safety = nn.CrossEntropyLoss()(logits_safety, labels_safety)
loss_attack = nn.CrossEntropyLoss()(logits_attack, labels_attack)
loss = loss_safety + loss_attack
return {'logits_safety': logits_safety, 'logits_attack': logits_attack, 'loss': loss}
# Загрузка модели с LoRA адаптерами
model = MultiTaskBert.from_pretrained('./fine-tuned-bert-lora_new').to(device)
# Восстановление модели с LoRA адаптерами
# Восстановление модели с LoRA
lora_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
r=8,
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment