Skip to content

Commit 315fa99

Browse files
committed
.,cgosqlite: add log handler callback
SQLite logs can contain useful diagnostic information such as access misuse and other out of band failure modes. Updates tailscale/corp#33477 Updates tailscale/corp#33305 Updates tailscale/corp#33577
1 parent 35d6745 commit 315fa99

File tree

4 files changed

+112
-1
lines changed

4 files changed

+112
-1
lines changed

cgosqlite/cgosqlite.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,4 +149,15 @@ static sqlite3_int64 ts_sqlite3_column_int64(handle_sqlite3_stmt stmt, int iCol)
149149

150150
static int ts_sqlite3_disable_function(sqlite3 *db, const char *zFunctionName, int nArg) {
151151
return sqlite3_create_function(db, zFunctionName, nArg, SQLITE_ANY, NULL, NULL, NULL, NULL);
152+
}
153+
154+
void logCallbackGo(void* userData, int errCode, char* msgC);
155+
156+
static void log_callback_into_go(void *userData, int errCode, const char *msg) {
157+
logCallbackGo(userData, errCode, (char*)msg);
158+
}
159+
160+
// ts_sqlite3_config_log configures SQLite to call into Go for log messages.
161+
static int ts_sqlite3_config_log(void) {
162+
return sqlite3_config(SQLITE_CONFIG_LOG, log_callback_into_go, NULL);
152163
}

cgosqlite/logcallback.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package cgosqlite
2+
3+
// #include "cgosqlite.h"
4+
import "C"
5+
import (
6+
"sync"
7+
"unsafe"
8+
9+
"github.com/tailscale/sqlite/sqliteh"
10+
)
11+
12+
// LogCallback receives SQLite log messages.
13+
type LogCallback func(code sqliteh.Code, msg string)
14+
15+
var (
16+
logCallbackMu sync.Mutex
17+
logCallback LogCallback
18+
)
19+
20+
//export logCallbackGo
21+
func logCallbackGo(userData unsafe.Pointer, errCode C.int, msgC *C.char) {
22+
logCallbackMu.Lock()
23+
cb := logCallback
24+
logCallbackMu.Unlock()
25+
26+
if cb == nil {
27+
return
28+
}
29+
30+
msg := C.GoString(msgC)
31+
cb(sqliteh.Code(errCode), msg)
32+
}
33+
34+
// SetLogCallback sets the global SQLite log callback.
35+
// If callback is nil, logs are discarded.
36+
func SetLogCallback(callback LogCallback) error {
37+
logCallbackMu.Lock()
38+
logCallback = callback
39+
logCallbackMu.Unlock()
40+
41+
res := C.ts_sqlite3_config_log()
42+
return errCode(res)
43+
}

sqlite_cgo.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,20 @@
33

44
package sqlite
55

6-
import "github.com/tailscale/sqlite/cgosqlite"
6+
import (
7+
"github.com/tailscale/sqlite/cgosqlite"
8+
"github.com/tailscale/sqlite/sqliteh"
9+
)
710

811
func init() {
912
Open = cgosqlite.Open
1013
}
14+
15+
// LogCallback receives SQLite log messages.
16+
type LogCallback func(code sqliteh.Code, msg string)
17+
18+
// SetLogCallback sets the global SQLite log callback.
19+
// If callback is nil, logs are discarded.
20+
func SetLogCallback(callback LogCallback) error {
21+
return cgosqlite.SetLogCallback(cgosqlite.LogCallback(callback))
22+
}

sqlite_cgo_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//go:build cgo
2+
// +build cgo
3+
4+
package sqlite
5+
6+
import (
7+
"sync"
8+
"testing"
9+
10+
"github.com/tailscale/sqlite/cgosqlite"
11+
"github.com/tailscale/sqlite/sqliteh"
12+
)
13+
14+
// ensure LogCallback is convertible to cgosqlite.LogCallback
15+
var _ cgosqlite.LogCallback = cgosqlite.LogCallback(LogCallback(func(code sqliteh.Code, msg string) {}))
16+
17+
func TestSetLogCallback(t *testing.T) {
18+
var mu sync.Mutex
19+
var logs []string
20+
21+
err := SetLogCallback(func(code sqliteh.Code, msg string) {
22+
mu.Lock()
23+
defer mu.Unlock()
24+
logs = append(logs, msg)
25+
})
26+
if err != nil {
27+
t.Fatal(err)
28+
}
29+
defer SetLogCallback(nil)
30+
31+
db := openTestDB(t)
32+
33+
_, err = db.Exec("SELECT * FROM nonexistent_table")
34+
if err == nil {
35+
t.Fatal("expected error from invalid SQL")
36+
}
37+
38+
mu.Lock()
39+
gotLogs := len(logs) > 0
40+
mu.Unlock()
41+
42+
if !gotLogs {
43+
t.Fatal("expected to receive log messages")
44+
}
45+
}

0 commit comments

Comments
 (0)