Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion datafaker/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

from sqlalchemy import Connection, insert, inspect
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateSchema, MetaData, Table
from sqlalchemy.schema import CreateColumn, CreateSchema, CreateTable, MetaData, Table

from datafaker.base import FileUploader, TableGenerator
from datafaker.settings import get_settings
Expand All @@ -24,6 +25,39 @@
RowCounts = Counter[str]


@compiles(CreateColumn, "duckdb")
def remove_serial(element: CreateColumn, compiler: Any, **kw: Any) -> str:
"""
Intercede in compilation for column creation, removing PostgreSQL's ``SERIAL``.

DuckDB does not understand ``SERIAL``, and we don't care about
autoincrementing in datafaker. Ideally ``duckdb_engine`` would remove
this for us, or DuckDB would implement ``SERIAL``
:param element: The CreateColumn being executed.
:param compiler: Actually a DDLCompiler, but that type is not exported.
:param kw: Further arguments.
:return: Corrected SQL.
"""
text: str = compiler.visit_create_column(element, **kw)
return text.replace(" SERIAL ", " INTEGER ")


@compiles(CreateTable, "duckdb")
def remove_on_delete_cascade(element: CreateTable, compiler: Any, **kw: Any) -> str:
"""
Intercede in compilation for column creation, removing ``ON DELETE CASCADE``.

DuckDB does not understand cascades, and we don't care about
that in datafaker. Ideally ``duckdb_engine`` would remove this for us.
:param element: The CreateTable being executed.
:param compiler: Actually a DDLCompiler, but that type is not exported.
:param kw: Further arguments.
:return: Corrected SQL.
"""
text: str = compiler.visit_create_table(element, **kw)
return text.replace(" ON DELETE CASCADE", "")


def create_db_tables(metadata: MetaData) -> None:
"""Create tables described by the sqlalchemy metadata object."""
settings = get_settings()
Expand Down
201 changes: 172 additions & 29 deletions datafaker/dump.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,181 @@
"""Data dumping functions."""
import csv
import io
from typing import TYPE_CHECKING
from abc import ABC, abstractmethod
from pathlib import Path

import pandas as pd
import sqlalchemy
from sqlalchemy.schema import MetaData

from datafaker.utils import create_db_engine, get_sync_engine, logger

if TYPE_CHECKING:
from _csv import Writer


def _make_csv_writer(file: io.TextIOBase) -> "Writer":
"""Make the standard CSV file writer."""
return csv.writer(file, quoting=csv.QUOTE_MINIMAL)


def dump_db_tables(
metadata: MetaData,
dsn: str,
schema: str | None,
table_name: str,
file: io.TextIOBase,
) -> None:
"""Output the table as CSV."""
if table_name not in metadata.tables:
logger.error("%s is not a table described in the ORM file", table_name)
return
table = metadata.tables[table_name]
csv_out = _make_csv_writer(file)
csv_out.writerow(table.columns.keys())
engine = get_sync_engine(create_db_engine(dsn, schema_name=schema))
with engine.connect() as connection:
result = connection.execute(sqlalchemy.select(table))
for row in result:
csv_out.writerow(row)

class TableWriter(ABC):
"""Writes a table out to a file."""

EXTENSION = ".csv"

def __init__(self, metadata: MetaData, dsn: str, schema: str | None) -> None:
"""
Initialize the TableWriter.

:param metadata: The metadata for our database.
:param dsn: The connection string for our database.
:param schema: The schema name for our database, or None for the default.
"""
self._metadata = metadata
self._dsn = dsn
self._schema = schema

def connect(self) -> sqlalchemy.engine.Connection:
"""Connect to the database."""
engine = get_sync_engine(create_db_engine(self._dsn, schema_name=self._schema))
return engine.connect()

@abstractmethod
def write_file(self, table: sqlalchemy.Table, filepath: Path) -> bool:
"""
Write the named table into the named file.

:param table: The table to output
:param dir: The directory to write into.
:return: ``true`` on success, otherwise ``false``.
"""

def write(self, table: sqlalchemy.Table, directory: Path) -> bool:
"""
Write the table into a directory with a filename based on the table's name.

:param table: The table to write out.
:param directory: The directory to write the table into.
:return: ``true`` on success, otherwise ``false``.
"""
tn = table.name
# DuckDB tables derived from files have confusing suffixes
# that we should probably remove
tn = tn.removesuffix(".csv")
tn = tn.removesuffix(".parquet")
return self.write_file(table, directory / f"{tn}{self.EXTENSION}")


