#### Note

Today, nvFuser in TorchScript is the only exposure of\n nvFuser that allows for dynamic shape changes, although we will\n expand this capability to other systems in the future. For more\n insight into how dynamic shapes are implemented in nvFuser, you can\n view this presentation from GTC 2021:\n https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31952/

\n\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining novel operations with nvFuser and FuncTorch\n\nOne of the primary benefits of nvFuser is the ability to define\nnovel operations composed of PyTorch \u201cprimitives\u201d which are then\njust-in-time compiled into efficient kernels.\n\nPyTorch has strong performance for any individual operation,\nespecially composite operations like LayerNorm. However, if\nLayerNorm wasn\u2019t already implemented in PyTorch as a composite\noperation, then you\u2019d have to define it as a series of simpler\n(primitive) operations. Let\u2019s make such a definition and run it\nwithout nvFuser.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def primitive_definition(\n input1: torch.Tensor,\n input2: torch.Tensor,\n weight: torch.Tensor,\n bias1: torch.Tensor,\n bias2: torch.Tensor,\n normalization_axis: int,\n dropout_prob: float,\n keepdim: bool,\n) -> torch.Tensor:\n bias1_out = input1 + bias1\n dropout_out = F.dropout(bias1_out, dropout_prob, training=True)\n norm_input = dropout_out + input2\n mean = norm_input.mean(normalization_axis, keepdim=keepdim)\n diff = norm_input - mean\n diff_sq = diff * diff\n var = diff_sq.mean(normalization_axis, keepdim=keepdim)\n pre_shift_scale_norm_output = (norm_input - mean) / torch.sqrt(var + 1e-12)\n norm_output = weight * pre_shift_scale_norm_output + bias2\n return norm_output\n\n\n# Profile primitive definition\nfunc = functools.partial(\n primitive_definition,\n input1,\n input2,\n weight,\n bias1,\n bias2,\n normalization_axis=2,\n dropout_prob=0.1,\n keepdim=True,\n)\nprofile_workload(\n func, grad_output, iteration_count=100, label=\"Eager Mode - Primitive Definition\"\n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"While the above is mathematically equivalent to our previous\ndefinition, benchmarking our new function with the original static\nshape using TorchScript and nvFuser shows the iterations per second\ndecreases \u2013 mostly due to the cost of accessing memory to save\nintermediate results.\n\n.. figure:: /_static/img/nvfuser_intro/nvfuser_tutorial_3.png\n\nThe geomean iterations per second is 260 iterations per second,\n3.26x slower than the composite definition in eager mode and 5.35x\nslower than the nvFuser composite operation! For more information on\nwhy there\u2019s such a drastic decrease in compute speed please see this\npresentation from GTC 2022:\nhttps://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41958/\n\nnvFuser with TorchScript can improve the performance of this\noperation even though it\u2019s defined with primitive PyTorch\noperations. Simply by enabling TorchScript on the new function\n(just like before), we can see much of the performance returns.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Profile scripted primitive definition\nscripted_primitive_definition = torch.jit.script(primitive_definition)\nfunc = functools.partial(\n scripted_primitive_definition,\n input1,\n input2,\n weight,\n bias1,\n bias2,\n normalization_axis=2,\n dropout_prob=0.1,\n keepdim=True,\n)\nprofile_workload(\n func, grad_output, iteration_count=100, label=\"TorchScript - Primitive definition\"\n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
".. figure:: /_static/img/nvfuser_intro/nvfuser_tutorial_4.png\n\nHowever, the performance is still slower than the original eager\nmode performance of the composite definition. TorchScript works well\nwhen predefined composite operations are used, however TorchScript\u2019s\napplication of Autograd saves all of the activations for each\noperator in the fusion for re-use in the backwards pass. However,\nthis is not typically the optimal choice. Especially when chaining\ntogether multiple simple operations, it is often much faster to\nrecompute some intermediate tensors rather than spend the time\nstoring and retrieving several saved results from memory.\n\nIt\u2019s possible to optimize away many of these unnecessary memory\naccesses, but it requires building a connected forward and backward\ngraph which isn\u2019t possible with TorchScript. The\n`memory_efficient_fusion` pass in FuncTorch, however, is such an\noptimization pass. To use this pass, we have to redefine our\nfunction to pull the constants inside (for now it\u2019s easiest to make\nnon-tensor constants literals in the function definition):\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def primitive_definition_for_memory_efficient_fusion(\n input1: torch.Tensor,\n input2: torch.Tensor,\n weight: torch.Tensor,\n bias1: torch.Tensor,\n bias2: torch.Tensor,\n) -> torch.Tensor:\n bias1_out = input1 + bias1\n dropout_out = F.dropout(bias1_out, 0.1, training=True)\n norm_input = dropout_out + input2\n mean = norm_input.mean(2, keepdim=True)\n diff = norm_input - mean\n diff_sq = diff * diff\n var = diff_sq.mean(2, keepdim=True)\n pre_shift_scale_norm_output = (norm_input - mean) / torch.sqrt(var + 1e-12)\n norm_output = weight * pre_shift_scale_norm_output + bias2\n return norm_output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, instead of passing our function to TorchScript, we will pass it\nto FuncTorch\u2019s optimization pass.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Optimize the model with FuncTorch tracing and the memory efficiency\n# optimization pass\nmemory_efficient_primitive_definition = memory_efficient_fusion(\n primitive_definition_for_memory_efficient_fusion\n)\n\n# Profile memory efficient primitive definition\nfunc = functools.partial(\n memory_efficient_primitive_definition, input1, input2, weight, bias1, bias2\n)\nprofile_workload(\n func,\n grad_output,\n iteration_count=100,\n label=\"FuncTorch - Primitive definition\",\n)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This recovers even more speed, but it\u2019s still not as fast as\nTorchScripts original performance with the composite definition.\nHowever, this is still faster than running this new definition\nwithout nvFuser, and is still faster than the composite definition\nwithout nvFuser.\n\n.. figure:: /_static/img/nvfuser_intro/nvfuser_tutorial_5.png\n\n#### Note

