Source code for torch.ao.nn.quantizable.modules.rnn
importnumbersfromtypingimportOptional,TupleimportwarningsimporttorchfromtorchimportTensor"""We will recreate all the RNN modules as we require the modules to be decomposedinto its building blocks to be able to observe."""__all__=["LSTMCell","LSTM"]classLSTMCell(torch.nn.Module):r"""A quantizable long short-term memory (LSTM) cell. For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell` Examples:: >>> import torch.ao.nn.quantizable as nnqa >>> rnn = nnqa.LSTMCell(10, 20) >>> input = torch.randn(6, 10) >>> hx = torch.randn(3, 20) >>> cx = torch.randn(3, 20) >>> output = [] >>> for i in range(6): ... hx, cx = rnn(input[i], (hx, cx)) ... output.append(hx) """_FLOAT_MODULE=torch.nn.LSTMCelldef__init__(self,input_dim:int,hidden_dim:int,bias:bool=True,device=None,dtype=None)->None:factory_kwargs={'device':device,'dtype':dtype}super().__init__()self.input_size=input_dimself.hidden_size=hidden_dimself.bias=biasself.igates=torch.nn.Linear(input_dim,4*hidden_dim,bias=bias,**factory_kwargs)self.hgates=torch.nn.Linear(hidden_dim,4*hidden_dim,bias=bias,**factory_kwargs)self.gates=torch.ao.nn.quantized.FloatFunctional()self.input_gate=torch.nn.Sigmoid()self.forget_gate=torch.nn.Sigmoid()self.cell_gate=torch.nn.Tanh()self.output_gate=torch.nn.Sigmoid()self.fgate_cx=torch.ao.nn.quantized.FloatFunctional()self.igate_cgate=torch.ao.nn.quantized.FloatFunctional()self.fgate_cx_igate_cgate=torch.ao.nn.quantized.FloatFunctional()self.ogate_cy=torch.ao.nn.quantized.FloatFunctional()self.initial_hidden_state_qparams:Tuple[float,int]=(1.0,0)self.initial_cell_state_qparams:Tuple[float,int]=(1.0,0)self.hidden_state_dtype:torch.dtype=torch.quint8self.cell_state_dtype:torch.dtype=torch.quint8defforward(self,x:Tensor,hidden:Optional[Tuple[Tensor,Tensor]]=None)->Tuple[Tensor,Tensor]:ifhiddenisNoneorhidden[0]isNoneorhidden[1]isNone:hidden=self.initialize_hidden(x.shape[0],x.is_quantized)hx,cx=hiddenigates=self.igates(x)hgates=self.hgates(hx)gates=self.gates.add(igates,hgates)input_gate,forget_gate,cell_gate,out_gate=gates.chunk(4,1)input_gate=self.input_gate(input_gate)forget_gate=self.forget_gate(forget_gate)cell_gate=self.cell_gate(cell_gate)out_gate=self.output_gate(out_gate)fgate_cx=self.fgate_cx.mul(forget_gate,cx)igate_cgate=self.igate_cgate.mul(input_gate,cell_gate)fgate_cx_igate_cgate=self.fgate_cx_igate_cgate.add(fgate_cx,igate_cgate)cy=fgate_cx_igate_cgate# TODO: make this tanh a member of the module so its qparams can be configuredtanh_cy=torch.tanh(cy)hy=self.ogate_cy.mul(out_gate,tanh_cy)returnhy,cydefinitialize_hidden(self,batch_size:int,is_quantized:bool=False)->Tuple[Tensor,Tensor]:h,c=torch.zeros((batch_size,self.hidden_size)),torch.zeros((batch_size,self.hidden_size))ifis_quantized:(h_scale,h_zp)=self.initial_hidden_state_qparams(c_scale,c_zp)=self.initial_cell_state_qparamsh=torch.quantize_per_tensor(h,scale=h_scale,zero_point=h_zp,dtype=self.hidden_state_dtype)c=torch.quantize_per_tensor(c,scale=c_scale,zero_point=c_zp,dtype=self.cell_state_dtype)returnh,cdef_get_name(self):return'QuantizableLSTMCell'@classmethoddeffrom_params(cls,wi,wh,bi=None,bh=None):"""Uses the weights and biases to create a new LSTM cell. Args: wi, wh: Weights for the input and hidden layers bi, bh: Biases for the input and hidden layers """assert(biisNone)==(bhisNone)# Either both None or both have valuesinput_size=wi.shape[1]hidden_size=wh.shape[1]cell=cls(input_dim=input_size,hidden_dim=hidden_size,bias=(biisnotNone))cell.igates.weight=torch.nn.Parameter(wi)ifbiisnotNone:cell.igates.bias=torch.nn.Parameter(bi)cell.hgates.weight=torch.nn.Parameter(wh)ifbhisnotNone:cell.hgates.bias=torch.nn.Parameter(bh)returncell@classmethoddeffrom_float(cls,other):asserttype(other)==cls._FLOAT_MODULEasserthasattr(other,'qconfig'),"The float module must have 'qconfig'"observed=cls.from_params(other.weight_ih,other.weight_hh,other.bias_ih,other.bias_hh)observed.qconfig=other.qconfigobserved.igates.qconfig=other.qconfigobserved.hgates.qconfig=other.qconfigreturnobservedclass_LSTMSingleLayer(torch.nn.Module):r"""A single one-directional LSTM layer. The difference between a layer and a cell is that the layer can process a sequence, while the cell only expects an instantaneous value. """def__init__(self,input_dim:int,hidden_dim:int,bias:bool=True,device=None,dtype=None)->None:factory_kwargs={'device':device,'dtype':dtype}super().__init__()self.cell=LSTMCell(input_dim,hidden_dim,bias=bias,**factory_kwargs)defforward(self,x:Tensor,hidden:Optional[Tuple[Tensor,Tensor]]=None):result=[]seq_len=x.shape[0]foriinrange(seq_len):hidden=self.cell(x[i],hidden)result.append(hidden[0])# type: ignore[index]result_tensor=torch.stack(result,0)returnresult_tensor,hidden@classmethoddeffrom_params(cls,*args,**kwargs):cell=LSTMCell.from_params(*args,**kwargs)layer=cls(cell.input_size,cell.hidden_size,cell.bias)layer.cell=cellreturnlayerclass_LSTMLayer(torch.nn.Module):r"""A single bi-directional LSTM layer."""def__init__(self,input_dim:int,hidden_dim:int,bias:bool=True,batch_first:bool=False,bidirectional:bool=False,device=None,dtype=None)->None:factory_kwargs={'device':device,'dtype':dtype}super().__init__()self.batch_first=batch_firstself.bidirectional=bidirectionalself.layer_fw=_LSTMSingleLayer(input_dim,hidden_dim,bias=bias,**factory_kwargs)ifself.bidirectional:self.layer_bw=_LSTMSingleLayer(input_dim,hidden_dim,bias=bias,**factory_kwargs)defforward(self,x:Tensor,hidden:Optional[Tuple[Tensor,Tensor]]=None):ifself.batch_first:x=x.transpose(0,1)ifhiddenisNone:hx_fw,cx_fw=(None,None)else:hx_fw,cx_fw=hiddenhidden_bw:Optional[Tuple[Tensor,Tensor]]=Noneifself.bidirectional:ifhx_fwisNone:hx_bw=Noneelse:hx_bw=hx_fw[1]hx_fw=hx_fw[0]ifcx_fwisNone:cx_bw=Noneelse:cx_bw=cx_fw[1]cx_fw=cx_fw[0]ifhx_bwisnotNoneandcx_bwisnotNone:hidden_bw=hx_bw,cx_bwifhx_fwisNoneandcx_fwisNone:hidden_fw=Noneelse:hidden_fw=torch.jit._unwrap_optional(hx_fw),torch.jit._unwrap_optional(cx_fw)result_fw,hidden_fw=self.layer_fw(x,hidden_fw)ifhasattr(self,'layer_bw')andself.bidirectional:x_reversed=x.flip(0)result_bw,hidden_bw=self.layer_bw(x_reversed,hidden_bw)result_bw=result_bw.flip(0)result=torch.cat([result_fw,result_bw],result_fw.dim()-1)ifhidden_fwisNoneandhidden_bwisNone:h=Nonec=Noneelifhidden_fwisNone:(h,c)=torch.jit._unwrap_optional(hidden_bw)elifhidden_bwisNone:(h,c)=torch.jit._unwrap_optional(hidden_fw)else:h=torch.stack([hidden_fw[0],hidden_bw[0]],0)# type: ignore[list-item]c=torch.stack([hidden_fw[1],hidden_bw[1]],0)# type: ignore[list-item]else:result=result_fwh,c=torch.jit._unwrap_optional(hidden_fw)# type: ignore[assignment]ifself.batch_first:result.transpose_(0,1)returnresult,(h,c)@classmethoddeffrom_float(cls,other,layer_idx=0,qconfig=None,**kwargs):r""" There is no FP equivalent of this class. This function is here just to mimic the behavior of the `prepare` within the `torch.ao.quantization` flow. """asserthasattr(other,'qconfig')or(qconfigisnotNone)input_size=kwargs.get('input_size',other.input_size)hidden_size=kwargs.get('hidden_size',other.hidden_size)bias=kwargs.get('bias',other.bias)batch_first=kwargs.get('batch_first',other.batch_first)bidirectional=kwargs.get('bidirectional',other.bidirectional)layer=cls(input_size,hidden_size,bias,batch_first,bidirectional)layer.qconfig=getattr(other,'qconfig',qconfig)wi=getattr(other,f'weight_ih_l{layer_idx}')wh=getattr(other,f'weight_hh_l{layer_idx}')bi=getattr(other,f'bias_ih_l{layer_idx}',None)bh=getattr(other,f'bias_hh_l{layer_idx}',None)layer.layer_fw=_LSTMSingleLayer.from_params(wi,wh,bi,bh)ifother.bidirectional:wi=getattr(other,f'weight_ih_l{layer_idx}_reverse')wh=getattr(other,f'weight_hh_l{layer_idx}_reverse')bi=getattr(other,f'bias_ih_l{layer_idx}_reverse',None)bh=getattr(other,f'bias_hh_l{layer_idx}_reverse',None)layer.layer_bw=_LSTMSingleLayer.from_params(wi,wh,bi,bh)returnlayer
[docs]classLSTM(torch.nn.Module):r"""A quantizable long short-term memory (LSTM). For the description and the argument types, please, refer to :class:`~torch.nn.LSTM` Attributes: layers : instances of the `_LSTMLayer` .. note:: To access the weights and biases, you need to access them per layer. See examples below. Examples:: >>> import torch.ao.nn.quantizable as nnqa >>> rnn = nnqa.LSTM(10, 20, 2) >>> input = torch.randn(5, 3, 10) >>> h0 = torch.randn(2, 3, 20) >>> c0 = torch.randn(2, 3, 20) >>> output, (hn, cn) = rnn(input, (h0, c0)) >>> # To get the weights: >>> # xdoctest: +SKIP >>> print(rnn.layers[0].weight_ih) tensor([[...]]) >>> print(rnn.layers[0].weight_hh) AssertionError: There is no reverse path in the non-bidirectional layer """_FLOAT_MODULE=torch.nn.LSTMdef__init__(self,input_size:int,hidden_size:int,num_layers:int=1,bias:bool=True,batch_first:bool=False,dropout:float=0.,bidirectional:bool=False,device=None,dtype=None)->None:factory_kwargs={'device':device,'dtype':dtype}super().__init__()self.input_size=input_sizeself.hidden_size=hidden_sizeself.num_layers=num_layersself.bias=biasself.batch_first=batch_firstself.dropout=float(dropout)self.bidirectional=bidirectionalself.training=False# We don't want to train using this modulenum_directions=2ifbidirectionalelse1ifnotisinstance(dropout,numbers.Number)ornot0<=dropout<=1or \
isinstance(dropout,bool):raiseValueError("dropout should be a number in range [0, 1] ""representing the probability of an element being ""zeroed")ifdropout>0:warnings.warn("dropout option for quantizable LSTM is ignored. ""If you are training, please, use nn.LSTM version ""followed by `prepare` step.")ifnum_layers==1:warnings.warn("dropout option adds dropout after all but last ""recurrent layer, so non-zero dropout expects "f"num_layers greater than 1, but got dropout={dropout} "f"and num_layers={num_layers}")layers=[_LSTMLayer(self.input_size,self.hidden_size,self.bias,batch_first=False,bidirectional=self.bidirectional,**factory_kwargs)]forlayerinrange(1,num_layers):layers.append(_LSTMLayer(self.hidden_size,self.hidden_size,self.bias,batch_first=False,bidirectional=self.bidirectional,**factory_kwargs))self.layers=torch.nn.ModuleList(layers)defforward(self,x:Tensor,hidden:Optional[Tuple[Tensor,Tensor]]=None):ifself.batch_first:x=x.transpose(0,1)max_batch_size=x.size(1)num_directions=2ifself.bidirectionalelse1ifhiddenisNone:zeros=torch.zeros(num_directions,max_batch_size,self.hidden_size,dtype=torch.float,device=x.device)zeros.squeeze_(0)ifx.is_quantized:zeros=torch.quantize_per_tensor(zeros,scale=1.0,zero_point=0,dtype=x.dtype)hxcx=[(zeros,zeros)for_inrange(self.num_layers)]else:hidden_non_opt=torch.jit._unwrap_optional(hidden)ifisinstance(hidden_non_opt[0],Tensor):hx=hidden_non_opt[0].reshape(self.num_layers,num_directions,max_batch_size,self.hidden_size)cx=hidden_non_opt[1].reshape(self.num_layers,num_directions,max_batch_size,self.hidden_size)hxcx=[(hx[idx].squeeze(0),cx[idx].squeeze(0))foridxinrange(self.num_layers)]else:hxcx=hidden_non_opthx_list=[]cx_list=[]foridx,layerinenumerate(self.layers):x,(h,c)=layer(x,hxcx[idx])hx_list.append(torch.jit._unwrap_optional(h))cx_list.append(torch.jit._unwrap_optional(c))hx_tensor=torch.stack(hx_list)cx_tensor=torch.stack(cx_list)# We are creating another dimension for bidirectional case# need to collapse ithx_tensor=hx_tensor.reshape(-1,hx_tensor.shape[-2],hx_tensor.shape[-1])cx_tensor=cx_tensor.reshape(-1,cx_tensor.shape[-2],cx_tensor.shape[-1])ifself.batch_first:x=x.transpose(0,1)returnx,(hx_tensor,cx_tensor)def_get_name(self):return'QuantizableLSTM'@classmethoddeffrom_float(cls,other,qconfig=None):assertisinstance(other,cls._FLOAT_MODULE)assert(hasattr(other,'qconfig')orqconfig)observed=cls(other.input_size,other.hidden_size,other.num_layers,other.bias,other.batch_first,other.dropout,other.bidirectional)observed.qconfig=getattr(other,'qconfig',qconfig)foridxinrange(other.num_layers):observed.layers[idx]=_LSTMLayer.from_float(other,idx,qconfig,batch_first=False)# TODO: Remove setting observed to eval to enable QAT.observed.eval()observed=torch.ao.quantization.prepare(observed,inplace=True)returnobserved@classmethoddeffrom_observed(cls,other):# The whole flow is float -> observed -> quantized# This class does float -> observed onlyraiseNotImplementedError("It looks like you are trying to convert a ""non-quantizable LSTM module. Please, see ""the examples on quantizable LSTMs.")
Docs
Access comprehensive developer documentation for PyTorch
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: Cookies Policy.