TenCrop¶
-
class
torchvision.transforms.
TenCrop
(size, vertical_flip=False)[source]¶ Crop the given image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). If the image is torch Tensor, it is expected to have […, H, W] shape, where … means an arbitrary number of leading dimensions
Note
This transform returns a tuple of images and there may be a mismatch in the number of inputs and targets your Dataset returns. See below for an example of how to deal with this.
- Parameters
Example
>>> transform = Compose([ >>> TenCrop(size), # this is a list of PIL Images >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor >>> ]) >>> #In your test loop you can do the following: >>> input, target = batch # input is a 5d tensor, target is 2d >>> bs, ncrops, c, h, w = input.size() >>> result = model(input.view(-1, c, h, w)) # fuse batch size and ncrops >>> result_avg = result.view(bs, ncrops, -1).mean(1) # avg over crops