Source code for sqlalchemy_jdbcapi.dialects.base

"""
Base JDBC dialect implementation following SOLID principles.

This module provides the abstract base class for all JDBC dialects,
implementing common functionality and defining the interface that
database-specific dialects must implement.
"""

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

from sqlalchemy import pool
from sqlalchemy.engine import Connection, Dialect, reflection
from sqlalchemy.engine.url import URL
from sqlalchemy.types import (
    BIGINT,
    BINARY,
    BOOLEAN,
    CHAR,
    DATE,
    DECIMAL,
    FLOAT,
    INTEGER,
    NUMERIC,
    REAL,
    SMALLINT,
    TIME,
    TIMESTAMP,
    VARBINARY,
    VARCHAR,
)

from ..jdbc.exceptions import DatabaseError, OperationalError

logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class JDBCDriverConfig:
    """
    Configuration for a JDBC driver.

    This encapsulates all driver-specific configuration in a single
    immutable object, following the Single Responsibility Principle.
    """

    driver_class: str
    """Fully qualified Java class name of the JDBC driver"""

    jdbc_url_template: str
    """Template for JDBC URL (e.g., 'jdbc:postgresql://{host}:{port}/{database}')"""

    default_port: int
    """Default port number for the database"""

    supports_transactions: bool = True
    """Whether the database supports transactions"""

    supports_schemas: bool = True
    """Whether the database supports schemas"""

    supports_sequences: bool = True
    """Whether the database supports sequences"""

    def format_jdbc_url(
        self,
        host: str,
        port: int | None,
        database: str | None,
        query_params: dict[str, Any] | None = None,
    ) -> str:
        """
        Format a JDBC connection URL.

        Args:
            host: Database host
            port: Database port (uses default_port if None)
            database: Database name
            query_params: Additional query parameters

        Returns:
            Formatted JDBC URL
        """
        port = port or self.default_port
        url = self.jdbc_url_template.format(
            host=host, port=port, database=database or ""
        )

        if query_params:
            params = "&".join(f"{k}={v}" for k, v in query_params.items())
            separator = "?" if "?" not in url else "&"
            url = f"{url}{separator}{params}"

        return url


