Source code for distributor.distributor_service

from __future__ import annotations

import asyncio
import itertools as it
import logging
import signal
import sys
import time
from typing import (
    Any,
    Callable,
    Dict,
    Generator,
    Iterator,
    List,
    Optional,
    Set,
    Tuple,
    Union,
)

import aiohttp

from jobmon.core import __version__
from jobmon.core.cluster_protocol import ClusterDistributor
from jobmon.core.configuration import JobmonConfig
from jobmon.core.constants import TaskInstanceStatus
from jobmon.core.exceptions import DistributorInterruptedError, InvalidResponse
from jobmon.core.requester import http_request_ok, Requester
from jobmon.core.serializers import SerializeTaskInstanceBatch
from jobmon.distributor.distributor_command import DistributorCommand
from jobmon.distributor.distributor_task_instance import DistributorTaskInstance
from jobmon.distributor.distributor_workflow_run import DistributorWorkflowRun
from jobmon.distributor.task_instance_batch import TaskInstanceBatch


[docs] logger = logging.getLogger(__name__)
[docs] class DistributorService: def __init__( self, cluster_interface: ClusterDistributor, requester: Optional[Requester] = None, workflow_run_heartbeat_interval: Optional[int] = None, task_instance_heartbeat_interval: Optional[int] = None, heartbeat_report_by_buffer: Optional[float] = None, distributor_poll_interval: Optional[int] = None, raise_on_error: bool = False, ) -> None: """Initialization of DistributorService.""" # operational args
[docs] config = JobmonConfig()
if workflow_run_heartbeat_interval is None: self._workflow_run_heartbeat_interval = config.get_int( "heartbeat", "workflow_run_interval" ) else: self._workflow_run_heartbeat_interval = workflow_run_heartbeat_interval if task_instance_heartbeat_interval is None: self._task_instance_heartbeat_interval = config.get_int( "heartbeat", "task_instance_interval" ) else: self._task_instance_heartbeat_interval = task_instance_heartbeat_interval if heartbeat_report_by_buffer is None: self._heartbeat_report_by_buffer = config.get_float( "heartbeat", "report_by_buffer" ) else: self._heartbeat_report_by_buffer = heartbeat_report_by_buffer if distributor_poll_interval is None: self._distributor_poll_interval = config.get_int( "distributor", "poll_interval" ) else: self._distributor_poll_interval = distributor_poll_interval
[docs] self.raise_on_error = raise_on_error
# indexing of task instance by associated id
[docs] self._task_instances: Dict[int, DistributorTaskInstance] = {}
[docs] self._task_instance_batches: Dict[Tuple[int, int], TaskInstanceBatch] = {}
# work queue
[docs] self._distributor_commands: Iterator[DistributorCommand] = it.chain([])
# indexing of task instances by status
[docs] self._task_instance_status_map: Dict[str, Set[DistributorTaskInstance]] = { TaskInstanceStatus.QUEUED: set(), TaskInstanceStatus.INSTANTIATED: set(), TaskInstanceStatus.LAUNCHED: set(), TaskInstanceStatus.RUNNING: set(), TaskInstanceStatus.TRIAGING: set(), TaskInstanceStatus.KILL_SELF: set(), TaskInstanceStatus.NO_HEARTBEAT: set(), }
# order through which we processes work
[docs] gen_map: Dict[str, Callable[..., Generator[DistributorCommand, None, None]]] = { TaskInstanceStatus.QUEUED: self._check_queued_for_work, TaskInstanceStatus.INSTANTIATED: self._check_instantiated_for_work, TaskInstanceStatus.TRIAGING: self._check_triaging_for_work, TaskInstanceStatus.KILL_SELF: self._check_kill_self_for_work, TaskInstanceStatus.NO_HEARTBEAT: self._check_no_heartbeat_for_work, }
[docs] self._command_generator_map = gen_map
# syncronization timings
[docs] self._last_heartbeat_time = time.time()
# cluster API
[docs] self.cluster_interface = cluster_interface
# web service API if requester is None: self.requester = Requester.from_defaults() else: self.requester = requester @property
[docs] def _next_report_increment(self) -> float: return self._heartbeat_report_by_buffer * self._task_instance_heartbeat_interval
[docs] def set_workflow_run(self, workflow_run_id: int) -> None: workflow_run = DistributorWorkflowRun(workflow_run_id, self.requester) self.workflow_run = workflow_run self.workflow_run.transition_to_instantiated()
[docs] def run(self) -> None: # start the cluster try: self._initialize_signal_handlers() self.cluster_interface.start() self.workflow_run.transition_to_launched() # signal via pipe that we are alive sys.stderr.write("ALIVE") sys.stderr.flush() done: List[str] = [] todo = [ TaskInstanceStatus.QUEUED, TaskInstanceStatus.INSTANTIATED, TaskInstanceStatus.LAUNCHED, TaskInstanceStatus.RUNNING, TaskInstanceStatus.TRIAGING, TaskInstanceStatus.KILL_SELF, TaskInstanceStatus.NO_HEARTBEAT, ] while True: # loop through all statuses and do as much work as we can till the heartbeat time_till_next_heartbeat = self._workflow_run_heartbeat_interval - ( time.time() - self._last_heartbeat_time ) while todo and time_till_next_heartbeat > 0: # log when this status started start_time = time.time() # remove status from todo and add to done status = todo.pop(0) # refresh internal state from db self.refresh_status_from_db(status) # how long the heartbeat took refresh_time = time.time() time_till_next_heartbeat -= refresh_time - start_time if status in self._command_generator_map.keys(): # process any work self.process_status(status, time_till_next_heartbeat) # how long the full status took end_time = time.time() time_till_next_heartbeat -= end_time - refresh_time else: end_time = refresh_time done.append(status) logger.info( f"Status processing for status={status} took " f"{int((end_time - start_time))}s." ) # append done work to the end of the work order todo += done done = [] if time_till_next_heartbeat > 0: time.sleep(time_till_next_heartbeat) self.log_task_instance_report_by_date() except DistributorInterruptedError: logger.info("Interrupt received!") except Exception as e: logger.exception(e) raise finally: # stop distributor self.cluster_interface.stop() # signal via pipe that we are shutdown sys.stderr.write("SHUTDOWN") sys.stderr.flush()
[docs] def process_status(self, status: str, timeout: Union[int, float] = -1) -> None: """Processes commands until all work is done or timeout is reached. Args: status: which status to process work for. timeout: time until we stop processing. -1 means process till no more work """ start = time.time() # generate new distributor commands from this status command_generator_callable = self._command_generator_map[status] command_generator = command_generator_callable() self._distributor_commands = it.chain(command_generator) # this way we always process at least 1 command keep_iterating = True while keep_iterating: # run commands try: # get next command distributor_command = next(self._distributor_commands) distributor_command(self.raise_on_error) # if we need a status sync close the main generator. we will process remaining # transactions, but nothing new from the generator if not ((time.time() - start) < timeout or timeout == -1): command_generator.close() except StopIteration: # stop processing commands if we are out of commands keep_iterating = False # update the state map task_instances = self._task_instance_status_map.pop(status) self._task_instance_status_map[status] = set() for task_instance in task_instances: self._task_instance_status_map[task_instance.status].add(task_instance)
[docs] def instantiate_task_instances( self, task_instances: List[DistributorTaskInstance] ) -> None: app_route = "/task_instance/instantiate_task_instances" _, result = self.requester.send_request( app_route=app_route, message={ "task_instance_ids": [ task_instance.task_instance_id for task_instance in task_instances ] }, request_type="post", ) # construct batch. associations are made inside batch init for batch in result["task_instance_batches"]: task_instance_batch_kwargs = SerializeTaskInstanceBatch.kwargs_from_wire( batch ) array_id = task_instance_batch_kwargs["array_id"] batch_number = task_instance_batch_kwargs["array_batch_num"] try: task_instance_batch = self._task_instance_batches[ (array_id, batch_number) ] except KeyError: task_instance_batch = TaskInstanceBatch( array_id=array_id, array_name=task_instance_batch_kwargs["array_name"], array_batch_num=batch_number, task_resources_id=task_instance_batch_kwargs["task_resources_id"], requester=self.requester, ) self._task_instance_batches[(array_id, batch_number)] = ( task_instance_batch ) for task_instance_id in task_instance_batch_kwargs["task_instance_ids"]: task_instance = self._task_instances[task_instance_id] task_instance.status = TaskInstanceStatus.INSTANTIATED task_instance_batch.add_task_instance(task_instance)
[docs] def launch_task_instance_batch( self, task_instance_batch: TaskInstanceBatch ) -> None: self._task_instance_batches.pop( (task_instance_batch.array_id, task_instance_batch.batch_number) ) # record batch info in db task_instance_batch.prepare_task_instance_batch_for_launch() # build worker node command command = self.cluster_interface.build_worker_node_command( task_instance_id=None, array_id=task_instance_batch.array_id, batch_number=task_instance_batch.batch_number, ) distributor_commands: List[DistributorCommand] = [] try: # submit array to distributor distributor_id_map = ( self.cluster_interface.submit_array_to_batch_distributor( command=command, name=task_instance_batch.submission_name, requested_resources=task_instance_batch.requested_resources, array_length=len(task_instance_batch.task_instances), ) ) task_instance_batch.set_distributor_ids(distributor_id_map) except NotImplementedError: # create DistributorCommands to submit the launch if array isn't implemented for task_instance in task_instance_batch.task_instances: distributor_command = DistributorCommand( self.launch_task_instance, task_instance, ) distributor_commands.append(distributor_command) except Exception as e: # if other error, transition to No ID status logger.exception(e) for task_instance in task_instance_batch.task_instances: distributor_command = DistributorCommand( task_instance.transition_to_no_distributor_id, no_id_err_msg=str(e) ) distributor_commands.append(distributor_command) else: # if successful log a transition to launched launch_command = DistributorCommand( task_instance_batch.transition_to_launched, self._next_report_increment ) # Log the distributor IDs log_distributor_ids_command = DistributorCommand( task_instance_batch.log_distributor_ids ) distributor_commands.append(launch_command) distributor_commands.append(log_distributor_ids_command) finally: self._distributor_commands = it.chain( distributor_commands, self._distributor_commands )
[docs] def launch_task_instance(self, task_instance: DistributorTaskInstance) -> None: """Submits a task instance on a given distributor. Adds the new task instance to self.submitted_or_running_task_instances. """ # load resources try: requested_resources = task_instance.batch.requested_resources except AttributeError: task_instance.batch.load_requested_resources() requested_resources = task_instance.batch.requested_resources # Fetch the worker node command command = self.cluster_interface.build_worker_node_command( task_instance_id=task_instance.task_instance_id ) # Submit to batch distributor try: distributor_id = self.cluster_interface.submit_to_batch_distributor( command=command, name=task_instance.submission_name, requested_resources=requested_resources, ) except Exception as e: logger.exception(e) task_instance.transition_to_no_distributor_id(no_id_err_msg=str(e)) else: # move from register queue to launch queue task_instance.transition_to_launched( distributor_id, self._next_report_increment )
[docs] def triage_error(self, task_instance: DistributorTaskInstance) -> None: """Triage a running task instance that has missed a heartbeat. Allowed transitions are (R, U, Z, F) """ r_value, r_msg = self.cluster_interface.get_remote_exit_info( task_instance.distributor_id ) task_instance.transition_to_error(r_msg, r_value)
[docs] def kill_self(self, task_instance: DistributorTaskInstance) -> None: """Terminate a task instance that has received a Kill Self signal. This signal is sent from a cold workflow resume, and transitions the task instance to an ERROR_FATAL state with no retries. """ self.cluster_interface.terminate_task_instances([task_instance.distributor_id]) task_instance.transition_to_error( "Task instance was self-killed.", TaskInstanceStatus.ERROR_FATAL )
[docs] def no_heartbeat_error(self, task_instance: DistributorTaskInstance) -> None: """Move a task instance in NO_HEARTBEAT state to a recoverable error state. This signal is sent from the swarm in the event a task instance in LAUNCHED state fails to log a heartbeat, either due to the distributor failing to log a heartbeat batch or due to the worker node failing to start up properly. ERROR state allows for a retry, so that a new task instance can attempt to run. """ task_instance.transition_to_error( "Task instance never reported a heartbeat after scheduling. Will retry", TaskInstanceStatus.ERROR, )
[docs] def log_task_instance_report_by_date(self) -> None: """Log the heartbeat to show that the task instance is still alive.""" task_instances_launched = self._task_instance_status_map[ TaskInstanceStatus.LAUNCHED ] submitted_or_running = self.cluster_interface.get_submitted_or_running( [x.distributor_id for x in task_instances_launched] ) task_instance_ids_to_heartbeat: List[int] = [] for task_instance_launched in task_instances_launched: if task_instance_launched.distributor_id in submitted_or_running: task_instance_ids_to_heartbeat.append( task_instance_launched.task_instance_id ) if any(task_instance_ids_to_heartbeat): # Create batches of task instance IDs chunk_size = 500 task_instance_batches = [ task_instance_ids_to_heartbeat[i : i + chunk_size] for i in range(0, len(task_instance_ids_to_heartbeat), chunk_size) ] # Send heartbeat for each batch asyncio.run(self._log_heartbeats(task_instance_batches)) self._last_heartbeat_time = time.time()
[docs] async def _log_heartbeats(self, task_instance_batches: List[List[int]]) -> None: """Create a task for each batch of task instances to send heartbeat.""" async with aiohttp.ClientSession(self.requester.base_url) as session: heartbeat_tasks = [ asyncio.create_task(self._log_heartbeat_by_batch(session, batch)) for batch in task_instance_batches ] await asyncio.gather(*heartbeat_tasks)
[docs] async def _log_heartbeat_by_batch( self, session: aiohttp.ClientSession, task_instance_ids_to_heartbeat: List[int] ) -> None: """Send heartbeat for a batch of task instances.""" message: Dict = { "next_report_increment": self._next_report_increment, "task_instance_ids": task_instance_ids_to_heartbeat, } app_route = f"{self.requester.route_prefix}/task_instance/log_report_by/batch" # Super basic retrying logic, to avoid fussing with tenacity logic. # TODO: Factor out into an asynchronous requester max_attempts, wait_time = 10, 1.5 while max_attempts > 0: async with session.post( app_route, json=message, params={"client_jobmon_version": __version__}, headers={"Content-Type": "application/json"}, ) as response: return_code = response.status response_text = await response.text() if 499 < return_code < 600: logger.warning( f"Got HTTP status_code={return_code} from server. " f"app_route: {app_route}." ) elif return_code == 423: logger.info( f"Got HTTP status_code=423 from server. app_route: {app_route}. " f"Retrying as per design." ) else: break max_attempts -= 1 await asyncio.sleep(wait_time) wait_time *= wait_time if not http_request_ok(return_code): raise InvalidResponse( f"Unexpected status code {return_code} from POST " f"request through route {app_route}. Expected " f"code 200. Response content: {response_text}" )
[docs] def _initialize_signal_handlers(self) -> None: def handle_sighup(signal: int, frame: Any) -> None: raise DistributorInterruptedError("Got signal SIGHUP.") def handle_sigterm(signal: int, frame: Any) -> None: raise DistributorInterruptedError("Got signal SIGTERM.") def handle_sigint(signal: int, frame: Any) -> None: pass signal.signal(signal.SIGTERM, handle_sigterm) signal.signal(signal.SIGHUP, handle_sighup) signal.signal(signal.SIGINT, handle_sigint)
[docs] def refresh_status_from_db(self, status: str) -> None: """Got to DB to check the list tis status.""" message = { "task_instance_ids": [ task_instance.task_instance_id for task_instance in self._task_instance_status_map[status] ], "status": status, } app_route = f"/workflow_run/{self.workflow_run.workflow_run_id}/sync_status" _, result = self.requester.send_request( app_route=app_route, message=message, request_type="post" ) # mutate the statuses and update the status map status_updates: Dict[str, List[int]] = result["status_updates"] for new_status, task_instance_ids in status_updates.items(): for task_instance_id in task_instance_ids: try: task_instance = self._task_instances[task_instance_id] except KeyError: task_instance = DistributorTaskInstance( task_instance_id, self.workflow_run.workflow_run_id, new_status, self.requester, ) self._task_instance_status_map[task_instance.status].add( task_instance ) self._task_instances[task_instance.task_instance_id] = task_instance else: # remove from old status set previous_status = task_instance.status self._task_instance_status_map[previous_status].remove( task_instance ) # change to new status and move to new set task_instance.status = new_status try: self._task_instance_status_map[task_instance.status].add( task_instance ) except KeyError: # If the task instance is in a terminal state, e.g. D, E, etc., # expire it from the distributor continue
[docs] def _check_queued_for_work(self) -> Generator[DistributorCommand, None, None]: queued_task_instances = list( self._task_instance_status_map[TaskInstanceStatus.QUEUED] ) queued_task_instances.sort() chunk_size = 500 while queued_task_instances: ti_list = queued_task_instances[:chunk_size] queued_task_instances = queued_task_instances[chunk_size:] yield DistributorCommand(self.instantiate_task_instances, ti_list)
[docs] def _check_instantiated_for_work(self) -> Generator[DistributorCommand, None, None]: # compute the task_instances that can be launched instantiated_task_instances = list( self._task_instance_status_map[TaskInstanceStatus.INSTANTIATED] ) task_instance_batches = set( [task_instance.batch for task_instance in instantiated_task_instances] ) for batch in task_instance_batches: yield DistributorCommand(self.launch_task_instance_batch, batch)
[docs] def _check_triaging_for_work(self) -> Generator[DistributorCommand, None, None]: """Handle TIs in TRIAGING state. For TaskInstances with TRIAGING status, check the nature of no heartbeat, and change the statuses accordingly. """ triaging_task_instances = self._task_instance_status_map[ TaskInstanceStatus.TRIAGING ] for task_instance in triaging_task_instances: yield DistributorCommand(self.triage_error, task_instance)
[docs] def _check_kill_self_for_work(self) -> Generator[DistributorCommand, None, None]: """Handle TIs in KILL_SELF state. For TaskInstances with KILL_SELF status, terminate it and transition it to error accordingly. """ kill_self_task_instances = self._task_instance_status_map[ TaskInstanceStatus.KILL_SELF ] for task_instance in kill_self_task_instances: yield DistributorCommand(self.kill_self, task_instance)
[docs] def _check_no_heartbeat_for_work(self) -> Generator[DistributorCommand, None, None]: """Handle TIs in NO_HEARTBEAT state. For TaskInstances with NO_HEARTBEAT status, move to an error recoverable state """ no_heartbeat_task_instances = self._task_instance_status_map[ TaskInstanceStatus.NO_HEARTBEAT ] for task_instance in no_heartbeat_task_instances: yield DistributorCommand(self.no_heartbeat_error, task_instance)