.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/blitz/data_parallel_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_blitz_data_parallel_tutorial.py: Optional: Data Parallelism ========================== **Authors**: `Sung Kim `_ and `Jenny Kang `_ In this tutorial, we will learn how to use multiple GPUs using ``DataParallel``. It's very easy to use GPUs with PyTorch. You can put the model on a GPU: .. code:: python device = torch.device("cuda:0") model.to(device) Then, you can copy all your tensors to the GPU: .. code:: python mytensor = my_tensor.to(device) Please note that just calling ``my_tensor.to(device)`` returns a new copy of ``my_tensor`` on GPU instead of rewriting ``my_tensor``. You need to assign it to a new tensor and use that tensor on the GPU. It's natural to execute your forward, backward propagations on multiple GPUs. However, Pytorch will only use one GPU by default. You can easily run your operations on multiple GPUs by making your model run parallelly using ``DataParallel``: .. code:: python model = nn.DataParallel(model) That's the core behind this tutorial. We will explore it in more detail below. .. GENERATED FROM PYTHON SOURCE LINES 39-44 Imports and parameters ---------------------- Import PyTorch modules and define parameters. .. GENERATED FROM PYTHON SOURCE LINES 44-57 .. code-block:: default import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader # Parameters and DataLoaders input_size = 5 output_size = 2 batch_size = 30 data_size = 100 .. GENERATED FROM PYTHON SOURCE LINES 58-60 Device .. GENERATED FROM PYTHON SOURCE LINES 60-62 .. code-block:: default device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") .. GENERATED FROM PYTHON SOURCE LINES 63-69 Dummy DataSet ------------- Make a dummy (random) dataset. You just need to implement the getitem .. GENERATED FROM PYTHON SOURCE LINES 69-86 .. code-block:: default class RandomDataset(Dataset): def __init__(self, size, length): self.len = length self.data = torch.randn(length, size) def __getitem__(self, index): return self.data[index] def __len__(self): return self.len rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size), batch_size=batch_size, shuffle=True) .. GENERATED FROM PYTHON SOURCE LINES 87-98 Simple Model ------------ For the demo, our model just gets an input, performs a linear operation, and gives an output. However, you can use ``DataParallel`` on any model (CNN, RNN, Capsule Net etc.) We've placed a print statement inside the model to monitor the size of input and output tensors. Please pay attention to what is printed at batch rank 0. .. GENERATED FROM PYTHON SOURCE LINES 98-114 .. code-block:: default class Model(nn.Module): # Our model def __init__(self, input_size, output_size): super(Model, self).__init__() self.fc = nn.Linear(input_size, output_size) def forward(self, input): output = self.fc(input) print("\tIn Model: input size", input.size(), "output size", output.size()) return output .. GENERATED FROM PYTHON SOURCE LINES 115-123 Create Model and DataParallel ----------------------------- This is the core part of the tutorial. First, we need to make a model instance and check if we have multiple GPUs. If we have multiple GPUs, we can wrap our model using ``nn.DataParallel``. Then we can put our model on GPUs by ``model.to(device)`` .. GENERATED FROM PYTHON SOURCE LINES 123-133 .. code-block:: default model = Model(input_size, output_size) if torch.cuda.device_count() > 1: print("Let's use", torch.cuda.device_count(), "GPUs!") # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs model = nn.DataParallel(model) model.to(device) .. rst-class:: sphx-glr-script-out .. code-block:: none Let's use 4 GPUs! DataParallel( (module): Model( (fc): Linear(in_features=5, out_features=2, bias=True) ) ) .. GENERATED FROM PYTHON SOURCE LINES 134-139 Run the Model ------------- Now we can see the sizes of input and output tensors. .. GENERATED FROM PYTHON SOURCE LINES 139-147 .. code-block:: default for data in rand_loader: input = data.to(device) output = model(input) print("Outside: input size", input.size(), "output_size", output.size()) .. rst-class:: sphx-glr-script-out .. code-block:: none In Model: input size torch.Size([8, 5]) output size torch.Size([8, 2]) In Model: input size torch.Size([8, 5]) output size torch.Size([8, 2]) In Model: input size torch.Size([8, 5]) output size torch.Size([8, 2]) In Model: input size torch.Size([6, 5]) output size torch.Size([6, 2]) Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) In Model: input size torch.Size([8, 5]) output size torch.Size([8, 2]) In Model: input size torch.Size([8, 5]) output size torch.Size([8, 2]) In Model: input size torch.Size([8, 5]) output size torch.Size([8, 2]) In Model: input size torch.Size([6, 5]) output size torch.Size([6, 2]) Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) In Model: input size torch.Size([8, 5]) output size torch.Size([8, 2]) In Model: input size torch.Size([8, 5]) output size torch.Size([8, 2]) In Model: input size torch.Size([8, 5]) output size torch.Size([8, 2]) In Model: input size torch.Size([6, 5]) output size torch.Size([6, 2]) Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) In Model: input size torch.Size([3, 5]) output size torch.Size([3, 2]) In Model: input size torch.Size([3, 5]) output size torch.Size([3, 2]) In Model: input size torch.Size([3, 5]) output size torch.Size([3, 2]) In Model: input size torch.Size([1, 5]) output size torch.Size([1, 2]) Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 2]) .. GENERATED FROM PYTHON SOURCE LINES 148-243 Results ------- If you have no GPU or one GPU, when we batch 30 inputs and 30 outputs, the model gets 30 and outputs 30 as expected. But if you have multiple GPUs, then you can get results like this. 2 GPUs ~~~~~~ If you have 2, you will see: .. code:: bash # on 2 GPUs Let's use 2 GPUs! In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) In Model: input size torch.Size([15, 5]) output size torch.Size([15, 2]) Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) In Model: input size torch.Size([5, 5]) output size torch.Size([5, 2]) In Model: input size torch.Size([5, 5]) output size torch.Size([5, 2]) Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 2]) 3 GPUs ~~~~~~ If you have 3 GPUs, you will see: .. code:: bash Let's use 3 GPUs! In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) In Model: input size torch.Size([10, 5]) output size torch.Size([10, 2]) Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 2]) 8 GPUs ~~~~~~~~~~~~~~ If you have 8, you will see: .. code:: bash Let's use 8 GPUs! In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([4, 5]) output size torch.Size([4, 2]) In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) Outside: input size torch.Size([30, 5]) output_size torch.Size([30, 2]) In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) In Model: input size torch.Size([2, 5]) output size torch.Size([2, 2]) Outside: input size torch.Size([10, 5]) output_size torch.Size([10, 2]) .. GENERATED FROM PYTHON SOURCE LINES 246-256 Summary ------- DataParallel splits your data automatically and sends job orders to multiple models on several GPUs. After each model finishes their job, DataParallel collects and merges the results before returning it to you. For more information, please check out https://pytorch.org/tutorials/beginner/former\_torchies/parallelism\_tutorial.html. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 9.704 seconds) .. _sphx_glr_download_beginner_blitz_data_parallel_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: data_parallel_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: data_parallel_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_