luz.predictors module¶
- class BaseLearner(**hparams)¶
Bases:
luz.predictors.Predictor
- abstract criterion()¶
- Return type
Callable
[[Tensor
,Tensor
],Tensor
]
- abstract fit_params()¶
Return fit parameters.
- Parameters
train_dataset – Training dataset used to learn a model.
val_dataset – Validation dataset.
device – Device to use for learning.
- Returns
Dictionary of fit parameters.
- Return type
dict[str, Any]
- abstract model(train_dataset)¶
Return module to be trained.
- Parameters
train_dataset – Training dataset used to learn a model.
- Returns
Module to be trained.
- Return type
torch.nn.Module
- abstract optimizer(model)¶
- Return type
Optimizer
- abstract run_batch(model, data)¶
Run model on a single batch.
- Return type
tuple
[Tensor
,Tensor
]- Returns
torch.Tensor – Model output.
torch.Tensor – Batch loss.