"""Routes for Tasks."""
from http import HTTPStatus as StatusCodes
from typing import Any, cast, Dict
from flask import jsonify, request
import sqlalchemy
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError, OperationalError
from sqlalchemy.orm import Session
import structlog
from jobmon.core import constants
from jobmon.server.web.models.arg import Arg
from jobmon.server.web.models.task_template import TaskTemplate
from jobmon.server.web.models.task_template_version import TaskTemplateVersion
from jobmon.server.web.models.template_arg_map import TemplateArgMap
from jobmon.server.web.routes.v1 import api_v1_blueprint
from jobmon.server.web.routes.v2 import api_v2_blueprint
from jobmon.server.web.routes.v2 import SessionLocal
from jobmon.server.web.server_side_exception import InvalidUsage
[docs]
logger = structlog.get_logger(__name__)
@api_v1_blueprint.route("/task_template", methods=["POST"])
@api_v2_blueprint.route("/task_template", methods=["POST"])
[docs]
def get_task_template() -> Any:
"""Add a task template for a given tool to the database."""
# check input variable
data = cast(Dict, request.get_json())
try:
tool_version_id = int(data["tool_version_id"])
name = data["task_template_name"]
except Exception as e:
raise InvalidUsage(
f"{str(e)} in request to {request.path}", status_code=400
) from e
structlog.contextvars.bind_contextvars(tool_version_id=tool_version_id)
logger.info(f"Add task tamplate for tool_version_id {tool_version_id} ")
# add to DB
session = SessionLocal()
try:
with session.begin():
task_template = TaskTemplate(tool_version_id=tool_version_id, name=name)
session.add(task_template)
except sqlalchemy.exc.IntegrityError:
with session.begin():
select_stmt = select(TaskTemplate).where(
TaskTemplate.tool_version_id == tool_version_id,
TaskTemplate.name == name,
)
task_template = session.execute(select_stmt).scalars().one()
resp = jsonify(task_template_id=task_template.id)
resp.status_code = StatusCodes.OK
return resp
@api_v1_blueprint.route("/task_template/<task_template_id>/versions", methods=["GET"])
@api_v2_blueprint.route("/task_template/<task_template_id>/versions", methods=["GET"])
[docs]
def get_task_template_versions(task_template_id: int) -> Any:
"""Get the task_template_version."""
# get task template version object
structlog.contextvars.bind_contextvars(task_template_id=task_template_id)
logger.info(f"Getting task template version for task template: {task_template_id}")
session = SessionLocal()
with session.begin():
select_stmt = select(TaskTemplateVersion).where(
TaskTemplateVersion.task_template_id == task_template_id
)
ttvs = session.execute(select_stmt).scalars().all()
wire_obj = [ttv.to_wire_as_client_task_template_version() for ttv in ttvs]
resp = jsonify(task_template_versions=wire_obj)
resp.status_code = StatusCodes.OK
return resp
[docs]
def _add_or_get_arg(name: str, session: Session) -> Arg:
retries = 0
while retries <= 5:
try:
with session.begin():
arg = Arg(name=name)
session.add(arg)
break # Successfully added, break the loop
except IntegrityError:
with session.begin():
select_stmt = select(Arg).where(Arg.name == name)
arg = session.execute(select_stmt).scalars().one()
break # Successfully retrieved, break the loop
except OperationalError as e:
if "Deadlock" in str(e):
retries += 1
continue # Deadlock detected, retrying
else:
raise # For other OperationalErrors, propagate the exception
return arg
@api_v1_blueprint.route(
"/task_template/<task_template_id>/add_version", methods=["POST"]
)
@api_v2_blueprint.route(
"/task_template/<task_template_id>/add_version", methods=["POST"]
)
[docs]
def add_task_template_version(task_template_id: int) -> Any:
"""Add a tool to the database."""
# check input variables
structlog.contextvars.bind_contextvars(task_template_id=task_template_id)
data = cast(Dict, request.get_json())
try:
task_template_id = int(task_template_id)
node_args = data["node_args"]
task_args = data["task_args"]
op_args = data["op_args"]
command_template = data["command_template"].strip()
arg_mapping_hash = str(data["arg_mapping_hash"]).strip()
except Exception as e:
raise InvalidUsage(
f"{str(e)} in request to {request.path}", status_code=400
) from e
session = SessionLocal()
# populate the argument table
arg_mapping_dct: dict = {
constants.ArgType.NODE_ARG: [],
constants.ArgType.TASK_ARG: [],
constants.ArgType.OP_ARG: [],
}
for arg_name in node_args:
arg_mapping_dct[constants.ArgType.NODE_ARG].append(
_add_or_get_arg(arg_name, session)
)
for arg_name in task_args:
arg_mapping_dct[constants.ArgType.TASK_ARG].append(
_add_or_get_arg(arg_name, session)
)
for arg_name in op_args:
arg_mapping_dct[constants.ArgType.OP_ARG].append(
_add_or_get_arg(arg_name, session)
)
try:
with session.begin():
ttv = TaskTemplateVersion(
task_template_id=task_template_id,
command_template=command_template,
arg_mapping_hash=arg_mapping_hash,
)
session.add(ttv)
session.flush()
# get a lock
session.refresh(ttv, with_for_update=True)
for arg_type_id in arg_mapping_dct.keys():
for arg in arg_mapping_dct[arg_type_id]:
ctatm = TemplateArgMap(
task_template_version_id=ttv.id,
arg_id=arg.id,
arg_type_id=arg_type_id,
)
session.add(ctatm)
session.flush()
task_template_version = ttv.to_wire_as_client_task_template_version()
except sqlalchemy.exc.IntegrityError:
with session.begin():
# if another process is adding this task_template_version then this query should
# block until the template_arg_map has been populated and committed
select_stmt = select(TaskTemplateVersion).where(
TaskTemplateVersion.task_template_id == task_template_id,
TaskTemplateVersion.command_template == command_template,
TaskTemplateVersion.arg_mapping_hash == arg_mapping_hash,
)
ttv = session.execute(select_stmt).scalars().one()
task_template_version = ttv.to_wire_as_client_task_template_version()
resp = jsonify(task_template_version=task_template_version)
resp.status_code = StatusCodes.OK
return resp