Skip to content

Commit 74e07d8

Browse files
aivanoufacebook-github-bot
authored andcommitted
Make docstring optional (#259)
Summary: Pull Request resolved: #259 * Refactor docstring functions: combines two functions that retrieve docstring into one * Make docstring optional * Remove docstring validator Git issue: #253 Reviewed By: kiukchung Differential Revision: D31671125 fbshipit-source-id: 2da71fcecf0d05f03c04dcc29b44ec43ab919eaa
1 parent bd1c36d commit 74e07d8

File tree

7 files changed

+210
-224
lines changed

7 files changed

+210
-224
lines changed

docs/source/component_best_practices.rst

+8
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,14 @@ others to understand how to use it.
7474
return AppDef(roles=[Role(..., num_replicas=num_replicas)])
7575
7676
77+
Documentation
78+
^^^^^^^^^^^^^^^^^^^^^
79+
80+
The documentation is optional, but it is the best practice to keep component functions documented,
81+
especially if you want to share your components. See :ref:Component Authoring<components/overview:Authoring>
82+
for more details.
83+
84+
7785
Named Resources
7886
-----------------
7987

torchx/specs/api.py

+11-19
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
Generic,
2020
Iterator,
2121
List,
22-
Mapping,
2322
Optional,
2423
Tuple,
2524
Type,
@@ -28,8 +27,7 @@
2827
)
2928

3029
import yaml
31-
from pyre_extensions import none_throws
32-
from torchx.specs.file_linter import parse_fn_docstring
30+
from torchx.specs.file_linter import get_fn_docstring, TorchXArgumentHelpFormatter
3331
from torchx.util.types import decode_from_string, decode_optional, is_bool, is_primitive
3432

3533

