fromtypingimportLiteral,OptionalimporttorchfromtorchcodecimportFrameBatchfromtorchcodec.decodersimportVideoDecoderfromtorchcodec.samplers._commonimport(_FRAMEBATCH_RETURN_DOCS,_POLICY_FUNCTION_TYPE,_POLICY_FUNCTIONS,_reshape_4d_framebatch_into_5d,_validate_common_params,)def_validate_params_index_based(*,num_clips,num_indices_between_frames):ifnum_clips<=0:raiseValueError(f"num_clips ({num_clips}) must be > 0")ifnum_indices_between_frames<=0:raiseValueError(f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive")def_validate_sampling_range_index_based(*,num_indices_between_frames,num_frames_per_clip,sampling_range_start,sampling_range_end,num_frames_in_video,):ifsampling_range_start<0:sampling_range_start=num_frames_in_video+sampling_range_startifsampling_range_start>=num_frames_in_video:raiseValueError(f"sampling_range_start ({sampling_range_start}) must be smaller than "f"the number of frames ({num_frames_in_video}).")clip_span=_get_clip_span(num_indices_between_frames=num_indices_between_frames,num_frames_per_clip=num_frames_per_clip,)ifsampling_range_endisNone:sampling_range_end=max(num_frames_in_video-clip_span+1,1)ifsampling_range_start>=sampling_range_end:raiseValueError(f"We determined that sampling_range_end should be {sampling_range_end}, ""but it is smaller than or equal to sampling_range_start "f"({sampling_range_start}).")else:ifsampling_range_end<0:# Support negative values so that -1 means last frame.sampling_range_end=num_frames_in_video+sampling_range_endsampling_range_end=min(sampling_range_end,num_frames_in_video)ifsampling_range_start>=sampling_range_end:raiseValueError(f"sampling_range_start ({sampling_range_start}) must be smaller than "f"sampling_range_end ({sampling_range_end}).")returnsampling_range_start,sampling_range_enddef_get_clip_span(*,num_indices_between_frames,num_frames_per_clip):"""Return the span of a clip, i.e. the number of frames (or indices) between the first and last frame in the clip, both included. This isn't the same as the number of frames in a clip! Example: f means a frame in the clip, x means a frame excluded from the clip num_frames_per_clip = 4 num_indices_between_frames = 1, clip = ffff , span = 4 num_indices_between_frames = 2, clip = fxfxfxf , span = 7 num_indices_between_frames = 3, clip = fxxfxxfxxf, span = 10 """returnnum_indices_between_frames*(num_frames_per_clip-1)+1def_build_all_clips_indices(*,clip_start_indices:torch.Tensor,# 1D int tensornum_frames_per_clip:int,num_indices_between_frames:int,num_frames_in_video:int,policy_fun:_POLICY_FUNCTION_TYPE,)->list[int]:# From the clip_start_indices [f_00, f_10, f_20, ...]# and from the rest of the parameters, return the list of all the frame# indices that make up all the clips.# I.e. the output is [f_00, f_01, f_02, f_03, f_10, f_11, f_12, f_13, ...]# where f_01 is the index of frame 1 in clip 0.## All clips in the output are of length num_frames_per_clip (=4 in example# above). When the frame indices go beyond num_frames_in_video, we force the# frame indices back to valid values by applying the user's policy (wrap,# repeat, etc.).all_clips_indices:list[int]=[]clip_span=_get_clip_span(num_indices_between_frames=num_indices_between_frames,num_frames_per_clip=num_frames_per_clip,)forstart_indexinclip_start_indices:frame_index_upper_bound=min(start_index+clip_span,num_frames_in_video)frame_indices=list(range(start_index,frame_index_upper_bound,num_indices_between_frames))iflen(frame_indices)<num_frames_per_clip:frame_indices=policy_fun(frame_indices,num_frames_per_clip)# type: ignore[assignment]all_clips_indices+=frame_indicesreturnall_clips_indicesdef_generic_index_based_sampler(kind:Literal["random","regular"],decoder:VideoDecoder,*,num_clips:int,num_frames_per_clip:int,num_indices_between_frames:int,sampling_range_start:int,sampling_range_end:Optional[int],# interval is [start, end).# Important note: sampling_range_end defines the upper bound of where a clip# can *start*, not where a clip can end.policy:Literal["repeat_last","wrap","error"],)->FrameBatch:_validate_common_params(decoder=decoder,num_frames_per_clip=num_frames_per_clip,policy=policy,)_validate_params_index_based(num_clips=num_clips,num_indices_between_frames=num_indices_between_frames,)sampling_range_start,sampling_range_end=_validate_sampling_range_index_based(num_frames_per_clip=num_frames_per_clip,num_indices_between_frames=num_indices_between_frames,sampling_range_start=sampling_range_start,sampling_range_end=sampling_range_end,num_frames_in_video=len(decoder),)ifkind=="random":clip_start_indices=torch.randint(low=sampling_range_start,high=sampling_range_end,size=(num_clips,))else:# Note [num clips larger than sampling range]# If we ask for more clips than there are frames in the sampling range or# in the video, we rely on torch.linspace behavior which will return# duplicated indices.# E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns# 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10# Alternatively we could wrap around, but the current behavior is closer to# the expected "equally spaced indices" sampling.clip_start_indices=torch.linspace(sampling_range_start,sampling_range_end-1,steps=num_clips,dtype=torch.int,)all_clips_indices=_build_all_clips_indices(clip_start_indices=clip_start_indices,num_frames_per_clip=num_frames_per_clip,num_indices_between_frames=num_indices_between_frames,num_frames_in_video=len(decoder),policy_fun=_POLICY_FUNCTIONS[policy],)frames=decoder.get_frames_at(indices=all_clips_indices)return_reshape_4d_framebatch_into_5d(frames=frames,num_clips=num_clips,num_frames_per_clip=num_frames_per_clip,)
[docs]defclips_at_random_indices(decoder:VideoDecoder,*,num_clips:int=1,num_frames_per_clip:int=1,num_indices_between_frames:int=1,sampling_range_start:int=0,sampling_range_end:Optional[int]=None,# interval is [start, end).policy:Literal["repeat_last","wrap","error"]="repeat_last",)->FrameBatch:# See docstring belowreturn_generic_index_based_sampler(kind="random",decoder=decoder,num_clips=num_clips,num_frames_per_clip=num_frames_per_clip,num_indices_between_frames=num_indices_between_frames,sampling_range_start=sampling_range_start,sampling_range_end=sampling_range_end,policy=policy,)
[docs]defclips_at_regular_indices(decoder:VideoDecoder,*,num_clips:int=1,num_frames_per_clip:int=1,num_indices_between_frames:int=1,sampling_range_start:int=0,sampling_range_end:Optional[int]=None,# interval is [start, end).policy:Literal["repeat_last","wrap","error"]="repeat_last",)->FrameBatch:# See docstring belowreturn_generic_index_based_sampler(kind="regular",decoder=decoder,num_clips=num_clips,num_frames_per_clip=num_frames_per_clip,num_indices_between_frames=num_indices_between_frames,sampling_range_start=sampling_range_start,sampling_range_end=sampling_range_end,policy=policy,)
_COMMON_DOCS=f""" Args: decoder (VideoDecoder): The :class:`~torchcodec.decoders.VideoDecoder` instance to sample clips from. num_clips (int, optional): The number of clips to return. Default: 1. num_frames_per_clip (int, optional): The number of frames per clips. Default: 1. num_indices_between_frames(int, optional): The number of indices between the frames *within* a clip. Default: 1, which means frames are consecutive. This is sometimes refered-to as "dilation". sampling_range_start (int, optional): The start of the sampling range, which defines the first index that a clip may *start* at. Default: 0, i.e. the start of the video. sampling_range_end (int or None, optional): The end of the sampling range, which defines the last index that a clip may *start* at. This value is exclusive, i.e. a clip may only start within [``sampling_range_start``, ``sampling_range_end``). If None (default), the value is set automatically such that the clips never span beyond the end of the video. For example if the last valid index in a video is 99 and the clips span 10 frames, this value is set to 99 - 10 + 1 = 90. Negative values are accepted and are equivalent to ``len(video) - val``. When a clip spans beyond the end of the video, the ``policy`` parameter defines how to construct such clip. policy (str, optional): Defines how to construct clips that span beyond the end of the video. This is best described with an example: assuming the last valid index in a video is 99, and a clip was sampled to start at index 95, with ``num_frames_per_clip=5`` and ``num_indices_between_frames=2``, the indices of the frames in the clip are supposed to be [95, 97, 99, 101, 103]. But 101 and 103 are invalid indices, so the ``policy`` parameter defines how to replace those frames, with valid indices: - "repeat_last": repeats the last valid frame of the clip. We would get [95, 97, 99, 99, 99]. - "wrap": wraps around to the beginning of the clip. We would get [95, 97, 99, 95, 97]. - "error": raises an error. Default is "repeat_last". Note that when ``sampling_range_end=None`` (default), this policy parameter is unlikely to be relevant.{_FRAMEBATCH_RETURN_DOCS}"""clips_at_random_indices.__doc__=f"""Sample :term:`clips` at random indices.{_COMMON_DOCS}"""clips_at_regular_indices.__doc__=f"""Sample :term:`clips` at regular (equally-spaced) indices.{_COMMON_DOCS}"""
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.