class ParquetTableWriter(TableWriter):
"""Writes the table to a Parquet file."""

EXTENSION = ".parquet"

def write_file(self, table: sqlalchemy.Table, filepath: Path) -> bool:
"""
Write the named table into the named file.

:param table: The table to output
:param filename: The filename of the file to write to.
:return: ``true`` on success, otherwise ``false``.
"""
with self.connect() as connection:
dates = [
str(name)
for name, col in table.columns.items()
if isinstance(
col.type, (sqlalchemy.types.DATE, sqlalchemy.types.DATETIME)
)
]
df = pd.read_sql(
sql=f"SELECT * FROM {table.name}",
con=connection,
columns=[str(col.name) for col in table.columns.values()],
parse_dates=dates,
)
df.to_parquet(filepath)
return True


class DuckDbParquetTableWriter(ParquetTableWriter):
"""
Writes the table to a Parquet file using DuckDB SQL.

The Pandas method used by ParquetTableWriter currently
does not work with DuckDB.
"""

def write_file(self, table: sqlalchemy.Table, filepath: Path) -> bool:
"""
Write the named table into the named file.

:param table: The table to output
:param filename: The filename of the file to write to.
:return: ``true`` on success, otherwise ``false``.
"""
with self.connect() as connection:
result = connection.execute(
sqlalchemy.text(
# We need the double quotes to get DuckDB to read the table not the file.
f"COPY \"{table.name}\" TO '{filepath}' (FORMAT PARQUET)"
)
)
return result is not None


def get_parquet_table_writer(
metadata: MetaData, dsn: str, schema: str | None
) -> TableWriter:
"""
Get a ``TableWriter`` that writes parquet files.

:param metadata: The database metadata containing the tables to be dumped to files.
:param dsn: The database connection string.
:param schema: The schema name, if required.
:return: ``TableWriter`` to write a parquet file.
"""
if dsn.startswith("duckdb:"):
return DuckDbParquetTableWriter(metadata, dsn, schema)
return ParquetTableWriter(metadata, dsn, schema)


class TableWriterIO(TableWriter):
"""Writes the table to an output object."""

@abstractmethod
def write_io(self, table: sqlalchemy.Table, out: io.TextIOBase) -> bool:
"""
Write the named table into the named file.

:param table: The table to output
:param filename: The filename of the file to write to.
:return: ``true`` on success, otherwise ``false``.
"""

def write_file(self, table: sqlalchemy.Table, filepath: Path) -> bool:
"""
Write the named table into the named file.

:param table: The table to output
:param filename: The filename of the file to write to.
:return: ``true`` on success, otherwise ``false``.
"""
with open(filepath, "wt", newline="", encoding="utf-8") as out:
return self.write_io(table, out)


class CsvTableWriter(TableWriterIO):
"""Writes the table to a CSV file."""

def write_io(self, table: sqlalchemy.Table, out: io.TextIOBase) -> bool:
"""
Write the named table into the named file.

:param table: The table to output
:param filename: The filename of the file to write to.
:return: ``True`` on success, otherwise ``False``.
"""
if table.name not in self._metadata.tables:
logger.error("%s is not a table described in the ORM file", table.name)
return False
table = self._metadata.tables[table.name]
csv_out = csv.writer(out, quoting=csv.QUOTE_MINIMAL)
csv_out.writerow(table.columns.keys())
with self.connect() as connection:
result = connection.execute(sqlalchemy.select(table))
for row in result:
csv_out.writerow(row)
return True
42 changes: 21 additions & 21 deletions datafaker/generators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Generators write generator function definitions and queries into config.yaml."""

from collections.abc import Mapping
from functools import lru_cache

from datafaker.generators.base import (
Expand Down Expand Up @@ -28,26 +29,25 @@
)


# Using a cache instead of just initializing an object to avoid
# startup time being spent when it isn't needed.
@lru_cache(1)
def everything_factory() -> GeneratorFactory:
"""Get a factory that encapsulates all the other factories."""
def everything_factory(config: Mapping) -> GeneratorFactory:
"""
Get a factory that encapsulates all the other factories.

