"""Routes for Workflows."""
from collections import defaultdict
from http import HTTPStatus as StatusCodes
from typing import Any, cast, Dict, List, Optional, Tuple
from flask import jsonify, request
import sqlalchemy
from sqlalchemy import func, select, update
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.orm import Session
import structlog
from jobmon.server.web.models.array import Array
from jobmon.server.web.models.dag import Dag
from jobmon.server.web.models.queue import Queue
from jobmon.server.web.models.task import Task
from jobmon.server.web.models.task_resources import TaskResources
from jobmon.server.web.models.task_status import TaskStatus
from jobmon.server.web.models.workflow import Workflow
from jobmon.server.web.models.workflow_attribute import WorkflowAttribute
from jobmon.server.web.models.workflow_attribute_type import WorkflowAttributeType
from jobmon.server.web.routes.v1 import api_v1_blueprint
from jobmon.server.web.routes.v2 import api_v2_blueprint
from jobmon.server.web.routes.v2 import SessionLocal
from jobmon.server.web.server_side_exception import InvalidUsage
[docs]
logger = structlog.get_logger(__name__)
[docs]
def _add_workflow_attributes(
workflow_id: int, workflow_attributes: Dict[str, str], session: Session
) -> None:
# add attribute
structlog.contextvars.bind_contextvars(workflow_id=workflow_id)
logger.info(f"Add Attributes: {workflow_attributes}")
wf_attributes_list = []
with session.begin_nested():
for name, val in workflow_attributes.items():
wf_type_id = _add_or_get_wf_attribute_type(name, session)
wf_attribute = WorkflowAttribute(
workflow_id=workflow_id,
workflow_attribute_type_id=wf_type_id,
value=val,
)
wf_attributes_list.append(wf_attribute)
logger.debug(f"Attribute name: {name}, value: {val}")
session.add_all(wf_attributes_list)
@api_v1_blueprint.route("/workflow", methods=["POST"])
@api_v2_blueprint.route("/workflow", methods=["POST"])
[docs]
def bind_workflow() -> Any:
"""Bind a workflow to the database."""
try:
data = cast(Dict, request.get_json())
tv_id = int(data["tool_version_id"])
dag_id = int(data["dag_id"])
whash = str(data["workflow_args_hash"])
thash = str(data["task_hash"])
description = data["description"]
name = data["name"]
workflow_args = data["workflow_args"]
max_concurrently_running = data["max_concurrently_running"]
workflow_attributes = data["workflow_attributes"]
except Exception as e:
raise InvalidUsage(
f"{str(e)} in request to {request.path}", status_code=400
) from e
structlog.contextvars.bind_contextvars(
dag_id=dag_id,
tool_version_id=tv_id,
workflow_args_hash=str(whash),
task_hash=str(thash),
)
logger.info("Bind workflow")
session = SessionLocal()
with session.begin():
select_stmt = select(Workflow).where(
Workflow.tool_version_id == tv_id,
Workflow.dag_id == dag_id,
Workflow.workflow_args_hash == whash,
Workflow.task_hash == thash,
)
workflow = session.execute(select_stmt).scalars().one_or_none()
if workflow is None:
# create a new workflow
workflow = Workflow(
tool_version_id=tv_id,
dag_id=dag_id,
workflow_args_hash=whash,
task_hash=thash,
description=description,
name=name,
workflow_args=workflow_args,
max_concurrently_running=max_concurrently_running,
)
session.add(workflow)
session.flush()
logger.info("Created new workflow")
# update attributes
if workflow_attributes and workflow and workflow.id:
_add_workflow_attributes(workflow.id, workflow_attributes, session)
session.flush()
newly_created = True
else:
# set mutable attributes. Moved here from the set_resume method
workflow.description = description
workflow.name = name
workflow.max_concurrently_running = max_concurrently_running
session.flush()
# upsert attributes
if workflow_attributes:
logger.info("Upsert attributes for workflow")
if workflow_attributes:
for name, val in workflow_attributes.items():
if workflow and workflow.id:
_upsert_wf_attribute(workflow.id, name, val, session)
newly_created = False
resp = jsonify(
{
"workflow_id": workflow.id,
"status": workflow.status,
"newly_created": newly_created,
}
)
resp.status_code = StatusCodes.OK
return resp
@api_v1_blueprint.route("/workflow/<workflow_args_hash>", methods=["GET"])
@api_v2_blueprint.route("/workflow/<workflow_args_hash>", methods=["GET"])
[docs]
def get_matching_workflows_by_workflow_args(workflow_args_hash: str) -> Any:
"""Return any dag hashes that are assigned to workflows with identical workflow args."""
try:
workflow_args_hash = str(int(workflow_args_hash))
except Exception as e:
raise InvalidUsage(
f"{str(e)} in request to {request.path}", status_code=400
) from e
structlog.contextvars.bind_contextvars(workflow_args_hash=str(workflow_args_hash))
logger.info(f"Looking for wf with hash {workflow_args_hash}")
session = SessionLocal()
with session.begin():
select_stmt = (
select(Workflow.task_hash, Workflow.tool_version_id, Dag.hash)
.join_from(Workflow, Dag, Workflow.dag_id == Dag.id)
.where(Workflow.workflow_args_hash == workflow_args_hash)
)
res = []
for row in session.execute(select_stmt).all():
res.append((row.task_hash, row.tool_version_id, row.hash))
if len(res) > 0:
logger.debug(
f"Found {res} workflow for " f"workflow_args_hash {workflow_args_hash}"
)
resp = jsonify(matching_workflows=res)
resp.status_code = StatusCodes.OK
return resp
[docs]
def _add_or_get_wf_attribute_type(name: str, session: Session) -> Optional[int]:
try:
with session.begin_nested():
wf_attrib_type = WorkflowAttributeType(name=name)
session.add(wf_attrib_type)
except sqlalchemy.exc.IntegrityError:
with session.begin_nested():
select_stmt = select(WorkflowAttributeType).where(
WorkflowAttributeType.name == name
)
wf_attrib_type = session.execute(select_stmt).scalars().one()
if wf_attrib_type:
return wf_attrib_type.id # type: ignore
else:
raise ValueError(f"Could not find or create attribute type {name}")
[docs]
def _upsert_wf_attribute(
workflow_id: int, name: str, value: str, session: Session
) -> None:
with session.begin_nested():
wf_attrib_id = _add_or_get_wf_attribute_type(name, session)
if (
SessionLocal
and SessionLocal.bind
and SessionLocal.bind.dialect.name == "mysql"
):
insert_vals1 = mysql_insert(WorkflowAttribute).values(
workflow_id=workflow_id,
workflow_attribute_type_id=wf_attrib_id,
value=value,
)
upsert_stmt = insert_vals1.on_duplicate_key_update(
value=insert_vals1.inserted.value
)
elif (
SessionLocal
and SessionLocal.bind
and SessionLocal.bind.dialect.name == "sqlite"
):
insert_vals2: sqlalchemy.dialects.sqlite.dml.Insert = sqlite_insert(
WorkflowAttribute
).values(
workflow_id=workflow_id,
workflow_attribute_type_id=wf_attrib_id,
value=value,
)
upsert_stmt = insert_vals2.on_conflict_do_update( # type: ignore
index_elements=["workflow_id", "workflow_attribute_type_id"],
set_=dict(value=value),
)
session.execute(upsert_stmt)
@api_v1_blueprint.route("/workflow/<workflow_id>/workflow_attributes", methods=["PUT"])
@api_v2_blueprint.route("/workflow/<workflow_id>/workflow_attributes", methods=["PUT"])
[docs]
def update_workflow_attribute(workflow_id: int) -> Any:
"""Update the attributes for a given workflow."""
structlog.contextvars.bind_contextvars(workflow_id=workflow_id)
try:
workflow_id = int(workflow_id)
except Exception as e:
raise InvalidUsage(
f"{str(e)} in request to {request.path}", status_code=400
) from e
""" Add/update attributes for a workflow """
data = cast(Dict, request.get_json())
logger.debug("Update attributes")
attributes = data["workflow_attributes"]
if attributes:
session = SessionLocal()
with session.begin():
for name, val in attributes.items():
_upsert_wf_attribute(workflow_id, name, val, session)
resp = jsonify()
resp.status_code = StatusCodes.OK
return resp
@api_v1_blueprint.route("/workflow/<workflow_id>/set_resume", methods=["POST"])
@api_v2_blueprint.route("/workflow/<workflow_id>/set_resume", methods=["POST"])
[docs]
def set_resume(workflow_id: int) -> Any:
"""Set resume on a workflow."""
structlog.contextvars.bind_contextvars(workflow_id=workflow_id)
try:
data = cast(Dict, request.get_json())
logger.info("Set resume for workflow")
reset_running_jobs = bool(data["reset_running_jobs"])
except Exception as e:
raise InvalidUsage(
f"{str(e)} in request to {request.path}", status_code=400
) from e
session = SessionLocal()
with session.begin():
select_stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.execute(select_stmt).scalars().one_or_none()
if workflow:
# trigger resume on active workflow run
workflow.resume(reset_running_jobs)
session.flush()
logger.info(f"Resume set for wf {workflow_id}")
resp = jsonify()
resp.status_code = StatusCodes.OK
return resp
@api_v1_blueprint.route("/workflow/<workflow_id>/is_resumable", methods=["GET"])
@api_v2_blueprint.route("/workflow/<workflow_id>/is_resumable", methods=["GET"])
[docs]
def workflow_is_resumable(workflow_id: int) -> Any:
"""Check if a workflow is in a resumable state."""
structlog.contextvars.bind_contextvars(workflow_id=workflow_id)
session = SessionLocal()
with session.begin():
select_stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.execute(select_stmt).scalars().one()
logger.info(f"Workflow is resumable: {workflow.is_resumable}")
resp = jsonify(workflow_is_resumable=workflow.is_resumable)
resp.status_code = StatusCodes.OK
return resp
@api_v1_blueprint.route(
"/workflow/<workflow_id>/get_max_concurrently_running", methods=["GET"]
)
@api_v2_blueprint.route(
"/workflow/<workflow_id>/get_max_concurrently_running", methods=["GET"]
)
[docs]
def get_max_concurrently_running(workflow_id: int) -> Any:
"""Return the maximum concurrency of this workflow."""
structlog.contextvars.bind_contextvars(workflow_id=workflow_id)
session = SessionLocal()
with session.begin():
select_stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.execute(select_stmt).scalars().one()
resp = jsonify(max_concurrently_running=workflow.max_concurrently_running)
resp.status_code = StatusCodes.OK
return resp
@api_v1_blueprint.route(
"workflow/<workflow_id>/update_max_concurrently_running", methods=["PUT"]
)
@api_v2_blueprint.route(
"workflow/<workflow_id>/update_max_concurrently_running", methods=["PUT"]
)
[docs]
def update_max_running(workflow_id: int) -> Any:
"""Update the number of tasks that can be running concurrently for a given workflow."""
data = cast(Dict, request.get_json())
structlog.contextvars.bind_contextvars(workflow_id=workflow_id)
logger.debug("Update workflow max concurrently running")
try:
new_limit = data["max_tasks"]
except KeyError as e:
raise InvalidUsage(
f"{str(e)} in request to {request.path}", status_code=400
) from e
session = SessionLocal()
with session.begin():
update_stmt = (
update(Workflow)
.where(Workflow.id == workflow_id)
.values(max_concurrently_running=new_limit)
)
res = session.execute(update_stmt)
if res.rowcount == 0: # Return a warning message if no update was performed
message = (
f"No update performed for workflow ID {workflow_id}, max_concurrently_running is "
f"{new_limit}"
)
else:
message = (
f"Workflow ID {workflow_id} max concurrently running updated to {new_limit}"
)
resp = jsonify(message=message)
resp.status_code = StatusCodes.OK
return resp
@api_v1_blueprint.route("/workflow/<workflow_id>/task_status_updates", methods=["POST"])
@api_v2_blueprint.route("/workflow/<workflow_id>/task_status_updates", methods=["POST"])
[docs]
def task_status_updates(workflow_id: int) -> Any:
"""Returns all tasks in the database that have the specified status.
Args:
workflow_id (int): the ID of the workflow.
"""
structlog.contextvars.bind_contextvars(workflow_id=workflow_id)
data = cast(Dict, request.get_json())
logger.info("Get task by status")
try:
filter_criteria: Tuple = (
(Task.workflow_id == workflow_id),
(Task.status_date >= data["last_sync"]),
)
except KeyError:
filter_criteria = (Task.workflow_id == workflow_id,)
# get time from db
session = SessionLocal()
with session.begin():
db_time = session.execute(select(func.now())).scalar()
str_time = db_time.strftime("%Y-%m-%d %H:%M:%S") if db_time else None
# Prepare and execute your query without GROUP_CONCAT
tasks_by_status_query = select(Task.status, Task.id).where(*filter_criteria)
# Fetch the rows
result_dict = defaultdict(list)
for row in session.execute(tasks_by_status_query):
result_dict[row.status].append(row.id)
resp = jsonify(tasks_by_status=result_dict, time=str_time)
resp.status_code = StatusCodes.OK
return resp
@api_v1_blueprint.route(
"/workflow/<workflow_id>/fetch_workflow_metadata", methods=["GET"]
)
@api_v2_blueprint.route(
"/workflow/<workflow_id>/fetch_workflow_metadata", methods=["GET"]
)
@api_v1_blueprint.route("/workflow/get_tasks/<workflow_id>", methods=["GET"])
@api_v2_blueprint.route("/workflow/get_tasks/<workflow_id>", methods=["GET"])
[docs]
def get_tasks_from_workflow(workflow_id: int) -> Any:
"""Return tasks associated with specified Workflow ID."""
max_task_id = request.args.get("max_task_id")
chunk_size = request.args.get("chunk_size")
session = SessionLocal()
if max_task_id == 0:
# Performance suffers heavily if we do a search with WHERE task.id > 0
# Therefore, select the smallest task ID in the workflow and use that as the initial
# floor.
with session.begin():
min_task_id = session.execute(
select(func.min(Task.id)).where(Task.workflow_id == workflow_id)
).scalar()
max_task_id = min_task_id - 1 if min_task_id else 0
with session.begin():
# Query task table
query = (
select(
Task.id,
Task.array_id,
Task.status,
Task.max_attempts,
Task.resource_scales,
Task.fallback_queues,
TaskResources.requested_resources,
TaskResources.queue_id,
)
.join_from(Task, Array, Task.array_id == Array.id)
.join_from(Task, TaskResources, Task.task_resources_id == TaskResources.id)
.where(
Task.workflow_id == workflow_id,
# Note: because of this status != "DONE" filter, only the portion of the DAG
# that is not complete is returned. Assumes that all tasks in a workflow
# correspond to nodes that belong in the same DAG, and that no downstream
# nodes can be in DONE for any unfinished task
Task.status != TaskStatus.DONE,
# Greater than set by the input max_task_id
Task.id > max_task_id,
)
.order_by(Task.id)
.limit(chunk_size)
)
res = session.execute(query).all()
queue_map: Dict[int, List[int]] = {}
array_map: Dict[int, List[int]] = {}
resp_dict = {}
for row in res:
task_id = row[0]
array_id = row[1]
queue_id = row[7]
row_metadata = row[1:7]
resp_dict[task_id] = list(row_metadata)
if queue_id not in queue_map:
queue_map[queue_id] = []
queue_map[queue_id].append(task_id)
if array_id not in array_map:
array_map[array_id] = []
array_map[array_id].append(task_id)
# get the queue and cluster
for queue_id in queue_map.keys():
queue = session.get(Queue, queue_id)
queue_name = queue.name if queue else None
cluster_name = queue.cluster.name if queue else None # type: ignore
for task_id in queue_map[queue_id]:
resp_dict[task_id].extend([cluster_name, queue_name])
# get the max concurrency
for array_id in array_map.keys():
array: Any = session.get(Array, array_id)
max_concurrently_running = array.max_concurrently_running
for task_id in array_map[array_id]:
resp_dict[task_id].append(max_concurrently_running)
resp = jsonify(tasks=resp_dict)
resp.status_code = StatusCodes.OK
return resp