from __future__ import annotations
import asyncio
import itertools as it
import signal
import sys
import time
import traceback
from collections import defaultdict
from typing import (
Any,
Callable,
Dict,
Generator,
Iterator,
List,
Optional,
Set,
Tuple,
Union,
)
import aiohttp
import structlog
from jobmon.core.cluster_protocol import ClusterDistributor
from jobmon.core.configuration import JobmonConfig
from jobmon.core.constants import TaskInstanceStatus
from jobmon.core.exceptions import DistributorInterruptedError
from jobmon.core.logging import set_jobmon_context
from jobmon.core.requester import Requester
from jobmon.core.serializers import SerializeTaskInstanceBatch
from jobmon.core.structlog_utils import bind_context
from jobmon.distributor.distributor_command import DistributorCommand
from jobmon.distributor.distributor_task_instance import DistributorTaskInstance
from jobmon.distributor.distributor_workflow_run import DistributorWorkflowRun
from jobmon.distributor.task_instance_batch import TaskInstanceBatch
[docs]
logger = structlog.get_logger(__name__)
[docs]
class DistributorService:
def __init__(
self,
cluster_interface: ClusterDistributor,
requester: Optional[Requester] = None,
workflow_run_heartbeat_interval: Optional[int] = None,
task_instance_heartbeat_interval: Optional[int] = None,
heartbeat_report_by_buffer: Optional[float] = None,
distributor_poll_interval: Optional[int] = None,
raise_on_error: bool = False,
) -> None:
"""Initialization of DistributorService."""
# Bind distributor instance context
# operational args
config = JobmonConfig()
if workflow_run_heartbeat_interval is None:
self._workflow_run_heartbeat_interval = config.get_int(
"heartbeat", "workflow_run_interval"
)
else:
self._workflow_run_heartbeat_interval = workflow_run_heartbeat_interval
if task_instance_heartbeat_interval is None:
self._task_instance_heartbeat_interval = config.get_int(
"heartbeat", "task_instance_interval"
)
else:
self._task_instance_heartbeat_interval = task_instance_heartbeat_interval
if heartbeat_report_by_buffer is None:
self._heartbeat_report_by_buffer = config.get_float(
"heartbeat", "report_by_buffer"
)
else:
self._heartbeat_report_by_buffer = heartbeat_report_by_buffer
if distributor_poll_interval is None:
self._distributor_poll_interval = config.get_int(
"distributor", "poll_interval"
)
else:
self._distributor_poll_interval = distributor_poll_interval
[docs]
self.raise_on_error = raise_on_error
# indexing of task instance by associated id
[docs]
self._task_instances: Dict[int, DistributorTaskInstance] = {}
[docs]
self._task_instance_batches: Dict[Tuple[int, int], TaskInstanceBatch] = {}
# work queue
[docs]
self._distributor_commands: Iterator[DistributorCommand] = it.chain([])
# indexing of task instances by status
[docs]
self._task_instance_status_map: Dict[str, Set[DistributorTaskInstance]] = {
TaskInstanceStatus.QUEUED: set(),
TaskInstanceStatus.INSTANTIATED: set(),
TaskInstanceStatus.LAUNCHED: set(),
TaskInstanceStatus.RUNNING: set(),
TaskInstanceStatus.TRIAGING: set(),
TaskInstanceStatus.KILL_SELF: set(),
TaskInstanceStatus.NO_HEARTBEAT: set(),
}
# order through which we processes work
gen_map: Dict[str, Callable[..., Generator[DistributorCommand, None, None]]] = {
TaskInstanceStatus.QUEUED: self._check_queued_for_work,
TaskInstanceStatus.INSTANTIATED: self._check_instantiated_for_work,
TaskInstanceStatus.TRIAGING: self._check_triaging_for_work,
TaskInstanceStatus.KILL_SELF: self._check_kill_self_for_work,
TaskInstanceStatus.NO_HEARTBEAT: self._check_no_heartbeat_for_work,
}
[docs]
self._command_generator_map = gen_map
# syncronization timings
[docs]
self._last_heartbeat_time = time.time()
# cluster API
[docs]
self.cluster_interface = cluster_interface
# web service API
if requester is None:
self.requester = Requester.from_defaults()
else:
self.requester = requester
@property
[docs]
def _next_report_increment(self) -> float:
return self._heartbeat_report_by_buffer * self._task_instance_heartbeat_interval
[docs]
def set_workflow_run(self, workflow_run_id: int) -> None:
"""Set the workflow run for this distributor service."""
set_jobmon_context(workflow_run_id=workflow_run_id)
workflow_run = DistributorWorkflowRun(workflow_run_id, requester=self.requester)
self.workflow_run = workflow_run
self.workflow_run.transition_to_instantiated()
logger.info("Workflow run initialized")
[docs]
def run(self) -> None:
"""Main distributor run loop."""
logger.info("Distributor running")
# start the cluster
try:
self._initialize_signal_handlers()
self.cluster_interface.start()
self.workflow_run.transition_to_launched()
# Send simple startup signal
sys.stderr.write("ALIVE")
sys.stderr.flush()
done: List[str] = []
todo = [
TaskInstanceStatus.QUEUED,
TaskInstanceStatus.INSTANTIATED,
TaskInstanceStatus.LAUNCHED,
TaskInstanceStatus.RUNNING,
TaskInstanceStatus.TRIAGING,
TaskInstanceStatus.KILL_SELF,
TaskInstanceStatus.NO_HEARTBEAT,
]
while True:
# loop through all statuses and do as much work as we can till the heartbeat
time_till_next_heartbeat = self._workflow_run_heartbeat_interval - (
time.time() - self._last_heartbeat_time
)
while todo and time_till_next_heartbeat > 0:
# log when this status started
start_time = time.time()
# remove status from todo and add to done
status = todo.pop(0)
# refresh internal state from db
self.refresh_status_from_db(status)
# how long the heartbeat took
refresh_time = time.time()
time_till_next_heartbeat -= refresh_time - start_time
if status in self._command_generator_map.keys():
# process any work
self.process_status(status, time_till_next_heartbeat)
# how long the full status took
end_time = time.time()
time_till_next_heartbeat -= end_time - refresh_time
else:
end_time = refresh_time
done.append(status)
duration = int(end_time - start_time)
if duration > 5: # Only log if took significant time
logger.info(
f"Status processing completed in {duration}s",
status=status,
duration_seconds=duration,
)
# append done work to the end of the work order
todo += done
done = []
logger.info(
f"Distributor service time_till_next_heartbeat: {time_till_next_heartbeat}"
)
if time_till_next_heartbeat > 0:
time.sleep(time_till_next_heartbeat)
self.log_task_instance_report_by_date()
except DistributorInterruptedError as e:
logger.info(f"Distributor interrupted: {e}")
except Exception as e:
logger.exception("Distributor error", error=str(e))
raise
finally:
logger.info("Distributor stopping")
# stop distributor
self.cluster_interface.stop()
# Send simple shutdown signal
sys.stderr.write("SHUTDOWN")
sys.stderr.flush()
[docs]
def process_status(self, status: str, timeout: Union[int, float] = -1) -> None:
"""Processes commands until all work is done or timeout is reached.
Args:
status: which status to process work for.
timeout: time until we stop processing. -1 means process till no more work
"""
start = time.time()
# generate new distributor commands from this status
command_generator_callable = self._command_generator_map[status]
command_generator = command_generator_callable()
self._distributor_commands = it.chain(command_generator)
# this way we always process at least 1 command
keep_iterating = True
while keep_iterating:
# run commands
try:
# get next command
distributor_command = next(self._distributor_commands)
distributor_command(self.raise_on_error)
# if we need a status sync close the main generator. we will process remaining
# transactions, but nothing new from the generator
if not ((time.time() - start) < timeout or timeout == -1):
command_generator.close()
except StopIteration:
# stop processing commands if we are out of commands
keep_iterating = False
# update the state map
task_instances = self._task_instance_status_map.pop(status)
self._task_instance_status_map[status] = set()
for task_instance in task_instances:
self._task_instance_status_map[task_instance.status].add(task_instance)
[docs]
def instantiate_task_instances(
self, task_instances: List[DistributorTaskInstance]
) -> None:
task_instance_ids = [ti.task_instance_id for ti in task_instances]
# Log each task instance ID for traceability (info level - state transition)
for ti in task_instances:
logger.info(
f"Task instance {ti.task_instance_id} queued for instantiation",
task_instance_id=ti.task_instance_id,
)
logger.debug(
f"Requesting instantiation of {len(task_instances)} task instances",
num_tasks=len(task_instances),
)
app_route = "/task_instance/instantiate_task_instances"
_, result = self.requester.send_request(
app_route=app_route,
message={"task_instance_ids": task_instance_ids},
request_type="post",
)
# construct batch. associations are made inside batch init
num_batches = len(result["task_instance_batches"])
logger.debug(
"Batch instantiation completed",
num_batches=num_batches,
num_tasks=len(task_instances),
)
for batch in result["task_instance_batches"]:
task_instance_batch_kwargs = SerializeTaskInstanceBatch.kwargs_from_wire(
batch
)
array_id = task_instance_batch_kwargs["array_id"]
batch_number = task_instance_batch_kwargs["array_batch_num"]
logger.debug(
"Distributor processing instantiated batch",
array_id=array_id,
array_batch_num=batch_number,
batch_size=len(task_instance_batch_kwargs["task_instance_ids"]),
)
try:
task_instance_batch = self._task_instance_batches[
(array_id, batch_number)
]
except KeyError:
task_instance_batch = TaskInstanceBatch(
array_id=array_id,
array_name=task_instance_batch_kwargs["array_name"],
array_batch_num=batch_number,
task_resources_id=task_instance_batch_kwargs["task_resources_id"],
requester=self.requester,
)
self._task_instance_batches[(array_id, batch_number)] = (
task_instance_batch
)
for task_instance_id in task_instance_batch_kwargs["task_instance_ids"]:
task_instance = self._task_instances[task_instance_id]
task_instance.status = TaskInstanceStatus.INSTANTIATED
task_instance_batch.add_task_instance(task_instance)
@bind_context(
array_id="task_instance_batch.array_id",
batch_number="task_instance_batch.batch_number",
)
[docs]
def launch_task_instance_batch(
self, task_instance_batch: TaskInstanceBatch
) -> None:
self._task_instance_batches.pop(
(task_instance_batch.array_id, task_instance_batch.batch_number)
)
batch_size = len(task_instance_batch.task_instances)
logger.debug(
"Distributor preparing batch for launch",
array_id=task_instance_batch.array_id,
array_batch_num=task_instance_batch.batch_number,
batch_size=batch_size,
)
# Log each task instance (info level - state transition)
for ti in task_instance_batch.task_instances:
logger.info(
f"Task instance {ti.task_instance_id} preparing for launch",
task_instance_id=ti.task_instance_id,
)
# record batch info in db
task_instance_batch.prepare_task_instance_batch_for_launch()
# build worker node command
command = self.cluster_interface.build_worker_node_command(
task_instance_id=None,
array_id=task_instance_batch.array_id,
batch_number=task_instance_batch.batch_number,
)
distributor_commands: List[DistributorCommand] = []
try:
# submit array to batch distributor
logger.debug(
"Submitting batch to cluster",
array_id=task_instance_batch.array_id,
array_batch_num=task_instance_batch.batch_number,
batch_size=batch_size,
submission_name=task_instance_batch.submission_name,
)
distributor_id_map = (
self.cluster_interface.submit_array_to_batch_distributor(
command=command,
name=task_instance_batch.submission_name,
requested_resources=task_instance_batch.requested_resources,
array_length=batch_size,
)
)
task_instance_batch.set_distributor_ids(distributor_id_map)
logger.info(
"Batch submitted to cluster successfully",
array_id=task_instance_batch.array_id,
array_batch_num=task_instance_batch.batch_number,
batch_size=batch_size,
)
except NotImplementedError:
# create DistributorCommands to submit the launch if array isn't implemented
logger.debug(
"Array submission not supported, launching individually",
batch_size=batch_size,
)
for task_instance in task_instance_batch.task_instances:
distributor_command = DistributorCommand(
self.launch_task_instance,
task_instance,
)
distributor_commands.append(distributor_command)
except Exception as e:
# if other error, transition to No ID status
stack_trace = traceback.format_exc()
logger.exception(
"Batch launch failed",
error=str(e),
batch_size=batch_size,
)
for task_instance in task_instance_batch.task_instances:
distributor_command = DistributorCommand(
task_instance.transition_to_no_distributor_id,
no_id_err_msg=stack_trace,
)
distributor_commands.append(distributor_command)
else:
# if successful log a transition to launched
launch_command = DistributorCommand(
task_instance_batch.transition_to_launched, self._next_report_increment
)
# Log the distributor IDs
log_distributor_ids_command = DistributorCommand(
task_instance_batch.log_distributor_ids
)
distributor_commands.append(launch_command)
distributor_commands.append(log_distributor_ids_command)
finally:
self._distributor_commands = it.chain(
distributor_commands, self._distributor_commands
)
@bind_context(task_instance_id="task_instance.task_instance_id")
[docs]
def launch_task_instance(self, task_instance: DistributorTaskInstance) -> None:
"""Submits a task instance on a given distributor.
Adds the new task instance to self.submitted_or_running_task_instances.
"""
# load resources
try:
requested_resources = task_instance.batch.requested_resources
except AttributeError:
task_instance.batch.load_requested_resources()
requested_resources = task_instance.batch.requested_resources
# Fetch the worker node command
command = self.cluster_interface.build_worker_node_command(
task_instance_id=task_instance.task_instance_id
)
# Submit to batch distributor
try:
distributor_id = self.cluster_interface.submit_to_batch_distributor(
command=command,
name=task_instance.submission_name,
requested_resources=requested_resources,
)
logger.debug("Task instance launched", distributor_id=distributor_id)
except Exception as e:
stack_trace = traceback.format_exc()
logger.exception("Task instance launch failed", error=str(e))
task_instance.transition_to_no_distributor_id(no_id_err_msg=stack_trace)
else:
# move from register queue to launch queue
task_instance.transition_to_launched(
distributor_id, self._next_report_increment
)
@bind_context(task_instance_id="task_instance.task_instance_id")
[docs]
def triage_error(self, task_instance: DistributorTaskInstance) -> None:
"""Triage a running task instance that has missed a heartbeat.
Allowed transitions are (R, U, Z, F)
"""
logger.info(
"Distributor triaging task instance error",
distributor_id=task_instance.distributor_id,
)
r_value, r_msg = self.cluster_interface.get_remote_exit_info(
task_instance.distributor_id
)
logger.info(
"Retrieved exit info from cluster",
return_code=r_value,
error_message=(
r_msg[:100] if r_msg else None
), # Truncate for log readability
)
task_instance.transition_to_error(r_msg, r_value)
logger.info(
"Task instance triage completed",
new_status=task_instance.status,
error_state=task_instance.error_state,
)
@bind_context(
array_id="task_instance_batch.array_id",
batch_number="task_instance_batch.batch_number",
)
[docs]
def kill_self_batch(self, task_instance_batch: TaskInstanceBatch) -> None:
"""Terminate all TIs in this batch.
Args:
task_instance_batch: The batch of task instances to terminate.
"""
batch_size = len(task_instance_batch.task_instances)
logger.info(
"Distributor terminating KILL_SELF batch",
batch_size=batch_size,
)
# 1) Collect the distributor IDs to terminate
distributor_ids = [
ti.distributor_id
for ti in task_instance_batch.task_instances
if ti.distributor_id is not None
]
# 2) If there are jobs to terminate, call the cluster
if distributor_ids:
logger.info(
"Sending termination signal to cluster",
num_tasks=len(distributor_ids),
distributor_ids=distributor_ids[:10], # Log first 10
)
self.cluster_interface.terminate_task_instances(distributor_ids)
logger.info(
"Cluster termination completed",
num_tasks=len(distributor_ids),
)
# 3) Mark them as killed in the DB
task_instance_batch.transition_to_killed()
@bind_context(task_instance_id="task_instance.task_instance_id")
[docs]
def no_heartbeat_error(self, task_instance: DistributorTaskInstance) -> None:
"""Move a task instance in NO_HEARTBEAT state to a recoverable error state.
This signal is sent from the swarm in the event a task instance in LAUNCHED state
fails to log a heartbeat, either due to the distributor failing to log a heartbeat
batch or due to the worker node failing to start up properly.
ERROR state allows for a retry, so that a new task instance can attempt to run.
"""
logger.info(
"Distributor processing NO_HEARTBEAT task instance",
distributor_id=task_instance.distributor_id,
)
task_instance.transition_to_error(
"Task instance never reported a heartbeat after scheduling. Will retry. "
"May be caused by distributor heartbeat failure or worker startup issue often due "
"to cluster node problem. If the retry fails, resume the task with Slurm logs "
"enabled by setting 'standard_error' and 'standard_output' in your compute "
"resources dictionary.",
TaskInstanceStatus.ERROR,
)
logger.info(
"Task instance transitioned NO_HEARTBEAT → ERROR",
new_status=task_instance.status,
)
[docs]
def log_task_instance_report_by_date(self) -> None:
"""Log the heartbeat to show that the task instance is still alive."""
task_instances_launched = self._task_instance_status_map[
TaskInstanceStatus.LAUNCHED
]
submitted_or_running = self.cluster_interface.get_submitted_or_running(
[x.distributor_id for x in task_instances_launched]
)
task_instance_ids_to_heartbeat: List[int] = []
for task_instance_launched in task_instances_launched:
if task_instance_launched.distributor_id in submitted_or_running:
task_instance_ids_to_heartbeat.append(
task_instance_launched.task_instance_id
)
if any(task_instance_ids_to_heartbeat):
# Create batches of task instance IDs
chunk_size = 500
task_instance_batches = [
task_instance_ids_to_heartbeat[i : i + chunk_size]
for i in range(0, len(task_instance_ids_to_heartbeat), chunk_size)
]
# Send heartbeat for each batch
logger.info(
f"Sending heartbeats for {len(task_instance_ids_to_heartbeat)} task instances",
num_tasks=len(task_instance_ids_to_heartbeat),
num_batches=len(task_instance_batches),
)
asyncio.run(self._log_heartbeats(task_instance_batches))
self._last_heartbeat_time = time.time()
[docs]
async def _log_heartbeats(self, task_instance_batches: List[List[int]]) -> None:
"""Create a task for each batch of task instances to send heartbeat."""
async with aiohttp.ClientSession() as session:
heartbeat_tasks = [
asyncio.create_task(self._log_heartbeat_by_batch(session, batch))
for batch in task_instance_batches
]
await asyncio.gather(*heartbeat_tasks)
[docs]
async def _log_heartbeat_by_batch(
self, session: aiohttp.ClientSession, task_instance_ids_to_heartbeat: List[int]
) -> None:
"""Send heartbeat for a batch of task instances using sophisticated retry logic."""
message: Dict = {
"next_report_increment": self._next_report_increment,
"task_instance_ids": task_instance_ids_to_heartbeat,
}
app_route = "/task_instance/log_report_by/batch"
# Use the sophisticated async requester with tenacity retry logic
await self.requester.send_request_async(
session=session,
app_route=app_route,
message=message,
request_type="post",
tenacious=True,
)
[docs]
def _initialize_signal_handlers(self) -> None:
def handle_sighup(signal: int, frame: Any) -> None:
raise DistributorInterruptedError("Got signal SIGHUP.")
def handle_sigterm(signal: int, frame: Any) -> None:
raise DistributorInterruptedError("Got signal SIGTERM.")
def handle_sigint(signal: int, frame: Any) -> None:
pass
signal.signal(signal.SIGTERM, handle_sigterm)
signal.signal(signal.SIGHUP, handle_sighup)
signal.signal(signal.SIGINT, handle_sigint)
[docs]
def refresh_status_from_db(self, status: str) -> None:
"""Got to DB to check the list tis status."""
message = {
"task_instance_ids": [
task_instance.task_instance_id
for task_instance in self._task_instance_status_map[status]
],
"status": status,
}
app_route = f"/workflow_run/{self.workflow_run.workflow_run_id}/sync_status"
_, result = self.requester.send_request(
app_route=app_route, message=message, request_type="post"
)
# mutate the statuses and update the status map
status_updates: Dict[str, List[int]] = result["status_updates"]
for new_status, task_instance_ids in status_updates.items():
for task_instance_id in task_instance_ids:
try:
task_instance = self._task_instances[task_instance_id]
except KeyError:
task_instance = DistributorTaskInstance(
task_instance_id,
self.workflow_run.workflow_run_id,
new_status,
self.requester,
)
self._task_instance_status_map[task_instance.status].add(
task_instance
)
self._task_instances[task_instance.task_instance_id] = task_instance
else:
# remove from old status set
previous_status = task_instance.status
self._task_instance_status_map[previous_status].remove(
task_instance
)
try:
self._task_instance_status_map[new_status].add(task_instance)
# change to new status and move to new set
task_instance.status = new_status
except KeyError:
# If the task instance is in a terminal state, e.g. D, E, etc.,
# expire it from the distributor
del self._task_instances[task_instance_id]
[docs]
def _check_queued_for_work(self) -> Generator[DistributorCommand, None, None]:
queued_task_instances = list(
self._task_instance_status_map[TaskInstanceStatus.QUEUED]
)
queued_task_instances.sort()
chunk_size = 500
while queued_task_instances:
ti_list = queued_task_instances[:chunk_size]
queued_task_instances = queued_task_instances[chunk_size:]
yield DistributorCommand(self.instantiate_task_instances, ti_list)
[docs]
def _check_instantiated_for_work(self) -> Generator[DistributorCommand, None, None]:
# compute the task_instances that can be launched
instantiated_task_instances = list(
self._task_instance_status_map[TaskInstanceStatus.INSTANTIATED]
)
task_instance_batches = set(
[task_instance.batch for task_instance in instantiated_task_instances]
)
for batch in task_instance_batches:
yield DistributorCommand(self.launch_task_instance_batch, batch)
[docs]
def _check_triaging_for_work(self) -> Generator[DistributorCommand, None, None]:
"""Handle TIs in TRIAGING state.
For TaskInstances with TRIAGING status, check the nature of no heartbeat,
and change the statuses accordingly.
"""
triaging_task_instances = self._task_instance_status_map[
TaskInstanceStatus.TRIAGING
]
if triaging_task_instances:
logger.info(
"Distributor processing TRIAGING task instances",
num_task_instances=len(triaging_task_instances),
)
for task_instance in triaging_task_instances:
yield DistributorCommand(self.triage_error, task_instance)
[docs]
def _check_kill_self_for_work(self) -> Generator[DistributorCommand, None, None]:
"""Handle TIs in KILL_SELF state, grouped by their TaskInstanceBatch."""
kill_self_task_instances = list(
self._task_instance_status_map[TaskInstanceStatus.KILL_SELF]
)
if kill_self_task_instances:
logger.info(
"Distributor processing KILL_SELF task instances",
num_task_instances=len(kill_self_task_instances),
)
# Log each task instance being killed (info level - state transition)
for ti in kill_self_task_instances:
logger.info(
"Task instance marked for termination",
task_instance_id=ti.task_instance_id,
distributor_id=ti.distributor_id,
)
# Group TIs by their batch
batch_map = defaultdict(list)
for ti in kill_self_task_instances:
batch_map[ti.batch].append(ti)
for batch_obj, _ in batch_map.items():
# If you'd like to verify they still have KILL_SELF status, etc., do it here.
yield DistributorCommand(self.kill_self_batch, batch_obj)
[docs]
def _check_no_heartbeat_for_work(self) -> Generator[DistributorCommand, None, None]:
"""Handle TIs in NO_HEARTBEAT state.
For TaskInstances with NO_HEARTBEAT status, move to an error recoverable state
"""
no_heartbeat_task_instances = self._task_instance_status_map[
TaskInstanceStatus.NO_HEARTBEAT
]
if no_heartbeat_task_instances:
logger.info(
"Distributor processing NO_HEARTBEAT task instances",
num_task_instances=len(no_heartbeat_task_instances),
)
for task_instance in no_heartbeat_task_instances:
yield DistributorCommand(self.no_heartbeat_error, task_instance)