Source code for server.web.repositories.task_repository

"""Repository for Task operations."""

from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Union

import pandas as pd
import structlog
from sqlalchemy import and_, select, update
from sqlalchemy.orm import Session

from jobmon.core import constants
from jobmon.core.constants import Direction
from jobmon.core.serializers import SerializeTaskResourceUsage
from jobmon.server.web.models.edge import Edge
from jobmon.server.web.models.node import Node
from jobmon.server.web.models.queue import Queue
from jobmon.server.web.models.task import Task
from jobmon.server.web.models.task_instance import TaskInstance
from jobmon.server.web.models.task_instance_error_log import TaskInstanceErrorLog
from jobmon.server.web.models.task_instance_status import TaskInstanceStatus
from jobmon.server.web.models.task_resources import TaskResources
from jobmon.server.web.models.task_template import TaskTemplate
from jobmon.server.web.models.task_template_version import TaskTemplateVersion
from jobmon.server.web.models.workflow import Workflow
from jobmon.server.web.models.workflow_run import WorkflowRun
from jobmon.server.web.schemas.task import (
    DownstreamTasksResponse,
    TaskDependenciesResponse,
    TaskDependencyItem,
    TaskDetailItem,
    TaskDetailsResponse,
    TaskInstanceDetailItem,
    TaskInstanceDetailsResponse,
    TaskResourceUsageResponse,
    TaskStatusResponse,
    TaskSubdagResponse,
)
from jobmon.server.web.server_side_exception import InvalidUsage

