diff --git a/validate.py b/validate.py index cb71a0624..86327731f 100755 --- a/validate.py +++ b/validate.py @@ -314,11 +314,15 @@ def validate(args): model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non - input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device) - if args.channels_last: - input = input.contiguous(memory_format=torch.channels_last) - with amp_autocast(): - model(input) + inputs = [torch.randn((args.batch_size,) + tuple(data_config['input_size'])).to(device)] + last_batch_size = len(dataset) % args.batch_size + if last_batch_size: + inputs.append(torch.randn((last_batch_size,) + tuple(data_config['input_size'])).to(device)) + for inp in inputs: + if args.channels_last: + inp = inp.contiguous(memory_format=torch.channels_last) + with amp_autocast(): + model(inp) end = time.time() for batch_idx, (input, target) in enumerate(loader):