Source code for torchrl.data.rlhf.utils
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import abc
import collections
import importlib
from typing import List, Tuple
import numpy as np
import torch
from tensordict import TensorDict
from torch import nn, Tensor
from torch.nn import functional as F
from torchrl.data.rlhf.prompt import PromptData
_has_transformers = importlib.util.find_spec("transformers") is not None
class KLControllerBase(abc.ABC):
"""Base class for KL controllers.
Each controller must implement an update method that takes the current KL value and
the number of steps and updates the kl_coef attribute of the wrapped model,
which will multiply the KL during calculation of the reward.
"""
@abc.abstractmethod
def update(self, kl_values: List[float]) -> float:
...
[docs]class ConstantKLController(KLControllerBase):
"""Constant KL Controller.
This controller maintains a fixed coefficient no matter what values it is updated
with.
Keyword Arguments:
kl_coef (float): The coefficient to multiply KL with when calculating the
reward.
model (nn.Module, optional): wrapped model that needs to be controlled.
Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will
be updated in-place.
"""
def __init__(
self,
*,
kl_coef: float = None,
model: nn.Module | None = None,
):
self.model = model
if model is not None and not hasattr(model, "kl_coef"):
raise AttributeError(
"Model input to ConstantKLController doesn't have attribute 'kl_coef'"
)
self.coef = kl_coef
if model is not None:
self.model.kl_coef = self.coef
def update(self, kl_values: List[float] = None) -> float:
if self.model is not None:
self.model.kl_coef = self.coef
return self.coef
[docs]class AdaptiveKLController(KLControllerBase):
"""Adaptive KL Controller as described in Ziegler et al. "Fine-Tuning Language Models from Human Preferences".
Keyword Arguments:
init_kl_coef (float): The starting value of the coefficient.
target (float): The target KL value. When the observed KL is smaller, the
coefficient is decreased, thereby relaxing the KL penalty in the training
objective and allowing the model to stray further from the reference model.
When the observed KL is greater than the target, the KL coefficient is
increased, thereby pulling the model back towards the reference model.
horizon (int): Scaling factor to control how aggressively we update the
coefficient.
model (nn.Module, optional): wrapped model that needs to be controlled.
Must have an attribute ``"kl_coef"``. If provided, the ``"kl_coef"`` will
be updated in-place.
Reference: Section 2.2 https://arxiv.org/pdf/1909.08593.pdf#page=2
Source: https://github.com/openai/lm-human-preferences/blob/master/lm_human_preferences/train_policy.py
"""
def __init__(
self,
*,
init_kl_coef: float,
target: float,
horizon: int,
model: nn.Module | None = None,
):
self.model = model
self.coef = init_kl_coef
self.target = target
self.horizon = horizon
if model is not None:
self.model.kl_coef = self.coef
[docs] def update(self, kl_values: List[float]):
"""Update ``self.coef`` adaptively.
Arguments:
kl_values (sequence of float): The current KL value between the newest policy and the initial
policy.
"""
if kl_values is None:
raise ValueError(
f"The kl_values were not provided to {type(self)}. "
f"Make sure these values are provided for the scheduler to be updated "
f"accordingly. "
)
n_steps = len(kl_values)
# renormalize kls
kl_value = -torch.as_tensor(kl_values).mean() / self.coef
proportional_error = np.clip(kl_value / self.target - 1, -0.2, 0.2) # ϵₜ
mult = 1 + proportional_error * n_steps / self.horizon
self.coef *= mult # βₜ₊₁
if self.model is not None:
self.model.kl_coef = self.coef
return self.coef
[docs]class RolloutFromModel:
"""A class for performing rollouts with causal language models.
It is assumed that the model this class wraps takes as input tokenized text and
whose task is to predict the next word in a sentence having read the n previous
words.
Args:
model (transformers.Transformer): the model to be used. Should have a
:meth:`generate` method.
ref_model (transformers.Transformer): a frozen version of ``model``
where params are in their initial configuration. This is used to compute a
KL penalty for the reward, to stop the model from straying too far from the
reference model during training.
reward_model: (nn.Module, tensordict.nn.TensorDictModule): a model which, given
``input_ids`` and ``attention_mask``, calculates rewards for each token and
end_scores (the reward for the final token in each sequence).
kl_coef: (float, optional): initial kl coefficient.
max_new_tokens (int, optional): the maximum length of the sequence.
Defaults to 50.
score_clip (float, optional): Scores from the reward model are clipped to the
range ``(-score_clip, score_clip)``. Defaults to 10.
kl_scheduler (KLControllerBase, optional): the KL coefficient scheduler.
num_steps (int, optional): number of steps between two optimization.
Examples:
>>> from tensordict.nn import TensorDictModule
>>> from torchrl.modules.models.rlhf import GPT2RewardModel
>>> from torchrl.data.rlhf.utils import RolloutFromModel
>>> from torchrl.data.rlhf.dataset import get_dataloader
>>> from torchrl.data.rlhf.prompt import PromptData
>>> from transformers import GPT2LMHeadModel
>>>
>>> dl = get_dataloader(
... batch_size=4,
... block_size=550,
... tensorclass_type=PromptData,
... device="cpu",
... dataset_name="CarperAI/openai_summarize_tldr",
... )
>>> model = GPT2LMHeadModel.from_pretrained("gpt2")
>>> # we load ref_model with random weights so it differs from model
>>> ref_model = GPT2LMHeadModel(GPT2LMHeadModel.config_class())
>>> reward_model = GPT2RewardModel(model_path="gpt2")
>>> rollout_from_model = RolloutFromModel(model, ref_model, reward_model)
>>>
>>> batch = next(dl)
>>> rollout = rollout_from_model.rollout_from_data(batch)
>>> rollout
TensorDict(
fields={
action: Tensor(shape=torch.Size([4, 50]), device=cpu, dtype=torch.int64, is_shared=False),
attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False),
input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False),
next: TensorDict(
fields={
attention_mask: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.bool, is_shared=False),
done: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.bool, is_shared=False),
input_ids: Tensor(shape=torch.Size([4, 50, 600]), device=cpu, dtype=torch.int64, is_shared=False),
reward: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
reward_kl: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False),
reward_raw: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([4, 50]),
device=cpu,
is_shared=False),
sample_log_prob: Tensor(shape=torch.Size([4, 50, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
batch_size=torch.Size([4, 50]),
device=cpu,
is_shared=False)
"""
EOS_TOKEN_ID = 50256
def __init__(
self,
model,
ref_model,
reward_model,
kl_coef=0.1,
max_new_tokens=50,
score_clip=10.0,
kl_scheduler: KLControllerBase | None = None,
num_steps: int | None = None,
):
if not _has_transformers:
raise ImportError(
"transformers module couldn't be found. Make sure it is installed in your "
"environment."
)
self.model = model
self.ref_model = ref_model
self.reward_model = reward_model
self.max_new_tokens = max_new_tokens
self.score_clip = score_clip
self.kl_coef = kl_coef
self.kl_scheduler = kl_scheduler
if num_steps is not None:
self._kl_queue = collections.deque(maxlen=num_steps)
else:
# we create a list. Value appended to it will be detached scalars so very cheap to store,
# even if the update is not called.
# The scheduler update will take care of erasing these values.
self._kl_queue = []
@torch.no_grad()
def rollout_from_data(self, batch):
generated, log_probs, log_ratio = self.generate(batch)
return self.create_rollout_td(batch, generated, log_probs, log_ratio)
[docs] @torch.no_grad()
def create_rollout_td(self, batch, generated, log_probs, log_ratio):
"""A TensorDict wrapper for generated data.
This function takes a batch plus the generated tokens and replicates the
tensordict structure that would have been obtained from a rollout with a TorchRL
env that sampled one token each timestep.
Args:
batch (TensorDict): A batch of data containing the original prompt together with a field
"rindex" indicating the right index of the prompt.
generated (torch.Tensor): Tokenized prompt followed by generated tokens. This can be obtained
by calling the ``generate`` method.
log_probs (torch.Tensor): The log probabilities of the generated tokens. Can be obtained by
calling the ``generate`` method.
log_ratio (torch.Tensor): The log ratio of the probabilities of the generated tokens
according to the generative model and the reference model. Can be
obtained by calling the ``generate`` method.
Returns:
A :class:`~tensordict.TensorDict` with the following keys:
- ``"action"``: the sequence of actions (generated tokens)
- ``"input_ids"``: the input_ids passed to the generative model at each time
step.
- ``"attention_mask"``: the attention_masks passed to the generative model at
each time step
- ``"sample_log_prob"``: the log probability of each token during generation
- ``("next", "input_ids")``: the sequence of tokens after generation. Makes up
part of the inputs that will be used for generating the next token.
- ``("next", "attention_mask")``: updated attention_mask after token has been
generated. Passed to the generative model on the next time step
- ``("next", "terminated")``: Boolean array indicating whether we've reached a
terminal state (either because we generated EOS token or because we
reached the token limit)
- ``("next", "done")``: Boolean array indicating whether we've reached a
final state. Currently a copy of ``"terminated"``.
- ``("next", "reward")``: The reward received at each time step
- ``("next", "reward_raw")``: The raw reward from the reward model, without the
KL term. This is mainly for debugging and logging, it is not used in
training
- ``("next", "reward_kl")``: The KL term from the reward. This is mainly for
debugging and logging, it is not used in training.
"""
rollout_generated = self._get_rollout_generated(generated, batch)
rollout_attention_mask = (rollout_generated != self.EOS_TOKEN_ID).bool()
done, terminated = self._get_done_status(generated, batch)
action = self._get_action(generated, batch)
end_scores, end_scores_labels = self._get_end_scores(
rollout_generated, rollout_attention_mask, batch
)
# the reward is zero except for the timestep where we reached a stopping condition
clipped_scores = torch.clip(
end_scores - end_scores_labels, -self.score_clip, self.score_clip
)
reward_raw = clipped_scores.unsqueeze(-1).unsqueeze(-1)
reward_raw = reward_raw * done
reward_kl = -self.kl_coef * log_ratio.unsqueeze(-1)
reward = reward_raw + reward_kl
td = {
"action": action,
"input_ids": rollout_generated[:, :-1].clone(),
"attention_mask": rollout_attention_mask[:, :-1].clone(),
"sample_log_prob": log_probs,
"next": {
"input_ids": rollout_generated[:, 1:].clone(),
"attention_mask": rollout_attention_mask[:, 1:].clone(),
"done": done,
"terminated": terminated,
"reward": reward,
"reward_raw": reward_raw,
"reward_kl": reward_kl,
},
}
self._kl_queue.append(reward_kl.detach().mean())
return TensorDict(
td, batch_size=done.shape[:2], device=generated.device
).refine_names(..., "time")
def _get_rollout_generated(self, generated, batch):
# stack the individual timesteps during generation into a single tensor
rollout_generated = []
arange = torch.arange(generated.shape[1], device=generated.device)
for rindex, row in zip(batch.prompt_rindex, generated):
tokens = []
for i in range(self.max_new_tokens + 1):
tokens.append(torch.where(arange < rindex + i, row, self.EOS_TOKEN_ID))
rollout_generated.append(torch.stack(tokens))
rollout_generated = torch.stack(rollout_generated)
return rollout_generated
def _get_done_status(self, generated, batch):
# done is True when we either first sample an EOS token or reach the maximum number
# of generated tokens
done_idx = torch.minimum(
(generated != self.EOS_TOKEN_ID).sum(dim=-1) - batch.prompt_rindex,
torch.as_tensor(self.max_new_tokens) - 1,
)
truncated_idx = (
torch.as_tensor(self.max_new_tokens, device=generated.device).expand_as(
done_idx
)
- 1
)
zeros = torch.zeros(
done_idx.numel(),
self.max_new_tokens,
dtype=torch.bool,
device=generated.device,
)
truncated = zeros.scatter(-1, truncated_idx.unsqueeze(-1), 1).unsqueeze(-1)
done = zeros.scatter(-1, done_idx.unsqueeze(-1), 1).unsqueeze(-1)
terminated = (
done & ~truncated
) # we assume that if it's not truncated, it was terminated
return truncated | terminated, terminated
def _get_action(self, generated, batch):
# the sequence of actions for each trajectory is just the generated token ids
action_idx = torch.arange(self.max_new_tokens, device=generated.device)
action_idx = action_idx + batch.prompt_rindex.unsqueeze(-1)
return generated.gather(-1, action_idx)
def _get_end_scores(self, rollout_generated, rollout_attention_mask, batch):
# calculate the reward for the finished sequence
_, end_scores = self.reward_model(
input_ids=rollout_generated[:, -1],
attention_mask=rollout_attention_mask[:, -1],
)
_, end_scores_labels = self.reward_model(
input_ids=batch.input_ids,
attention_mask=batch.attention_mask,
)
return end_scores, end_scores_labels
@classmethod
def _padded_right_to_left(cls, tensor, *, eos_token_id=None, dim=1):
if eos_token_id is None:
eos_token_id = cls.EOS_TOKEN_ID
mask = tensor != eos_token_id
out = torch.full_like(tensor, eos_token_id)
out[mask.flip(dim)] = tensor[mask]
return out
@classmethod
def _padded_left_to_right(
cls, tensor, *, sequence_length=None, eos_token_id=None, dim=1
):
# some care must be taken here, because generated sequences may have both left
# and right padding, and also may not terminated early if all sequences in the
# batch reached EOS before reaching the token limit
if sequence_length is None:
sequence_length = tensor.size(dim)
if dim < 0:
dim = tensor.ndim + dim
if eos_token_id is None:
eos_token_id = cls.EOS_TOKEN_ID
mask = tensor != eos_token_id
# convert [0, 0, 1, 1, 0] to [0, 0, 1, 1, 1] to avoid right eos
mask = ~((~mask).to(torch.uint8).cumprod(dim).bool())
shape = list(mask.shape)
shape[dim] = sequence_length
out = torch.full(torch.Size(shape), eos_token_id, device=tensor.device)
index = (slice(None),) * dim + (slice(tensor.size(dim)),)
out[index][mask.flip(dim)] = tensor[mask]
return out
@property
def _default_conf(self):
from transformers import GenerationConfig
return GenerationConfig(
pad_token_id=self.EOS_TOKEN_ID,
max_new_tokens=self.max_new_tokens,
return_dict_in_generate=True,
output_scores=True,
do_sample=True,
)
def _get_scores(
self, scores: Tuple, generated_tokens: Tensor = None, use_max=False, pad_to=None
):
scores = torch.stack(scores, 1)
if scores.shape[1] != self.max_new_tokens:
scores = F.pad(
scores,
(0, 0, 0, self.max_new_tokens - scores.shape[1]),
value=float("-inf"),
)
scores = F.log_softmax(scores, dim=-1)
if use_max:
scores = scores.max(dim=-1).values
else:
index = generated_tokens.unsqueeze(-1)
scores = torch.gather(scores, dim=-1, index=index)
if pad_to is not None:
pad = pad_to - scores.shape[1]
return F.pad(scores, (0, pad), value=-float("inf"))
return scores
[docs] @staticmethod
def logprobs_of_labels(logits, labels):
"""Log probabilities of the labels.
These are calculated from the logits. The labels (token ids) are used to index
the logits along the relevant dimension.
"""
logprobs = F.log_softmax(logits, dim=-1)
logprobs_labels = torch.gather(logprobs, dim=-1, index=labels.unsqueeze(-1))
return logprobs_labels.squeeze(-1)
@torch.no_grad()
def _log_ratio(self, generated, prompt_rindex):
# get the scores and normalise for log probabilities
attention_mask = (generated != self.EOS_TOKEN_ID).bool()
logits = self.model(
input_ids=generated, attention_mask=attention_mask, return_dict=True
).logits
logprobs = self.logprobs_of_labels(logits[:, :-1], generated[:, 1:])
ref_logits = self.ref_model(
input_ids=generated.to(self.ref_model.device),
attention_mask=attention_mask.to(self.ref_model.device),
return_dict=True,
).logits.to(logits.device)
ref_logprobs = self.logprobs_of_labels(ref_logits[:, :-1], generated[:, 1:])
log_ratio = logprobs - ref_logprobs
log_ratio = log_ratio.masked_fill(~attention_mask[:, :-1], 0)
log_ratio = torch.stack(
[
row[rindex - 1 : rindex + self.max_new_tokens - 1]
for row, rindex in zip(log_ratio, prompt_rindex)
],
dim=0,
)
return log_ratio
def _get_generated_tokens(self, generated, rindex):
# extracts the generated tokens from the full sequence of prompt + generated
idx = torch.arange(generated.shape[1], device=generated.device)
rindex = rindex.unsqueeze(-1)
mask = (idx >= rindex) & (idx < rindex + self.max_new_tokens)
return generated[mask].reshape(-1, self.max_new_tokens)
[docs] @torch.no_grad()
def generate(self, batch: PromptData, generation_config=None):
"""Generates a sequence of tokens from a batch of data sampled from the data collector.
Args:
batch (PromptData): the data to be used. Must have ``input_ids``
and ``prompt_rindex`` fields.
generation_config (GenerationConfig, optional): the configuration for the
call to generate.
Returns:
generated (torch.Tensor): a [B x (Ti +To)] sequence of integers (tokens),
where Ti is the length of the input sequence and To is the length
of the generated sequence.
log_probs_gen: the log-probabilities of the token generated.
log_ratio: the log ratio between probabilities under the generative
model and the frozen version.
"""
input_ids = batch.mask_label().input_ids
# move padding tokens to left pad
# huggingface models expect left padding for generation
input_ids = self._padded_right_to_left(input_ids)
# generate and capture scores
if generation_config is None:
generation_config = self._default_conf
attention_mask = (input_ids != self.EOS_TOKEN_ID).bool()
outputs = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
generation_config=generation_config,
)
samples = outputs.sequences
# we'll insert generated tokens into a tensor prepopulated with padding tokens,
# thereby moving back to right padding for reward model
generated = self._padded_left_to_right(
samples,
sequence_length=input_ids.shape[1] + self.max_new_tokens,
eos_token_id=self.EOS_TOKEN_ID,
)
generated_tokens = self._get_generated_tokens(generated, batch.prompt_rindex)
# get the scores and normalise for log probabilities
log_probs_gen = self._get_scores(outputs.scores, generated_tokens)
log_ratio = self._log_ratio(generated, batch.prompt_rindex)
return generated, log_probs_gen, log_ratio
def step_scheduler(self):
# recover true kl
self.kl_coef = self.kl_scheduler.update(self._kl_queue)
if isinstance(self._kl_queue, (list, collections.deque)):
# remove all values
while len(self._kl_queue):
self._kl_queue.remove(self._kl_queue[0])