Source code for server.web.routes.v2.fsm.task_instance

"""Routes for TaskInstances."""

from collections import defaultdict
from http import HTTPStatus as StatusCodes
from typing import Any, DefaultDict, Dict, Optional, cast

import sqlalchemy
import structlog
from fastapi import Request
from sqlalchemy import and_, select, update
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from starlette.responses import JSONResponse

from jobmon.core import constants
from jobmon.core.exceptions import InvalidStateTransition
from jobmon.core.serializers import SerializeTaskInstanceBatch
from jobmon.server.web._compat import add_time
from jobmon.server.web.db import get_sessionmaker
from jobmon.server.web.models.array import Array
from jobmon.server.web.models.task import Task
from jobmon.server.web.models.task_instance import TaskInstance
from jobmon.server.web.models.task_instance_error_log import TaskInstanceErrorLog
from jobmon.server.web.routes.v2.fsm import fsm_router as api_v2_router
from jobmon.server.web.server_side_exception import ServerError

[docs] logger = structlog.get_logger(__name__)
[docs] SessionMaker = get_sessionmaker()
@api_v2_router.post("/task_instance/{task_instance_id}/log_running")
[docs] async def log_running(task_instance_id: int, request: Request) -> Any: """Log a task_instance as running. Args: task_instance_id: id of the task_instance to log as running request: fastapi request object """ structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id) data = cast(Dict, await request.json()) with SessionMaker() as session: with session.begin(): select_stmt = select(TaskInstance).where( TaskInstance.id == task_instance_id ) task_instance = session.execute(select_stmt).scalars().one() if data.get("distributor_id", None) is not None: task_instance.distributor_id = data["distributor_id"] if data.get("nodename", None) is not None: task_instance.nodename = data["nodename"] task_instance.process_group_id = data["process_group_id"] try: task_instance.transition(constants.TaskInstanceStatus.RUNNING) task_instance.report_by_date = add_time(data["next_report_increment"]) except InvalidStateTransition as e: if task_instance.status == constants.TaskInstanceStatus.RUNNING: logger.warning(e) elif task_instance.status == constants.TaskInstanceStatus.KILL_SELF: task_instance.transition(constants.TaskInstanceStatus.ERROR_FATAL) elif task_instance.status == constants.TaskInstanceStatus.NO_HEARTBEAT: task_instance.transition(constants.TaskInstanceStatus.ERROR) else: # Tried to move to an illegal state logger.error(e) wire_format = task_instance.to_wire_as_worker_node_task_instance() resp = JSONResponse( content={"task_instance": wire_format}, status_code=StatusCodes.OK ) return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_report_by")
[docs] async def log_ti_report_by(task_instance_id: int, request: Request) -> Any: """Log a task_instance as being responsive with a new report_by_date. This is done at the worker node heartbeat_interval rate, so it may not happen at the same rate that the reconciler updates batch submitted report_by_dates (also because it causes a lot of traffic if all workers are logging report by_dates often compared to if the reconciler runs often). Args: task_instance_id: id of the task_instance to log request: fastapi request object """ structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id) data = cast(Dict, await request.json()) with SessionMaker() as session: with session.begin(): vals = {"report_by_date": add_time(data["next_report_increment"])} for optional_val in ["distributor_id", "stderr", "stdout"]: val = data.get(optional_val, None) if data is not None: vals[optional_val] = val update_stmt = update(TaskInstance).where( TaskInstance.id == task_instance_id ) session.execute(update_stmt.values(**vals)) session.flush() select_stmt = select(TaskInstance).where( TaskInstance.id == task_instance_id ) task_instance = session.execute(select_stmt).scalars().one() if task_instance.status == constants.TaskInstanceStatus.TRIAGING: task_instance.transition(constants.TaskInstanceStatus.RUNNING) resp = JSONResponse( content={"status": task_instance.status}, status_code=StatusCodes.OK ) return resp
@api_v2_router.post("/task_instance/log_report_by/batch")
[docs] async def log_ti_report_by_batch(request: Request) -> Any: """Log task_instances as being responsive with a new report_by_date. This is done at the worker node heartbeat_interval rate, so it may not happen at the same rate that the reconciler updates batch submitted report_by_dates (also because it causes a lot of traffic if all workers are logging report by_dates often compared to if the reconciler runs often). Args: task_instance_id: id of the task_instance to log request: fastapi request object """ data = cast(Dict, await request.json()) tis = data.get("task_instance_ids", None) next_report_increment = float(data["next_report_increment"]) logger.debug(f"Log report_by for TI {tis}.") if tis: with SessionMaker() as session: with session.begin(): update_stmt = ( update(TaskInstance) .where( TaskInstance.id.in_(tis), TaskInstance.status == constants.TaskInstanceStatus.LAUNCHED, ) .values(report_by_date=add_time(next_report_increment)) ) session.execute(update_stmt) resp = JSONResponse(content={}, status_code=StatusCodes.OK) return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_done")
[docs] async def log_done(task_instance_id: int, request: Request) -> Any: """Log a task_instance as done. Args: task_instance_id: id of the task_instance to log done request: fastapi request object """ structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id) data = cast(Dict, await request.json()) with SessionMaker() as session: with session.begin(): select_stmt = select(TaskInstance).where( TaskInstance.id == task_instance_id ) task_instance = session.execute(select_stmt).scalars().one() optional_vals = [ "distributor_id", "stdout_log", "stderr_log", "nodename", "stdout", "stderr", ] for optional_val in optional_vals: val = data.get(optional_val, None) if val is not None: setattr(task_instance, optional_val, val) try: task_instance.transition(constants.TaskInstanceStatus.DONE) except InvalidStateTransition as e: if task_instance.status == constants.TaskInstanceStatus.DONE: logger.warning(e) else: # Tried to move to an illegal state logger.error(e) resp = JSONResponse( content={"status": task_instance.status}, status_code=StatusCodes.OK ) return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_error_worker_node")
[docs] async def log_error_worker_node(task_instance_id: int, request: Request) -> Any: """Log a task_instance as errored. Args: task_instance_id (str): id of the task_instance to log done request (Request): fastapi request object """ structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id) data = cast(Dict, await request.json()) logger.info(f"Log ERROR for TI:{task_instance_id}.") with SessionMaker() as session: with session.begin(): select_stmt = select(TaskInstance).where( TaskInstance.id == task_instance_id ) task_instance = session.execute(select_stmt).scalars().one() optional_vals = [ "distributor_id", "stdout_log", "stderr_log", "nodename", "stdout", "stderr", ] for optional_val in optional_vals: val = data.get(optional_val, None) if data is not None: setattr(task_instance, optional_val, val) # add error log error_state = data["error_state"] error_description = data["error_description"] try: session.execute( select(Task) .where(Task.id == task_instance.task_id) .with_for_update() ).scalar_one() task_instance.transition(error_state) # Add error log only if transition was successful error = TaskInstanceErrorLog( task_instance_id=task_instance.id, description=error_description ) session.add(error) except InvalidStateTransition as e: if task_instance.status == error_state: logger.warning(e) else: # Tried to move to an illegal state logger.error(e) # Note: Transaction will rollback automatically on exception resp = JSONResponse( content={"status": task_instance.status}, status_code=StatusCodes.OK ) return resp
@api_v2_router.get("/task_instance/{task_instance_id}/task_instance_error_log")
[docs] async def get_task_instance_error_log(task_instance_id: int) -> Any: """Route to return all task_instance_error_log entries of the task_instance_id. Args: task_instance_id (int): ID of the task instance Return: jsonified task_instance_error_log result set """ structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id) logger.info(f"Getting task instance error log for ti {task_instance_id}") with SessionMaker() as session: with session.begin(): select_stmt = ( select(TaskInstanceErrorLog) .where(TaskInstanceErrorLog.task_instance_id == task_instance_id) .order_by(TaskInstanceErrorLog.task_instance_id) ) res = session.execute(select_stmt).scalars().all() r = [tiel.to_wire() for tiel in res] resp = JSONResponse( content={"task_instance_error_log": r}, status_code=StatusCodes.OK ) return resp
@api_v2_router.get("/get_array_task_instance_id/{array_id}/{batch_num}/{step_id}")
[docs] def get_array_task_instance_id(array_id: int, batch_num: int, step_id: int) -> Any: """Given an array ID and an index, select a single task instance ID. Task instance IDs that are associated with the array are ordered, and selected by index. This route will be called once per array task instance worker node, so must be scalable. """ structlog.contextvars.bind_contextvars(array_id=array_id) with SessionMaker() as session: with session.begin(): select_stmt = select(TaskInstance.id).where( TaskInstance.array_id == array_id, TaskInstance.array_batch_num == batch_num, TaskInstance.array_step_id == step_id, ) task_instance_id = session.execute(select_stmt).scalars().one() resp = JSONResponse( content={"task_instance_id": task_instance_id}, status_code=StatusCodes.OK ) return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_no_distributor_id")
[docs] async def log_no_distributor_id(task_instance_id: int, request: Request) -> Any: """Log a task_instance_id that did not get an distributor_id upon submission.""" structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id) logger.info( f"Logging ti {task_instance_id} did not get distributor id upon submission" ) data = cast(Dict, await request.json()) logger.debug(f"Log NO DISTRIBUTOR ID. Data {data['no_id_err_msg']}") err_msg = data["no_id_err_msg"] with SessionMaker() as session: with session.begin(): select_stmt = select(TaskInstance).where( TaskInstance.id == task_instance_id ) task_instance = session.execute(select_stmt).scalars().one() msg = _update_task_instance_state( task_instance, constants.TaskInstanceStatus.NO_DISTRIBUTOR_ID, request ) error = TaskInstanceErrorLog( task_instance_id=task_instance.id, description=err_msg ) session.add(error) resp = JSONResponse(content={"message": msg}, status_code=StatusCodes.OK) return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_distributor_id")
[docs] async def log_distributor_id(task_instance_id: int, request: Request) -> Any: """Log a task_instance's distributor id. Args: task_instance_id: id of the task_instance to log request: fastapi request object """ structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id) data = cast(Dict, await request.json()) with SessionMaker() as session: with session.begin(): select_stmt = select(TaskInstance).where( TaskInstance.id == task_instance_id ) task_instance = session.execute(select_stmt).scalars().one() msg = _update_task_instance_state( task_instance, constants.TaskInstanceStatus.LAUNCHED, request ) task_instance.distributor_id = data["distributor_id"] task_instance.report_by_date = add_time(data["next_report_increment"]) resp = JSONResponse(content={"message": msg}, status_code=StatusCodes.OK) return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_known_error")
[docs] async def log_known_error(task_instance_id: int, request: Request) -> Any: """Log a task_instance as errored. Args: task_instance_id (int): id for task instance. request (Request): fastapi request object. """ structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id) data = cast(Dict, await request.json()) error_state = data["error_state"] error_message = data["error_message"] distributor_id = data.get("distributor_id", None) nodename = data.get("nodename", None) logger.info(f"Log ERROR for TI:{task_instance_id}.") with SessionMaker() as session: with session.begin(): select_stmt = select(TaskInstance).where( TaskInstance.id == task_instance_id ) task_instance = session.execute(select_stmt).scalars().one() try: resp = _log_error( session, task_instance, error_state, error_message, distributor_id, nodename, request, ) except sqlalchemy.exc.OperationalError: # modify the error message and retry new_msg = error_message.encode("latin1", "replace").decode("utf-8") resp = _log_error( session, task_instance, error_state, new_msg, distributor_id, nodename, request, ) return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_unknown_error")
[docs] async def log_unknown_error(task_instance_id: int, request: Request) -> Any: """Log a task_instance as errored. Args: task_instance_id (int): id for task instance request (Request): fastapi request object """ structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id) data = cast(Dict, await request.json()) error_state = data["error_state"] error_message = data["error_message"] distributor_id = data.get("distributor_id", None) nodename = data.get("nodename", None) logger.info(f"Log ERROR for TI:{task_instance_id}.") with SessionMaker() as session: with session.begin(): # make sure the task hasn't logged a new heartbeat since we began # reconciliation select_stmt = select(TaskInstance).where( TaskInstance.id == task_instance_id, TaskInstance.report_by_date <= func.now(), ) task_instance = session.execute(select_stmt).scalars().one_or_none() session.flush() if task_instance is not None: try: resp = _log_error( session, task_instance, error_state, error_message, distributor_id, nodename, request, ) except sqlalchemy.exc.OperationalError: # modify the error message and retry new_msg = error_message.encode("latin1", "replace").decode("utf-8") resp = _log_error( session, task_instance, error_state, new_msg, distributor_id, nodename, request, ) return resp
@api_v2_router.post("/task_instance/instantiate_task_instances")
[docs] async def instantiate_task_instances(request: Request) -> Any: """Sync status of given task intance IDs.""" data = cast(Dict, await request.json()) task_instance_ids_list = tuple([int(tid) for tid in data["task_instance_ids"]]) with SessionMaker() as session: with session.begin(): # update the task table where FSM allows it sub_query = ( select(Task.id) .join(TaskInstance, TaskInstance.task_id == Task.id) .where( and_( TaskInstance.id.in_(task_instance_ids_list), Task.status == constants.TaskStatus.QUEUED, ) ) ).alias("derived_table") task_update = ( update(Task) .where(Task.id.in_(select(sub_query.c.id))) .values( status=constants.TaskStatus.INSTANTIATING, status_date=func.now() ) .execution_options(synchronize_session=False) ) session.execute(task_update) # then propagate back into task instance where a change was made sub_query = ( select(TaskInstance.id) .join(Task, TaskInstance.task_id == Task.id) .where( and_( # a successful transition (Task.status == constants.TaskStatus.INSTANTIATING), # and part of the current set TaskInstance.id.in_(task_instance_ids_list), ) ) ).alias("derived_table") task_instance_update = ( update(TaskInstance) .where(TaskInstance.id.in_(select(sub_query.c.id))) .values( status=constants.TaskInstanceStatus.INSTANTIATED, status_date=func.now(), ) .execution_options(synchronize_session=False) ) session.execute(task_instance_update) session.flush() # fetch rows individually without group_concat # Key is a tuple of array_id, array_name, array_batch_num, task_resources_id # Values are task instances in this batch grouped_data: DefaultDict = defaultdict(list) instantiated_batches_query = ( select( TaskInstance.array_id, Array.name, TaskInstance.array_batch_num, TaskInstance.task_resources_id, TaskInstance.id, ).where( TaskInstance.id.in_(task_instance_ids_list) & (TaskInstance.status == constants.TaskInstanceStatus.INSTANTIATED) & (TaskInstance.array_id == Array.id) ) # Optionally, add an order_by clause here to make the rows easier to work with ) # Collect the rows into the defaultdict for ( array_id, array_name, array_batch_num, task_resources_id, task_instance_id, ) in session.execute(instantiated_batches_query): key = (array_id, array_batch_num, array_name, task_resources_id) grouped_data[key].append(int(task_instance_id)) # Serialize the grouped data serialized_batches = [] for key, task_instance_ids in grouped_data.items(): array_id, array_batch_num, array_name, task_resources_id = key serialized_batches.append( SerializeTaskInstanceBatch.to_wire( array_id=array_id, array_name=array_name, array_batch_num=array_batch_num, task_resources_id=task_resources_id, task_instance_ids=task_instance_ids, ) ) resp = JSONResponse( content={"task_instance_batches": serialized_batches}, status_code=StatusCodes.OK, ) return resp
# ############################ HELPER FUNCTIONS ###############################
[docs] def _update_task_instance_state( task_instance: TaskInstance, status_id: str, request: Request ) -> Any: """Advance the states of task_instance and it's associated Task. Return any messages that should be published based on the transition. Args: task_instance (TaskInstance): object of time models.TaskInstance status_id (int): id of the status to which to transition request (Request): fastapi request object """ response = "" try: task_instance.transition(status_id) except InvalidStateTransition: if task_instance.status == status_id: # It was already in that state, just log it msg = ( f"Attempting to transition to existing state." f"Not transitioning task, tid= " f"{task_instance.id} from {task_instance.status} to " f"{status_id}" ) logger.warning(msg) response += msg else: # Tried to move to an illegal state msg = ( f"Illegal state transition. Not transitioning task, " f"tid={task_instance.id}, from {task_instance.status} to " f"{status_id}" ) logger.error(msg) response += msg except Exception as e: raise ServerError( f"General exception in _update_task_instance_state, jid " f"{task_instance}, transitioning to {task_instance}. Not " f"transitioning task. Server Error in {request.url.path}", status_code=500, ) from e return response
[docs] def _log_error( session: Session, ti: TaskInstance, error_state: str, error_msg: str, distributor_id: Optional[int] = None, nodename: Optional[str] = None, request: Optional[Request] = None, ) -> Any: if nodename is not None: ti.nodename = nodename # type: ignore if distributor_id is not None: ti.distributor_id = str(distributor_id) try: error = TaskInstanceErrorLog(task_instance_id=ti.id, description=error_msg) session.add(error) msg = _update_task_instance_state(ti, error_state, request) # type: ignore session.flush() resp = JSONResponse(content={"message": msg}, status_code=StatusCodes.OK) except Exception as e: session.rollback() if request is not None: raise ServerError( f"Unexpected Jobmon Server Error in {request.url.path}", status_code=500 ) from e else: raise ServerError("Unexpected Jobmon Server Error", status_code=500) from e return resp