"""Routes for TaskInstances."""
from collections import defaultdict
from http import HTTPStatus as StatusCodes
from typing import Any, DefaultDict, Dict, Optional, cast
import sqlalchemy
import structlog
from fastapi import Request
from sqlalchemy import and_, select, update
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from starlette.responses import JSONResponse
from jobmon.core import constants
from jobmon.core.exceptions import InvalidStateTransition
from jobmon.core.serializers import SerializeTaskInstanceBatch
from jobmon.server.web._compat import add_time
from jobmon.server.web.db import get_sessionmaker
from jobmon.server.web.models.array import Array
from jobmon.server.web.models.task import Task
from jobmon.server.web.models.task_instance import TaskInstance
from jobmon.server.web.models.task_instance_error_log import TaskInstanceErrorLog
from jobmon.server.web.routes.v2.fsm import fsm_router as api_v2_router
from jobmon.server.web.server_side_exception import ServerError
[docs]
logger = structlog.get_logger(__name__)
[docs]
SessionMaker = get_sessionmaker()
@api_v2_router.post("/task_instance/{task_instance_id}/log_running")
[docs]
async def log_running(task_instance_id: int, request: Request) -> Any:
"""Log a task_instance as running.
Args:
task_instance_id: id of the task_instance to log as running
request: fastapi request object
"""
structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id)
data = cast(Dict, await request.json())
with SessionMaker() as session:
with session.begin():
select_stmt = select(TaskInstance).where(
TaskInstance.id == task_instance_id
)
task_instance = session.execute(select_stmt).scalars().one()
if data.get("distributor_id", None) is not None:
task_instance.distributor_id = data["distributor_id"]
if data.get("nodename", None) is not None:
task_instance.nodename = data["nodename"]
task_instance.process_group_id = data["process_group_id"]
try:
task_instance.transition(constants.TaskInstanceStatus.RUNNING)
task_instance.report_by_date = add_time(data["next_report_increment"])
except InvalidStateTransition as e:
if task_instance.status == constants.TaskInstanceStatus.RUNNING:
logger.warning(e)
elif task_instance.status == constants.TaskInstanceStatus.KILL_SELF:
task_instance.transition(constants.TaskInstanceStatus.ERROR_FATAL)
elif task_instance.status == constants.TaskInstanceStatus.NO_HEARTBEAT:
task_instance.transition(constants.TaskInstanceStatus.ERROR)
else:
# Tried to move to an illegal state
logger.error(e)
wire_format = task_instance.to_wire_as_worker_node_task_instance()
resp = JSONResponse(
content={"task_instance": wire_format}, status_code=StatusCodes.OK
)
return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_report_by")
[docs]
async def log_ti_report_by(task_instance_id: int, request: Request) -> Any:
"""Log a task_instance as being responsive with a new report_by_date.
This is done at the worker node heartbeat_interval rate, so it may not happen at the same
rate that the reconciler updates batch submitted report_by_dates (also because it causes
a lot of traffic if all workers are logging report by_dates often compared to if the
reconciler runs often).
Args:
task_instance_id: id of the task_instance to log
request: fastapi request object
"""
structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id)
data = cast(Dict, await request.json())
with SessionMaker() as session:
with session.begin():
vals = {"report_by_date": add_time(data["next_report_increment"])}
for optional_val in ["distributor_id", "stderr", "stdout"]:
val = data.get(optional_val, None)
if data is not None:
vals[optional_val] = val
update_stmt = update(TaskInstance).where(
TaskInstance.id == task_instance_id
)
session.execute(update_stmt.values(**vals))
session.flush()
select_stmt = select(TaskInstance).where(
TaskInstance.id == task_instance_id
)
task_instance = session.execute(select_stmt).scalars().one()
if task_instance.status == constants.TaskInstanceStatus.TRIAGING:
task_instance.transition(constants.TaskInstanceStatus.RUNNING)
resp = JSONResponse(
content={"status": task_instance.status}, status_code=StatusCodes.OK
)
return resp
@api_v2_router.post("/task_instance/log_report_by/batch")
[docs]
async def log_ti_report_by_batch(request: Request) -> Any:
"""Log task_instances as being responsive with a new report_by_date.
This is done at the worker node heartbeat_interval rate, so it may not happen at the same
rate that the reconciler updates batch submitted report_by_dates (also because it causes
a lot of traffic if all workers are logging report by_dates often compared to if the
reconciler runs often).
Args:
task_instance_id: id of the task_instance to log
request: fastapi request object
"""
data = cast(Dict, await request.json())
tis = data.get("task_instance_ids", None)
next_report_increment = float(data["next_report_increment"])
logger.debug(f"Log report_by for TI {tis}.")
if tis:
with SessionMaker() as session:
with session.begin():
update_stmt = (
update(TaskInstance)
.where(
TaskInstance.id.in_(tis),
TaskInstance.status == constants.TaskInstanceStatus.LAUNCHED,
)
.values(report_by_date=add_time(next_report_increment))
)
session.execute(update_stmt)
resp = JSONResponse(content={}, status_code=StatusCodes.OK)
return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_done")
[docs]
async def log_done(task_instance_id: int, request: Request) -> Any:
"""Log a task_instance as done.
Args:
task_instance_id: id of the task_instance to log done
request: fastapi request object
"""
structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id)
data = cast(Dict, await request.json())
with SessionMaker() as session:
with session.begin():
select_stmt = select(TaskInstance).where(
TaskInstance.id == task_instance_id
)
task_instance = session.execute(select_stmt).scalars().one()
optional_vals = [
"distributor_id",
"stdout_log",
"stderr_log",
"nodename",
"stdout",
"stderr",
]
for optional_val in optional_vals:
val = data.get(optional_val, None)
if val is not None:
setattr(task_instance, optional_val, val)
try:
task_instance.transition(constants.TaskInstanceStatus.DONE)
except InvalidStateTransition as e:
if task_instance.status == constants.TaskInstanceStatus.DONE:
logger.warning(e)
else:
# Tried to move to an illegal state
logger.error(e)
resp = JSONResponse(
content={"status": task_instance.status}, status_code=StatusCodes.OK
)
return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_error_worker_node")
[docs]
async def log_error_worker_node(task_instance_id: int, request: Request) -> Any:
"""Log a task_instance as errored.
Args:
task_instance_id (str): id of the task_instance to log done
request (Request): fastapi request object
"""
structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id)
data = cast(Dict, await request.json())
logger.info(f"Log ERROR for TI:{task_instance_id}.")
with SessionMaker() as session:
with session.begin():
select_stmt = select(TaskInstance).where(
TaskInstance.id == task_instance_id
)
task_instance = session.execute(select_stmt).scalars().one()
optional_vals = [
"distributor_id",
"stdout_log",
"stderr_log",
"nodename",
"stdout",
"stderr",
]
for optional_val in optional_vals:
val = data.get(optional_val, None)
if data is not None:
setattr(task_instance, optional_val, val)
# add error log
error_state = data["error_state"]
error_description = data["error_description"]
try:
session.execute(
select(Task)
.where(Task.id == task_instance.task_id)
.with_for_update()
).scalar_one()
task_instance.transition(error_state)
# Add error log only if transition was successful
error = TaskInstanceErrorLog(
task_instance_id=task_instance.id, description=error_description
)
session.add(error)
except InvalidStateTransition as e:
if task_instance.status == error_state:
logger.warning(e)
else:
# Tried to move to an illegal state
logger.error(e)
# Note: Transaction will rollback automatically on exception
resp = JSONResponse(
content={"status": task_instance.status}, status_code=StatusCodes.OK
)
return resp
@api_v2_router.get("/task_instance/{task_instance_id}/task_instance_error_log")
[docs]
async def get_task_instance_error_log(task_instance_id: int) -> Any:
"""Route to return all task_instance_error_log entries of the task_instance_id.
Args:
task_instance_id (int): ID of the task instance
Return:
jsonified task_instance_error_log result set
"""
structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id)
logger.info(f"Getting task instance error log for ti {task_instance_id}")
with SessionMaker() as session:
with session.begin():
select_stmt = (
select(TaskInstanceErrorLog)
.where(TaskInstanceErrorLog.task_instance_id == task_instance_id)
.order_by(TaskInstanceErrorLog.task_instance_id)
)
res = session.execute(select_stmt).scalars().all()
r = [tiel.to_wire() for tiel in res]
resp = JSONResponse(
content={"task_instance_error_log": r}, status_code=StatusCodes.OK
)
return resp
@api_v2_router.get("/get_array_task_instance_id/{array_id}/{batch_num}/{step_id}")
[docs]
def get_array_task_instance_id(array_id: int, batch_num: int, step_id: int) -> Any:
"""Given an array ID and an index, select a single task instance ID.
Task instance IDs that are associated with the array are ordered, and selected by index.
This route will be called once per array task instance worker node, so must be scalable.
"""
structlog.contextvars.bind_contextvars(array_id=array_id)
with SessionMaker() as session:
with session.begin():
select_stmt = select(TaskInstance.id).where(
TaskInstance.array_id == array_id,
TaskInstance.array_batch_num == batch_num,
TaskInstance.array_step_id == step_id,
)
task_instance_id = session.execute(select_stmt).scalars().one()
resp = JSONResponse(
content={"task_instance_id": task_instance_id}, status_code=StatusCodes.OK
)
return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_no_distributor_id")
[docs]
async def log_no_distributor_id(task_instance_id: int, request: Request) -> Any:
"""Log a task_instance_id that did not get an distributor_id upon submission."""
structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id)
logger.info(
f"Logging ti {task_instance_id} did not get distributor id upon submission"
)
data = cast(Dict, await request.json())
logger.debug(f"Log NO DISTRIBUTOR ID. Data {data['no_id_err_msg']}")
err_msg = data["no_id_err_msg"]
with SessionMaker() as session:
with session.begin():
select_stmt = select(TaskInstance).where(
TaskInstance.id == task_instance_id
)
task_instance = session.execute(select_stmt).scalars().one()
msg = _update_task_instance_state(
task_instance, constants.TaskInstanceStatus.NO_DISTRIBUTOR_ID, request
)
error = TaskInstanceErrorLog(
task_instance_id=task_instance.id, description=err_msg
)
session.add(error)
resp = JSONResponse(content={"message": msg}, status_code=StatusCodes.OK)
return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_distributor_id")
[docs]
async def log_distributor_id(task_instance_id: int, request: Request) -> Any:
"""Log a task_instance's distributor id.
Args:
task_instance_id: id of the task_instance to log
request: fastapi request object
"""
structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id)
data = cast(Dict, await request.json())
with SessionMaker() as session:
with session.begin():
select_stmt = select(TaskInstance).where(
TaskInstance.id == task_instance_id
)
task_instance = session.execute(select_stmt).scalars().one()
msg = _update_task_instance_state(
task_instance, constants.TaskInstanceStatus.LAUNCHED, request
)
task_instance.distributor_id = data["distributor_id"]
task_instance.report_by_date = add_time(data["next_report_increment"])
resp = JSONResponse(content={"message": msg}, status_code=StatusCodes.OK)
return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_known_error")
[docs]
async def log_known_error(task_instance_id: int, request: Request) -> Any:
"""Log a task_instance as errored.
Args:
task_instance_id (int): id for task instance.
request (Request): fastapi request object.
"""
structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id)
data = cast(Dict, await request.json())
error_state = data["error_state"]
error_message = data["error_message"]
distributor_id = data.get("distributor_id", None)
nodename = data.get("nodename", None)
logger.info(f"Log ERROR for TI:{task_instance_id}.")
with SessionMaker() as session:
with session.begin():
select_stmt = select(TaskInstance).where(
TaskInstance.id == task_instance_id
)
task_instance = session.execute(select_stmt).scalars().one()
try:
resp = _log_error(
session,
task_instance,
error_state,
error_message,
distributor_id,
nodename,
request,
)
except sqlalchemy.exc.OperationalError:
# modify the error message and retry
new_msg = error_message.encode("latin1", "replace").decode("utf-8")
resp = _log_error(
session,
task_instance,
error_state,
new_msg,
distributor_id,
nodename,
request,
)
return resp
@api_v2_router.post("/task_instance/{task_instance_id}/log_unknown_error")
[docs]
async def log_unknown_error(task_instance_id: int, request: Request) -> Any:
"""Log a task_instance as errored.
Args:
task_instance_id (int): id for task instance
request (Request): fastapi request object
"""
structlog.contextvars.bind_contextvars(task_instance_id=task_instance_id)
data = cast(Dict, await request.json())
error_state = data["error_state"]
error_message = data["error_message"]
distributor_id = data.get("distributor_id", None)
nodename = data.get("nodename", None)
logger.info(f"Log ERROR for TI:{task_instance_id}.")
with SessionMaker() as session:
with session.begin():
# make sure the task hasn't logged a new heartbeat since we began
# reconciliation
select_stmt = select(TaskInstance).where(
TaskInstance.id == task_instance_id,
TaskInstance.report_by_date <= func.now(),
)
task_instance = session.execute(select_stmt).scalars().one_or_none()
session.flush()
if task_instance is not None:
try:
resp = _log_error(
session,
task_instance,
error_state,
error_message,
distributor_id,
nodename,
request,
)
except sqlalchemy.exc.OperationalError:
# modify the error message and retry
new_msg = error_message.encode("latin1", "replace").decode("utf-8")
resp = _log_error(
session,
task_instance,
error_state,
new_msg,
distributor_id,
nodename,
request,
)
return resp
@api_v2_router.post("/task_instance/instantiate_task_instances")
[docs]
async def instantiate_task_instances(request: Request) -> Any:
"""Sync status of given task intance IDs."""
data = cast(Dict, await request.json())
task_instance_ids_list = tuple([int(tid) for tid in data["task_instance_ids"]])
with SessionMaker() as session:
with session.begin():
# update the task table where FSM allows it
sub_query = (
select(Task.id)
.join(TaskInstance, TaskInstance.task_id == Task.id)
.where(
and_(
TaskInstance.id.in_(task_instance_ids_list),
Task.status == constants.TaskStatus.QUEUED,
)
)
).alias("derived_table")
task_update = (
update(Task)
.where(Task.id.in_(select(sub_query.c.id)))
.values(
status=constants.TaskStatus.INSTANTIATING, status_date=func.now()
)
.execution_options(synchronize_session=False)
)
session.execute(task_update)
# then propagate back into task instance where a change was made
sub_query = (
select(TaskInstance.id)
.join(Task, TaskInstance.task_id == Task.id)
.where(
and_(
# a successful transition
(Task.status == constants.TaskStatus.INSTANTIATING),
# and part of the current set
TaskInstance.id.in_(task_instance_ids_list),
)
)
).alias("derived_table")
task_instance_update = (
update(TaskInstance)
.where(TaskInstance.id.in_(select(sub_query.c.id)))
.values(
status=constants.TaskInstanceStatus.INSTANTIATED,
status_date=func.now(),
)
.execution_options(synchronize_session=False)
)
session.execute(task_instance_update)
session.flush()
# fetch rows individually without group_concat
# Key is a tuple of array_id, array_name, array_batch_num, task_resources_id
# Values are task instances in this batch
grouped_data: DefaultDict = defaultdict(list)
instantiated_batches_query = (
select(
TaskInstance.array_id,
Array.name,
TaskInstance.array_batch_num,
TaskInstance.task_resources_id,
TaskInstance.id,
).where(
TaskInstance.id.in_(task_instance_ids_list)
& (TaskInstance.status == constants.TaskInstanceStatus.INSTANTIATED)
& (TaskInstance.array_id == Array.id)
)
# Optionally, add an order_by clause here to make the rows easier to work with
)
# Collect the rows into the defaultdict
for (
array_id,
array_name,
array_batch_num,
task_resources_id,
task_instance_id,
) in session.execute(instantiated_batches_query):
key = (array_id, array_batch_num, array_name, task_resources_id)
grouped_data[key].append(int(task_instance_id))
# Serialize the grouped data
serialized_batches = []
for key, task_instance_ids in grouped_data.items():
array_id, array_batch_num, array_name, task_resources_id = key
serialized_batches.append(
SerializeTaskInstanceBatch.to_wire(
array_id=array_id,
array_name=array_name,
array_batch_num=array_batch_num,
task_resources_id=task_resources_id,
task_instance_ids=task_instance_ids,
)
)
resp = JSONResponse(
content={"task_instance_batches": serialized_batches},
status_code=StatusCodes.OK,
)
return resp
# ############################ HELPER FUNCTIONS ###############################
[docs]
def _update_task_instance_state(
task_instance: TaskInstance, status_id: str, request: Request
) -> Any:
"""Advance the states of task_instance and it's associated Task.
Return any messages that should be published based on the transition.
Args:
task_instance (TaskInstance): object of time models.TaskInstance
status_id (int): id of the status to which to transition
request (Request): fastapi request object
"""
response = ""
try:
task_instance.transition(status_id)
except InvalidStateTransition:
if task_instance.status == status_id:
# It was already in that state, just log it
msg = (
f"Attempting to transition to existing state."
f"Not transitioning task, tid= "
f"{task_instance.id} from {task_instance.status} to "
f"{status_id}"
)
logger.warning(msg)
response += msg
else:
# Tried to move to an illegal state
msg = (
f"Illegal state transition. Not transitioning task, "
f"tid={task_instance.id}, from {task_instance.status} to "
f"{status_id}"
)
logger.error(msg)
response += msg
except Exception as e:
raise ServerError(
f"General exception in _update_task_instance_state, jid "
f"{task_instance}, transitioning to {task_instance}. Not "
f"transitioning task. Server Error in {request.url.path}",
status_code=500,
) from e
return response
[docs]
def _log_error(
session: Session,
ti: TaskInstance,
error_state: str,
error_msg: str,
distributor_id: Optional[int] = None,
nodename: Optional[str] = None,
request: Optional[Request] = None,
) -> Any:
if nodename is not None:
ti.nodename = nodename # type: ignore
if distributor_id is not None:
ti.distributor_id = str(distributor_id)
try:
error = TaskInstanceErrorLog(task_instance_id=ti.id, description=error_msg)
session.add(error)
msg = _update_task_instance_state(ti, error_state, request) # type: ignore
session.flush()
resp = JSONResponse(content={"message": msg}, status_code=StatusCodes.OK)
except Exception as e:
session.rollback()
if request is not None:
raise ServerError(
f"Unexpected Jobmon Server Error in {request.url.path}", status_code=500
) from e
else:
raise ServerError("Unexpected Jobmon Server Error", status_code=500) from e
return resp