RolloutFromModel¶
- class torchrl.data.RolloutFromModel(model, ref_model, reward_model, kl_coef=0.1, max_new_tokens=50, score_clip=10.0, kl_scheduler: Optional[KLControllerBase] = None, num_steps: Optional[int] = None)[source]¶
A class for performing rollouts with causal language models.
It is assumed that the model this class wraps takes as input tokenized text and whose task is to predict the next word in a sentence having read the n previous words.
- Parameters:
model (transformers.Transformer) – the model to be used. Should have a
generate()
method.ref_model (transformers.Transformer) – a frozen version of
model
where params are in their initial configuration. This is used to compute a KL penalty for the reward, to stop the model from straying too far from the reference model during training.reward_model – (nn.Module, tensordict.nn.TensorDictModule): a model which, given
input_ids
andattention_mask
, calculates rewards for each token and end_scores (the reward for the final token in each sequence).kl_coef – (
float
, optional): initial kl coefficient.max_new_tokens (int, optional) – the maximum length of the sequence. Defaults to 50.
score_clip (
float
, optional) – Scores from the reward model are clipped to the range(-score_clip, score_clip)
. Defaults to 10.kl_scheduler (KLControllerBase, optional) – the KL coefficient scheduler.
num_steps (int, optional) – number of steps between two optimization.
Examples
>>> from tensordict.nn import TensorDictModule >>> from torchrl.modules.models.rlhf import GPT2RewardModel >>> from torchrl.data.rlhf.utils import RolloutFromModel >>> from torchrl.data.rlhf.dataset import get_dataloader >>> from torchrl.data.rlhf.prompt import PromptData >>> from transformers import GPT2LMHeadModel >>> >>> dl = get_dataloader( ... batch_size=4, ... block_size=550, ... tensorclass_type=PromptData, ... device="cpu", ... dataset_name="CarperAI/openai_summarize_tldr", ... ) >>> model = GPT2LMHeadModel.from_pretrained("gpt2") >>> # we load ref_model with random weights so it differs from model >>> ref_model = GPT2LMHeadModel(GPT2LMHeadModel.config_class()) >>> reward_model = GPT2RewardModel(model_path="gpt2") >>> rollout_from_model = RolloutFromModel(model, ref_model, reward_model) >>> >>> batch = next(dl) >>> rollout = rollout_from_model.rollout_from_data(batch) >>> rollout TensorDict( fields={ action: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False), attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False), input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False), next: TensorDict( fields={ attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False), done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False), input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False), reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), reward_kl: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False), reward_raw: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 50]), device=cpu, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 50]), device=cpu, is_shared=False)
- create_rollout_td(batch, generated, log_probs, log_ratio)[source]¶
A TensorDict wrapper for generated data.
This function takes a batch plus the generated tokens and replicates the tensordict structure that would have been obtained from a rollout with a TorchRL env that sampled one token each timestep.
- Parameters:
batch (TensorDict) – A batch of data containing the original prompt together with a field “rindex” indicating the right index of the prompt.
generated (torch.Tensor) – Tokenized prompt followed by generated tokens. This can be obtained by calling the
generate
method.log_probs (torch.Tensor) – The log probabilities of the generated tokens. Can be obtained by calling the
generate
method.log_ratio (torch.Tensor) – The log ratio of the probabilities of the generated tokens according to the generative model and the reference model. Can be obtained by calling the
generate
method.
- Returns:
"action"
: the sequence of actions (generated tokens)"input_ids"
: the input_ids passed to the generative model at each time step."attention_mask"
: the attention_masks passed to the generative model at each time step"sample_log_prob"
: the log probability of each token during generation("next", "input_ids")
: the sequence of tokens after generation. Makes up part of the inputs that will be used for generating the next token.("next", "attention_mask")
: updated attention_mask after token has been generated. Passed to the generative model on the next time step("next", "terminated")
: Boolean array indicating whether we’ve reached a terminal state (either because we generated EOS token or because we reached the token limit)("next", "done")
: Boolean array indicating whether we’ve reached a final state. Currently a copy of"terminated"
.("next", "reward")
: The reward received at each time step("next", "reward_raw")
: The raw reward from the reward model, without the KL term. This is mainly for debugging and logging, it is not used in training("next", "reward_kl")
: The KL term from the reward. This is mainly for debugging and logging, it is not used in training.
- Return type:
A
TensorDict
with the following keys
- generate(batch: PromptData, generation_config=None)[source]¶
Generates a sequence of tokens from a batch of data sampled from the data collector.
- Parameters:
batch (PromptData) – the data to be used. Must have
input_ids
andprompt_rindex
fields.generation_config (GenerationConfig, optional) – the configuration for the call to generate.
- Returns:
- a [B x (Ti +To)] sequence of integers (tokens),
where Ti is the length of the input sequence and To is the length of the generated sequence.
log_probs_gen: the log-probabilities of the token generated. log_ratio: the log ratio between probabilities under the generative
model and the frozen version.
- Return type:
generated (torch.Tensor)