Source code for server.web.routes.v2.reaper.reaper

"""Routes used to move through the finite state."""

from http import HTTPStatus as StatusCodes
from typing import Any, Sequence, Tuple

from flask import jsonify, request
from sqlalchemy import func, Row, Select, select, text, update
import structlog

from jobmon.core.exceptions import InvalidStateTransition
from jobmon.server.web.models.task import Task
from jobmon.server.web.models.workflow import Workflow
from jobmon.server.web.models.workflow_run import WorkflowRun
from jobmon.server.web.models.workflow_run_status import WorkflowRunStatus
from jobmon.server.web.models.workflow_status import WorkflowStatus
from jobmon.server.web.routes.v1 import api_v1_blueprint
from jobmon.server.web.routes.v2 import api_v2_blueprint
from jobmon.server.web.routes.v2 import SessionLocal

# new structlog logger per flask request context. internally stored as flask.g.logger
[docs] logger = structlog.get_logger(__name__)
@api_v1_blueprint.route( "/workflow/<workflow_id>/fix_status_inconsistency", methods=["PUT"] ) @api_v2_blueprint.route( "/workflow/<workflow_id>/fix_status_inconsistency", methods=["PUT"] )
[docs] def fix_wf_inconsistency(workflow_id: int) -> Any: """Find wf in F with all tasks in D and fix them. For flexibility, pass in the step size. It is easier to redeploy the reaper than the service. """ data = request.get_json() increase_step = data["increase_step"] logger.debug( f"Fix inconsistencies starting at workflow {workflow_id} by {increase_step}" ) session = SessionLocal() with session.begin(): sql0 = select(Workflow.id) rows = session.execute(sql0).all() # the id to return to reaper as next start point total_wf = len(rows) # move the starting row forward by increase_step # It takes about 1 second per thousand; increase_step is passed in from the reaper. # Lf the starting row > max row, restart from workflow-id 0. # This way, we can get to the unfinished the wf later # without querying the whole db every time. current_max_wf_id = int(workflow_id) + int(increase_step) if current_max_wf_id > total_wf: logger.info("Fix inconsistencies starting from workflow_id zero again") current_max_wf_id = 0 # Update wf in F with all task in D to D # count(s) will have the total number of tasks, sum(s) is those in D. # If the two are equal, then the workflow Tasks are all D and therefore the workflow # should be D. session = SessionLocal() with session.begin(): query_filter = [ Workflow.id > workflow_id, Workflow.id <= int(workflow_id) + increase_step, Workflow.status == "F", Workflow.id == Task.workflow_id, ] sql: Select[Tuple[int, str]] = select(Workflow.id, Task.status).where( *query_filter ) rows1: Sequence[Row[Tuple[int, str]]] = session.execute(sql).all() result_set = set([r[0] for r in rows1]) for r in rows1: if r[1] != "D" and r[0] in result_set: result_set -= {r[0]} result_list = list(result_set) if result_list is None or len(result_list) == 0: logger.debug("No inconsistent F-D workflows to fix.") else: logger.info("Fixing inconsistent F-D workflow: {ids}") session = SessionLocal() with session.begin(): update_stmt = ( update(Workflow) .where(Workflow.id.in_(result_list)) .values(status="D", status_date=func.now()) ) session.execute(update_stmt) session.commit() logger.debug("Done fixing F-D inconsistent workflows.") resp = jsonify({"wfid": current_max_wf_id}) resp.status_code = StatusCodes.OK return resp
@api_v1_blueprint.route( "/workflow/<workflow_id>/workflow_name_and_args", methods=["GET"] ) @api_v2_blueprint.route( "/workflow/<workflow_id>/workflow_name_and_args", methods=["GET"] )
[docs] def get_wf_name_and_args(workflow_id: int) -> Any: """Return workflow name and args associated with specified workflow ID.""" session = SessionLocal() with session.begin(): query_filter = [Workflow.id == workflow_id] sql = select(Workflow.name, Workflow.workflow_args).where(*query_filter) result = session.execute(sql).all() if result is None or len(result) == 0: # return empty values in case of DB inconsistency resp = jsonify(workflow_name=None, workflow_args=None) resp.status_code = StatusCodes.OK return resp resp = jsonify(workflow_name=result[0][0], workflow_args=result[0][1]) resp.status_code = StatusCodes.OK return resp
@api_v1_blueprint.route("/lost_workflow_run", methods=["GET"]) @api_v2_blueprint.route("/lost_workflow_run", methods=["GET"])
[docs] def get_lost_workflow_runs() -> Any: """Return all workflow runs that are currently in the specified state.""" statuses = request.args.getlist("status") version = request.args.get("version") session = SessionLocal() with session.begin(): query_filter = [ WorkflowRun.status.in_(statuses), WorkflowRun.heartbeat_date <= func.now(), WorkflowRun.jobmon_server_version == version, ] sql = select(WorkflowRun.id, WorkflowRun.workflow_id).where(*query_filter) rows = session.execute(sql).all() workflow_runs = [(r[0], r[1]) for r in rows] resp = jsonify(workflow_runs=workflow_runs) resp.status_code = StatusCodes.OK return resp
@api_v1_blueprint.route("/workflow_run/<workflow_run_id>/reap", methods=["PUT"]) @api_v2_blueprint.route("/workflow_run/<workflow_run_id>/reap", methods=["PUT"])
[docs] def reap_workflow_run(workflow_run_id: int) -> Any: """If the last task was more than 2 minutes ago, transition wfr to A state. Also check WorkflowRun status_date to avoid possible race condition where reaper checks tasks from a different WorkflowRun with the same workflow id. Avoid setting while waiting for a resume (when workflow is in suspended state). """ structlog.contextvars.bind_contextvars(workflow_run_id=workflow_run_id) logger.info(f"Reap wfr: {workflow_run_id}") session = SessionLocal() with session.begin(): # get the wfr query_filter = [ WorkflowRun.id == workflow_run_id, WorkflowRun.heartbeat_date <= func.now(), ] sql = select(WorkflowRun.id, WorkflowRun.workflow_id, WorkflowRun.status).where( *query_filter ) rows = session.execute(sql).all() if len(rows) == 0: resp = jsonify(status="") resp.status_code = StatusCodes.OK return resp # reap wfr wfr_id, wf_id, wfr_status = rows[0][0], rows[0][1], rows[0][2] if wfr_status == WorkflowRunStatus.LINKING: logger.debug(f"Transitioning wfr {wfr_id} to ABORTED") target_wfr_status = WorkflowRunStatus.ABORTED target_wf_status = WorkflowStatus.ABORTED if wfr_status in [WorkflowRunStatus.COLD_RESUME, WorkflowRunStatus.HOT_RESUME]: logger.debug(f"Transitioning wfr {wfr_id} to TERMINATED") target_wfr_status = WorkflowRunStatus.TERMINATED target_wf_status = WorkflowStatus.HALTED if wfr_status == WorkflowRunStatus.RUNNING: logger.debug(f"Transitioning wfr {wfr_id} to ERROR") target_wfr_status = WorkflowRunStatus.ERROR target_wf_status = WorkflowStatus.FAILED # validate transition if (wfr_status, target_wfr_status) not in WorkflowRun().valid_transitions: try: raise InvalidStateTransition( model="WorkflowRun", id=wfr_id, old_state=wfr_status, new_state=target_wfr_status, ) except (InvalidStateTransition, AttributeError) as e: # this branch handles race condition or case where no wfr was returned logger.debug(f"Unable to reap workflow_run {wfr_id}: {e}") # update status session = SessionLocal() with session.begin(): query1 = f"""UPDATE workflow_run SET status="{target_wfr_status}" WHERE id={wfr_id} """ session.execute(text(query1)) query2 = f"""UPDATE workflow SET status="{target_wf_status}" WHERE id={wf_id} """ session.execute(text(query2)) session.commit() resp = jsonify(status=target_wfr_status) resp.status_code = StatusCodes.OK return resp