# 创建一个 Trainer 对象,用于管理训练过程pl.seed_everything(42)trainer = pl.Trainer( accelerator="cpu",# 梯度裁剪是一种防止梯度爆炸的技术,特别是对于循环神经网络(RNNs)。当梯度值超过这个值时,它们会被缩放,以避免训练过程中出现数值不稳定 gradient_clip_val=0.1,)tft = TemporalFusionTransformer.from_dataset( training,# not meaningful for finding the learning rate but otherwise very important learning_rate=0.03, hidden_size=8, # most important hyperparameter apart from learning rate# number of attention heads. Set to up to 4 for large datasets attention_head_size=1, dropout=0.1, # between 0.1 and 0.3 are good values hidden_continuous_size=8, # set to <= hidden_size loss=QuantileLoss(), optimizer="ranger",# reduce learning rate if no improvement in validation loss after x epochs# reduce_on_plateau_patience=1000,)print(f"Number of parameters in network: {tft.size() /1e3:.1f}k")# find optimal learning rate# from lightning.pytorch.tuner import Tuner# res = Tuner(trainer).lr_find(# tft,# train_dataloaders=train_dataloader,# val_dataloaders=val_dataloader,# max_lr=10.0,# min_lr=1e-6# )# print(f"suggested learning rate: {res.suggestion()}")# fig = res.plot(show=True, suggest=True)# fig.show()
Number of parameters in network: 13.5k
代码
# configure network and trainerearly_stop_callback = EarlyStopping( monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")lr_logger = LearningRateMonitor() # log the learning ratelogger = TensorBoardLogger("lightning_logs") # logging results to a tensorboardtrainer = pl.Trainer( max_epochs=50, accelerator="cpu", enable_model_summary=True, gradient_clip_val=0.1, limit_train_batches=50, # coment in for training, running valiation every 30 batches# fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs callbacks=[lr_logger, early_stop_callback], logger=logger,)tft = TemporalFusionTransformer.from_dataset( training, learning_rate=0.03, hidden_size=16, attention_head_size=2, dropout=0.1, hidden_continuous_size=8, loss=QuantileLoss(), log_interval=10, # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches optimizer="ranger", reduce_on_plateau_patience=4,)print(f"Number of parameters in network: {tft.size() /1e3:.1f}k")