Example of usage with PyTorch Lightning

This section provides an example of how to use Prov4ML with PyTorch Lightning.

In any lightning module the calls to train_step, validation_step, and test_step can be overridden to log the necessary information.


Example:
def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    loss = self.loss(y_hat, y)
    prov4ml.log_metric("MSE_train", loss, prov4ml.Context.TRAINING, step=self.current_epoch)
    prov4ml.log_flops_per_batch("train_flops", self, batch, prov4ml.Context.TRAINING,step=self.current_epoch)
    return loss

This will log the mean squared error and the number of flops per batch for each the training step.

Alternatively, the on_train_epoch_end method can be overridden to log information at the end of each epoch.


Example:
def on_train_epoch_end(self) -> None:
    prov4ml.log_metric("epoch", self.current_epoch, prov4ml.Context.TRAINING, step=self.current_epoch)
    prov4ml.save_model_version(self, f"model_version_{self.current_epoch}", prov4ml.Context.TRAINING, step=self.current_epoch)
    prov4ml.log_system_metrics(prov4ml.Context.TRAINING,step=self.current_epoch)
    prov4ml.log_carbon_metrics(prov4ml.Context.TRAINING,step=self.current_epoch)
    prov4ml.log_current_execution_time("train_epoch_time", prov4ml.Context.TRAINING, self.current_epoch)

Example of usage with PyTorch Lightning Logger

When integrating with lightning, a much easier way to produce the provenance graph is through the ProvMLLogger.


Example:
trainer = L.Trainer(
    accelerator="cuda",
    devices=1,
    max_epochs=EPOCHS,
    enable_checkpointing=False, 
    log_every_n_steps=1, 
    logger=[prov4ml.ProvMLLogger()],
)

When logging in such a way, there is no need to call the start_run and end_run directives, and everything will be logged automatically. If necessary, it's still possible to call all yprov4ml directives, such as log_param and log_metrics, and the data will be saved in the current execution directory.