Note
Click here to download the full example code
(beta) Using TORCH_LOGS python API with torch.compile¶
Created On: Jan 24, 2024 | Last Updated: Jan 31, 2024 | Last Verified: Nov 05, 2024
Author: Michael Lazos
import logging
This tutorial introduces the TORCH_LOGS
environment variable, as well as the Python API, and
demonstrates how to apply it to observe the phases of torch.compile
.
Note
This tutorial requires PyTorch 2.2.0 or later.
Setup¶
In this example, we’ll set up a simple Python function which performs an elementwise
add and observe the compilation process with TORCH_LOGS
Python API.
Note
There is also an environment variable TORCH_LOGS
, which can be used to
change logging settings at the command line. The equivalent environment
variable setting is shown for each example.
import torch
# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
print("Skipping because torch.compile is not supported on this device.")
else:
@torch.compile()
def fn(x, y):
z = x + y
return z + 2
inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))
# print separator and reset dynamo
# between each example
def separator(name):
print(f"==================={name}=========================")
torch._dynamo.reset()
separator("Dynamo Tracing")
# View dynamo tracing
# TORCH_LOGS="+dynamo"
torch._logging.set_logs(dynamo=logging.DEBUG)
fn(*inputs)
separator("Traced Graph")
# View traced graph
# TORCH_LOGS="graph"
torch._logging.set_logs(graph=True)
fn(*inputs)
separator("Fusion Decisions")
# View fusion decisions
# TORCH_LOGS="fusion"
torch._logging.set_logs(fusion=True)
fn(*inputs)
separator("Output Code")
# View output code generated by inductor
# TORCH_LOGS="output_code"
torch._logging.set_logs(output_code=True)
fn(*inputs)
separator("")
===================Dynamo Tracing=========================
V0418 18:40:20.524000 634 torch/_dynamo/convert_frame.py:1345] skipping: _is_skip_guard_eval_unsafe_stance (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] torchdynamo start compiling fn /var/lib/workspace/recipes_source/torch_logs.py:39, stack (elided 5 frames):
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/bin/sphinx-build", line 8, in <module>
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] sys.exit(main())
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 288, in main
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] return make_main(argv)
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 193, in make_main
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] return make_mode.run_make_mode(argv[1:])
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] return make.run_generic_build(args[0])
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] return build_main(args + opts)
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 272, in build_main
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 256, in __init__
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] self._init_builder()
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 314, in _init_builder
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] self.events.emit('builder-inited')
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] results.append(listener.handler(self.app, *args))
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] ) = generate_dir_rst(
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] intro, title, cost = generate_file_rst(
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] output_blocks, time_elapsed = execute_script(script_blocks,
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] output_blocks.append(execute_code_block(
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] is_last_expr, mem_max = _exec_and_get_memory(
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] mem_max, _ = gallery_conf['call_memory'](
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] return 0., func()
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] exec(self.code, self.fake_main.__dict__)
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] File "/var/lib/workspace/recipes_source/torch_logs.py", line 59, in <module>
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0] fn(*inputs)
V0418 18:40:20.525000 634 torch/_dynamo/convert_frame.py:930] [0/0]
I0418 18:40:20.527000 634 torch/_dynamo/symbolic_convert.py:2706] [0/0] Step 1: torchdynamo start tracing fn /var/lib/workspace/recipes_source/torch_logs.py:39
I0418 18:40:20.527000 634 torch/fx/experimental/symbolic_shapes.py:3192] [0/0] create_env
V0418 18:40:20.531000 634 torch/_dynamo/symbolic_convert.py:932] [0/0] [__trace_source] TRACE starts_line /var/lib/workspace/recipes_source/torch_logs.py:41 in fn (fn)
V0418 18:40:20.531000 634 torch/_dynamo/symbolic_convert.py:932] [0/0] [__trace_source] z = x + y
V0418 18:40:20.532000 634 torch/_dynamo/symbolic_convert.py:955] [0/0] [__trace_bytecode] TRACE LOAD_FAST x []
V0418 18:40:20.532000 634 torch/_dynamo/symbolic_convert.py:955] [0/0] [__trace_bytecode] TRACE LOAD_FAST y [LazyVariableTracker()]
V0418 18:40:20.532000 634 torch/_dynamo/symbolic_convert.py:955] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [LazyVariableTracker(), LazyVariableTracker()]
V0418 18:40:20.533000 634 torch/_dynamo/variables/builder.py:2853] [0/0] wrap_to_fake L['x'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='x', is_input=True, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0418 18:40:20.534000 634 torch/_dynamo/output_graph.py:2156] [0/0] create_graph_input L_x_ L['x'] FakeTensor(..., device='cuda:0', size=(2, 2)) at debug_level 0 before=False
V0418 18:40:20.535000 634 torch/_dynamo/variables/builder.py:2853] [0/0] wrap_to_fake L['y'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='y', is_input=True, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0418 18:40:20.536000 634 torch/_dynamo/output_graph.py:2156] [0/0] create_graph_input L_y_ L['y'] FakeTensor(..., device='cuda:0', size=(2, 2)) at debug_level 0 before=False
V0418 18:40:20.538000 634 torch/_dynamo/symbolic_convert.py:955] [0/0] [__trace_bytecode] TRACE STORE_FAST z [TensorVariable()]
V0418 18:40:20.538000 634 torch/_dynamo/symbolic_convert.py:932] [0/0] [__trace_source] TRACE starts_line /var/lib/workspace/recipes_source/torch_logs.py:42 in fn (fn)
V0418 18:40:20.538000 634 torch/_dynamo/symbolic_convert.py:932] [0/0] [__trace_source] return z + 2
V0418 18:40:20.538000 634 torch/_dynamo/symbolic_convert.py:955] [0/0] [__trace_bytecode] TRACE LOAD_FAST z []
V0418 18:40:20.539000 634 torch/_dynamo/symbolic_convert.py:955] [0/0] [__trace_bytecode] TRACE LOAD_CONST 2 [TensorVariable()]
V0418 18:40:20.539000 634 torch/_dynamo/symbolic_convert.py:955] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [TensorVariable(), ConstantVariable(int: 2)]
V0418 18:40:20.540000 634 torch/_dynamo/symbolic_convert.py:955] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TensorVariable()]
I0418 18:40:20.540000 634 torch/_dynamo/symbolic_convert.py:3028] [0/0] Step 1: torchdynamo done tracing fn (RETURN_VALUE)
V0418 18:40:20.541000 634 torch/_dynamo/symbolic_convert.py:3032] [0/0] RETURN_VALUE triggered compile
V0418 18:40:20.541000 634 torch/_dynamo/output_graph.py:972] [0/0] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /var/lib/workspace/recipes_source/torch_logs.py, line 42 in fn>], graph_break=False)
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code] TRACED GRAPH
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code] ===== __compiled_fn_1 =====
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code] /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code] def forward(self, L_x_: "f32[2, 2][2, 1]cuda:0", L_y_: "f32[2, 2][2, 1]cuda:0"):
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code] l_x_ = L_x_
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code] l_y_ = L_y_
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code]
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code] # File: /var/lib/workspace/recipes_source/torch_logs.py:41 in fn, code: z = x + y
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code] z: "f32[2, 2][2, 1]cuda:0" = l_x_ + l_y_; l_x_ = l_y_ = None
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code]
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code] # File: /var/lib/workspace/recipes_source/torch_logs.py:42 in fn, code: return z + 2
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code] add_1: "f32[2, 2][2, 1]cuda:0" = z + 2; z = None
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code] return (add_1,)
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code]
V0418 18:40:20.542000 634 torch/_dynamo/output_graph.py:1353] [0/0] [__graph_code]
I0418 18:40:20.544000 634 torch/_dynamo/output_graph.py:1458] [0/0] Step 2: calling compiler function inductor
I0418 18:40:21.789000 634 torch/fx/experimental/symbolic_shapes.py:4547] [0/0] produce_guards
I0418 18:40:21.793000 634 torch/_dynamo/output_graph.py:1463] [0/0] Step 2: done compiler function inductor
I0418 18:40:21.795000 634 torch/fx/experimental/symbolic_shapes.py:4547] [0/0] produce_guards
V0418 18:40:21.796000 634 torch/fx/experimental/symbolic_shapes.py:4755] [0/0] track_symint L['x'].size()[0] 2 None
V0418 18:40:21.796000 634 torch/fx/experimental/symbolic_shapes.py:4755] [0/0] track_symint L['x'].size()[1] 2 None
V0418 18:40:21.796000 634 torch/fx/experimental/symbolic_shapes.py:4755] [0/0] track_symint L['x'].stride()[0] 2 None
V0418 18:40:21.796000 634 torch/fx/experimental/symbolic_shapes.py:4755] [0/0] track_symint L['x'].stride()[1] 1 None
V0418 18:40:21.797000 634 torch/fx/experimental/symbolic_shapes.py:4755] [0/0] track_symint L['x'].storage_offset() 0 None
V0418 18:40:21.797000 634 torch/fx/experimental/symbolic_shapes.py:4755] [0/0] track_symint L['y'].size()[0] 2 None
V0418 18:40:21.797000 634 torch/fx/experimental/symbolic_shapes.py:4755] [0/0] track_symint L['y'].size()[1] 2 None
V0418 18:40:21.797000 634 torch/fx/experimental/symbolic_shapes.py:4755] [0/0] track_symint L['y'].stride()[0] 2 None
V0418 18:40:21.797000 634 torch/fx/experimental/symbolic_shapes.py:4755] [0/0] track_symint L['y'].stride()[1] 1 None
V0418 18:40:21.798000 634 torch/fx/experimental/symbolic_shapes.py:4755] [0/0] track_symint L['y'].storage_offset() 0 None
V0418 18:40:21.798000 634 torch/fx/experimental/symbolic_shapes.py:4958] [0/0] Skipping guard L['x'].size()[0] == 2
V0418 18:40:21.798000 634 torch/fx/experimental/symbolic_shapes.py:4958] [0/0] Skipping guard L['x'].size()[1] == 2
V0418 18:40:21.799000 634 torch/fx/experimental/symbolic_shapes.py:4958] [0/0] Skipping guard L['x'].stride()[0] == 2
V0418 18:40:21.799000 634 torch/fx/experimental/symbolic_shapes.py:4958] [0/0] Skipping guard L['x'].stride()[1] == 1
V0418 18:40:21.799000 634 torch/fx/experimental/symbolic_shapes.py:4958] [0/0] Skipping guard L['x'].storage_offset() == 0
V0418 18:40:21.799000 634 torch/fx/experimental/symbolic_shapes.py:4958] [0/0] Skipping guard L['y'].size()[0] == 2
V0418 18:40:21.800000 634 torch/fx/experimental/symbolic_shapes.py:4958] [0/0] Skipping guard L['y'].size()[1] == 2
V0418 18:40:21.800000 634 torch/fx/experimental/symbolic_shapes.py:4958] [0/0] Skipping guard L['y'].stride()[0] == 2
V0418 18:40:21.800000 634 torch/fx/experimental/symbolic_shapes.py:4958] [0/0] Skipping guard L['y'].stride()[1] == 1
V0418 18:40:21.800000 634 torch/fx/experimental/symbolic_shapes.py:4958] [0/0] Skipping guard L['y'].storage_offset() == 0
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2364] [0/0] [__guards] GUARDS:
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards]
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] TREE_GUARD_MANAGER:
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] +- RootGuardManager
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None # _dynamo/output_graph.py:493 in init_ambient_guards
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] | +- GLOBAL_STATE: ___check_global_state()
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] | +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] | +- GuardManager: source=L['x'], accessed_by=DictGetItemGuardAccessor('x')
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1]) # z = x + y # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False # z = x + y # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] | +- GuardManager: source=L['y'], accessed_by=DictGetItemGuardAccessor('y')
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['y'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1]) # z = x + y # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['y'], '_dynamo_dynamic_indices') == False # z = x + y # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards] | | +- NO_TENSOR_ALIASING
V0418 18:40:21.801000 634 torch/_dynamo/guards.py:2321] [0/0] [__guards]
V0418 18:40:22.802000 634 torch/_dynamo/guards.py:2346] [0/0] [__guards] Guard eval latency = 0.90 us
I0418 18:40:22.803000 634 torch/_dynamo/pgo.py:636] [0/0] put_code_state: no cache key, skipping
V0418 18:40:22.807000 634 torch/_dynamo/convert_frame.py:1345] skipping: _fn (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
V0418 18:40:22.808000 634 torch/_dynamo/convert_frame.py:1345] skipping: _callback_from_stance (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
V0418 18:40:22.808000 634 torch/_dynamo/convert_frame.py:1345] skipping: _maybe_set_eval_frame (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
V0418 18:40:22.808000 634 torch/_dynamo/convert_frame.py:1345] skipping: justknobs_check (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py)
===================Traced Graph=========================
I0418 18:40:22.809000 634 torch/_dynamo/__init__.py:99] torch._dynamo.reset
I0418 18:40:22.809000 634 torch/_dynamo/__init__.py:132] torch._dynamo.reset_code_caches
===================Fusion Decisions=========================
===================Output Code=========================
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1116] [0/0] [__output_code] Output code written to: /tmp/torchinductor_ci-user/kq/ckqum5d54uw5vrn4ovqu3ozxvleh2tky6bhsgwcpfvkh2e2266p2.py
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] Output code:
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] # AOT ID: ['1_inference']
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] import torch
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] import math
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] import random
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] import os
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] import tempfile
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from math import inf, nan
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch import device, empty_strided
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] import triton
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] import triton.language as tl
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._inductor.runtime.triton_heuristics import (
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] grid,
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] split_scan_grid,
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] grid_combo_kernels,
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] start_graph,
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] end_graph,
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] cooperative_reduction_grid,
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] )
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] aten = torch.ops.aten
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] _quantized = torch.ops._quantized
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] async_compile = AsyncCompile()
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] # kernel path: /tmp/torchinductor_ci-user/ea/ceamkuk7na23fw4f7lhqhrqai3i7ypqjiqxrfg3gsladwvs7sqfl.py
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] # Topologically Sorted Source Nodes: [z, add_1], Original ATen: [aten.add]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] # Source node to ATen node mapping:
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] # add_1 => add_1
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] # z => add
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] # Graph fragment:
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] # %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] # %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, 2), kwargs = {})
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] triton_poi_fused_add_0 = async_compile.triton('triton_poi_fused_add_0', '''
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] import triton
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] import triton.language as tl
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from triton.compiler.compiler import AttrsDescriptor
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._inductor.runtime import triton_helpers, triton_heuristics
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] triton_helpers.set_driver_to_gpu()
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] @triton_heuristics.pointwise(
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] size_hints={'x': 4},
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] filename=__file__,
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=80, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '0E91B58DAB54C915AAF8467E3EDB6871F6D05685FF049BBEEDA70C789216121A', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] min_elem_per_thread=0
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] )
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] @triton.jit
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] xnumel = 4
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] xoffset = tl.program_id(0) * XBLOCK
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] xindex = xoffset + tl.arange(0, XBLOCK)[:]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] xmask = xindex < xnumel
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] x0 = xindex
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] tmp0 = tl.load(in_ptr0 + (x0), xmask)
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] tmp1 = tl.load(in_ptr1 + (x0), xmask)
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] tmp2 = tmp0 + tmp1
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] tmp3 = 2.0
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] tmp4 = tmp2 + tmp3
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] tl.store(out_ptr0 + (x0), tmp4, xmask)
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] ''', device_str='cuda')
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] async_compile.wait(globals())
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] del async_compile
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] def call(args):
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] arg0_1, arg1_1 = args
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] args.clear()
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] assert_size_stride(arg0_1, (2, 2), (2, 1))
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] assert_size_stride(arg1_1, (2, 2), (2, 1))
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] with torch.cuda._DeviceGuard(0):
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] torch.cuda.set_device(0)
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] buf0 = empty_strided_cuda((2, 2), (2, 1), torch.float32)
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] # Topologically Sorted Source Nodes: [z, add_1], Original ATen: [aten.add]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] stream0 = get_raw_stream(0)
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] triton_poi_fused_add_0.run(arg0_1, arg1_1, buf0, 4, grid=grid(4), stream=stream0)
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] del arg0_1
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] del arg1_1
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] return (buf0, )
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._dynamo.testing import rand_strided
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._inductor.utils import print_performance
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] arg0_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] arg1_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] fn = lambda: call([arg0_1, arg1_1])
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] return print_performance(fn, times=times, repeat=repeat)
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] if __name__ == "__main__":
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] from torch._inductor.wrapper_benchmark import compiled_module_main
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code] compiled_module_main('None', benchmark_compiled_module)
V0418 18:40:23.163000 634 torch/_inductor/codecache.py:1117] [0/0] [__output_code]
============================================
Conclusion¶
In this tutorial we introduced the TORCH_LOGS environment variable and python API by experimenting with a small number of the available logging options. To view descriptions of all available options, run any python script which imports torch and set TORCH_LOGS to “help”.
Alternatively, you can view the torch._logging documentation to see descriptions of all available logging options.
For more information on torch.compile, see the torch.compile tutorial.
Total running time of the script: ( 0 minutes 3.037 seconds)