"""Routes used to move through the finite state."""
from http import HTTPStatus as StatusCodes
from typing import Any, Union
import structlog
from fastapi import Depends, Query, Request
from sqlalchemy import case, func, select, text, update
from sqlalchemy.orm import Session
from starlette.responses import JSONResponse
from jobmon.core.exceptions import InvalidStateTransition
from jobmon.core.logging import set_jobmon_context
from jobmon.server.web.db.deps import get_db
from jobmon.server.web.models.task import Task
from jobmon.server.web.models.workflow import Workflow
from jobmon.server.web.models.workflow_run import WorkflowRun
from jobmon.server.web.models.workflow_run_status import WorkflowRunStatus
from jobmon.server.web.models.workflow_status import WorkflowStatus
from jobmon.server.web.routes.v3.reaper import reaper_router as api_v3_router
# new structlog logger per flask request context. internally stored as flask.g.logger
[docs]
logger = structlog.get_logger(__name__)
@api_v3_router.put("/workflow/{workflow_id}/fix_status_inconsistency")
[docs]
async def fix_wf_inconsistency(
workflow_id: int, request: Request, db: Session = Depends(get_db)
) -> Any:
"""Find wf in F with all tasks in D and fix them.
For flexibility, pass in the step size. It is easier to redeploy the reaper than the
service.
"""
data = await request.json()
increase_step = data["increase_step"]
logger.debug(
f"Fix inconsistencies starting at workflow {workflow_id} by {increase_step}"
)
total_wf = db.execute(select(func.count(Workflow.id))).scalar() or 0
# move the starting row forward by increase_step
# It takes about 1 second per thousand; increase_step is passed in from the reaper.
# Lf the starting row > max row, restart from workflow-id 0.
# This way, we can get to the unfinished the wf later
# without querying the whole db every time.
current_max_wf_id = int(workflow_id) + int(increase_step)
if current_max_wf_id > total_wf:
logger.info("Fix inconsistencies starting from workflow_id zero again")
current_max_wf_id = 0
# Update wf in F with all task in D to D
# Find workflows where all tasks have status 'D' using CTE for better performance
workflow_tasks_cte = (
select(
Task.workflow_id,
func.count(Task.id).label("total_tasks"),
func.sum(case((Task.status == "D", 1), else_=0)).label("done_tasks"),
)
.where(
Task.workflow_id.in_(
select(Workflow.id).where(
Workflow.id > workflow_id,
Workflow.id <= int(workflow_id) + increase_step,
Workflow.status == "F",
)
)
)
.group_by(Task.workflow_id)
.having(func.count(Task.id) == func.sum(case((Task.status == "D", 1), else_=0)))
.cte("workflow_tasks")
)
sql = select(workflow_tasks_cte.c.workflow_id)
result_list = [row[0] for row in db.execute(sql).all()]
if result_list is None or len(result_list) == 0:
logger.debug("No inconsistent F-D workflows to fix.")
else:
logger.info("Fixing inconsistent F-D workflow: {ids}")
update_stmt = (
update(Workflow)
.where(Workflow.id.in_(result_list))
.values(status="D", status_date=func.now())
)
db.execute(update_stmt)
db.commit()
logger.debug("Done fixing F-D inconsistent workflows.")
resp = JSONResponse(content={"wfid": current_max_wf_id}, status_code=StatusCodes.OK)
return resp
@api_v3_router.get("/workflow/{workflow_id}/workflow_name_and_args")
[docs]
def get_wf_name_and_args(workflow_id: int, db: Session = Depends(get_db)) -> Any:
"""Return workflow name and args associated with specified workflow ID."""
query_filter = [Workflow.id == workflow_id]
sql = select(Workflow.name, Workflow.workflow_args).where(*query_filter)
result = db.execute(sql).all()
if result is None or len(result) == 0:
# return empty values in case of DB inconsistency
resp = JSONResponse(
content={"workflow_name": None, "workflow_args": None},
status_code=StatusCodes.OK,
)
resp = JSONResponse(
content={"workflow_name": result[0][0], "workflow_args": result[0][1]},
status_code=StatusCodes.OK,
)
return resp
@api_v3_router.get("/lost_workflow_run")
[docs]
def get_lost_workflow_runs(
status: Union[str, list[str]] = Query(...),
version: str = Query(...),
db: Session = Depends(get_db),
) -> Any:
"""Return all workflow runs that are currently in the specified state."""
if isinstance(status, str):
status = [status]
query_filter = [
WorkflowRun.status.in_(status),
WorkflowRun.heartbeat_date <= func.now(),
]
sql = select(WorkflowRun.id, WorkflowRun.workflow_id).where(*query_filter)
rows = db.execute(sql).all()
workflow_runs = [(r[0], r[1]) for r in rows]
resp = JSONResponse(
content={"workflow_runs": workflow_runs}, status_code=StatusCodes.OK
)
return resp
@api_v3_router.put("/workflow_run/{workflow_run_id}/reap")
[docs]
def reap_workflow_run(workflow_run_id: int, db: Session = Depends(get_db)) -> Any:
"""If the last task was more than 2 minutes ago, transition wfr to A state.
Also check WorkflowRun status_date to avoid possible race condition where reaper
checks tasks from a different WorkflowRun with the same workflow id. Avoid setting
while waiting for a resume (when workflow is in suspended state).
"""
set_jobmon_context(workflow_run_id=workflow_run_id)
logger.info(f"Reap wfr: {workflow_run_id}")
# get the wfr
query_filter = [
WorkflowRun.id == workflow_run_id,
WorkflowRun.heartbeat_date <= func.now(),
]
sql = select(WorkflowRun.id, WorkflowRun.workflow_id, WorkflowRun.status).where(
*query_filter
)
rows = db.execute(sql).all()
if len(rows) == 0:
resp = JSONResponse(content={"status": ""}, status_code=StatusCodes.OK)
return resp
# reap wfr
wfr_id, wf_id, wfr_status = rows[0][0], rows[0][1], rows[0][2]
if wfr_status == WorkflowRunStatus.LINKING:
logger.debug(f"Transitioning wfr {wfr_id} to ABORTED")
target_wfr_status = WorkflowRunStatus.ABORTED
target_wf_status = WorkflowStatus.ABORTED
if wfr_status in [WorkflowRunStatus.COLD_RESUME, WorkflowRunStatus.HOT_RESUME]:
logger.debug(f"Transitioning wfr {wfr_id} to TERMINATED")
target_wfr_status = WorkflowRunStatus.TERMINATED
target_wf_status = WorkflowStatus.HALTED
if wfr_status == WorkflowRunStatus.RUNNING:
logger.debug(f"Transitioning wfr {wfr_id} to ERROR")
target_wfr_status = WorkflowRunStatus.ERROR
target_wf_status = WorkflowStatus.FAILED
# validate transition
if (wfr_status, target_wfr_status) not in WorkflowRun().valid_transitions:
try:
raise InvalidStateTransition(
model="WorkflowRun",
id=wfr_id,
old_state=wfr_status,
new_state=target_wfr_status,
)
except (InvalidStateTransition, AttributeError) as e:
# this branch handles race condition or case where no wfr was returned
logger.debug(f"Unable to reap workflow_run {wfr_id}: {e}")
# update status
query1 = f"""UPDATE workflow_run
SET status="{target_wfr_status}"
WHERE id={wfr_id}
"""
db.execute(text(query1))
query2 = f"""UPDATE workflow
SET status="{target_wf_status}"
WHERE id={wf_id}
"""
db.execute(text(query2))
db.commit()
resp = JSONResponse(
content={"status": target_wfr_status}, status_code=StatusCodes.OK
)
return resp