[docs] logger = structlog.get_logger(__name__)
[docs] _task_instance_label_mapping = { "Q": "PENDING", "B": "PENDING", "I": "PENDING", "R": "RUNNING", "E": "FATAL", "Z": "FATAL", "W": "FATAL", "U": "FATAL", "K": "FATAL", "D": "DONE", }
[docs] _reversed_task_instance_label_mapping = { "PENDING": ["Q", "B", "I"], "RUNNING": ["R"], "FATAL": ["E", "Z", "W", "U", "K"], "DONE": ["D"], }
[docs] class TaskRepository: def __init__(self, session: Session) -> None: """Initialize the TaskRepository with a database session."""
[docs] self.session = session
[docs] def update_task_statuses( self, workflow_id: str, recursive: bool, workflow_status: Optional[str], task_ids: Union[List[int], str], new_status: str, ) -> None: """Update the status of tasks with business logic. Description: - When task_ids='all', it updates all tasks in the workflow with recursive=False. This improves performance. - When recursive=True, it updates the tasks and it's dependencies all the way up or down the DAG. - When recursive=False, it updates only the tasks in the task_ids list. - When workflow_status is None, it gets the workflow status from the db. - After updating the tasks, it checks the workflow status and updates it. """ # Get all task IDs if task_ids is "all" if task_ids == "all": task_ids = self._get_all_task_ids(workflow_id) if isinstance(task_ids, str): raise InvalidUsage(f"Invalid task_ids value: {task_ids}") task_ids = list(task_ids) # Get recursive task IDs if needed if recursive: task_ids = self._get_recursive_task_ids(task_ids, new_status) # Update task statuses self._update_task_statuses_in_db(task_ids, new_status) # Handle special cases based on status if new_status == constants.TaskStatus.REGISTERING: self._handle_registering_status(workflow_id, task_ids, workflow_status) elif new_status == constants.TaskStatus.DONE: self._handle_done_status(workflow_id, new_status)
[docs] def _get_all_task_ids(self, workflow_id: str) -> List[int]: """Get all task IDs for a workflow.""" task_ids = ( self.session.query(Task.id).filter(Task.workflow_id == workflow_id).all() ) return [task_id for task_id, in task_ids]
[docs] def _get_recursive_task_ids( self, task_ids: List[int], new_status: str ) -> List[int]: """Get task IDs including dependencies based on status direction.""" if new_status == constants.TaskStatus.DONE: logger.info("recursive update including upstream tasks") direction = constants.Direction.UP elif new_status == constants.TaskStatus.REGISTERING: logger.info("recursive update including downstream tasks") direction = constants.Direction.DOWN else: raise InvalidUsage( f"Invalid new_status {new_status} for recursive update", status_code=400, ) task_ids = list(self._get_tasks_recursive(set(task_ids), direction)) logger.info(f"reset status to new_status: {new_status}") return task_ids
[docs] def _update_task_statuses_in_db(self, task_ids: List[int], new_status: str) -> None: """Update task statuses in the database.""" update_stmt = update(Task).where( and_(Task.id.in_(task_ids), Task.status != new_status) ) vals = {"status": new_status} self.session.execute(update_stmt.values(**vals)) self.session.flush()
[docs] def _get_workflow_run(self, workflow_id: str) -> WorkflowRun | None: """Get the latest workflow run for a workflow.""" return ( self.session.query(WorkflowRun) .filter(WorkflowRun.workflow_id == workflow_id) .order_by(WorkflowRun.id.desc()) .first() )
[docs] def _kill_active_task_instances( self, task_ids: List[int], workflow_run_id: int ) -> None: """Kill active task instances for the given tasks.""" active_statuses = [ constants.TaskInstanceStatus.SUBMITTED_TO_BATCH_DISTRIBUTOR, constants.TaskInstanceStatus.INSTANTIATED, constants.TaskInstanceStatus.LAUNCHED, constants.TaskInstanceStatus.QUEUED, constants.TaskInstanceStatus.RUNNING, constants.TaskInstanceStatus.TRIAGING, constants.TaskInstanceStatus.NO_HEARTBEAT, ] # Process task_ids in batches to reduce lock duration batch_size = 100 for i in range(0, len(task_ids), batch_size): batch_task_ids = task_ids[i : i + batch_size] # First, get the IDs of rows that need updating using a subquery subquery = ( self.session.query(TaskInstance.id) .filter( TaskInstance.workflow_run_id == workflow_run_id, TaskInstance.task_id.in_(batch_task_ids), TaskInstance.status.in_(active_statuses), ) .subquery() ) # Update TIs to "K" status task_instance_update_stmt = update(TaskInstance).where( TaskInstance.id.in_(self.session.query(subquery.c.id)) ) vals = {"status": constants.TaskInstanceStatus.KILL_SELF} self.session.execute(task_instance_update_stmt.values(**vals)) self.session.flush()
[docs] def _handle_registering_status( self, workflow_id: str, task_ids: List[int], workflow_status: Optional[str], ) -> None: """Handle special logic for REGISTERING status.""" wfr = self._get_workflow_run(workflow_id) if wfr is None: raise ValueError(f"No workflow run found for workflow_id: {workflow_id}") self._kill_active_task_instances(task_ids, wfr.id) # Get workflow status from db if not provided if workflow_status is None: workflow_status = ( self.session.query(Workflow.status) .filter(Workflow.id == workflow_id) .scalar() ) # If workflow is done, need to set it to an error state before resuming if workflow_status == constants.WorkflowStatus.DONE: logger.info(f"reset workflow status for workflow_id: {workflow_id}") workflow_update_stmt = update(Workflow).where(Workflow.id == workflow_id) vals = {"status": constants.WorkflowStatus.FAILED} self.session.execute(workflow_update_stmt.values(**vals))
[docs] def _handle_done_status(self, workflow_id: str, new_status: str) -> None: """Handle special logic for DONE status.""" tasks_done = ( self.session.query(Task.id) .filter(Task.workflow_id == workflow_id, Task.status != new_status) .all() ) if not tasks_done: logger.info(f"set workflow status to DONE for workflow_id: {workflow_id}") workflow_update_stmt = update(Workflow).where(Workflow.id == workflow_id) vals = {"status": constants.WorkflowStatus.DONE} self.session.execute(workflow_update_stmt.values(**vals))
[docs] def _get_tasks_recursive( self, task_ids: Set[int], direction: Direction ) -> Set[int]: """Get all task IDs connected in the specified direction iteratively. Starting with the given task_ids, the function traverses the dependency graph and returns all tasks found, including the input set. It also verifies that all tasks belong to the same workflow. Args: task_ids (Set[int]): Initial set of task IDs. direction (Direction): Either Direction.UP or Direction.DOWN. Returns: Set[int]: The complete set of task IDs connected in the specified direction. """ # make sure all tasks belong to the same workflow distinct_workflow_ids = ( self.session.query(Task.workflow_id) .filter(Task.id.in_(task_ids)) .distinct() .all() ) if len(distinct_workflow_ids) == 1: # All tasks share the same workflow_id. workflow_id = distinct_workflow_ids[0][0] # Extract the workflow_id value else: # The tasks belong to different workflows. raise InvalidUsage( f"{task_ids} in request belong to different workflow_ids ", status_code=400, ) # get dag_id of the workflow_id dag_id = ( self.session.query(Workflow.dag_id) .filter(Workflow.id == workflow_id) .scalar() ) # get the node_ids of the task_ids rows = self.session.query(Task.node_id).filter(Task.id.in_(task_ids)).all() node_ids = [int(row[0]) for row in rows] # This set will accumulate all discovered node IDs. nodes_recursive: Set[int] = set() # Use a stack (list) for iterative traversal. stack = list(node_ids) while stack: current_node = stack.pop() # Skip if we've already processed this task. if current_node in nodes_recursive: continue # Mark the current node as visited. nodes_recursive.add(current_node) # Get the node dependencies for the current node based on the specified direction. node_deps = self._get_node_dependencies( {current_node}, dag_id, ( Direction.DOWN if direction == constants.Direction.DOWN else Direction.UP ), ) if node_deps: # Add the node dependencies to the stack for further processing. stack.extend(list(node_deps)) # get task_ids from node_ids tasks_recursive = self._get_tasks_from_nodes( workflow_id, list(nodes_recursive), [] ) return set(tasks_recursive.keys())
[docs] def _get_node_dependencies( self, nodes: set, dag_id: int, direction: Direction ) -> Set[int]: """Get all upstream or downstream nodes of a node. Args: nodes (set): set of nodes dag_id (int): ID of DAG direction (Direction): either up or down """ select_stmt = select(Edge).where( Edge.dag_id == int(dag_id), Edge.node_id.in_(list(nodes)) ) node_ids: Set[int] = set() for row in self.session.execute(select_stmt).all(): edges = row[0] if direction == Direction.UP: upstreams = edges.upstream_node_ids if upstreams: node_ids.update(upstreams) elif direction == Direction.DOWN: downstreams = edges.downstream_node_ids if downstreams: node_ids.update(downstreams) else: raise ValueError( f"Invalid direction type. Expected one of: {Direction}" ) return node_ids
[docs] def _get_tasks_from_nodes( self, workflow_id: int, nodes: List, task_status: List ) -> dict: """Get task ids of the given node ids. Args: workflow_id (int): ID of the workflow nodes (list): list of nodes task_status (list): list of task statuses """ if not nodes: return {} select_stmt = select(Task.id, Task.status, Task.name).where( Task.workflow_id == workflow_id, Task.node_id.in_(list(nodes)) ) result = self.session.execute(select_stmt).all() task_dict = {} for r in result: # When task_status not specified, return the full subdag if not task_status: task_dict[r[0]] = [r[1], r[2]] else: if r[1] in task_status: task_dict[r[0]] = [r[1], r[2]] return task_dict
[docs] def get_task_status( self, task_ids: Optional[Union[int, List[int]]], status: Optional[Union[str, List[str]]], ) -> TaskStatusResponse: """Get the status of tasks with filtering.""" if task_ids is None: raise InvalidUsage("Missing task_ids in request", status_code=400) if isinstance(task_ids, int): task_ids = [task_ids] if len(task_ids) == 0: raise InvalidUsage(f"Missing {task_ids} in request", status_code=400) if status and isinstance(status, str): status = [status] query_filter = [ Task.id == TaskInstance.task_id, TaskInstanceStatus.id == TaskInstance.status, ] if status: if len(status) > 0: status_codes = [ i for arg in status for i in _reversed_task_instance_label_mapping[arg] ] query_filter.append( TaskInstance.status.in_([i for arg in status for i in status_codes]) ) if task_ids: query_filter.append(Task.id.in_(task_ids)) sql = ( select( Task.id, Task.status, TaskInstance.id, TaskInstance.distributor_id, TaskInstanceStatus.label, TaskInstance.usage_str, TaskInstance.stdout, TaskInstance.stderr, TaskInstanceErrorLog.description, ) .join_from( TaskInstance, TaskInstanceErrorLog, TaskInstance.id == TaskInstanceErrorLog.task_instance_id, isouter=True, ) .where(*query_filter) ) rows = self.session.execute(sql).all() column_names = ( "TASK_ID", "task_status", "TASK_INSTANCE_ID", "DISTRIBUTOR_ID", "STATUS", "RESOURCE_USAGE", "STDOUT", "STDERR", "ERROR_TRACE", ) if rows and len(rows) > 0: # assign to dataframe for serialization df = pd.DataFrame(rows, columns=column_names) # remap to jobmon_cli statuses df.STATUS.replace(to_replace=_task_instance_label_mapping, inplace=True) task_instance_status = df.to_json() else: df = pd.DataFrame({}, columns=column_names) task_instance_status = df.to_json() return TaskStatusResponse(task_instance_status=task_instance_status)
[docs] def get_task_subdag( self, task_ids: List[int], task_status: List[str] ) -> TaskSubdagResponse: """Get the sub DAG of given tasks.""" if not task_ids: raise InvalidUsage(f"Missing {task_ids} in request", status_code=400) if task_status is None: task_status = [] select_stmt = ( select( Task.workflow_id.label("workflow_id"), Workflow.dag_id.label("dag_id"), Task.node_id.label("node_id"), ) .join_from(Task, Workflow, Task.workflow_id == Workflow.id) .where(Task.id.in_(task_ids)) ) # Initialize defaultdict to store information grouped_data: Dict = defaultdict( lambda: {"workflow_id": None, "dag_id": None, "node_ids": []} ) for row in self.session.execute(select_stmt): key = (row.workflow_id, row.dag_id) grouped_data[key]["workflow_id"] = row.workflow_id grouped_data[key]["dag_id"] = row.dag_id if grouped_data[key]: grouped_data[key]["node_ids"].append(row.node_id) # If we find no results, we handle it here if not grouped_data: return TaskSubdagResponse(workflow_id=None, sub_task=None) # Since we have validated all the tasks belong to the same wf in status_command # before this call, assume they all belong to the same wf. if grouped_data: some_key = next(iter(grouped_data)) workflow_id, dag_id = some_key node_ids = [int(node_id) for node_id in grouped_data[some_key]["node_ids"]] # Continue with your current processing logic sub_dag_tree = self._get_subdag(node_ids, dag_id) sub_task_tree = self._get_tasks_from_nodes( workflow_id, sub_dag_tree, task_status ) return TaskSubdagResponse(workflow_id=workflow_id, sub_task=sub_task_tree)
[docs] def _get_subdag(self, node_ids: List[int], dag_id: int) -> List[int]: """Get all descendants of given nodes. Args: node_ids (list): list of node IDs dag_id (int): ID of DAG """ node_set = set(node_ids) node_descendants = node_set while len(node_descendants) > 0: node_descendants = self._get_node_dependencies( node_descendants, dag_id, Direction.DOWN ) node_set = node_set.union(node_descendants) return list(node_set)
[docs] def get_task_dependencies(self, task_id: int) -> TaskDependenciesResponse: """Get task's downstream and upstream tasks and their status.""" dag_id, workflow_id, node_id = self._get_dag_and_wf_id(task_id) if dag_id is None or workflow_id is None or node_id is None: return TaskDependenciesResponse(up=[], down=[]) up_nodes = self._get_node_dependencies({node_id}, dag_id, Direction.UP) down_nodes = self._get_node_dependencies({node_id}, dag_id, Direction.DOWN) up_task_dict = self._get_tasks_from_nodes(workflow_id, list(up_nodes), []) down_task_dict = self._get_tasks_from_nodes(workflow_id, list(down_nodes), []) # return a "standard" json format so that it can be reused by future GUI up = ( [] if up_task_dict is None or len(up_task_dict) == 0 else [ [ TaskDependencyItem( id=k, status=up_task_dict[k][0], name=up_task_dict[k][1], ) ] for k in up_task_dict ] ) down = ( [] if down_task_dict is None or len(down_task_dict) == 0 else [ [ TaskDependencyItem( id=k, status=down_task_dict[k][0], name=down_task_dict[k][1], ) ] for k in down_task_dict ] ) return TaskDependenciesResponse(up=up, down=down)
[docs] def _get_dag_and_wf_id(self, task_id: int) -> tuple: """Get DAG ID, workflow ID, and node ID for a task.""" select_stmt = ( select( Workflow.dag_id.label("dag_id"), Task.workflow_id.label("workflow_id"), Task.node_id.label("node_id"), ) .join_from(Task, Workflow, Task.workflow_id == Workflow.id) .where(Task.id == task_id) ) row = self.session.execute(select_stmt).one_or_none() if row is None: return None, None, None return int(row.dag_id), int(row.workflow_id), int(row.node_id)
[docs] def get_task_resource_usage(self, task_id: int) -> TaskResourceUsageResponse: """Return the resource usage for a given Task ID.""" # Select the fields required by SerializeTaskResourceUsage.to_wire select_stmt = ( select( Task.num_attempts, TaskInstance.nodename, TaskInstance.wallclock, TaskInstance.maxrss, ) .join_from( TaskInstance, Task, # Join Task table TaskInstance.task_id == Task.id, ) .where(TaskInstance.task_id == task_id, TaskInstance.status == "D") ) result = self.session.execute(select_stmt).one_or_none() if result is None: resource_usage = SerializeTaskResourceUsage.to_wire(None, None, None, None) else: resource_usage = SerializeTaskResourceUsage.to_wire( result.num_attempts, result.nodename, result.wallclock, result.maxrss, ) return TaskResourceUsageResponse(resource_usage=list(resource_usage))
[docs] def get_downstream_tasks( self, task_ids: List[int], dag_id: int, client_version: Optional[str] = None ) -> DownstreamTasksResponse: """Get only the direct downstreams of a task.""" from jobmon.server.web.utils.json_compat import normalize_node_ids_for_client tasks_and_edges = self.session.execute( select(Task.id, Task.node_id, Edge.downstream_node_ids).where( Task.id.in_(task_ids), Task.node_id == Edge.node_id, Edge.dag_id == dag_id, ) ).all() result = {} for row in tasks_and_edges: # Format downstream_node_ids based on client version formatted_downstream_ids = normalize_node_ids_for_client( row.downstream_node_ids, client_version ) result[row.id] = [row.node_id, formatted_downstream_ids] return DownstreamTasksResponse(downstream_tasks=result)
[docs] def get_task_instance_details(self, task_id: int) -> TaskInstanceDetailsResponse: """Get information about TaskInstances associated with specific Task ID.""" query = ( select( TaskInstance.id, TaskInstanceStatus.label, TaskInstance.stdout, TaskInstance.stderr, TaskInstance.stdout_log, TaskInstance.stderr_log, TaskInstance.distributor_id, TaskInstance.nodename, TaskInstanceErrorLog.description, TaskInstance.wallclock, TaskInstance.maxrss, TaskResources.requested_resources, TaskInstance.submitted_date, TaskInstance.status_date, Queue.name, ) .outerjoin_from( TaskInstance, TaskInstanceErrorLog, TaskInstance.id == TaskInstanceErrorLog.task_instance_id, ) .join( TaskResources, TaskInstance.task_resources_id == TaskResources.id, ) .join( Queue, TaskResources.queue_id == Queue.id, ) .where( TaskInstance.task_id == task_id, TaskInstance.status == TaskInstanceStatus.id, ) ) rows = self.session.execute(query).all() def serialize_datetime(dt: Any) -> str: if isinstance(dt, datetime): return dt.isoformat() return dt result = [ TaskInstanceDetailItem( ti_id=row[0], ti_status=row[1], ti_stdout=row[2], ti_stderr=row[3], ti_stdout_log=row[4], ti_stderr_log=row[5], ti_distributor_id=row[6], ti_nodename=row[7], ti_error_log_description=row[8], ti_wallclock=row[9], ti_maxrss=row[10], ti_resources=row[11], ti_submit_date=serialize_datetime(row[12]), ti_status_date=serialize_datetime(row[13]), ti_queue_name=row[14], ) for row in rows ] return TaskInstanceDetailsResponse(taskinstances=result)
[docs] def get_task_details_viz(self, task_id: int) -> TaskDetailsResponse: """Get status of Task from Task ID.""" query = ( select( Task.status, Task.workflow_id, Task.name, Task.command, Task.status_date, TaskTemplate.id, ) .join(Node, Task.node_id == Node.id) .join( TaskTemplateVersion, Node.task_template_version_id == TaskTemplateVersion.id, ) .join( TaskTemplate, TaskTemplateVersion.task_template_id == TaskTemplate.id, ) .where(Task.id == task_id) ) rows = self.session.execute(query).all() result = [] for row in rows: status_date = row[4].isoformat() if isinstance(row[4], datetime) else row[4] result.append( TaskDetailItem( task_status=row[0], workflow_id=row[1], task_name=row[2], task_command=row[3], task_status_date=status_date, task_template_id=row[5], ) ) return TaskDetailsResponse(task_details=result)