Source code for server.web.db.dns

"""dns.py — generic SQLAlchemy engine factory with DNS-aware pooling.

`get_dns_engine()` behaves exactly like :pyfunc:`sqlalchemy.create_engine`, **but**
ensures that every connection the pool hands out is opened against the *current*
IP address for the hostname in your database URL.  When the DNS A-record
changes, stale sockets are invalidated and new ones are created transparently.

Key points:
~~~~~~~~~~~
* Works with any **synchronous** SQLAlchemy dialect – no hard dependency on a
  particular DB-API.
* Accepts the full argument surface of :pyfunc:`sqlalchemy.create_engine`.
  The function injects its own `creator` for DNS-aware pooling; supplying
  *your own* `creator` will raise a clear `ValueError`, because a custom
  creator would defeat the point of DNS-aware pooling.
* Opens **zero** extra database connections while building the engine.
* Includes ``clear_dns_cache()`` for tests.

Example:
~~~~~~~~
>>> from dns_pool import get_dns_engine, clear_dns_cache
>>> engine = get_dns_engine(
...     "postgresql+psycopg2://user:pass@db.internal/mydb",
...     echo=True, pool_size=10, pool_timeout=20
... )

If the host's IP address changes at runtime, the very next checkout from the
pool raises ``DisconnectionError``; SQLAlchemy automatically retries and a new
connection is opened to the fresh IP.
"""

from __future__ import annotations

import importlib
import ipaddress
import logging
import random
import threading
import time
from types import ModuleType
from typing import Any, Dict, Optional, Tuple, Type

from dns import resolver
from sqlalchemy import create_engine, event, exc
from sqlalchemy.engine import URL, Engine, make_url
from sqlalchemy.engine.interfaces import DBAPIConnection, Dialect
from sqlalchemy.pool import _ConnectionRecord

__all__ = ["get_dns_engine", "clear_dns_cache", "get_ip_with_ttl"]

logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# DNS-resolution helpers
# ---------------------------------------------------------------------------
_DNS_CACHE: Dict[str, Tuple[str, float, int]] = (
    {}
)  # host -> (ip, expiry, failure_count)
_CACHE_LOCK = threading.RLock()
_RESOLVER_POOL = threading.local()  # Thread-local resolver instances
_DEFAULT_MAX_TTL = 300  # seconds


def _get_thread_local_resolver(
    nameservers: Optional[list[str]] = None,
) -> resolver.Resolver:
    """Get a thread-local DNS resolver instance to avoid recreation overhead.

    This reduces the overhead of creating new resolver instances for every DNS query
    and provides more consistent behavior under high concurrency.
    """
    if not hasattr(_RESOLVER_POOL, "resolver") or not hasattr(
        _RESOLVER_POOL, "nameservers"
    ):
        r = resolver.Resolver()  # honors /etc/resolv.conf by default
        if nameservers:
            r.nameservers = list(nameservers)
        # Configure for better performance under load
        r.timeout = 3  # Per-server timeout (reduced from default)
        r.lifetime = 12  # Total query timeout
        _RESOLVER_POOL.resolver = r
        _RESOLVER_POOL.nameservers = nameservers
    elif _RESOLVER_POOL.nameservers != nameservers:
        # Nameservers changed, update resolver
        if nameservers:
            _RESOLVER_POOL.resolver.nameservers = list(nameservers)
        else:
            # Reset to system defaults
            _RESOLVER_POOL.resolver = resolver.Resolver()
        _RESOLVER_POOL.nameservers = nameservers

    return _RESOLVER_POOL.resolver


