importgcimportmathimportosimportreimportwarningsfromfractionsimportFractionfromtypingimportAny,Dict,List,Optional,Tuple,Unionimportnumpyasnpimporttorchfrom..utilsimport_log_api_usage_oncefrom.import_video_opttry:importavav.logging.set_level(av.logging.ERROR)ifnothasattr(av.video.frame.VideoFrame,"pict_type"):av=ImportError("""\Your version of PyAV is too old for the necessary video operations in torchvision.If you are on Python 3.5, you will have to build from source (the conda-forgepackages are not up-to-date). Seehttps://github.com/mikeboers/PyAV#installation for instructions on how toinstall PyAV on your system.""")try:FFmpegError=av.FFmpegError# from av 14 https://github.com/PyAV-Org/PyAV/blob/main/CHANGELOG.rstexceptAttributeError:FFmpegError=av.AVErrorexceptImportError:av=ImportError("""\PyAV is not installed, and is necessary for the video operations in torchvision.See https://github.com/mikeboers/PyAV#installation for instructions on how toinstall PyAV on your system.""")def_check_av_available()->None:ifisinstance(av,Exception):raiseavdef_av_available()->bool:returnnotisinstance(av,Exception)# PyAV has some reference cycles_CALLED_TIMES=0_GC_COLLECTION_INTERVAL=10
[docs]defwrite_video(filename:str,video_array:torch.Tensor,fps:float,video_codec:str="libx264",options:Optional[Dict[str,Any]]=None,audio_array:Optional[torch.Tensor]=None,audio_fps:Optional[float]=None,audio_codec:Optional[str]=None,audio_options:Optional[Dict[str,Any]]=None,)->None:""" Writes a 4d tensor in [T, H, W, C] format in a video file. This function relies on PyAV (therefore, ultimately FFmpeg) to encode videos, you can get more fine-grained control by referring to the other options at your disposal within `the FFMpeg wiki <http://trac.ffmpeg.org/wiki#Encoding>`_. .. warning:: In the near future, we intend to centralize PyTorch's video decoding capabilities within the `torchcodec <https://github.com/pytorch/torchcodec>`_ project. We encourage you to try it out and share your feedback, as the torchvision video decoders will eventually be deprecated. Args: filename (str): path where the video will be saved video_array (Tensor[T, H, W, C]): tensor containing the individual frames, as a uint8 tensor in [T, H, W, C] format fps (Number): video frames per second video_codec (str): the name of the video codec, i.e. "libx264", "h264", etc. options (Dict): dictionary containing options to be passed into the PyAV video stream. The list of options is codec-dependent and can all be found from `the FFMpeg wiki <http://trac.ffmpeg.org/wiki#Encoding>`_. audio_array (Tensor[C, N]): tensor containing the audio, where C is the number of channels and N is the number of samples audio_fps (Number): audio sample rate, typically 44100 or 48000 audio_codec (str): the name of the audio codec, i.e. "mp3", "aac", etc. audio_options (Dict): dictionary containing options to be passed into the PyAV audio stream. The list of options is codec-dependent and can all be found from `the FFMpeg wiki <http://trac.ffmpeg.org/wiki#Encoding>`_. Examples:: >>> # Creating libx264 video with CRF 17, for visually lossless footage: >>> >>> from torchvision.io import write_video >>> # 1000 frames of 100x100, 3-channel image. >>> vid = torch.randn(1000, 100, 100, 3, dtype = torch.uint8) >>> write_video("video.mp4", options = {"crf": "17"}) """ifnottorch.jit.is_scripting()andnottorch.jit.is_tracing():_log_api_usage_once(write_video)_check_av_available()video_array=torch.as_tensor(video_array,dtype=torch.uint8).numpy(force=True)# PyAV does not support floating point numbers with decimal point# and will throw OverflowException in case this is not the caseifisinstance(fps,float):fps=np.round(fps)withav.open(filename,mode="w")ascontainer:stream=container.add_stream(video_codec,rate=fps)stream.width=video_array.shape[2]stream.height=video_array.shape[1]stream.pix_fmt="yuv420p"ifvideo_codec!="libx264rgb"else"rgb24"stream.options=optionsor{}ifaudio_arrayisnotNone:audio_format_dtypes={"dbl":"<f8","dblp":"<f8","flt":"<f4","fltp":"<f4","s16":"<i2","s16p":"<i2","s32":"<i4","s32p":"<i4","u8":"u1","u8p":"u1",}a_stream=container.add_stream(audio_codec,rate=audio_fps)a_stream.options=audio_optionsor{}num_channels=audio_array.shape[0]audio_layout="stereo"ifnum_channels>1else"mono"audio_sample_fmt=container.streams.audio[0].format.nameformat_dtype=np.dtype(audio_format_dtypes[audio_sample_fmt])audio_array=torch.as_tensor(audio_array).numpy(force=True).astype(format_dtype)frame=av.AudioFrame.from_ndarray(audio_array,format=audio_sample_fmt,layout=audio_layout)frame.sample_rate=audio_fpsforpacketina_stream.encode(frame):container.mux(packet)forpacketina_stream.encode():container.mux(packet)forimginvideo_array:frame=av.VideoFrame.from_ndarray(img,format="rgb24")try:frame.pict_type="NONE"exceptTypeError:fromav.video.frameimportPictureType# noqaframe.pict_type=PictureType.NONEforpacketinstream.encode(frame):container.mux(packet)# Flush streamforpacketinstream.encode():container.mux(packet)
def_read_from_stream(container:"av.container.Container",start_offset:float,end_offset:float,pts_unit:str,stream:"av.stream.Stream",stream_name:Dict[str,Optional[Union[int,Tuple[int,...],List[int]]]],)->List["av.frame.Frame"]:global_CALLED_TIMES,_GC_COLLECTION_INTERVAL_CALLED_TIMES+=1if_CALLED_TIMES%_GC_COLLECTION_INTERVAL==_GC_COLLECTION_INTERVAL-1:gc.collect()ifpts_unit=="sec":# TODO: we should change all of this from ground up to simply take# sec and convert to MS in C++start_offset=int(math.floor(start_offset*(1/stream.time_base)))ifend_offset!=float("inf"):end_offset=int(math.ceil(end_offset*(1/stream.time_base)))else:warnings.warn("The pts_unit 'pts' gives wrong results. Please use pts_unit 'sec'.")frames={}should_buffer=Truemax_buffer_size=5ifstream.type=="video":# DivX-style packed B-frames can have out-of-order pts (2 frames in a single pkt)# so need to buffer some extra frames to sort everything# properlyextradata=stream.codec_context.extradata# overly complicated way of finding if `divx_packed` is set, following# https://github.com/FFmpeg/FFmpeg/commit/d5a21172283572af587b3d939eba0091484d3263ifextradataandb"DivX"inextradata:# can't use regex directly because of some weird characters sometimes...pos=extradata.find(b"DivX")d=extradata[pos:]o=re.search(rb"DivX(\d+)Build(\d+)(\w)",d)ifoisNone:o=re.search(rb"DivX(\d+)b(\d+)(\w)",d)ifoisnotNone:should_buffer=o.group(3)==b"p"seek_offset=start_offset# some files don't seek to the right location, so better be safe hereseek_offset=max(seek_offset-1,0)ifshould_buffer:# FIXME this is kind of a hack, but we will jump to the previous keyframe# so this will be safeseek_offset=max(seek_offset-max_buffer_size,0)try:# TODO check if stream needs to always be the video stream here or notcontainer.seek(seek_offset,any_frame=False,backward=True,stream=stream)exceptFFmpegError:# TODO add some warnings in this case# print("Corrupted file?", container.name)return[]buffer_count=0try:for_idx,frameinenumerate(container.decode(**stream_name)):frames[frame.pts]=frameifframe.pts>=end_offset:ifshould_bufferandbuffer_count<max_buffer_size:buffer_count+=1continuebreakexceptFFmpegError:# TODO add a warningpass# ensure that the results are sorted wrt the ptsresult=[frames[i]foriinsorted(frames)ifstart_offset<=frames[i].pts<=end_offset]iflen(frames)>0andstart_offset>0andstart_offsetnotinframes:# if there is no frame that exactly matches the pts of start_offset# add the last frame smaller than start_offset, to guarantee that# we will have all the necessary data. This is most useful for audiopreceding_frames=[iforiinframesifi<start_offset]iflen(preceding_frames)>0:first_frame_pts=max(preceding_frames)result.insert(0,frames[first_frame_pts])returnresultdef_align_audio_frames(aframes:torch.Tensor,audio_frames:List["av.frame.Frame"],ref_start:int,ref_end:float)->torch.Tensor:start,end=audio_frames[0].pts,audio_frames[-1].ptstotal_aframes=aframes.shape[1]step_per_aframe=(end-start+1)/total_aframess_idx=0e_idx=total_aframesifstart<ref_start:s_idx=int((ref_start-start)/step_per_aframe)ifend>ref_end:e_idx=int((ref_end-end)/step_per_aframe)returnaframes[:,s_idx:e_idx]
[docs]defread_video(filename:str,start_pts:Union[float,Fraction]=0,end_pts:Optional[Union[float,Fraction]]=None,pts_unit:str="pts",output_format:str="THWC",)->Tuple[torch.Tensor,torch.Tensor,Dict[str,Any]]:""" Reads a video from a file, returning both the video frames and the audio frames .. warning:: In the near future, we intend to centralize PyTorch's video decoding capabilities within the `torchcodec <https://github.com/pytorch/torchcodec>`_ project. We encourage you to try it out and share your feedback, as the torchvision video decoders will eventually be deprecated. Args: filename (str): path to the video file. If using the pyav backend, this can be whatever ``av.open`` accepts. start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): The start presentation time of the video end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): The end presentation time pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, either 'pts' or 'sec'. Defaults to 'pts'. output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". Returns: vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) """ifnottorch.jit.is_scripting()andnottorch.jit.is_tracing():_log_api_usage_once(read_video)output_format=output_format.upper()ifoutput_formatnotin("THWC","TCHW"):raiseValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.")fromtorchvisionimportget_video_backendifget_video_backend()!="pyav":ifnotos.path.exists(filename):raiseRuntimeError(f"File not found: {filename}")vframes,aframes,info=_video_opt._read_video(filename,start_pts,end_pts,pts_unit)else:_check_av_available()ifend_ptsisNone:end_pts=float("inf")ifend_pts<start_pts:raiseValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}")info={}video_frames=[]audio_frames=[]audio_timebase=_video_opt.default_timebasetry:withav.open(filename,metadata_errors="ignore")ascontainer:ifcontainer.streams.audio:audio_timebase=container.streams.audio[0].time_baseifcontainer.streams.video:video_frames=_read_from_stream(container,start_pts,end_pts,pts_unit,container.streams.video[0],{"video":0},)video_fps=container.streams.video[0].average_rate# guard against potentially corrupted filesifvideo_fpsisnotNone:info["video_fps"]=float(video_fps)ifcontainer.streams.audio:audio_frames=_read_from_stream(container,start_pts,end_pts,pts_unit,container.streams.audio[0],{"audio":0},)info["audio_fps"]=container.streams.audio[0].rateexceptFFmpegError:# TODO raise a warning?passvframes_list=[frame.to_rgb().to_ndarray()forframeinvideo_frames]aframes_list=[frame.to_ndarray()forframeinaudio_frames]ifvframes_list:vframes=torch.as_tensor(np.stack(vframes_list))else:vframes=torch.empty((0,1,1,3),dtype=torch.uint8)ifaframes_list:aframes=np.concatenate(aframes_list,1)aframes=torch.as_tensor(aframes)ifpts_unit=="sec":start_pts=int(math.floor(start_pts*(1/audio_timebase)))ifend_pts!=float("inf"):end_pts=int(math.ceil(end_pts*(1/audio_timebase)))aframes=_align_audio_frames(aframes,audio_frames,start_pts,end_pts)else:aframes=torch.empty((1,0),dtype=torch.float32)ifoutput_format=="TCHW":# [T,H,W,C] --> [T,C,H,W]vframes=vframes.permute(0,3,1,2)returnvframes,aframes,info
def_can_read_timestamps_from_packets(container:"av.container.Container")->bool:extradata=container.streams[0].codec_context.extradataifextradataisNone:returnFalseifb"Lavc"inextradata:returnTruereturnFalsedef_decode_video_timestamps(container:"av.container.Container")->List[int]:if_can_read_timestamps_from_packets(container):# fast pathreturn[x.ptsforxincontainer.demux(video=0)ifx.ptsisnotNone]else:return[x.ptsforxincontainer.decode(video=0)ifx.ptsisnotNone]
[docs]defread_video_timestamps(filename:str,pts_unit:str="pts")->Tuple[List[int],Optional[float]]:""" List the video frames timestamps. .. warning:: In the near future, we intend to centralize PyTorch's video decoding capabilities within the `torchcodec <https://github.com/pytorch/torchcodec>`_ project. We encourage you to try it out and share your feedback, as the torchvision video decoders will eventually be deprecated. Note that the function decodes the whole video frame-by-frame. Args: filename (str): path to the video file pts_unit (str, optional): unit in which timestamp values will be returned either 'pts' or 'sec'. Defaults to 'pts'. Returns: pts (List[int] if pts_unit = 'pts', List[Fraction] if pts_unit = 'sec'): presentation timestamps for each one of the frames in the video. video_fps (float, optional): the frame rate for the video """ifnottorch.jit.is_scripting()andnottorch.jit.is_tracing():_log_api_usage_once(read_video_timestamps)fromtorchvisionimportget_video_backendifget_video_backend()!="pyav":return_video_opt._read_video_timestamps(filename,pts_unit)_check_av_available()video_fps=Nonepts=[]try:withav.open(filename,metadata_errors="ignore")ascontainer:ifcontainer.streams.video:video_stream=container.streams.video[0]video_time_base=video_stream.time_basetry:pts=_decode_video_timestamps(container)exceptFFmpegError:warnings.warn(f"Failed decoding frames for file {filename}")video_fps=float(video_stream.average_rate)exceptFFmpegErrorase:msg=f"Failed to open container for {filename}; Caught error: {e}"warnings.warn(msg,RuntimeWarning)pts.sort()ifpts_unit=="sec":pts=[x*video_time_baseforxinpts]returnpts,video_fps
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.