Source code for torchx.specs.file_linter

#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import abc
import argparse
import ast
import inspect
from dataclasses import dataclass
from typing import Callable, cast, Dict, List, Optional, Tuple

from docstring_parser import parse
from import read_conf_file
from torchx.util.types import none_throws

# pyre-ignore-all-errors[16]

def _get_default_arguments_descriptions(fn: Callable[..., object]) -> Dict[str, str]:
    parameters = inspect.signature(fn).parameters
    args_decs = {}
    for parameter_name in parameters.keys():
        # The None or Empty string values getting ignored during help command by argparse
        args_decs[parameter_name] = " "
    return args_decs

[docs]class TorchXArgumentHelpFormatter(argparse.HelpFormatter): """Help message formatter which adds default values and required to argument help. If the argument is required, the class appends `(required)` at the end of the help message. If the argument has default value, the class appends `(default: $DEFAULT)` at the end. The formatter is designed to be used only for the torchx components functions. These functions do not have both required and default arguments. """ def _get_help_string(self, action: argparse.Action) -> str: help = or "" # Only `--help` will have be SUPPRESS, so we ignore it if action.default is argparse.SUPPRESS: return help if action.required: help += " (required)" else: help += f" (default: {action.default})" return help
[docs]def get_fn_docstring(fn: Callable[..., object]) -> Tuple[str, Dict[str, str]]: """ Parses the function and arguments description from the provided function. Docstring should be in `google-style format <>`_ If function has no docstring, the function description will be the name of the function, TIP on how to improve the help message and arguments descriptions will be names of the arguments. The arguments that are not present in the docstring will contain default/required information Args: fn: Function with or without docstring Returns: function description, arguments description where key is the name of the argument and value if the description """ default_fn_desc = f"""{fn.__name__} TIP: improve this help string by adding a docstring to your component (see:""" args_description = _get_default_arguments_descriptions(fn) func_description = inspect.getdoc(fn) if not func_description: return default_fn_desc, args_description docstring = parse(func_description) for param in docstring.params: if param.description is not None: args_description[param.arg_name] = param.description short_func_description = docstring.short_description or default_fn_desc if docstring.long_description: short_func_description += " ..." return (short_func_description or default_fn_desc, args_description)
[docs]@dataclass class LinterMessage: name: str description: str line: int char: int severity: str = "error"
[docs]class TorchxFunctionValidator(abc.ABC):
[docs] @abc.abstractmethod def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]: """ Method to call to validate the provided function def. """ raise NotImplementedError()
def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage: return LinterMessage( name="TorchxFunctionValidator", description=description, line=lineno, char=0, severity="error", )
[docs]class TorchxFunctionArgsValidator(TorchxFunctionValidator):
[docs] def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]: linter_errors = [] for arg_def in app_specs_func_def.args.args: arg_linter_errors = self._validate_arg_def(, arg_def) linter_errors += arg_linter_errors if app_specs_func_def.args.vararg: arg_linter_errors = self._validate_arg_def(, none_throws(app_specs_func_def.args.vararg) ) linter_errors += arg_linter_errors for arg in app_specs_func_def.args.kwonlyargs: arg_linter_errors = self._validate_arg_def(, arg) linter_errors += arg_linter_errors return linter_errors
def _validate_arg_def( self, function_name: str, arg_def: ast.arg ) -> List[LinterMessage]: if not arg_def.annotation: return [ self._gen_linter_message( f"Arg {arg_def.arg} missing type annotation", arg_def.lineno ) ] if isinstance(arg_def.annotation, ast.Name): return [] complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation)) if == "Optional": # ast module in python3.9 does not have ast.Index wrapper if isinstance(complex_type_def.slice, ast.Index): complex_type_def = complex_type_def.slice.value else: complex_type_def = complex_type_def.slice # Check if type is Optional[primitive_type] if isinstance(complex_type_def, ast.Name): return [] # Check if type is Union[Dict,List] type_name = if type_name != "Dict" and type_name != "List": desc = ( f"`{function_name}` allows only Dict, List as complex types." f"Argument `{arg_def.arg}` has: {type_name}" ) return [self._gen_linter_message(desc, arg_def.lineno)] linter_errors = [] # ast module in python3.9 does not have objects wrapped in ast.Index if isinstance(complex_type_def.slice, ast.Index): sub_type = complex_type_def.slice.value else: sub_type = complex_type_def.slice if type_name == "Dict": sub_type_tuple = cast(ast.Tuple, sub_type) for el in sub_type_tuple.elts: if not isinstance(el, ast.Name): desc = "Dict can only have primitive types" linter_errors.append(self._gen_linter_message(desc, arg_def.lineno)) elif not isinstance(sub_type, ast.Name): desc = "List can only have primitive types" linter_errors.append(self._gen_linter_message(desc, arg_def.lineno)) return linter_errors
[docs]class TorchxReturnValidator(TorchxFunctionValidator): def _get_return_annotation( self, app_specs_func_def: ast.FunctionDef ) -> Optional[str]: return_def = app_specs_func_def.returns if not return_def: return None if isinstance(return_def, ast.Attribute): return return_def.attr elif isinstance(return_def, ast.Name): return elif isinstance(return_def, ast.Str): return return_def.s elif isinstance(return_def, ast.Constant): return return_def.value else: return None
[docs] def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]: """ Validates return annotation of the torchx function. Current allowed annotations: * AppDef * specs.AppDef """ supported_return_annotation = "AppDef" return_annotation = self._get_return_annotation(app_specs_func_def) linter_errors = [] if not return_annotation: desc = ( f"Function: {} missing return annotation " f"or has unknown annotations. Supported return annotation: {supported_return_annotation}" ) linter_errors.append( self._gen_linter_message(desc, app_specs_func_def.lineno) ) elif return_annotation != supported_return_annotation: desc = ( f"Function: {} has incorrect return annotation, " f"supported annotation: {supported_return_annotation}" ) linter_errors.append( self._gen_linter_message(desc, app_specs_func_def.lineno) ) return linter_errors
[docs]class TorchFunctionVisitor(ast.NodeVisitor): """ Visitor that finds the component_function and runs registered validators on it. Current registered validators: * TorchxFunctionArgsValidator - validates arguments of the function. Criteria: * Each argument should be annotated with the type * The following types are supported: - primitive_types: {int, str, float}, - Optional[primitive_types], - Dict[primitive_types, primitive_types], - List[primitive_types], - Optional[Dict[primitive_types, primitive_types]], - Optional[List[primitive_types]] """ def __init__(self, component_function_name: str) -> None: self.validators = [ TorchxFunctionArgsValidator(), TorchxReturnValidator(), ] self.linter_errors: List[LinterMessage] = [] self.component_function_name = component_function_name self.visited_function = False
[docs] def visit_FunctionDef(self, node: ast.FunctionDef) -> None: """ Validates the function def with the child validators. """ if != self.component_function_name: return self.visited_function = True for validator in self.validators: self.linter_errors += validator.validate(node)
[docs]def validate(path: str, component_function: str) -> List[LinterMessage]: """ Validates the function to make sure it complies the component standard. ``validate`` finds the ``component_function`` and vaidates it for according to the following rules: #. The function must have `google-styple docs <>`_ #. All function parameters must be annotated #. The function must return :py:class:`torchx.specs.api.AppDef` Args: path: Path to python source file. component_function: Name of the function to be validated. Returns: List[LinterMessage]: List of validation errors """ source = read_conf_file(path) try: module = ast.parse(source) except SyntaxError as ex: linter_message = LinterMessage( name="TorchXValidator", description=ex.msg, line=ex.lineno or 0, char=ex.offset or 0, severity="error", ) return [linter_message] visitor = TorchFunctionVisitor(component_function) visitor.visit(module) linter_errors = visitor.linter_errors if not visitor.visited_function: linter_errors.append( LinterMessage( name="TorchXValidator", description=f"Function {component_function} not found", line=0, char=0, severity="error", ) ) return linter_errors


Access comprehensive developer documentation for PyTorch

View Docs


Get in-depth tutorials for beginners and advanced developers

View Tutorials


Find development resources and get your questions answered

View Resources