feat: Add tensorboard callback to lightgbm

pull/12763/head
Robert Caulk 3 months ago
parent 0ef85e161e
commit 455954e0e3

@ -5,6 +5,7 @@ from lightgbm import LGBMClassifier
from freqtrade.freqai.base_models.BaseClassifierModel import BaseClassifierModel
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.tensorboard import LightGBMCallback
logger = logging.getLogger(__name__)
@ -46,6 +47,10 @@ class LightGBMClassifier(BaseClassifierModel):
init_model = self.get_init_model(dk.pair)
model = LGBMClassifier(**self.model_training_parameters)
activate_tensorboard = self.freqai_info.get("activate_tensorboard", True)
callbacks = []
if LightGBMCallback is not None:
callbacks = [LightGBMCallback(dk.data_path, activate_tensorboard)]
model.fit(
X=X,
y=y,
@ -53,6 +58,7 @@ class LightGBMClassifier(BaseClassifierModel):
sample_weight=train_weights,
eval_sample_weight=[test_weights],
init_model=init_model,
callbacks=callbacks,
)
return model

@ -6,6 +6,7 @@ from lightgbm import LGBMClassifier
from freqtrade.freqai.base_models.BaseClassifierModel import BaseClassifierModel
from freqtrade.freqai.base_models.FreqaiMultiOutputClassifier import FreqaiMultiOutputClassifier
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.tensorboard import LightGBMCallback
logger = logging.getLogger(__name__)
@ -53,6 +54,11 @@ class LightGBMClassifierMultiTarget(BaseClassifierModel):
else:
init_models = [None] * y.shape[1]
activate_tensorboard = self.freqai_info.get("activate_tensorboard", True)
callbacks = []
if LightGBMCallback is not None:
callbacks = [LightGBMCallback(dk.data_path, activate_tensorboard)]
fit_params = []
for i in range(len(eval_sets)):
fit_params.append(
@ -60,6 +66,7 @@ class LightGBMClassifierMultiTarget(BaseClassifierModel):
"eval_set": eval_sets[i],
"eval_sample_weight": eval_weights,
"init_model": init_models[i],
"callbacks": callbacks,
}
)

@ -5,6 +5,7 @@ from lightgbm import LGBMRegressor
from freqtrade.freqai.base_models.BaseRegressionModel import BaseRegressionModel
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.tensorboard import LightGBMCallback
logger = logging.getLogger(__name__)
@ -42,6 +43,11 @@ class LightGBMRegressor(BaseRegressionModel):
model = LGBMRegressor(**self.model_training_parameters)
activate_tensorboard = self.freqai_info.get("activate_tensorboard", True)
callbacks = []
if LightGBMCallback is not None:
callbacks = [LightGBMCallback(dk.data_path, activate_tensorboard)]
model.fit(
X=X,
y=y,
@ -49,6 +55,7 @@ class LightGBMRegressor(BaseRegressionModel):
sample_weight=train_weights,
eval_sample_weight=[eval_weights],
init_model=init_model,
callbacks=callbacks,
)
return model

@ -6,6 +6,7 @@ from lightgbm import LGBMRegressor
from freqtrade.freqai.base_models.BaseRegressionModel import BaseRegressionModel
from freqtrade.freqai.base_models.FreqaiMultiOutputRegressor import FreqaiMultiOutputRegressor
from freqtrade.freqai.data_kitchen import FreqaiDataKitchen
from freqtrade.freqai.tensorboard import LightGBMCallback
logger = logging.getLogger(__name__)
@ -55,6 +56,11 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel):
else:
init_models = [None] * y.shape[1]
activate_tensorboard = self.freqai_info.get("activate_tensorboard", True)
callbacks = []
if LightGBMCallback is not None:
callbacks = [LightGBMCallback(dk.data_path, activate_tensorboard)]
fit_params = []
for i in range(len(eval_sets)):
fit_params.append(
@ -62,6 +68,7 @@ class LightGBMRegressorMultiTarget(BaseRegressionModel):
"eval_set": eval_sets[i],
"eval_sample_weight": eval_weights,
"init_model": init_models[i],
"callbacks": callbacks,
}
)

@ -1,9 +1,11 @@
# ensure users can still use a non-torch freqai version
try:
from freqtrade.freqai.tensorboard.lightgbm_callback import LightGBMTensorboardCallback
from freqtrade.freqai.tensorboard.tensorboard import TensorBoardCallback, TensorboardLogger
TBLogger = TensorboardLogger
TBCallback = TensorBoardCallback
LightGBMCallback = LightGBMTensorboardCallback
except ModuleNotFoundError:
from freqtrade.freqai.tensorboard.base_tensorboard import (
BaseTensorBoardCallback,
@ -12,5 +14,6 @@ except ModuleNotFoundError:
TBLogger = BaseTensorboardLogger # type: ignore
TBCallback = BaseTensorBoardCallback # type: ignore
LightGBMCallback = None # type: ignore
__all__ = ("TBLogger", "TBCallback")
__all__ = ("TBLogger", "TBCallback", "LightGBMCallback")

@ -0,0 +1,24 @@
from __future__ import annotations
from freqtrade.freqai.tensorboard import TBLogger
class LightGBMTensorboardCallback:
def __init__(self, logdir, activate: bool) -> None:
self.activate = activate
self.logger = TBLogger(logdir, activate)
def __call__(self, env) -> None:
if not self.activate:
return
evals = getattr(env, "evaluation_result_list", None)
if not evals:
return
for data_name, metric_name, value, _ in evals:
self.logger.log_scalar(f"{data_name}-{metric_name}", value, env.iteration)
end_iteration = getattr(env, "end_iteration", None)
if end_iteration is not None and env.iteration + 1 >= end_iteration:
self.logger.close()
Loading…
Cancel
Save