Commit 17d77dde authored by Мазур Грета Евгеньевна's avatar Мазур Грета Евгеньевна
Browse files

proverka obuchenya

parent 34496ab5
No related merge requests found
Showing with 6 additions and 28 deletions
+6 -28
......@@ -120,15 +120,14 @@ class MultiTaskBert(nn.Module):
self.classifier_safety = nn.Linear(768, 2) # safe/unsafe
self.classifier_attack = nn.Linear(768, 4) # 4 attack types
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
def forward(self, input_ids, attention_mask, token_type_ids=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled_output = outputs.last_hidden_state[:, 0, :]
logits_safety = self.classifier_safety(pooled_output)
logits_attack = self.classifier_attack(pooled_output)
return logits_safety, logits_attack
def load_model(model_path):
# Проверяем наличие файлов
if not os.path.exists(model_path):
raise FileNotFoundError(f"Директория {model_path} не существует")
......@@ -219,14 +218,4 @@ def main():
print(f"Ошибка при классификации промпта '{prompt}': {e}")
if __name__ == "__main__":
main()
main()
\ No newline at end of file
......@@ -120,15 +120,14 @@ class MultiTaskBert(nn.Module):
self.classifier_safety = nn.Linear(768, 2) # safe/unsafe
self.classifier_attack = nn.Linear(768, 4) # 4 attack types
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
def forward(self, input_ids, attention_mask, token_type_ids=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
pooled_output = outputs.last_hidden_state[:, 0, :]
logits_safety = self.classifier_safety(pooled_output)
logits_attack = self.classifier_attack(pooled_output)
return logits_safety, logits_attack
def load_model(model_path):
# Проверяем наличие файлов
if not os.path.exists(model_path):
raise FileNotFoundError(f"Директория {model_path} не существует")
......@@ -219,14 +218,4 @@ def main():
print(f"Ошибка при классификации промпта '{prompt}': {e}")
if __name__ == "__main__":
main()
main()
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment