Shortcuts

RolloutFromModel

class torchrl.data.RolloutFromModel(model, ref_model, reward_model, kl_coef=0.1, max_new_tokens=50, score_clip=10.0, kl_scheduler: KLControllerBase | None = None, num_steps: int | None = None)[source]

A class for performing rollouts with causal language models.

It is assumed that the model this class wraps takes as input tokenized text and whose task is to predict the next word in a sentence having read the n previous words.

Parameters:
  • model (transformers.Transformer) – the model to be used. Should have a generate() method.

  • ref_model (transformers.Transformer) – a frozen version of model where params are in their initial configuration. This is used to compute a KL penalty for the reward, to stop the model from straying too far from the reference model during training.

  • reward_model – (nn.Module, tensordict.nn.TensorDictModule): a model which, given input_ids and attention_mask, calculates rewards for each token and end_scores (the reward for the final token in each sequence).

  • kl_coef – (float, optional): initial kl coefficient.

  • max_new_tokens (int, optional) – the maximum length of the sequence. Defaults to 50.

  • score_clip (float, optional) – Scores from the reward model are clipped to the range (-score_clip, score_clip). Defaults to 10.

  • kl_scheduler (KLControllerBase, optional) – the KL coefficient scheduler.

  • num_steps (int, optional) – number of steps between two optimization.

Examples

>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules.models.rlhf import GPT2RewardModel
>>> from torchrl.data.rlhf.utils import RolloutFromModel
>>> from torchrl.data.rlhf.dataset import get_dataloader
>>> from torchrl.data.rlhf.prompt import PromptData
>>> from transformers import GPT2LMHeadModel
>>>
>>> dl = get_dataloader(
...     batch_size=4,
...     block_size=550,
...     tensorclass_type=PromptData,
...     device="cpu",
...     dataset_name="CarperAI/openai_summarize_tldr",
... )
>>> model = GPT2LMHeadModel.from_pretrained("gpt2")
>>> # we load ref_model with random weights so it differs from model
>>> ref_model = GPT2LMHeadModel(GPT2LMHeadModel.config_class())
>>> reward_model = GPT2RewardModel(model_path="gpt2")
>>> rollout_from_model = RolloutFromModel(model, ref_model, reward_model)
>>>
>>> batch = next(dl)
>>> rollout = rollout_from_model.rollout_from_data(batch)
>>> rollout
TensorDict(
    fields={
        action: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False),
        attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False),
        input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False),
        next: TensorDict(
            fields={
                attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False),
                done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False),
                reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                reward_kl: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                reward_raw: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([4, 50]),
            device=cpu,
            is_shared=False),
        sample_log_prob: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([4, 50]),
    device=cpu,
    is_shared=False)
create_rollout_td(batch, generated, log_probs, log_ratio)[source]

A TensorDict wrapper for generated data.

This function takes a batch plus the generated tokens and replicates the tensordict structure that would have been obtained from a rollout with a TorchRL env that sampled one token each timestep.

Parameters:
  • batch (TensorDict) – A batch of data containing the original prompt together with a field “rindex” indicating the right index of the prompt.

  • generated (torch.Tensor) – Tokenized prompt followed by generated tokens. This can be obtained by calling the generate method.

  • log_probs (torch.Tensor) – The log probabilities of the generated tokens. Can be obtained by calling the generate method.

  • log_ratio (torch.Tensor) – The log ratio of the probabilities of the generated tokens according to the generative model and the reference model. Can be obtained by calling the generate method.

Returns:

  • "action": the sequence of actions (generated tokens)

  • "input_ids": the input_ids passed to the generative model at each time step.

  • "attention_mask": the attention_masks passed to the generative model at each time step

  • "sample_log_prob": the log probability of each token during generation

  • ("next", "input_ids"): the sequence of tokens after generation. Makes up part of the inputs that will be used for generating the next token.

  • ("next", "attention_mask"): updated attention_mask after token has been generated. Passed to the generative model on the next time step

  • ("next", "terminated"): Boolean array indicating whether we’ve reached a terminal state (either because we generated EOS token or because we reached the token limit)

  • ("next", "done"): Boolean array indicating whether we’ve reached a final state. Currently a copy of "terminated".

  • ("next", "reward"): The reward received at each time step

  • ("next", "reward_raw"): The raw reward from the reward model, without the KL term. This is mainly for debugging and logging, it is not used in training

  • ("next", "reward_kl"): The KL term from the reward. This is mainly for debugging and logging, it is not used in training.

Return type:

A TensorDict with the following keys

generate(batch: PromptData, generation_config=None)[source]

Generates a sequence of tokens from a batch of data sampled from the data collector.

Parameters:
  • batch (PromptData) – the data to be used. Must have input_ids and prompt_rindex fields.

  • generation_config (GenerationConfig, optional) – the configuration for the call to generate.

Returns:

a [B x (Ti +To)] sequence of integers (tokens),

where Ti is the length of the input sequence and To is the length of the generated sequence.

log_probs_gen: the log-probabilities of the token generated. log_ratio: the log ratio between probabilities under the generative

model and the frozen version.

Return type:

generated (torch.Tensor)

static logprobs_of_labels(logits, labels)[source]

Log probabilities of the labels.

These are calculated from the logits. The labels (token ids) are used to index the logits along the relevant dimension.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources