diff --git a/api/api.go b/api/api.go index 65c214a..88599eb 100644 --- a/api/api.go +++ b/api/api.go @@ -46,6 +46,7 @@ type ( //sys SQLAllocHandle(handleType SQLSMALLINT, inputHandle SQLHANDLE, outputHandle *SQLHANDLE) (ret SQLRETURN) = odbc32.SQLAllocHandle //sys SQLBindCol(statementHandle SQLHSTMT, columnNumber SQLUSMALLINT, targetType SQLSMALLINT, targetValuePtr SQLPOINTER, bufferLength SQLLEN, vallen *SQLLEN) (ret SQLRETURN) = odbc32.SQLBindCol //sys SQLBindParameter(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, inputOutputType SQLSMALLINT, valueType SQLSMALLINT, parameterType SQLSMALLINT, columnSize SQLULEN, decimalDigits SQLSMALLINT, parameterValue SQLPOINTER, bufferLength SQLLEN, ind *SQLLEN) (ret SQLRETURN) = odbc32.SQLBindParameter +//sys SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) = odbc32.SQLCancel //sys SQLCloseCursor(statementHandle SQLHSTMT) (ret SQLRETURN) = odbc32.SQLCloseCursor //sys SQLDescribeCol(statementHandle SQLHSTMT, columnNumber SQLUSMALLINT, columnName *SQLWCHAR, bufferLength SQLSMALLINT, nameLengthPtr *SQLSMALLINT, dataTypePtr *SQLSMALLINT, columnSizePtr *SQLULEN, decimalDigitsPtr *SQLSMALLINT, nullablePtr *SQLSMALLINT) (ret SQLRETURN) = odbc32.SQLDescribeColW //sys SQLDescribeParam(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, dataTypePtr *SQLSMALLINT, parameterSizePtr *SQLULEN, decimalDigitsPtr *SQLSMALLINT, nullablePtr *SQLSMALLINT) (ret SQLRETURN) = odbc32.SQLDescribeParam diff --git a/api/api_unix.go b/api/api_unix.go index 06ea270..3b6b6e0 100644 --- a/api/api_unix.go +++ b/api/api_unix.go @@ -126,6 +126,18 @@ const ( SQL_CP_DEFAULT = SQL_CP_OFF SQL_CP_STRICT_MATCH = uintptr(C.SQL_CP_STRICT_MATCH) SQL_CP_RELAXED_MATCH = uintptr(C.SQL_CP_RELAXED_MATCH) + + //Transaction Isolation + SQL_ATTR_TXN_ISOLATION = C.SQL_ATTR_TXN_ISOLATION + SQL_TXN_READ_COMMITTED = C.SQL_TXN_READ_COMMITTED + SQL_TXN_READ_UNCOMMITTED = C.SQL_TXN_READ_UNCOMMITTED + SQL_TXN_REPEATABLE_READ = C.SQL_TXN_REPEATABLE_READ + SQL_TXN_SERIALIZABLE = C.SQL_TXN_SERIALIZABLE + + //Access Mode + SQL_ATTR_ACCESS_MODE = C.SQL_ATTR_ACCESS_MODE + SQL_MODE_READ_ONLY = C.SQL_MODE_READ_ONLY + SQL_MODE_READ_WRITE = C.SQL_MODE_READ_WRITE ) type ( diff --git a/api/api_windows.go b/api/api_windows.go index 30e978f..b2d0be2 100644 --- a/api/api_windows.go +++ b/api/api_windows.go @@ -108,6 +108,18 @@ const ( SQL_CP_DEFAULT = SQL_CP_OFF SQL_CP_STRICT_MATCH = 0 SQL_CP_RELAXED_MATCH = uintptr(1) + + //Transaction Isolation + SQL_ATTR_TXN_ISOLATION = 108 + SQL_TXN_READ_COMMITTED = 2 + SQL_TXN_READ_UNCOMMITTED = 1 + SQL_TXN_REPEATABLE_READ = 4 + SQL_TXN_SERIALIZABLE = 8 + + //Access Mode + SQL_ATTR_ACCESS_MODE = 101 + SQL_MODE_READ_ONLY = 1 + SQL_MODE_READ_WRITE = 0 ) type ( diff --git a/api/zapi_unix.go b/api/zapi_unix.go index e36d53f..2055440 100644 --- a/api/zapi_unix.go +++ b/api/zapi_unix.go @@ -35,6 +35,11 @@ func SQLBindParameter(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, in return SQLRETURN(r) } +func SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) { + r := C.SQLCancel(C.SQLHSTMT(statementHandle)) + return SQLRETURN(r) +} + func SQLCloseCursor(statementHandle SQLHSTMT) (ret SQLRETURN) { r := C.SQLCloseCursor(C.SQLHSTMT(statementHandle)) return SQLRETURN(r) diff --git a/api/zapi_windows.go b/api/zapi_windows.go index 3657da3..bc5dd2b 100644 --- a/api/zapi_windows.go +++ b/api/zapi_windows.go @@ -1,4 +1,4 @@ -// MACHINE GENERATED BY 'go generate' COMMAND; DO NOT EDIT +// Code generated by 'go generate'; DO NOT EDIT. package api @@ -19,6 +19,7 @@ const ( var ( errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING) + errERROR_EINVAL error = syscall.EINVAL ) // errnoErr returns common boxed Errno values, to prevent @@ -26,7 +27,7 @@ var ( func errnoErr(e syscall.Errno) error { switch e { case 0: - return nil + return errERROR_EINVAL case errnoERROR_IO_PENDING: return errERROR_IO_PENDING } @@ -42,6 +43,7 @@ var ( procSQLAllocHandle = mododbc32.NewProc("SQLAllocHandle") procSQLBindCol = mododbc32.NewProc("SQLBindCol") procSQLBindParameter = mododbc32.NewProc("SQLBindParameter") + procSQLCancel = mododbc32.NewProc("SQLCancel") procSQLCloseCursor = mododbc32.NewProc("SQLCloseCursor") procSQLDescribeColW = mododbc32.NewProc("SQLDescribeColW") procSQLDescribeParam = mododbc32.NewProc("SQLDescribeParam") @@ -53,13 +55,13 @@ var ( procSQLFreeHandle = mododbc32.NewProc("SQLFreeHandle") procSQLGetData = mododbc32.NewProc("SQLGetData") procSQLGetDiagRecW = mododbc32.NewProc("SQLGetDiagRecW") - procSQLNumParams = mododbc32.NewProc("SQLNumParams") procSQLMoreResults = mododbc32.NewProc("SQLMoreResults") + procSQLNumParams = mododbc32.NewProc("SQLNumParams") procSQLNumResultCols = mododbc32.NewProc("SQLNumResultCols") procSQLPrepareW = mododbc32.NewProc("SQLPrepareW") procSQLRowCount = mododbc32.NewProc("SQLRowCount") - procSQLSetEnvAttr = mododbc32.NewProc("SQLSetEnvAttr") procSQLSetConnectAttrW = mododbc32.NewProc("SQLSetConnectAttrW") + procSQLSetEnvAttr = mododbc32.NewProc("SQLSetEnvAttr") ) func SQLAllocHandle(handleType SQLSMALLINT, inputHandle SQLHANDLE, outputHandle *SQLHANDLE) (ret SQLRETURN) { @@ -80,6 +82,12 @@ func SQLBindParameter(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, in return } +func SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLCancel.Addr(), 1, uintptr(statementHandle), 0, 0) + ret = SQLRETURN(r0) + return +} + func SQLCloseCursor(statementHandle SQLHSTMT) (ret SQLRETURN) { r0, _, _ := syscall.Syscall(procSQLCloseCursor.Addr(), 1, uintptr(statementHandle), 0, 0) ret = SQLRETURN(r0) @@ -146,14 +154,14 @@ func SQLGetDiagRec(handleType SQLSMALLINT, handle SQLHANDLE, recNumber SQLSMALLI return } -func SQLNumParams(statementHandle SQLHSTMT, parameterCountPtr *SQLSMALLINT) (ret SQLRETURN) { - r0, _, _ := syscall.Syscall(procSQLNumParams.Addr(), 2, uintptr(statementHandle), uintptr(unsafe.Pointer(parameterCountPtr)), 0) +func SQLMoreResults(statementHandle SQLHSTMT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLMoreResults.Addr(), 1, uintptr(statementHandle), 0, 0) ret = SQLRETURN(r0) return } -func SQLMoreResults(statementHandle SQLHSTMT) (ret SQLRETURN) { - r0, _, _ := syscall.Syscall(procSQLMoreResults.Addr(), 1, uintptr(statementHandle), 0, 0) +func SQLNumParams(statementHandle SQLHSTMT, parameterCountPtr *SQLSMALLINT) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall(procSQLNumParams.Addr(), 2, uintptr(statementHandle), uintptr(unsafe.Pointer(parameterCountPtr)), 0) ret = SQLRETURN(r0) return } @@ -176,14 +184,14 @@ func SQLRowCount(statementHandle SQLHSTMT, rowCountPtr *SQLLEN) (ret SQLRETURN) return } -func SQLSetEnvAttr(environmentHandle SQLHENV, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) { - r0, _, _ := syscall.Syscall6(procSQLSetEnvAttr.Addr(), 4, uintptr(environmentHandle), uintptr(attribute), uintptr(valuePtr), uintptr(stringLength), 0, 0) +func SQLSetConnectAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall6(procSQLSetConnectAttrW.Addr(), 4, uintptr(connectionHandle), uintptr(attribute), uintptr(valuePtr), uintptr(stringLength), 0, 0) ret = SQLRETURN(r0) return } -func SQLSetConnectAttr(connectionHandle SQLHDBC, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) { - r0, _, _ := syscall.Syscall6(procSQLSetConnectAttrW.Addr(), 4, uintptr(connectionHandle), uintptr(attribute), uintptr(valuePtr), uintptr(stringLength), 0, 0) +func SQLSetEnvAttr(environmentHandle SQLHENV, attribute SQLINTEGER, valuePtr SQLPOINTER, stringLength SQLINTEGER) (ret SQLRETURN) { + r0, _, _ := syscall.Syscall6(procSQLSetEnvAttr.Addr(), 4, uintptr(environmentHandle), uintptr(attribute), uintptr(valuePtr), uintptr(stringLength), 0, 0) ret = SQLRETURN(r0) return } diff --git a/conn.go b/conn.go index aec716f..0f32f29 100644 --- a/conn.go +++ b/conn.go @@ -5,48 +5,33 @@ package odbc import ( + "context" "database/sql/driver" "strings" - "unsafe" "github.com/polytomic/odbc/api" + "go.uber.org/atomic" ) type Conn struct { h api.SQLHDBC tx *Tx - bad bool + bad *atomic.Bool + closingInBG *atomic.Bool isMSAccessDriver bool } var accessDriverSubstr = strings.ToUpper(strings.Replace("DRIVER={Microsoft Access Driver", " ", "", -1)) -func (d *Driver) Open(dsn string) (driver.Conn, error) { - if d.initErr != nil { - return nil, d.initErr - } - - var out api.SQLHANDLE - ret := api.SQLAllocHandle(api.SQL_HANDLE_DBC, api.SQLHANDLE(d.h), &out) - if IsError(ret) { - return nil, NewError("SQLAllocHandle", d.h) - } - h := api.SQLHDBC(out) - drv.Stats.updateHandleCount(api.SQL_HANDLE_DBC, 1) - - b := api.StringToUTF16(dsn) - ret = api.SQLDriverConnect(h, 0, - (*api.SQLWCHAR)(unsafe.Pointer(&b[0])), api.SQL_NTS, - nil, 0, nil, api.SQL_DRIVER_NOPROMPT) - if IsError(ret) { - defer releaseHandle(h) - return nil, NewError("SQLDriverConnect", h) +// implement driver.Conn +func (c *Conn) Close() (err error) { + if c.closingInBG.Load() { + //if we are cancelling/closing in a background thread, ignore requests to Close this connection from the driver + return nil } - isAccess := strings.Contains(strings.ToUpper(strings.Replace(dsn, " ", "", -1)), accessDriverSubstr) - return &Conn{h: h, isMSAccessDriver: isAccess}, nil + return c.close() } - -func (c *Conn) Close() (err error) { +func (c *Conn) close() (err error) { if c.tx != nil { c.tx.Rollback() } @@ -68,7 +53,27 @@ func (c *Conn) Close() (err error) { func (c *Conn) newError(apiName string, handle interface{}) error { err := NewError(apiName, handle) if err == driver.ErrBadConn { - c.bad = true + c.bad.Store(true) } return err } + +// implement driver.Conn +func (c *Conn) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +// implement driver.Conn +func (c *Conn) Begin() (driver.Tx, error) { + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +//implement driver.Execer +func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) { + return c.ExecContext(context.Background(), query, toNamedValues(args)) +} + +//implement driver.Queryer +func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) { + return c.QueryContext(context.Background(), query, toNamedValues(args)) +} diff --git a/conn_go18.go b/conn_go18.go new file mode 100644 index 0000000..916cb58 --- /dev/null +++ b/conn_go18.go @@ -0,0 +1,131 @@ +package odbc + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "unsafe" + + "github.com/polytomic/odbc/api" +) + +var ErrTXAlreadyStarted = errors.New("already in a transaction") + +var sqlIsolationLevel = map[sql.IsolationLevel]uintptr{ + sql.LevelReadCommitted: api.SQL_TXN_READ_COMMITTED, + sql.LevelReadUncommitted: api.SQL_TXN_READ_UNCOMMITTED, + sql.LevelRepeatableRead: api.SQL_TXN_REPEATABLE_READ, + sql.LevelSerializable: api.SQL_TXN_SERIALIZABLE, +} + +var testBeginErr error // used during tests + +//implement driver.ConnBeginTx +func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) { + if c.bad.Load() { + return nil, driver.ErrBadConn + } + //TODO(ninthclowd): refactor to use mocks / test hook functions or behavior tests so there isn't test logic in production + if testBeginErr != nil { + c.bad.Store(true) + return nil, testBeginErr + } + + if c.tx != nil { + return nil, ErrTXAlreadyStarted + } + c.tx = &Tx{c: c, opts: opts} + + if ret := api.SQLSetConnectUIntPtrAttr(c.h, api.SQL_ATTR_AUTOCOMMIT, api.SQL_AUTOCOMMIT_OFF, api.SQL_IS_UINTEGER); IsError(ret) { + c.bad.Store(true) + return nil, NewError("SQLSetConnectUIntPtrAttr", c.h) + } + + if isolation, modeAvailable := sqlIsolationLevel[sql.IsolationLevel(opts.Isolation)]; modeAvailable { + if ret := api.SQLSetConnectUIntPtrAttr(c.h, api.SQL_ATTR_TXN_ISOLATION, isolation, api.SQL_IS_UINTEGER); IsError(ret) { + c.bad.Store(true) + return nil, NewError("SQLSetConnectUIntPtrAttr", c.h) + } + } + if opts.ReadOnly { + if ret := api.SQLSetConnectUIntPtrAttr(c.h, api.SQL_ATTR_ACCESS_MODE, api.SQL_MODE_READ_ONLY, api.SQL_IS_UINTEGER); IsError(ret) { + c.bad.Store(true) + return nil, NewError("SQLSetConnectUIntPtrAttr", c.h) + } + } + + return c.tx, nil +} + +//implement driver.ConnPrepareContext +func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if c.bad.Load() { + return nil, driver.ErrBadConn + } + + var out api.SQLHANDLE + ret := api.SQLAllocHandle(api.SQL_HANDLE_STMT, api.SQLHANDLE(c.h), &out) + if IsError(ret) { + return nil, c.newError("SQLAllocHandle", c.h) + } + h := api.SQLHSTMT(out) + drv.Stats.StmtCount.Inc() + + b := api.StringToUTF16(query) + ret = api.SQLPrepare(h, (*api.SQLWCHAR)(unsafe.Pointer(&b[0])), api.SQL_NTS) + if IsError(ret) { + defer releaseHandle(h) + return nil, c.newError("SQLPrepare", h) + } + ps, err := ExtractParameters(h) + if err != nil { + defer releaseHandle(h) + return nil, err + } + + return &Stmt{ + c: c, + query: query, + h: h, + parameters: ps, + rows: nil, + }, nil +} + +//implement driver.ExecerContext +func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) { + //TODO(ninthclowd): build and execute a statement with SQLExecDirect instead of preparing the statement + return nil, driver.ErrSkip +} + +//implement driver.QueryerContext +func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) { + //TODO(ninthclowd): build and execute a statement with SQLExecDirect instead of preparing the statement + return nil, driver.ErrSkip +} + +//implement driver.SessionResetter +func (c *Conn) ResetSession(ctx context.Context) error { + if c.bad.Load() { + return driver.ErrBadConn + } + return nil +} + +//implement driver.Pinger +func (c *Conn) Ping(ctx context.Context) error { + if c.bad.Load() { + return driver.ErrBadConn + } + stmt, err := c.PrepareContext(ctx, "SELECT 1") + if err != nil { + return driver.ErrBadConn + } + defer stmt.Close() + + if _, err := stmt.(*Stmt).ExecContext(ctx, nil); err != nil { + return driver.ErrBadConn + } + return nil +} diff --git a/connector.go b/connector.go new file mode 100644 index 0000000..3decf40 --- /dev/null +++ b/connector.go @@ -0,0 +1,21 @@ +package odbc + +import ( + "context" + "database/sql/driver" +) + +type connector struct { + d *Driver + name string +} + +//implement driver.Connector +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + return c.d.open(c.name, ctx) +} + +//implement driver.Connector +func (c *connector) Driver() driver.Driver { + return c.d +} diff --git a/driver.go b/driver.go index 12e4fa3..871d70b 100644 --- a/driver.go +++ b/driver.go @@ -7,9 +7,14 @@ package odbc import ( + "context" "database/sql" + "database/sql/driver" + "strings" + "unsafe" "github.com/polytomic/odbc/api" + "go.uber.org/atomic" ) var drv Driver @@ -21,7 +26,12 @@ type Driver struct { } func initDriver() error { - + //initialize allocation counters + drv.Stats = Stats{ + EnvCount: atomic.NewInt32(0), + ConnCount: atomic.NewInt32(0), + StmtCount: atomic.NewInt32(0), + } //Allocate environment handle var out api.SQLHANDLE in := api.SQLHANDLE(api.SQL_NULL_HANDLE) @@ -30,10 +40,7 @@ func initDriver() error { return NewError("SQLAllocHandle", api.SQLHENV(in)) } drv.h = api.SQLHENV(out) - err := drv.Stats.updateHandleCount(api.SQL_HANDLE_ENV, 1) - if err != nil { - return err - } + drv.Stats.EnvCount.Inc() // will use ODBC v3 ret = api.SQLSetEnvUIntPtrAttr(drv.h, api.SQL_ATTR_ODBC_VERSION, api.SQL_OV_ODBC3, 0) @@ -63,6 +70,7 @@ func initDriver() error { return nil } +//TODO(ninthclowd): this is not part of the driver.Driver interface and will never be called by a consumer func (d *Driver) Close() error { // TODO(brainman): who will call (*Driver).Close (to dispose all opened handles)? h := d.h @@ -77,3 +85,33 @@ func init() { } sql.Register("odbc", &drv) } + +// implement driver.Driver +func (d *Driver) Open(name string) (driver.Conn, error) { + return d.open(name, context.Background()) +} + +func (d *Driver) open(name string, dialContext context.Context) (driver.Conn, error) { + if d.initErr != nil { + return nil, d.initErr + } + //TODO(ninthclowd): return early if dialContext expires while connecting + var out api.SQLHANDLE + ret := api.SQLAllocHandle(api.SQL_HANDLE_DBC, api.SQLHANDLE(d.h), &out) + if IsError(ret) { + return nil, NewError("SQLAllocHandle", d.h) + } + h := api.SQLHDBC(out) + drv.Stats.ConnCount.Inc() + + b := api.StringToUTF16(name) + ret = api.SQLDriverConnect(h, 0, + (*api.SQLWCHAR)(unsafe.Pointer(&b[0])), api.SQL_NTS, + nil, 0, nil, api.SQL_DRIVER_NOPROMPT) + if IsError(ret) { + defer releaseHandle(h) + return nil, NewError("SQLDriverConnect", h) + } + isAccess := strings.Contains(strings.ToUpper(strings.Replace(name, " ", "", -1)), accessDriverSubstr) + return &Conn{h: h, isMSAccessDriver: isAccess, bad: atomic.NewBool(false), closingInBG: atomic.NewBool(false)}, nil +} diff --git a/driver_go18.go b/driver_go18.go new file mode 100644 index 0000000..40477a7 --- /dev/null +++ b/driver_go18.go @@ -0,0 +1,11 @@ +package odbc + +import "database/sql/driver" + +//implement driver.DriverContext +func (d *Driver) OpenConnector(name string) (driver.Connector, error) { + return &connector{ + d: d, + name: name, + }, nil +} diff --git a/go.mod b/go.mod index 05f00a7..9374847 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,6 @@ go 1.13 require ( github.com/go-ole/go-ole v1.2.6 - golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3 + go.uber.org/atomic v1.7.0 + golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46 ) diff --git a/go.sum b/go.sum index a452934..f58c7e6 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,16 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3 h1:7TYNF4UdlohbFwpNH04CoPMp1cHUZgO1Ebq5r2hIjfo= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46 h1:V066+OYJ66oTjnhm4Yrn7SXIwSCiDQJxpBxmvqb1N1c= +golang.org/x/sys v0.0.0-20210228012217-479acdf4ea46/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/mssql_go18_test.go b/mssql_go18_test.go new file mode 100644 index 0000000..bb7d511 --- /dev/null +++ b/mssql_go18_test.go @@ -0,0 +1,220 @@ +package odbc + +import ( + "context" + "database/sql" + "errors" + "sync" + "testing" + "time" +) + +func TestMSSQLMultipleExecOnStatement(t *testing.T) { + + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + exec(t, db, "drop table if exists dbo.temp") + exec(t, db, "create table dbo.temp (id int, a varchar(255))") + + stmt1, err := db.Prepare(`insert into dbo.temp (id, a) VALUES (?,?)`) + + if err != nil { + t.Fatal(err) + } + if _, err = stmt1.Exec(1, "TEST 1"); err != nil { + t.Errorf("Failed to insert record with ID 1: %s", err) + } + if _, err = stmt1.Exec(2, "TEST 2"); err != nil { + t.Errorf("Failed to insert record with ID 2: %s", err) + } + + if err = stmt1.Close(); err != nil { + t.Errorf("Failed to close exec statement: %s", err) + } + + stmt2, err := db.Prepare(`SELECT a FROM dbo.temp WHERE id = ?`) + + if err != nil { + t.Fatal(err) + } + if rows, err := stmt2.Query(1); err != nil { + t.Fatalf("Failed to query record with ID 1: %s", err) + } else if found := rows.Next(); !found { + t.Fatalf("No results returned from query") + } else { + var field string + rows.Scan(&field) + if field != "TEST 1" { + t.Fatalf("Got unexpected value from query: %s", field) + } + rows.Close() + } + + if rows, err := stmt2.Query(2); err != nil { + t.Fatalf("Failed to query record with ID 2: %s", err) + } else if found := rows.Next(); !found { + t.Fatalf("No results returned from query") + } else { + var field string + rows.Scan(&field) + if field != "TEST 2" { + t.Fatalf("Got unexpected value from query: %s", field) + } + rows.Close() + } + + if err = stmt2.Close(); err != nil { + t.Errorf("Failed to close query statement: %s", err) + } +} + +func TestMSSQLContextExpired(t *testing.T) { + + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + expiredContext, cancel := context.WithDeadline(context.Background(), time.Now().Add(-1*time.Second)) + defer cancel() + + if _, err := db.PrepareContext(expiredContext, `insert into dbo.temp (id, a) VALUES (?,?)`); !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected a context expired error from PrepareContext") + } + + if _, err := db.QueryContext(expiredContext, `SELECT * FROM dbo.temp`); !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected a context expired error from QueryContext") + } + + if _, err := db.ExecContext(expiredContext, `SELECT 1`); !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected a context expired error from ExecContext") + } + + if _, err := db.BeginTx(expiredContext, nil); !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("expected a context expired error from BeginTx") + } +} + +func TestMSSQLQueryContextTimeout(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + + testingIssue5 = false + + start := time.Now() + + var wgTest sync.WaitGroup + + for i := 0; i < 10; i++ { + wgTest.Add(1) + go func() { + defer wgTest.Done() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + if _, qErr := db.QueryContext(ctx, "WAITFOR DELAY '00:00:10'"); qErr == nil { + t.Error("expected an error to be returned") + } else if !errors.Is(qErr, context.DeadlineExceeded) { + t.Errorf("expected a context deadline error. got: %s", qErr.Error()) + } + if time.Since(start).Seconds() >= 2 { + t.Error("query should have been canceled after 1 second") + } + }() + } + + wgTest.Wait() + + //wait for the query to finish cancelling in the background and for the connection to close + time.Sleep(3 * time.Second) + + closeDB(t, db, sc, sc) + +} + +func TestMSSQLExecContextTimeout(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + + testingIssue5 = false + + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + if _, qErr := db.ExecContext(ctx, "WAITFOR DELAY '00:00:10'"); qErr == nil { + t.Fatal("expected an error to be returned") + } else if !errors.Is(qErr, context.DeadlineExceeded) { + t.Fatalf("expected a context canceled error. got: %s", qErr.Error()) + } + if time.Since(start).Seconds() > 2 { + t.Fatal("exec should have been canceled after 1 second") + } + + if _, qErr := db.ExecContext(ctx, "SELECT 1"); qErr == nil { + t.Fatal("expected an error to be returned for subsequent exec on expired context") + } else if !errors.Is(qErr, context.DeadlineExceeded) { + t.Fatalf("expected a context canceled error. got: %s", qErr.Error()) + } + + if _, qErr := db.ExecContext(context.Background(), "SELECT 1"); qErr != nil { + t.Fatalf("exec on a fresh context should execute without error. Got: %s", qErr.Error()) + } + + //wait for the query to finish cancelling in the background and for the connection to close + time.Sleep(3 * time.Second) + + closeDB(t, db, sc, sc) +} + +func TestMSSQLPing(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + c, _ := db.Conn(context.Background()) + + if err := c.PingContext(context.Background()); err != nil { + t.Fatalf("did not expect an error from ping. got %s", err.Error()) + } + + c.Close() + + if err := c.PingContext(context.Background()); err == nil { + t.Fatalf("expected ping to fail after being closed") + } + +} + +func TestMSSQLTxOptions(t *testing.T) { + db, sc, err := mssqlConnect() + if err != nil { + t.Fatal(err) + } + defer closeDB(t, db, sc, sc) + + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ + Isolation: sql.LevelReadUncommitted, + ReadOnly: true, + }) + if err != nil { + t.Fatalf("expected no error starting transaction. Got %s", err.Error()) + } + if _, err = tx.ExecContext(context.Background(), "SELECT 1"); err != nil { + t.Fatalf("expected no error from exec on transaction. Got %s", err.Error()) + } + if err := tx.Rollback(); err != nil { + t.Fatalf("expected no error rolling back transaction. Got %s", err.Error()) + } + +} diff --git a/mssql_test.go b/mssql_test.go index b1d79b6..319601d 100644 --- a/mssql_test.go +++ b/mssql_test.go @@ -110,26 +110,26 @@ func (params connParams) makeODBCConnectionString() string { return c } -func mssqlConnectWithParams(params connParams) (db *sql.DB, stmtCount int, err error) { +func mssqlConnectWithParams(params connParams) (db *sql.DB, stmtCount int32, err error) { db, err = sql.Open("odbc", params.makeODBCConnectionString()) if err != nil { return nil, 0, err } stats := db.Driver().(*Driver).Stats - return db, stats.StmtCount, nil + return db, stats.StmtCount.Load(), nil } -func mssqlConnect() (db *sql.DB, stmtCount int, err error) { +func mssqlConnect() (db *sql.DB, stmtCount int32, err error) { return mssqlConnectWithParams(newConnParams()) } -func closeDB(t *testing.T, db *sql.DB, shouldStmtCount, ignoreIfStmtCount int) { +func closeDB(t *testing.T, db *sql.DB, shouldStmtCount, ignoreIfStmtCount int32) { s := db.Driver().(*Driver).Stats err := db.Close() if err != nil { t.Fatalf("error closing DB: %v", err) } - switch s.StmtCount { + switch s.StmtCount.Load() { case shouldStmtCount: // all good case ignoreIfStmtCount: @@ -867,8 +867,9 @@ func TestMSSQLStmtAndRows(t *testing.T) { } }() - if db.Driver().(*Driver).Stats.StmtCount != sc { - t.Fatalf("invalid statement count: expected %v, is %v", sc, db.Driver().(*Driver).Stats.StmtCount) + gotSC := db.Driver().(*Driver).Stats.StmtCount.Load() + if gotSC != sc { + t.Fatalf("invalid statement count: expected %v, is %v", sc, gotSC) } // no resource tracking past this point @@ -1581,12 +1582,12 @@ func TestMSSQLMarkTxBadConn(t *testing.T) { testFn := func(endTx func(driver.Tx) error, nextFn func(driver.Conn) error) { proxy.restart() - cc, sc := drv.Stats.ConnCount, drv.Stats.StmtCount + cc, sc := drv.Stats.ConnCount.Load(), drv.Stats.StmtCount.Load() defer func() { - if should, is := sc, drv.Stats.StmtCount; should != is { + if should, is := sc, drv.Stats.StmtCount.Load(); should != is { t.Errorf("leaked statement, should=%d, is=%d", should, is) } - if should, is := cc, drv.Stats.ConnCount; should != is { + if should, is := cc, drv.Stats.ConnCount.Load(); should != is { t.Errorf("leaked connection, should=%d, is=%d", should, is) } }() @@ -1662,12 +1663,12 @@ func TestMSSQLMarkBeginBadConn(t *testing.T) { params := newConnParams() testFn := func(label string, nextFn func(driver.Conn) error) { - cc, sc := drv.Stats.ConnCount, drv.Stats.StmtCount + cc, sc := drv.Stats.ConnCount.Load(), drv.Stats.StmtCount.Load() defer func() { - if should, is := sc, drv.Stats.StmtCount; should != is { + if should, is := sc, drv.Stats.StmtCount.Load(); should != is { t.Errorf("leaked statement, should=%d, is=%d", should, is) } - if should, is := cc, drv.Stats.ConnCount; should != is { + if should, is := cc, drv.Stats.ConnCount.Load(); should != is { t.Errorf("leaked connection, should=%d, is=%d", should, is) } }() @@ -1696,9 +1697,9 @@ func TestMSSQLMarkBeginBadConn(t *testing.T) { }() // database/sql might return the broken driver.Conn to the pool. The - // next operation on the driver connection must return - // driver.ErrBadConn to prevent the bad connection from getting used - // again. + // next operation on the driver connection should return + // driver.ErrBadConn, or driver.SessionResetter should return driver.ErrBadConn + // to prevent the bad connection from getting used again. if should, is := driver.ErrBadConn, nextFn(dc); should != is { t.Errorf("%s: should=\"%v\", is=\"%v\"", label, should, is) } diff --git a/mysql_test.go b/mysql_test.go index 3d73470..da331a8 100644 --- a/mysql_test.go +++ b/mysql_test.go @@ -19,7 +19,7 @@ var ( mypass = flag.String("mypass", "", "mysql password") ) -func mysqlConnect() (db *sql.DB, stmtCount int, err error) { +func mysqlConnect() (db *sql.DB, stmtCount int32, err error) { // from https://dev.mysql.com/doc/connector-odbc/en/connector-odbc-configuration-connection-parameters.html conn := fmt.Sprintf("driver=mysql;server=%s;database=%s;user=%s;password=%s;", *mysrv, *mydb, *myuser, *mypass) @@ -28,7 +28,7 @@ func mysqlConnect() (db *sql.DB, stmtCount int, err error) { return nil, 0, err } stats := db.Driver().(*Driver).Stats - return db, stats.StmtCount, nil + return db, stats.StmtCount.Load(), nil } func TestMYSQLTime(t *testing.T) { diff --git a/ns_test.go b/ns_test.go new file mode 100644 index 0000000..0adc337 --- /dev/null +++ b/ns_test.go @@ -0,0 +1,91 @@ +package odbc + +import ( + "context" + "database/sql" + "errors" + "flag" + "fmt" + "testing" + "time" +) + +var ( + nsdsn = flag.String("nsdsn", "NetSuite", "") + nsuid = flag.String("nsuid", "", "") + nspwd = flag.String("nspwd", "", "") + nsacct = flag.String("nsacct", "", "") + nsrole = flag.String("nsrole", "3", "") +) + +func TestNSExecContextTimeout(t *testing.T) { + connStr := fmt.Sprintf( + "DSN=%s;Uid=%s;Pwd=%s;CustomProperties=AccountID=%s;RoleID=%s", + *nsdsn, + *nsuid, + *nspwd, + *nsacct, + *nsrole) + t.Log(connStr) + db, err := sql.Open("odbc", connStr) + if err != nil { + t.Skip("skipping ns tests") + } + + start := time.Now() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + + if _, qErr := db.ExecContext(ctx, "SELECT * FROM TRANSACTIONS"); qErr == nil { + t.Fatal("expected an error to be returned") + } else if !errors.Is(qErr, context.DeadlineExceeded) { + t.Fatalf("expected a context canceled error. got: %s", qErr.Error()) + } + if time.Since(start).Seconds() > 2 { + t.Fatal("exec should have been canceled after 1 second") + } + + if _, qErr := db.ExecContext(ctx, "SELECT 1"); qErr == nil { + t.Fatal("expected an error to be returned for subsequent exec on expired context") + } else if !errors.Is(qErr, context.DeadlineExceeded) { + t.Fatalf("expected a context canceled error. got: %s", qErr.Error()) + } + + if _, qErr := db.ExecContext(context.Background(), "SELECT 1"); qErr != nil { + t.Fatalf("exec on a fresh context should execute without error. Got: %s", qErr.Error()) + } + + //wait for the query to finish cancelling in the background and for the connection to close + time.Sleep(3 * time.Second) + + db.Close() +} + +func TestNSPing(t *testing.T) { + connStr := fmt.Sprintf( + "DSN=%s;Uid=%s;Pwd=%s;CustomProperties=AccountID=%s;RoleID=%s", + *nsdsn, + *nsuid, + *nspwd, + *nsacct, + *nsrole) + t.Log(connStr) + db, err := sql.Open("odbc", connStr) + + if err != nil { + t.Fatal(err) + } + + c, _ := db.Conn(context.Background()) + + if err := c.PingContext(context.Background()); err != nil { + t.Fatalf("did not expect an error from ping. got %s", err.Error()) + } + + c.Close() + + if err := c.PingContext(context.Background()); err == nil { + t.Fatalf("expected ping to fail after being closed") + } + +} diff --git a/odbcstmt.go b/odbcstmt.go deleted file mode 100644 index 373f100..0000000 --- a/odbcstmt.go +++ /dev/null @@ -1,160 +0,0 @@ -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package odbc - -import ( - "database/sql/driver" - "errors" - "fmt" - "sync" - "time" - "unsafe" - - "github.com/polytomic/odbc/api" -) - -// TODO(brainman): see if I could use SQLExecDirect anywhere - -type ODBCStmt struct { - h api.SQLHSTMT - Parameters []Parameter - Cols []Column - // locking/lifetime - mu sync.Mutex - usedByStmt bool - usedByRows bool -} - -func (c *Conn) PrepareODBCStmt(query string) (*ODBCStmt, error) { - var out api.SQLHANDLE - ret := api.SQLAllocHandle(api.SQL_HANDLE_STMT, api.SQLHANDLE(c.h), &out) - if IsError(ret) { - return nil, c.newError("SQLAllocHandle", c.h) - } - h := api.SQLHSTMT(out) - err := drv.Stats.updateHandleCount(api.SQL_HANDLE_STMT, 1) - if err != nil { - return nil, err - } - - b := api.StringToUTF16(query) - ret = api.SQLPrepare(h, (*api.SQLWCHAR)(unsafe.Pointer(&b[0])), api.SQL_NTS) - if IsError(ret) { - defer releaseHandle(h) - return nil, c.newError("SQLPrepare", h) - } - ps, err := ExtractParameters(h) - if err != nil { - defer releaseHandle(h) - return nil, err - } - return &ODBCStmt{ - h: h, - Parameters: ps, - usedByStmt: true, - }, nil -} - -func (s *ODBCStmt) closeByStmt() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.usedByStmt { - defer func() { s.usedByStmt = false }() - if !s.usedByRows { - return s.releaseHandle() - } - } - return nil -} - -func (s *ODBCStmt) closeByRows() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.usedByRows { - defer func() { s.usedByRows = false }() - if s.usedByStmt { - ret := api.SQLCloseCursor(s.h) - if IsError(ret) { - return NewError("SQLCloseCursor", s.h) - } - return nil - } else { - return s.releaseHandle() - } - } - return nil -} - -func (s *ODBCStmt) releaseHandle() error { - h := s.h - s.h = api.SQLHSTMT(api.SQL_NULL_HSTMT) - return releaseHandle(h) -} - -var testingIssue5 bool // used during tests - -func (s *ODBCStmt) Exec(args []driver.Value, conn *Conn) error { - if len(args) != len(s.Parameters) { - return fmt.Errorf("wrong number of arguments %d, %d expected", len(args), len(s.Parameters)) - } - for i, a := range args { - // this could be done in 2 steps: - // 1) bind vars right after prepare; - // 2) set their (vars) values here; - // but rebinding parameters for every new parameter value - // should be efficient enough for our purpose. - if err := s.Parameters[i].BindValue(s.h, i, a, conn); err != nil { - return err - } - } - if testingIssue5 { - time.Sleep(10 * time.Microsecond) - } - ret := api.SQLExecute(s.h) - if ret == api.SQL_NO_DATA { - // success but no data to report - return nil - } - if IsError(ret) { - return NewError("SQLExecute", s.h) - } - return nil -} - -func (s *ODBCStmt) BindColumns() error { - // count columns - var n api.SQLSMALLINT - ret := api.SQLNumResultCols(s.h, &n) - if IsError(ret) { - return NewError("SQLNumResultCols", s.h) - } - if n < 1 { - return errors.New("Stmt did not create a result set") - } - // fetch column descriptions - s.Cols = make([]Column, n) - binding := true - for i := range s.Cols { - c, err := NewColumn(s.h, i) - if err != nil { - return err - } - s.Cols[i] = c - // Once we found one non-bindable column, we will not bind the rest. - // http://www.easysoft.com/developer/languages/c/odbc-tutorial-fetching-results.html - // ... One common restriction is that SQLGetData may only be called on columns after the last bound column. ... - if !binding { - continue - } - bound, err := s.Cols[i].Bind(s.h, i) - if err != nil { - return err - } - if !bound { - binding = false - } - } - return nil -} diff --git a/result.go b/result.go index 9d4a0da..5b3a362 100644 --- a/result.go +++ b/result.go @@ -12,11 +12,13 @@ type Result struct { rowCount int64 } +// implement driver.Result func (r *Result) LastInsertId() (int64, error) { // TODO(brainman): implement (*Result).LastInsertId return 0, errors.New("not implemented") } +//implement driver.Result func (r *Result) RowsAffected() (int64, error) { return r.rowCount, nil } diff --git a/rows.go b/rows.go index 81c5bda..8a67893 100644 --- a/rows.go +++ b/rows.go @@ -12,27 +12,29 @@ import ( ) type Rows struct { - os *ODBCStmt + s *Stmt } +// implement driver.Rows func (r *Rows) Columns() []string { - names := make([]string, len(r.os.Cols)) + names := make([]string, len(r.s.cols)) for i := 0; i < len(names); i++ { - names[i] = r.os.Cols[i].Name() + names[i] = r.s.cols[i].Name() } return names } +// implement driver.Rows func (r *Rows) Next(dest []driver.Value) error { - ret := api.SQLFetch(r.os.h) + ret := api.SQLFetch(r.s.h) if ret == api.SQL_NO_DATA { return io.EOF } if IsError(ret) { - return NewError("SQLFetch", r.os.h) + return NewError("SQLFetch", r.s.h) } for i := range dest { - v, err := r.os.Cols[i].Value(r.os.h, i) + v, err := r.s.cols[i].Value(r.s.h, i) if err != nil { return err } @@ -41,24 +43,35 @@ func (r *Rows) Next(dest []driver.Value) error { return nil } +// implement driver.Rows func (r *Rows) Close() error { - return r.os.closeByRows() + if r.s.c.closingInBG.Load() { + //if we are cancelling/closing in a background thread, ignore requests to Close this statement from the driver + return nil + } + r.s.rows = nil + if ret := api.SQLCloseCursor(r.s.h); IsError(ret) { + return NewError("SQLCloseCursor", r.s.h) + } + return nil } +// implement driver.RowsNextResultSet func (r *Rows) HasNextResultSet() bool { return true } +// implement driver.RowsNextResultSet func (r *Rows) NextResultSet() error { - ret := api.SQLMoreResults(r.os.h) + ret := api.SQLMoreResults(r.s.h) if ret == api.SQL_NO_DATA { return io.EOF } if IsError(ret) { - return NewError("SQLMoreResults", r.os.h) + return NewError("SQLMoreResults", r.s.h) } - err := r.os.BindColumns() + err := r.s.bindColumns() if err != nil { return err } diff --git a/stats.go b/stats.go index a5d2e7b..21ed307 100644 --- a/stats.go +++ b/stats.go @@ -6,28 +6,25 @@ package odbc import ( "fmt" - "sync" "github.com/polytomic/odbc/api" + "go.uber.org/atomic" ) type Stats struct { - EnvCount int - ConnCount int - StmtCount int - mu sync.Mutex + EnvCount *atomic.Int32 + ConnCount *atomic.Int32 + StmtCount *atomic.Int32 } -func (s *Stats) updateHandleCount(handleType api.SQLSMALLINT, change int) error { - s.mu.Lock() - defer s.mu.Unlock() +func (s *Stats) updateHandleCount(handleType api.SQLSMALLINT, change int32) error { switch handleType { case api.SQL_HANDLE_ENV: - s.EnvCount += change + s.EnvCount.Add(change) case api.SQL_HANDLE_DBC: - s.ConnCount += change + s.ConnCount.Add(change) case api.SQL_HANDLE_STMT: - s.StmtCount += change + s.StmtCount.Add(change) default: return fmt.Errorf("unexpected handle type %d", handleType) } diff --git a/stmt.go b/stmt.go index 3e1e9f9..1069b31 100644 --- a/stmt.go +++ b/stmt.go @@ -5,104 +5,112 @@ package odbc import ( + "context" "database/sql/driver" "errors" - "sync" "github.com/polytomic/odbc/api" ) +// TODO(brainman): see if I could use SQLExecDirect anywhere + type Stmt struct { c *Conn query string - os *ODBCStmt - mu sync.Mutex -} -func (c *Conn) Prepare(query string) (driver.Stmt, error) { - if c.bad { - return nil, driver.ErrBadConn - } - os, err := c.PrepareODBCStmt(query) - if err != nil { - return nil, err - } - return &Stmt{c: c, os: os, query: query}, nil + h api.SQLHSTMT + parameters []Parameter + cols []Column + + //each statement can only have one open rows. If a second query is executed while rows is still open, + //the driver will prepare a new statement to execute on + rows *Rows } +// implement driver.Stmt func (s *Stmt) NumInput() int { - if s.os == nil { + if s.parameters == nil { return -1 } - return len(s.os.Parameters) + return len(s.parameters) } +// implement driver.Stmt +// Close closes the statement. +// +// As of Go 1.1, a Stmt will not be closed if it's in use +// by any queries. func (s *Stmt) Close() error { - if s.os == nil { - return errors.New("Stmt is already closed") + if s.c.closingInBG.Load() { + //if we are cancelling/closing in a background thread, ignore requests to Close this statement from the driver + return nil } - ret := s.os.closeByStmt() - s.os = nil - return ret + return s.close() +} +func (s *Stmt) close() error { + return s.releaseHandle() } +// implement driver.Stmt - per documentation, not supposed to be used by multiple goroutines func (s *Stmt) Exec(args []driver.Value) (driver.Result, error) { - if s.os == nil { - return nil, errors.New("Stmt is closed") + return s.ExecContext(context.Background(), toNamedValues(args)) +} + +// implement driver.Stmt - per documentation, not supposed to be used by multiple goroutines +func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { + return s.QueryContext(context.Background(), toNamedValues(args)) +} + +func (s *Stmt) releaseHandle() error { + h := s.h + s.h = api.SQLHSTMT(api.SQL_NULL_HSTMT) + return releaseHandle(h) +} + +func (s *Stmt) bindColumns() error { + // count columns + var n api.SQLSMALLINT + ret := api.SQLNumResultCols(s.h, &n) + if IsError(ret) { + return NewError("SQLNumResultCols", s.h) + } + if n < 1 { + return errors.New("statement did not create a result set") } - s.mu.Lock() - defer s.mu.Unlock() - if s.os.usedByRows { - s.os.closeByStmt() - s.os = nil - os, err := s.c.PrepareODBCStmt(s.query) + // fetch column descriptions + s.cols = make([]Column, n) + binding := true + for i := range s.cols { + c, err := NewColumn(s.h, i) if err != nil { - return nil, err + return err } - s.os = os - } - err := s.os.Exec(args, s.c) - if err != nil { - return nil, err - } - var sumRowCount int64 - for { - var c api.SQLLEN - ret := api.SQLRowCount(s.os.h, &c) - if IsError(ret) { - return nil, NewError("SQLRowCount", s.os.h) + s.cols[i] = c + // Once we found one non-bindable column, we will not bind the rest. + // http://www.easysoft.com/developer/languages/c/odbc-tutorial-fetching-results.html + // ... One common restriction is that SQLGetData may only be called on columns after the last bound column. ... + if !binding { + continue + } + bound, err := s.cols[i].Bind(s.h, i) + if err != nil { + return err } - sumRowCount += int64(c) - if ret = api.SQLMoreResults(s.os.h); ret == api.SQL_NO_DATA { - break + if !bound { + binding = false } } - return &Result{rowCount: sumRowCount}, nil + return nil } -func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) { - if s.os == nil { - return nil, errors.New("Stmt is closed") - } - s.mu.Lock() - defer s.mu.Unlock() - if s.os.usedByRows { - s.os.closeByStmt() - s.os = nil - os, err := s.c.PrepareODBCStmt(s.query) - if err != nil { - return nil, err +func toNamedValues(values []driver.Value) []driver.NamedValue { + namedValues := make([]driver.NamedValue, len(values)) + for idx, value := range values { + namedValues[idx] = driver.NamedValue{ + Name: "", + Ordinal: idx + 1, + Value: value, } - s.os = os - } - err := s.os.Exec(args, s.c) - if err != nil { - return nil, err - } - err = s.os.BindColumns() - if err != nil { - return nil, err } - s.os.usedByRows = true // now both Stmt and Rows refer to it - return &Rows{os: s.os}, nil + return namedValues } diff --git a/stmt_go18.go b/stmt_go18.go new file mode 100644 index 0000000..155122e --- /dev/null +++ b/stmt_go18.go @@ -0,0 +1,118 @@ +package odbc + +import ( + "context" + "database/sql/driver" + "fmt" + "sync" + "time" + + "go.uber.org/atomic" + + "github.com/polytomic/odbc/api" +) + +//implement driver.StmtExecContext +func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + + if err := s.exec(ctx, args); err != nil { + return nil, err + } + + var sumRowCount int64 + for { + var c api.SQLLEN + ret := api.SQLRowCount(s.h, &c) + if IsError(ret) { + return nil, NewError("SQLRowCount", s.h) + } + sumRowCount += int64(c) + if ret = api.SQLMoreResults(s.h); ret == api.SQL_NO_DATA { + break + } + } + return &Result{rowCount: sumRowCount}, nil +} + +//implement driver.StmtQueryContext +func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + if err := s.exec(ctx, args); err != nil { + return nil, err + } + + if err := s.bindColumns(); err != nil { + return nil, err + } + + s.rows = &Rows{s: s} + return s.rows, nil + +} + +var testingIssue5 bool // used during tests + +func (s *Stmt) exec(ctx context.Context, args []driver.NamedValue) error { + if len(args) != len(s.parameters) { + return fmt.Errorf("wrong number of arguments %d, %d expected", len(args), len(s.parameters)) + } + for _, namedValue := range args { + // this could be done in 2 steps: + // 1) bind vars right after prepare; + // 2) set their (vars) values here; + // but rebinding parameters for every new parameter value + // should be efficient enough for our purpose. + if err := s.parameters[namedValue.Ordinal-1].BindValue(s.h, namedValue.Ordinal-1, namedValue.Value, s.c); err != nil { + return err + } + } + if testingIssue5 { + time.Sleep(10 * time.Microsecond) + } + + sqlResult, cancelExec := s.sqlExecuteAsync() + + select { + case <-ctx.Done(): + //mark the connection as bad, so that the driver does not reuse it + s.c.bad.Store(true) + //mark the statement as closing in bg so stmt.Close and conn.Close do not block + s.c.closingInBG.Store(true) + //cancel the query and close the statement and connection in the background + go cancelExec() + return ctx.Err() + case err := <-sqlResult: + return err + } + +} + +func (s *Stmt) sqlExecuteAsync() (err <-chan error, cancel func()) { + var wgExecuting sync.WaitGroup + cancelled := atomic.NewBool(false) + cancel = func() { + if cancelled.Load() { + return + } + cancelled.Store(true) + //cancel the running statement + _ = api.SQLCancel(s.h) + //wait for the query to finish + wgExecuting.Wait() + s.close() + s.c.close() + } + errChannel := make(chan error) + wgExecuting.Add(1) + go func() { + defer wgExecuting.Done() + ret := api.SQLExecute(s.h) + if !cancelled.Load() { + var execErr error + if ret != api.SQL_NO_DATA && IsError(ret) { + execErr = NewError("SQLExecute", s.h) + } + errChannel <- execErr + } + }() + return errChannel, cancel +} diff --git a/tx.go b/tx.go index fb9d64b..9fb6a9b 100644 --- a/tx.go +++ b/tx.go @@ -11,67 +11,44 @@ import ( "github.com/polytomic/odbc/api" ) + +var ErrTXCompleted = errors.New("transaction already completed") + type Tx struct { - c *Conn + c *Conn + opts driver.TxOptions } -var testBeginErr error // used during tests - -func (c *Conn) setAutoCommitAttr(a uintptr) error { - if testBeginErr != nil { - return testBeginErr - } - ret := api.SQLSetConnectUIntPtrAttr(c.h, api.SQL_ATTR_AUTOCOMMIT, a, api.SQL_IS_UINTEGER) - if IsError(ret) { - return c.newError("SQLSetConnectUIntPtrAttr", c.h) - } - return nil +// implement driver.Tx +func (tx *Tx) Commit() error { + return tx.endTx(api.SQL_COMMIT) } -func (c *Conn) Begin() (driver.Tx, error) { - if c.bad { - return nil, driver.ErrBadConn - } - if c.tx != nil { - return nil, errors.New("already in a transaction") - } - c.tx = &Tx{c: c} - err := c.setAutoCommitAttr(api.SQL_AUTOCOMMIT_OFF) - if err != nil { - c.bad = true - return nil, err - } - return c.tx, nil +// implement driver.Tx +func (tx *Tx) Rollback() error { + return tx.endTx(api.SQL_ROLLBACK) } -func (c *Conn) endTx(commit bool) error { - if c.tx == nil { - return errors.New("not in a transaction") +func (tx *Tx) endTx(mode api.SQLSMALLINT) error { + if tx.c.tx == nil { + return ErrTXCompleted } - var howToEnd api.SQLSMALLINT - if commit { - howToEnd = api.SQL_COMMIT - } else { - howToEnd = api.SQL_ROLLBACK + tx.c.tx = nil + if ret := api.SQLEndTran(api.SQL_HANDLE_DBC, api.SQLHANDLE(tx.c.h), mode); IsError(ret) { + tx.c.bad.Store(true) + return tx.c.newError("SQLEndTran", tx.c.h) } - ret := api.SQLEndTran(api.SQL_HANDLE_DBC, api.SQLHANDLE(c.h), howToEnd) - if IsError(ret) { - c.bad = true - return c.newError("SQLEndTran", c.h) + + if ret := api.SQLSetConnectUIntPtrAttr(tx.c.h, api.SQL_ATTR_AUTOCOMMIT, api.SQL_AUTOCOMMIT_ON, api.SQL_IS_UINTEGER); IsError(ret) { + tx.c.bad.Store(true) + return tx.c.newError("SQLSetConnectUIntPtrAttr", tx.c.h) } - c.tx = nil - err := c.setAutoCommitAttr(api.SQL_AUTOCOMMIT_ON) - if err != nil { - c.bad = true - return err + + if tx.opts.ReadOnly { + if ret := api.SQLSetConnectUIntPtrAttr(tx.c.h, api.SQL_ATTR_ACCESS_MODE, api.SQL_MODE_READ_WRITE, api.SQL_IS_UINTEGER); IsError(ret) { + tx.c.bad.Store(true) + return NewError("SQLSetConnectUIntPtrAttr", tx.c.h) + } } return nil } - -func (tx *Tx) Commit() error { - return tx.c.endTx(true) -} - -func (tx *Tx) Rollback() error { - return tx.c.endTx(false) -}