"""Routes for Arrays."""
from collections import defaultdict
from http import HTTPStatus as StatusCodes
from typing import Any, cast, Dict
from flask import jsonify, request
from sqlalchemy import and_, case, func, insert, literal_column, select, update
import structlog
from jobmon.core.constants import TaskInstanceStatus
from jobmon.server.web._compat import add_time
from jobmon.server.web.models.array import Array
from jobmon.server.web.models.task import Task
from jobmon.server.web.models.task_instance import TaskInstance
from jobmon.server.web.models.task_status import TaskStatus
from jobmon.server.web.routes.v1 import api_v1_blueprint
from jobmon.server.web.routes.v2 import api_v2_blueprint
from jobmon.server.web.routes.v2 import SessionLocal
[docs]
logger = structlog.get_logger(__name__)
@api_v1_blueprint.route("/array", methods=["POST"])
@api_v2_blueprint.route("/array", methods=["POST"])
[docs]
def add_array() -> Any:
"""Return an array ID by workflow and task template version ID.
If not found, bind the array.
"""
data = cast(Dict, request.get_json())
workflow_id = int(data["workflow_id"])
task_template_version_id = int(data["task_template_version_id"])
structlog.contextvars.bind_contextvars(
task_template_version_id=task_template_version_id,
workflow_id=workflow_id,
)
# Check if the array is already bound, if so return it
session = SessionLocal()
with session.begin():
select_stmt = select(Array).where(
Array.workflow_id == workflow_id,
Array.task_template_version_id == task_template_version_id,
)
array = session.execute(select_stmt).scalars().one_or_none()
if array is None: # not found, so need to add it
array = Array(
task_template_version_id=data["task_template_version_id"],
workflow_id=data["workflow_id"],
max_concurrently_running=data["max_concurrently_running"],
name=data["name"],
)
session.add(array)
else:
array_condition = and_(Array.id == array.id)
# take a lock on the array that will be updated
array_locks = (
select(Array.id)
.where(array_condition)
.with_for_update()
.execution_options(synchronize_session=False)
)
session.execute(array_locks)
# update the array with the new max_concurrently_running
update_stmt = (
update(Array)
.where(array_condition)
.values(max_concurrently_running=data["max_concurrently_running"])
)
session.execute(update_stmt)
session.commit()
# return result
resp = jsonify(array_id=array.id)
resp.status_code = StatusCodes.OK
return resp
@api_v1_blueprint.route("/array/<array_id>/queue_task_batch", methods=["POST"])
@api_v2_blueprint.route("/array/<array_id>/queue_task_batch", methods=["POST"])
[docs]
def record_array_batch_num(array_id: int) -> Any:
"""Record a batch number to associate sets of task instances with an array submission."""
data = cast(Dict, request.get_json())
array_id = int(array_id)
task_ids = [int(task_id) for task_id in data["task_ids"]]
task_resources_id = int(data["task_resources_id"])
workflow_run_id = int(data["workflow_run_id"])
task_condition = and_(
Task.id.in_(task_ids),
Task.status.in_([TaskStatus.REGISTERING, TaskStatus.ADJUSTING_RESOURCES]),
)
session = SessionLocal()
with session.begin():
# Acquire locks on tasks to be updated
task_locks = (
select(Task.id)
.where(task_condition)
.with_for_update()
.execution_options(synchronize_session=False)
)
session.execute(task_locks)
# update task status to acquire lock
update_stmt = (
update(Task)
.where(task_condition)
.values(
status=TaskStatus.QUEUED,
status_date=func.now(),
num_attempts=(Task.num_attempts + 1),
)
)
session.execute(update_stmt)
# now insert them into task instance
insert_stmt = insert(TaskInstance).from_select(
# columns map 1:1 to selected rows
[
"task_id",
"workflow_run_id",
"array_id",
"task_resources_id",
"array_batch_num",
"array_step_id",
"status",
"status_date",
],
# select statement
select(
# unique id
Task.id.label("task_id"),
# static associations
literal_column(str(workflow_run_id)).label("workflow_run_id"),
literal_column(str(array_id)).label("array_id"),
literal_column(str(task_resources_id)).label("task_resources_id"),
# batch info
select(func.coalesce(func.max(TaskInstance.array_batch_num) + 1, 1))
.where((TaskInstance.array_id == array_id))
.label("array_batch_num"),
(func.row_number().over(order_by=Task.id) - 1).label("array_step_id"),
# status columns
literal_column(f"'{TaskInstanceStatus.QUEUED}'").label("status"),
func.now().label("status_date"),
)
.where(Task.id.in_(task_ids), Task.status == TaskStatus.QUEUED)
.with_for_update(),
# no python side defaults. Server defaults only
include_defaults=False,
)
session.execute(insert_stmt)
with session.begin():
tasks_by_status_query = (
select(Task.status, Task.id)
.where(Task.id.in_(task_ids))
.with_for_update()
.order_by(
Task.status
) # This line is optional but helps in organizing the result
)
result_dict = defaultdict(list)
for row in session.execute(tasks_by_status_query):
result_dict[row[0]].append(row[1])
resp = jsonify(tasks_by_status=result_dict)
resp.status_code = StatusCodes.OK
return resp
@api_v1_blueprint.route("/array/<array_id>/transition_to_launched", methods=["POST"])
@api_v2_blueprint.route("/array/<array_id>/transition_to_launched", methods=["POST"])
[docs]
def transition_array_to_launched(array_id: int) -> Any:
"""Transition TIs associated with an array_id and batch_num to launched."""
structlog.contextvars.bind_contextvars(array_id=array_id)
data = cast(Dict, request.get_json())
batch_num = data["batch_number"]
next_report = data["next_report_increment"]
session = SessionLocal()
with session.begin():
# Acquire a lock and update tasks to launched
task_ids_query = (
select(TaskInstance.task_id)
.where(
TaskInstance.array_id == array_id,
TaskInstance.array_batch_num == batch_num,
)
.execution_options(synchronize_session=False)
)
task_ids = session.execute(task_ids_query).scalars()
task_condition = and_(
Task.array_id == array_id,
Task.id.in_(task_ids),
Task.status == TaskStatus.INSTANTIATING,
)
task_locks = (
select(Task.id)
.where(task_condition)
.with_for_update()
.execution_options(synchronize_session=False)
)
session.execute(task_locks)
update_task_stmt = (
update(Task)
.where(task_condition)
.values(status=TaskStatus.LAUNCHED, status_date=func.now())
).execution_options(synchronize_session=False)
session.execute(update_task_stmt)
# Update the task instances in a separate session
_update_task_instance(array_id, batch_num, next_report)
resp = jsonify()
resp.status_code = StatusCodes.OK
return resp
[docs]
def _update_task_instance(array_id: int, batch_num: int, next_report: int) -> None:
task_instance_condition = and_(
TaskInstance.array_id == array_id,
TaskInstance.status == TaskInstanceStatus.INSTANTIATED,
TaskInstance.array_batch_num == batch_num,
)
session = SessionLocal()
with session.begin():
# Acquire a lock and update tasks to launched
task_instance_ids_query = (
select(TaskInstance.id)
.where(task_instance_condition)
.with_for_update()
.execution_options(synchronize_session=False)
)
session.execute(task_instance_ids_query)
# Transition all the task instances in the batch
# Bypassing the ORM for performance reasons.
update_stmt = (
update(TaskInstance)
.where(task_instance_condition)
.values(
status=TaskInstanceStatus.LAUNCHED,
submitted_date=func.now(),
status_date=func.now(),
report_by_date=add_time(next_report),
)
).execution_options(synchronize_session=False)
session.execute(update_stmt)
@api_v1_blueprint.route("/array/<array_id>/log_distributor_id", methods=["POST"])
@api_v2_blueprint.route("/array/<array_id>/log_distributor_id", methods=["POST"])
[docs]
def log_array_distributor_id(array_id: int) -> Any:
"""Add distributor_id, stderr/stdout paths to the DB for all TIs in an array."""
data = request.get_json()
id_lst = list(data.keys())
where_condition = and_(
TaskInstance.id.in_(id_lst),
TaskInstance.array_id == array_id,
)
# Prepare to acquire locks on the task instances
task_instance_ids_query = (
select(TaskInstance.id)
.where(where_condition)
.with_for_update()
.execution_options(synchronize_session=False)
)
# Prepare the case statement for dynamic updating based on conditions
case_stmt = case(
*[
(TaskInstance.id == int(task_instance_id), distributor_id)
for task_instance_id, distributor_id in data.items()
],
else_=TaskInstance.distributor_id,
)
# Acquire locks and update TaskInstances
session = SessionLocal()
with session.begin():
# locks for the updates
session.execute(task_instance_ids_query)
# Using the session to construct an update statement for ORM objects
update_stmt = (
update(TaskInstance)
.where(where_condition)
.values(distributor_id=case_stmt)
.execution_options(synchronize_session="fetch")
)
# updates
session.execute(update_stmt)
resp = jsonify(success=True)
resp.status_code = StatusCodes.OK
return resp