Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions migra/changes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from collections import OrderedDict as od
from functools import partial

from networkx import lexicographical_topological_sort, DiGraph
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this use TopologicalSorter which is part of schemainspect: https://github.com/djrobstep/schemainspect/blob/066262d6fb4668f874925305a0b7dbb3ac866882/schemainspect/graphlib/__init__.py#L38 ?

Also, since I'm not sure migra/schemainspect are maintained. Would you have any interest opening similar PRs for this and djrobstep/schemainspect#90 against https://github.com/mmkal/pgkit? I ported migra and schemainspect to typescript in that repo, and I intend to continue to maintain them.


import schemainspect

from .statements import Statements
Expand All @@ -21,6 +23,7 @@
"collations",
"rlspolicies",
"triggers",
"comments",
]
PK = "PRIMARY KEY"

Expand Down Expand Up @@ -227,11 +230,23 @@ def get_table_changes(

statements += enums_pre

for t, v in added.items():
statements.append(v.create_statement)
if v.rowsecurity:
rls_alter = v.alter_rls_statement
statements.append(rls_alter)
# topologial sort of tables using table.dependents
# this is to ensure that tables are created in the correct order
# so that foreign keys can be created

G = DiGraph()
G.add_nodes_from(tables_target.keys())
for t, v in tables_target.items():
G.add_edges_from((t, d) for d in v.dependents)

tables_sorted = {k: tables_target[k] for k in lexicographical_topological_sort(G) if k in tables_target}

for t, v in tables_sorted.items():
if t in added:
statements.append(v.create_statement)
if v.rowsecurity:
rls_alter = v.alter_rls_statement
statements.append(rls_alter)

statements += enums_post

Expand Down Expand Up @@ -365,7 +380,6 @@ def get_selectable_differences(
not_replaceable = set()

if add_dependents_for_modified:

for k, m in changed_all.items():
old = selectables_from[k]

Expand Down Expand Up @@ -478,7 +492,7 @@ def get_selectable_changes(
statements = Statements()

def functions(d):
return {k: v for k, v in d.items() if v.relationtype == "f"}
return {k: v for k, v in d.items() if v.relationtype in ("f", "a")}

if not tables_only:
if not creations_only:
Expand Down Expand Up @@ -656,6 +670,15 @@ def sequences(self):
modifications=False,
)

@property
def comments(self):
return partial(
statements_for_changes,
self.i_from.comments,
self.i_target.comments,
modifications=False,
)

def __getattr__(self, name):
if name in THINGS:
return partial(
Expand Down
2 changes: 2 additions & 0 deletions migra/migra.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def add_extension_changes(self, creates=True, drops=True):
self.add(self.changes.extensions(drops_only=True))

def add_all_changes(self, privileges=False):
self.add(self.changes.comments(drops_only=True))
self.add(self.changes.schemas(creations_only=True))

self.add(self.changes.extensions(creations_only=True, modifications=False))
Expand Down Expand Up @@ -122,6 +123,7 @@ def add_all_changes(self, privileges=False):
self.add(self.changes.triggers(creations_only=True))
self.add(self.changes.collations(drops_only=True))
self.add(self.changes.schemas(drops_only=True))
self.add(self.changes.comments(creations_only=True))

@property
def sql(self):
Expand Down
Loading