:param config: The ``config.yaml`` configuration.
"""
return MultiGeneratorFactory(
[
MimesisStringGeneratorFactory(),
MimesisIntegerGeneratorFactory(),
MimesisFloatGeneratorFactory(),
MimesisDateGeneratorFactory(),
MimesisDateTimeGeneratorFactory(),
MimesisTimeGeneratorFactory(),
ContinuousDistributionGeneratorFactory(),
ContinuousLogDistributionGeneratorFactory(),
ChoiceGeneratorFactory(),
ConstantGeneratorFactory(),
MultivariateNormalGeneratorFactory(),
MultivariateLogNormalGeneratorFactory(),
NullPartitionedNormalGeneratorFactory(),
NullPartitionedLogNormalGeneratorFactory(),
]
MimesisStringGeneratorFactory(),
MimesisIntegerGeneratorFactory(),
MimesisFloatGeneratorFactory(),
MimesisDateGeneratorFactory(),
MimesisDateTimeGeneratorFactory(),
MimesisTimeGeneratorFactory(),
ContinuousDistributionGeneratorFactory(),
ContinuousLogDistributionGeneratorFactory(),
ChoiceGeneratorFactory(),
ConstantGeneratorFactory(),
MultivariateNormalGeneratorFactory(),
MultivariateLogNormalGeneratorFactory(),
NullPartitionedNormalGeneratorFactory(config),
NullPartitionedLogNormalGeneratorFactory(config),
)
26 changes: 16 additions & 10 deletions datafaker/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]:
"""
return {}

def custom_queries(self) -> dict[str, dict[str, str]]:
def custom_queries(self) -> dict[str, dict[str, Any]]:
"""
Get the SQL queries to add to SRC_STATS.

Expand All @@ -95,7 +95,7 @@ def custom_queries(self) -> dict[str, dict[str, str]]:

For example {"myquery": {
"query": "SELECT one, too AS two FROM mytable WHERE too > 1",
"comment": "big enough one and two from table mytable"
"comments": ["big enough one and two from table mytable"]
}}
will populate SRC_STATS["myquery"]["results"][0]["one"]
and SRC_STATS["myquery"]["results"][0]["two"]
Expand Down Expand Up @@ -209,7 +209,7 @@ def __init__(
logger.debug("Custom query %s is '%s'", name, query)
self._custom_queries[name] = {
"query": query,
"comment": comments[0] if comments else None,
"comments": comments,
}

def function_name(self) -> str:
Expand All @@ -224,7 +224,7 @@ def select_aggregate_clauses(self) -> dict[str, dict[str, str]]:
"""Get the query fragments the generators need to call."""
return self._select_aggregate_clauses

def custom_queries(self) -> dict[str, dict[str, str]]:
def custom_queries(self) -> dict[str, dict[str, Any]]:
"""Get the queries the generators need to call."""
return self._custom_queries

Expand All @@ -248,7 +248,9 @@ class GeneratorFactory(ABC):

@abstractmethod
def get_generators(
self, columns: list[Column], engine: Engine
self,
columns: list[Column],
engine: Engine,
) -> Sequence[Generator]:
"""Get the generators appropriate to these columns."""

Expand Down Expand Up @@ -289,9 +291,13 @@ def __init__(
)
self.buckets: Sequence[int] = [0] * 10
for rb in raw_buckets:
if rb.b is not None:
bucket = min(9, max(0, int(rb.b) + 1))
self.buckets[bucket] += rb.f / count
try:
x = float(rb.b)
if x.is_integer():
bucket = min(9, max(0, int(x) + 1))
self.buckets[bucket] += rb.f / count
except TypeError:
pass
self.mean = mean
self.stddev = stddev

Expand Down Expand Up @@ -350,7 +356,7 @@ def fit_from_values(self, values: list[float]) -> float:
class MultiGeneratorFactory(GeneratorFactory):
"""A composite factory."""

def __init__(self, factories: list[GeneratorFactory]):
def __init__(self, *factories: GeneratorFactory):
"""Initialise a MultiGeneratorFactory."""
super().__init__()
self.factories = factories
Expand Down Expand Up @@ -404,7 +410,7 @@ class ConstantGeneratorFactory(GeneratorFactory):
"""Just the null generator."""

def get_generators(
self, columns: list[Column], engine: Engine
self, columns: list[Column], _engine: Engine
) -> Sequence[Generator]:
"""Get the generators appropriate for these columns."""
if len(columns) != 1:
Expand Down
Loading
Loading