diff --git a/.ipynb_checkpoints/bertwithgoodgrad-checkpoint.py b/.ipynb_checkpoints/bertwithgoodgrad-checkpoint.py index fceea1b617079852c4bd9c5aeba5af12bb726776..2be84ffccb96fce8a2c1ecda0c34e37f5f3883cc 100644 --- a/.ipynb_checkpoints/bertwithgoodgrad-checkpoint.py +++ b/.ipynb_checkpoints/bertwithgoodgrad-checkpoint.py @@ -43,7 +43,7 @@ class Classifiers(torch.nn.Module): def forward(self, hidden_state): cls_token = hidden_state[:, 0, :] # Берем [CLS] токен return { - 'safety': torch.softmax(self.safety(cls_token), + 'safety': torch.softmax(self.safety(cls_token)), 'attack': torch.softmax(self.attack(cls_token)) } diff --git a/bertwithgoodgrad.py b/bertwithgoodgrad.py index fceea1b617079852c4bd9c5aeba5af12bb726776..2be84ffccb96fce8a2c1ecda0c34e37f5f3883cc 100644 --- a/bertwithgoodgrad.py +++ b/bertwithgoodgrad.py @@ -43,7 +43,7 @@ class Classifiers(torch.nn.Module): def forward(self, hidden_state): cls_token = hidden_state[:, 0, :] # Берем [CLS] токен return { - 'safety': torch.softmax(self.safety(cls_token), + 'safety': torch.softmax(self.safety(cls_token)), 'attack': torch.softmax(self.attack(cls_token)) }