import ast
import importlib
import inspect
import logging
import shutil
import traceback
from types import ModuleType
from typing import (
Any,
Callable,
Collection,
Dict,
Iterable,
List,
Optional,
Type,
Union,
get_args,
)
import docstring_parser
from docutils import nodes # type: ignore
from docutils import statemachine
from docutils.parsers.rst import Directive # type: ignore
from sphinx import application # type: ignore
from sphinx.util import nodes as sphinx_nodes # type: ignore
from jobmon.client.api import Tool
from jobmon.client.task import Task
from jobmon.client.workflow import Workflow
[docs]
SIMPLE_TYPES = {str, int, float, bool}
[docs]
BUILT_IN_COLLECTIONS = {list, tuple, set}
[docs]
OPTIONAL_TYPES = {Optional, type(None)}
[docs]
TASK_RUNNER_NAME = "worker_node_entry_point"
[docs]
TASK_RUNNER_SUB_COMMAND = "task_generator"
[docs]
HELP_TEXT_INTRO_FORMAT = "Command Line Documentation for {full_path}"
[docs]
SERIALIZED_EMPTY_STRING = '""'
[docs]
logger = logging.getLogger(__name__)
[docs]
def create_task_name(
kwargs_for_name: dict,
prefix: str = "",
name_func: Optional[Callable] = None,
) -> str:
"""Create a task name from the kwargs."""
if name_func:
return name_func(prefix, kwargs_for_name)
else:
# use FHS default if no name_func is provided
# Handle the case where we have an empty string placeholder in the name by making it
# empty in the name; Jobmon will not accept quotes in the name
for key, value in kwargs_for_name.items():
if value == SERIALIZED_EMPTY_STRING:
kwargs_for_name[key] = ""
name = prefix
for item_name, value in kwargs_for_name.items():
# trim leading and ending single quote added by make_cli_argument_string
if value[0] == "'" and value[-1] == "'":
value = value[1:-1]
# trim leading [ and ending ] when converting a list to string
if value[0] == "[" and value[-1] == "]":
value = value[1:-1]
# remove illegal characters: '/\\'\" ' from name
value = (
value.replace("/", "_")
.replace("\\", "_")
.replace('"', "_")
.replace("'", "_")
.replace(" ", "_")
)
name += f":{item_name}={value}"
return name
[docs]
def _is_multidimensional_type(type_hint: Any) -> bool:
# Check if the type hint is a List
if (
hasattr(type_hint, "__origin__")
and type_hint.__origin__ in BUILT_IN_COLLECTIONS # type: ignore
):
# Get the arguments inside the List
args = get_args(type_hint)
# Check if the first argument is also a List
if (
args
and hasattr(args[0], "__origin__")
and args[0].__origin__ in BUILT_IN_COLLECTIONS # type: ignore
):
return True
return False
[docs]
def _find_executable_path(executable_name: str) -> str:
path = shutil.which(executable_name)
if path is not None:
return path
else:
raise FileNotFoundError(
f"The executable '{executable_name}' was not found in the system PATH."
)
[docs]
def _is_unannotated_built_in_collection_type(obj_type: Any) -> bool:
"""Return whether the ``obj_type`` is an unannotated, built-in collection type.
Ex. ``list`` would return ``True``, ``tuple`` would return ``True``, ``list[str]``
would return ``False``.
"""
return obj_type in BUILT_IN_COLLECTIONS
[docs]
def _get_collection_type(obj_type: Any) -> Optional[Type]:
"""Return the type of collection.
To check, we extract the ``__origin__`` component of the ``obj_type`` -- this would be
the type of collection (list, tuple, set).
Ex. ``list[int]`` would return ``list``, ``tuple[str]`` would return ``tuple``,
``float`` would return ``None`` (because it is not a collection).
"""
# This catches unannotated built-in collection types (``list``, ``tuple``, ``set``)
if obj_type in BUILT_IN_COLLECTIONS:
return obj_type
return getattr(obj_type, "__origin__", None)
[docs]
def _get_generic_type_parameters(obj_type: Any) -> Any:
"""Return the parameters for a generic type.
The results are a tuple of types. For example, a ``list[int]`` returns ``(int,)``, a
``tuple[int, str]``, returns ``(int, str)``, and a ``Union[str, None]`` returns
``(str, NoneType)``.
"""
return getattr(obj_type, "__args__", None)
[docs]
def _is_annotated_built_in_collection_type(obj_type: Any) -> bool:
"""Return whether the type is a properly annotated built-in collection type."""
collection_type = _get_collection_type(obj_type)
collection_items_type = _get_generic_type_parameters(obj_type)
return collection_type in BUILT_IN_COLLECTIONS and collection_items_type is not None
[docs]
def _zip_collection_items_and_types(
obj: Collection[str], collection_items_type: tuple
) -> Any:
"""Return a list of tuples mapping collection items and their types.
Ex. ``obj`` is ``[1, 2, 3]`` and ``collection_items_type`` is ``(int,)`` would return
``[(1, int), (2, int), (3, int)]``.
"""
# If the collection has a single internal type annotation, or if the second
# internal type annotation is the Ellipsis (``...``), match the first (only) internal
# annotation to all items in the collection
if len(collection_items_type) == 1 or collection_items_type[1] == Ellipsis:
return [(item, collection_items_type[0]) for item in obj]
# If the collection has multiple internal type annotations but the object
# provided is a collection of a different length, raise an error
if len(obj) != len(collection_items_type):
raise TypeError(
f"Expected a collection of internal type ``{collection_items_type}`` but got "
f"``{obj}``. Note the mismatched lengths."
)
# Otherwise, match each item of the collection to the type of the corresponding internal
# type annotation
return list(zip(obj, collection_items_type))
[docs]
def is_optional_type(obj_type: Any) -> bool:
"""Return whether the ``obj_type`` is an Optional type.
Ex. ``Optional[str]`` would return ``True``, ``NoneType`` would return ``True``,
``str | None`` would return ``True``, ``str`` would return ``False``.
"""
matches: List[bool] = []
# This is true if ``obj_type=NoneType``
exact_type_match = obj_type in OPTIONAL_TYPES
matches.append(exact_type_match)
# This is true if ``obj_type=Optional[str]``. The _name attribute is set by
# ``typing.Optional``
possible_annotation = getattr(obj_type, "_name", None)
annotated_type_match = possible_annotation is not None and any(
possible_annotation in annotation for annotation in map(str, OPTIONAL_TYPES)
)
matches.append(annotated_type_match)
# This is true if ``obj_type=str | None`` (Note that this is represented as a collection
# so we can access it using ``_get_generic_type_parameters``)
possible_union_type = _get_generic_type_parameters(obj_type)
if possible_union_type is not None:
possible_union_annotation = any(
annotation in OPTIONAL_TYPES for annotation in possible_union_type
)
matches.append(possible_union_annotation)
return any(matches)
[docs]
def get_optional_type_parameter(obj_type: Type) -> Type:
"""Return the type inside an optional."""
# theoretically a tuple with NoneType and the other type
types = _get_generic_type_parameters(obj_type)
if types is not None:
for t in types:
if t != type(None): # noqa: E721
return t
raise TypeError(
f"Looks like we couldn't extract a non-None type from optional-type {obj_type}. This "
"shouldn't have happend."
)
[docs]
def _clean_arg_name(arg_name: str) -> str:
"""Remove the ``--`` from the arg name and convert dashes to underscores."""
return arg_name.replace("--", "").replace("-", "_")
[docs]
def _get_short_description(task_function_docstring: docstring_parser.Docstring) -> str:
short_description = ""
if task_function_docstring.short_description:
short_description += "\n" + task_function_docstring.short_description
return short_description
[docs]
def make_cli_argument_string(arg_value: Union[str, List]) -> str:
"""Make a CLI argument string from an argument name and value.
For string, return itself.
For list, say ["a", "b", "c"], return string '[a,b,c]'
"""
if isinstance(arg_value, list):
no_space_string = ",".join(arg_value)
# we can not take single quote in the string, throw an error
if "'" in no_space_string:
raise ValueError(
f"This version of TaskGenertor cannot serialize list with single quote in it: "
f"{arg_value}."
)
return f"'[{no_space_string}]'"
return f"'{arg_value}'"
[docs]
class TaskGenerator:
"""Class for auto-generating jobmon tasks from a python function.
Arguments on the task_function are inspected to generate the task template, and type
annotations are used for serializing and deserializing arguments to strings passed on the
command line.
The TaskGenerator has built in serializers for:
* str, int, bool, float
* list, tuple, set of other serializable types
* Optional types (and types that are equivalent to Optional types) of other
serializable types
Users can also supply serializers for custom types by providing a dictionary of type to
a tuple of serialization and deserialization functions. Serializers must always turn
objects into a single string. Lists of strings are represented on the command line as
``arg_name='[value1,value2]'``.
Note:
While lists, tuples and sets (and their ``Optional`` counterparts) can currently
serialize empty collections, we can't currently serialize custom types as empty lists.
"""
def __init__(
self,
task_function: Callable,
serializers: Dict,
tool_name: str,
naming_args: Optional[List[str]] = None,
max_attempts: Optional[int] = None,
module_source_path: Optional[str] = None,
name_func: Optional[Callable] = None,
default_cluster_name: str = "",
default_compute_resources: Optional[Dict[str, Any]] = None,
default_resource_scales: Optional[Dict[str, float]] = None,
yaml_file: Optional[str] = None,
) -> None:
"""Initialize TaskGenerator.
Generates and saves the task template.
Args:
task_function: The function that the task will run.
serializers: A dict mapping types to two callables. The first callable is the
serializer, which converts an object of the type to a string. And the second is
a deserializer, which converts a string to an object of the type.
tool_name: A jobmon tool name for generating tasks.
naming_args: A list of arguments to use in the task name. If not provided (or
``None``), it uses all the arguments. If ``[]`` is provided, it uses no naming
arguments, and the name of the task will just be the name of the task function.
max_attempts: The max number of attempts jobmon will make on the tasks
max_attempts: The max number of attempts jobmon will make on the tasks
module_source_path: The path to the module source code. If not provided,
the module is assumed to be installed in the system.
name_func: A function that takes in the task name prefix and the kwargs to generate
the task name. If not provided, the FHS default is used.
default_cluster_name: The default cluster name to use when creating tasks. If not
provided, an empty string is used.
default_compute_resources: The default compute resources when creating tasks.
If not provided, an empty dictionary is used.
default_resource_scales: The default resource scales to use when creating tasks.
If not provided, an empty dictionary is used.
yaml_file: The path to the yaml file that contains the task template. If not
provided, the default template is used.
"""
[docs]
self.task_function = task_function
[docs]
self.serializers = serializers
[docs]
self.max_attempts = max_attempts
[docs]
self.mod_name = f"{task_function.__module__}"
[docs]
self.name = task_function.__name__
[docs]
self.full_path = f"{task_function.__module__}:{self.name}"
[docs]
self.module_source_path = module_source_path
[docs]
self.params = {
name: details.annotation
for name, details in inspect.signature(
self.task_function
).parameters.items()
}
if naming_args is None:
logger.warning(
"You have specified no naming_args on a task_generator, which means that all "
"task args will be used in the task name. Usually, this results in a name "
"that's too long for Jobmon, so you might need to specify naming_args. "
"Include just the arguments that differ from task to task within a workflow."
)
[docs]
self._naming_args = (
naming_args if naming_args is not None else self.params.keys()
)
self._validate_task_function()
[docs]
self._task_template = None
[docs]
self._default_cluster_name = default_cluster_name
[docs]
self._default_compute_resources = default_compute_resources
[docs]
self._default_resource_scales = default_resource_scales
[docs]
self._yaml_file = yaml_file
# If the user provides a name_func, use that to generate the task name
# otherwise, use FHS default
[docs]
self.name_func = name_func
[docs]
def _validate_task_function(self) -> None:
"""Check that a task can be generated from the task_function.
Currently checks:
- All parameters have type annotations
- All parameters have serializers
"""
# Verify that all parameters have type annotations
unannotated_params = [
name
for name, annotation in self.params.items()
if annotation == inspect._empty
]
if unannotated_params:
raise TypeError(
f"Can't generate a task template for {self.name} because "
"the following parameters are missing type annotations: "
f"{' '.join(unannotated_params)}"
)
# Verify that we know how to serialize all of the parameters
for arg_name, annotation in self.params.items():
if not self._is_valid_annotation(annotation):
raise TypeError(
f"Unknown annotation on task_function ``{self.name}`` for parameter "
f"``{arg_name}: {annotation}``. Please either provide a serializer and "
f"deserializer for this type in the ``serializers`` dictionary, or "
f"annotate this parameter with one of the known types: "
f"{SIMPLE_TYPES | BUILT_IN_COLLECTIONS | OPTIONAL_TYPES}."
)
[docs]
def _is_valid_annotation(self, annotation: Type) -> bool:
"""Returns True if the annotation is valid, False otherwise."""
# Check for simple types and types that were provided in the serializers dict
if annotation in SIMPLE_TYPES:
return True
if annotation in self.serializers.keys():
return True
# Check for built-in collection types. Note that we validate the underlying type
if _is_annotated_built_in_collection_type(annotation):
if all(
# The type can either be a simple type, a type in the serializers dict, or
# the Ellipses (``...``) type
t in SIMPLE_TYPES | set(self.serializers.keys()) | {Ellipsis}
for t in _get_generic_type_parameters(annotation)
):
return True
# Check for optional types. Note that we validate the underlying type
if is_optional_type(annotation):
return self._is_valid_annotation(get_optional_type_parameter(annotation))
# If we've made it through all the conditions above, return False
return False
[docs]
def _generate_task_template(self) -> None:
"""Generate and store the task template."""
# args convert to foo=1 bar=2
args_template = " ".join(
f"{arg_name}={{{arg_name}}}" for arg_name in self.params
)
args_template = " " + args_template
if self.module_source_path:
self._task_template = self.tool.get_task_template(
template_name=self.name,
command_template="{executable} "
+ TASK_RUNNER_SUB_COMMAND
+ " --module_name "
+ self.mod_name
+ " --func_name "
+ self.name
+ " --module_source_path "
+ self.module_source_path
+ args_template,
node_args=self.params.keys(), # type: ignore
op_args=["executable"],
default_cluster_name=self._default_cluster_name,
default_compute_resources=self._default_compute_resources,
default_resource_scales=self._default_resource_scales,
yaml_file=self._yaml_file,
)
else:
self._task_template = self.tool.get_task_template(
template_name=self.name,
command_template="{executable} "
+ TASK_RUNNER_SUB_COMMAND
+ " --module_name "
+ self.mod_name
+ " --func_name "
+ self.name
+ args_template,
node_args=self.params.keys(), # type: ignore
op_args=["executable"],
default_cluster_name=self._default_cluster_name,
default_compute_resources=self._default_compute_resources,
default_resource_scales=self._default_resource_scales,
yaml_file=self._yaml_file,
)
[docs]
def _is_valid_type(self, obj: Any, expected_type: Type) -> bool:
"""Check that the obj type matches the expected type."""
if _get_collection_type(expected_type) in BUILT_IN_COLLECTIONS:
matches = isinstance(obj, _get_collection_type(expected_type)) # type: ignore
elif is_optional_type(expected_type):
if obj is None:
matches = True
else:
matches = self._is_valid_type(
obj, get_optional_type_parameter(expected_type)
)
else:
matches = isinstance(obj, expected_type)
return matches
[docs]
def serialize(self, obj: Any, expected_type: Type) -> str:
"""Serialize ``obj``, validating that it is actually of type ``expected_type``."""
if not self._is_valid_type(obj, expected_type):
raise TypeError(
f"Expected something of type ``{expected_type}``, but got ``{obj}`` of "
f"type ``{type(obj)}``."
)
serialized_result: str
# raise an error if the expected_type is a multi-dimensional collection
if _is_multidimensional_type(expected_type):
raise TypeError(
f"This version of Task Generator cannot "
f"serialize multi-dimensional collection: {expected_type}."
)
if expected_type in self.serializers.keys():
# The 0'th index of the serializers dict is the serialization function
serialized_result = self.serializers[expected_type][0](obj)
elif expected_type in SIMPLE_TYPES:
# If the object is an empty string use the placeholder, otherwise just stringify it
if expected_type == str and len(obj) == 0:
serialized_result = SERIALIZED_EMPTY_STRING
else:
serialized_result = str(obj)
elif is_optional_type(expected_type):
if obj is None:
serialized_result = str(None)
else:
serialized_result = self.serialize(
obj=obj, expected_type=get_optional_type_parameter(expected_type)
)
elif _get_collection_type(expected_type) in BUILT_IN_COLLECTIONS:
internal_type = _get_generic_type_parameters(expected_type)
if internal_type is None:
raise TypeError(
f"Cannot serialize collection with unknown internal type: {expected_type}."
)
serialized_result = [
self.serialize(obj=item, expected_type=item_type) # type: ignore
for item, item_type in _zip_collection_items_and_types(
obj, internal_type
)
]
# convert list to string
if isinstance(serialized_result, list):
serialized_result = f"[{','.join(serialized_result)}]"
else:
raise TypeError(
f"Cannot serialize unknown type {expected_type.__name__}. Upon "
"constructing this TaskGenerator, please provide serialization and "
"deserialization functions for this type in the ``serializers`` dictionary."
)
# error out if the output is not a string
if not isinstance(serialized_result, str):
raise TypeError(
f"Expected a string to serialize, but got {type(serialized_result)}."
)
return serialized_result
[docs]
def serialize_array(self, obj: Any, expected_type: Type) -> List:
"""Serialize obj into a list of serialized string."""
if isinstance(obj, list):
return [self.serialize(item, expected_type) for item in obj]
else:
return [self.serialize(obj, expected_type)]
[docs]
def deserialize(self, obj: str, obj_type: Type) -> Any:
"""Deserialize ``obj``."""
deserialized_result: Any
# error out if the input is not a string
if not isinstance(obj, str):
raise TypeError(f"Expected a string to deserialize, but got {type(obj)}.")
# raise an error if the expected_type is a multi-dimensional collection
if _is_multidimensional_type(obj_type):
raise TypeError(
f"This version of Task Generator cannot "
f"serialize multi-dimensional collection: {obj_type}."
)
if is_optional_type(obj_type):
if obj == str(None):
return None
else:
obj_type = get_optional_type_parameter(obj_type)
if obj_type in self.serializers.keys():
# To support dynamic length list input, the deserializer only takes string
if not isinstance(obj, str):
raise TypeError(
f"Expected a string to deserialize, but got {type(obj)}."
)
# The 1'st index of the serializers dict is the deserialization function
deserialized_result = self.serializers[obj_type][1](obj)
elif obj_type in SIMPLE_TYPES:
# If we simply called something like ``bool("False")`` we would get ``True`` due to
# the way Python thinks about truthiness. So instead we use ``ast.literal_eval``
# to evaluate the string as a Python literal.
if obj_type == bool:
deserialized_result = ast.literal_eval(obj) # type: ignore
# If the object is our serialized empty string, return an empty string
elif obj_type == str and obj == SERIALIZED_EMPTY_STRING:
deserialized_result = ""
# For all other simple types, we can simply call the type on ``obj``
else:
deserialized_result = obj_type(obj)
elif _is_unannotated_built_in_collection_type(obj_type):
raise TypeError(
f"The provided type annotation ``{obj_type}`` does not provide enough "
"information to deserialize. Please provide a more specific type "
"annotation like ``list[int]`` or ``tuple[str]``."
)
elif _is_annotated_built_in_collection_type(obj_type):
# input is a string, convert it to a list
try:
middle_result = ast.literal_eval(obj)
# if the input is a single item, convert it to a list
if type(middle_result) not in BUILT_IN_COLLECTIONS:
middle_result = [middle_result]
except Exception:
# handle input like "[a,b,c]"
# remove leading and tailing space
obj = obj.strip()
# handle input like "'[a,b,c]'"
if obj[0] == "'" and obj[-1] == "'":
obj = obj[1:-1]
# remove leading and tailing brackets
if obj[0] in ["[", "("] and obj[-1] in ["]", ")"]:
obj = obj[1:-1]
middle_result = obj.split(",")
deserialized_result = []
for item in middle_result:
if obj_type.__args__[0] in self.serializers.keys():
deserialized_result.append(
self.serializers[obj_type.__args__[0]][1](item)
)
else:
deserialized_result.append(obj_type.__args__[0](item))
else:
raise TypeError(
f"Cannot deserialize unknown type {obj_type.__name__}. Upon "
"constructing this TaskGenerator, please provide serialization and "
"deserialization functions for this type in the ``serializers`` dictionary."
)
return deserialized_result
[docs]
def _cluster_resource_check(
self, cluster_name: str, compute_resources: Optional[Dict]
) -> str:
"""Make sure a cluster name is available to use, and compute_resource is a dict."""
# add compute_resources type protection
if compute_resources:
if not isinstance(compute_resources, dict):
raise TypeError(
f"Expected a dictionary for compute_resources, "
f"but got {type(compute_resources)}."
)
# if cluster_name is still an empty string, use the default cluster name
if cluster_name == "":
cluster_name = self._default_cluster_name
# if cluster_name is still an empty string, leave to wf to handle
logger.info(
"Cluster name is empty. Will use the wf default if one is provided."
)
return cluster_name
[docs]
def create_task(
self,
cluster_name: str = "",
compute_resources: Optional[Dict] = None,
resource_scales: Optional[Dict[str, Any]] = None,
upstream_tasks: List[Task] = [],
task_attributes: Union[List[str], Dict[str, Any]] = {},
**kwargs: Any,
) -> Task:
"""Create a task for the task_function with the given kwargs."""
# if the task template is not generated, generate it
if self._task_template is None:
self._generate_task_template()
# check input and adjust the cluster_name value
cluster_name = self._cluster_resource_check(cluster_name, compute_resources)
executable_path = _find_executable_path(executable_name=TASK_RUNNER_NAME)
# Serialize the kwargs
serialized_kwargs = {
name: self.serialize(obj=value, expected_type=self.params[name])
for name, value in kwargs.items()
}
# assign the None value to args in self.params.keys but not in serialized_kwargs.keys
for name in self.params.keys():
if name not in serialized_kwargs.keys():
serialized_kwargs[name] = str(None)
# Format the kwargs for the task
kwargs_for_task = {
name: make_cli_argument_string(arg_value=value)
for name, value in serialized_kwargs.items()
}
# We want a slightly different format to put list kwargs into the name
kwargs_for_name = {
name: value
for name, value in serialized_kwargs.items()
if name in self._naming_args
}
name = create_task_name(
kwargs_for_name=kwargs_for_name, prefix=self.name, name_func=self.name_func
)
# Create the task
task = self._task_template.create_task( # type: ignore
name=name,
cluster_name=cluster_name,
compute_resources=compute_resources,
resource_scales=resource_scales,
max_attempts=self.max_attempts,
executable=executable_path,
upstream_tasks=upstream_tasks,
task_attributes=task_attributes,
**kwargs_for_task, # type: ignore
)
return task
[docs]
def create_tasks(
self,
cluster_name: str = "",
compute_resources: Optional[Dict] = None,
resource_scales: Optional[Dict[str, Any]] = None,
upstream_tasks: Optional[List[Task]] = None,
**kwargs: Any,
) -> List[Task]:
"""Create a task array for the task_function with the given kwargs."""
# if the task template is not generated, generate it
if self._task_template is None:
self._generate_task_template()
# check input and adjust the cluster_name value
cluster_name = self._cluster_resource_check(cluster_name, compute_resources)
executable_path = _find_executable_path(executable_name=TASK_RUNNER_NAME)
# Serialize the kwargs
serialized_kwargs = {
name: self.serialize_array(obj=value, expected_type=self.params[name])
for name, value in kwargs.items()
}
# Format the kwargs for the task
# Each individual element in the array should be formatted before sending to the cli
kwargs_for_task = {}
for name, value in serialized_kwargs.items():
if isinstance(value, list):
kwargs_for_task[name] = [
make_cli_argument_string(arg_value=item) for item in value
]
else:
kwargs_for_task[name] = make_cli_argument_string(arg_value=value)
# name is auto for array
# Create the task
tasks = self._task_template.create_tasks( # type: ignore
cluster_name=cluster_name,
compute_resources=compute_resources,
resource_scales=resource_scales,
max_attempts=self.max_attempts,
executable=executable_path,
upstream_tasks=upstream_tasks,
**kwargs_for_task, # type: ignore
)
return tasks
[docs]
def run(self, args: List[str]) -> Any:
"""Run the task_function with the given args and return any result."""
# Parse the args
parsed_arg_value_pairs: Dict[str, Union[str, List[str]]] = dict()
# args is a list of string like ["arg1=1", "arg2=[2, 3]", "arg1=4"]
for arg in args:
arg_name, arg_value = arg.split("=")
# if the arg_name key, already exists, append the value to the list
if arg_name in parsed_arg_value_pairs.keys():
if isinstance(parsed_arg_value_pairs[arg_name], list):
parsed_arg_value_pairs[arg_name].append(arg_value) # type: ignore
else:
parsed_arg_value_pairs[arg_name] = [
parsed_arg_value_pairs[arg_name], # type: ignore
arg_value,
]
else:
parsed_arg_value_pairs[arg_name] = arg_value # type: ignore
# Raise an error if the user did not provide all of the arguments for the task_function
if parsed_arg_value_pairs.keys() != self.params.keys():
raise ValueError(
"Some arguments to the task_function were not provided on the command line. "
"The following arguments were not provided: "
f"{set(self.params.keys()) - set(parsed_arg_value_pairs.keys())}"
)
# Deserialize the args, catching any errors that may come from the deserialization fn
deserialized_args = dict()
for arg_name, arg_value in parsed_arg_value_pairs.items(): # type: ignore
try:
deserialized_args[arg_name] = self.deserialize(
obj=arg_value, obj_type=self.params[arg_name]
)
except Exception:
raise ValueError(
f"Could not deserialize argument ``{arg_name}`` with value "
f"``{arg_value}``. This error could be the result of a bug in the "
f"deserialization function, or a result of not providing a value for "
f"this argument on the command line."
)
# Run the task_function
return self.task_function(**deserialized_args)
[docs]
def help(self) -> str:
"""Return help text for the task_function."""
# Parse the task function's docstring - Note that there may be nothing!
task_function_docstring = docstring_parser.parse( # type: ignore
self.task_function.__doc__ # type: ignore
) # type: ignore
# Map the parameter names to their annotations and descriptions
task_function_docstring_param_names_to_annotations = {
param.arg_name: param.type_name for param in task_function_docstring.params
}
task_function_docstring_param_names_to_descriptions = {
param.arg_name: param.description
for param in task_function_docstring.params
}
# Get the short description
short_description = _get_short_description(task_function_docstring)
# Get the parameter names and descriptions
parameter_names_and_descriptions = self._get_param_names_and_descriptions(
task_function_docstring_param_names_to_annotations,
task_function_docstring_param_names_to_descriptions,
)
# Combine the full help text and return it
return "\n".join(
[
HELP_TEXT_INTRO_FORMAT.format(full_path=self.full_path),
short_description,
*parameter_names_and_descriptions,
]
)
[docs]
def _get_param_names_and_descriptions(
self,
task_function_docstring_param_names_to_annotations: Dict,
task_function_docstring_param_names_to_descriptions: Dict,
) -> List[str]:
parameter_names_and_descriptions = []
for param, annotation in self.params.items():
# If the parameter has an annotation, use that for the text.
if param in task_function_docstring_param_names_to_annotations.keys():
annotation_string = task_function_docstring_param_names_to_annotations[
param
]
else:
# Otherwise try to pull the annotation from the self.params annotation (this
# will work for simple types like str, int, etc.)
try:
annotation_string = annotation.__name__
# Otherwise just stringify the annotation (this will work for custom or union
# types like ``str | None``)
except AttributeError:
annotation_string = str(annotation)
this_param_text = f"\t--{param} [{annotation_string}]: "
# If the parameter has a description in the docstring, add that to the text
if param in task_function_docstring_param_names_to_descriptions.keys():
this_param_text += (
f"{task_function_docstring_param_names_to_descriptions[param]}"
)
else:
this_param_text += "No description found in docstring."
parameter_names_and_descriptions.append(this_param_text)
return parameter_names_and_descriptions
[docs]
def task_generator(
serializers: Dict,
tool_name: str,
naming_args: Optional[List[str]] = None,
max_attempts: Optional[int] = None,
module_source_path: Optional[str] = None,
name_func: Optional[Callable] = None,
default_cluster_name: str = "",
default_compute_resources: Optional[Dict[str, Any]] = None,
default_resource_scales: Optional[Dict[str, float]] = None,
yaml_file: Optional[str] = None,
) -> Callable:
"""Decorator for generating jobmon tasks from a python function."""
def wrapper(task_function: Callable) -> TaskGenerator:
"""Wrap the task_function with a TaskGenerator."""
return TaskGenerator(
task_function=task_function,
serializers=serializers,
tool_name=tool_name,
naming_args=naming_args,
max_attempts=max_attempts,
module_source_path=module_source_path,
name_func=name_func,
default_cluster_name=default_cluster_name,
default_compute_resources=default_compute_resources,
default_resource_scales=default_resource_scales,
yaml_file=yaml_file,
)
return wrapper
[docs]
def get_tasks_by_node_args(
workflow: Workflow,
task_generator: TaskGenerator,
node_args_dict: dict[str, Any],
error_on_empty: bool = True,
) -> list[Task]:
"""Get the tasks of a TaskGenerator in a workflow that have the given node arguments.
This method does some value serialization and formatting before handing the node argument
string over to ``workflow.get_tasks_by_node_args``.
"""
# Re-serialize the node args dict and format them as CLI arguments
serialized_node_args = {
arg_name: make_cli_argument_string(
arg_value=task_generator.serialize(
obj=arg_value, expected_type=task_generator.params[arg_name]
),
)
for arg_name, arg_value in node_args_dict.items()
}
try:
result = workflow.get_tasks_by_node_args(
task_template_name=task_generator.name, **serialized_node_args
)
except ValueError as err:
if error_on_empty:
raise err
result = []
if not result and error_on_empty:
raise ValueError(
f"There were no tasks in the workflow that matched {node_args_dict}"
)
return result
[docs]
class TaskGeneratorDocumenter(Directive):
"""Directive for generating documentation for a single task generator."""
[docs]
final_argument_whitespace = True
[docs]
option_spec = {"optional": lambda x: x} # Defines the 'optional' option
[docs]
def run(self) -> list[nodes.Node]:
"""The function sphinx/docutils use to generate documentation for the directive."""
module_name = self.arguments[0]
# if there are more than one arg, the second will be module path
module_path = self.options.get("optional", None)
task_generator = self._load_task_generator(module_name, module_path)
return _generate_nodes(task_generator, self.state)
[docs]
def _load_task_generator(
self, task_generator_path: str, module_path: Optional[str] = None
) -> TaskGenerator:
"""Load the task generator from the given path."""
try:
module_name, attr_name = task_generator_path.split(":", 1)
except ValueError:
raise self.error(
f'"{task_generator_path}" is not of format "module:parser"'
)
try:
if module_path:
spec = importlib.util.spec_from_file_location(
module_name, module_path
) # type: ignore
mod = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(mod) # type: ignore
else:
mod = __import__(module_name, globals(), locals(), [attr_name])
except (Exception, SystemExit) as exc:
err_msg = f'Failed to import "{attr_name}" from "{module_name}". '
if isinstance(exc, SystemExit):
err_msg += "The module appeared to call sys.exit()"
else:
err_msg += (
f"The following exception was raised:\n{traceback.format_exc()}"
)
raise self.error(err_msg)
if not hasattr(mod, attr_name):
raise self.error(f'Module "{module_name}" has no attribute "{attr_name}"')
task_generator = getattr(mod, attr_name)
if not isinstance(task_generator, TaskGenerator):
raise self.error(
f'Attribute "{attr_name}" of module "{module_name}" is not a TaskGenerator'
)
return task_generator
[docs]
class TaskGeneratorModuleDocumenter(Directive):
"""Directive for generating documentation for all the task generators in a module."""
[docs]
final_argument_whitespace = True
[docs]
option_spec = {"optional": lambda x: x} # Defines the 'optional' option
[docs]
def run(self) -> list[nodes.Node]:
"""The function sphinx/docutils use to generate documentation for the directive."""
module_name = self.arguments[0]
# if there are more than one arg, the second will be module path
module_path = self.options.get("optional", None)
module, task_generators = self._load_module_task_generators(
module_name, module_path
)
section = self._generate_module_section(module_name, module, task_generators)
for task_generator in task_generators:
section.extend(_generate_nodes(task_generator, self.state))
return [section]
[docs]
def _load_module_task_generators(
self, module_name: str, module_path: Optional[str] = None
) -> tuple[ModuleType, list[TaskGenerator]]:
"""Load all the task generators in a given module."""
task_generators = []
try:
if module_path:
spec = importlib.util.spec_from_file_location(
module_name, module_path
) # type: ignore
mod = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(mod) # type: ignore
else:
mod = importlib.import_module(module_name)
except (Exception, SystemExit) as exc:
err_msg = f'Failed to import "{module_name}". '
if isinstance(exc, SystemExit):
err_msg += "The module appeared to call sys.exit()"
else:
err_msg += (
f"The following exception was raised:\n{traceback.format_exc()}"
)
raise self.error(err_msg)
for attr_name in dir(mod):
if attr_name.startswith("_"):
continue
attr = getattr(mod, attr_name)
if isinstance(attr, TaskGenerator):
task_generators.append(attr)
return mod, task_generators
[docs]
def _generate_module_section(
self,
module_name: str,
module: ModuleType,
task_generators: Iterable[TaskGenerator],
) -> nodes.Node:
"""Makes the docutils node for the module section.
Includes the docstring for the module, and a list of task generator names. Doesn't
include any task generator detailed documentation.
"""
# Set up base node
section = nodes.section(
"",
nodes.title(text=module_name),
ids=[nodes.make_id(module_name)],
names=[nodes.fully_normalize_name(module_name)],
)
# Use sphinx and docutils tooling to format the docstring into rST and then parse it
# into a docutils node.
result = statemachine.ViewList()
if module.__doc__:
for line in statemachine.string2lines(
module.__doc__, tab_width=4, convert_whitespace=True
):
result.append(line, module_name)
task_generator_path_lines = [""]
for task_generator in task_generators:
task_generator_path_lines.append(".. code-block:: shell")
task_generator_path_lines.append("")
task_generator_path_lines.append(
f" {TASK_RUNNER_NAME} {task_generator.full_path}"
)
task_generator_path_lines.append("")
for line in task_generator_path_lines:
result.append(line, module_name)
sphinx_nodes.nested_parse_with_titles(self.state, result, section)
return section
[docs]
def _generate_nodes(task_generator: TaskGenerator, state: Any) -> nodes.Node:
"""Makes a docutils node with the documentation for a task generator.
Includes the docstring description for the task generator and all the arguments.
"""
# Set up base node
section = nodes.section(
"",
nodes.title(text=task_generator.name),
ids=[nodes.make_id(task_generator.full_path)],
names=[nodes.fully_normalize_name(task_generator.full_path)],
)
# Get rST lines for the description and options
parsed_docstring = docstring_parser.parse(str(task_generator.task_function.__doc__))
lines = []
lines.extend(
_format_description(
task_generator=task_generator, parsed_docstring=parsed_docstring
)
)
lines.extend(
_format_options(
task_generator=task_generator, parsed_docstring=parsed_docstring
)
)
# Convert the rST lines into a docutils node
result = statemachine.ViewList()
for line in lines:
result.append(line, task_generator.name)
sphinx_nodes.nested_parse_with_titles(state, result, section)
return [section]
[docs]
def setup(app: application.Sphinx) -> dict:
"""The function that registers the extension with sphinx."""
app.add_directive("task_generator", TaskGeneratorDocumenter)
app.add_directive("task_generator_module", TaskGeneratorModuleDocumenter)
return {
"version": "0.1",
"parallel_read_safe": True,
"parallel_write_safe": True,
}