Source code for torchtext.datasets.babi
import os
from io import open
import torch
from ..data import Dataset, Field, Example, Iterator
class BABI20Field(Field):
def __init__(self, memory_size, **kwargs):
super(BABI20Field, self).__init__(**kwargs)
self.memory_size = memory_size
self.unk_token = None
self.batch_first = True
def preprocess(self, x):
if isinstance(x, list):
return [super(BABI20Field, self).preprocess(s) for s in x]
else:
return super(BABI20Field, self).preprocess(x)
def pad(self, minibatch):
if isinstance(minibatch[0][0], list):
self.fix_length = max(max(len(x) for x in ex) for ex in minibatch)
padded = []
for ex in minibatch:
# sentences are indexed in reverse order and truncated to memory_size
nex = ex[::-1][:self.memory_size]
padded.append(
super(BABI20Field, self).pad(nex)
+ [[self.pad_token] * self.fix_length]
* (self.memory_size - len(nex)))
self.fix_length = None
return padded
else:
return super(BABI20Field, self).pad(minibatch)
def numericalize(self, arr, device=None):
if isinstance(arr[0][0], list):
tmp = [
super(BABI20Field, self).numericalize(x, device=device).data
for x in arr
]
arr = torch.stack(tmp)
if self.sequential:
arr = arr.contiguous()
return arr
else:
return super(BABI20Field, self).numericalize(arr, device=device)
[docs]class BABI20(Dataset):
urls = ['http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz']
name = ''
dirname = ''
[docs] def __init__(self, path, text_field, only_supporting=False, **kwargs):
fields = [('story', text_field), ('query', text_field), ('answer', text_field)]
self.sort_key = lambda x: len(x.query)
with open(path, 'r', encoding="utf-8") as f:
triplets = self._parse(f, only_supporting)
examples = [Example.fromlist(triplet, fields) for triplet in triplets]
super(BABI20, self).__init__(examples, fields, **kwargs)
@staticmethod
def _parse(file, only_supporting):
data, story = [], []
for line in file:
tid, text = line.rstrip('\n').split(' ', 1)
if tid == '1':
story = []
# sentence
if text.endswith('.'):
story.append(text[:-1])
# question
else:
# remove any leading or trailing whitespace after splitting
query, answer, supporting = (x.strip() for x in text.split('\t'))
if only_supporting:
substory = [story[int(i) - 1] for i in supporting.split()]
else:
substory = [x for x in story if x]
data.append((substory, query[:-1], answer)) # remove '?'
story.append("")
return data
[docs] @classmethod
def splits(cls, text_field, path=None, root='.data', task=1, joint=False, tenK=False,
only_supporting=False, train=None, validation=None, test=None, **kwargs):
assert isinstance(task, int) and 1 <= task <= 20
if tenK:
cls.dirname = os.path.join('tasks_1-20_v1-2', 'en-valid-10k')
else:
cls.dirname = os.path.join('tasks_1-20_v1-2', 'en-valid')
if path is None:
path = cls.download(root)
if train is None:
if joint: # put all tasks together for joint learning
train = 'all_train.txt'
if not os.path.isfile(os.path.join(path, train)):
with open(os.path.join(path, train), 'w') as tf:
for task in range(1, 21):
with open(
os.path.join(path,
'qa' + str(task) + '_train.txt')) as f:
tf.write(f.read())
else:
train = 'qa' + str(task) + '_train.txt'
if validation is None:
if joint: # put all tasks together for joint learning
validation = 'all_valid.txt'
if not os.path.isfile(os.path.join(path, validation)):
with open(os.path.join(path, validation), 'w') as tf:
for task in range(1, 21):
with open(
os.path.join(path,
'qa' + str(task) + '_valid.txt')) as f:
tf.write(f.read())
else:
validation = 'qa' + str(task) + '_valid.txt'
if test is None:
test = 'qa' + str(task) + '_test.txt'
return super(BABI20,
cls).splits(path=path, root=root, text_field=text_field, train=train,
validation=validation, test=test, **kwargs)
@classmethod
def iters(cls, batch_size=32, root='.data', memory_size=50, task=1, joint=False,
tenK=False, only_supporting=False, sort=False, shuffle=False, device=None,
**kwargs):
text = BABI20Field(memory_size)
train, val, test = BABI20.splits(text, root=root, task=task, joint=joint,
tenK=tenK, only_supporting=only_supporting,
**kwargs)
text.build_vocab(train)
return Iterator.splits((train, val, test), batch_size=batch_size, sort=sort,
shuffle=shuffle, device=device)