"""A node represents an individual task within a DAG."""
from __future__ import annotations
import hashlib
import logging
from typing import Any, Dict, List, Optional, Set
from jobmon.client.task_template_version import TaskTemplateVersion
from jobmon.core.constants import SpecialChars
from jobmon.core.requester import Requester
[docs]
logger = logging.getLogger(__name__)
[docs]
class Node:
"""A node represents an individual task within a Dag."""
def __init__(
self,
task_template_version: TaskTemplateVersion,
node_args: Dict[str, Any],
requester: Optional[Requester] = None,
) -> None:
"""A node represents an individual task within a Dag.
This includes its relationship to other nodes that it is dependent upon or nodes that
depend upon it. A node stores node arguments (arguments relating to the
actual shape of the dag and number of tasks created for a given
stage/task template) and can register itself with the database via the
Jobmon Query Service and the Jobmon State Manager.
Args:
task_template_version: The associated TaskTemplateVersion.
node_args: key-value pairs of arg_name and a value.
requester: Requester object to communicate with the flask services.
"""
[docs]
self._node_id: Optional[int] = None
[docs]
self.task_template_version = task_template_version
[docs]
self.node_args = node_args
[docs]
self.mapped_node_args = self.task_template_version.convert_arg_names_to_ids(
**self.node_args
)
[docs]
self.node_args_hash = self._hash_node_args()
[docs]
self.upstream_nodes: Set[Node] = set()
[docs]
self.downstream_nodes: Set[Node] = set()
if requester is None:
requester = Requester.from_defaults()
[docs]
self.requester = requester
@property
[docs]
def node_id(self) -> int:
"""Unique id for each node."""
if self._node_id is None:
raise AttributeError("node_id cannot be accessed before node is bound")
return self._node_id
@node_id.setter
def node_id(self, val: int) -> None:
"""Unique id for each node."""
self._node_id = val
@property
[docs]
def task_template_version_id(self) -> int:
return self.task_template_version.id
@property
[docs]
def default_name(self) -> str:
"""The default name of this node in the array."""
name = (
self.task_template_version.task_template.template_name
+ "_"
+ "_".join(
[str(k) + "-" + str(self.node_args[k]) for k in self.node_args.keys()]
)
)
# special char protection
name = "".join(
[
c if c not in SpecialChars.ILLEGAL_SPECIAL_CHARACTERS else "_"
for c in name
]
)
# long name protection
name = name if len(name) < 250 else name[0:249]
return name
[docs]
def _hash_node_args(self) -> int:
"""A hash of the node.
The hash is the encoded result of the args and values concatenated together.
"""
arg_ids = list(self.mapped_node_args.keys())
arg_ids.sort()
arg_values = [str(self.mapped_node_args[key]) for key in arg_ids]
str_arg_ids = [str(arg) for arg in arg_ids]
hash_value = int(
hashlib.sha256(
"".join(str_arg_ids + arg_values).encode("utf-8")
).hexdigest(),
16,
)
return hash_value
[docs]
def add_upstream_node(self, upstream_node: Node) -> None:
"""Add a single node to this one's upstream Nodes.
Args:
upstream_node: node to add a dependency on
"""
self.upstream_nodes.add(upstream_node)
# Add this node to the upstream nodes' downstream
upstream_node.downstream_nodes.add(self)
[docs]
def add_upstream_nodes(self, upstream_nodes: List[Node]) -> None:
"""Add many nodes to this one's upstream Nodes.
Args:
upstream_nodes: list of nodes to add dependencies on
"""
for node in upstream_nodes:
self.add_upstream_node(node)
[docs]
def add_downstream_node(self, downstream_node: Node) -> None:
"""Add a node to this one's downstream Nodes.
Args:
downstream_node: Node that will be dependent on this node
"""
self.downstream_nodes.add(downstream_node)
# avoid endless recursion, set directly
downstream_node.upstream_nodes.add(self)
[docs]
def add_downstream_nodes(self, downstream_nodes: List[Node]) -> None:
"""Add a list of nodes as this node's downstream nodes.
Args:
downstream_nodes: Nodes that will be dependent on this node.
"""
for node in downstream_nodes:
self.add_downstream_node(node)
[docs]
def __repr__(self) -> str:
"""Repr the node attributes."""
return (
"Node("
f"task_template_version_id={self.task_template_version_id}, "
f"node_args={self.node_args}, "
f"node_args_hash={self.node_args_hash})"
)
[docs]
def __eq__(self, other: object) -> bool:
"""Check if two nodes have equal hashes."""
if not isinstance(other, Node):
return False
else:
return hash(self) == hash(other)
[docs]
def __lt__(self, other: Node) -> bool:
"""Check if this hash is less than anothers."""
return hash(self) < hash(other)
[docs]
def __hash__(self) -> int:
"""Create a hash that will be a unique identifier for the node."""
if not hasattr(self, "_hash_val"):
hash_value = hashlib.sha256()
hash_value.update(bytes(str(self.node_args_hash).encode("utf-8")))
hash_value.update(bytes(str(self.task_template_version_id).encode("utf-8")))
self._hash_val = int(hash_value.hexdigest(), 16)
return self._hash_val