TensorDictTokenizer¶
- class torchrl.data.TensorDictTokenizer(tokenizer, max_length, key='text', padding='max_length', truncation=True, return_tensordict=True, device=None)[source]¶
Factory for a process function that applies a tokenizer over a text example.
- Parameters:
tokenizer (tokenizer from transformers library) – the tokenizer to use.
max_length (int) – maximum length of the sequence.
key (str, optional) – the key where to find the text. Defaults to
"text"
.padding (str, optional) – type of padding. Defaults to
"max_length"
.truncation (bool, optional) – whether the sequences should be truncated to max_length.
return_tensordict (bool, optional) – if
True
, a TensoDict is returned. Otherwise, a the orignal data will be returned.device (torch.device, optional) – the device where to store the data. This option is ignored if
return_tensordict=False
.
- See transformers library for more information about tokenizers:
Padding and truncation: https://huggingface.co/docs/transformers/pad_truncation
Returns: a
tensordict.TensorDict
instance with the same batch-size as the input data.Examples
>>> from transformers import AutoTokenizer >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> tokenizer.pad_token = 100 >>> process = TensorDictTokenizer(tokenizer, max_length=10) >>> # example with a single input >>> example = {"text": "I am a little worried"} >>> process(example) TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False), input_ids: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) >>> # example with a multiple inputs >>> example = {"text": ["Let me reassure you", "It will be ok"]} >>> process(example) TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False), input_ids: Tensor(shape=torch.Size([2, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False)