FuncTorch\u2019s memory efficient pass specializes on the shapes of\n the inputs to the function. If new inputs are provided with\n different shapes, then you need to construct a new function\n using `memory_efficient_fusion` and apply it to the new inputs.

\n\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Transformer Block With a Novel Normalization\nThe ability to quickly execute chains of simple operations is\nimportant as not every operation has a composite operation defined\nin PyTorch. Previously, this meant researchers either had to define\nan entirely new operation in PyTorch \u2013 which takes a lot of time and\nknowledge of the lower-level PyTorch code as well as parallel\nprogramming \u2013 or writing the operation in simpler PyTorch ops and\nsettling for poor performance. For example, let's replace LayerNorm\nin our example with RMSNorm. Even though RMSNorm is a bit simpler\nthan LayerNorm, it doesn\u2019t have an existing compound operation in\nPyTorch. See the [Root Mean Square Layer Normalization](https://doi.org/10.48550/arXiv.1910.07467)_ paper for more information about RMSNorm.\nAs before, we\u2019ll define our new transformer block with\nprimitive PyTorch operations.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def with_rms_norm(\n input1: torch.Tensor,\n input2: torch.Tensor,\n weight: torch.Tensor,\n bias: torch.Tensor,\n normalization_axis: int,\n dropout_prob: float,\n keepdim: bool,\n) -> torch.Tensor:\n bias_out = input1 + bias\n dropout_out = F.dropout(bias_out, dropout_prob, training=True)\n norm_input = dropout_out + input2\n var = norm_input.mul(norm_input).mean(normalization_axis, keepdim)\n pre_shift_scale_norm_output = norm_input / torch.sqrt(var + 1e-12)\n norm_output = weight * pre_shift_scale_norm_output\n return norm_output"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"As before, we\u2019ll get a baseline by running PyTorch without nvFuser.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Profile rms_norm\nfunc = functools.partial(\n with_rms_norm,\n input1,\n input2,\n weight,\n bias1,\n normalization_axis=2,\n dropout_prob=0.1,\n keepdim=True,\n)\nprofile_workload(func, grad_output, iteration_count=100, label=\"Eager Mode - RMS Norm\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With nvFuser through TorchScript.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Profile scripted rms_norm\nscripted_with_rms_norm = torch.jit.script(with_rms_norm)\nfunc = functools.partial(\n scripted_with_rms_norm,\n input1,\n input2,\n weight,\n bias1,\n normalization_axis=2,\n dropout_prob=0.1,\n keepdim=True,\n)\nprofile_workload(func, grad_output, iteration_count=100, label=\"TorchScript - RMS Norm\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With nvFuser through Functorch.\n\n\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"def with_rms_norm_for_memory_efficient_fusion(\n input1: torch.Tensor, input2: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor\n) -> torch.Tensor:\n bias_out = input1 + bias\n dropout_out = torch.nn.functional.dropout(bias_out, 0.1)\n norm_input = dropout_out + input2\n var = norm_input.mul(norm_input).mean(2, keepdim=True)\n pre_shift_scale_norm_output = norm_input / torch.sqrt(var + 1e-12)\n norm_output = weight * pre_shift_scale_norm_output\n return norm_output\n\n\n# Profile memory efficient rms_norm\nmemory_efficient_rms_norm = memory_efficient_fusion(\n with_rms_norm_for_memory_efficient_fusion\n)\nfunc = functools.partial(memory_efficient_rms_norm, input1, input2, weight, bias1)\nprofile_workload(func, grad_output, iteration_count=100, label=\"FuncTorch - RMS Norm\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
".. figure:: /_static/img/nvfuser_intro/nvfuser_tutorial_6.png\n\nSince RMSNorm is simpler than LayerNorm the performance of our new\ntransformer block is a little higher than the primitive definition\nwithout nvFuser (354 iterations per second compared with 260\niterations per second). With TorchScript, the iterations per second\nincreases by 2.68x and 3.36x to 952 iterations per second and 1,191\niterations per second with TorchScript and FuncTorch\u2019s memory\nefficient optimization pass, respectively. The performance of this\nnew operation nearly matches the performance of the composite Layer\nNorm definition with TorchScript.\n\nnvFuser is here to provide the ability to define novel operations in\nsimple PyTorch and get performance that\u2019s close to a highly optimized\ncomposite operation in PyTorch. We believe this will enable research\ninto novel network topologies without paying for sometimes devastating\neffects on speed of training. nvFuser provides this unique ability as\nit\u2019s able to analyze users\u2019 programs to provide performance as fast as a\nhighly hand tuned implementation, regardless of how the operations are\ndefined. nvFuser still cannot support every operation in PyTorch,\nhowever its capabilities will continue to grow over time.\n\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 0
}