Source code for sqlalchemy_jdbcapi.dialects.mysql
"""
MySQL and MariaDB JDBC dialects for SQLAlchemy.
Provides support for MySQL and MariaDB through JDBC.
"""
from __future__ import annotations
import logging
import re
from typing import Any
from sqlalchemy import exc, sql
from sqlalchemy.dialects.mysql.base import MySQLDialect as BaseMySQLDialect
from sqlalchemy.engine import Connection
from .base import BaseJDBCDialect, JDBCDriverConfig
logger = logging.getLogger(__name__)
[docs]
class MySQLDialect(BaseJDBCDialect, BaseMySQLDialect):
"""
MySQL dialect using JDBC driver.
Supports MySQL-specific features including:
- AUTO_INCREMENT
- Full-text indexes
- JSON columns (MySQL 5.7+)
- Spatial types
Connection URL format:
jdbcapi+mysql://user:password@host:3306/database
"""
name = "mysql"
driver = "jdbcapi"
# MySQL capabilities
supports_native_boolean = False # MySQL uses TINYINT(1)
supports_native_enum = True
supports_sequences = False # MySQL < 8.0 doesn't support sequences
supports_statement_cache = True
[docs]
@classmethod
def get_driver_config(cls) -> JDBCDriverConfig:
"""Get MySQL JDBC driver configuration."""
return JDBCDriverConfig(
driver_class="com.mysql.cj.jdbc.Driver", # MySQL Connector/J 8.0+
jdbc_url_template="jdbc:mysql://{host}:{port}/{database}",
default_port=3306,
supports_transactions=True,
supports_schemas=True,
supports_sequences=False,
)
def _detect_charset(self, connection: Connection) -> str:
"""Detect MySQL connection character set."""
try:
result = connection.exec_driver_sql(
"SELECT @@character_set_client"
).scalar()
return result or "utf8mb4"
except Exception as e:
logger.warning(f"Failed to detect charset: {e}")
return "utf8mb4" # Default to utf8mb4
[docs]
def initialize(self, connection: Connection) -> None:
"""Initialize MySQL connection."""
super().initialize(connection)
logger.debug("Initialized MySQL JDBC dialect")
def _get_server_version_info(self, connection: Connection) -> tuple[int, ...]:
"""
Get MySQL server version.
Returns:
Tuple of version numbers (e.g., (8, 0, 32))
"""
try:
result = connection.execute(sql.text("SELECT VERSION()")).scalar()
if result:
# Parse version from string like:
# "8.0.32" or "5.7.40-log"
match = re.search(r"(\d+)\.(\d+)\.(\d+)", result)
if match:
major = int(match.group(1))
minor = int(match.group(2))
patch = int(match.group(3))
return (major, minor, patch)
except exc.DBAPIError as e:
logger.warning(f"Failed to get MySQL server version: {e}")
# Default fallback
return (5, 7, 0)
[docs]
def do_ping(self, dbapi_connection: Any) -> bool:
"""Check if MySQL connection is alive."""
try:
cursor = dbapi_connection.cursor()
cursor.execute("SELECT 1")
cursor.close()
return True
except Exception as e:
logger.debug(f"MySQL ping failed: {e}")
return False
[docs]
class MariaDBDialect(MySQLDialect):
"""
MariaDB dialect using JDBC driver.
MariaDB is a MySQL fork with additional features.
This dialect extends MySQL with MariaDB-specific capabilities.
Connection URL format:
jdbcapi+mariadb://user:password@host:3306/database
"""
name = "mariadb"
[docs]
@classmethod
def get_driver_config(cls) -> JDBCDriverConfig:
"""Get MariaDB JDBC driver configuration."""
return JDBCDriverConfig(
driver_class="org.mariadb.jdbc.Driver",
jdbc_url_template="jdbc:mariadb://{host}:{port}/{database}",
default_port=3306,
supports_transactions=True,
supports_schemas=True,
supports_sequences=True, # MariaDB 10.3+ supports sequences
)
def _get_server_version_info(self, connection: Connection) -> tuple[int, ...]:
"""
Get MariaDB server version.
Returns:
Tuple of version numbers (e.g., (10, 11, 2))
"""
try:
result = connection.execute(sql.text("SELECT VERSION()")).scalar()
if result:
# Parse version from string like:
# "10.11.2-MariaDB" or "10.6.12-MariaDB-log"
match = re.search(r"(\d+)\.(\d+)\.(\d+)", result)
if match:
major = int(match.group(1))
minor = int(match.group(2))
patch = int(match.group(3))
return (major, minor, patch)
except exc.DBAPIError as e:
logger.warning(f"Failed to get MariaDB server version: {e}")
# Default fallback
return (10, 6, 0)
# Export dialects
dialect = MySQLDialect