Source code for torch.backends.mkldnn

import sys
import torch
from contextlib import contextmanager
from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation

[docs]def is_available(): r"""Returns whether PyTorch is built with MKL-DNN support.""" return torch._C.has_mkldnn
def set_flags(_enabled): orig_flags = (torch._C._get_mkldnn_enabled(),) torch._C._set_mkldnn_enabled(_enabled) return orig_flags @contextmanager def flags(enabled=False): with __allow_nonbracketed_mutation(): orig_flags = set_flags(enabled) try: yield finally: with __allow_nonbracketed_mutation(): set_flags(orig_flags[0]) class MkldnnModule(PropModule): def __init__(self, m, name): super(MkldnnModule, self).__init__(m, name) enabled = ContextProp(torch._C._get_mkldnn_enabled, torch._C._set_mkldnn_enabled) # Cool stuff from torch/backends/cudnn/ and # sys.modules[__name__] = MkldnnModule(sys.modules[__name__], __name__)


