Shortcuts

Source code for torchtext.datasets.babi

import os
from io import open

import torch

from ..data import Dataset, Field, Example, Iterator


class BABI20Field(Field):

    def __init__(self, memory_size, **kwargs):
        super(BABI20Field, self).__init__(**kwargs)
        self.memory_size = memory_size
        self.unk_token = None
        self.batch_first = True

    def preprocess(self, x):
        if isinstance(x, list):
            return [super(BABI20Field, self).preprocess(s) for s in x]
        else:
            return super(BABI20Field, self).preprocess(x)

    def pad(self, minibatch):
        if isinstance(minibatch[0][0], list):
            self.fix_length = max(max(len(x) for x in ex) for ex in minibatch)
            padded = []
            for ex in minibatch:
                # sentences are indexed in reverse order and truncated to memory_size
                nex = ex[::-1][:self.memory_size]
                padded.append(
                    super(BABI20Field, self).pad(nex)
                    + [[self.pad_token] * self.fix_length]
                    * (self.memory_size - len(nex)))
            self.fix_length = None
            return padded
        else:
            return super(BABI20Field, self).pad(minibatch)

    def numericalize(self, arr, device=None):
        if isinstance(arr[0][0], list):
            tmp = [
                super(BABI20Field, self).numericalize(x, device=device).data
                for x in arr
            ]
            arr = torch.stack(tmp)
            if self.sequential:
                arr = arr.contiguous()
            return arr
        else:
            return super(BABI20Field, self).numericalize(arr, device=device)


[docs]class BABI20(Dataset): urls = ['http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2.tar.gz'] name = '' dirname = ''
[docs] def __init__(self, path, text_field, only_supporting=False, **kwargs): fields = [('story', text_field), ('query', text_field), ('answer', text_field)] self.sort_key = lambda x: len(x.query) with open(path, 'r', encoding="utf-8") as f: triplets = self._parse(f, only_supporting) examples = [Example.fromlist(triplet, fields) for triplet in triplets] super(BABI20, self).__init__(examples, fields, **kwargs)
@staticmethod def _parse(file, only_supporting): data, story = [], [] for line in file: tid, text = line.rstrip('\n').split(' ', 1) if tid == '1': story = [] # sentence if text.endswith('.'): story.append(text[:-1]) # question else: # remove any leading or trailing whitespace after splitting query, answer, supporting = (x.strip() for x in text.split('\t')) if only_supporting: substory = [story[int(i) - 1] for i in supporting.split()] else: substory = [x for x in story if x] data.append((substory, query[:-1], answer)) # remove '?' story.append("") return data
[docs] @classmethod def splits(cls, text_field, path=None, root='.data', task=1, joint=False, tenK=False, only_supporting=False, train=None, validation=None, test=None, **kwargs): assert isinstance(task, int) and 1 <= task <= 20 if tenK: cls.dirname = os.path.join('tasks_1-20_v1-2', 'en-valid-10k') else: cls.dirname = os.path.join('tasks_1-20_v1-2', 'en-valid') if path is None: path = cls.download(root) if train is None: if joint: # put all tasks together for joint learning train = 'all_train.txt' if not os.path.isfile(os.path.join(path, train)): with open(os.path.join(path, train), 'w') as tf: for task in range(1, 21): with open( os.path.join(path, 'qa' + str(task) + '_train.txt')) as f: tf.write(f.read()) else: train = 'qa' + str(task) + '_train.txt' if validation is None: if joint: # put all tasks together for joint learning validation = 'all_valid.txt' if not os.path.isfile(os.path.join(path, validation)): with open(os.path.join(path, validation), 'w') as tf: for task in range(1, 21): with open( os.path.join(path, 'qa' + str(task) + '_valid.txt')) as f: tf.write(f.read()) else: validation = 'qa' + str(task) + '_valid.txt' if test is None: test = 'qa' + str(task) + '_test.txt' return super(BABI20, cls).splits(path=path, root=root, text_field=text_field, train=train, validation=validation, test=test, **kwargs)
@classmethod def iters(cls, batch_size=32, root='.data', memory_size=50, task=1, joint=False, tenK=False, only_supporting=False, sort=False, shuffle=False, device=None, **kwargs): text = BABI20Field(memory_size) train, val, test = BABI20.splits(text, root=root, task=task, joint=joint, tenK=tenK, only_supporting=only_supporting, **kwargs) text.build_vocab(train) return Iterator.splits((train, val, test), batch_size=batch_size, sort=sort, shuffle=shuffle, device=device)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources