Source code for client.workflow

"""The overarching framework to create tasks and dependencies within."""

from __future__ import annotations

import copy
import fcntl
import hashlib
import itertools
import os
import sys
import time
import uuid
from subprocess import PIPE, Popen, TimeoutExpired
from types import TracebackType
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Union

import psutil
import structlog

from jobmon.client.array import Array
from jobmon.client.dag import Dag
from jobmon.client.swarm.workflow_run import WorkflowRun as SwarmWorkflowRun
from jobmon.client.task import Task
from jobmon.client.task_resources import TaskResources
from jobmon.client.tool_version import ToolVersion
from jobmon.client.workflow_run import WorkflowRunFactory
from jobmon.core.cluster import Cluster
from jobmon.core.configuration import JobmonConfig
from jobmon.core.constants import (
    MaxConcurrentlyRunning,
    TaskStatus,
    WorkflowRunStatus,
    WorkflowStatus,
)
from jobmon.core.exceptions import (
    ConfigError,
    DistributorStartupTimeout,
    DuplicateNodeArgsError,
    WorkflowAlreadyComplete,
    WorkflowAlreadyExists,
)
from jobmon.core.requester import Requester

if TYPE_CHECKING:
    from jobmon.client.tool import Tool


[docs] logger = structlog.get_logger(__name__)
[docs] class DistributorContext: def __init__(self, cluster_name: str, workflow_run_id: int, timeout: int) -> None: """Initialization of the DistributorContext."""
[docs] self._cluster_name = cluster_name
[docs] self._workflow_run_id = workflow_run_id
[docs] self._timeout = timeout
[docs] def wait_for_startup_signal(self, timeout: int = 180) -> bool: """Wait for startup signal with non-blocking reads to handle timing issues.""" buffer = "" start_time = time.time() assert self.process.stderr is not None # keep mypy happy # Make stderr non-blocking to avoid hanging if signal already sent fd = self.process.stderr.fileno() fl = fcntl.fcntl(fd, fcntl.F_GETFL) fcntl.fcntl(fd, fcntl.F_SETFL, fl | os.O_NONBLOCK) while time.time() - start_time < timeout: try: chunk = self.process.stderr.read(100) if chunk: buffer += chunk # Look for startup signal (handles package warnings naturally) if "ALIVE" in buffer: logger.info("Received startup signal") return True except BlockingIOError: # No data available, check if process died if self.process.poll() is not None: logger.error( f"Distributor process exited with code: " f"{self.process.returncode}" ) return False time.sleep(0.1) except Exception as e: logger.warning(f"Error reading stderr: {e}") time.sleep(0.1) return False
[docs] def __enter__(self) -> DistributorContext: """Starts the Distributor Process.""" logger.info("Starting Distributor Process") # construct env env = os.environ.copy() entry_point = self.derive_jobmon_command_from_env() if entry_point is not None: env["JOBMON__DISTRIBUTOR__WORKER_NODE_ENTRY_POINT"] = f'"{entry_point}"' # Start the distributor. Write stderr to a file. cmd = [ sys.executable, "-m", # safest way to find the entrypoint "jobmon.distributor.cli", "start", "--cluster_name", self._cluster_name, "--workflow_run_id", str(self._workflow_run_id), ] self.process = Popen( cmd, stderr=PIPE, universal_newlines=True, env=env, ) # Use simple non-blocking startup detection if not self.wait_for_startup_signal(self._timeout): stderr_all = "" try: _, stderr_all = self.process.communicate(timeout=5) except TimeoutExpired: pass err = self._shutdown() raise DistributorStartupTimeout( f"Distributor process did not start within {self._timeout}s, " f"stderr='{err}'\n\nFull stderr: {stderr_all}" ) return self
[docs] def __exit__( self, exc_type: Optional[BaseException], exc_value: Optional[BaseException], exc_traceback: Optional[TracebackType], ) -> None: """Stops the Distributor Process.""" logger.info("Stopping Distributor Process") err = self._shutdown() logger.info(f"Got {err} from Distributor Process")
[docs] def alive(self) -> bool: self.process.poll() return self.process.returncode is None
[docs] def _shutdown(self) -> str: """Shutdown the distributor process.""" self.process.terminate() try: _, err = self.process.communicate(timeout=self._timeout) if "SHUTDOWN" in err: logger.info("Received shutdown confirmation") else: logger.warning("No shutdown confirmation received") except TimeoutExpired: logger.warning("Timeout waiting for graceful shutdown") err = "" if "SHUTDOWN" not in err: try: parent = psutil.Process(self.process.pid) for child in parent.children(recursive=True): child.kill() except psutil.NoSuchProcess: pass self.process.kill() self.process.wait() return err
@staticmethod
[docs] def derive_jobmon_command_from_env() -> Optional[str]: """If a singularity path is provided, use it when running the worker node.""" singularity_img_path = os.environ.get("IMGPATH", None) if singularity_img_path: return f"singularity run --app jobmon_command {singularity_img_path}" return None
[docs] class Workflow(object): """(aka Batch, aka Swarm). A Workflow is a framework by which a user may define the relationship between tasks and define the relationship between multiple runs of the same set of tasks. The great benefit of the Workflow is that it's resumable. A Workflow can only be re-loaded if two things are shown to be exact matches to a previous Workflow: 1. WorkflowArgs: It is recommended to pass a meaningful unique identifier to workflow_args, to ease resuming. However, if the Workflow is a one-off project, you may instantiate the Workflow anonymously, without WorkflowArgs. Under the hood, the WorkflowArgs will default to a UUID which, as it is randomly generated, will be harder to remember and thus harder to resume. Workflow args must be hashable. For example, CodCorrect or Como version might be passed as Args to the Workflow. For now, the assumption is WorkflowArgs is a string. 2. The tasks added to the workflow. A Workflow is built up by using Workflow.add_task(). In order to resume a Workflow, all the same tasks must be added with the same dependencies between tasks. """ def __init__( self, tool_version: ToolVersion, workflow_args: str = "", name: str = "", description: str = "", workflow_attributes: Optional[Union[List, dict]] = None, max_concurrently_running: int = MaxConcurrentlyRunning.MAXCONCURRENTLYRUNNING, requester: Optional[Requester] = None, chunk_size: int = 500, # TODO: should be in the config ) -> None: """Initialization of the client workflow. Args: tool_version: ToolVersion this workflow is associated workflow_args: Unique identifier of a workflow name: Name of the workflow description: Description of the workflow workflow_attributes: Attributes that make this workflow different from other workflows that the user wants to record. max_concurrently_running: How many running jobs to allow in parallel requester: object to communicate with the FastApi services. chunk_size: how many tasks to bind in a single request default_max_attempts: the default max attempts of the workflow for each array """
[docs] self._tool_version = tool_version
[docs] self.name = name
[docs] self.description = description
[docs] self.max_concurrently_running: int = max_concurrently_running
if requester is None: requester = Requester.from_defaults()
[docs] self.requester = requester
[docs] self._dag = Dag(requester)
# hash to task object mapping. ensure only 1
[docs] self.tasks: Dict[int, Task] = {}
[docs] self.arrays: Dict[str, Array] = {}
[docs] self._chunk_size: int = chunk_size
if workflow_args: self.workflow_args = workflow_args else: self.workflow_args = str(uuid.uuid4()) logger.info( "Workflow_args defaulting to uuid {}. To resume this " "workflow, you must re-instantiate Workflow and pass " "this uuid in as the workflow_args. As a uuid is hard " "to remember, we recommend you name your workflows and" " make workflow_args a meaningful unique identifier. " "Then add the same tasks to this workflow".format(self.workflow_args) )
[docs] self.workflow_args_hash = int( hashlib.sha256(self.workflow_args.encode("utf-8")).hexdigest(), 16 )
[docs] self.workflow_attributes: Dict[str, Any] = {}
if workflow_attributes: if isinstance(workflow_attributes, List): for attr in workflow_attributes: self.workflow_attributes[attr] = None elif isinstance(workflow_attributes, dict): for attr, val in workflow_attributes.items(): self.workflow_attributes[str(attr)] = str(val) else: raise ValueError( "workflow_attributes must be provided as a list of attributes or a " "dictionary of attributes and their values" ) # Cache for clusters and task resources
[docs] self._clusters: Dict[str, Cluster] = {}
[docs] self._task_resources: Dict[int, TaskResources] = {}
[docs] self.default_cluster_name: str = ""
[docs] self._default_max_attempts: Optional[int] = None
[docs] self.default_compute_resources_set: Dict[str, Dict[str, Any]] = {}
[docs] self.default_resource_scales_set: Dict[str, Dict[str, float]] = {}
[docs] self._fail_after_n_executions = 1_000_000_000
[docs] self.last_workflow_run_id: Optional[int] = None
@property
[docs] def tool(self) -> Tool: """Returns the associated tool to this workflow.""" return self._tool_version.tool
@property
[docs] def is_bound(self) -> bool: """If the workflow has been bound to the db.""" if not hasattr(self, "_workflow_id"): return False else: return True
@property
[docs] def workflow_id(self) -> int: """If the workflow is bound then it will have been given an id.""" if not self.is_bound: raise AttributeError( "workflow_id cannot be accessed before workflow is bound" ) return self._workflow_id
@property
[docs] def dag_id(self) -> int: """If it has been bound, it will have an associated dag_id.""" if not self.is_bound: raise AttributeError("dag_id cannot be accessed before workflow is bound") return self._dag.dag_id
@property
[docs] def task_hash(self) -> int: """Hash of all of the tasks.""" hash_value = hashlib.sha256() tasks = sorted(self.tasks.values()) if len(tasks) > 0: # if there are no tasks, we want to skip this for task in tasks: hash_value.update(str(hash(task)).encode("utf-8")) return int(hash_value.hexdigest(), 16)
@property
[docs] def task_errors(self) -> Dict: """Return a dict of error associated with a task.""" return { task.name: task.get_errors() for task in self.tasks.values() if task.final_status == TaskStatus.ERROR_FATAL }
@property
[docs] def default_max_attempts(self) -> Optional[int]: """Return the workflow default max attempts.""" if self._default_max_attempts is None: self._default_max_attempts = self.tool.default_max_attempts return self._default_max_attempts
[docs] def add_attributes(self, workflow_attributes: dict) -> None: """Users can call either to update values of existing attributes or add new attributes. Args: workflow_attributes: attributes to be bound to the db that describe this workflow. """ app_route = f"/workflow/{self.workflow_id}/workflow_attributes" self.requester.send_request( app_route=app_route, message={"workflow_attributes": workflow_attributes}, request_type="put", )
[docs] def add_task(self, task: Task) -> Task: """Add a task to the workflow to be executed. Set semantics - add tasks once only, based on hash name. Args: task: single task to add. """ logger.debug(f"Adding Task {task}") if hash(task) in self.tasks.keys(): raise ValueError( f"A task with hash {hash(task)} already exists. " f"All tasks in a workflow must have unique " f"commands. Your command was: {task.command}" ) try: # link array self._link_array_and_workflow(task.array) except AttributeError: # or infer if not already created template_name = task.node.task_template_version.task_template.template_name try: array = self.arrays[template_name] except KeyError: # create array from the task template version on the node array = Array( task_template_version=task.node.task_template_version, task_args=task.task_args, op_args=task.op_args, cluster_name=task.cluster_name, requester=self.requester, ) self._link_array_and_workflow(array) # add task to inferred array array.add_task(task) except ValueError: # check if current task array is the same as the one attached to the workflow template_name = task.node.task_template_version.task_template.template_name if self.arrays[template_name] != task.array: raise # set array max_attempts # task.array.max_attempts = self._default_max_attempts # add node to task try: self._dag.add_node(task.node) except DuplicateNodeArgsError: raise DuplicateNodeArgsError( "All tasks for a given task template in a workflow must have unique node_args." f"Found duplicate node args for {task}. task_template_version_id=" f"{task.node.task_template_version_id}, node_args={task.node.node_args}" ) # add task to workflow self.tasks[hash(task)] = task task.workflow = self logger.debug(f"Task {hash(task)} added") return task
[docs] def add_tasks(self, tasks: Iterable[Task]) -> None: """Add a list of task to the workflow to be executed.""" for task in tasks: # add the task self.add_task(task)
[docs] def set_default_compute_resources_from_yaml( self, cluster_name: str, yaml_file: str ) -> None: """Set default compute resources from a user provided yaml file for workflow level. TODO: Implement this method. Args: cluster_name: name of cluster to set default values for. yaml_file: the yaml file that is providing the compute resource values. """
[docs] def set_default_compute_resources_from_dict( self, cluster_name: str, dictionary: Dict[str, Any] ) -> None: """Set default compute resources for a given cluster_name. Args: cluster_name: name of cluster to set default values for. dictionary: dictionary of default compute resources to run tasks with. Can be overridden at task template, tool or task level. """ # TODO: Do we need to handle the scenario where no cluster name is specified? self.default_compute_resources_set[cluster_name] = dictionary
[docs] def set_default_resource_scales_from_dict( self, cluster_name: str, dictionary: Dict[str, float] ) -> None: """Set default resource scales for a given cluster_name. Args: cluster_name: name of cluster to set default values for. dictionary: dictionary of default resource scales to adjust task resources with. Can be overridden at task template or task level. """ # TODO: Do we need to handle the scenario where no cluster name is specified? self.default_resource_scales_set[cluster_name] = dictionary
[docs] def set_default_cluster_name(self, cluster_name: str) -> None: """Set the default cluster. Args: cluster_name: name of cluster to set as default. """ self.default_cluster_name = cluster_name
[docs] def set_default_max_attempts(self, value: int) -> None: """Set the max attempts. Args: value: value of max_attempts. """ self._default_max_attempts = value
[docs] def get_tasks_by_node_args( self, task_template_name: str, **kwargs: Any ) -> List[Task]: """Query tasks by node args. Used for setting dependencies.""" try: array = self.arrays[task_template_name] except KeyError: raise ValueError( f"task_template_name={task_template_name} not found on workflow. Known " f"template_names are {self.arrays.keys()}." ) tasks = array.get_tasks_by_node_args(**kwargs) return tasks
[docs] def set_max_concurrently_running( self, task_template_name: str, max_concurrently_running: int ) -> None: pass
[docs] def run( self, fail_fast: bool = False, seconds_until_timeout: int = 36000, resume: bool = False, reset_running_jobs: bool = True, distributor_startup_timeout: int = 180, resume_timeout: int = 300, configure_logging: bool = False, ) -> Optional[str]: """Run the workflow. Traverse the dag and submitting new tasks when their tasks have completed successfully. Args: fail_fast: whether to break out of distributor on first failure. seconds_until_timeout: amount of time (in seconds) to wait until the whole workflow times out. Submitted jobs will continue resume: whether the workflow should be resumed or not, if it is not set to resume and an identical workflow already exists, the workflow will error out reset_running_jobs: whether or not to reset running jobs upon resume distributor_startup_timeout: amount of time to wait for the distributor process to start up resume_timeout: seconds to wait for a workflow to become resumable before giving up configure_logging: setup jobmon client logging. If False, no logging will be configured. If True, automatic component logging will be configured. Returns: str of WorkflowRunStatus """ # LAZY STRUCTLOG CONFIGURATION: Ensure structlog is configured before any logging # This happens on first workflow.run() call, ensuring host application always # has the opportunity to configure structlog first (eliminating import-order issues) from jobmon.client.logging import ensure_structlog_configured ensure_structlog_configured() # Set gRPC fork support environment variables BEFORE any gRPC initialization # This must happen before configure_logging as OTLP handlers create gRPC connections if configure_logging is True: # Only set gRPC fork support if OTLP is available from jobmon.core.otlp import OTLP_AVAILABLE if OTLP_AVAILABLE: import os os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "true" os.environ["GRPC_POLL_STRATEGY"] = "poll" self._configure_component_logging() # bind to database logger.info("Adding Workflow metadata to database") self.bind() config = JobmonConfig() try: gui_url = config.get("http", "gui_url") except ConfigError: gui_url = "" logger.info( f"Workflow ID {self.workflow_id} assigned. Progress can be monitored at " f"{gui_url}/#/workflow/{self.workflow_id}" ) # Check if this workflow is already complete and is runnable if self._status == WorkflowStatus.DONE: raise WorkflowAlreadyComplete( f"Workflow ({self.workflow_id}) is already in done state and cannot be resumed" ) if not self._newly_created and not resume: raise WorkflowAlreadyExists( "This workflow already exists. If you are trying to resume a workflow, " "please set the resume flag. If you are not trying to resume a workflow, make " "sure the workflow args are unique or the tasks are unique" ) if self._newly_created and resume: logger.warning( "The resume flag has been set but no previous workflow_args exist." "Note that the workflow will execute as a new workflow." ) # Bind tasks logger.info("Adding task metadata to database") # Need to wait for resume signal to be sent before resetting tasks, in case of a resume factory = WorkflowRunFactory(self.workflow_id) if not self._newly_created and resume: factory.set_workflow_resume( reset_running_jobs=reset_running_jobs, resume_timeout=resume_timeout ) self._bind_tasks( reset_if_running=reset_running_jobs, chunk_size=self._chunk_size ) # create workflow_run logger.info("Adding WorkflowRun metadata to database") wfr = factory.create_workflow_run() # Update the workflowrun to BOUND state immediately in this API. All metadata already # bound, so the swarm can start immediately. wfr._update_status(WorkflowRunStatus.BOUND) logger.info(f"WorkflowRun ID {wfr.workflow_run_id} assigned") # start distributor cluster_name = list(self._clusters.keys())[0] with DistributorContext( cluster_name, wfr.workflow_run_id, distributor_startup_timeout ) as distributor: # set up swarm and initial DAG swarm = SwarmWorkflowRun( workflow_run_id=wfr.workflow_run_id, fail_after_n_executions=self._fail_after_n_executions, requester=self.requester, fail_fast=fail_fast, status=wfr.status, ) swarm.from_workflow(self) self._num_previously_completed = swarm.num_previously_complete try: swarm.run(distributor.alive, seconds_until_timeout) except Exception as e: logger.warning(f"Error running workflow: {e}") finally: # figure out doneness num_new_completed = ( len(swarm.done_tasks) - swarm.num_previously_complete ) if swarm.status != WorkflowRunStatus.DONE: logger.info( f"WorkflowRun execution ended, num failed {len(swarm.failed_tasks)}" ) else: logger.info( f"WorkflowRun execute finished successfully, {num_new_completed} tasks" ) # update workflow tasks with final status for task in self.tasks.values(): task.final_status = swarm.tasks[task.task_id].status self._num_newly_completed = num_new_completed self.last_workflow_run_id = wfr.workflow_run_id return swarm.status
[docs] def _configure_component_logging(self) -> None: """Configure component logging for client workflow operations.""" from jobmon.client.logging import configure_client_logging configure_client_logging()
[docs] def set_task_template_max_concurrency_limit( self, task_template_name: str, limit: int ) -> None: try: array = self.arrays[task_template_name] except Exception: raise KeyError( f"There is no task_template named '{task_template_name}' " f"associated with this workflow. Workflow name: {self.name}" ) array.max_concurrently_running = limit
[docs] def validate(self, strict: bool = True, raise_on_error: bool = False) -> None: """Confirm that the tasks in this workflow are valid. This method will: - access the database to confirm the requested resources are valid for the specified cluster - confirm that the workflow args are valid - make sure no task contains up/down stream tasks that are not in the workflow """ # construct task resources for task in self.tasks.values(): # get the cluster for this task cluster = self.get_cluster_by_name(task.cluster_name) # not dynamic resource request. Construct TaskResources if task.compute_resources_callable is None: try: queue = cluster.get_queue(task.queue_name) except ValueError as e: if raise_on_error: raise e else: logger.info(e) continue # validate the constructed resources task_resources = TaskResources( requested_resources=task.requested_resources, queue=queue ) is_valid, msg = task_resources.validate_resources(strict) if not is_valid: if raise_on_error: raise ValueError(f"Failed validation, reasons: {msg}") else: logger.info(f"Failed validation, reasons: {msg}") for array in self.arrays.values(): try: array.validate() except ValueError as e: if raise_on_error: raise else: logger.info(e) try: cluster_names = list(self._clusters.keys()) if len(list(self._clusters.keys())) > 1: raise RuntimeError( f"Workflow can only use one cluster. Found cluster_names={cluster_names}" ) # check if workflow is valid self._dag.validate() self._matching_wf_args_diff_hash() except Exception as e: if raise_on_error: raise else: logger.exception("Workflow validation error", error=str(e))
[docs] def bind(self) -> None: """Get a workflow_id.""" if self.is_bound: return # LAZY STRUCTLOG CONFIGURATION: Ensure structlog is configured before any logging from jobmon.client.logging import ensure_structlog_configured ensure_structlog_configured() # strict = False means we can coerce. obviously we need to raise at this point self.validate(strict=False, raise_on_error=True) # bind dag self._dag.bind(self._chunk_size) # bind workflow app_route = "/workflow" return_code, response = self.requester.send_request( app_route=app_route, message={ "tool_version_id": self._tool_version.id, "dag_id": self._dag.dag_id, "workflow_args_hash": self.workflow_args_hash, "task_hash": self.task_hash, "description": self.description, "name": self.name, "workflow_args": self.workflow_args, "max_concurrently_running": self.max_concurrently_running, "workflow_attributes": self.workflow_attributes, }, request_type="post", ) self._workflow_id = response["workflow_id"] self._status = response["status"] self._newly_created = response["newly_created"]
[docs] def _bind_tasks( self, reset_if_running: bool = True, chunk_size: int = 500, ) -> None: app_route = "/task/bind_tasks_no_args" remaining_task_hashes = list(self.tasks.keys()) while remaining_task_hashes: # split off first chunk elements from queue. task_hashes_chunk = remaining_task_hashes[:chunk_size] remaining_task_hashes = remaining_task_hashes[chunk_size:] # If this is the last chunk, mark the created_date field in the # database. mark_created = len(remaining_task_hashes) == 0 # send to server in a format of: # {<hash>:[workflow_id(0), node_id(1), task_args_hash(2), array_id(3), # name(4), command(5), max_attempts(6)], reset_if_running(7), task_args(8), # task_attributes(9), resource_scales(10), fallback_queues(11)} # flat the data structure so that the server won't depend on the client task_metadata: Dict[int, List] = {} for task_hash in task_hashes_chunk: task = self.tasks[task_hash] # get array id array = task.array if not array.is_bound: array.bind() # get task resources id self._set_original_task_resources(task) serializable_resource_scales = copy.copy(task.resource_scales) for resource, scaler in task.resource_scales.items(): # We can't serialize a callable, so use the function name instead. if callable(scaler): serializable_resource_scales[resource] = getattr( # type: ignore scaler, "__name__", "Unknown Callable" # type: ignore ) # We can't serialize an iterator, so take the relevant elements as a # list. elif isinstance(scaler, Iterator): serializable_resource_scales[resource] = list( # type: ignore itertools.islice( copy.deepcopy(scaler), task.max_attempts - 1 ) ) task_metadata[task_hash] = [ task.node.node_id, str(task.task_args_hash), task.array.array_id, task.original_task_resources.id, task.name, task.command, task.max_attempts, reset_if_running, serializable_resource_scales, task.fallback_queues, ] parameters = { "workflow_id": self.workflow_id, "tasks": task_metadata, "mark_created": mark_created, } return_code, response = self.requester.send_request( app_route=app_route, message=parameters, request_type="put", ) # populate returned values onto task dict return_tasks = response["tasks"] for k in return_tasks.keys(): task = self.tasks[int(k)] task.task_id = return_tasks[k][0] task.initial_status = return_tasks[k][1] # Bind task arguments and attributes as well self._bind_task_args(chunk_size) self._bind_task_attributes(chunk_size)
[docs] def _bind_task_args(self, chunk_size: int = 500) -> None: """Bind all task args to the database. Loop through our bound task dict in chunks in order to bind new args and arg types to the database. """ remaining_task_hashes = list(self.tasks.keys()) while remaining_task_hashes: # split off first chunk elements from queue. task_hashes_chunk = remaining_task_hashes[:chunk_size] remaining_task_hashes = remaining_task_hashes[chunk_size:] task_arg_list = [] for task_hash in task_hashes_chunk: task = self.tasks[task_hash] task_args = [ (task.task_id, arg_id, value) for arg_id, value in task.mapped_task_args.items() ] task_arg_list.extend(task_args) self.requester.send_request( app_route="/task/bind_task_args", message={"task_args": task_arg_list}, request_type="put", )
[docs] def _bind_task_attributes(self, chunk_size: int = 500) -> None: remaining_task_hashes = list(self.tasks.keys()) while remaining_task_hashes: # split off first chunk elements from queue. task_hashes_chunk = remaining_task_hashes[:chunk_size] remaining_task_hashes = remaining_task_hashes[chunk_size:] attribute_dict = {} for task_hash in task_hashes_chunk: task = self.tasks[task_hash] attribute_dict[task.task_id] = task.task_attributes # Send the request self.requester.send_request( app_route="/task/bind_task_attributes", message={"task_attributes": attribute_dict}, request_type="put", )
[docs] def get_errors( self, limit: int = 1000 ) -> Optional[Dict[int, Dict[str, Union[int, List[Dict[str, Union[str, int]]]]]]]: """Method to get all errors. Return a dictionary with the erring task_id as the key, and the Task.get_errors content as the value. When limit is specifically set as None from the client, this return set will pass back all the erred tasks in the workflow. """ errors = {} cnt: int = 0 for task in self.tasks.values(): task_id = task.task_id task_errors = task.get_errors() if task_errors is not None and len(task_errors) > 0: errors[task_id] = task_errors cnt += 1 if limit is not None and cnt >= limit - 1: break return errors
[docs] def get_cluster_by_name(self, cluster_name: str) -> Cluster: """Check if the cluster that the task specified is in the cache. If the cluster is not in the cache, create it and add to cache. """ try: cluster = self._clusters[cluster_name] except KeyError: cluster = Cluster(cluster_name=cluster_name, requester=self.requester) cluster.bind() self._clusters[cluster_name] = cluster return cluster
[docs] def _set_original_task_resources(self, task: Task) -> None: cluster = self.get_cluster_by_name(task.cluster_name) queue = cluster.get_queue(task.queue_name) task_resources = TaskResources( requested_resources=task.requested_resources, queue=queue ) try: task_resources = self._task_resources[hash(task_resources)] except KeyError: task_resources.bind() self._task_resources[hash(task_resources)] = task_resources task.original_task_resources = task_resources
[docs] def _matching_wf_args_diff_hash(self) -> None: """Check that that an existing workflow does not contain different tasks. Check that an existing workflow with the same workflow_args does not have a different hash, this would indicate that thgat the workflow contains different tasks. """ rc, response = self.requester.send_request( app_route=f"/workflow/{str(self.workflow_args_hash)}", message={}, request_type="get", ) bound_workflow_hashes = response["matching_workflows"] for task_hash, tool_version_id, dag_hash in bound_workflow_hashes: match = self._tool_version.id == tool_version_id and ( str(self.task_hash) != task_hash or str(hash(self._dag)) != dag_hash ) if match: raise WorkflowAlreadyExists( "The unique workflow_args already belong to a workflow " "that contains different tasks than the workflow you are " "creating, either change your workflow args so that they " "are unique for this set of tasks, or make sure your tasks" " match the workflow you are trying to resume" )
[docs] def __hash__(self) -> int: """Hash to encompass tool version id, workflow args, tasks and dag.""" hash_value = hashlib.sha256() hash_value.update(str(hash(self._tool_version.id)).encode("utf-8")) hash_value.update(str(self.workflow_args_hash).encode("utf-8")) hash_value.update(str(self.task_hash).encode("utf-8")) hash_value.update(str(hash(self._dag)).encode("utf-8")) return int(hash_value.hexdigest(), 16)
[docs] def __repr__(self) -> str: """A representation string for a Workflow instance.""" repr_string = ( f"Workflow(workflow_args={self.workflow_args}, " f"name={self.name}" ) try: repr_string += f", workflow_id={self.workflow_id})" except AttributeError: # Workflow not yet bound so no ID to add to repr repr_string += ")" return repr_string