Source code for torch.nn.quantized.dynamic.modules.rnn
importnumbersimportwarningsimporttorchimporttorch.nnasnnfromtorchimportTensor# noqa: F401fromtorch._jit_internalimportTuple,Optional,List,Union,Dict# noqa: F401fromtorch.nn.utils.rnnimportPackedSequencefromtorch.nn.quantized.modules.utilsimport_quantize_weightdefapply_permutation(tensor:Tensor,permutation:Tensor,dim:int=1)->Tensor:returntensor.index_select(dim,permutation)classPackedParameter(torch.nn.Module):def__init__(self,param):super(PackedParameter,self).__init__()self.param=paramdef_save_to_state_dict(self,destination,prefix,keep_vars):super(PackedParameter,self)._save_to_state_dict(destination,prefix,keep_vars)destination[prefix+'param']=self.paramdef_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs):self.param=state_dict[prefix+'param']super(PackedParameter,self)._load_from_state_dict(state_dict,prefix,local_metadata,False,missing_keys,unexpected_keys,error_msgs)classRNNBase(torch.nn.Module):_FLOAT_MODULE=nn.RNNBase_version=2def__init__(self,mode,input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.,bidirectional=False,dtype=torch.qint8):super(RNNBase,self).__init__()self.mode=modeself.input_size=input_sizeself.hidden_size=hidden_sizeself.num_layers=num_layersself.bias=biasself.batch_first=batch_firstself.dropout=float(dropout)self.bidirectional=bidirectionalself.dtype=dtypeself.version=2self.training=Falsenum_directions=2ifbidirectionalelse1# "type: ignore" is required since ints and Numbers are not fully comparable# https://github.com/python/mypy/issues/8566ifnotisinstance(dropout,numbers.Number) \
ornot0<=dropout<=1orisinstance(dropout,bool):# type: ignore[operator]raiseValueError("dropout should be a number in range [0, 1] ""representing the probability of an element being ""zeroed")ifdropout>0andnum_layers==1:# type: ignore[operator]warnings.warn("dropout option adds dropout after all but last ""recurrent layer, so non-zero dropout expects ""num_layers greater than 1, but got dropout={} and ""num_layers={}".format(dropout,num_layers))ifmode=='LSTM':gate_size=4*hidden_sizeelifmode=='GRU':gate_size=3*hidden_sizeelse:raiseValueError("Unrecognized RNN mode: "+mode)_all_weight_values=[]forlayerinrange(num_layers):fordirectioninrange(num_directions):layer_input_size=input_sizeiflayer==0elsehidden_size*num_directionsw_ih=torch.randn(gate_size,layer_input_size).to(torch.float)w_hh=torch.randn(gate_size,hidden_size).to(torch.float)b_ih=torch.randn(gate_size).to(torch.float)b_hh=torch.randn(gate_size).to(torch.float)ifdtype==torch.qint8:w_ih=torch.quantize_per_tensor(w_ih,scale=0.1,zero_point=0,dtype=torch.qint8)w_hh=torch.quantize_per_tensor(w_hh,scale=0.1,zero_point=0,dtype=torch.qint8)packed_ih= \
torch.ops.quantized.linear_prepack(w_ih,b_ih)packed_hh= \
torch.ops.quantized.linear_prepack(w_hh,b_hh)ifself.versionisNoneorself.version<2:cell_params=torch.ops.quantized.make_quantized_cell_params_dynamic(packed_ih,packed_hh,b_ih,b_hh)else:cell_params=torch.ops.quantized.make_quantized_cell_params_dynamic(packed_ih,packed_hh,b_ih,b_hh,True)else:packed_ih=torch.ops.quantized.linear_prepack_fp16(w_ih,b_ih)packed_hh=torch.ops.quantized.linear_prepack_fp16(w_hh,b_hh)cell_params=torch.ops.quantized.make_quantized_cell_params_fp16(packed_ih,packed_hh)_all_weight_values.append(PackedParameter(cell_params))self._all_weight_values=torch.nn.ModuleList(_all_weight_values)def_get_name(self):return'DynamicQuantizedRNN'defextra_repr(self):s='{input_size}, {hidden_size}'ifself.num_layers!=1:s+=', num_layers={num_layers}'ifself.biasisnotTrue:s+=', bias={bias}'ifself.batch_firstisnotFalse:s+=', batch_first={batch_first}'ifself.dropout!=0:s+=', dropout={dropout}'ifself.bidirectionalisnotFalse:s+=', bidirectional={bidirectional}'returns.format(**self.__dict__)def__repr__(self):# We don't want to show `ModuleList` children, hence custom# `__repr__`. This is the same as nn.Module.__repr__, except the check# for the `PackedParameter` and `nn.ModuleList`.# You should still override `extra_repr` to add more info.extra_lines=[]extra_repr=self.extra_repr()# empty string will be split into list ['']ifextra_repr:extra_lines=extra_repr.split('\n')child_lines=[]forkey,moduleinself._modules.items():ifisinstance(module,(PackedParameter,nn.ModuleList)):continuemod_str=repr(module)mod_str=nn.modules.module._addindent(mod_str,2)child_lines.append('('+key+'): '+mod_str)lines=extra_lines+child_linesmain_str=self._get_name()+'('iflines:# simple one-liner info, which most builtin Modules will useiflen(extra_lines)==1andnotchild_lines:main_str+=extra_lines[0]else:main_str+='\n '+'\n '.join(lines)+'\n'main_str+=')'returnmain_strdefcheck_input(self,input:Tensor,batch_sizes:Optional[Tensor])->None:expected_input_dim=2ifbatch_sizesisnotNoneelse3ifinput.dim()!=expected_input_dim:raiseRuntimeError('input must have {} dimensions, got {}'.format(expected_input_dim,input.dim()))ifself.input_size!=input.size(-1):raiseRuntimeError('input.size(-1) must be equal to input_size. Expected {}, got {}'.format(self.input_size,input.size(-1)))defget_expected_hidden_size(self,input:Tensor,batch_sizes:Optional[Tensor])->Tuple[int,int,int]:ifbatch_sizesisnotNone:mini_batch=int(batch_sizes[0])else:mini_batch=input.size(0)ifself.batch_firstelseinput.size(1)num_directions=2ifself.bidirectionalelse1expected_hidden_size=(self.num_layers*num_directions,mini_batch,self.hidden_size)returnexpected_hidden_sizedefcheck_hidden_size(self,hx:Tensor,expected_hidden_size:Tuple[int,int,int],msg:str='Expected hidden size {}, got {}')->None:ifhx.size()!=expected_hidden_size:raiseRuntimeError(msg.format(expected_hidden_size,list(hx.size())))defcheck_forward_args(self,input:Tensor,hidden:Tensor,batch_sizes:Optional[Tensor])->None:self.check_input(input,batch_sizes)expected_hidden_size=self.get_expected_hidden_size(input,batch_sizes)self.check_hidden_size(hidden,expected_hidden_size,msg='Expected hidden size {}, got {}')defpermute_hidden(self,hx:Tensor,permutation:Optional[Tensor])->Tensor:ifpermutationisNone:returnhxreturnapply_permutation(hx,permutation)def_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs):version=local_metadata.get('version',None)self.version=versionsuper(RNNBase,self)._load_from_state_dict(state_dict,prefix,local_metadata,False,missing_keys,unexpected_keys,error_msgs)@classmethoddeffrom_float(cls,mod):asserttype(mod)inset([torch.nn.LSTM,torch.nn.GRU]),'nn.quantized.dynamic.RNNBase.from_float only works for nn.LSTM and nn.GRU'asserthasattr(mod,'qconfig'),'Input float module must have qconfig defined'ifmod.qconfigisnotNoneandmod.qconfig.weightisnotNone:weight_observer_method=mod.qconfig.weightelse:# We have the circular import issues if we import the qconfig in the beginning of this file:# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the# import until we need it.fromtorch.quantization.qconfigimportdefault_dynamic_qconfigweight_observer_method=default_dynamic_qconfig.weightdtype=weight_observer_method().dtypesupported_scalar_types=[torch.qint8,torch.float16]ifdtypenotinsupported_scalar_types:raiseRuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype))# RNNBase can be either LSTM or GRUqRNNBase:Union[LSTM,GRU]ifmod.mode=='LSTM':qRNNBase=LSTM(mod.input_size,mod.hidden_size,mod.num_layers,mod.bias,mod.batch_first,mod.dropout,mod.bidirectional,dtype)elifmod.mode=='GRU':qRNNBase=GRU(mod.input_size,mod.hidden_size,mod.num_layers,mod.bias,mod.batch_first,mod.dropout,mod.bidirectional,dtype)else:raiseNotImplementedError('Only LSTM/GRU is supported for QuantizedRNN for now')num_directions=2ifmod.bidirectionalelse1assertmod.bias_all_weight_values=[]forlayerinrange(qRNNBase.num_layers):fordirectioninrange(num_directions):suffix='_reverse'ifdirection==1else''defretrieve_weight_bias(ihhh):weight_name='weight_{}_l{}{}'.format(ihhh,layer,suffix)bias_name='bias_{}_l{}{}'.format(ihhh,layer,suffix)weight=getattr(mod,weight_name)bias=getattr(mod,bias_name)returnweight,biasweight_ih,bias_ih=retrieve_weight_bias('ih')weight_hh,bias_hh=retrieve_weight_bias('hh')ifdtype==torch.qint8:defquantize_and_pack(w,b):weight_observer=weight_observer_method()weight_observer(w)qweight=_quantize_weight(w.float(),weight_observer)packed_weight= \
torch.ops.quantized.linear_prepack(qweight,b)returnpacked_weightpacked_ih=quantize_and_pack(weight_ih,bias_ih)packed_hh=quantize_and_pack(weight_hh,bias_hh)ifqRNNBase.versionisNoneorqRNNBase.version<2:cell_params=torch.ops.quantized.make_quantized_cell_params_dynamic(packed_ih,packed_hh,bias_ih,bias_hh)else:cell_params=torch.ops.quantized.make_quantized_cell_params_dynamic(packed_ih,packed_hh,bias_ih,bias_hh,True)elifdtype==torch.float16:packed_ih=torch.ops.quantized.linear_prepack_fp16(weight_ih.float(),bias_ih)packed_hh=torch.ops.quantized.linear_prepack_fp16(weight_hh.float(),bias_hh)cell_params=torch.ops.quantized.make_quantized_cell_params_fp16(packed_ih,packed_hh)else:raiseRuntimeError('Unsupported dtype specified for dynamic quantized LSTM!')_all_weight_values.append(PackedParameter(cell_params))qRNNBase._all_weight_values=torch.nn.ModuleList(_all_weight_values)returnqRNNBasedef_weight_bias(self):# Returns a dict of weights and biasesweight_bias_dict:Dict[str,Dict]={'weight':{},'bias':{}}count=0num_directions=2ifself.bidirectionalelse1forlayerinrange(self.num_layers):fordirectioninrange(num_directions):suffix='_reverse'ifdirection==1else''key_name1='weight_ih_l{layer_idx}{suffix}'.format(layer_idx=layer,suffix=suffix)key_name2='weight_hh_l{layer_idx}{suffix}'.format(layer_idx=layer,suffix=suffix)# packed weights are part of torchbind class, CellParamsSerializationType# Within the packed weight class, the weight and bias are accessible as Tensorspacked_weight_bias=self._all_weight_values[count].param.__getstate__()[0][4]weight_bias_dict['weight'][key_name1]=packed_weight_bias[0].__getstate__()[0][0]weight_bias_dict['weight'][key_name2]=packed_weight_bias[1].__getstate__()[0][0]key_name1='bias_ih_l{layer_idx}{suffix}'.format(layer_idx=layer,suffix=suffix)key_name2='bias_hh_l{layer_idx}{suffix}'.format(layer_idx=layer,suffix=suffix)weight_bias_dict['bias'][key_name1]=packed_weight_bias[0].__getstate__()[0][1]weight_bias_dict['bias'][key_name2]=packed_weight_bias[1].__getstate__()[0][1]count=count+1returnweight_bias_dictdefget_weight(self):returnself._weight_bias()['weight']defget_bias(self):returnself._weight_bias()['bias']
[docs]classLSTM(RNNBase):r""" A dynamic quantized LSTM module with floating point tensor as inputs and outputs. We adopt the same interface as `torch.nn.LSTM`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.LSTM for documentation. Examples:: >>> rnn = nn.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> c0 = torch.randn(2, 3, 20) >>> output, (hn, cn) = rnn(input, (h0, c0)) """_FLOAT_MODULE=nn.LSTM__overloads__={'forward':['forward_packed','forward_tensor']}def__init__(self,*args,**kwargs):super(LSTM,self).__init__('LSTM',*args,**kwargs)def_get_name(self):return'DynamicQuantizedLSTM'defforward_impl(self,input:Tensor,hx:Optional[Tuple[Tensor,Tensor]],batch_sizes:Optional[Tensor],max_batch_size:int,sorted_indices:Optional[Tensor])->Tuple[Tensor,Tuple[Tensor,Tensor]]:ifhxisNone:num_directions=2ifself.bidirectionalelse1zeros=torch.zeros(self.num_layers*num_directions,max_batch_size,self.hidden_size,dtype=input.dtype,device=input.device)hx=(zeros,zeros)else:# Each batch of the hidden state should match the input sequence that# the user believes he/she is passing in.hx=self.permute_hidden(hx,sorted_indices)self.check_forward_args(input,hx,batch_sizes)_all_params=([m.paramforminself._all_weight_values])ifbatch_sizesisNone:result=torch.quantized_lstm(input,hx,_all_params,self.bias,self.num_layers,float(self.dropout),self.training,self.bidirectional,self.batch_first,dtype=self.dtype,use_dynamic=True)else:result=torch.quantized_lstm(input,batch_sizes,hx,_all_params,self.bias,self.num_layers,float(self.dropout),self.training,self.bidirectional,dtype=self.dtype,use_dynamic=True)output=result[0]hidden=result[1:]returnoutput,hidden@torch.jit.exportdefforward_tensor(self,input:Tensor,hx:Optional[Tuple[Tensor,Tensor]]=None)->Tuple[Tensor,Tuple[Tensor,Tensor]]:batch_sizes=Nonemax_batch_size=input.size(0)ifself.batch_firstelseinput.size(1)sorted_indices=Noneunsorted_indices=Noneoutput,hidden=self.forward_impl(input,hx,batch_sizes,max_batch_size,sorted_indices)returnoutput,self.permute_hidden(hidden,unsorted_indices)@torch.jit.exportdefforward_packed(self,input:PackedSequence,hx:Optional[Tuple[Tensor,Tensor]]=None)->Tuple[PackedSequence,Tuple[Tensor,Tensor]]:input_,batch_sizes,sorted_indices,unsorted_indices=inputmax_batch_size=batch_sizes[0]max_batch_size=int(max_batch_size)output_,hidden=self.forward_impl(input_,hx,batch_sizes,max_batch_size,sorted_indices)output=PackedSequence(output_,batch_sizes,sorted_indices,unsorted_indices)returnoutput,self.permute_hidden(hidden,unsorted_indices)# "type: ignore" is required due to issue #43072defpermute_hidden(# type: ignore[override]self,hx:Tuple[Tensor,Tensor],permutation:Optional[Tensor])->Tuple[Tensor,Tensor]:ifpermutationisNone:returnhxreturnapply_permutation(hx[0],permutation),apply_permutation(hx[1],permutation)# "type: ignore" is required due to issue #43072defcheck_forward_args(# type: ignore[override]self,input:Tensor,hidden:Tuple[Tensor,Tensor],batch_sizes:Optional[Tensor])->None:self.check_input(input,batch_sizes)expected_hidden_size=self.get_expected_hidden_size(input,batch_sizes)self.check_hidden_size(hidden[0],expected_hidden_size,'Expected hidden[0] size {}, got {}')self.check_hidden_size(hidden[1],expected_hidden_size,'Expected hidden[1] size {}, got {}')@torch.jit.ignoredefforward(self,input,hx=None):ifisinstance(input,PackedSequence):returnself.forward_packed(input,hx)else:returnself.forward_tensor(input,hx)@classmethoddeffrom_float(cls,mod):returnsuper(LSTM,cls).from_float(mod)
classGRU(RNNBase):r"""Applies a multi-layer gated recurrent unit (GRU) RNN to an input sequence. For each element in the input sequence, each layer computes the following function: .. math:: \begin{array}{ll} r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ n_t = \tanh(W_{in} x_t + b_{in} + r_t * (W_{hn} h_{(t-1)}+ b_{hn})) \\ h_t = (1 - z_t) * n_t + z_t * h_{(t-1)} \end{array} where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, :math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. :math:`\sigma` is the sigmoid function, and :math:`*` is the Hadamard product. In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer (:math:`l >= 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random variable which is :math:`0` with probability :attr:`dropout`. Args: input_size: The number of expected features in the input `x` hidden_size: The number of features in the hidden state `h` num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` would mean stacking two GRUs together to form a `stacked GRU`, with the second GRU taking in outputs of the first GRU and computing the final results. Default: 1 bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. Default: ``True`` batch_first: If ``True``, then the input and output tensors are provided as (batch, seq, feature). Default: ``False`` dropout: If non-zero, introduces a `Dropout` layer on the outputs of each GRU layer except the last layer, with dropout probability equal to :attr:`dropout`. Default: 0 bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` Inputs: input, h_0 - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features of the input sequence. The input can also be a packed variable length sequence. See :func:`torch.nn.utils.rnn.pack_padded_sequence` for details. - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor containing the initial hidden state for each element in the batch. Defaults to zero if not provided. If the RNN is bidirectional, num_directions should be 2, else it should be 1. Outputs: output, h_n - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor containing the output features h_t from the last layer of the GRU, for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output will also be a packed sequence. For the unpacked case, the directions can be separated using ``output.view(seq_len, batch, num_directions, hidden_size)``, with forward and backward being direction `0` and `1` respectively. Similarly, the directions can be separated in the packed case. - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor containing the hidden state for `t = seq_len` Like *output*, the layers can be separated using ``h_n.view(num_layers, num_directions, batch, hidden_size)``. Shape: - Input1: :math:`(L, N, H_{in})` tensor containing input features where :math:`H_{in}=\text{input\_size}` and `L` represents a sequence length. - Input2: :math:`(S, N, H_{out})` tensor containing the initial hidden state for each element in the batch. :math:`H_{out}=\text{hidden\_size}` Defaults to zero if not provided. where :math:`S=\text{num\_layers} * \text{num\_directions}` If the RNN is bidirectional, num_directions should be 2, else it should be 1. - Output1: :math:`(L, N, H_{all})` where :math:`H_{all}=\text{num\_directions} * \text{hidden\_size}` - Output2: :math:`(S, N, H_{out})` tensor containing the next hidden state for each element in the batch Attributes: weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer (W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer (W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer (b_ir|b_iz|b_in), of shape `(3*hidden_size)` bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer (b_hr|b_hz|b_hn), of shape `(3*hidden_size)` .. note:: All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where :math:`k = \frac{1}{\text{hidden\_size}}` .. include:: cudnn_persistent_rnn.rst Examples:: >>> rnn = nn.GRU(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> output, hn = rnn(input, h0) """_FLOAT_MODULE=nn.GRU__overloads__={'forward':['forward_packed','forward_tensor']}def__init__(self,*args,**kwargs):super(GRU,self).__init__('GRU',*args,**kwargs)def_get_name(self):return'DynamicQuantizedGRU'defcheck_forward_args(self,input,hidden,batch_sizes):# type: (Tensor, Tensor, Optional[Tensor])->Noneself.check_input(input,batch_sizes)expected_hidden_size=self.get_expected_hidden_size(input,batch_sizes)self.check_hidden_size(hidden,expected_hidden_size,'Expected hidden size {}, got {}')defforward_impl(self,input:Tensor,hx:Optional[Tensor],batch_sizes:Optional[Tensor],max_batch_size:int,sorted_indices:Optional[Tensor])->Tuple[Tensor,Tensor]:ifhxisNone:num_directions=2ifself.bidirectionalelse1zeros=torch.zeros(self.num_layers*num_directions,max_batch_size,self.hidden_size,dtype=input.dtype,device=input.device)hx=zeroselse:# Each batch of the hidden state should match the input sequence that# the user believes he/she is passing in.hx=self.permute_hidden(hx,sorted_indices)self.check_forward_args(input,hx,batch_sizes)_all_params=([m.paramforminself._all_weight_values])ifbatch_sizesisNone:result=torch.quantized_gru(input,hx,_all_params,self.bias,self.num_layers,self.dropout,self.training,self.bidirectional,self.batch_first)else:result=torch.quantized_gru(input,batch_sizes,hx,_all_params,self.bias,self.num_layers,self.dropout,self.training,self.bidirectional)output=result[0]hidden=result[1]returnoutput,hidden@torch.jit.exportdefforward_tensor(self,input:Tensor,hx:Optional[Tensor]=None)->Tuple[Tensor,Tensor]:batch_sizes=Nonemax_batch_size=input.size(0)ifself.batch_firstelseinput.size(1)sorted_indices=Noneunsorted_indices=Noneoutput,hidden=self.forward_impl(input,hx,batch_sizes,max_batch_size,sorted_indices)returnoutput,self.permute_hidden(hidden,unsorted_indices)@torch.jit.exportdefforward_packed(self,input:PackedSequence,hx:Optional[Tensor]=None)->Tuple[PackedSequence,Tensor]:input_,batch_sizes,sorted_indices,unsorted_indices=inputmax_batch_size=batch_sizes[0]max_batch_size=int(max_batch_size)output_,hidden=self.forward_impl(input_,hx,batch_sizes,max_batch_size,sorted_indices)output=PackedSequence(output_,batch_sizes,sorted_indices,unsorted_indices)returnoutput,self.permute_hidden(hidden,unsorted_indices)defpermute_hidden(self,hx:Tensor,permutation:Optional[Tensor])->Tensor:ifpermutationisNone:returnhxreturnapply_permutation(hx,permutation)@torch.jit.ignoredefforward(self,input,hx=None):ifisinstance(input,PackedSequence):returnself.forward_packed(input,hx)else:returnself.forward_tensor(input,hx)@classmethoddeffrom_float(cls,mod):returnsuper(GRU,cls).from_float(mod)classRNNCellBase(torch.nn.Module):# _FLOAT_MODULE = nn.CellRNNBase__constants__=['input_size','hidden_size','bias']def__init__(self,input_size,hidden_size,bias=True,num_chunks=4,dtype=torch.qint8):super(RNNCellBase,self).__init__()self.input_size=input_sizeself.hidden_size=hidden_sizeself.bias=biasifbias:self.bias_ih=torch.randn(num_chunks*hidden_size).to(dtype=torch.float)self.bias_hh=torch.randn(num_chunks*hidden_size).to(dtype=torch.float)else:self.register_parameter('bias_ih',None)self.register_parameter('bias_hh',None)weight_ih=torch.randn(num_chunks*hidden_size,input_size).to(torch.float)weight_hh=torch.randn(num_chunks*hidden_size,hidden_size).to(torch.float)ifdtype==torch.qint8:weight_ih=torch.quantize_per_tensor(weight_ih,scale=1,zero_point=0,dtype=torch.qint8)weight_hh=torch.quantize_per_tensor(weight_hh,scale=1,zero_point=0,dtype=torch.qint8)ifdtype==torch.qint8:# for each layer, for each direction we need to quantize and pack# weights and pack parameters in this order:## w_ih, w_hhpacked_weight_ih= \
torch.ops.quantized.linear_prepack(weight_ih,self.bias_ih)packed_weight_hh= \
torch.ops.quantized.linear_prepack(weight_hh,self.bias_hh)else:# for each layer, for each direction we need to quantize and pack# weights and pack parameters in this order:## packed_ih, packed_hh, b_ih, b_hhpacked_weight_ih=torch.ops.quantized.linear_prepack_fp16(weight_ih,self.bias_ih)packed_weight_hh=torch.ops.quantized.linear_prepack_fp16(weight_hh,self.bias_hh)self._packed_weight_ih=packed_weight_ihself._packed_weight_hh=packed_weight_hhdef_get_name(self):return'DynamicQuantizedRNNBase'defextra_repr(self):s='{input_size}, {hidden_size}'if'bias'inself.__dict__andself.biasisnotTrue:s+=', bias={bias}'if'nonlinearity'inself.__dict__andself.nonlinearity!="tanh":s+=', nonlinearity={nonlinearity}'returns.format(**self.__dict__)defcheck_forward_input(self,input):ifinput.size(1)!=self.input_size:raiseRuntimeError("input has inconsistent input_size: got {}, expected {}".format(input.size(1),self.input_size))defcheck_forward_hidden(self,input:Tensor,hx:Tensor,hidden_label:str='')->None:ifinput.size(0)!=hx.size(0):raiseRuntimeError("Input batch size {} doesn't match hidden{} batch size {}".format(input.size(0),hidden_label,hx.size(0)))ifhx.size(1)!=self.hidden_size:raiseRuntimeError("hidden{} has inconsistent hidden_size: got {}, expected {}".format(hidden_label,hx.size(1),self.hidden_size))@classmethoddeffrom_float(cls,mod):asserttype(mod)inset([torch.nn.LSTMCell,torch.nn.GRUCell,torch.nn.RNNCell]),'nn.quantized.dynamic.RNNCellBase.from_float \ only works for nn.LSTMCell, nn.GRUCell and nn.RNNCell'asserthasattr(mod,'qconfig'),'Input float module must have qconfig defined'ifmod.qconfigisnotNoneandmod.qconfig.weightisnotNone:weight_observer_method=mod.qconfig.weightelse:# We have the circular import issues if we import the qconfig in the beginning of this file:# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the# import until we need it.fromtorch.quantization.qconfigimportdefault_dynamic_qconfigweight_observer_method=default_dynamic_qconfig.weightdtype=weight_observer_method().dtypesupported_scalar_types=[torch.qint8,torch.float16]ifdtypenotinsupported_scalar_types:raiseRuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype))qRNNCellBase:Union[LSTMCell,GRUCell,RNNCell]iftype(mod)==torch.nn.LSTMCell:qRNNCellBase=LSTMCell(mod.input_size,mod.hidden_size,bias=mod.bias,dtype=dtype)eliftype(mod)==torch.nn.GRUCell:qRNNCellBase=GRUCell(mod.input_size,mod.hidden_size,bias=mod.bias,dtype=dtype)eliftype(mod)==torch.nn.RNNCell:qRNNCellBase=RNNCell(mod.input_size,mod.hidden_size,bias=mod.bias,nonlinearity=mod.nonlinearity,dtype=dtype)else:raiseNotImplementedError('Only LSTMCell, GRUCell and RNNCell \ are supported for QuantizedRNN for now')assertmod.biasdefprocess_weights(weight,bias,dtype):ifdtype==torch.qint8:# for each layer, for each direction we need to quantize and pack# weights and pack parameters in this order:## w_ih, w_hhweight_observer=weight_observer_method()weight_observer(weight)qweight=_quantize_weight(weight.float(),weight_observer)packed_weight= \
torch.ops.quantized.linear_prepack(qweight,bias)returnpacked_weightelse:# for each layer, for each direction we need to quantize and pack# weights and pack parameters in this order:## packed_ih, packed_hh, b_ih, b_hhpacked_weight=torch.ops.quantized.linear_prepack_fp16(weight.float(),bias)returnpacked_weightqRNNCellBase._packed_weight_ih=process_weights(mod.weight_ih,mod.bias_ih,dtype)qRNNCellBase._packed_weight_hh=process_weights(mod.weight_hh,mod.bias_hh,dtype)returnqRNNCellBasedef_weight_bias(self):# Returns a dict of weights and biasesweight_bias_dict:Dict[str,Dict]={'weight':{},'bias':{}}w1,b1=self._packed_weight_ih.__getstate__()[0]w2,b2=self._packed_weight_hh.__getstate__()[0]weight_bias_dict['weight']['weight_ih']=w1weight_bias_dict['weight']['weight_hh']=w2weight_bias_dict['bias']['bias_ih']=b1weight_bias_dict['bias']['bias_hh']=b2returnweight_bias_dictdefget_weight(self):returnself._weight_bias()['weight']defget_bias(self):returnself._weight_bias()['bias']def_save_to_state_dict(self,destination,prefix,keep_vars):super(RNNCellBase,self)._save_to_state_dict(destination,prefix,keep_vars)destination[prefix+'_packed_weight_ih']=self._packed_weight_ihdestination[prefix+'_packed_weight_hh']=self._packed_weight_hhdef_load_from_state_dict(self,state_dict,prefix,local_metadata,strict,missing_keys,unexpected_keys,error_msgs):self._packed_weight_ih=state_dict.pop(prefix+'_packed_weight_ih')self._packed_weight_hh=state_dict.pop(prefix+'_packed_weight_hh')super(RNNCellBase,self)._load_from_state_dict(state_dict,prefix,local_metadata,False,missing_keys,unexpected_keys,error_msgs)
[docs]classRNNCell(RNNCellBase):r"""An Elman RNN cell with tanh or ReLU non-linearity. A dynamic quantized RNNCell module with floating point tensor as inputs and outputs. Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.RNNCell`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.RNNCell for documentation. Examples:: >>> rnn = nn.RNNCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): hx = rnn(input[i], hx) output.append(hx) """__constants__=['input_size','hidden_size','bias','nonlinearity']def__init__(self,input_size,hidden_size,bias=True,nonlinearity="tanh",dtype=torch.qint8):super(RNNCell,self).__init__(input_size,hidden_size,bias,num_chunks=1,dtype=dtype)self.nonlinearity=nonlinearitydef_get_name(self):return'DynamicQuantizedRNNCell'defforward(self,input:Tensor,hx:Optional[Tensor]=None)->Tensor:self.check_forward_input(input)ifhxisNone:hx=torch.zeros(input.size(0),self.hidden_size,dtype=input.dtype,device=input.device)self.check_forward_hidden(input,hx,'')ifself.nonlinearity=="tanh":ret=torch.ops.quantized.quantized_rnn_tanh_cell_dynamic(input,hx,self._packed_weight_ih,self._packed_weight_hh,self.bias_ih,self.bias_hh)elifself.nonlinearity=="relu":ret=torch.ops.quantized.quantized_rnn_relu_cell_dynamic(input,hx,self._packed_weight_ih,self._packed_weight_hh,self.bias_ih,self.bias_hh)else:ret=input# TODO: remove when jit supports exception flowraiseRuntimeError("Unknown nonlinearity: {}".format(self.nonlinearity))returnret@classmethoddeffrom_float(cls,mod):returnsuper(RNNCell,cls).from_float(mod)
[docs]classLSTMCell(RNNCellBase):r"""A long short-term memory (LSTM) cell. A dynamic quantized LSTMCell module with floating point tensor as inputs and outputs. Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.LSTMCell`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.LSTMCell for documentation. Examples:: >>> rnn = nn.LSTMCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> cx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): hx, cx = rnn(input[i], (hx, cx)) output.append(hx) """def__init__(self,*args,**kwargs):super(LSTMCell,self).__init__(*args,num_chunks=4,**kwargs)# type: ignore[misc]def_get_name(self):return'DynamicQuantizedLSTMCell'defforward(self,input:Tensor,hx:Optional[Tuple[Tensor,Tensor]]=None)->Tuple[Tensor,Tensor]:self.check_forward_input(input)ifhxisNone:zeros=torch.zeros(input.size(0),self.hidden_size,dtype=input.dtype,device=input.device)hx=(zeros,zeros)self.check_forward_hidden(input,hx[0],'[0]')self.check_forward_hidden(input,hx[1],'[1]')returntorch.ops.quantized.quantized_lstm_cell_dynamic(input,hx,self._packed_weight_ih,self._packed_weight_hh,self.bias_ih,self.bias_hh)@classmethoddeffrom_float(cls,mod):returnsuper(LSTMCell,cls).from_float(mod)
[docs]classGRUCell(RNNCellBase):r"""A gated recurrent unit (GRU) cell A dynamic quantized GRUCell module with floating point tensor as inputs and outputs. Weights are quantized to 8 bits. We adopt the same interface as `torch.nn.GRUCell`, please see https://pytorch.org/docs/stable/nn.html#torch.nn.GRUCell for documentation. Examples:: >>> rnn = nn.GRUCell(10, 20) >>> input = torch.randn(6, 3, 10) >>> hx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): hx = rnn(input[i], hx) output.append(hx) """def__init__(self,input_size,hidden_size,bias=True,dtype=torch.qint8):super(GRUCell,self).__init__(input_size,hidden_size,bias,num_chunks=3,dtype=dtype)def_get_name(self):return'DynamicQuantizedGRUCell'defforward(self,input:Tensor,hx:Optional[Tensor]=None)->Tensor:self.check_forward_input(input)ifhxisNone:hx=torch.zeros(input.size(0),self.hidden_size,dtype=input.dtype,device=input.device)self.check_forward_hidden(input,hx,'')returntorch.ops.quantized.quantized_gru_cell_dynamic(input,hx,self._packed_weight_ih,self._packed_weight_hh,self.bias_ih,self.bias_hh,)@classmethoddeffrom_float(cls,mod):returnsuper(GRUCell,cls).from_float(mod)
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.