Skip to content
GitLab
Explore
Projects
Groups
Topics
Snippets
Projects
Groups
Topics
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
proekt
obuch
Commits
2650a64c
Commit
2650a64c
authored
2 weeks ago
by
Мазур Грета Евгеньевна
Browse files
Options
Download
Patches
Plain Diff
obuch with cross and graphic SAVING LORA
parent
b67bb593
master
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
.ipynb_checkpoints/proverkabert-checkpoint.py
+31
-3
.ipynb_checkpoints/proverkabert-checkpoint.py
proverkabert.py
+31
-3
proverkabert.py
with
62 additions
and
6 deletions
+62
-6
.ipynb_checkpoints/proverkabert-checkpoint.py
+
31
−
3
View file @
2650a64c
from
transformers
import
BertTokenizer
import
torch
from
peft
import
get_peft_model
from
peft
import
get_peft_model
,
LoraConfig
,
TaskType
import
torch.nn
as
nn
from
transformers
import
BertModel
,
BertPreTrainedModel
# Загрузка модели с адаптерами LoRA
# Убедитесь, что класс MultiTaskBert определён, как в вашем первоначальном коде
class
MultiTaskBert
(
BertPreTrainedModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
bert
=
BertModel
(
config
)
self
.
classifier_safety
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
classifier_attack
=
nn
.
Linear
(
config
.
hidden_size
,
4
)
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
**
kwargs
):
# Переводим тензоры на устройство
input_ids
,
attention_mask
,
labels
=
map
(
lambda
x
:
x
.
to
(
device
)
if
x
is
not
None
else
None
,
[
input_ids
,
attention_mask
,
labels
])
outputs
=
self
.
bert
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
return_dict
=
True
)
pooled_output
=
outputs
.
last_hidden_state
[:,
0
,
:]
logits_safety
=
self
.
classifier_safety
(
pooled_output
)
logits_attack
=
self
.
classifier_attack
(
pooled_output
)
loss
=
None
if
labels
is
not
None
:
labels_safety
,
labels_attack
=
labels
[:,
0
],
labels
[:,
1
]
loss_safety
=
nn
.
CrossEntropyLoss
()(
logits_safety
,
labels_safety
)
loss_attack
=
nn
.
CrossEntropyLoss
()(
logits_attack
,
labels_attack
)
loss
=
loss_safety
+
loss_attack
return
{
'logits_safety'
:
logits_safety
,
'logits_attack'
:
logits_attack
,
'loss'
:
loss
}
# Загрузка модели с LoRA адаптерами
model
=
MultiTaskBert
.
from_pretrained
(
'./fine-tuned-bert-lora_new'
).
to
(
device
)
# Восстановление модели с LoRA
адаптерами
# Восстановление модели с LoRA
lora_config
=
LoraConfig
(
task_type
=
TaskType
.
SEQ_CLS
,
r
=
8
,
...
...
This diff is collapsed.
Click to expand it.
proverkabert.py
+
31
−
3
View file @
2650a64c
from
transformers
import
BertTokenizer
import
torch
from
peft
import
get_peft_model
from
peft
import
get_peft_model
,
LoraConfig
,
TaskType
import
torch.nn
as
nn
from
transformers
import
BertModel
,
BertPreTrainedModel
# Загрузка модели с адаптерами LoRA
# Убедитесь, что класс MultiTaskBert определён, как в вашем первоначальном коде
class
MultiTaskBert
(
BertPreTrainedModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
bert
=
BertModel
(
config
)
self
.
classifier_safety
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
self
.
classifier_attack
=
nn
.
Linear
(
config
.
hidden_size
,
4
)
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
**
kwargs
):
# Переводим тензоры на устройство
input_ids
,
attention_mask
,
labels
=
map
(
lambda
x
:
x
.
to
(
device
)
if
x
is
not
None
else
None
,
[
input_ids
,
attention_mask
,
labels
])
outputs
=
self
.
bert
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
return_dict
=
True
)
pooled_output
=
outputs
.
last_hidden_state
[:,
0
,
:]
logits_safety
=
self
.
classifier_safety
(
pooled_output
)
logits_attack
=
self
.
classifier_attack
(
pooled_output
)
loss
=
None
if
labels
is
not
None
:
labels_safety
,
labels_attack
=
labels
[:,
0
],
labels
[:,
1
]
loss_safety
=
nn
.
CrossEntropyLoss
()(
logits_safety
,
labels_safety
)
loss_attack
=
nn
.
CrossEntropyLoss
()(
logits_attack
,
labels_attack
)
loss
=
loss_safety
+
loss_attack
return
{
'logits_safety'
:
logits_safety
,
'logits_attack'
:
logits_attack
,
'loss'
:
loss
}
# Загрузка модели с LoRA адаптерами
model
=
MultiTaskBert
.
from_pretrained
(
'./fine-tuned-bert-lora_new'
).
to
(
device
)
# Восстановление модели с LoRA
адаптерами
# Восстановление модели с LoRA
lora_config
=
LoraConfig
(
task_type
=
TaskType
.
SEQ_CLS
,
r
=
8
,
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment
Menu
Explore
Projects
Groups
Topics
Snippets