Shortcuts

torchtext.transforms

Transforms are common text transforms. They can be chained together using torch.nn.Sequential or using torchtext.transforms.Sequential to support torch-scriptability.

SentencePieceTokenizer

class torchtext.transforms.SentencePieceTokenizer(sp_model_path: str)[source]

Transform for Sentence Piece tokenizer from pre-trained sentencepiece model

Additional details: https://github.com/google/sentencepiece

Parameters:

sp_model_path (str) – Path to pre-trained sentencepiece model

Example
>>> from torchtext.transforms import SentencePieceTokenizer
>>> transform = SentencePieceTokenizer("spm_model")
>>> transform(["hello world", "attention is all you need!"])
Tutorials using SentencePieceTokenizer:
SST-2 Binary text classification with XLM-RoBERTa model

SST-2 Binary text classification with XLM-RoBERTa model

SST-2 Binary text classification with XLM-RoBERTa model
forward(input: Any) Any[source]
Parameters:

input (Union[str, List[str]]) – Input sentence or list of sentences on which to apply tokenizer.

Returns:

tokenized text

Return type:

Union[List[str], List[List[str]]]

GPT2BPETokenizer

class torchtext.transforms.GPT2BPETokenizer(encoder_json_path: str, vocab_bpe_path: str, return_tokens: bool = False)[source]

Transform for GPT-2 BPE Tokenizer.

Reimplements openai GPT-2 BPE in TorchScript. Original openai implementation https://github.com/openai/gpt-2/blob/master/src/encoder.py

Parameters:
  • encoder_json_path (str) – Path to GPT-2 BPE encoder json file.

  • vocab_bpe_path (str) – Path to bpe vocab file.

  • return_tokens – Indicate whether to return split tokens. If False, it will return encoded token IDs as strings (default: False)

forward(input: Any) Any[source]
Parameters:

input (Union[str, List[str]]) – Input sentence or list of sentences on which to apply tokenizer.

Returns:

tokenized text

Return type:

Union[List[str], List[List(str)]]

CLIPTokenizer

class torchtext.transforms.CLIPTokenizer(merges_path: str, encoder_json_path: Optional[str] = None, num_merges: Optional[int] = None, return_tokens: bool = False)[source]

Transform for CLIP Tokenizer. Based on Byte-Level BPE.

Reimplements CLIP Tokenizer in TorchScript. Original implementation: https://github.com/mlfoundations/open_clip/blob/main/src/clip/tokenizer.py

This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will be encoded differently whether it is at the beginning of the sentence (without space) or not.

The below code snippet shows how to use the CLIP tokenizer with encoder and merges file taken from the original paper implementation.

Example
>>> from torchtext.transforms import CLIPTokenizer
>>> MERGES_FILE = "http://download.pytorch.org/models/text/clip_merges.bpe"
>>> ENCODER_FILE = "http://download.pytorch.org/models/text/clip_encoder.json"
>>> tokenizer = CLIPTokenizer(merges_path=MERGES_FILE, encoder_json_path=ENCODER_FILE)
>>> tokenizer("the quick brown fox jumped over the lazy dog")
Parameters:
  • merges_path (str) – Path to bpe merges file.

  • encoder_json_path (str) – Optional, path to BPE encoder json file. When specified, this is used to infer num_merges.

  • num_merges (int) – Optional, number of merges to read from the bpe merges file.

  • return_tokens – Indicate whether to return split tokens. If False, it will return encoded token IDs as strings (default: False)

forward(input: Any) Any[source]
Parameters:

input (Union[str, List[str]]) – Input sentence or list of sentences on which to apply tokenizer.

Returns:

tokenized text

Return type:

Union[List[str], List[List(str)]]

RegexTokenizer

class torchtext.transforms.RegexTokenizer(patterns_list)[source]

Regex tokenizer for a string sentence that applies all regex replacements defined in patterns_list. It is backed by the C++ RE2 regular expression engine from Google.

Parameters:
  • patterns_list (List[Tuple[str, str]]) – a list of tuples (ordered pairs) which contain the regex pattern string

  • element. (as the first element and the replacement string as the second) –

Caveats
  • The RE2 library does not support arbitrary lookahead or lookbehind assertions, nor does it support backreferences. Look at the docs here for more info.

  • The final tokenization step always uses spaces as separators. To split strings based on a specific regex pattern, similar to Python’s re.split, a tuple of ('<regex_pattern>', ' ') can be provided.

Example
Regex tokenization based on (patterns, replacements) list.
>>> import torch
>>> from torchtext.transforms import RegexTokenizer
>>> test_sample = 'Basic Regex Tokenization for a Line of Text'
>>> patterns_list = [
    (r''', ' '  '),
    (r'"', '')]
>>> reg_tokenizer = RegexTokenizer(patterns_list)
>>> jit_reg_tokenizer = torch.jit.script(reg_tokenizer)
>>> tokens = jit_reg_tokenizer(test_sample)
Regex tokenization based on (single_pattern, ' ') list.
>>> import torch
>>> from torchtext.transforms import RegexTokenizer
>>> test_sample = 'Basic.Regex,Tokenization_for+a..Line,,of  Text'
>>> patterns_list = [
    (r'[,._+ ]+', r' ')]
>>> reg_tokenizer = RegexTokenizer(patterns_list)
>>> jit_reg_tokenizer = torch.jit.script(reg_tokenizer)
>>> tokens = jit_reg_tokenizer(test_sample)
forward(line: str) List[str][source]
Parameters:

lines (str) – a text string to tokenize.

Returns:

a token list after regex.

Return type:

List[str]

BERTTokenizer

class torchtext.transforms.BERTTokenizer(vocab_path: str, do_lower_case: bool = True, strip_accents: Optional[bool] = None, return_tokens=False, never_split: Optional[List[str]] = None)[source]

Transform for BERT Tokenizer.

Based on WordPiece algorithm introduced in paper: https://static.googleusercontent.com/media/research.google.com/ja//pubs/archive/37842.pdf

The backend kernel implementation is taken and modified from https://github.com/LieluoboAi/radish.

See PR https://github.com/pytorch/text/pull/1707 summary for more details.

The below code snippet shows how to use the BERT tokenizer using the pre-trained vocab files.

Example
>>> from torchtext.transforms import BERTTokenizer
>>> VOCAB_FILE = "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt"
>>> tokenizer = BERTTokenizer(vocab_path=VOCAB_FILE, do_lower_case=True, return_tokens=True)
>>> tokenizer("Hello World, How are you!") # single sentence input
>>> tokenizer(["Hello World","How are you!"]) # batch input
Parameters:
  • vocab_path (str) – Path to pre-trained vocabulary file. The path can be either local or URL.

  • do_lower_case (Optional[bool]) – Indicate whether to do lower case. (default: True)

  • strip_accents (Optional[bool]) – Indicate whether to strip accents. (default: None)

  • return_tokens (bool) – Indicate whether to return tokens. If false, returns corresponding token IDs as strings (default: False)

  • never_split (Optional[List[str]]) – Collection of tokens which will not be split during tokenization. (default: None)

forward(input: Any) Any[source]
Parameters:

input (Union[str, List[str]]) – Input sentence or list of sentences on which to apply tokenizer.

Returns:

tokenized text

Return type:

Union[List[str], List[List(str)]]

VocabTransform

class torchtext.transforms.VocabTransform(vocab: Vocab)[source]

Vocab transform to convert input batch of tokens into corresponding token ids

Parameters:

vocab – an instance of torchtext.vocab.Vocab class.

Example

>>> import torch
>>> from torchtext.vocab import vocab
>>> from torchtext.transforms import VocabTransform
>>> from collections import OrderedDict
>>> vocab_obj = vocab(OrderedDict([('a', 1), ('b', 1), ('c', 1)]))
>>> vocab_transform = VocabTransform(vocab_obj)
>>> output = vocab_transform([['a','b'],['a','b','c']])
>>> jit_vocab_transform = torch.jit.script(vocab_transform)
Tutorials using VocabTransform:
SST-2 Binary text classification with XLM-RoBERTa model

SST-2 Binary text classification with XLM-RoBERTa model

SST-2 Binary text classification with XLM-RoBERTa model
forward(input: Any) Any[source]
Parameters:

input (Union[List[str], List[List[str]]]) – Input batch of token to convert to correspnding token ids

Returns:

Converted input into corresponding token ids

Return type:

Union[List[int], List[List[int]]]

ToTensor

class torchtext.transforms.ToTensor(padding_value: Optional[int] = None, dtype: dtype = torch.int64)[source]

Convert input to torch tensor

Parameters:
  • padding_value (Optional[int]) – Pad value to make each input in the batch of length equal to the longest sequence in the batch.

  • dtype (torch.dtype) – torch.dtype of output tensor

forward(input: Any) Tensor[source]
Parameters:

input (Union[List[int], List[List[int]]]) – Sequence or batch of token ids

Return type:

Tensor

LabelToIndex

class torchtext.transforms.LabelToIndex(label_names: Optional[List[str]] = None, label_path: Optional[str] = None, sort_names=False)[source]

Transform labels from string names to ids.

Parameters:
  • label_names (Optional[List[str]]) – a list of unique label names

  • label_path (Optional[str]) – a path to file containing unique label names containing 1 label per line. Note that either label_names or label_path should be supplied but not both.

forward(input: Any) Any[source]
Parameters:

input (Union[str, List[str]]) – Input labels to convert to corresponding ids

Return type:

Union[int, List[int]]

Truncate

class torchtext.transforms.Truncate(max_seq_len: int)[source]

Truncate input sequence

Parameters:

max_seq_len (int) – The maximum allowable length for input sequence

Tutorials using Truncate:
SST-2 Binary text classification with XLM-RoBERTa model

SST-2 Binary text classification with XLM-RoBERTa model

SST-2 Binary text classification with XLM-RoBERTa model
forward(input: Any) Any[source]
Parameters:

input (Union[List[Union[str, int]], List[List[Union[str, int]]]]) – Input sequence or batch of sequence to be truncated

Returns:

Truncated sequence

Return type:

Union[List[Union[str, int]], List[List[Union[str, int]]]]

AddToken

class torchtext.transforms.AddToken(token: Union[int, str], begin: bool = True)[source]

Add token to beginning or end of sequence

Parameters:
  • token (Union[int, str]) – The token to be added

  • begin (bool, optional) – Whether to insert token at start or end or sequence, defaults to True

Tutorials using AddToken:
SST-2 Binary text classification with XLM-RoBERTa model

SST-2 Binary text classification with XLM-RoBERTa model

SST-2 Binary text classification with XLM-RoBERTa model
forward(input: Any) Any[source]
Parameters:

input (Union[List[Union[str, int]], List[List[Union[str, int]]]]) – Input sequence or batch

Sequential

class torchtext.transforms.Sequential(*args: Module)[source]
class torchtext.transforms.Sequential(arg: OrderedDict[str, Module])

A container to host a sequence of text transforms.

Tutorials using Sequential:
SST-2 Binary text classification with XLM-RoBERTa model

SST-2 Binary text classification with XLM-RoBERTa model

SST-2 Binary text classification with XLM-RoBERTa model
forward(input: Any) Any[source]
Parameters:

input (Any) – Input sequence or batch. The input type must be supported by the first transform in the sequence.

PadTransform

class torchtext.transforms.PadTransform(max_length: int, pad_value: int)[source]

Pad tensor to a fixed length with given padding value.

Parameters:
  • max_length (int) – Maximum length to pad to

  • pad_value (bool) – Value to pad the tensor with

forward(x: Tensor) Tensor[source]
Parameters:

x (Tensor) – The tensor to pad

Returns:

Tensor padded up to max_length with pad_value

Return type:

Tensor

StrToIntTransform

class torchtext.transforms.StrToIntTransform[source]

Convert string tokens to integers (either single sequence or batch).

forward(input: Any) Any[source]
Parameters:

input (Union[List[str], List[List[str]]]) – sequence or batch of string tokens to convert

Returns:

sequence or batch converted into corresponding token ids

Return type:

Union[List[int], List[List[int]]]

CharBPETokenizer

class torchtext.transforms.CharBPETokenizer(bpe_encoder_path: str, bpe_merges_path: str, return_tokens: bool = False, unk_token: Optional[str] = None, suffix: Optional[str] = None, special_tokens: Optional[List[str]] = None)[source]

Transform for a Character Byte-Pair-Encoding Tokenizer.

:param : param bpe_encoder_path: Path to the BPE encoder json file. :param : type bpe_encoder_path: str :param : param bpe_merges_path: Path to the BPE merges text file. :param : type bpe_merges_path: str :param : param return_tokens: Indicate whether to return split tokens. If False, it will return encoded token IDs (default: False). :param : type return_tokens: bool :param : param unk_token: The unknown token. If provided, it must exist in encoder. :param : type unk_token: Optional[str] :param : param suffix: The suffix to be used for every subword that is an end-of-word. :param : type suffix: Optional[str] :param : param special_tokens: Special tokens which should not be split into individual characters. If provided, these must exist in encoder. :param : type special_tokens: Optional[List[str]]

forward(input: Union[str, List[str]]) Union[List, List[List]][source]

Forward method of module encodes strings or list of strings into token ids

Parameters:

input – Input sentence or list of sentences on which to apply tokenizer.

Returns:

A list or list of lists of token IDs

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources