importosimportsysfromtypingimportCallable,Optionalimporttorchfromtorch.typesimportStorage__all__:list[str]=["gds_register_buffer","gds_deregister_buffer","GdsFile",]def_dummy_fn(name:str)->Callable:deffn(*args,**kwargs):# type: ignore[no-untyped-def]raiseRuntimeError(f"torch._C.{name} is not supported on this platform")returnfnifnothasattr(torch._C,"_gds_register_buffer"):assertnothasattr(torch._C,"_gds_deregister_buffer")assertnothasattr(torch._C,"_gds_register_handle")assertnothasattr(torch._C,"_gds_deregister_handle")assertnothasattr(torch._C,"_gds_load_storage")assertnothasattr(torch._C,"_gds_save_storage")# Define functionstorch._C.__dict__["_gds_register_buffer"]=_dummy_fn("_gds_register_buffer")torch._C.__dict__["_gds_deregister_buffer"]=_dummy_fn("_gds_deregister_buffer")torch._C.__dict__["_gds_register_handle"]=_dummy_fn("_gds_register_handle")torch._C.__dict__["_gds_deregister_handle"]=_dummy_fn("_gds_deregister_handle")torch._C.__dict__["_gds_load_storage"]=_dummy_fn("_gds_load_storage")torch._C.__dict__["_gds_save_storage"]=_dummy_fn("_gds_save_storage")
[docs]defgds_register_buffer(s:Storage)->None:"""Registers a storage on a CUDA device as a cufile buffer. Example:: >>> # xdoctest: +SKIP("gds filesystem requirements") >>> src = torch.randn(1024, device="cuda") >>> s = src.untyped_storage() >>> gds_register_buffer(s) Args: s (Storage): Buffer to register. """torch._C._gds_register_buffer(s)
[docs]defgds_deregister_buffer(s:Storage)->None:"""Deregisters a previously registered storage on a CUDA device as a cufile buffer. Example:: >>> # xdoctest: +SKIP("gds filesystem requirements") >>> src = torch.randn(1024, device="cuda") >>> s = src.untyped_storage() >>> gds_register_buffer(s) >>> gds_deregister_buffer(s) Args: s (Storage): Buffer to register. """torch._C._gds_deregister_buffer(s)
[docs]classGdsFile:r"""Wrapper around cuFile. cuFile is a file-like interface to the GPUDirect Storage (GDS) API. See the `cufile docs <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api>`_ for more details. Args: filename (str): Name of the file to open. flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will be added automatically. Example:: >>> # xdoctest: +SKIP("gds filesystem requirements") >>> src1 = torch.randn(1024, device="cuda") >>> src2 = torch.randn(2, 1024, device="cuda") >>> file = torch.cuda.gds.GdsFile(f, os.O_CREAT | os.O_RDWR) >>> file.save_storage(src1.untyped_storage(), offset=0) >>> file.save_storage(src2.untyped_storage(), offset=src1.nbytes) >>> dest1 = torch.empty(1024, device="cuda") >>> dest2 = torch.empty(2, 1024, device="cuda") >>> file.load_storage(dest1.untyped_storage(), offset=0) >>> file.load_storage(dest2.untyped_storage(), offset=src1.nbytes) >>> torch.equal(src1, dest1) True >>> torch.equal(src2, dest2) True """def__init__(self,filename:str,flags:int):ifsys.platform=="win32":raiseRuntimeError("GdsFile is not supported on this platform.")self.filename=filenameself.flags=flagsself.fd=os.open(filename,flags|os.O_DIRECT)# type: ignore[attr-defined]self.handle:Optional[int]=Noneself.register_handle()def__del__(self)->None:ifself.handleisnotNone:self.deregister_handle()os.close(self.fd)
[docs]defregister_handle(self)->None:"""Registers file descriptor to cuFile Driver. This is a wrapper around ``cuFileHandleRegister``. """assert(self.handleisNone),"Cannot register a handle that is already registered."self.handle=torch._C._gds_register_handle(self.fd)
[docs]defderegister_handle(self)->None:"""Deregisters file descriptor from cuFile Driver. This is a wrapper around ``cuFileHandleDeregister``. """assert(self.handleisnotNone),"Cannot deregister a handle that is not registered."torch._C._gds_deregister_handle(self.handle)self.handle=None
[docs]defload_storage(self,storage:Storage,offset:int=0)->None:"""Loads data from the file into the storage. This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data will be loaded from the file at ``offset`` into the storage. Args: storage (Storage): Storage to load data into. offset (int, optional): Offset into the file to start loading from. (Default: 0) """assert(self.handleisnotNone),"Cannot load data from a file that is not registered."torch._C._gds_load_storage(self.handle,storage,offset)
[docs]defsave_storage(self,storage:Storage,offset:int=0)->None:"""Saves data from the storage into the file. This is a wrapper around ``cuFileWrite``. All bytes of the storage will be written to the file at ``offset``. Args: storage (Storage): Storage to save data from. offset (int, optional): Offset into the file to start saving to. (Default: 0) """assert(self.handleisnotNone),"Cannot save data to a file that is not registered."torch._C._gds_save_storage(self.handle,storage,offset)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.