diff --git a/.ipynb_checkpoints/proverkabert-checkpoint.py b/.ipynb_checkpoints/proverkabert-checkpoint.py index 44a5249d02234acdb0259f7764a8ed446db52e01..61952b9d3f5639f49c4c3972c71dbebbaddeafed 100644 --- a/.ipynb_checkpoints/proverkabert-checkpoint.py +++ b/.ipynb_checkpoints/proverkabert-checkpoint.py @@ -120,15 +120,14 @@ class MultiTaskBert(nn.Module): self.classifier_safety = nn.Linear(768, 2) # safe/unsafe self.classifier_attack = nn.Linear(768, 4) # 4 attack types - def forward(self, input_ids, attention_mask): - outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + def forward(self, input_ids, attention_mask, token_type_ids=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) pooled_output = outputs.last_hidden_state[:, 0, :] logits_safety = self.classifier_safety(pooled_output) logits_attack = self.classifier_attack(pooled_output) return logits_safety, logits_attack def load_model(model_path): - # Проверяем наличие файлов if not os.path.exists(model_path): raise FileNotFoundError(f"Директория {model_path} не существует") @@ -219,14 +218,4 @@ def main(): print(f"Ошибка при классификации промпта '{prompt}': {e}") if __name__ == "__main__": - main() - - - - - - - - - - + main() \ No newline at end of file diff --git a/proverkabert.py b/proverkabert.py index 44a5249d02234acdb0259f7764a8ed446db52e01..61952b9d3f5639f49c4c3972c71dbebbaddeafed 100644 --- a/proverkabert.py +++ b/proverkabert.py @@ -120,15 +120,14 @@ class MultiTaskBert(nn.Module): self.classifier_safety = nn.Linear(768, 2) # safe/unsafe self.classifier_attack = nn.Linear(768, 4) # 4 attack types - def forward(self, input_ids, attention_mask): - outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + def forward(self, input_ids, attention_mask, token_type_ids=None): + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) pooled_output = outputs.last_hidden_state[:, 0, :] logits_safety = self.classifier_safety(pooled_output) logits_attack = self.classifier_attack(pooled_output) return logits_safety, logits_attack def load_model(model_path): - # Проверяем наличие файлов if not os.path.exists(model_path): raise FileNotFoundError(f"Директория {model_path} не существует") @@ -219,14 +218,4 @@ def main(): print(f"Ошибка при классификации промпта '{prompt}': {e}") if __name__ == "__main__": - main() - - - - - - - - - - + main() \ No newline at end of file