Source code for torch.ao.nn.quantized.dynamic.modules.rnn
# mypy: allow-untyped-decorators# mypy: allow-untyped-defsimportnumbersimportwarningsfromtyping_extensionsimportdeprecatedimporttorchimporttorch.nnasnnfromtorchimportTensor# noqa: F401fromtorch._jit_internalimportDict,List,Optional,Tuple,Union# noqa: F401fromtorch.ao.nn.quantized.modules.utilsimport_quantize_weightfromtorch.nn.utils.rnnimportPackedSequence__all__=["pack_weight_bias","PackedParameter","RNNBase","LSTM","GRU","RNNCellBase","RNNCell","LSTMCell","GRUCell","apply_permutation",]def_apply_permutation(tensor:Tensor,permutation:Tensor,dim:int=1)->Tensor:returntensor.index_select(dim,permutation)@deprecated("`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead",category=FutureWarning,)defapply_permutation(tensor:Tensor,permutation:Tensor,dim:int=1)->Tensor:return_apply_permutation(tensor,permutation,dim)defpack_weight_bias(qweight,bias,dtype):ifdtype==torch.qint8:# for each layer, for each direction we need to quantize and pack# weights and pack parameters in this order:## w_ih, w_hhpacked_weight=torch.ops.quantized.linear_prepack(qweight,bias)returnpacked_weightelse:# for each layer, for each direction we need to quantize and pack# weights and pack parameters in this order:## packed_ih, packed_hh, b_ih, b_hhpacked_weight=torch.ops.quantized.linear_prepack_fp16(qweight,bias)returnpacked_weightclassPackedParameter(torch.nn.Module):def__init__(self,param):super().__init__()self.param=paramdef_save_to_state_dict(self,destination,prefix,keep_vars):super()._save_to_state_dict(destination,prefix,keep_vars)destination[prefix+"param"]=self.paramdef_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs,):self.param=state_dict[prefix+"param"]super()._load_from_state_dict(state_dict,prefix,local_metadata,False,missing_keys,unexpected_keys,error_msgs,)classRNNBase(torch.nn.Module):_FLOAT_MODULE=nn.RNNBase_version=2def__init__(self,mode,input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,dtype=torch.qint8,):super().__init__()self.mode=modeself.input_size=input_sizeself.hidden_size=hidden_sizeself.num_layers=num_layersself.bias=biasself.batch_first=batch_firstself.dropout=float(dropout)self.bidirectional=bidirectionalself.dtype=dtypeself.version=2self.training=Falsenum_directions=2ifbidirectionalelse1# "type: ignore" is required since ints and Numbers are not fully comparable# https://github.com/python/mypy/issues/8566if(notisinstance(dropout,numbers.Number)ornot0<=dropout<=1# type: ignore[operator]orisinstance(dropout,bool)):raiseValueError("dropout should be a number in range [0, 1] ""representing the probability of an element being ""zeroed")ifdropout>0andnum_layers==1:# type: ignore[operator]warnings.warn("dropout option adds dropout after all but last ""recurrent layer, so non-zero dropout expects "f"num_layers greater than 1, but got dropout={dropout} and "f"num_layers={num_layers}")ifmode=="LSTM":gate_size=4*hidden_sizeelifmode=="GRU":gate_size=3*hidden_sizeelse:raiseValueError("Unrecognized RNN mode: "+mode)_all_weight_values=[]forlayerinrange(num_layers):fordirectioninrange(num_directions):layer_input_size=(input_sizeiflayer==0elsehidden_size*num_directions)w_ih=torch.randn(gate_size,layer_input_size).to(torch.float)w_hh=torch.randn(gate_size,hidden_size).to(torch.float)b_ih=torch.randn(gate_size).to(torch.float)b_hh=torch.randn(gate_size).to(torch.float)ifdtype==torch.qint8:w_ih=torch.quantize_per_tensor(w_ih,scale=0.1,zero_point=0,dtype=torch.qint8)w_hh=torch.quantize_per_tensor(w_hh,scale=0.1,zero_point=0,dtype=torch.qint8)packed_ih=torch.ops.quantized.linear_prepack(w_ih,b_ih)packed_hh=torch.ops.quantized.linear_prepack(w_hh,b_hh)ifself.versionisNoneorself.version<2:cell_params=(torch.ops.quantized.make_quantized_cell_params_dynamic(packed_ih,packed_hh,b_ih,b_hh))else:cell_params=(torch.ops.quantized.make_quantized_cell_params_dynamic(packed_ih,packed_hh,b_ih,b_hh,True))else:packed_ih=torch.ops.quantized.linear_prepack_fp16(w_ih,b_ih)packed_hh=torch.ops.quantized.linear_prepack_fp16(w_hh,b_hh)cell_params=torch.ops.quantized.make_quantized_cell_params_fp16(packed_ih,packed_hh)_all_weight_values.append(PackedParameter(cell_params))self._all_weight_values=torch.nn.ModuleList(_all_weight_values)def_get_name(self):return"DynamicQuantizedRNN"defextra_repr(self):s="{input_size}, {hidden_size}"ifself.num_layers!=1:s+=", num_layers={num_layers}"ifself.biasisnotTrue:s+=", bias={bias}"ifself.batch_firstisnotFalse:s+=", batch_first={batch_first}"ifself.dropout!=0:s+=", dropout={dropout}"ifself.bidirectionalisnotFalse:s+=", bidirectional={bidirectional}"returns.format(**self.__dict__)def__repr__(self):# We don't want to show `ModuleList` children, hence custom# `__repr__`. This is the same as nn.Module.__repr__, except the check# for the `PackedParameter` and `nn.ModuleList`.# You should still override `extra_repr` to add more info.extra_lines=[]extra_repr=self.extra_repr()# empty string will be split into list ['']ifextra_repr:extra_lines=extra_repr.split("\n")child_lines=[]forkey,moduleinself._modules.items():ifisinstance(module,(PackedParameter,nn.ModuleList)):continuemod_str=repr(module)mod_str=nn.modules.module._addindent(mod_str,2)child_lines.append("("+key+"): "+mod_str)lines=extra_lines+child_linesmain_str=self._get_name()+"("iflines:# simple one-liner info, which most builtin Modules will useiflen(extra_lines)==1andnotchild_lines:main_str+=extra_lines[0]else:main_str+="\n "+"\n ".join(lines)+"\n"main_str+=")"returnmain_strdefcheck_input(self,input:Tensor,batch_sizes:Optional[Tensor])->None:expected_input_dim=2ifbatch_sizesisnotNoneelse3ifinput.dim()!=expected_input_dim:raiseRuntimeError(f"input must have {expected_input_dim} dimensions, got {input.dim()}")ifself.input_size!=input.size(-1):raiseRuntimeError(f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}")defget_expected_hidden_size(self,input:Tensor,batch_sizes:Optional[Tensor])->Tuple[int,int,int]:ifbatch_sizesisnotNone:mini_batch=int(batch_sizes[0])else:mini_batch=input.size(0)ifself.batch_firstelseinput.size(1)num_directions=2ifself.bidirectionalelse1expected_hidden_size=(self.num_layers*num_directions,mini_batch,self.hidden_size,)returnexpected_hidden_sizedefcheck_hidden_size(self,hx:Tensor,expected_hidden_size:Tuple[int,int,int],msg:str="Expected hidden size {}, got {}",)->None:ifhx.size()!=expected_hidden_size:raiseRuntimeError(msg.format(expected_hidden_size,list(hx.size())))defcheck_forward_args(self,input:Tensor,hidden:Tensor,batch_sizes:Optional[Tensor])->None:self.check_input(input,batch_sizes)expected_hidden_size=self.get_expected_hidden_size(input,batch_sizes)self.check_hidden_size(hidden,expected_hidden_size,msg="Expected hidden size {}, got {}")defpermute_hidden(self,hx:Tensor,permutation:Optional[Tensor])->Tensor:ifpermutationisNone:returnhxreturn_apply_permutation(hx,permutation)def_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs,):version=local_metadata.get("version",None)self.version=versionsuper()._load_from_state_dict(state_dict,prefix,local_metadata,False,missing_keys,unexpected_keys,error_msgs,)defset_weight_bias(self,weight_bias_dict):defweight_bias_name(ihhh,layer,suffix):weight_name=f"weight_{ihhh}_l{layer}{suffix}"bias_name=f"bias_{ihhh}_l{layer}{suffix}"returnweight_name,bias_namenum_directions=2ifself.bidirectionalelse1# TODO: dedup with __init__ of RNNBase_all_weight_values=[]forlayerinrange(self.num_layers):fordirectioninrange(num_directions):suffix="_reverse"ifdirection==1else""w_ih_name,b_ih_name=weight_bias_name("ih",layer,suffix)w_hh_name,b_hh_name=weight_bias_name("hh",layer,suffix)w_ih=weight_bias_dict[w_ih_name]b_ih=weight_bias_dict[b_ih_name]w_hh=weight_bias_dict[w_hh_name]b_hh=weight_bias_dict[b_hh_name]ifw_ih.dtype==torch.qint8:packed_ih=torch.ops.quantized.linear_prepack(w_ih,b_ih)packed_hh=torch.ops.quantized.linear_prepack(w_hh,b_hh)ifself.versionisNoneorself.version<2:cell_params=(torch.ops.quantized.make_quantized_cell_params_dynamic(packed_ih,packed_hh,b_ih,b_hh))else:cell_params=(torch.ops.quantized.make_quantized_cell_params_dynamic(packed_ih,packed_hh,b_ih,b_hh,True))else:packed_ih=torch.ops.quantized.linear_prepack_fp16(w_ih,b_ih)packed_hh=torch.ops.quantized.linear_prepack_fp16(w_hh,b_hh)cell_params=torch.ops.quantized.make_quantized_cell_params_fp16(packed_ih,packed_hh)_all_weight_values.append(PackedParameter(cell_params))self._all_weight_values=torch.nn.ModuleList(_all_weight_values)@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):asserttype(mod)in{torch.nn.LSTM,torch.nn.GRU,},"nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU"asserthasattr(mod,"qconfig"),"Input float module must have qconfig defined"ifmod.qconfigisnotNoneandmod.qconfig.weightisnotNone:weight_observer_method=mod.qconfig.weightelse:# We have the circular import issues if we import the qconfig in the beginning of this file:# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the# import until we need it.fromtorch.ao.quantization.qconfigimportdefault_dynamic_qconfigweight_observer_method=default_dynamic_qconfig.weightdtype=weight_observer_method().dtypesupported_scalar_types=[torch.qint8,torch.float16]ifdtypenotinsupported_scalar_types:raiseRuntimeError(f"Unsupported dtype for dynamic RNN quantization: {dtype}")# RNNBase can be either LSTM or GRUqRNNBase:Union[LSTM,GRU]ifmod.mode=="LSTM":qRNNBase=LSTM(mod.input_size,mod.hidden_size,mod.num_layers,mod.bias,mod.batch_first,mod.dropout,mod.bidirectional,dtype,)elifmod.mode=="GRU":qRNNBase=GRU(mod.input_size,mod.hidden_size,mod.num_layers,mod.bias,mod.batch_first,mod.dropout,mod.bidirectional,dtype,)else:raiseNotImplementedError("Only LSTM/GRU is supported for QuantizedRNN for now")num_directions=2ifmod.bidirectionalelse1assertmod.bias_all_weight_values=[]forlayerinrange(qRNNBase.num_layers):fordirectioninrange(num_directions):suffix="_reverse"ifdirection==1else""defretrieve_weight_bias(ihhh):weight_name=f"weight_{ihhh}_l{layer}{suffix}"bias_name=f"bias_{ihhh}_l{layer}{suffix}"weight=getattr(mod,weight_name)bias=getattr(mod,bias_name)returnweight,biasweight_ih,bias_ih=retrieve_weight_bias("ih")weight_hh,bias_hh=retrieve_weight_bias("hh")ifdtype==torch.qint8:defquantize_and_pack(w,b):weight_observer=weight_observer_method()weight_observer(w)qweight=_quantize_weight(w.float(),weight_observer)packed_weight=torch.ops.quantized.linear_prepack(qweight,b)returnpacked_weightpacked_ih=quantize_and_pack(weight_ih,bias_ih)packed_hh=quantize_and_pack(weight_hh,bias_hh)ifqRNNBase.versionisNoneorqRNNBase.version<2:cell_params=(torch.ops.quantized.make_quantized_cell_params_dynamic(packed_ih,packed_hh,bias_ih,bias_hh))else:cell_params=(torch.ops.quantized.make_quantized_cell_params_dynamic(packed_ih,packed_hh,bias_ih,bias_hh,True))elifdtype==torch.float16:packed_ih=torch.ops.quantized.linear_prepack_fp16(weight_ih.float(),bias_ih)packed_hh=torch.ops.quantized.linear_prepack_fp16(weight_hh.float(),bias_hh)cell_params=torch.ops.quantized.make_quantized_cell_params_fp16(packed_ih,packed_hh)else:raiseRuntimeError("Unsupported dtype specified for dynamic quantized LSTM!")_all_weight_values.append(PackedParameter(cell_params))qRNNBase._all_weight_values=torch.nn.ModuleList(_all_weight_values)returnqRNNBasedef_weight_bias(self):# Returns a dict of weights and biasesweight_bias_dict:Dict[str,Dict]={"weight":{},"bias":{}}count=0num_directions=2ifself.bidirectionalelse1forlayerinrange(self.num_layers):fordirectioninrange(num_directions):suffix="_reverse"ifdirection==1else""key_name1=f"weight_ih_l{layer}{suffix}"key_name2=f"weight_hh_l{layer}{suffix}"# packed weights are part of torchbind class, CellParamsSerializationType# Within the packed weight class, the weight and bias are accessible as Tensorspacked_weight_bias=self._all_weight_values[count].param.__getstate__()[0][4]weight_bias_dict["weight"][key_name1]=packed_weight_bias[0].__getstate__()[0][0]weight_bias_dict["weight"][key_name2]=packed_weight_bias[1].__getstate__()[0][0]key_name1=f"bias_ih_l{layer}{suffix}"key_name2=f"bias_hh_l{layer}{suffix}"weight_bias_dict["bias"][key_name1]=packed_weight_bias[0].__getstate__()[0][1]weight_bias_dict["bias"][key_name2]=packed_weight_bias[1].__getstate__()[0][1]count=count+1returnweight_bias_dictdefget_weight(self):returnself._weight_bias()["weight"]defget_bias(self):returnself._weight_bias()["bias"]
[docs]classLSTM(RNNBase):r""" A dynamic quantized LSTM module with floating point tensor as inputs and outputs. We adopt the same interface as `torch.nn.LSTM`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM for documentation. Examples:: >>> # xdoctest: +SKIP >>> rnn = nn.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> c0 = torch.randn(2, 3, 20) >>> output, (hn, cn) = rnn(input, (h0, c0)) """_FLOAT_MODULE=nn.LSTM__overloads__={"forward":["forward_packed","forward_tensor"]}def__init__(self,*args,**kwargs):super().__init__("LSTM",*args,**kwargs)def_get_name(self):return"DynamicQuantizedLSTM"defforward_impl(self,input:Tensor,hx:Optional[Tuple[Tensor,Tensor]],batch_sizes:Optional[Tensor],max_batch_size:int,sorted_indices:Optional[Tensor],)->Tuple[Tensor,Tuple[Tensor,Tensor]]:ifhxisNone:num_directions=2ifself.bidirectionalelse1zeros=torch.zeros(self.num_layers*num_directions,max_batch_size,self.hidden_size,dtype=input.dtype,device=input.device,)hx=(zeros,zeros)else:# Each batch of the hidden state should match the input sequence that# the user believes he/she is passing in.hx=self.permute_hidden(hx,sorted_indices)self.check_forward_args(input,hx,batch_sizes)_all_params=[m.paramforminself._all_weight_values]ifbatch_sizesisNone:result=torch.quantized_lstm(input,hx,_all_params,self.bias,self.num_layers,float(self.dropout),self.training,self.bidirectional,self.batch_first,dtype=self.dtype,use_dynamic=True,)else:result=torch.quantized_lstm(input,batch_sizes,hx,_all_params,self.bias,self.num_layers,float(self.dropout),self.training,self.bidirectional,dtype=self.dtype,use_dynamic=True,)output=result[0]hidden=result[1:]returnoutput,hidden@torch.jit.exportdefforward_tensor(self,input:Tensor,hx:Optional[Tuple[Tensor,Tensor]]=None)->Tuple[Tensor,Tuple[Tensor,Tensor]]:batch_sizes=Nonemax_batch_size=input.size(0)ifself.batch_firstelseinput.size(1)sorted_indices=Noneunsorted_indices=Noneoutput,hidden=self.forward_impl(input,hx,batch_sizes,max_batch_size,sorted_indices)returnoutput,self.permute_hidden(hidden,unsorted_indices)@torch.jit.exportdefforward_packed(self,input:PackedSequence,hx:Optional[Tuple[Tensor,Tensor]]=None)->Tuple[PackedSequence,Tuple[Tensor,Tensor]]:input_,batch_sizes,sorted_indices,unsorted_indices=inputmax_batch_size=int(batch_sizes[0])output_,hidden=self.forward_impl(input_,hx,batch_sizes,max_batch_size,sorted_indices)output=PackedSequence(output_,batch_sizes,sorted_indices,unsorted_indices)returnoutput,self.permute_hidden(hidden,unsorted_indices)# "type: ignore" is required due to issue #43072defpermute_hidden(# type: ignore[override]self,hx:Tuple[Tensor,Tensor],permutation:Optional[Tensor],)->Tuple[Tensor,Tensor]:ifpermutationisNone:returnhxreturn_apply_permutation(hx[0],permutation),_apply_permutation(hx[1],permutation)# "type: ignore" is required due to issue #43072defcheck_forward_args(# type: ignore[override]self,input:Tensor,hidden:Tuple[Tensor,Tensor],batch_sizes:Optional[Tensor],)->None:self.check_input(input,batch_sizes)expected_hidden_size=self.get_expected_hidden_size(input,batch_sizes)self.check_hidden_size(hidden[0],expected_hidden_size,"Expected hidden[0] size {}, got {}")self.check_hidden_size(hidden[1],expected_hidden_size,"Expected hidden[1] size {}, got {}")@torch.jit.ignoredefforward(self,input,hx=None):ifisinstance(input,PackedSequence):returnself.forward_packed(input,hx)else:returnself.forward_tensor(input,hx)@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):returnsuper().from_float(mod,use_precomputed_fake_quant=use_precomputed_fake_quant)@classmethoddeffrom_reference(cls,ref_mod):asserthasattr(ref_mod,"weight_ih_l0_dtype"),"We are assuming weight_ih_l0 ""exists in LSTM, may need to relax the assumption to support the use case"qmod=cls(ref_mod.input_size,ref_mod.hidden_size,ref_mod.num_layers,ref_mod.bias,ref_mod.batch_first,ref_mod.dropout,ref_mod.bidirectional,# assuming there is layer 0, which should be OKref_mod.weight_ih_l0_dtype,)qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict())returnqmod
[docs]classGRU(RNNBase):r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. For each element in the input sequence, each layer computes the following function: .. math:: \begin{array}{ll} r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\ h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} \end{array} where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random variable which is :math:`0` with probability :attr:`dropout`. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two GRUs together to form a `stacked GRU`, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1 bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each GRU layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` Inputs: input, h_0 - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features of the input sequence. The input can also be a packed variable length sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` for details. - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided. If the RNN is bidirectional, num_directions should be 2, else it should be 1. Outputs: output, h_n - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor containing the output features h_t from the last layer of the GRU, for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output will also be a packed sequence. For the unpacked case, the directions can be separated using ``output.view(seq_len, batch, num_directions, hidden_size)``, with forward and backward being direction `0` and `1` respectively. Similarly, the directions can be separated in the packed case. - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor containing the hidden state for `t = seq_len` Like *output*, the layers can be separated using ``h_n.view(num_layers, num_directions, batch, hidden_size)``. Shape: - Input1: :math:`(L, N, H_{in})` tensor containing input features where :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. - Input2: :math:`(S, N, H_{out})` tensor containing the initial hidden state for each element in the batch. :math:`H_{out}=\text{hidden\_size}` Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` If the RNN is bidirectional, num_directions should be 2, else it should be 1. - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state for each element in the batch Attributes: weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer (b_ir|b_iz|b_in), of shape `(3*hidden_size)` bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` .. note:: The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks. In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix `W` and addition of bias: .. math:: \begin{aligned} n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn}) \end{aligned} This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}` .. math:: \begin{aligned} n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \end{aligned} This implementation differs on purpose for efficiency. .. include:: ../cudnn_persistent_rnn.rst Examples:: >>> # xdoctest: +SKIP >>> rnn = nn.GRU(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> output, hn = rnn(input, h0) """_FLOAT_MODULE=nn.GRU__overloads__={"forward":["forward_packed","forward_tensor"]}def__init__(self,*args,**kwargs):super().__init__("GRU",*args,**kwargs)def_get_name(self):return"DynamicQuantizedGRU"defcheck_forward_args(self,input:Tensor,hidden:Tensor,batch_sizes:Optional[Tensor])->None:self.check_input(input,batch_sizes)expected_hidden_size=self.get_expected_hidden_size(input,batch_sizes)self.check_hidden_size(hidden,expected_hidden_size,"Expected hidden size {}, got {}")defforward_impl(self,input:Tensor,hx:Optional[Tensor],batch_sizes:Optional[Tensor],max_batch_size:int,sorted_indices:Optional[Tensor],)->Tuple[Tensor,Tensor]:ifhxisNone:num_directions=2ifself.bidirectionalelse1zeros=torch.zeros(self.num_layers*num_directions,max_batch_size,self.hidden_size,dtype=input.dtype,device=input.device,)hx=zeroselse:# Each batch of the hidden state should match the input sequence that# the user believes he/she is passing in.hx=self.permute_hidden(hx,sorted_indices)self.check_forward_args(input,hx,batch_sizes)_all_params=[m.paramforminself._all_weight_values]ifbatch_sizesisNone:result=torch.quantized_gru(input,hx,_all_params,self.bias,self.num_layers,self.dropout,self.training,self.bidirectional,self.batch_first,)else:result=torch.quantized_gru(input,batch_sizes,hx,_all_params,self.bias,self.num_layers,self.dropout,self.training,self.bidirectional,)output=result[0]hidden=result[1]returnoutput,hidden@torch.jit.exportdefforward_tensor(self,input:Tensor,hx:Optional[Tensor]=None)->Tuple[Tensor,Tensor]:batch_sizes=Nonemax_batch_size=input.size(0)ifself.batch_firstelseinput.size(1)sorted_indices=Noneunsorted_indices=Noneoutput,hidden=self.forward_impl(input,hx,batch_sizes,max_batch_size,sorted_indices)returnoutput,self.permute_hidden(hidden,unsorted_indices)@torch.jit.exportdefforward_packed(self,input:PackedSequence,hx:Optional[Tensor]=None)->Tuple[PackedSequence,Tensor]:input_,batch_sizes,sorted_indices,unsorted_indices=inputmax_batch_size=int(batch_sizes[0])output_,hidden=self.forward_impl(input_,hx,batch_sizes,max_batch_size,sorted_indices)output=PackedSequence(output_,batch_sizes,sorted_indices,unsorted_indices)returnoutput,self.permute_hidden(hidden,unsorted_indices)defpermute_hidden(self,hx:Tensor,permutation:Optional[Tensor])->Tensor:ifpermutationisNone:returnhxreturn_apply_permutation(hx,permutation)@torch.jit.ignoredefforward(self,input,hx=None):ifisinstance(input,PackedSequence):returnself.forward_packed(input,hx)else:returnself.forward_tensor(input,hx)@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):returnsuper().from_float(mod,use_precomputed_fake_quant=use_precomputed_fake_quant)@classmethoddeffrom_reference(cls,ref_mod):asserthasattr(ref_mod,"weight_ih_l0_dtype"),"We are assuming weight_ih_l0 ""exists in LSTM, may need to relax the assumption to support the use case"qmod=cls(ref_mod.input_size,ref_mod.hidden_size,ref_mod.num_layers,ref_mod.bias,ref_mod.batch_first,ref_mod.dropout,ref_mod.bidirectional,# assuming there is layer 0, which should be OKref_mod.weight_ih_l0_dtype,)qmod.set_weight_bias(ref_mod.get_quantized_weight_bias_dict())returnqmod
classRNNCellBase(torch.nn.Module):# _FLOAT_MODULE = nn.CellRNNBase__constants__=["input_size","hidden_size","bias"]def__init__(self,input_size,hidden_size,bias=True,num_chunks=4,dtype=torch.qint8):super().__init__()self.input_size=input_sizeself.hidden_size=hidden_sizeself.bias=biasself.weight_dtype=dtypeifbias:self.bias_ih=torch.randn(num_chunks*hidden_size).to(dtype=torch.float)self.bias_hh=torch.randn(num_chunks*hidden_size).to(dtype=torch.float)else:self.register_parameter("bias_ih",None)self.register_parameter("bias_hh",None)weight_ih=torch.randn(num_chunks*hidden_size,input_size).to(torch.float)weight_hh=torch.randn(num_chunks*hidden_size,hidden_size).to(torch.float)ifdtype==torch.qint8:weight_ih=torch.quantize_per_tensor(weight_ih,scale=1,zero_point=0,dtype=torch.qint8)weight_hh=torch.quantize_per_tensor(weight_hh,scale=1,zero_point=0,dtype=torch.qint8)ifdtype==torch.qint8:# for each layer, for each direction we need to quantize and pack# weights and pack parameters in this order:## w_ih, w_hhpacked_weight_ih=torch.ops.quantized.linear_prepack(weight_ih,self.bias_ih)packed_weight_hh=torch.ops.quantized.linear_prepack(weight_hh,self.bias_hh)else:# for each layer, for each direction we need to quantize and pack# weights and pack parameters in this order:## packed_ih, packed_hh, b_ih, b_hhpacked_weight_ih=torch.ops.quantized.linear_prepack_fp16(weight_ih,self.bias_ih)packed_weight_hh=torch.ops.quantized.linear_prepack_fp16(weight_hh,self.bias_hh)self._packed_weight_ih=packed_weight_ihself._packed_weight_hh=packed_weight_hhdef_get_name(self):return"DynamicQuantizedRNNBase"defextra_repr(self):s="{input_size}, {hidden_size}"if"bias"inself.__dict__andself.biasisnotTrue:s+=", bias={bias}"if"nonlinearity"inself.__dict__andself.nonlinearity!="tanh":s+=", nonlinearity={nonlinearity}"returns.format(**self.__dict__)defcheck_forward_input(self,input):ifinput.size(1)!=self.input_size:raiseRuntimeError(f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}")defcheck_forward_hidden(self,input:Tensor,hx:Tensor,hidden_label:str="")->None:ifinput.size(0)!=hx.size(0):raiseRuntimeError(f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}")ifhx.size(1)!=self.hidden_size:raiseRuntimeError(f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}")@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):asserttype(mod)in{torch.nn.LSTMCell,torch.nn.GRUCell,torch.nn.RNNCell,},"nn.quantized.dynamic.RNNCellBase.from_float \ only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell"asserthasattr(mod,"qconfig"),"Input float module must have qconfig defined"ifmod.qconfigisnotNoneandmod.qconfig.weightisnotNone:weight_observer_method=mod.qconfig.weightelse:# We have the circular import issues if we import the qconfig in the beginning of this file:# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the# import until we need it.fromtorch.ao.quantization.qconfigimportdefault_dynamic_qconfigweight_observer_method=default_dynamic_qconfig.weightdtype=weight_observer_method().dtypesupported_scalar_types=[torch.qint8,torch.float16]ifdtypenotinsupported_scalar_types:raiseRuntimeError(f"Unsupported dtype for dynamic RNN quantization: {dtype}")qRNNCellBase:Union[LSTMCell,GRUCell,RNNCell]iftype(mod)==torch.nn.LSTMCell:qRNNCellBase=LSTMCell(mod.input_size,mod.hidden_size,bias=mod.bias,dtype=dtype)eliftype(mod)==torch.nn.GRUCell:qRNNCellBase=GRUCell(mod.input_size,mod.hidden_size,bias=mod.bias,dtype=dtype)eliftype(mod)==torch.nn.RNNCell:qRNNCellBase=RNNCell(mod.input_size,mod.hidden_size,bias=mod.bias,nonlinearity=mod.nonlinearity,dtype=dtype,)else:raiseNotImplementedError("Only LSTMCell, GRUCell and RNNCell \ are supported for QuantizedRNN for now")assertmod.biasdef_observe_and_quantize_weight(weight):ifdtype==torch.qint8:weight_observer=weight_observer_method()weight_observer(weight)qweight=_quantize_weight(weight.float(),weight_observer)returnqweightelse:returnweight.float()qRNNCellBase._packed_weight_ih=pack_weight_bias(_observe_and_quantize_weight(mod.weight_ih),mod.bias_ih,dtype)qRNNCellBase._packed_weight_hh=pack_weight_bias(_observe_and_quantize_weight(mod.weight_hh),mod.bias_hh,dtype)returnqRNNCellBase@classmethoddeffrom_reference(cls,ref_mod):asserthasattr(ref_mod,"weight_ih_dtype"),"We are assuming weight_ih ""exists in reference module, may need to relax the assumption to support the use case"ifhasattr(ref_mod,"nonlinearity"):qmod=cls(ref_mod.input_size,ref_mod.hidden_size,ref_mod.bias,ref_mod.nonlinearity,dtype=ref_mod.weight_ih_dtype,)else:qmod=cls(ref_mod.input_size,ref_mod.hidden_size,ref_mod.bias,dtype=ref_mod.weight_ih_dtype,)weight_bias_dict={"weight":{"weight_ih":ref_mod.get_quantized_weight_ih(),"weight_hh":ref_mod.get_quantized_weight_hh(),},"bias":{"bias_ih":ref_mod.bias_ih,"bias_hh":ref_mod.bias_hh,},}qmod.set_weight_bias(weight_bias_dict)returnqmoddef_weight_bias(self):# Returns a dict of weights and biasesweight_bias_dict:Dict[str,Dict]={"weight":{},"bias":{}}w1,b1=self._packed_weight_ih.__getstate__()[0]w2,b2=self._packed_weight_hh.__getstate__()[0]# TODO: these can be simplified to one level? e.g. using weight_ih as key# directlyweight_bias_dict["weight"]["weight_ih"]=w1weight_bias_dict["weight"]["weight_hh"]=w2weight_bias_dict["bias"]["bias_ih"]=b1weight_bias_dict["bias"]["bias_hh"]=b2returnweight_bias_dictdefget_weight(self):returnself._weight_bias()["weight"]defget_bias(self):returnself._weight_bias()["bias"]defset_weight_bias(self,weight_bias_dict):# TODO: these can be simplified to one level? e.g. using weight_ih as key# directlyself._packed_weight_ih=pack_weight_bias(weight_bias_dict["weight"]["weight_ih"],weight_bias_dict["bias"]["bias_ih"],self.weight_dtype,)self._packed_weight_hh=pack_weight_bias(weight_bias_dict["weight"]["weight_hh"],weight_bias_dict["bias"]["bias_hh"],self.weight_dtype,)def_save_to_state_dict(self,destination,prefix,keep_vars):super()._save_to_state_dict(destination,prefix,keep_vars)destination[prefix+"_packed_weight_ih"]=self._packed_weight_ihdestination[prefix+"_packed_weight_hh"]=self._packed_weight_hhdef_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs,):self._packed_weight_ih=state_dict.pop(prefix+"_packed_weight_ih")self._packed_weight_hh=state_dict.pop(prefix+"_packed_weight_hh")super()._load_from_state_dict(state_dict,prefix,local_metadata,False,missing_keys,unexpected_keys,error_msgs,)
[docs]classRNNCell(RNNCellBase):r"""An Elman RNN cell with tanh or ReLU non-linearity. A dynamic quantized RNNCell module with floating point tensor as inputs and outputs. Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.RNNCell`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.RNNCell for documentation. Examples:: >>> # xdoctest: +SKIP >>> rnn = nn.RNNCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): ... hx = rnn(input[i], hx) ... output.append(hx) """__constants__=["input_size","hidden_size","bias","nonlinearity"]def__init__(self,input_size,hidden_size,bias=True,nonlinearity="tanh",dtype=torch.qint8):super().__init__(input_size,hidden_size,bias,num_chunks=1,dtype=dtype)self.nonlinearity=nonlinearitydef_get_name(self):return"DynamicQuantizedRNNCell"defforward(self,input:Tensor,hx:Optional[Tensor]=None)->Tensor:self.check_forward_input(input)ifhxisNone:hx=torch.zeros(input.size(0),self.hidden_size,dtype=input.dtype,device=input.device)self.check_forward_hidden(input,hx,"")ifself.nonlinearity=="tanh":ret=torch.ops.quantized.quantized_rnn_tanh_cell_dynamic(input,hx,self._packed_weight_ih,self._packed_weight_hh,self.bias_ih,self.bias_hh,)elifself.nonlinearity=="relu":ret=torch.ops.quantized.quantized_rnn_relu_cell_dynamic(input,hx,self._packed_weight_ih,self._packed_weight_hh,self.bias_ih,self.bias_hh,)else:ret=input# TODO: remove when jit supports exception flowraiseRuntimeError(f"Unknown nonlinearity: {self.nonlinearity}")returnret@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):returnsuper().from_float(mod,use_precomputed_fake_quant=use_precomputed_fake_quant)
[docs]classLSTMCell(RNNCellBase):r"""A long short-term memory (LSTM) cell. A dynamic quantized LSTMCell module with floating point tensor as inputs and outputs. Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.LSTMCell`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell for documentation. Examples:: >>> # xdoctest: +SKIP >>> rnn = nn.LSTMCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> cx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): ... hx, cx = rnn(input[i], (hx, cx)) ... output.append(hx) """def__init__(self,*args,**kwargs):super().__init__(*args,num_chunks=4,**kwargs)# type: ignore[misc]def_get_name(self):return"DynamicQuantizedLSTMCell"defforward(self,input:Tensor,hx:Optional[Tuple[Tensor,Tensor]]=None)->Tuple[Tensor,Tensor]:self.check_forward_input(input)ifhxisNone:zeros=torch.zeros(input.size(0),self.hidden_size,dtype=input.dtype,device=input.device)hx=(zeros,zeros)self.check_forward_hidden(input,hx[0],"[0]")self.check_forward_hidden(input,hx[1],"[1]")returntorch.ops.quantized.quantized_lstm_cell_dynamic(input,hx,self._packed_weight_ih,self._packed_weight_hh,self.bias_ih,self.bias_hh,)@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):returnsuper().from_float(mod,use_precomputed_fake_quant=use_precomputed_fake_quant)
[docs]classGRUCell(RNNCellBase):r"""A gated recurrent unit (GRU) cell A dynamic quantized GRUCell module with floating point tensor as inputs and outputs. Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.GRUCell`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.GRUCell for documentation. Examples:: >>> # xdoctest: +SKIP >>> rnn = nn.GRUCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): ... hx = rnn(input[i], hx) ... output.append(hx) """def__init__(self,input_size,hidden_size,bias=True,dtype=torch.qint8):super().__init__(input_size,hidden_size,bias,num_chunks=3,dtype=dtype)def_get_name(self):return"DynamicQuantizedGRUCell"defforward(self,input:Tensor,hx:Optional[Tensor]=None)->Tensor:self.check_forward_input(input)ifhxisNone:hx=torch.zeros(input.size(0),self.hidden_size,dtype=input.dtype,device=input.device)self.check_forward_hidden(input,hx,"")returntorch.ops.quantized.quantized_gru_cell_dynamic(input,hx,self._packed_weight_ih,self._packed_weight_hh,self.bias_ih,self.bias_hh,)@classmethoddeffrom_float(cls,mod,use_precomputed_fake_quant=False):returnsuper().from_float(mod,use_precomputed_fake_quant=use_precomputed_fake_quant)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.