Source code for sqlalchemy_jdbcapi.dialects.oracle

"""
Oracle JDBC dialect for SQLAlchemy.

Provides full Oracle Database support through JDBC, compatible with SQLAlchemy 2.0+.
"""

from __future__ import annotations

import logging
import re
from typing import Any

from sqlalchemy import exc, sql, util
from sqlalchemy.dialects.oracle.base import OracleDialect
from sqlalchemy.engine import Connection

from .base import BaseJDBCDialect, JDBCDriverConfig

logger = logging.getLogger(__name__)


[docs] class OracleDialect(BaseJDBCDialect, OracleDialect): # type: ignore """ Oracle Database dialect using JDBC driver. Supports Oracle-specific features including: - Sequences - Synonyms - Database links - Packages - Custom types Connection URL formats: jdbcapi+oracle://user:password@host:1521/database jdbcapi+oracle://user:password@host:1521/SID jdbcapi+oraclejdbc://user:password@host:1521/service_name # Alias For TNS connections: jdbcapi+oracle://user:password@tnsname """ name = "oracle" driver = "jdbcapi" # Oracle-specific capabilities supports_sequences = True supports_native_boolean = False # Oracle < 23c doesn't have native boolean supports_identity_columns = True # Oracle 12c+ # Override column specifications for JDBC type handling colspecs = util.update_copy( OracleDialect.colspecs, # type: ignore { # Add JDBC-specific type mappings here if needed }, )
[docs] @classmethod def get_driver_config(cls) -> JDBCDriverConfig: """Get Oracle JDBC driver configuration.""" return JDBCDriverConfig( driver_class="oracle.jdbc.OracleDriver", jdbc_url_template="jdbc:oracle:thin:@{host}:{port}/{database}", default_port=1521, supports_transactions=True, supports_schemas=True, supports_sequences=True, )
[docs] def create_connect_args(self, url: Any) -> tuple[list[Any], dict[str, Any]]: """ Create connection arguments from SQLAlchemy URL. Handles various Oracle connection formats including TNS names. """ config = self.get_driver_config() # Check if this is a TNS name (no port specified) if url.port is None and url.host and "/" not in url.host: # TNS name format jdbc_url = f"jdbc:oracle:thin:@{url.host}" if url.database: jdbc_url = f"{jdbc_url}/{url.database}" else: # Standard format jdbc_url = config.format_jdbc_url( host=url.host or "localhost", port=url.port, database=url.database, query_params=dict(url.query) if url.query else None, ) logger.debug(f"Creating Oracle connection to: {jdbc_url}") # Build driver arguments driver_args: dict[str, Any] = {} if url.username: driver_args["user"] = url.username if url.password: driver_args["password"] = url.password # Add query parameters as connection properties if url.query: driver_args.update(url.query) kwargs = { "jclassname": config.driver_class, "url": jdbc_url, "driver_args": driver_args if driver_args else None, } return ([], kwargs)
[docs] def initialize(self, connection: Connection) -> None: """Initialize Oracle connection.""" super().initialize(connection) logger.debug("Initialized Oracle JDBC dialect")
def _get_server_version_info(self, connection: Connection) -> tuple[int, ...]: """ Get Oracle server version. Returns: Tuple of version numbers (e.g., (19, 3, 0)) """ try: banner = connection.execute( sql.text("SELECT BANNER FROM v$version") ).scalar() if banner: # Parse version from string like: # "Oracle Database 19c Enterprise Edition Release 19.0.0.0.0 - Production" match = re.search(r"Release ([\d\.]+)", banner) if match: version_str = match.group(1) parts = version_str.split(".") return tuple(int(p) for p in parts[:3]) except exc.DBAPIError as e: logger.warning(f"Failed to get Oracle server version: {e}") # Default fallback return (11, 0, 0) @property def _is_oracle_8(self) -> bool: """Check if connected to Oracle 8 (legacy support).""" return getattr(self, "_server_version_info", (11, 0, 0))[0] < 9 def _check_max_identifier_length(self, connection: Connection) -> int | None: """ Get maximum identifier length for this Oracle version. Oracle 12.2+ supports 128 characters, earlier versions support 30. """ version = getattr(self, "_server_version_info", (11, 0, 0)) if version >= (12, 2): return 128 return 30
[docs] def do_ping(self, dbapi_connection: Any) -> bool: """ Check if Oracle connection is alive. Uses Oracle's dual table for efficiency. """ try: cursor = dbapi_connection.cursor() cursor.execute("SELECT 1 FROM DUAL") cursor.close() return True except Exception as e: logger.debug(f"Oracle ping failed: {e}") return False
# Export the dialect dialect = OracleDialect