Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions db/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ struct db {

/* Fatal if we try to write to db */
bool readonly;

/* Set during migrations to prevent STRICT mode on table creation */
bool in_migration;
};

struct db_query {
Expand Down
34 changes: 32 additions & 2 deletions db/db_sqlite3.c
Original file line number Diff line number Diff line change
Expand Up @@ -203,16 +203,46 @@ static bool db_sqlite3_setup(struct db *db, bool create)
"PRAGMA foreign_keys = ON;", -1, &stmt, NULL);
err = sqlite3_step(stmt);
sqlite3_finalize(stmt);
return err == SQLITE_DONE;

if (err != SQLITE_DONE)
return false;

if (db->developer) {
sqlite3_prepare_v2(conn2sql(db->conn),
"PRAGMA trusted_schema = OFF;", -1, &stmt, NULL);
sqlite3_step(stmt);
sqlite3_finalize(stmt);

sqlite3_prepare_v2(conn2sql(db->conn),
"PRAGMA cell_size_check = ON;", -1, &stmt, NULL);
sqlite3_step(stmt);
sqlite3_finalize(stmt);
}

return true;
}

static bool db_sqlite3_query(struct db_stmt *stmt)
{
sqlite3_stmt *s;
sqlite3 *conn = conn2sql(stmt->db->conn);
int err;
const char *query = stmt->query->query;
char *modified_query = NULL;

/* STRICT tables for developer mode, and not during upgrades. */
if (stmt->db->developer &&
!stmt->db->in_migration &&
strncasecmp(query, "CREATE TABLE", 12) == 0 &&
!strstr(query, "STRICT")) {
modified_query = tal_fmt(stmt, "%s STRICT", query);
query = modified_query;
}

err = sqlite3_prepare_v2(conn, query, -1, &s, NULL);

err = sqlite3_prepare_v2(conn, stmt->query->query, -1, &s, NULL);
if (modified_query)
tal_free(modified_query);

for (size_t i=0; i<stmt->query->placeholders; i++) {
struct db_binding *b = &stmt->bindings[i];
Expand Down
1 change: 1 addition & 0 deletions db/utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ struct db *db_open_(const tal_t *ctx, const char *filename,
db->in_transaction = NULL;
db->transaction_started = false;
db->changes = NULL;
db->in_migration = false;

/* This must be outside a transaction, so catch it */
assert(!db->in_transaction);
Expand Down
2 changes: 2 additions & 0 deletions devtools/sql-rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def rewrite_single(self, query):
r'BIGINT': 'INTEGER',
r'BIGINTEGER': 'INTEGER',
r'BIGSERIAL': 'INTEGER',
r'VARCHAR(?:\(\d+\))?': 'TEXT',
r'\bINT\b': 'INTEGER',
r'CURRENT_TIMESTAMP\(\)': "strftime('%s', 'now')",
r'INSERT INTO[ \t]+(.*)[ \t]+ON CONFLICT.*DO NOTHING;': 'INSERT OR IGNORE INTO \\1;',
# Rewrite "decode('abcd', 'hex')" to become "x'abcd'"
Expand Down
53 changes: 53 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ def test_scid_upgrade(node_factory, bitcoind):
assert l1.db_query('SELECT scid FROM channels;') == [{'scid': scid_to_int('103x1x1')}]
assert l1.db_query('SELECT failscid FROM payments;') == [{'failscid': scid_to_int('103x1x1')}]

faildetail_types = l1.db_query(
"SELECT id, typeof(faildetail) as type "
"FROM payments WHERE faildetail IS NOT NULL"
)
for row in faildetail_types:
assert row['type'] == 'text', \
f"Payment {row['id']}: faildetail has type {row['type']}, expected 'text'"


@unittest.skipIf(not COMPAT, "needs COMPAT to convert obsolete db")
@unittest.skipIf(os.getenv('TEST_DB_PROVIDER', 'sqlite3') != 'sqlite3', "This test is based on a sqlite3 snapshot")
Expand Down Expand Up @@ -642,3 +650,48 @@ def test_channel_htlcs_id_change(bitcoind, node_factory):
# Make some HTLCS
for amt in (100, 500, 1000, 5000, 10000, 50000, 100000):
l1.pay(l3, amt)


@unittest.skipIf(os.getenv('TEST_DB_PROVIDER', 'sqlite3') != 'sqlite3', "STRICT tables are SQLite3 specific")
def test_sqlite_strict_mode(node_factory):
"""Test that STRICT is appended to CREATE TABLE in developer mode."""
l1 = node_factory.get_node(options={'developer': None})

tables = l1.db_query("SELECT name, sql FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")

strict_tables = [t for t in tables if t['sql'] and 'STRICT' in t['sql']]
assert len(strict_tables) > 0, f"Expected at least one STRICT table in developer mode, found none out of {len(tables)}"

known_strict_tables = ['version', 'forwards', 'payments', 'local_anchors', 'addresses']
for table_name in known_strict_tables:
table_sql = next((t['sql'] for t in tables if t['name'] == table_name), None)
if table_sql:
assert 'STRICT' in table_sql, f"Expected table '{table_name}' to be STRICT in developer mode"


@unittest.skipIf(os.getenv('TEST_DB_PROVIDER', 'sqlite3') != 'sqlite3', "SQLite3-specific test")
@unittest.skipIf(not COMPAT, "needs COMPAT to test old database upgrade")
@unittest.skipIf(TEST_NETWORK != 'regtest', "The network must match the DB snapshot")
def test_strict_mode_with_old_database(node_factory, bitcoind):
"""Test old database upgrades work (STRICT not applied during migrations)."""
bitcoind.generate_block(1)

l1 = node_factory.get_node(dbfile='oldstyle-scids.sqlite3.xz',
options={'database-upgrade': True,
'developer': None})

assert l1.rpc.getinfo()['id'] is not None

# Upgraded tables won't be STRICT (only fresh databases get STRICT).
strict_tables = l1.db_query(
"SELECT name FROM sqlite_master "
"WHERE type='table' AND sql LIKE '%STRICT%'"
)
assert len(strict_tables) == 0, "Upgraded database should not have STRICT tables"

# Verify BLOB->TEXT migration ran for faildetail cleanup.
result = l1.db_query(
"SELECT COUNT(*) as count FROM payments "
"WHERE typeof(faildetail) = 'blob'"
)
assert result[0]['count'] == 0, "Found BLOB-typed faildetail after migration"
61 changes: 61 additions & 0 deletions wallet/db.c
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ static void migrate_initialize_channel_htlcs_wait_indexes_and_fixup_forwards(str
struct db *db);
static void migrate_fail_pending_payments_without_htlcs(struct lightningd *ld,
struct db *db);
static void migrate_fix_payments_faildetail_type(struct lightningd *ld,
struct db *db);

/* Do not reorder or remove elements from this array, it is used to
* migrate existing databases from a previous state, based on the
Expand Down Expand Up @@ -1102,6 +1104,8 @@ static struct migration dbmigrations[] = {
")"), NULL},
{NULL, migrate_fail_pending_payments_without_htlcs},
{SQL("ALTER TABLE channels ADD withheld INTEGER DEFAULT 0;"), NULL},
/* Fix BLOB→TEXT in payments.faildetail for old databases. */
{NULL, migrate_fix_payments_faildetail_type},
};

/**
Expand All @@ -1118,6 +1122,9 @@ static bool db_migrate(struct lightningd *ld, struct db *db,
orig = current = db_get_version(db);
available = ARRAY_SIZE(dbmigrations) - 1;

/* Disable STRICT for upgrades: legacy data may have wrong type affinity. */
db->in_migration = (current != -1);

if (current == -1)
log_info(ld->log, "Creating database");
else if (available < current) {
Expand Down Expand Up @@ -1195,6 +1202,8 @@ struct db *db_setup(const tal_t *ctx, struct lightningd *ld,

db_commit_transaction(db);

db->in_migration = false;

/* This needs to be done outside a transaction, apparently.
* It's a good idea to do this every so often, and on db
* upgrade is a reasonable time. */
Expand Down Expand Up @@ -2153,3 +2162,55 @@ static void migrate_fail_pending_payments_without_htlcs(struct lightningd *ld,
db_bind_int(stmt, payment_status_in_db(PAYMENT_PENDING));
db_exec_prepared_v2(take(stmt));
}

static void migrate_fix_payments_faildetail_type(struct lightningd *ld,
struct db *db)
{
/* Historical databases may have BLOB-typed faildetail data.
* STRICT mode rejects this, so convert or NULL out invalid UTF-8. */
struct db_stmt *stmt;
size_t fixed = 0, invalid = 0;

stmt = db_prepare_v2(db, SQL("SELECT id, faildetail "
"FROM payments "
"WHERE typeof(faildetail) = 'blob'"));
db_query_prepared(stmt);

while (db_step(stmt)) {
u64 id = db_col_u64(stmt, "id");
const u8 *blob = db_col_blob(stmt, "faildetail");
size_t len = db_col_bytes(stmt, "faildetail");
struct db_stmt *upd;

if (!utf8_check(blob, len)) {
log_unusual(ld->log, "Payment %"PRIu64": "
"Invalid UTF-8 in faildetail, setting to NULL",
id);
upd = db_prepare_v2(db,
SQL("UPDATE payments "
"SET faildetail = NULL "
"WHERE id = ?"));
db_bind_u64(upd, id);
db_exec_prepared_v2(take(upd));
invalid++;
continue;
}

char *text = tal_strndup(tmpctx, (char *)blob, len);
upd = db_prepare_v2(db,
SQL("UPDATE payments "
"SET faildetail = ? "
"WHERE id = ?"));
db_bind_text(upd, text);
db_bind_u64(upd, id);
db_exec_prepared_v2(take(upd));
fixed++;
}

tal_free(stmt);

if (fixed > 0 || invalid > 0)
log_info(ld->log, "payments.faildetail migration: "
"%zu converted, %zu invalid UTF-8 nulled",
fixed, invalid);
}
Loading