from collections import namedtuple
import torch

PackedSequence_ = namedtuple('PackedSequence', ['data', 'batch_sizes'])

class PackedSequence(PackedSequence_):
r"""Holds the data and list of batch_sizes of a packed sequence.

All RNN modules accept packed sequences as inputs.

Note:
Instances of this class should never be created manually. They are meant
to be instantiated by functions like :func:pack_padded_sequence.

Attributes:
data (Variable): Variable containing packed sequence
batch_sizes (list[int]): list of integers holding information about
the batch size at each sequence step
"""
pass

r"""Packs a Variable containing padded sequences of variable length.

Input can be of size TxBx* where T is the length of the longest sequence
(equal to lengths[0]), B is the batch size, and * is any number of
dimensions (including 0). If batch_first is True BxTx* inputs are
expected.

The sequences should be sorted by length in a decreasing order, i.e.
input[:,0] should be the longest sequence, and input[:,B-1] the
shortest one.

Note:
This function accept any input that has at least two dimensions. You
can apply it to pack the labels, and use the output of the RNN with
them to compute the loss directly. A Variable can be retrieved from
a :class:PackedSequence object by accessing its .data attribute.

Arguments:
input (Variable): padded batch of variable length sequences.
lengths (list[int]): list of sequences lengths of each batch element.
batch_first (bool, optional): if True, the input is expected in BxTx*
format.

Returns:
a :class:PackedSequence object
"""
if lengths[-1] <= 0:
raise ValueError("length of all samples has to be greater than 0, "
"but found an element in 'lengths' that is <=0")
if batch_first:
input = input.transpose(0, 1)

steps = []
batch_sizes = []
lengths_iter = reversed(lengths)
batch_size = input.size(1)
if len(lengths) != batch_size:
raise ValueError("lengths array has incorrect size")

prev_l = 0
for i, l in enumerate(lengths_iter):
if l > prev_l:
c_batch_size = batch_size - i
steps.append(input[prev_l:l, :c_batch_size].contiguous().view(-1, *input.size()[2:]))
batch_sizes.extend([c_batch_size] * (l - prev_l))
prev_l = l
elif prev_l > l:  # remember that new_length is the preceding length in the array
raise ValueError("lengths array has to be sorted in decreasing order")

return PackedSequence(torch.cat(steps), batch_sizes)

r"""Pads a packed batch of variable length sequences.

It is an inverse operation to :func:pack_padded_sequence.

The returned Variable's data will be of size TxBx*, where T is the length
of the longest sequence and B is the batch size. If batch_first is True,
the data will be transposed into BxTx* format.

Batch elements will be ordered decreasingly by their length.

Arguments:
batch_first (bool, optional): if True, the output will be in BxTx*
format.

Returns:
Tuple of Variable containing the padded sequence, and a list of lengths
of each sequence in the batch.
"""
var_data, batch_sizes = sequence
max_batch_size = batch_sizes[0]
output = Variable(output)

lengths = []
data_offset = 0
prev_batch_size = batch_sizes[0]
prev_i = 0
for i, batch_size in enumerate(batch_sizes):
if batch_size != prev_batch_size:
l = prev_batch_size * (i - prev_i)
output[prev_i:i, :prev_batch_size] = var_data[data_offset:data_offset + l]
data_offset += l
prev_i = i
dec = prev_batch_size - batch_size
if dec > 0:
lengths.extend((i,) * dec)
prev_batch_size = batch_size

l = prev_batch_size * (len(batch_sizes) - prev_i)
output[prev_i:, :prev_batch_size] = var_data[data_offset:data_offset + l]

lengths.extend((i + 1,) * batch_size)
lengths.reverse()

if batch_first:
output = output.transpose(0, 1)
return output, lengths