def _resolve_with_retry(
    host: str,
    timeout_seconds: int,
    nameservers: Optional[list[str]] = None,
    max_retries: int = 3,
) -> Tuple[str, int]:
    """Resolve host with exponential backoff retry logic.

    This addresses transient DNS failures by implementing:
    - Exponential backoff with jitter
    - Thread-local resolver instances
    - Configurable retry attempts
    """
    r = _get_thread_local_resolver(nameservers)

    last_exception = None
    for attempt in range(max_retries):
        try:
            ans = r.resolve(host, "A", lifetime=timeout_seconds, search=False)
            ttl = getattr(ans.rrset, "ttl", None) or _DEFAULT_MAX_TTL
            if attempt > 0:
                logger.info(
                    f"DNS resolution succeeded for {host} on retry {attempt + 1}/{max_retries}"
                )
            return ans[0].address, int(ttl)
        except Exception as e:
            last_exception = e
            if attempt < max_retries - 1:
                # Exponential backoff with jitter to avoid thundering herd
                delay = (2**attempt) * 0.1 + random.uniform(0, 0.1)
                logger.debug(
                    f"DNS retry {attempt + 1}/{max_retries} for {host} "
                    f"after {delay:.2f}s (error: {e})"
                )
                time.sleep(delay)
            else:
                logger.warning(
                    f"DNS resolution failed for {host} after {max_retries} attempts: {e}"
                )

    if last_exception is not None:
        raise last_exception
    else:
        raise RuntimeError("DNS resolution failed for unknown reason")


def _resolve(
    host: str, timeout_seconds: int, nameservers: list[str] | None
) -> Tuple[str, int]:
    """Resolve host to (ip, ttl_seconds) with configurable timeout and nameservers.

    Uses a single total lifetime budget; dnspython will iterate configured
    nameservers within that lifetime. Search domains are disabled.

    This function now uses the enhanced retry logic by default.
    """
    return _resolve_with_retry(host, timeout_seconds, nameservers, max_retries=3)


def _cached_ip(
    host: str,
    dns_timeout: int = 12,
    dns_nameservers: list[str] | None = None,
    dns_grace_ttl: int = 30,
    dns_max_retries: int = 3,
    dns_extend_grace: bool = True,
) -> Tuple[str, int]:
    """Enhanced DNS resolution with improved caching and retry logic.

    New features:
    - Retry with exponential backoff
    - Extended grace period on repeated failures
    - Thread-local resolver instances
    - Failure count tracking
    """
    now = time.time()

    # Check cache first
    with _CACHE_LOCK:
        cache_entry = _DNS_CACHE.get(host, (None, 0.0, 0))
        cached_ip, exp, failure_count = cache_entry

        if cached_ip and exp > now:
            return cached_ip, int(exp - now)

    # Attempt DNS resolution with retry
    try:
        ip, ttl = _resolve_with_retry(
            host,
            timeout_seconds=dns_timeout,
            nameservers=dns_nameservers,
            max_retries=dns_max_retries,
        )

        # Success - reset failure count and cache the result
        with _CACHE_LOCK:
            _DNS_CACHE[host] = (ip, now + min(ttl, _DEFAULT_MAX_TTL), 0)

        if failure_count > 0:
            logger.info(
                f"DNS resolution recovered for {host} -> {ip} (TTL: {ttl}s) "
                f"after {failure_count} failures"
            )
        else:
            logger.debug(f"DNS resolved {host} -> {ip} (TTL: {ttl}s)")
        return ip, ttl

    except Exception as err:
        logger.exception(f"DNS resolve failed for {host}: {err}")

        # Enhanced fallback logic
        if cached_ip:
            # Calculate extended grace period based on failure history
            base_grace = dns_grace_ttl
            if dns_extend_grace and failure_count > 0:
                # Extend grace period for repeated failures (max 5 minutes)
                extended_grace = min(base_grace * (2 ** min(failure_count, 4)), 300)
                logger.info(
                    f"Using cached IP {cached_ip} for {host} with extended grace period "
                    f"{extended_grace}s (failure #{failure_count + 1})"
                )
                grace_period = extended_grace
            else:
                grace_period = base_grace
                logger.info(
                    f"Using cached IP {cached_ip} for {host} with {grace_period}s grace period"
                )

            # Update failure count in cache
            with _CACHE_LOCK:
                _DNS_CACHE[host] = (cached_ip, now + grace_period, failure_count + 1)

            return cached_ip, int(grace_period)

        # No cached IP available
        raise