@@ -748,22 +746,21 @@ def get_argparse_param_type(parameter: inspect.Parameter) -> Callable[[str], obj
748746
return str
749747

750748

751-
def _create_args_parser(
752-
fn_name: str,
753-
parameters: Mapping[str, inspect.Parameter],
754-
function_desc: str,
755-
args_desc: Dict[str, str],
756-
) -> argparse.ArgumentParser:
749+
def _create_args_parser(app_fn: Callable[..., AppDef]) -> argparse.ArgumentParser:
750+
parameters = inspect.signature(app_fn).parameters
751+
function_desc, args_desc = get_fn_docstring(app_fn)
757752
script_parser = argparse.ArgumentParser(
758-
prog=f"torchx run ...torchx_params... {fn_name} ",
759-
description=f"App spec: {function_desc}",
753+
prog=f"torchx run <<torchx_params>> {app_fn.__name__} ",
754+
description=f"AppDef for {function_desc}",
755+
formatter_class=TorchXArgumentHelpFormatter,
760756
)
761757

762758
remainder_arg = []
763759

764760
for param_name, parameter in parameters.items():
761+
param_desc = args_desc[parameter.name]
765762
args: Dict[str, Any] = {
766-
"help": args_desc[param_name],
763+
"help": param_desc,
767764
"type": get_argparse_param_type(parameter),
768765
}
769766
if parameter.default != inspect.Parameter.empty:
@@ -788,20 +785,15 @@ def _create_args_parser(
788785
def _get_function_args(
789786
app_fn: Callable[..., AppDef], app_args: List[str]
790787
) -> Tuple[List[object], List[str], Dict[str, object]]:
791-
docstring = none_throws(inspect.getdoc(app_fn))
792-
function_desc, args_desc = parse_fn_docstring(docstring)
793-
794-
parameters = inspect.signature(app_fn).parameters
795-
script_parser = _create_args_parser(
796-
app_fn.__name__, parameters, function_desc, args_desc
797-
)
788+
script_parser = _create_args_parser(app_fn)
798789

799790
parsed_args = script_parser.parse_args(app_args)
800791

801792
function_args = []
802793
var_arg = []
803794
kwargs = {}
804795

796+
parameters = inspect.signature(app_fn).parameters
805797
for param_name, parameter in parameters.items():
806798
arg_value = getattr(parsed_args, param_name)
807799
parameter_type = parameter.annotation

torchx/specs/file_linter.py

+56-81
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
# LICENSE file in the root directory of this source tree.
77

88
import abc
9+
import argparse
910
import ast
11+
import inspect
1012
from dataclasses import dataclass
11-
from typing import Dict, List, Optional, Tuple, cast
13+
from typing import Dict, List, Optional, Tuple, cast, Callable
1214

1315
from docstring_parser import parse
1416
from pyre_extensions import none_throws
@@ -18,53 +20,66 @@
1820
# pyre-ignore-all-errors[16]
1921

2022

21-
def get_arg_names(app_specs_func_def: ast.FunctionDef) -> List[str]:
22-
arg_names = []
23-
fn_args = app_specs_func_def.args
24-
for arg_def in fn_args.args:
25-
arg_names.append(arg_def.arg)
26-
if fn_args.vararg:
27-
arg_names.append(fn_args.vararg.arg)
28-
for arg in fn_args.kwonlyargs:
29-
arg_names.append(arg.arg)
30-
return arg_names
23+
def _get_default_arguments_descriptions(fn: Callable[..., object]) -> Dict[str, str]:
24+
parameters = inspect.signature(fn).parameters
25+
args_decs = {}
26+
for parameter_name in parameters.keys():
27+
# The None or Empty string values getting ignored during help command by argparse
28+
args_decs[parameter_name] = " "
29+
return args_decs
3130

3231

33-
def parse_fn_docstring(func_description: str) -> Tuple[str, Dict[str, str]]:
32+
class TorchXArgumentHelpFormatter(argparse.HelpFormatter):
33+
"""Help message formatter which adds default values and required to argument help.
34+
35+
If the argument is required, the class appends `(required)` at the end of the help message.
36+
If the argument has default value, the class appends `(default: $DEFAULT)` at the end.
37+
The formatter is designed to be used only for the torchx components functions.
38+
These functions do not have both required and default arguments.
3439
"""
35-
Given a docstring in a google-style format, returns the function description and
36-
description of all arguments.
37-
See: https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html
40+
41+
def _get_help_string(self, action: argparse.Action) -> str:
42+
help = action.help or ""
43+
# Only `--help` will have be SUPPRESS, so we ignore it
44+
if action.default is argparse.SUPPRESS:
45+
return help
46+
if action.required:
47+
help += " (required)"
48+
else:
49+
help += f" (default: {action.default})"
50+
return help
51+
52+
53+
def get_fn_docstring(fn: Callable[..., object]) -> Tuple[str, Dict[str, str]]:
3854
"""
39-
args_description = {}
55+
Parses the function and arguments description from the provided function. Docstring should be in
56+
`google-style format <https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_
57+
58+
If function has no docstring, the function description will be the name of the function, TIP
59+
on how to improve the help message and arguments descriptions will be names of the arguments.
60+
61+
The arguments that are not present in the docstring will contain default/required information
62+
63+
Args:
64+
fn: Function with or without docstring
65+
66+
Returns:
67+
function description, arguments description where key is the name of the argument and value
68+
if the description
69+
"""
70+
default_fn_desc = f"""{fn.__name__} TIP: improve this help string by adding a docstring
71+
to your component (see: https://pytorch.org/torchx/latest/component_best_practices.html)"""
72+
args_description = _get_default_arguments_descriptions(fn)
73+
func_description = inspect.getdoc(fn)
74+
if not func_description:
75+
return default_fn_desc, args_description
4076
docstring = parse(func_description)
4177
for param in docstring.params:
4278
args_description[param.arg_name] = param.description
43-
short_func_description = docstring.short_description
44-
return (short_func_description or "", args_description)
45-
46-
47-
def _get_fn_docstring(
48-
source: str, function_name: str
49-
) -> Optional[Tuple[str, Dict[str, str]]]:
50-
module = ast.parse(source)
51-
for expr in module.body:
52-
if type(expr) == ast.FunctionDef:
53-
func_def = cast(ast.FunctionDef, expr)
54-
if func_def.name == function_name:
55-
docstring = ast.get_docstring(func_def)
56-
if not docstring:
57-
return None
58-
return parse_fn_docstring(docstring)
59-
return None
60-
61-
62-
def get_short_fn_description(path: str, function_name: str) -> Optional[str]:
63-
source = read_conf_file(path)
64-
docstring = _get_fn_docstring(source, function_name)
65-
if not docstring:
66-
return None
67-
return docstring[0]
79+
short_func_description = docstring.short_description or default_fn_desc
80+
if docstring.long_description:
81+
short_func_description += " ..."
82+
return (short_func_description or default_fn_desc, args_description)
6883

6984

7085
@dataclass
@@ -91,38 +106,6 @@ def _gen_linter_message(self, description: str, lineno: int) -> LinterMessage:
91106
)
92107

93108

94-
class TorchxDocstringValidator(TorchxFunctionValidator):
95-
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
96-
"""
97-
Validates the docstring of the `get_app_spec` function. Criteria:
98-
* There mast be google-style docstring
99-
* If there are more than zero arguments, there mast be a `Args:` section defined
100-
with all arguments included.
101-
"""
102-
docsting = ast.get_docstring(app_specs_func_def)
103-
lineno = app_specs_func_def.lineno
104-
if not docsting:
105-
desc = (
106-
f"`{app_specs_func_def.name}` is missing a Google Style docstring, please add one. "
107-
"For more information on the docstring format see: "
108-
"https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html"
109-
)
110-
return [self._gen_linter_message(desc, lineno)]
111-
112-
arg_names = get_arg_names(app_specs_func_def)
113-
_, docstring_arg_defs = parse_fn_docstring(docsting)
114-
missing_args = [
115-
arg_name for arg_name in arg_names if arg_name not in docstring_arg_defs
116-
]
117-
if len(missing_args) > 0:
118-
desc = (
119-
f"`{app_specs_func_def.name}` not all function arguments are present"
120-
f" in the docstring. Missing args: {missing_args}"
121-
)
122-
return [self._gen_linter_message(desc, lineno)]
123-
return []
124-
125-
126109
class TorchxFunctionArgsValidator(TorchxFunctionValidator):
127110
def validate(self, app_specs_func_def: ast.FunctionDef) -> List[LinterMessage]:
128111
linter_errors = []
@@ -149,7 +132,6 @@ def _validate_arg_def(
149132
)
150133
]
151134
if isinstance(arg_def.annotation, ast.Name):
152-
# TODO(aivanou): add support for primitive type check
153135
return []
154136
complex_type_def = cast(ast.Subscript, none_throws(arg_def.annotation))
155137
if complex_type_def.value.id == "Optional":
@@ -239,12 +221,6 @@ class TorchFunctionVisitor(ast.NodeVisitor):
239221
Visitor that finds the component_function and runs registered validators on it.
240222
Current registered validators:
241223
242-
* TorchxDocstringValidator - validates the docstring of the function.
243-
Criteria:
244-
* There format should be google-python
245-
* If there are more than zero arguments defined, there
246-
should be obligatory `Args:` section that describes each argument on a new line.
247-
248224
* TorchxFunctionArgsValidator - validates arguments of the function.
249225
Criteria:
250226
* Each argument should be annotated with the type
@@ -260,7 +236,6 @@ class TorchFunctionVisitor(ast.NodeVisitor):
260236

