Source code for server.web.routes.v3.reaper.reaper

"""Routes used to move through the finite state."""

from http import HTTPStatus as StatusCodes
from typing import Any, Union

import structlog
from fastapi import Depends, Query, Request
from sqlalchemy import case, func, select, text, update
from sqlalchemy.orm import Session
from starlette.responses import JSONResponse

from jobmon.core.exceptions import InvalidStateTransition
from jobmon.core.logging import set_jobmon_context
from jobmon.server.web.db.deps import get_db
from jobmon.server.web.models.task import Task
from jobmon.server.web.models.workflow import Workflow
from jobmon.server.web.models.workflow_run import WorkflowRun
from jobmon.server.web.models.workflow_run_status import WorkflowRunStatus
from jobmon.server.web.models.workflow_status import WorkflowStatus
from jobmon.server.web.routes.v3.reaper import reaper_router as api_v3_router

# new structlog logger per flask request context. internally stored as flask.g.logger
[docs] logger = structlog.get_logger(__name__)
@api_v3_router.put("/workflow/{workflow_id}/fix_status_inconsistency")
[docs] async def fix_wf_inconsistency( workflow_id: int, request: Request, db: Session = Depends(get_db) ) -> Any: """Find wf in F with all tasks in D and fix them. For flexibility, pass in the step size. It is easier to redeploy the reaper than the service. """ data = await request.json() increase_step = data["increase_step"] logger.debug( f"Fix inconsistencies starting at workflow {workflow_id} by {increase_step}" ) total_wf = db.execute(select(func.count(Workflow.id))).scalar() or 0 # move the starting row forward by increase_step # It takes about 1 second per thousand; increase_step is passed in from the reaper. # Lf the starting row > max row, restart from workflow-id 0. # This way, we can get to the unfinished the wf later # without querying the whole db every time. current_max_wf_id = int(workflow_id) + int(increase_step) if current_max_wf_id > total_wf: logger.info("Fix inconsistencies starting from workflow_id zero again") current_max_wf_id = 0 # Update wf in F with all task in D to D # Find workflows where all tasks have status 'D' using CTE for better performance workflow_tasks_cte = ( select( Task.workflow_id, func.count(Task.id).label("total_tasks"), func.sum(case((Task.status == "D", 1), else_=0)).label("done_tasks"), ) .where( Task.workflow_id.in_( select(Workflow.id).where( Workflow.id > workflow_id, Workflow.id <= int(workflow_id) + increase_step, Workflow.status == "F", ) ) ) .group_by(Task.workflow_id) .having(func.count(Task.id) == func.sum(case((Task.status == "D", 1), else_=0))) .cte("workflow_tasks") ) sql = select(workflow_tasks_cte.c.workflow_id) result_list = [row[0] for row in db.execute(sql).all()] if result_list is None or len(result_list) == 0: logger.debug("No inconsistent F-D workflows to fix.") else: logger.info("Fixing inconsistent F-D workflow: {ids}") update_stmt = ( update(Workflow) .where(Workflow.id.in_(result_list)) .values(status="D", status_date=func.now()) ) db.execute(update_stmt) db.commit() logger.debug("Done fixing F-D inconsistent workflows.") resp = JSONResponse(content={"wfid": current_max_wf_id}, status_code=StatusCodes.OK) return resp
@api_v3_router.get("/workflow/{workflow_id}/workflow_name_and_args")
[docs] def get_wf_name_and_args(workflow_id: int, db: Session = Depends(get_db)) -> Any: """Return workflow name and args associated with specified workflow ID.""" query_filter = [Workflow.id == workflow_id] sql = select(Workflow.name, Workflow.workflow_args).where(*query_filter) result = db.execute(sql).all() if result is None or len(result) == 0: # return empty values in case of DB inconsistency resp = JSONResponse( content={"workflow_name": None, "workflow_args": None}, status_code=StatusCodes.OK, ) resp = JSONResponse( content={"workflow_name": result[0][0], "workflow_args": result[0][1]}, status_code=StatusCodes.OK, ) return resp
@api_v3_router.get("/lost_workflow_run")
[docs] def get_lost_workflow_runs( status: Union[str, list[str]] = Query(...), version: str = Query(...), db: Session = Depends(get_db), ) -> Any: """Return all workflow runs that are currently in the specified state.""" if isinstance(status, str): status = [status] query_filter = [ WorkflowRun.status.in_(status), WorkflowRun.heartbeat_date <= func.now(), ] sql = select(WorkflowRun.id, WorkflowRun.workflow_id).where(*query_filter) rows = db.execute(sql).all() workflow_runs = [(r[0], r[1]) for r in rows] resp = JSONResponse( content={"workflow_runs": workflow_runs}, status_code=StatusCodes.OK ) return resp
@api_v3_router.put("/workflow_run/{workflow_run_id}/reap")
[docs] def reap_workflow_run(workflow_run_id: int, db: Session = Depends(get_db)) -> Any: """If the last task was more than 2 minutes ago, transition wfr to A state. Also check WorkflowRun status_date to avoid possible race condition where reaper checks tasks from a different WorkflowRun with the same workflow id. Avoid setting while waiting for a resume (when workflow is in suspended state). """ set_jobmon_context(workflow_run_id=workflow_run_id) logger.info(f"Reap wfr: {workflow_run_id}") # get the wfr query_filter = [ WorkflowRun.id == workflow_run_id, WorkflowRun.heartbeat_date <= func.now(), ] sql = select(WorkflowRun.id, WorkflowRun.workflow_id, WorkflowRun.status).where( *query_filter ) rows = db.execute(sql).all() if len(rows) == 0: resp = JSONResponse(content={"status": ""}, status_code=StatusCodes.OK) return resp # reap wfr wfr_id, wf_id, wfr_status = rows[0][0], rows[0][1], rows[0][2] if wfr_status == WorkflowRunStatus.LINKING: logger.debug(f"Transitioning wfr {wfr_id} to ABORTED") target_wfr_status = WorkflowRunStatus.ABORTED target_wf_status = WorkflowStatus.ABORTED if wfr_status in [WorkflowRunStatus.COLD_RESUME, WorkflowRunStatus.HOT_RESUME]: logger.debug(f"Transitioning wfr {wfr_id} to TERMINATED") target_wfr_status = WorkflowRunStatus.TERMINATED target_wf_status = WorkflowStatus.HALTED if wfr_status == WorkflowRunStatus.RUNNING: logger.debug(f"Transitioning wfr {wfr_id} to ERROR") target_wfr_status = WorkflowRunStatus.ERROR target_wf_status = WorkflowStatus.FAILED # validate transition if (wfr_status, target_wfr_status) not in WorkflowRun().valid_transitions: try: raise InvalidStateTransition( model="WorkflowRun", id=wfr_id, old_state=wfr_status, new_state=target_wfr_status, ) except (InvalidStateTransition, AttributeError) as e: # this branch handles race condition or case where no wfr was returned logger.debug(f"Unable to reap workflow_run {wfr_id}: {e}") # update status query1 = f"""UPDATE workflow_run SET status="{target_wfr_status}" WHERE id={wfr_id} """ db.execute(text(query1)) query2 = f"""UPDATE workflow SET status="{target_wf_status}" WHERE id={wf_id} """ db.execute(text(query2)) db.commit() resp = JSONResponse( content={"status": target_wfr_status}, status_code=StatusCodes.OK ) return resp