Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
Version history
===============

**4.0.0rc1**

- Added python enum generation for native database ENUM types (e.g., PostgreSQL / MySQL ENUM)
Retained synthetic Python enum generation from CHECK constraints with IN clauses (e.g., ``column IN ('val1', 'val2', ...)``)
Use ``--options nonativeenums`` to disable enum generation for native database enums
Use ``--options nosyntheticenums`` to disable enum generation for synthetic database enums (VARCHAR columns with check constraints)
(PR by @sheinbergon)

**3.2.0**

- Dropped support for Python 3.9
Expand Down
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ values must be delimited by commas, e.g. ``--options noconstraints,nobidi``):
* ``noconstraints``: ignore constraints (foreign key, unique etc.)
* ``nocomments``: ignore table/column comments
* ``noindexes``: ignore indexes
* ``nonativeenums``: don't generate Python enum classes for native database ENUM types (e.g., PostgreSQL ENUM); use string-based SQLAlchemy Enum instead (legacy name: ``noenums``)
* ``nosyntheticenums``: don't generate Python enum classes from CHECK constraints with IN clauses (e.g., ``column IN ('value1', 'value2', ...)``); preserves CHECK constraints as-is
* ``noidsuffix``: prevent the special naming logic for single column many-to-one
and one-to-one relationships (see `Relationship naming logic`_ for details)
* ``include_dialect_options``: render a table' dialect options, such as ``starrocks_partition`` for StarRocks' specific options.
Expand Down
213 changes: 184 additions & 29 deletions src/sqlacodegen/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class TablesGenerator(CodeGenerator):
"noindexes",
"noconstraints",
"nocomments",
"nonativeenums",
"nosyntheticenums",
"include_dialect_options",
"keep_dialect_types",
}
Expand All @@ -148,6 +150,11 @@ def __init__(
# Keep dialect-specific types instead of adapting to generic SQLAlchemy types
self.keep_dialect_types: bool = "keep_dialect_types" in self.options

# Track Python enum classes: maps (table_name, column_name) -> enum_class_name
self.enum_classes: dict[tuple[str, str], str] = {}
# Track enum values: maps enum_class_name -> list of values
self.enum_values: dict[str, list[str]] = {}

@property
def views_supported(self) -> bool:
return True
Expand Down Expand Up @@ -192,19 +199,22 @@ def generate(self) -> str:
models: list[Model] = self.generate_models()

# Render module level variables
variables = self.render_module_variables(models)
if variables:
if variables := self.render_module_variables(models):
sections.append(variables + "\n")

# Render enum classes
if enum_classes := self.render_enum_classes():
sections.append(enum_classes + "\n")

# Render models
rendered_models = self.render_models(models)
if rendered_models:
if rendered_models := self.render_models(models):
sections.append(rendered_models)

# Render collected imports
groups = self.group_imports()
imports = "\n\n".join("\n".join(line for line in group) for group in groups)
if imports:
if imports := "\n\n".join(
"\n".join(line for line in group) for group in groups
):
sections.insert(0, imports)

return "\n\n".join(sections) + "\n"
Expand Down Expand Up @@ -324,7 +334,8 @@ def get_collection(package: str) -> list[str]:
return collection

for package in sorted(self.imports):
imports = ", ".join(sorted(self.imports[package]))
imports_list = sorted(self.imports[package])
imports = ", ".join(imports_list)

collection = get_collection(package)
collection.append(f"from {package} import {imports}")
Expand Down Expand Up @@ -467,7 +478,7 @@ def render_column(
# Render the column type if there are no foreign keys on it or any of them
# points back to itself
if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
args.append(self.render_column_type(column.type))
args.append(self.render_column_type(column.type, column))

for fk in dedicated_fks:
args.append(self.render_constraint(fk))
Expand Down Expand Up @@ -528,7 +539,20 @@ def render_column_callable(self, is_table: bool, *args: Any, **kwargs: Any) -> s
else:
return render_callable("mapped_column", *args, kwargs=kwargs)

def render_column_type(self, coltype: TypeEngine[Any]) -> str:
def render_column_type(
self, coltype: TypeEngine[Any], column: Column[Any] | None = None
) -> str:
# Check if this is an enum column with a Python enum class
if isinstance(coltype, Enum) and column is not None:
table_name = column.table.name
column_name = column.name
if (table_name, column_name) in self.enum_classes:
enum_class_name = self.enum_classes[(table_name, column_name)]
# Import SQLAlchemy Enum (will be handled in collect_imports)
self.add_import(Enum)
# Return the Python enum class as the type parameter
return f"Enum({enum_class_name})"

args = []
kwargs: dict[str, Any] = {}
sig = inspect.signature(coltype.__class__.__init__)
Expand Down Expand Up @@ -709,6 +733,87 @@ def find_free_name(

return name

def _enum_name_to_class_name(self, enum_name: str) -> str:
"""Convert a database enum name to a Python class name (PascalCase)."""
parts = []
for part in enum_name.split("_"):
if part:
parts.append(part.capitalize())
return "".join(parts)

def _create_enum_class(
self, table_name: str, column_name: str, values: list[str]
) -> str:
"""
Create a Python enum class name and register it.

Returns the enum class name to use in generated code.
"""
# Generate enum class name from table and column names
# Convert to PascalCase: user_status -> UserStatus
parts = []
for part in table_name.split("_"):
if part:
parts.append(part.capitalize())
for part in column_name.split("_"):
if part:
parts.append(part.capitalize())

base_name = "".join(parts)

# Ensure uniqueness
enum_class_name = base_name
counter = 1
while enum_class_name in self.enum_values:
# Check if it's the same enum (same values)
if self.enum_values[enum_class_name] == values:
# Reuse existing enum class
return enum_class_name
enum_class_name = f"{base_name}{counter}"
counter += 1

# Register the new enum class
self.enum_values[enum_class_name] = values
return enum_class_name

def render_enum_classes(self) -> str:
"""Render Python enum class definitions."""
if not self.enum_values:
return ""

self.add_module_import("enum")

enum_defs = []
for enum_class_name, values in sorted(self.enum_values.items()):
# Create enum members with valid Python identifiers
members = []
for value in values:
# Unescape SQL escape sequences (e.g., \' -> ')
# The value from the CHECK constraint has SQL escaping
unescaped_value = value.replace("\\'", "'").replace("\\\\", "\\")

# Create a valid identifier from the enum value
member_name = _re_invalid_identifier.sub("_", unescaped_value).upper()
if not member_name:
member_name = "EMPTY"
elif member_name[0].isdigit():
member_name = "_" + member_name
elif iskeyword(member_name):
member_name += "_"

# Re-escape for Python string literal
python_escaped = unescaped_value.replace("\\", "\\\\").replace(
"'", "\\'"
)
members.append(f" {member_name} = '{python_escaped}'")

enum_def = f"class {enum_class_name}(str, enum.Enum):\n" + "\n".join(
members
)
enum_defs.append(enum_def)

return "\n\n\n".join(enum_defs)

def fix_column_types(self, table: Table) -> None:
"""Adjust the reflected column types."""
# Detect check constraints for boolean and enum columns
Expand All @@ -718,34 +823,76 @@ def fix_column_types(self, table: Table) -> None:

# Turn any integer-like column with a CheckConstraint like
# "column IN (0, 1)" into a Boolean
match = _re_boolean_check_constraint.match(sqltext)
if match:
colname_match = _re_column_name.match(match.group(1))
if colname_match:
if match := _re_boolean_check_constraint.match(sqltext):
if colname_match := _re_column_name.match(match.group(1)):
colname = colname_match.group(3)
table.constraints.remove(constraint)
table.c[colname].type = Boolean()
continue

# Turn any string-type column with a CheckConstraint like
# "column IN (...)" into an Enum
match = _re_enum_check_constraint.match(sqltext)
if match:
colname_match = _re_column_name.match(match.group(1))
if colname_match:
colname = colname_match.group(3)
items = match.group(2)
if isinstance(table.c[colname].type, String):
table.constraints.remove(constraint)
if not isinstance(table.c[colname].type, Enum):
options = _re_enum_item.findall(items)
table.c[colname].type = Enum(
*options, native_enum=False
# Turn VARCHAR columns with CHECK constraints like "column IN ('a', 'b')"
# into synthetic Enum types with Python enum classes
if "nosyntheticenums" not in self.options:
if match := _re_enum_check_constraint.match(sqltext):
if colname_match := _re_column_name.match(match.group(1)):
colname = colname_match.group(3)
items = match.group(2)
if isinstance(table.c[colname].type, String):
if not isinstance(table.c[colname].type, Enum):
options = _re_enum_item.findall(items)
# Create Python enum class
enum_class_name = self._create_enum_class(
table.name, colname, options
)
self.enum_classes[(table.name, colname)] = (
enum_class_name
)
# Convert to Enum type but KEEP the constraint
table.c[colname].type = Enum(
*options, native_enum=False
)
continue

for column in table.c:
# Handle native database Enum types (e.g., PostgreSQL ENUM)
if (
"nonativeenums" not in self.options
and "noenums" not in self.options
and isinstance(column.type, Enum)
and column.type.enums
):
if column.type.name:
# Named enum - create shared enum class if not already created
if (table.name, column.name) not in self.enum_classes:
# Check if we've already created an enum for this name
existing_class = None
for (t, c), cls in self.enum_classes.items():
if cls == self._enum_name_to_class_name(column.type.name):
existing_class = cls
break

if existing_class:
enum_class_name = existing_class
else:
# Create new enum class from the enum's name
enum_class_name = self._enum_name_to_class_name(
column.type.name
)
# Register the enum values if not already registered
if enum_class_name not in self.enum_values:
self.enum_values[enum_class_name] = list(
column.type.enums
)

continue
self.enum_classes[(table.name, column.name)] = enum_class_name
else:
# Unnamed enum - create enum class per column
if (table.name, column.name) not in self.enum_classes:
enum_class_name = self._create_enum_class(
table.name, column.name, list(column.type.enums)
)
self.enum_classes[(table.name, column.name)] = enum_class_name

for column in table.c:
if not self.keep_dialect_types:
try:
column.type = self.get_adapted_type(column.type)
Expand Down Expand Up @@ -1326,6 +1473,14 @@ def get_type_qualifiers() -> tuple[str, TypeEngine[Any], str]:
return "".join(pre), column_type, "]" * post_size

def render_python_type(column_type: TypeEngine[Any]) -> str:
# Check if this is an enum column with a Python enum class
if isinstance(column_type, Enum):
table_name = column.table.name
column_name = column.name
if (table_name, column_name) in self.enum_classes:
enum_class_name = self.enum_classes[(table_name, column_name)]
return enum_class_name

if isinstance(column_type, DOMAIN):
column_type = column_type.data_type

Expand Down
8 changes: 1 addition & 7 deletions src/sqlacodegen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,4 @@ def decode_postgresql_sequence(clause: TextClause) -> tuple[str | None, str | No


def get_stdlib_module_names() -> set[str]:
major, minor = sys.version_info.major, sys.version_info.minor
if (major, minor) > (3, 9):
return set(sys.builtin_module_names) | set(sys.stdlib_module_names)
else:
from stdlib_list import stdlib_list

return set(sys.builtin_module_names) | set(stdlib_list(f"{major}.{minor}"))
return set(sys.builtin_module_names) | set(sys.stdlib_module_names)
Loading