261237
def __init__(self, component_function_name: str) -> None:
262238
self.validators = [
263-
TorchxDocstringValidator(),
264239
TorchxFunctionArgsValidator(),
265240
TorchxReturnValidator(),
266241
]

torchx/specs/finder.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from pyre_extensions import none_throws
1919
from torchx.specs import AppDef
20-
from torchx.specs.file_linter import get_short_fn_description, validate
20+
from torchx.specs.file_linter import get_fn_docstring, validate
2121
from torchx.util import entrypoints
2222
from torchx.util.io import read_conf_file
2323

@@ -40,14 +40,15 @@ class _Component:
4040
Args:
4141
name: The name of the component, which usually MODULE_PATH.FN_NAME
4242
description: The description of the component, taken from the desrciption
43-
of the function that creates component
43+
of the function that creates component. In case of no docstring, description
44+
will be the same as name
4445
fn_name: Function name that creates component
4546
fn: Function that creates component
4647
validation_errors: Validation errors
4748
"""
4849

4950
name: str
50-
description: Optional[str]
51+
description: str
5152
fn_name: str
5253
fn: Callable[..., AppDef]
5354
validation_errors: List[str]
@@ -150,7 +151,7 @@ def _get_components_from_module(
150151
module_path = os.path.abspath(module.__file__)
151152
for function_name, function in functions:
152153
linter_errors = validate(module_path, function_name)
153-
component_desc = get_short_fn_description(module_path, function_name)
154+
component_desc, _ = get_fn_docstring(function)
154155
component_def = _Component(
155156
name=self._get_component_name(
156157
base_module, module.__name__, function_name
@@ -197,7 +198,6 @@ def find(self) -> List[_Component]:
197198
validation_errors = self._get_validation_errors(
198199
self._filepath, self._function_name
199200
)
200-
fn_desc = get_short_fn_description(self._filepath, self._function_name)
201201

202202
file_source = read_conf_file(self._filepath)
203203
namespace = globals()
@@ -207,6 +207,7 @@ def find(self) -> List[_Component]:
207207
f"Function {self._function_name} does not exist in file {self._filepath}"
208208
)
209209
app_fn = namespace[self._function_name]
210+
fn_desc, _ = get_fn_docstring(app_fn)
210211
return [
211212
_Component(
212213
name=f"{self._filepath}:{self._function_name}",

0 commit comments

Comments
 (0)