|
|
|
|
@ -51,9 +51,7 @@ class PyTorchTransformerClassifier(BasePyTorchClassifier):
|
|
|
|
|
)
|
|
|
|
|
model.to(self.device)
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
|
|
|
|
|
criterion = torch.nn.CrossEntropyLoss(
|
|
|
|
|
weight=torch.tensor([10.0, 10.0, 1.0]).to(self.device)
|
|
|
|
|
)
|
|
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
|
|
|
# check if continual_learning is activated, and retrieve the model to continue training
|
|
|
|
|
trainer = self.get_init_model(dk.pair)
|
|
|
|
|
if trainer is None:
|
|
|
|
|
|