ncdia.trainers
ncdia.trainers¶
ncdia.trainers.base.py¶
BaseTrainer¶
Basic trainer class for training models.
Attributes:
- model (nn.Module): Neural network models.
- train_loader (DataLoader): DataLoader for training.
- val_loader (DataLoader): DataLoader for validation.
- test_loader (DataLoader): DataLoader for testing.
- optimizer (Optimizer): Optimizer.
- scheduler (lr_scheduler._LRScheduler): Learning rate scheduler.
- criterion (Callable): Criterion for training.
- algorithm (object): Algorithm for training.
- metrics (dict): Metrics for evaluation and testing.
- session (int): Session number.
- max_epochs (int): Total epochs for training.
- max_train_iters (int): Iterations on one epoch for training.
- max_val_iters (int): Iterations on one epoch for validation.
- max_test_iters (int): Iterations on one epoch for testing.
- epoch (int): Current training epoch.
- iter (int): Current iteration or index of the current batch.
- cfg (Configs): Configuration for trainer.
- hooks (List[Hook]): List of registered hooks.
- logger (Logger): Logger for logging information.
- device (torch.device): Device to use.
- work_dir (str): Working directory to save logs and checkpoints.
- exp_name (str): Experiment name.
- load_from (str): Checkpoint file path to load.
Methods:
-
__init__(self, cfg, session, model, train_loader, val_loader, test_loader, default_hooks, custom_hooks, load_from, exp_name, work_dir)
The constructor method that initializes an instance of BaseTrainer.
Parameters:
- cfg (dict, optional): Configuration for trainer, Contains:
- 'trainer' (dict):
- 'type' (str): Type of trainer.
- 'algorithm' (dict):
- 'type' (str): Type of algorithm.
- 'criterion' (dict):
- 'type' (str): Type of criterion for training.
- 'optimizer':
- 'type' (str): Name of optimizer.
- 'param_groups' (dict | None): If provided, directly optimize param_groups and abandon model.
- kwargs (dict) for optimizer, such as 'lr', 'weight_decay', etc.
- 'scheduler':
- 'type' (str): Name of scheduler.
- kwargs (dict) for scheduler, such as 'step_size', 'gamma', etc.
- 'device' (str | torch.device | None): Device to use. If None, use 'cuda' if available.
- 'trainloader':
- 'dataset':
- 'type' (str): Type of dataset.
- kwargs (dict) for dataset, such as 'root', 'split', etc.
- kwargs (dict) for DataLoader, such as 'batch_size', 'shuffle', etc.
- 'dataset':
- 'valloader':
- 'dataset':
- 'type' (str): Type of dataset.
- kwargs (dict) for dataset, such as 'root', 'split', etc.
- kwargs (dict) for DataLoader, such as 'batch_size', 'shuffle', etc.
- 'dataset':
- 'testloader':
- 'dataset':
- 'type' (str): Type of dataset.
- kwargs (dict) for dataset, such as 'root', 'split', etc.
- kwargs (dict) for DataLoader, such as 'batch_size', 'shuffle', etc.
- 'dataset':
- 'exp_name' (str): Experiment name.
- 'work_dir' (str): Working directory to save logs and checkpoints.
- 'trainer' (dict):
- session (int): Session number. If == 0, execute pre-training. If > 0, execute incremental training.
- model (nn.Module): Model to be trained.
- train_loader (DataLoader | dict, optional): DataLoader for training.
- val_loader (DataLoader | dict, optional): DataLoader for validation.
- test_loader (DataLoader | dict, optional): DataLoader for testing.
- default_hooks (dict, optional): Default hooks to be registered.
- custom_hooks (list, optional): Custom hooks to be registered.
- load_from (str, optional): Checkpoint file path to load.
- work_dir (str, optional): Working directory to save logs and checkpoints.
- cfg (dict, optional): Configuration for trainer, Contains:
-
train_step(self, batch, **kwargs)
Training step. This method should be implemented in subclasses.
Parameters:
- batch (dict | tuple | list): A batch of data from the data loader.
Returns:
-
results (dict): Contains the following:
{"key1": value1, "key2": value2,...}
keys denote the description of the value, such as "loss", "acc", "ccr", etc. values are the corresponding values of the keys, can be int, float, str, etc.
-
val_step(self, batch, kwargs)**
Validation step. This method should be implemented in subclasses.
Parameters:
- batch (dict | tuple | list): A batch of data from the data loader.
Returns:
-
results (dict): Contains the following:
{"key1": value1, "key2": value2,...}
keys denote the description of the value, such as "loss", "acc", "ccr", etc. values are the corresponding values of the keys, can be int, float, str, etc.
-
test_step(self, batch, kwargs)**
Test step. This method should be implemented in subclasses.
Parameters:
- batch (dict | tuple | list): A batch of data from the data loader.
Returns:
-
results (dict): Contains the following:
{"key1": value1, "key2": value2,...}
keys denote the description of the value, such as "loss", "acc", "ccr", etc. values are the corresponding values of the keys, can be int, float, str, etc.
-
train(self)
Launch the training process.
Returns:
- model (nn.Module): Trained model.
-
val(self)
Validation process.
- test(self)
Test process.
-
load_ckpt(self, fpath, device='cpu')
Load checkpoint from file.
Parameters:
- fpath (str): Checkpoint file path.
- device (str): Device to load checkpoint. Defaults to 'cpu'.
Returns:
- model (nn.Module): Loaded model.
-
save_ckpt(self, fpath)
Save checkpoint to file.
Parameters:
- fpath (str): Checkpoint file path.
-
call_hook(self, fn_name: str, kwargs)**
Call all hooks with the specified function name.
Parameters:
-
fn_name (str): Function name to be called, such as:
- 'before_train_epoch'
- 'after_train_epoch'
- 'before_train_iter'
- 'after_train_iter'
- 'before_val_epoch'
- 'after_val_epoch'
- 'before_val_iter'
- 'after_val_iter'
-
kwargs (dict): Arguments for the function.
-
-
register_hook(self, hook, priority=None)
Register a hook into the hook list.
The hook will be inserted into a priority queue, with the specified priority (See :class:
Priority
for details of priorities). For hooks with the same priority, they will be triggered in the same order as they are registered. Priority of hook will be decided with the following priority:priority
argument. Ifpriority
is given, it will be priority of hook.- If
hook
argument is a dict andpriority
in it, the priority will be the value ofhook['priority']
. - If
hook
argument is a dict butpriority
not in it orhook
is an instance ofhook
, the priority will behook.priority
.
Parameters:
- hook (:obj:
Hook
or dict): The hook to be registered. priority (int or str or :obj:Priority
, optional): Hook priority. Lower value means higher priority.
-
register_default_hooks(self, hooks=None)
Register default hooks into hook list.
hooks
will be registered into runner to execute some default actions like updating model parameters or saving checkpoints.Default hooks and their priorities:
Hooks Priority RuntimeInfoHook VERY_HIGH (10) IterTimerHook NORMAL (50) DistSamplerSeedHook NORMAL (50) LoggerHook BELOW_NORMAL (60) ParamSchedulerHook LOW (70) CheckpointHook VERY_LOW (90) If
hooks
is None, above hooks will be registered by default:Text Only 1 2 3 4 5 6 7 8
default_hooks = dict( logger=dict(type='LoggerHook'), model=dict(type='ModelHook'), alg=dict(type='AlgHook'), optimizer = dict(type='OptimizerHook'), scheduler = dict(type='SchedulerHook'), metric = dict(type='MetricHook'), )
If not None,
hooks
will be merged intodefault_hooks
. If there are None value in default_hooks, the corresponding item will be popped fromdefault_hooks
:Text Only 1
hooks = dict(timer=None)
The final registered default hooks will be :obj:
RuntimeInfoHook
, :obj:DistSamplerSeedHook
, :obj:LoggerHook
, :obj:ParamSchedulerHook
and :obj:CheckpointHook
.Parameters:
- hooks (dict[str, Hook or dict]): Default hooks or configs to be registered.
-
register_custom_hooks(self, hooks)
Register custom hooks into hook list.
Parameters:
hooks (list[Hook | dict]): List of hooks or configs to be registered.
-
register_hooks(self, default_hooks=None, custom_hooks=None)
Register default hooks and custom hooks into hook list.
Parameters:
- default_hooks (dict[str, dict] or dict[str, Hook]): Hooks to execute default actions like updating model parameters and saving checkpoints. Defaults to None.
- custom_hooks (list[dict] or list[Hook]): Hooks to execute custom actions like visualizing images processed by pipeline. Defaults to None.
-
get_hooks_info(self)
Get registered hooks information.
Returns:
- info (str): Information of registered hooks.
ncdia.trainers.pretrainer.py¶
PreTrainer¶
PreTrainer class for pre-training a model on session 0.
Attributes:
- max_epochs (int): Total epochs for training.
Methods:
- __init__(self, max_epochs=1, **kwargs): The constructor method that initializes an instance of PreTrainer. max_epochs (int): Total epochs for training.
- train_step(self, batch, **kwargs): Training step.
- val_step(self, batch, **kwargs): Validation step.
- test_step(self, batch, **kwargs): Test step.
-
batch_parser(batch)
Parse a batch of data.
Parameters:
- batch (dict | tuple | list): A batch of data.
Returns:
- data (torch.Tensor | list): Input data.
- label (torch.Tensor | list): Label data.
- attribute (torch.Tensor | list): Attribute data.
- imgpath (list of str): Image path.
ncdia.trainers.inctrainer.py¶
IncTrainer¶
IncTrainer class for incremental training.
Attributes:
- sess_cfg (Configs): Session configuration.
- num_sess (int): Number of sessions.
- session (int): Session number. If == 0, execute pre-training. If > 0, execute incremental training.
- hist_trainset (MergedDataset): Historical training dataset.
- hist_valset (MergedDataset): Historical validation dataset.
- hist_testset (MergedDataset): Historical testing dataset.
Methods:
-
__init__(self, cfg=None, sess_cfg=None, ncd_cfg=None, session=0, model=None, hist_trainset=None, hist_testset=None, old_model=None, **kwargs)
The constructor method that initializes an instance of IncTrainer.
Parameters:
- model (nn.Module): Model to be trained.
- cfg (dict): Configuration for trainer.
- sess_cfg (Configs): Session configuration.
- session (int): Session number. Default: 0.
-
train(self)
Incremental training.
self.num_sess determines the number of sessions, and session number is stored in self.session.
Returns:
- model (nn.Module): Trained model.
ncdia.trainers.hooks¶
Implements some of the commonly used hooks.
Hook¶
ncdia.trainers.hooks.hook.py
Base hook class. All hooks should inherit from this class.
AlgHook¶
ncdia.trainers.hooks.alghook.py
A hook to modify algorithm state in the pipeline. This class is a base class for all algorithm hooks.
LoggerHook¶
ncdia.trainers.hooks.loggerhook.py
A hook to log information during training and evaluation.
MetricHook¶
ncdia.trainers.hooks.metrichook.py
A hook to calculate metrics during evaluation and testing.
ModelHook¶
ncdia.trainers.hooks.modelhook.py
A hook to change model state in the pipeline, such as setting device, changing model to eval mode, etc.
NCDHook¶
ncdia.trainers.hooks.ncdhook.py
A hook to execute OOD and NCD detection to relabel data
OptimizerHook¶
ncdia.trainers.hooks.optimizerhook.py
A hook to put optimizer to zero_grad and step during training.
SchedulerHook¶
ncdia.trainers.hooks.schedulerhook.py
A hook to change learning rate during training.
ncdia.trainers.optims¶
ncdia.trainers.optims.optimizer.py¶
-
build_optimizer(type, model, param_groups=None, **kwargs)
Build optimizer.
Parameters:
- type (str): type of optimizer
- model (nn.Module | dict): model or param_groups
- param_groups (dict | None): if provided, directly optimize param_groups and abandon model
- kwargs (dict): arguments for optimizer
Returns:
- optimizer (torch.optim.Optimizer): optimizer
ncdia.trainers.optims.scheduler.py¶
Implements some of the commonly used scheduler.
- CosineWarmupLR
- LinearWarmupLR
- ConstantLR
Methods:
-
build_scheduler(type, optimizer, **kwargs)
Build learning rate scheduler.
Parameters:
- type (str): type of scheduler
- optimizer (torch.optim.Optimizer): optimizer
- kwargs (dict): arguments for scheduler
Returns:
- lr_scheduler (torch.optim.lr_scheduler._LRScheduler): learning rate scheduler
ncdia.trainers.priority¶
Hook priority levels.
Priority¶
Level | Value |
---|---|
HIGHEST | 0 |
VERY_HIGH | 10 |
HIGH | 30 |
ABOVE_NORMAL | 40 |
NORMAL | 50 |
BELOW_NORMAL | 60 |
LOW | 70 |
VERY_LOW | 90 |
LOWEST | 100 |