Skip to content
GitLab
Explore
Projects
Groups
Topics
Snippets
Projects
Groups
Topics
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
proekt
obuch
Commits
17d77dde
Commit
17d77dde
authored
1 week ago
by
Мазур Грета Евгеньевна
Browse files
Options
Download
Patches
Plain Diff
proverka obuchenya
parent
34496ab5
master
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
.ipynb_checkpoints/proverkabert-checkpoint.py
+3
-14
.ipynb_checkpoints/proverkabert-checkpoint.py
proverkabert.py
+3
-14
proverkabert.py
with
6 additions
and
28 deletions
+6
-28
.ipynb_checkpoints/proverkabert-checkpoint.py
+
3
−
14
View file @
17d77dde
...
...
@@ -120,15 +120,14 @@ class MultiTaskBert(nn.Module):
self
.
classifier_safety
=
nn
.
Linear
(
768
,
2
)
# safe/unsafe
self
.
classifier_attack
=
nn
.
Linear
(
768
,
4
)
# 4 attack types
def
forward
(
self
,
input_ids
,
attention_mask
):
outputs
=
self
.
bert
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
def
forward
(
self
,
input_ids
,
attention_mask
,
token_type_ids
=
None
):
outputs
=
self
.
bert
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
)
pooled_output
=
outputs
.
last_hidden_state
[:,
0
,
:]
logits_safety
=
self
.
classifier_safety
(
pooled_output
)
logits_attack
=
self
.
classifier_attack
(
pooled_output
)
return
logits_safety
,
logits_attack
def
load_model
(
model_path
):
# Проверяем наличие файлов
if
not
os
.
path
.
exists
(
model_path
):
raise
FileNotFoundError
(
f
"Директория
{
model_path
}
не существует"
)
...
...
@@ -219,14 +218,4 @@ def main():
print
(
f
"Ошибка при классификации промпта '
{
prompt
}
':
{
e
}
"
)
if
__name__
==
"__main__"
:
main
()
main
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
proverkabert.py
+
3
−
14
View file @
17d77dde
...
...
@@ -120,15 +120,14 @@ class MultiTaskBert(nn.Module):
self
.
classifier_safety
=
nn
.
Linear
(
768
,
2
)
# safe/unsafe
self
.
classifier_attack
=
nn
.
Linear
(
768
,
4
)
# 4 attack types
def
forward
(
self
,
input_ids
,
attention_mask
):
outputs
=
self
.
bert
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
)
def
forward
(
self
,
input_ids
,
attention_mask
,
token_type_ids
=
None
):
outputs
=
self
.
bert
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
)
pooled_output
=
outputs
.
last_hidden_state
[:,
0
,
:]
logits_safety
=
self
.
classifier_safety
(
pooled_output
)
logits_attack
=
self
.
classifier_attack
(
pooled_output
)
return
logits_safety
,
logits_attack
def
load_model
(
model_path
):
# Проверяем наличие файлов
if
not
os
.
path
.
exists
(
model_path
):
raise
FileNotFoundError
(
f
"Директория
{
model_path
}
не существует"
)
...
...
@@ -219,14 +218,4 @@ def main():
print
(
f
"Ошибка при классификации промпта '
{
prompt
}
':
{
e
}
"
)
if
__name__
==
"__main__"
:
main
()
main
()
\ No newline at end of file
This diff is collapsed.
Click to expand it.
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment
Menu
Explore
Projects
Groups
Topics
Snippets