{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# NestedTensors\n\nNestedTensors are similar to regular tensors, except for their shape:\n\n* for a regular tensor, each dimension has a size\n\n* for a nestedtensor, not all dimensions have regular sizes; some of them are jagged\n\nNestedtensors are a natural solution for representing sequential data within various domains:\n\n* in NLP, sentences can have variable lengths, so a batch of sentences forms a nestedtensor\n\n* in CV, images can have variable shapes, so a batch of images forms a nestedtensor\n\nIn this tutorial, we will demonstrate basic usage of nestedtensors and motivate their usefulness\nfor operating on sequential data of varying lengths with a real-world example.\n\nNestedTensor are currently a prototype feature and are subject to change.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import torch\nimport torch.nn.functional as F\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## NestedTensor Initialization\n\nFrom the Python frontend, a nestedtensor can be created from a list of tensors.\nWe denote nt[i] as the ith tensor component of a nestedtensor.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"nt = torch.nested.nested_tensor([torch.arange(12).reshape(\n 2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device)\nprint(f\"{nt=}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"By padding every underlying tensor to the same shape,\na nestedtensor can be converted to a regular tensor.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"padded_out_tensor = torch.nested.to_padded_tensor(nt, padding=0.0)\nprint(f\"{padded_out_tensor=}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"All tensors posses an attribute for determining if they are nested;\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"print(f\"nt is nested: {nt.is_nested}\")\nprint(f\"padded_out_tensor is nested: {padded_out_tensor.is_nested}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is common to construct nestedtensors from batches of irregularly shaped tensors.\ni.e. dimension 0 is assumed to be the batch dimension.\nIndexing dimension 0 gives back the first underlying tensor component.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"print(\"First underlying tensor component:\", nt[0], sep='\\n')\nprint(\"last column of 2nd underlying tensor component:\", nt[1, :, -1], sep='\\n')\n\n# When indexing a nestedtensor's 0th dimension, the result is a regular tensor.\nprint(f\"First underlying tensor component is nested: {nt[0].is_nested}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"An important note is that slicing in dimension 0 has not been supported yet.\nWhich means it not currently possible to construct a view that combines the underlying\ntensor components.\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Nested Tensor Operations\n\nAs each operation must be explicitly implemented for nestedtensors,\noperation coverage for nestedtensors is currently narrower than that of regular tensors.\nFor now, only basic operations such as index, dropout, softmax, transpose, reshape, linear, bmm are covered.\nHowever, coverage is being expanded.\nIf you need certain operations, please file an [issue](https://github.com/pytorch/pytorch)_\nto help us prioritize coverage.\n\n**reshape**\n\nThe reshape op is for changing the shape of a tensor.\nIts full semantics for regular tensors can be found\n[here](https://pytorch.org/docs/stable/generated/torch.reshape.html)_.\nFor regular tensors, when specifying the new shape,\na single dimension may be -1, in which case it is inferred\nfrom the remaining dimensions and the number of elements.\n\nThe semantics for nestedtensors are similar, except that -1 no longer infers.\nInstead, it inherits the old size (here 2 for ``nt[0]`` and 3 for ``nt[1]``).\n-1 is the only legal size to specify for a jagged dimension.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"nt_reshaped = nt.reshape(2, -1, 2, 3)\nprint(f\"{nt_reshaped=}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**transpose**\n\nThe transpose op is for swapping two dimensions of a tensor.\nIts full semantics can be found\n[here](https://pytorch.org/docs/stable/generated/torch.transpose.html)_.\nNote that for nestedtensors dimension 0 is special;\nit is assumed to be the batch dimension,\nso transposes involving nestedtensor dimension 0 are not supported.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"nt_transposed = nt_reshaped.transpose(1, 2)\nprint(f\"{nt_transposed=}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**others**\n\nOther operations have the same semantics as for regular tensors.\nApplying the operation on a nestedtensor is equivalent to\napplying the operation to the underlying tensor components,\nwith the result being a nestedtensor as well.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device)\nnt3 = torch.matmul(nt_transposed, nt_mm)\nprint(f\"Result of Matmul:\\n {nt3}\")\n\nnt4 = F.dropout(nt3, 0.1)\nprint(f\"Result of Dropout:\\n {nt4}\")\n\nnt5 = F.softmax(nt4, -1)\nprint(f\"Result of Softmax:\\n {nt5}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Why Nested Tensor\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"When data is sequential, it is often the case that each sample has a different length.\nFor example, in a batch of sentences, each sentence has a different number of words.\nA common technique for handling varying sequences is to manually pad each data tensor\nto the same shape in order to form a batch.\nFor example, we have 2 sentences with different lengths and a vocabulary\nIn order to represent his as single tensor we pad with 0 to the max length in the batch.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"sentences = [[\"goodbye\", \"padding\"],\n [\"embrace\", \"nested\", \"tensor\"]]\nvocabulary = {\"goodbye\": 1.0, \"padding\": 2.0,\n \"embrace\": 3.0, \"nested\": 4.0, \"tensor\": 5.0}\npadded_sentences = torch.tensor([[1.0, 2.0, 0.0],\n [3.0, 4.0, 5.0]])\nnested_sentences = torch.nested.nested_tensor([torch.tensor([1.0, 2.0]),\n torch.tensor([3.0, 4.0, 5.0])])\nprint(f\"{padded_sentences=}\")\nprint(f\"{nested_sentences=}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This techinque of padding a batch of data to its max length is not optimal.\nThe padded data is not needed for computation and wastes memory by allocating\nlarger tensors than necessary.\nFurther, not all operations have the same semnatics when applied to padded data.\nFor matrix multiplications in order to ignore the padded entries, one needs to pad\nwith 0 while for softmax one has to pad with -inf to ignore specific entries.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float(\"-inf\")],\n [3.0, 4.0, 5.0]])\nprint(F.softmax(padded_sentences_for_softmax, -1))\nprint(F.softmax(nested_sentences, -1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us take a look at a practical example: the multi-head attention component\nutilized in [Transformers](https://arxiv.org/pdf/1706.03762.pdf)_.\nThe nestedtensor version is straightforward.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import math\n\ndef mha_nested(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nheads: int,\n W_q: torch.Tensor, W_k: torch.Tensor, W_v: torch.Tensor, W_out: torch.Tensor,\n b_q: torch.Tensor = None, b_k: torch.Tensor = None, b_v: torch.Tensor = None, b_out: torch.Tensor = None,\n dropout_p: float = 0.0) -> torch.Tensor:\n \"\"\"Compute multi-head attention with nested tensors.\n Args:\n query (torch.Tensor): query of shape (N, L_t, E_q)\n key (torch.Tensor): key of shape (N, L_s, E_k)\n value (torch.Tensor): value of shape (N, L_s, E_v)\n nheads (int): number of heads in multi-head attention\n W_q (torch.Tensor): Weight for query input projection of shape (E_total, E_q)\n W_k (torch.Tensor): Weight for key input projection of shape (E_total, E_k)\n W_v (torch.Tensor): Weight for value input projection of shape (E_total, E_v)\n W_out (torch.Tensor): Weight for output projection of shape (E_out, E_total)\n b_q (torch.Tensor, optional): Bias for query input projection of shape E_total. Default: None. Defaults to None.\n b_k (torch.Tensor, optional): Bias for key input projection of shape E_total. Default: None. Defaults to None.\n b_v (torch.Tensor, optional): Bias for value input projection of shape E_total. Default: None. Defaults to None.\n b_out (torch.Tensor, optional): Bias for output projection of shape E_out. Default: None. Defaults to None.\n dropout_p (float, optional): Dropout probability. Defaults to 0.0.\n\n Where:\n N is the batch size\n L_t is the target sequence length (jagged)\n L_s is the source sequence length (jagged)\n E_q is the embedding size for query\n E_k is the embedding size for key\n E_v is the embedding size for value\n E_total is the embedding size for all heads combined\n E_out is the output embedding size\n Returns:\n torch.Tensor: Output of shape (N, L_t, E_out)\n \"\"\"\n\n N = query.size(0)\n E_total = W_q.size(0)\n assert E_total % nheads == 0, \"Embedding dim is not divisible by nheads\"\n E_head = E_total // nheads\n\n # apply input projection\n # (N, L_t, E_q) -> (N, L_t, E_total)\n query = F.linear(query, W_q, b_q)\n # (N, L_s, E_k) -> (N, L_s, E_total)\n key = F.linear(key, W_k, b_k)\n # (N, L_s, E_v) -> (N, L_s, E_total)\n value = F.linear(value, W_v, b_v)\n\n # reshape query, key, value to separate by head\n # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)\n query = query.reshape(N, -1, nheads, E_head).transpose(1, 2)\n # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)\n key = key.reshape(N, -1, nheads, E_head).transpose(1, 2)\n # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)\n value = value.reshape(N, -1, nheads, E_head).transpose(1, 2)\n\n # query matmul key^T\n # (N, nheads, L_t, E_head) x (N, nheads, L_s, E_head)^T -> (N, nheads, L_t, L_s)\n keyT = key.transpose(-1, -2)\n attn_weights = torch.matmul(query, keyT)\n\n # scale down\n attn_weights = attn_weights * (1.0 / math.sqrt(E_head))\n\n # softmax\n attn_weights = F.softmax(attn_weights, dim=-1)\n\n # dropout\n if dropout_p > 0.0:\n attn_weights = F.dropout(attn_weights, p=dropout_p)\n\n # attention_weights matmul value\n # (N, nheads, L_t, L_s) x (N, nheads, L_s, E_head) -> (N, nheads, L_t, E_head)\n attn_output = torch.matmul(attn_weights, value)\n\n # merge heads\n # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)\n attn_output = attn_output.transpose(1, 2).reshape(N, -1, E_total)\n\n # apply output projection\n # (N, L_t, E_total) -> (N, L_t, E_out)\n attn_output = F.linear(attn_output, W_out, b_out)\n\n return attn_output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The 0-padded tensor version additionally requires masks\nfor more complicated treatments at padded entries.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def mha_padded(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, nheads: int,\n attn_mask_q: torch.Tensor, attn_mask_kv: torch.Tensor,\n W_q: torch.Tensor, W_k: torch.Tensor, W_v: torch.Tensor, W_out: torch.Tensor,\n b_q: torch.Tensor = None, b_k: torch.Tensor = None, b_v: torch.Tensor = None, b_out: torch.Tensor = None,\n dropout_p: float = 0.0) -> torch.Tensor:\n \"\"\"Compute multi-head attention for padded out dense tensors.\n\n Args:\n query (torch.Tensor): query of shape (N, L_t, E_q)\n key (torch.Tensor): key of shape (N, L_s, E_k)\n value (torch.Tensor): value of shape (N, L_s, E_v)\n nheads (int): number of heads in multi-head attention\n attn_mask_q (torch.Tensor): boolean mask indicating locations that should not take part in attention for query, shape (N, L_t)\n attn_mask_kv (torch.Tensor): boolean mask indicating locations that should not take part in attention for key and value, shape (N, L_s)\n W_q (torch.Tensor): Weight for query input projection of shape (E_total, E_q)\n W_k (torch.Tensor): Weight for key input projection of shape (E_total, E_k)\n W_v (torch.Tensor): Weight for value input projection of shape (E_total, E_v)\n W_out (torch.Tensor): Weight for output projection of shape (E_out, E_total)\n b_q (torch.Tensor, optional): Bias for query input projection of shape E_total.. Defaults to None.\n b_k (torch.Tensor, optional): Bias for key input projection of shape E_total.. Defaults to None.\n b_v (torch.Tensor, optional): Bias for value input projection of shape E_total.. Defaults to None.\n b_out (torch.Tensor, optional): Bias for output projection of shape E_out. Defaults to None.\n dropout_p (float, optional): Dropout probability. Defaults to 0.0.\n\n Where:\n N is the batch size\n L_t is the target sequence length (padded)\n L_s is the source sequence length (padded)\n E_q is the embedding size for query\n E_k is the embedding size for key\n E_v is the embedding size for value\n E_total is the embedding size for all heads combined\n E_out is the output embedding size\n Returns:\n torch.Tensor: Output of shape (N, L_t, E_out)\n \"\"\"\n N = query.size(0)\n L_t = query.size(1)\n L_s = key.size(1)\n E_total = W_q.size(0)\n assert E_total % nheads == 0, \"Embedding dim is not divisible by nheads\"\n assert L_t == L_s, \"This implementation assumes equal query and key sequence lengths\"\n E_head = E_total // nheads\n\n # apply input projection\n # (N, L_t, E_q) -> (N, L_t, E_total)\n query = F.linear(query, W_q, b_q)\n # (N, L_s, E_k) -> (N, L_s, E_total)\n key = F.linear(key, W_k, b_k)\n # (N, L_s, E_v) -> (N, L_s, E_total)\n value = F.linear(value, W_v, b_v)\n\n # reshape query, key, value to separate by head\n # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head) -> (N * nheads, L_t, E_head)\n query = query.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head)\n # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head)\n key = key.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head)\n # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head) -> (N * nheads, L_s, E_head)\n value = value.reshape(N, -1, nheads, E_head).transpose(1, 2).reshape(N * nheads, -1, E_head)\n\n # query bmm key^T\n # (N * nheads, L_t, E_head) x (N * nheads, L_s, E_head)^T -> (N * nheads, L_t, L_s)\n keyT = key.transpose(-1, -2)\n attn_weights = torch.bmm(query, keyT)\n\n # scale down\n attn_weights = attn_weights * (1.0 / math.sqrt(E_head))\n\n # Have to manipulate masks in order to apply them to the attention weights\n key_padding_mask = attn_mask_q.view(N, 1, 1, L_t).expand(-1, nheads, -1, -1).reshape(N*nheads, 1, L_t).to(device=device)\n attn_mask = torch.zeros(key_padding_mask.shape, device=device, dtype=torch.float32)\n attn_mask = attn_mask.masked_fill_(key_padding_mask, float(\"-inf\"))\n\n # Zero out the attention weights where the mask is True by adding -inf prior to softmax\n attn_weights.add_(attn_mask)\n\n # softmax\n attn_weights = F.softmax(attn_weights, dim=-1).nan_to_num_(0.0)\n\n # dropout\n if dropout_p > 0.0:\n attn_weights = F.dropout(attn_weights, p=dropout_p)\n\n # attention_weights bmm value\n # (N * nheads, L_t, L_s) x (N * nheads, L_s, E_head) -> (N * nheads, L_t, E_head)\n attn_output = attn_weights.bmm(value)\n\n # merge heads\n # (N * nheads, L_t, E_head) -> (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)\n attn_output = attn_output.reshape(N, nheads, -1, E_head).transpose(1, 2).reshape(N, -1, E_total)\n\n # apply output projection\n # (N, L_t, E_total) -> (N, L_t, E_out)\n attn_output = F.linear(attn_output, W_out, b_out)\n\n # padding-specific step: remove output projection bias from padded entries\n attn_output[attn_mask_q, :] = 0.0\n\n return attn_output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"set hyperparameters following [the Transformer paper](https://arxiv.org/pdf/1706.03762.pdf)_\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"N = 512\nE_q, E_k, E_v, E_total, E_out = 512, 512, 512, 512, 512\nnheads = 8"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"except for dropout probability: set to 0 for correctness check\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"dropout_p = 0.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let us generate some realistic fake data from Zipf's law.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import numpy as np\n\ndef zipf_sentence_lengths(alpha: float, batch_size: int) -> np.ndarray:\n # generate fake corpus by unigram Zipf distribution\n # from wikitext-2 corpus, we get rank \".\" = 3, \"!\" = 386, \"?\" = 858\n sentence_lengths = np.empty(batch_size, dtype=int)\n for ibatch in range(batch_size):\n sentence_lengths[ibatch] = 1\n word = np.random.zipf(alpha)\n while word != 3 and word != 386 and word != 858:\n sentence_lengths[ibatch] += 1\n word = np.random.zipf(alpha)\n return sentence_lengths\n\nalpha = 1.2\n\nsentence_lengths = zipf_sentence_lengths(alpha, N)\nL_t = np.max(sentence_lengths)\nL_s = L_t"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"create inputs\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# create parameters\nW_q, b_q = torch.randn((E_total, E_q), device=device), torch.randn(E_total, device=device)\nW_k, b_k = torch.randn((E_total, E_k), device=device), torch.randn(E_total, device=device)\nW_v, b_v = torch.randn((E_total, E_v), device=device), torch.randn(E_total, device=device)\nW_out, b_out = torch.randn((E_out, E_total), device=device), torch.randn(E_out, device=device)\n\n# create nested input\nqueries = []\nkeys = []\nvalues = []\nfor i in range(N):\n l = sentence_lengths[i]\n s = l\n queries.append(torch.randn((l, E_q), device=device))\n keys .append(torch.randn((s, E_k), device=device))\n values .append(torch.randn((s, E_v), device=device))\nquery = torch.nested.nested_tensor(queries)\nkey = torch.nested.nested_tensor(keys)\nvalue = torch.nested.nested_tensor(values)\n\n# pad input\npadded_query = torch.nested.to_padded_tensor(query, 0.0, (N, L_t, E_q))\npadded_key = torch.nested.to_padded_tensor(key, 0.0, (N, L_s, E_k))\npadded_value = torch.nested.to_padded_tensor(value, 0.0, (N, L_s, E_v))\n\n# create attention masks\nattn_mask_q = torch.zeros((N, L_t), dtype=torch.bool)\nattn_mask_kv = torch.zeros((N, L_s), dtype=torch.bool)\n\n# We need to mask out the padding entries in the attention weights.\nfor i, entry_length in enumerate(sentence_lengths):\n attn_mask_q[i, entry_length:] = True\n attn_mask_kv[i, entry_length:] = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"check correctness and performance\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import timeit\n\nt0 = timeit.default_timer()\nout_nested = mha_nested(\n query, key, value, nheads,\n W_q, W_k, W_v, W_out,\n b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,\n dropout_p=dropout_p)\n\nt1 = timeit.default_timer()\nout_padded = mha_padded(\n padded_query, padded_key, padded_value, nheads,\n attn_mask_q, attn_mask_kv,\n W_q, W_k, W_v, W_out,\n b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,\n dropout_p=dropout_p)\nt2 = timeit.default_timer()\n\nprint(\"nested and padded calculations differ by\", (torch.nested.to_padded_tensor(out_nested, 0.0, (N, L_t, E_out)) - out_padded).abs().max().item())\nprint(\"nestedtensor multi-head attention takes\", t1 - t0, \"seconds\")\nprint(\"padded tensor multi-head attention takes\", t2 - t1, \"seconds\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Although the nestedtensor version avoids wasted computation on padding, it is not faster\nthen the equivalent padded tensor version. This is because the nestedtensor version\nhas implemented a few of the kernels, like softmax, in a non optimal way.\n\nThere are plans to implement performance critical operations using the new Pytorch 2.0 stack\nFor now, some performant kernels are provided for specific use cases, e.g.\nself-attention evaluation by multi-head attention formula.\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# embeddings are assumed to be the same\nE = E_total\nmha_lib = torch.nn.MultiheadAttention(E, nheads, batch_first=True, device=device)\nmha_lib.eval()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"extract parameters for correctness check\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"mha_lib.in_proj_weight.requires_grad_(False)\nmha_lib.in_proj_bias.requires_grad_(False)\nmha_lib.out_proj.weight.requires_grad_(False)\nmha_lib.out_proj.bias.requires_grad_(False)\nW_q, b_q = mha_lib.in_proj_weight[: E, :], mha_lib.in_proj_bias[: E]\nW_k, b_k = mha_lib.in_proj_weight[E : 2 * E, :], mha_lib.in_proj_bias[E : 2 * E]\nW_v, b_v = mha_lib.in_proj_weight[2 * E :, :], mha_lib.in_proj_bias[2 * E :]\nW_out, b_out = mha_lib.out_proj.weight, mha_lib.out_proj.bias"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we set need_weights to False this will enable the fast path in the library.\nUnder the hood this will call _scaled_dot_product_attention. If your tensors\nare on CUDA, than a fused, efficient attention kernel will be used. For\nmore detailed performance characteristics look at the benchmark in\npytorch/benchmarks/transformer/sdp.py\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"with torch.inference_mode():\n t0 = timeit.default_timer()\n out_lib, out_lib_weights = mha_lib(query, query, query, need_weights=False)\n\n t1 = timeit.default_timer()\n padded_out = mha_padded(\n padded_query, padded_query, padded_query, nheads,\n attn_mask_q, attn_mask_q,\n W_q, W_k, W_v, W_out,\n b_q=b_q, b_k=b_k, b_v=b_v, b_out=b_out,\n dropout_p=dropout_p)\n t2 = timeit.default_timer()\n\nnested_time = t1 - t0\npadded_time = t2 - t1\nprint(\"Nested and padded calculations differ by\", (torch.nested.to_padded_tensor(out_lib, 0.0) - padded_out).abs().max().item())\nprint(\"Nested library multi-head attention takes\", nested_time, \"seconds\")\nprint(\"Padded tensor multi-head attention takes\", padded_time, \"seconds\")\nprint(f\"Nested Speedup: {padded_time / nested_time:.3f}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}