luz.predictors module

class BaseLearner(**hparams)

Bases: luz.predictors.Predictor

abstract criterion()
Return type

Callable[[Tensor, Tensor], Tensor]

abstract fit_params()

Return fit parameters.

Parameters
  • train_dataset – Training dataset used to learn a model.

  • val_dataset – Validation dataset.

  • device – Device to use for learning.

Returns

Dictionary of fit parameters.

Return type

dict[str, Any]

abstract model(train_dataset)

Return module to be trained.

Parameters

train_dataset – Training dataset used to learn a model.

Returns

Module to be trained.

Return type

torch.nn.Module

abstract optimizer(model)
Return type

Optimizer

abstract run_batch(model, data)

Run model on a single batch.

Return type

tuple[Tensor, Tensor]

Returns

  • torch.Tensor – Model output.

  • torch.Tensor – Batch loss.

class BaseTuner(num_iterations, **hparams)

Bases: luz.predictors.Predictor

abstract get_trial(hparams, trials, scores)
abstract hparams()
abstract learner(trial)
abstract scorer()