例子
from argparse import ArgumentParser
from torch import nn
from torch.optim import SGD
from torchvision.transforms import Compose, ToTensor, Normalizefrom ignite.engines import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import CategoricalAccuracy, Lossdef run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):cuda = torch.cuda.is_available()train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)model = Net()if cuda:model = model.cuda()optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)trainer = create_supervised_trainer(model, optimizer, F.nll_loss, cuda=cuda)evaluator = create_supervised_evaluator(model,metrics={'accuracy': CategoricalAccuracy(),'nll': Loss(F.nll_loss)},cuda=cuda)@trainer.on(Events.ITERATION_COMPLETED)def log_training_loss(engine):iter = (engine.iteration - 1) % len(train_loader) + 1if iter % log_interval == 0:print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}".format(engine.state.epoch, iter, len(train_loader), engine.state.output))@trainer.on(Events.EPOCH_COMPLETED)def log_validation_results(engine):metrics = evaluator.run(val_loader).metricsavg_accuracy = metrics['accuracy']avg_nll = metrics['nll']print("Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}".format(engine.state.epoch, avg_accuracy, avg_nll))trainer.run(train_loader, max_epochs=epochs)
---------------------
流程
创建模型, 创建 Dataloader
创建 trainer
创建 evaluator
为一些事件注册函数, @trainer.on()
trainer.run()
Event
"""
类似枚举类, 定义了几个事件
"""
class Events(Enum):EPOCH_STARTED = "epoch_started" # 当一个新的 epoch 开始时会触发此事件EPOCH_COMPLETED = "epoch_completed" # 当一个 epoch 结束时, 会触发此事件STARTED = "started" # 开始训练模型是, 会触发此事件COMPLETED = "completed" # 当训练结束时, 会触发此事件ITERATION_STARTED = "iteration_started" # 当一个 iteration 开始时, 会触发此事件ITERATION_COMPLETED = "iteration_completed" # 当一个 iteration 结束时, 会触发此事件EXCEPTION_RAISED = "exception_raised" # 当有异常发生时, 会触发此事件
---------------------
State
class State(object):def __init__(self, **kwargs):self.iteration = 0 # 记录 iterationself.output = None # 当前 iteration 的 输出. 对于 Supervised Trainer 来说, 是 loss.self.batch = None # 本次 iteration 的 mini-batch 样本for k, v in kwargs.items(): # 其它一些希望 State 记录下来的 状态setattr(self, k, v)
---------------------
create_supervised_trainer
def create_supervised_trainer(model, optimizer, loss_fn, cuda=False):"""Factory function for creating a trainer for supervised modelsArgs:model (torch.nn.Module): the model to trainoptimizer (torch.optim.Optimizer): the optimizer to useloss_fn (torch.nn loss function): the loss function to usecuda (bool, optional): whether or not to transfer batch to GPU (default: False)Returns:Trainer: a trainer instance with supervised update function"""
---------------------
create_supervised_evaluator
def create_supervised_evaluator(model, metrics={}, cuda=False):"""Factory function for creating an evaluator for supervised modelsArgs:model (torch.nn.Module): the model to trainmetrics (dict of str: Metric): a map of metric names to Metricscuda (bool, optional): whether or not to transfer batch to GPU (default: False)Returns:Evaluator: a evaluator instance with supervised inference function"""
---------------------
Trainer
# 继承自 Engine
def __init__(self, process_function):pass """
process_function 的 signature 是 func(batch)->anything
def func(batch): # batch会保存在 state.batch 中1. process batch2. forward compution3. compute loss4. computer gradient5. update parameters6. return loss or else # 返回的值会被保存在 state.output 中"""""" 为某事件注册函数, 当事件发生时, 此函数就会被调用
函数的 signature 必须是 def func(trainer, state)
"""
@trainer.on(...)
def some_func(trainer):passTrainer.run() # 训练模型
---------------------
Evaluator
# 继承自 Engine
def __init__(self, process_function):pass """
process_function 的 signature 是 func(batch)->anything
def func(batch): # batch会保存在 state.batch 中1. process batch2. forward compution3. return something # 返回的值会被保存在 state.output 中,# 用来计算 Metric
"""# 为 evaluator 一些事件注册 函数.
@evaluator.on(...)
def func(evaluator):passEvaluator.run() # 执行计算. 返回 state
state.metrics # 验证集上 metrics 计算的结果都保存在这里
---------------------