[docs] class BaseJDBCDialect(Dialect, ABC): """ Abstract base class for JDBC-based SQLAlchemy dialects. This class implements the Template Method pattern, providing common JDBC functionality while allowing database-specific customization through abstract methods. Subclasses must implement: - get_driver_config(): Return driver configuration - _get_server_version_info(): Parse database version SQLAlchemy 2.0+ compatible with full type hints. """ # DB-API module (our custom JDBC bridge) driver = "jdbcapi" # SQLAlchemy capabilities supports_native_decimal = True supports_sane_rowcount = False supports_sane_multi_rowcount = False supports_unicode_binds = True supports_statement_cache = True supports_server_side_cursors = False # Connection pooling supports_native_boolean = True poolclass = pool.QueuePool
[docs] @classmethod def import_dbapi(cls) -> type: """ Return the DB-API module. Returns our custom JDBC bridge module that implements the Python DB-API 2.0 specification. """ from .. import jdbc return jdbc # type: ignore
[docs] @classmethod def dbapi(cls) -> type: """Deprecated: Use import_dbapi() instead.""" return cls.import_dbapi()
[docs] @classmethod @abstractmethod def get_driver_config(cls) -> JDBCDriverConfig: """ Get JDBC driver configuration for this dialect. This method must be implemented by each database dialect to provide driver-specific configuration. Returns: JDBCDriverConfig instance """ ...
[docs] def create_connect_args(self, url: URL) -> tuple[list[Any], dict[str, Any]]: """ Create connection arguments from SQLAlchemy URL. Converts a SQLAlchemy URL into arguments for our JDBC connect() function, following the Adapter pattern. Args: url: SQLAlchemy connection URL Returns: Tuple of (args, kwargs) for jdbc.connect() """ config = self.get_driver_config() # Build JDBC URL jdbc_url = config.format_jdbc_url( host=url.host or "localhost", port=url.port, database=url.database, query_params=dict(url.query) if url.query else None, ) logger.debug(f"Creating connection to: {jdbc_url}") # Build driver arguments driver_args: dict[str, Any] = {} if url.username: driver_args["user"] = url.username if url.password: driver_args["password"] = url.password # Add query parameters as connection properties if url.query: driver_args.update(url.query) # Connection arguments for jdbc.connect() kwargs = { "jclassname": config.driver_class, "url": jdbc_url, "driver_args": driver_args if driver_args else None, } return ([], kwargs)
[docs] def initialize(self, connection: Connection) -> None: """ Initialize a new connection. Called when a new connection is established to set up connection-specific settings. Args: connection: SQLAlchemy connection object """ super().initialize(connection) # Set up server version if not hasattr(self, "_server_version_info"): self._server_version_info = self._get_server_version_info(connection) logger.debug(f"Server version: {self._server_version_info}")
@abstractmethod def _get_server_version_info(self, connection: Connection) -> tuple[int, ...]: """ Get database server version information. This must be implemented by each dialect to parse version information in a database-specific way. Args: connection: SQLAlchemy connection Returns: Tuple of version numbers (e.g., (14, 5, 0)) """ ...
[docs] def is_disconnect(self, e: Exception, connection: Any, cursor: Any) -> bool: """ Check if an exception indicates a database disconnect. Args: e: Exception that occurred connection: Database connection cursor: Database cursor Returns: True if this is a disconnect error """ if isinstance(e, (DatabaseError, OperationalError)): error_str = str(e).lower() disconnect_indicators = [ "connection is closed", "cursor is closed", "connection reset", "broken pipe", "connection refused", "connection lost", "can't connect", "connection terminated", ] return any(indicator in error_str for indicator in disconnect_indicators) return False
[docs] def do_rollback(self, dbapi_connection: Any) -> None: """ Perform a rollback on the connection. Some JDBC drivers have issues with rollback, this can be overridden by subclasses. Args: dbapi_connection: DB-API connection object """ try: dbapi_connection.rollback() except Exception as e: logger.warning(f"Rollback failed: {e}")
[docs] def do_commit(self, dbapi_connection: Any) -> None: """ Perform a commit on the connection. Args: dbapi_connection: DB-API connection object """ dbapi_connection.commit()
[docs] def do_close(self, dbapi_connection: Any) -> None: """ Close the connection. Args: dbapi_connection: DB-API connection object """ dbapi_connection.close()
[docs] def do_ping(self, dbapi_connection: Any) -> bool: """ Check if connection is alive. Args: dbapi_connection: DB-API connection object Returns: True if connection is alive, False otherwise """ try: cursor = dbapi_connection.cursor() cursor.execute("SELECT 1") cursor.close() return True except Exception as e: logger.debug(f"Ping failed: {e}") return False
[docs] def get_isolation_level(self, dbapi_connection: Any) -> str | None: """ Get the current transaction isolation level. Args: dbapi_connection: DB-API connection object Returns: Isolation level name or None """ # This would need JDBC-specific implementation # Most JDBC connections support getTransactionIsolation() return None
[docs] def set_isolation_level(self, dbapi_connection: Any, level: str) -> None: """ Set the transaction isolation level. Args: dbapi_connection: DB-API connection object level: Isolation level to set """
# This would need JDBC-specific implementation # Most JDBC connections support setTransactionIsolation() # ======================================================================== # JDBC Reflection Methods - Using DatabaseMetaData API # ======================================================================== def _get_jdbc_metadata(self, connection: Connection) -> Any: """ Get JDBC DatabaseMetaData object from connection. Args: connection: SQLAlchemy connection Returns: JDBC DatabaseMetaData object """ # Get the raw JDBC connection dbapi_conn = connection.connection.dbapi_connection if hasattr(dbapi_conn, "_jdbc_connection"): jdbc_conn = dbapi_conn._jdbc_connection return jdbc_conn.getMetaData() raise OperationalError("Cannot access JDBC connection metadata") def _jdbc_type_to_sqlalchemy(self, jdbc_type_name: str, jdbc_type: int) -> Any: """ Convert JDBC type to SQLAlchemy type. Args: jdbc_type_name: JDBC type name (e.g., 'VARCHAR') jdbc_type: JDBC type code (from java.sql.Types) Returns: SQLAlchemy type instance """ # Map common JDBC type codes to SQLAlchemy types type_map = { -7: BOOLEAN, # BIT -6: SMALLINT, # TINYINT -5: BIGINT, # BIGINT -4: VARBINARY, # LONGVARBINARY -3: VARBINARY, # VARBINARY -2: BINARY, # BINARY -1: VARCHAR, # LONGVARCHAR 1: CHAR, # CHAR 2: NUMERIC, # NUMERIC 3: DECIMAL, # DECIMAL 4: INTEGER, # INTEGER 5: SMALLINT, # SMALLINT 6: FLOAT, # FLOAT 7: REAL, # REAL 8: FLOAT, # DOUBLE 12: VARCHAR, # VARCHAR 16: BOOLEAN, # BOOLEAN 91: DATE, # DATE 92: TIME, # TIME 93: TIMESTAMP, # TIMESTAMP } return type_map.get(jdbc_type, VARCHAR())
[docs] @reflection.cache def get_schema_names(self, connection: Connection, **kw: Any) -> list[str]: """ Get list of schema names using JDBC DatabaseMetaData. Args: connection: SQLAlchemy connection **kw: Additional keyword arguments Returns: List of schema names """ try: metadata = self._get_jdbc_metadata(connection) schemas = [] rs = metadata.getSchemas() while rs.next(): schema_name = rs.getString("TABLE_SCHEM") if schema_name: schemas.append(schema_name) rs.close() logger.debug(f"Found {len(schemas)} schemas") return schemas except Exception as e: logger.warning(f"Failed to get schema names: {e}") return []
[docs] @reflection.cache def get_table_names( self, connection: Connection, schema: str | None = None, **kw: Any ) -> list[str]: """ Get list of table names using JDBC DatabaseMetaData. Args: connection: SQLAlchemy connection schema: Schema name (None for default schema) **kw: Additional keyword arguments Returns: List of table names """ try: metadata = self._get_jdbc_metadata(connection) tables = [] # getTables(catalog, schemaPattern, tableNamePattern, types[]) rs = metadata.getTables(None, schema, "%", ["TABLE"]) while rs.next(): table_name = rs.getString("TABLE_NAME") if table_name: tables.append(table_name) rs.close() logger.debug(f"Found {len(tables)} tables in schema '{schema}'") return tables except Exception as e: logger.warning(f"Failed to get table names: {e}") return []
[docs] @reflection.cache def get_view_names( self, connection: Connection, schema: str | None = None, **kw: Any ) -> list[str]: """ Get list of view names using JDBC DatabaseMetaData. Args: connection: SQLAlchemy connection schema: Schema name (None for default schema) **kw: Additional keyword arguments Returns: List of view names """ try: metadata = self._get_jdbc_metadata(connection) views = [] rs = metadata.getTables(None, schema, "%", ["VIEW"]) while rs.next(): view_name = rs.getString("TABLE_NAME") if view_name: views.append(view_name) rs.close() logger.debug(f"Found {len(views)} views in schema '{schema}'") return views except Exception as e: logger.warning(f"Failed to get view names: {e}") return []
[docs] def has_table( self, connection: Connection, table_name: str, schema: str | None = None, **kw: Any, ) -> bool: """ Check if a table exists using JDBC DatabaseMetaData. Args: connection: SQLAlchemy connection table_name: Table name to check schema: Schema name (None for default schema) **kw: Additional keyword arguments Returns: True if table exists, False otherwise """ try: metadata = self._get_jdbc_metadata(connection) rs = metadata.getTables(None, schema, table_name, ["TABLE"]) exists = rs.next() rs.close() logger.debug(f"Table '{schema}.{table_name}' exists: {exists}") return exists except Exception as e: logger.warning(f"Failed to check table existence: {e}") return False
[docs] @reflection.cache def get_columns( self, connection: Connection, table_name: str, schema: str | None = None, **kw: Any, ) -> list[dict[str, Any]]: """ Get column definitions for a table using JDBC DatabaseMetaData. Args: connection: SQLAlchemy connection table_name: Table name schema: Schema name (None for default schema) **kw: Additional keyword arguments Returns: List of column dictionaries with keys: - name: Column name - type: SQLAlchemy type instance - nullable: Boolean - default: Default value (string or None) - autoincrement: Boolean """ try: metadata = self._get_jdbc_metadata(connection) columns = [] rs = metadata.getColumns(None, schema, table_name, "%") while rs.next(): column_name = rs.getString("COLUMN_NAME") data_type = rs.getInt("DATA_TYPE") type_name = rs.getString("TYPE_NAME") column_size = rs.getInt("COLUMN_SIZE") nullable = rs.getInt("NULLABLE") == 1 # DatabaseMetaData.columnNullable column_def = rs.getString("COLUMN_DEF") is_autoincrement = rs.getString("IS_AUTOINCREMENT") # Convert JDBC type to SQLAlchemy type sa_type = self._jdbc_type_to_sqlalchemy(type_name, data_type) # Apply size for character/binary types if hasattr(sa_type, "length") and column_size: if isinstance(sa_type, (VARCHAR, CHAR, VARBINARY, BINARY)): sa_type = type(sa_type)(length=column_size) columns.append( { "name": column_name, "type": sa_type, "nullable": nullable, "default": column_def, "autoincrement": is_autoincrement == "YES" if is_autoincrement else False, } ) rs.close() logger.debug( f"Found {len(columns)} columns for table '{schema}.{table_name}'" ) return columns except Exception as e: logger.warning(f"Failed to get columns: {e}") return []
[docs] @reflection.cache def get_pk_constraint( self, connection: Connection, table_name: str, schema: str | None = None, **kw: Any, ) -> dict[str, Any]: """ Get primary key constraint for a table using JDBC DatabaseMetaData. Args: connection: SQLAlchemy connection table_name: Table name schema: Schema name (None for default schema) **kw: Additional keyword arguments Returns: Dictionary with keys: - name: Constraint name - constrained_columns: List of column names """ try: metadata = self._get_jdbc_metadata(connection) pk_columns = [] pk_name = None rs = metadata.getPrimaryKeys(None, schema, table_name) while rs.next(): column_name = rs.getString("COLUMN_NAME") pk_name = rs.getString("PK_NAME") key_seq = rs.getInt("KEY_SEQ") pk_columns.append((key_seq, column_name)) rs.close() # Sort by KEY_SEQ to maintain correct column order pk_columns.sort(key=lambda x: x[0]) constrained_columns = [col for _, col in pk_columns] result = { "name": pk_name, "constrained_columns": constrained_columns, } logger.debug(f"Primary key for '{schema}.{table_name}': {result}") return result except Exception as e: logger.warning(f"Failed to get primary key: {e}") return {"name": None, "constrained_columns": []}
[docs] @reflection.cache def get_foreign_keys( self, connection: Connection, table_name: str, schema: str | None = None, **kw: Any, ) -> list[dict[str, Any]]: """ Get foreign key constraints for a table using JDBC DatabaseMetaData. Args: connection: SQLAlchemy connection table_name: Table name schema: Schema name (None for default schema) **kw: Additional keyword arguments Returns: List of dictionaries with keys: - name: Constraint name - constrained_columns: List of local column names - referred_schema: Referenced schema name - referred_table: Referenced table name - referred_columns: List of referenced column names """ try: metadata = self._get_jdbc_metadata(connection) fks: dict[str, dict[str, Any]] = {} rs = metadata.getImportedKeys(None, schema, table_name) while rs.next(): fk_name = rs.getString("FK_NAME") fk_column = rs.getString("FKCOLUMN_NAME") pk_table = rs.getString("PKTABLE_NAME") pk_schema = rs.getString("PKTABLE_SCHEM") pk_column = rs.getString("PKCOLUMN_NAME") key_seq = rs.getInt("KEY_SEQ") if fk_name not in fks: fks[fk_name] = { "name": fk_name, "constrained_columns": [], "referred_schema": pk_schema, "referred_table": pk_table, "referred_columns": [], "_seq": [], } fks[fk_name]["_seq"].append(key_seq) fks[fk_name]["constrained_columns"].append(fk_column) fks[fk_name]["referred_columns"].append(pk_column) rs.close() # Sort columns by KEY_SEQ result = [] for fk_name, fk_data in fks.items(): # Sort by sequence sorted_data = sorted( zip( fk_data["_seq"], fk_data["constrained_columns"], fk_data["referred_columns"], ) ) fk_data["constrained_columns"] = [col for _, col, _ in sorted_data] fk_data["referred_columns"] = [col for _, _, col in sorted_data] del fk_data["_seq"] result.append(fk_data) logger.debug( f"Found {len(result)} foreign keys for '{schema}.{table_name}'" ) return result except Exception as e: logger.warning(f"Failed to get foreign keys: {e}") return []
[docs] @reflection.cache def get_indexes( self, connection: Connection, table_name: str, schema: str | None = None, **kw: Any, ) -> list[dict[str, Any]]: """ Get indexes for a table using JDBC DatabaseMetaData. Args: connection: SQLAlchemy connection table_name: Table name schema: Schema name (None for default schema) **kw: Additional keyword arguments Returns: List of dictionaries with keys: - name: Index name - column_names: List of column names - unique: Boolean """ try: metadata = self._get_jdbc_metadata(connection) indexes: dict[str, dict[str, Any]] = {} # getIndexInfo(catalog, schema, table, unique, approximate) rs = metadata.getIndexInfo(None, schema, table_name, False, True) while rs.next(): index_name = rs.getString("INDEX_NAME") # Skip statistics if not index_name: continue column_name = rs.getString("COLUMN_NAME") non_unique = rs.getBoolean("NON_UNIQUE") ordinal_position = rs.getInt("ORDINAL_POSITION") if index_name not in indexes: indexes[index_name] = { "name": index_name, "column_names": [], "unique": not non_unique, "_positions": [], } indexes[index_name]["_positions"].append(ordinal_position) indexes[index_name]["column_names"].append(column_name) rs.close() # Sort columns by ordinal position result = [] for idx_name, idx_data in indexes.items(): sorted_data = sorted( zip(idx_data["_positions"], idx_data["column_names"]) ) idx_data["column_names"] = [col for _, col in sorted_data] del idx_data["_positions"] result.append(idx_data) logger.debug(f"Found {len(result)} indexes for '{schema}.{table_name}'") return result except Exception as e: logger.warning(f"Failed to get indexes: {e}") return []
[docs] @reflection.cache def get_unique_constraints( self, connection: Connection, table_name: str, schema: str | None = None, **kw: Any, ) -> list[dict[str, Any]]: """ Get unique constraints for a table. This is extracted from get_indexes() by filtering for unique indexes. Args: connection: SQLAlchemy connection table_name: Table name schema: Schema name (None for default schema) **kw: Additional keyword arguments Returns: List of dictionaries with keys: - name: Constraint name - column_names: List of column names """ try: indexes = self.get_indexes(connection, table_name, schema, **kw) unique_constraints = [ {"name": idx["name"], "column_names": idx["column_names"]} for idx in indexes if idx.get("unique", False) ] logger.debug( f"Found {len(unique_constraints)} unique constraints for '{schema}.{table_name}'" ) return unique_constraints except Exception as e: logger.warning(f"Failed to get unique constraints: {e}") return []
[docs] @reflection.cache def get_check_constraints( self, connection: Connection, table_name: str, schema: str | None = None, **kw: Any, ) -> list[dict[str, Any]]: """ Get check constraints for a table. Note: JDBC DatabaseMetaData doesn't have a standard method for check constraints. This returns an empty list. Database-specific dialects should override this method to query system tables if check constraint information is needed. Args: connection: SQLAlchemy connection table_name: Table name schema: Schema name (None for default schema) **kw: Additional keyword arguments Returns: List of dictionaries with keys: - name: Constraint name - sqltext: Check constraint SQL expression """ # JDBC doesn't provide standard access to check constraints # Subclasses can override to query database-specific system tables logger.debug( f"Check constraints not available via JDBC for '{schema}.{table_name}'" ) return []
[docs] def __repr__(self) -> str: """String representation of the dialect.""" return f"<{self.__class__.__name__}>"