Source code for client.dag

"""The DAG captures the interconnected graph of tasks and their dependencies."""

import hashlib
from http import HTTPStatus as StatusCodes
from typing import Dict, List, Optional, Set, Tuple, Union

import structlog

from jobmon.client.node import Node
from jobmon.core.exceptions import (
    CyclicGraphError,
    DuplicateNodeArgsError,
    InvalidResponse,
    NodeDependencyNotExistError,
)
from jobmon.core.requester import Requester

[docs] logger = structlog.get_logger(__name__)
[docs] class Dag(object): """The DAG captures the interconnected graph of tasks and their dependencies.""" def __init__(self, requester: Optional[Requester] = None) -> None: """Directed Acyclic Graph. The DAG captures the tasks (nodes) as they are related to each other in their dependency structure. The Dag is traversed in the order of node dependencies so a workflow run is a single instance of traversing through a dag. This object stores the nodes and communicates with the server with regard to itself. Args: requester (str): url to communicate with the flask services. """
[docs] self.nodes: Set[Node] = set()
if requester is None: requester = Requester.from_defaults()
[docs] self.requester = requester
@property
[docs] def dag_id(self) -> int: """Database unique ID of this DAG.""" if not hasattr(self, "_dag_id"): raise AttributeError("_dag_id cannot be accessed before dag is bound") return self._dag_id
[docs] def add_node(self, node: Node) -> None: """Add a node to this dag. Args: node (jobmon.client.node.Node): Node to add to the dag """ # validate node has unique node args within this task template version if node in self.nodes: raise DuplicateNodeArgsError( "A duplicate node was found for task_template_version_id=" f"{node.task_template_version_id}. Node args were {node.node_args}" ) # wf.add_task should call ClientNode.add_node() + pass the tasks' node self.nodes.add(node)
[docs] def bind(self, chunk_size: int = 500) -> int: """Retrieve an id for a matching dag from the server. If it doesn't exist, first create one, including its edges. """ if len(self.nodes) == 0: raise RuntimeError( "No nodes were found in the dag. An empty dag " "cannot be bound." ) self._bulk_bind_nodes(chunk_size) dag_hash = hash(self) _, response = self.requester.send_request( app_route="/dag", message={"dag_hash": dag_hash}, request_type="post", ) dag_id = response["dag_id"] # no created date means bind edges if response["created_date"] is None: self._bulk_insert_edges(dag_id) self._dag_id = dag_id return dag_id
[docs] def validate(self) -> None: """Validate the nodes and their dependencies.""" nodes_in_dag = self.nodes for node in nodes_in_dag: # Make sure no task contains up/down stream tasks that are not in the workflow for n in node.upstream_nodes: if n not in nodes_in_dag: raise NodeDependencyNotExistError( f"Upstream node, {hash(n)}, for node, {hash(node)}, " "does not exist in the dag.Check that every task has been added to " "the workflow and is in the correct order." ) for n in node.downstream_nodes: if n not in nodes_in_dag: raise NodeDependencyNotExistError( f"Downstream node, {hash(n)}, for node, {hash(node)}, " "does not exist in the dag.Check that every task has been added to " "the workflow and is in the correct order." ) dag_map = {node: node.downstream_nodes for node in nodes_in_dag} if self._is_cyclic(dag_map): raise CyclicGraphError( "Cycle detected in the task graph. Please ensure that your task dependencies " "flow in only one direction." )
[docs] def _is_cyclic(self, dag_map: Dict[Node, Set[Node]]) -> bool: """Return true if the nodes are cyclic. This method is effectively a depth-first search looking for already-seen nodes, implemented using the "stack of iterators" pattern to get around Python's recursion limit of 1000. """ visited = set() path = [object()] path_set = set(path) stack = [iter(dag_map)] while stack: for v in stack[-1]: if v in path_set: return True elif v not in visited: visited.add(v) path.append(v) path_set.add(v) stack.append(iter(dag_map.get(v, ()))) break else: path_set.remove(path.pop()) stack.pop() return False
[docs] def _bulk_bind_nodes(self, chunk_size: int) -> None: def get_chunk(total_nodes: int, chunk_number: int) -> Optional[Tuple[int, int]]: # This function is created for unit testing if (chunk_number - 1) * chunk_size >= total_nodes: return None return ( (chunk_number - 1) * chunk_size, min(total_nodes - 1, chunk_number * chunk_size - 1), ) nodes_in_dag = list(self.nodes) nodes_received = {} total_nodes = len(self.nodes) chunk_number = 1 chunk_boarder = get_chunk(total_nodes, chunk_number) while chunk_boarder is not None: # do something to bind nodes_to_send = [] for i in range(chunk_boarder[0], chunk_boarder[1] + 1): node = nodes_in_dag[i] n = { "task_template_version_id": node.task_template_version_id, "node_args_hash": str(node.node_args_hash), "node_args": node.mapped_node_args, } nodes_to_send.append(n) rc, response = self.requester.send_request( app_route="/nodes", message={"nodes": nodes_to_send}, request_type="post", ) nodes_received.update(response["nodes"]) chunk_number += 1 chunk_boarder = get_chunk(total_nodes, chunk_number) for node in nodes_in_dag: k = f"{node.task_template_version_id}:{node.node_args_hash}" if k in nodes_received.keys(): node.node_id = int(nodes_received[k]) else: raise InvalidResponse( f"Fail to find node_id in HTTP response for node_args_hash " f"{node.node_args_hash} and task_template_version_id " f"{node.task_template_version_id} HTTP Response:\n {response}" )
[docs] def _get_dag_id(self) -> Optional[int]: dag_hash = hash(self) logger.info(f"Querying for dag with hash: {dag_hash}") return_code, response = self.requester.send_request( app_route="/dag", message={"dag_hash": dag_hash}, request_type="get", ) if return_code == StatusCodes.OK: return response["dag_id"] else: raise ValueError( f"Unexpected status code {return_code} from GET " f"request through route /dag/{dag_hash} . " f"Expected code 200. Response content: " f"{response}" )
[docs] def _bulk_insert_edges(self, dag_id: int, chunk_size: int = 500) -> None: # compile full list of edges all_edges: List[Dict[str, Union[List, int]]] = [] for node in self.nodes: # get the node ids for all upstream and downstream nodes upstream_nodes = [ upstream_node.node_id for upstream_node in node.upstream_nodes ] downstream_nodes = [ downstream_node.node_id for downstream_node in node.downstream_nodes ] all_edges.append( { "node_id": node.node_id, "upstream_node_ids": upstream_nodes, "downstream_node_ids": downstream_nodes, } ) logger.debug(f"message included in edge post request: {all_edges}") while all_edges: # split off first chunk elements from queue. edge_chunk, all_edges = all_edges[:chunk_size], all_edges[chunk_size:] message: Dict[str, Union[List[Dict], bool]] = {"edges_to_add": edge_chunk} # more edges to add. don't mark it created if all_edges: message["mark_created"] = False else: message["mark_created"] = True app_route = f"/dag/{dag_id}/edges" self.requester.send_request( app_route=app_route, message=message, request_type="post" )
[docs] def __hash__(self) -> int: """Determined by hashing all sorted node hashes and their downstream.""" hash_value = hashlib.sha256() if len(self.nodes) > 0: # if the dag is empty, we want to skip this for node in sorted(self.nodes): hash_value.update(str(hash(node)).encode("utf-8")) for downstream_node in sorted(node.downstream_nodes): hash_value.update(str(hash(downstream_node)).encode("utf-8")) return int(hash_value.hexdigest(), 16)
[docs] def __repr__(self) -> str: """A representation string for a Dag instance.""" return f"Dag(hash={self.__hash__()}"