LitModule
- class LitModule(dataset='nations', dataset_kwargs=None, mode=None, model='distmult', model_kwargs=None, batch_size=32, learning_rate=0.001, label_smoothing=0.0, optimizer=None, optimizer_kwargs=None)[source]
Bases:
LightningModule
A base module for training models with PyTorch Lightning.
Create the lightning module.
- Parameters:
dataset (
Union
[str
,Dataset
,Type
[Dataset
],None
]) – the dataset, or a hint thereofdataset_kwargs (
Optional
[Mapping
[str
,Any
]]) – additional keyword-based parameters passed to the datasetmode (
Optional
[Literal
[‘training’, ‘validation’, ‘testing’]]) – the inductive mode; defaults to transductive trainingmodel (
Union
[str
,Model
,Type
[Model
],None
]) – the model, or a hint thereofmodel_kwargs (
Optional
[Mapping
[str
,Any
]]) – additional keyword-based parameters passed to the modelbatch_size (
int
) – the training batch sizelearning_rate (
float
) – the learning ratelabel_smoothing (
float
) – the label smoothingoptimizer (
Union
[str
,Optimizer
,Type
[Optimizer
],None
]) – the optimizer, or a hint thereofoptimizer_kwargs (
Optional
[Mapping
[str
,Any
]]) – additional keyword-based parameters passed to the optimizer. should not contain lr, or params.
Methods Summary
Configure the optimizers.
forward
(hr_batch)Perform the prediction or inference step by wrapping
pykeen.models.ERModel.predict_t()
.on_before_zero_grad
(optimizer)Called after
training_step()
and beforeoptimizer.zero_grad()
.Create the training data loader.
training_step
(batch, batch_idx)Perform a training step.
Create the validation data loader.
validation_step
(batch, batch_idx, *args, ...)Perform a validation step.
Methods Documentation
- forward(hr_batch)[source]
Perform the prediction or inference step by wrapping
pykeen.models.ERModel.predict_t()
.- Parameters:
hr_batch (
LongTensor
) – shape: (batch_size, 2), dtype: long The indices of (head, relation) pairs.- Return type:
FloatTensor
- Returns:
shape: (batch_size, num_entities), dtype: float For each h-r pair, the scores for all possible tails.
Note
in lightning, forward defines the prediction/inference actions
- on_before_zero_grad(optimizer)[source]
Called after
training_step()
and beforeoptimizer.zero_grad()
.Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.
This is where it is called:
for optimizer in optimizers: out = training_step(...) model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad() backward()
- Args:
optimizer: The optimizer for which grads should be zeroed.