Source code for torchtext.datasets.sequence_tagging

from .. import data
import random

[docs]class SequenceTaggingDataset(data.Dataset): """Defines a dataset for sequence tagging. Examples in this dataset contain paired lists -- paired list of words and tags. For example, in the case of part-of-speech tagging, an example is of the form [I, love, PyTorch, .] paired with [PRON, VERB, PROPN, PUNCT] See torchtext/test/ on how to use this class. """ @staticmethod def sort_key(example): for attr in dir(example): if not callable(getattr(example, attr)) and \ not attr.startswith("__"): return len(getattr(example, attr)) return 0
[docs] def __init__(self, path, fields, encoding="utf-8", separator="\t", **kwargs): examples = [] columns = [] with open(path, encoding=encoding) as input_file: for line in input_file: line = line.strip() if line == "": if columns: examples.append(data.Example.fromlist(columns, fields)) columns = [] else: for i, column in enumerate(line.split(separator)): if len(columns) < i + 1: columns.append([]) columns[i].append(column) if columns: examples.append(data.Example.fromlist(columns, fields)) super(SequenceTaggingDataset, self).__init__(examples, fields, **kwargs)
[docs]class UDPOS(SequenceTaggingDataset): # Universal Dependencies English Web Treebank. # Download original at # License: urls = [''] dirname = 'en-ud-v2' name = 'udpos'
[docs] @classmethod def splits(cls, fields, root=".data", train="en-ud-tag.v2.train.txt", validation="", test="en-ud-tag.v2.test.txt", **kwargs): """Downloads and loads the Universal Dependencies Version 2 POS Tagged data. """ return super(UDPOS, cls).splits( fields=fields, root=root, train=train, validation=validation, test=test, **kwargs)
[docs]class CoNLL2000Chunking(SequenceTaggingDataset): # CoNLL 2000 Chunking Dataset # urls = ['', ''] dirname = '' name = 'conll2000'
[docs] @classmethod def splits(cls, fields, root=".data", train="train.txt", test="test.txt", validation_frac=0.1, **kwargs): """Downloads and loads the CoNLL 2000 Chunking dataset. NOTE: There is only a train and test dataset so we use 10% of the train set as validation """ train, test = super(CoNLL2000Chunking, cls).splits( fields=fields, root=root, train=train, test=test, separator=' ', **kwargs) # HACK: Saving the sort key function as the split() call removes it sort_key = train.sort_key # Now split the train set # Force a random seed to make the split deterministic random.seed(0) train, val = train.split(1 - validation_frac, random_state=random.getstate()) # Reset the seed random.seed() # HACK: Set the sort key train.sort_key = sort_key val.sort_key = sort_key return train, val, test


