Shortcuts

PromptTensorDictTokenizer

class torchrl.data.PromptTensorDictTokenizer(tokenizer, max_length, key='prompt', padding='max_length', truncation=True, return_tensordict=True, device=None)[source]

Tokenization recipe for prompt datasets.

Returns a tokenizer function, which reads an example containing a prompt and a label and tokenizes them.

Parameters:
  • tokenizer (tokenizer from transformers library) – the tokenizer to use.

  • max_length (int) – maximum length of the sequence.

  • key (str, optional) – the key where to find the text. Defaults to "prompt".

  • padding (str, optional) – type of padding. Defaults to "max_length".

  • truncation (bool, optional) – whether the sequences should be truncated to max_length.

  • return_tensordict (bool, optional) – if True, a TensoDict is returned. Otherwise, a the orignal data will be returned.

  • device (torch.device, optional) – the device where to store the data. This option is ignored if return_tensordict=False.

The __call__() method of this class will execute the following operations:

  • Read the prompt string contacted with the label string and tokenize them. The results will be stored in the "input_ids" TensorDict entry.

  • Write a "prompt_rindex" entry with the index of the last valid token from the prompt.

  • Write a "valid_sample" which identifies which entry in the tensordict has eough toknens to meet the max_length criterion.

  • Return a tensordict.TensorDict instance with tokenized inputs.

The tensordict batch-size will match the batch-size of the input.

Examples

>>> from transformers import AutoTokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> tokenizer.pad_token = tokenizer.eos_token
>>> example = {
...     "prompt": ["This prompt is long enough to be tokenized.", "this one too!"],
...     "label": ["Indeed it is.", 'It might as well be.'],
... }
>>> fn = PromptTensorDictTokenizer(tokenizer, 50)
>>> print(fn(example))
TensorDict(
    fields={
        attention_mask: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False),
        input_ids: Tensor(shape=torch.Size([2, 50]), device=cpu, dtype=torch.int64, is_shared=False),
        prompt_rindex: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
        valid_sample: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources