Shortcuts

GAILLoss

class torchrl.objectives.GAILLoss(*args, **kwargs)[source]

TorchRL implementation of the Generative Adversarial Imitation Learning (GAIL) loss.

Presented in “Generative Adversarial Imitation Learning” <https://arxiv.org/pdf/1606.03476>

Parameters:

discriminator_network (TensorDictModule) – stochastic actor

Keyword Arguments:
  • use_grad_penalty (bool, optional) – Whether to use gradient penalty. Default: False.

  • gp_lambda (float, optional) – Gradient penalty lambda. Default: 10.

  • reduction (str, optional) – Specifies the reduction to apply to the output: "none" | "mean" | "sum". "none": no reduction will be applied, "mean": the sum of the output will be divided by the number of elements in the output, "sum": the output will be summed. Default: "mean".

forward(tensordict: TensorDictBase = None) TensorDictBase[source]

The forward method.

Computes the discriminator loss and gradient penalty if use_grad_penalty is set to True. If use_grad_penalty is set to True, the detached gradient penalty loss is also returned for logging purposes. To see what keys are expected in the input tensordict and what keys are expected as output, check the class’s “in_keys” and “out_keys” attributes.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources