SymmetricLCWATrainingLoop¶
- class SymmetricLCWATrainingLoop(model, triples_factory, optimizer=None, optimizer_kwargs=None, lr_scheduler=None, lr_scheduler_kwargs=None, automatic_memory_optimization=True, mode=None, result_tracker=None, result_tracker_kwargs=None)[source]¶
Bases:
TrainingLoop
[Tuple
[LongTensor
],Tuple
[LongTensor
]]A “symmetric” LCWA scoring heads and tails at once.
This objective was introduced by [lacroix2018] as
\[l_{i,j,k}(X) = - X_{i,j,k} + \log \left( \sum_{k'} \exp(X_{i,j,k′}) \right) - X_{k,j+P,i} + \log \left( \sum_{i'} \exp (X_{k, j+P, i'}) \right)\]which can be seen as a “symmetric LCWA”, where for one batch of triples, we score both, heads and tails, given the remainder of the triple.
Note
at the same time, there is a also a difference to the
LCWATrainingLoop
: we do not group by e.g., head+relation pairs. Thus, the name might be suboptimal and change in the future.Initialize the training loop.
- Parameters:
model (
Model
) – The model to traintriples_factory (
CoreTriplesFactory
) – The training triples factoryoptimizer (
Union
[str
,Optimizer
,Type
[Optimizer
],None
]) – The optimizer to use while training the modeloptimizer_kwargs (
Optional
[Mapping
[str
,Any
]]) – additional keyword-based parameters to instantiate the optimizer (if necessary). params will be added automatically based on the model.lr_scheduler (
Union
[str
,_LRScheduler
,Type
[_LRScheduler
],None
]) – The learning rate scheduler you want to use while training the modellr_scheduler_kwargs (
Optional
[Mapping
[str
,Any
]]) – additional keyword-based parameters to instantiate the LR scheduler (if necessary). optimizer will be added automatically.automatic_memory_optimization (
bool
) – bool Whether to automatically optimize the sub-batch size during training and batch size during evaluation with regards to the hardware at hand.mode (
Optional
[Literal
[‘training’, ‘validation’, ‘testing’]]) – The inductive training mode. None if transductive.result_tracker (
Union
[str
,ResultTracker
,Type
[ResultTracker
],None
]) – the result trackerresult_tracker_kwargs (
Optional
[Mapping
[str
,Any
]]) – additional keyword-based parameters to instantiate the result tracker