Commit 17d77dde authored by Мазур Грета Евгеньевна's avatar Мазур Грета Евгеньевна
Browse files

proverka obuchenya

parent 34496ab5
No related merge requests found
Showing with 6 additions and 28 deletions
+6 -28
...@@ -120,15 +120,14 @@ class MultiTaskBert(nn.Module): ...@@ -120,15 +120,14 @@ class MultiTaskBert(nn.Module):
self.classifier_safety = nn.Linear(768, 2) # safe/unsafe self.classifier_safety = nn.Linear(768, 2) # safe/unsafe
self.classifier_attack = nn.Linear(768, 4) # 4 attack types self.classifier_attack = nn.Linear(768, 4) # 4 attack types
def forward(self, input_ids, attention_mask): def forward(self, input_ids, attention_mask, token_type_ids=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled_output = outputs.last_hidden_state[:, 0, :] pooled_output = outputs.last_hidden_state[:, 0, :]
logits_safety = self.classifier_safety(pooled_output) logits_safety = self.classifier_safety(pooled_output)
logits_attack = self.classifier_attack(pooled_output) logits_attack = self.classifier_attack(pooled_output)
return logits_safety, logits_attack return logits_safety, logits_attack
def load_model(model_path): def load_model(model_path):
# Проверяем наличие файлов
if not os.path.exists(model_path): if not os.path.exists(model_path):
raise FileNotFoundError(f"Директория {model_path} не существует") raise FileNotFoundError(f"Директория {model_path} не существует")
...@@ -219,14 +218,4 @@ def main(): ...@@ -219,14 +218,4 @@ def main():
print(f"Ошибка при классификации промпта '{prompt}': {e}") print(f"Ошибка при классификации промпта '{prompt}': {e}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -120,15 +120,14 @@ class MultiTaskBert(nn.Module): ...@@ -120,15 +120,14 @@ class MultiTaskBert(nn.Module):
self.classifier_safety = nn.Linear(768, 2) # safe/unsafe self.classifier_safety = nn.Linear(768, 2) # safe/unsafe
self.classifier_attack = nn.Linear(768, 4) # 4 attack types self.classifier_attack = nn.Linear(768, 4) # 4 attack types
def forward(self, input_ids, attention_mask): def forward(self, input_ids, attention_mask, token_type_ids=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled_output = outputs.last_hidden_state[:, 0, :] pooled_output = outputs.last_hidden_state[:, 0, :]
logits_safety = self.classifier_safety(pooled_output) logits_safety = self.classifier_safety(pooled_output)
logits_attack = self.classifier_attack(pooled_output) logits_attack = self.classifier_attack(pooled_output)
return logits_safety, logits_attack return logits_safety, logits_attack
def load_model(model_path): def load_model(model_path):
# Проверяем наличие файлов
if not os.path.exists(model_path): if not os.path.exists(model_path):
raise FileNotFoundError(f"Директория {model_path} не существует") raise FileNotFoundError(f"Директория {model_path} не существует")
...@@ -219,14 +218,4 @@ def main(): ...@@ -219,14 +218,4 @@ def main():
print(f"Ошибка при классификации промпта '{prompt}': {e}") print(f"Ошибка при классификации промпта '{prompt}': {e}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment