Shortcuts

TensorDictTokenizer

class torchrl.data.TensorDictTokenizer(tokenizer, max_length, key='text', padding='max_length', truncation=True, return_tensordict=True, device=None)[source]

Factory for a process function that applies a tokenizer over a text example.

Parameters:
  • tokenizer (tokenizer from transformers library) – the tokenizer to use.

  • max_length (int) – maximum length of the sequence.

  • key (str, optional) – the key where to find the text. Defaults to "text".

  • padding (str, optional) – type of padding. Defaults to "max_length".

  • truncation (bool, optional) – whether the sequences should be truncated to max_length.

  • return_tensordict (bool, optional) – if True, a TensoDict is returned. Otherwise, a the orignal data will be returned.

  • device (torch.device, optional) – the device where to store the data. This option is ignored if return_tensordict=False.

See transformers library for more information about tokenizers:

Padding and truncation: https://huggingface.co/docs/transformers/pad_truncation

Returns: a tensordict.TensorDict instance with the same batch-size as the input data.

Examples

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> tokenizer.pad_token = 100
>>> process = TensorDictTokenizer(tokenizer, max_length=10)
>>> # example with a single input
>>> example = {"text": "I am a little worried"}
>>> process(example)
TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
>>> # example with a multiple inputs
>>> example = {"text": ["Let me reassure you", "It will be ok"]}
>>> process(example)
TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources