Source code for client.tool_version

"""A logical instance of a project or model that will be run many times over."""

from __future__ import annotations

from http import HTTPStatus as StatusCodes
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type

import structlog

from jobmon.client.task_template import TaskTemplate
from jobmon.core.exceptions import InvalidResponse
from jobmon.core.requester import Requester
from jobmon.core.serializers import SerializeClientToolVersion

if TYPE_CHECKING:
    from jobmon.client.tool import Tool


[docs] logger = structlog.get_logger(__name__)
[docs] class ToolVersion: """Represents a logical instance of a project or model that will be run many times over.""" def __init__( self, tool_version_id: int, requester: Optional[Requester] = None ) -> None: """Instantiate a tool version. Args: tool_version_id: an integer id associated with a Tool requester: communicate with the FastApi services. """
[docs] self.id = tool_version_id
[docs] self._tool: Tool
[docs] self.task_templates: Dict[str, TaskTemplate] = {}
[docs] self.default_compute_resources_set: Dict[str, Dict[str, Any]] = {}
[docs] self.default_resource_scales_set: Dict[str, Dict[str, float]] = {}
[docs] self.default_cluster_name: str = ""
[docs] self.default_max_attempt: Optional[int] = None
if requester is None: requester = Requester.from_defaults()
[docs] self.requester = requester
@classmethod
[docs] def get_tool_version(cls: Type[ToolVersion], tool: Tool) -> ToolVersion: """Get a new instance of a ToolVersion from the database. Args: tool: a Tool to get a version from. tool_version_id: tool_version_id to get from the database. """ message = {"tool_id": tool.id} app_route = "/tool_version" return_code, response = tool.requester.send_request( app_route=app_route, message=message, request_type="post" ) if return_code != StatusCodes.OK: raise InvalidResponse( f"Unexpected status code {return_code} from POST request through route " f"{app_route}. Expected code 200. Response content: {response}" ) tool_version = cls.from_wire(response["tool_version"], tool) return tool_version
@classmethod
[docs] def from_wire(cls: Type[ToolVersion], wire_tuple: Tuple, tool: Tool) -> ToolVersion: """Convert from the wire format of ToolVersion to an instance. Args: wire_tuple: Wire format for ToolVersion defined in jobmon.serializers. tool: The Tool object to verify the right tool_version based on tool_id. """ tool_version_kwargs = SerializeClientToolVersion.kwargs_from_wire(wire_tuple) if tool_version_kwargs["tool_id"] != tool.id: raise ValueError( "tool_id in wire_tuple does not match tool object. " f"Expected {tool.id} in wire_tuple. Got " f"{tool_version_kwargs['tool_id']}" ) tool_version = cls(tool_version_kwargs["id"], requester=tool.requester) tool_version._tool = tool return tool_version
@property
[docs] def tool(self) -> Tool: """The Tool this ToolVersion is associated with.""" return self._tool
[docs] def load_task_templates(self) -> None: """Get all task_templates associated with this tool version from the database.""" app_route = f"/tool_version/{self.id}/task_templates" return_code, response = self.requester.send_request( app_route=app_route, message={}, request_type="get" ) if return_code != StatusCodes.OK: raise InvalidResponse( f"Unexpected status code {return_code} from POST request through route " f"{app_route}. Expected code 200. Response content: {response}" ) task_templates = [ TaskTemplate.from_wire(wire_tuple, self) for wire_tuple in response["task_templates"] ] for task_template in task_templates: self.task_templates[task_template.template_name] = task_template
[docs] def get_task_template(self, template_name: str) -> TaskTemplate: """Get a single task_template associated with this tool version from the database.""" task_template = self.task_templates.get(template_name) if task_template is None: task_template = TaskTemplate.get_task_template(self, template_name) task_template.load_task_template_versions() self.task_templates[template_name] = task_template return task_template
[docs] def update_default_compute_resources( self, cluster_name: str, **kwargs: Any ) -> None: """Update default compute resources in place only overridding specified keys. If no default cluster is specified when this method is called, cluster_name will become the default cluster. Args: cluster_name: name of cluster to modify default values for. **kwargs: any key/value pair you want to update specified as an argument. """ compute_resources = {cluster_name: kwargs} self.default_compute_resources_set.update(compute_resources)
[docs] def update_default_resource_scales(self, cluster_name: str, **kwargs: Any) -> None: """Update default resource scales in place only overridding specified keys. If no default cluster is specified when this method is called, cluster_name will become the default cluster. Args: cluster_name: name of cluster to modify default values for. **kwargs: any key/value pair you want to update specified as an argument. """ resource_scales = {cluster_name: kwargs} self.default_resource_scales_set.update(resource_scales)
[docs] def set_default_compute_resources_from_dict( self, cluster_name: str, compute_resources: Dict[str, Any] ) -> None: """Set default compute resources for a given cluster_name. If no default cluster is specified when this method is called, cluster_name will become the default cluster. Args: cluster_name: name of cluster to set default values for. compute_resources: dictionary of default compute resources to run tasks with. Can be overridden at task template or task level. dict of {resource_name: resource_value} """ self.default_compute_resources_set[cluster_name] = compute_resources
[docs] def set_default_resource_scales_from_dict( self, cluster_name: str, resource_scales: Dict[str, float] ) -> None: """Set default resource scales for a given cluster_name. If no default cluster is specified when this method is called, cluster_name will become the default cluster. Args: cluster_name: name of cluster to set default values for. resource_scales: dictionary of default resource scales to adjust task resources with. Can be overridden at task template or task level. dict of {resource_name: scale_value} """ self.default_resource_scales_set[cluster_name] = resource_scales
[docs] def set_default_max_attempts(self, value: int) -> None: """Set default max attempts at tool leve. Args: value: the default max attempts value. """ if value: self.default_max_attempt = value else: logger.info("The default_max_attempt for tool_version can not be None.")
[docs] def __repr__(self) -> str: """A representation string for a ToolVersion instance.""" return ( f"ToolVersion(tool_version_id={self.id}," f"task_templates: {[t for t in self.task_templates.keys()]})" )