|
|
|
|
@ -148,7 +148,8 @@ class PyTorchModelTrainer(PyTorchTrainerInterface):
|
|
|
|
|
the motivation here is that `n_steps` is easier to optimize and keep stable,
|
|
|
|
|
across different n_obs - the number of data points.
|
|
|
|
|
"""
|
|
|
|
|
assert isinstance(self.n_steps, int), "Either `n_steps` or `n_epochs` should be set."
|
|
|
|
|
if not isinstance(self.n_steps, int):
|
|
|
|
|
raise ValueError("Either `n_steps` or `n_epochs` should be set.")
|
|
|
|
|
n_batches = n_obs // self.batch_size
|
|
|
|
|
n_epochs = max(self.n_steps // n_batches, 1)
|
|
|
|
|
if n_epochs <= 10:
|
|
|
|
|
|