[docs] def clear_dns_cache() -> None: """Flush the local DNS cache and reset failure counts (useful in unit tests).""" with _CACHE_LOCK: _DNS_CACHE.clear() # Also clear thread-local resolver instances if hasattr(_RESOLVER_POOL, "resolver"): delattr(_RESOLVER_POOL, "resolver") if hasattr(_RESOLVER_POOL, "nameservers"): delattr(_RESOLVER_POOL, "nameservers")
# --------------------------------------------------------------------------- # Dialect helpers – *zero* extra DB connections # --------------------------------------------------------------------------- _DIALECT_CACHE: Dict[str, Type[Dialect]] = {} def _get_dialect_cls(url: URL) -> Type[Dialect]: if url.drivername not in _DIALECT_CACHE: _DIALECT_CACHE[url.drivername] = url.get_dialect() # no I/O return _DIALECT_CACHE[url.drivername] def _import_explicit_driver(drivername: str) -> ModuleType | None: if "+" not in drivername: return None _dialect, driver = drivername.split("+", 1) try: return importlib.import_module(driver) except ModuleNotFoundError as exc: raise RuntimeError( f"Requested driver '{driver}' for dialect '{_dialect}' is not installed" ) from exc # --------------------------------------------------------------------------- # Engine factory # ---------------------------------------------------------------------------
[docs] def get_dns_engine( uri: str | URL, *engine_args: Any, dns_timeout: int = 12, dns_nameservers: list[str] | None = None, dns_grace_ttl: int = 30, dns_fallback: bool = True, dns_max_retries: int = 3, dns_extend_grace: bool = True, **engine_kwargs: Any, ) -> Engine: """Create a DNS-powered engine from a SQLAlchemy URL. With a normal SQLAlchemy engine, the hostname in the URL is looked up exactly once, on connection pool creation. If a DNS entry TTL expires or otherwise updates, this will prevent failover on reconnect, because the old, cached address will still be used. This function fixes the problem by intercepting the connection factory to resolve DNS on each reconnect. """ logger.debug( "get_dns_engine: " "uri=%s engine_kwargs=%s dns_timeout=%s " "dns_fallback=%s dns_max_retries=%s " "dns_extend_grace=%s", uri, engine_kwargs, dns_timeout, dns_fallback, dns_max_retries, dns_extend_grace, ) # Extract and save user connect_args user_connect_args = engine_kwargs.pop("connect_args", {}) # Connection pools are singletons created by engine spec (args & kwargs) # so they are effectively cached forever. # url = make_url(uri) host = url.host logger.debug("get_dns_engine: url=%s", str(url)) is_sqlite_url = url.drivername.startswith("sqlite") # If host is None AND it's not a SQLite URL, then it's an invalid URL for this function. if host is None and not is_sqlite_url: raise ValueError(f"URL has no host and is not a SQLite URL: {url!s}") # For SQLite, or a non-DNS hostname (localhost), or an IP address, use a standard engine. # For SQLite, host will be None. (host and _is_ip_address(host)) handles this. if ( is_sqlite_url or host == "127.0.0.1" or host == "localhost" or (host and _is_ip_address(host)) ): logger.info("Creating non-DNS engine for %s", url) # Restore the connect_args before creating the standard engine if user_connect_args: engine_kwargs["connect_args"] = user_connect_args return create_engine(str(url), *engine_args, **engine_kwargs) # If we've reached here, it's a DNS-aware case, and host should not be None. assert host is not None, "Host should be a string at this point for DNS resolution" # Make sure that we can create a fully valid SQLAlchemy engine # with the current URL, which will break if the URL is malformed. minimal_dialect = create_engine(str(url), _initialize=False).dialect dbapi_module = getattr(minimal_dialect, "dbapi", None) # This will throw errors in automap_base.prepare() if the dialect # chosen doesn't match actual database. Dialect names are things like: # 'mysql+pymysql', 'postgresql+psycopg2', not just 'mysql'. logger.info( "Creating DNS-aware engine for %s with dialect %r", url, minimal_dialect.name, ) # Get the connection factory method from the dialect. def creator() -> DBAPIConnection: """Return a connection to the currently resolved IP address.""" try: ip_now, _ = _cached_ip( host, dns_timeout=dns_timeout, dns_nameservers=dns_nameservers, dns_grace_ttl=dns_grace_ttl, dns_max_retries=dns_max_retries, dns_extend_grace=dns_extend_grace, ) connect_url = url.set(host=ip_now) except Exception: if dns_fallback: # Fall back to original hostname connect_url = url logger.warning( "DNS resolution failed; falling back to original hostname for %s", host, ) else: raise cargs, cparams = minimal_dialect.create_connect_args(connect_url) # Merge user-supplied connect_args into the parameters if user_connect_args: cparams.update(user_connect_args) logger.debug( "Updated connection params with user args: %s", user_connect_args ) module = dbapi_module or minimal_dialect.dbapi if module is None: raise RuntimeError( f"Could not determine DBAPI module for dialect {minimal_dialect.name}" ) logger.info(f"Successfully created DNS-aware engine for {url}.") if user_connect_args: logger.debug(f"Applied user_connect_args: {user_connect_args}") return module.connect(*cargs, **cparams) # No need for placeholder URL when using custom creator - just use the original URL # This allows OpenTelemetry to properly read the real connection details engine = create_engine( str(url), *engine_args, future=True, creator=creator, **engine_kwargs, ) # --------------------------------------------------------------- # Pool listeners: invalidate when the A-record changes # --------------------------------------------------------------- @event.listens_for(engine, "connect") def _store_peer_ip( dbapi_conn: DBAPIConnection, record: _ConnectionRecord ) -> None: # type: ignore[func-returns-value] try: record.info["peer_ip"] = _cached_ip( host, dns_timeout=dns_timeout, dns_nameservers=dns_nameservers, dns_grace_ttl=dns_grace_ttl, dns_max_retries=dns_max_retries, dns_extend_grace=dns_extend_grace, )[0] except Exception: # If DNS fails here, keep existing info or store hostname marker record.info["peer_ip"] = record.info.get("peer_ip", host) # This makes sure the DB connection's IP matches current DNS. If it has changed, # drop the connection. The insert=True guarantees this listener is first so it # runs before any others on checkout. @event.listens_for(engine, "checkout", insert=True) def _ensure_ip_fresh( dbapi_conn: DBAPIConnection, record: _ConnectionRecord, proxy: Any ) -> None: # type: ignore[func-returns-value] try: current_ip = _cached_ip( host, dns_timeout=dns_timeout, dns_nameservers=dns_nameservers, dns_grace_ttl=dns_grace_ttl, dns_max_retries=dns_max_retries, dns_extend_grace=dns_extend_grace, )[0] except Exception: # If DNS refresh fails during checkout, keep the connection return if record.info.get("peer_ip") != current_ip: record.invalidate() raise exc.DisconnectionError( f"DNS changed; reconnecting ({record.info.get('peer_ip')}{current_ip})" ) return engine
[docs] def get_ip_with_ttl( host: str, dns_timeout: int = 12, dns_nameservers: list[str] | None = None, dns_grace_ttl: int = 30, dns_max_retries: int = 3, dns_extend_grace: bool = True, ) -> Tuple[str, int]: """Get the current IP address and TTL for a hostname. Args: host (str): The hostname to resolve dns_timeout (int): DNS query timeout in seconds dns_nameservers (list[str] | None): Custom nameservers to use dns_grace_ttl (int): Grace period for cached IPs during failures dns_max_retries (int): Maximum retry attempts for DNS resolution dns_extend_grace (bool): Whether to extend grace period on repeated failures Returns: Tuple[str, int]: A tuple of (ip_address, ttl_in_seconds) """ return _cached_ip( host, dns_timeout=dns_timeout, dns_nameservers=dns_nameservers, dns_grace_ttl=dns_grace_ttl, dns_max_retries=dns_max_retries, dns_extend_grace=dns_extend_grace, )
def _is_ip_address(host: str) -> bool: """Check if a string is a valid IP address.""" try: ipaddress.ip_address(host) return True except ValueError: return False