Source code for worker_node.worker_node_task_instance

"""The Task Instance Object once it has been submitted to run on a worker node."""

import asyncio
import os
import signal
import socket
from time import time
from typing import Dict, Optional, TextIO

import structlog

from jobmon.core.cluster_protocol import ClusterWorkerNode
from jobmon.core.configuration import JobmonConfig
from jobmon.core.constants import TaskInstanceStatus
from jobmon.core.exceptions import ReturnCodes, TransitionError
from jobmon.core.requester import Requester
from jobmon.core.serializers import SerializeTaskInstance
from jobmon.core.structlog_utils import bind_method_context

[docs] logger = structlog.get_logger(__name__)
[docs] class WorkerNodeTaskInstance: """The Task Instance object once it has been submitted to run on a worker node.""" def __init__( self, cluster_interface: ClusterWorkerNode, task_instance_id: int, task_instance_heartbeat_interval: Optional[int] = None, heartbeat_report_by_buffer: Optional[float] = None, command_interrupt_timeout: Optional[int] = None, requester: Optional[Requester] = None, ) -> None: """A mechanism whereby a running task_instance can communicate back to the JSM. Logs its status, errors, usage details, etc. Args: cluster_interface: interface that gathers executor info in the execution_wrapper. task_instance_id: the id of the TaskInstance that is reporting back. task_instance_heartbeat_interval: how ofter to log a report by with the db heartbeat_report_by_buffer: multiplier for report by date in case we miss a few. command_interrupt_timeout: the amount of time to wait for the child process to terminate. requester: communicate with the flask services. """ # identity attributes
[docs] self._task_instance_id = task_instance_id
# service API if requester is None: requester = Requester.from_defaults()
[docs] self.requester = requester
# cluster API
[docs] self.cluster_interface = cluster_interface
# get distributor id from executor
[docs] self._distributor_id = self.cluster_interface.distributor_id
# config config = JobmonConfig() if task_instance_heartbeat_interval is None: self._task_instance_heartbeat_interval = config.get_int( "heartbeat", "task_instance_interval" ) else: self._task_instance_heartbeat_interval = task_instance_heartbeat_interval if heartbeat_report_by_buffer is None: self._heartbeat_report_by_buffer = config.get_float( "heartbeat", "report_by_buffer" ) else: self._heartbeat_report_by_buffer = heartbeat_report_by_buffer if command_interrupt_timeout is None: self._command_interrupt_timeout = config.get_int( "worker_node", "command_interrupt_timeout" ) else: self._command_interrupt_timeout = command_interrupt_timeout # attrs set by log running
[docs] self._status: Optional[str] = None
[docs] self._command: Optional[str] = None
[docs] self._command_add_env: Optional[Dict[str, str]] = None
[docs] self._stdout: Optional[str] = None
[docs] self._stderr: Optional[str] = None
# set last heartbeat
[docs] self.last_heartbeat_time = time()
@property
[docs] def task_instance_id(self) -> int: """Returns a task instance ID if it's been bound.""" if self._task_instance_id is None: raise AttributeError("Cannot access task_instance_id because it is None.") return self._task_instance_id
@property
[docs] def distributor_id(self) -> Optional[str]: """Executor id given from the executor it is being run on.""" return self._distributor_id
@property
[docs] def nodename(self) -> Optional[str]: """Node it is being run on.""" if not hasattr(self, "_nodename"): self._nodename = socket.getfqdn() return self._nodename
@property
[docs] def process_group_id(self) -> Optional[int]: """Process group to track parent and child processes.""" if not hasattr(self, "_process_group_id"): self._process_group_id = os.getpid() return self._process_group_id
@property
[docs] def status(self) -> str: """Returns the last known status of the task instance.""" if self._status is None: raise AttributeError( "Cannot access status until log_running() has been called." ) return self._status
@property
[docs] def stdout(self) -> str: if self._stdout is None: raise AttributeError( "Cannot access stdout until log_running() has been called." ) return self._stdout
@property
[docs] def stderr(self) -> str: if self._stderr is None: raise AttributeError( "Cannot access stderr until log_running() has been called." ) return self._stderr
@property
[docs] def command(self) -> str: """Returns the command this task instance will run.""" if self._command is None: raise AttributeError( "Cannot access command until log_running() has been called." ) return self._command
@property
[docs] def command_add_env(self) -> Dict[str, str]: """Returns the command this task instance will run.""" if self._command_add_env is None: raise AttributeError( "Cannot access command_add_env until log_running() has been called." ) return self._command_add_env
@property
[docs] def command_returncode(self) -> int: """Returns the exit code of the command that was run.""" if not hasattr(self, "_proc_returncode"): raise AttributeError( "Cannot access command_returncode until run() has been called" ) return self._proc_returncode
@property
[docs] def command_stdout(self) -> str: """Returns the last 10k characters of the commands stdout.""" if not hasattr(self, "_proc_stdout"): raise AttributeError( "Cannot access command_stdout until run() has been called" ) return self._proc_stdout
@property
[docs] def command_stderr(self) -> str: """Returns the last 10k characters of the commands stderr.""" if not hasattr(self, "_proc_stderr"): raise AttributeError( "Cannot access command_stderr until run() has been called" ) return self._proc_stderr
@bind_method_context(task_instance_id="_task_instance_id")
[docs] def log_done(self) -> None: """Tell the JobStateManager that this task_instance is done.""" logger.info( f"Task instance {self.task_instance_id} marked as DONE", task_instance_id=self.task_instance_id, return_code=self.command_returncode, ) message = { "stdout": self.stdout, "stderr": self.stderr, "stdout_log": self.command_stdout, "stderr_log": self.command_stderr, "nodename": self.nodename, "distributor_id": self.distributor_id, } app_route = f"/task_instance/{self.task_instance_id}/log_done" _, response = self.requester.send_request( app_route=app_route, message=message, request_type="post", ) self._status = response["status"] if self.status != TaskInstanceStatus.DONE: logger.error( "Task instance failed to transition to DONE", expected_status=TaskInstanceStatus.DONE, actual_status=self.status, ) raise TransitionError( f"TaskInstance {self.task_instance_id} failed because it could not transition " f"to {TaskInstanceStatus.DONE} status. Current status is {self.status}." ) logger.info( "Task instance successfully transitioned to DONE", status=self.status, )
@bind_method_context(task_instance_id="_task_instance_id")
[docs] def log_error(self, error_state: str, description: str) -> None: """Tell the JobStateManager that this task_instance has errored.""" logger.info( "Worker node logging task instance error", error_state=error_state, error_description=description[:200], # Truncate for readability nodename=self.nodename, distributor_id=self.distributor_id, ) message = { "error_state": error_state, "error_description": description, "stdout_log": self.command_stdout, "stderr_log": self.command_stderr, "stdout": self.stdout, "stderr": self.stderr, "nodename": self.nodename, "distributor_id": self.distributor_id, } app_route = f"/task_instance/{self.task_instance_id}/log_error_worker_node" _, response = self.requester.send_request( app_route=app_route, message=message, request_type="post", ) self._status = response["status"] if self.status != error_state: logger.error( "Task instance failed to transition to error state", expected_status=error_state, actual_status=self.status, ) raise TransitionError( f"TaskInstance {self.task_instance_id} failed because it could not transition " f"to {error_state} status. Current status is {self.status}." ) logger.info( "Task instance successfully transitioned to error state", error_state=self.status, )
@bind_method_context( task_instance_id="_task_instance_id", nodename="nodename", distributor_id="distributor_id", )
[docs] def log_running(self) -> None: """Tell the JobStateManager that this task_instance is running. Update the report_by_date to be further in the future in case it gets reconciled immediately. """ logger.info( f"Task instance {self.task_instance_id} started running", task_instance_id=self.task_instance_id, ) message = { "nodename": self.nodename, "process_group_id": str(self.process_group_id), "next_report_increment": ( self._task_instance_heartbeat_interval * self._heartbeat_report_by_buffer ), } if self.distributor_id is not None: message["distributor_id"] = str(self.distributor_id) else: logger.warning("No distributor_id in worker environment") app_route = f"/task_instance/{self.task_instance_id}/log_running" _, response = self.requester.send_request( app_route=app_route, message=message, request_type="post", ) kwargs = SerializeTaskInstance.kwargs_from_wire_worker_node( response["task_instance"] ) self._status = kwargs.pop("status") self._command = kwargs.pop("command") task_name = kwargs.pop("name") self._stdout = self.cluster_interface.initialize_logfile( "stdout", kwargs.pop("stdout_dir"), task_name ) self._stderr = self.cluster_interface.initialize_logfile( "stderr", kwargs.pop("stderr_dir"), task_name ) self._command_add_env = { f"JOBMON_{k.upper()}": str(v) for k, v in kwargs.items() } self.last_heartbeat_time = time() if self.status != TaskInstanceStatus.RUNNING: logger.error( "Task instance failed to transition to RUNNING", expected_status=TaskInstanceStatus.RUNNING, actual_status=self.status, ) raise TransitionError( f"TaskInstance {self.task_instance_id} failed because it could not transition " f"to {TaskInstanceStatus.RUNNING} status. Current status is {self.status}." ) logger.info( "Task instance successfully transitioned to RUNNING", status=self.status, )
@bind_method_context(task_instance_id="_task_instance_id")
[docs] def log_report_by(self) -> None: """Log the heartbeat to show that the task instance is still alive.""" logger.debug("Worker node logging heartbeat") message: Dict = { "next_report_increment": ( self._task_instance_heartbeat_interval * self._heartbeat_report_by_buffer ), "stdout": self.stdout, "stderr": self.stderr, } if self.distributor_id is not None: message["distributor_id"] = str(self.distributor_id) else: logger.debug("No distributor_id was found in the sbatch env at this time") app_route = f"/task_instance/{self.task_instance_id}/log_report_by" _, response = self.requester.send_request( app_route=app_route, message=message, request_type="post", ) self._status = response["status"] self.last_heartbeat_time = time() if self.status != TaskInstanceStatus.RUNNING: raise TransitionError( f"TaskInstance {self.task_instance_id} failed because it could not transition " f"to {TaskInstanceStatus.RUNNING} status. Current status is {self.status}." )
[docs] def run(self) -> None: """This script executes on the target node and wraps the target application. Could be in any language, anything that can execute on linux. Similar to a stub or a container set ENV variables in case tasks need to access them. """ # If it logs running and is not able to transition it raises TransitionError self.log_running() try: # run the command in a subprocess asyncio.run(self._run_cmd()) # some other deployment unit transitioned task instance out of R state except RuntimeError as e: if isinstance(e.__cause__, TransitionError): msg = ( f"TaskInstance is in status '{self.status}'. Expected status 'R'." f" Terminating command {self.command}." ) logger.error("Task in unexpected state", current_status=self.status) # log an error with db if we are in K state if self.status == TaskInstanceStatus.KILL_SELF: msg = ( f"Command: '{self.command}' got KILL_SELF event. Process shut down" f" with exit code: '{self.command_returncode}'" ) logger.error("Task killed", exit_code=self.command_returncode) self.log_error(TaskInstanceStatus.ERROR_FATAL, msg) # otherwise raise the error cause we are in trouble else: raise e.__cause__ else: raise # normal happy path else: if self.command_returncode == ReturnCodes.OK: logger.info( f"Task instance {self.task_instance_id} completed successfully", task_instance_id=self.task_instance_id, return_code=self.command_returncode, ) self.log_done() else: logger.info( "Command exited with non-zero code", return_code=self.command_returncode, ) error_state, msg = self.cluster_interface.get_exit_info( self.command_returncode, self.command_stderr ) self.log_error(error_state, msg)
[docs] def set_command_output(self, returncode: int, stdout: str, stderr: str) -> None: self._proc_returncode = returncode self._proc_stdout = stdout self._proc_stderr = stderr
@staticmethod
[docs] async def _communicate( async_stream: asyncio.StreamReader, output_stream: TextIO, chunk_size: int = 64 ) -> str: mem_buffer = "" try: while True: # Read a chunk of data. If no data is returned, we've reached EOF. data_chunk = await async_stream.read(chunk_size) if not data_chunk: break # EOF reached # Attempt to decode and write the data chunk. try: data_chunk_str = data_chunk.decode() output_stream.write(data_chunk_str) output_stream.flush() mem_buffer += data_chunk_str # Keep only the last 10k characters in memory. mem_buffer = mem_buffer[-10000:] except UnicodeDecodeError: pass # Ignore decoding errors and continue reading the stream. except Exception as e: # Log unexpected errors. This could be any exception raised by # the reading or writing operations. Consider appending an error message # to `mem_buffer` to indicate that an error occurred. logger.exception("Stream reading error", error=str(e)) mem_buffer += "\n[Error reading stream: {}]".format(e) finally: # Ensure that the method always returns the buffer, even if an error occurred. return mem_buffer
[docs] async def _process_poller(self, process: asyncio.subprocess.Process) -> int: keep_polling = True while keep_polling: time_till_next_heartbeat = self._task_instance_heartbeat_interval - ( time() - self.last_heartbeat_time ) try: await asyncio.wait_for(process.wait(), timeout=time_till_next_heartbeat) keep_polling = False except asyncio.TimeoutError: self.log_report_by() # keep typecheck happy returncode = process.returncode if returncode is None: raise AttributeError( "process finished polling but does not a a return code." ) return returncode
[docs] async def _run_cmd(self) -> None: # Copy the current environment variables and update them with additional settings env = os.environ.copy() env.update(self.command_add_env) # capture stdout and stderr for asynchronous reading process = await asyncio.create_subprocess_shell( self.command, env=env, stdout=asyncio.subprocess.PIPE, # Captures stdout stderr=asyncio.subprocess.PIPE, # Captures stderr ) # Assert that stdout and stderr are not None for the type checker. assert process.stdout is not None assert process.stderr is not None # Initialize task variables to None. These will hold the asyncio tasks for # communicating with and monitoring the subprocess. stdout_task = stderr_task = heartbeat_task = None try: # Context manager to ensure file streams are properly closed after writing. with open(self.stdout, "w") as stdout_stream, open( self.stderr, "w" ) as stderr_stream: # Create asyncio tasks for reading subprocess stdout and stderr. stdout_task = asyncio.create_task( self._communicate(process.stdout, stdout_stream) ) stderr_task = asyncio.create_task( self._communicate(process.stderr, stderr_stream) ) # Task for monitoring the subprocess (e.g., for timeouts or heartbeats). heartbeat_task = asyncio.create_task(self._process_poller(process)) # Await the completion of communication and monitoring tasks. # We gather all tasks together, ensuring they are completed before proceeding. valid_tasks = [stdout_task, stderr_task, heartbeat_task] await asyncio.gather(*valid_tasks) except Exception as e: # If an exception occurs, cancel all ongoing tasks to prevent dangling operations. tasks_to_cancel = [ t for t in [stdout_task, stderr_task, heartbeat_task] if t is not None ] for task in tasks_to_cancel: task.cancel() # Await cancellation, ignoring any exceptions raised from cancellation. await asyncio.gather(*tasks_to_cancel, return_exceptions=True) # Attempt a graceful shutdown of the subprocess if it's still running. process.send_signal(signal.SIGINT) try: # Wait for the process to terminate, with a specified timeout. await asyncio.wait_for( process.wait(), timeout=self._command_interrupt_timeout ) except asyncio.TimeoutError: # Forcefully terminate the subprocess process.kill() await process.wait() logger.exception( "Task instance command execution failed", error=str(e), command=self.command, ) raise RuntimeError( f"Failed to execute command '{self.command}': {e}" ) from e finally: # After tasks have been handled and the subprocess is ensured to be terminated, # retrieve the results from the tasks if they were successfully completed. stdout_result = stderr_result = "" if stdout_task and stdout_task.done(): stdout_result = stdout_task.result() if stderr_task and stderr_task.done(): stderr_result = stderr_task.result() # Ensure the subprocess is fully terminated before exiting this method. returncode = await process.wait() # Update the task instance with the final results of command execution. self.set_command_output( returncode=returncode, stdout=stdout_result, stderr=stderr_result, )