fix: remove hardcoded class weight for debug.

pull/11219/head
b2r66sun 1 year ago committed by GitHub
parent c3309f2152
commit 7de74f4730
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -51,9 +51,7 @@ class PyTorchTransformerClassifier(BasePyTorchClassifier):
)
model.to(self.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=self.learning_rate)
criterion = torch.nn.CrossEntropyLoss(
weight=torch.tensor([10.0, 10.0, 1.0]).to(self.device)
)
criterion = torch.nn.CrossEntropyLoss()
# check if continual_learning is activated, and retrieve the model to continue training
trainer = self.get_init_model(dk.pair)
if trainer is None:

Loading…
Cancel
Save