Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ type (
//sys SQLAllocHandle(handleType SQLSMALLINT, inputHandle SQLHANDLE, outputHandle *SQLHANDLE) (ret SQLRETURN) = odbc32.SQLAllocHandle
//sys SQLBindCol(statementHandle SQLHSTMT, columnNumber SQLUSMALLINT, targetType SQLSMALLINT, targetValuePtr SQLPOINTER, bufferLength SQLLEN, vallen *SQLLEN) (ret SQLRETURN) = odbc32.SQLBindCol
//sys SQLBindParameter(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, inputOutputType SQLSMALLINT, valueType SQLSMALLINT, parameterType SQLSMALLINT, columnSize SQLULEN, decimalDigits SQLSMALLINT, parameterValue SQLPOINTER, bufferLength SQLLEN, ind *SQLLEN) (ret SQLRETURN) = odbc32.SQLBindParameter
//sys SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) = odbc32.SQLCancel
//sys SQLCloseCursor(statementHandle SQLHSTMT) (ret SQLRETURN) = odbc32.SQLCloseCursor
//sys SQLDescribeCol(statementHandle SQLHSTMT, columnNumber SQLUSMALLINT, columnName *SQLWCHAR, bufferLength SQLSMALLINT, nameLengthPtr *SQLSMALLINT, dataTypePtr *SQLSMALLINT, columnSizePtr *SQLULEN, decimalDigitsPtr *SQLSMALLINT, nullablePtr *SQLSMALLINT) (ret SQLRETURN) = odbc32.SQLDescribeColW
//sys SQLDescribeParam(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, dataTypePtr *SQLSMALLINT, parameterSizePtr *SQLULEN, decimalDigitsPtr *SQLSMALLINT, nullablePtr *SQLSMALLINT) (ret SQLRETURN) = odbc32.SQLDescribeParam
Expand Down
12 changes: 12 additions & 0 deletions api/api_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ const (
SQL_CP_DEFAULT = SQL_CP_OFF
SQL_CP_STRICT_MATCH = uintptr(C.SQL_CP_STRICT_MATCH)
SQL_CP_RELAXED_MATCH = uintptr(C.SQL_CP_RELAXED_MATCH)

//Transaction Isolation
SQL_ATTR_TXN_ISOLATION = C.SQL_ATTR_TXN_ISOLATION
SQL_TXN_READ_COMMITTED = C.SQL_TXN_READ_COMMITTED
SQL_TXN_READ_UNCOMMITTED = C.SQL_TXN_READ_UNCOMMITTED
SQL_TXN_REPEATABLE_READ = C.SQL_TXN_REPEATABLE_READ
SQL_TXN_SERIALIZABLE = C.SQL_TXN_SERIALIZABLE

//Access Mode
SQL_ATTR_ACCESS_MODE = C.SQL_ATTR_ACCESS_MODE
SQL_MODE_READ_ONLY = C.SQL_MODE_READ_ONLY
SQL_MODE_READ_WRITE = C.SQL_MODE_READ_WRITE
)

type (
Expand Down
12 changes: 12 additions & 0 deletions api/api_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ const (
SQL_CP_DEFAULT = SQL_CP_OFF
SQL_CP_STRICT_MATCH = 0
SQL_CP_RELAXED_MATCH = uintptr(1)

//Transaction Isolation
SQL_ATTR_TXN_ISOLATION = 108
SQL_TXN_READ_COMMITTED = 2
SQL_TXN_READ_UNCOMMITTED = 1
SQL_TXN_REPEATABLE_READ = 4
SQL_TXN_SERIALIZABLE = 8

//Access Mode
SQL_ATTR_ACCESS_MODE = 101
SQL_MODE_READ_ONLY = 1
SQL_MODE_READ_WRITE = 0
)

type (
Expand Down
5 changes: 5 additions & 0 deletions api/zapi_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ func SQLBindParameter(statementHandle SQLHSTMT, parameterNumber SQLUSMALLINT, in
return SQLRETURN(r)
}

func SQLCancel(statementHandle SQLHSTMT) (ret SQLRETURN) {
r := C.SQLCancel(C.SQLHSTMT(statementHandle))
return SQLRETURN(r)
}

func SQLCloseCursor(statementHandle SQLHSTMT) (ret SQLRETURN) {
r := C.SQLCloseCursor(C.SQLHSTMT(statementHandle))
return SQLRETURN(r)
Expand Down
32 changes: 20 additions & 12 deletions api/zapi_windows.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 32 additions & 27 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,33 @@
package odbc

import (
"context"
"database/sql/driver"
"strings"
"unsafe"

"github.com/polytomic/odbc/api"
"go.uber.org/atomic"
)

type Conn struct {
h api.SQLHDBC
tx *Tx
bad bool
bad *atomic.Bool
closingInBG *atomic.Bool
isMSAccessDriver bool
}

var accessDriverSubstr = strings.ToUpper(strings.Replace("DRIVER={Microsoft Access Driver", " ", "", -1))

func (d *Driver) Open(dsn string) (driver.Conn, error) {
if d.initErr != nil {
return nil, d.initErr
}

var out api.SQLHANDLE
ret := api.SQLAllocHandle(api.SQL_HANDLE_DBC, api.SQLHANDLE(d.h), &out)
if IsError(ret) {
return nil, NewError("SQLAllocHandle", d.h)
}
h := api.SQLHDBC(out)
drv.Stats.updateHandleCount(api.SQL_HANDLE_DBC, 1)

b := api.StringToUTF16(dsn)
ret = api.SQLDriverConnect(h, 0,
(*api.SQLWCHAR)(unsafe.Pointer(&b[0])), api.SQL_NTS,
nil, 0, nil, api.SQL_DRIVER_NOPROMPT)
if IsError(ret) {
defer releaseHandle(h)
return nil, NewError("SQLDriverConnect", h)
// implement driver.Conn
func (c *Conn) Close() (err error) {
if c.closingInBG.Load() {
//if we are cancelling/closing in a background thread, ignore requests to Close this connection from the driver
return nil
}
isAccess := strings.Contains(strings.ToUpper(strings.Replace(dsn, " ", "", -1)), accessDriverSubstr)
return &Conn{h: h, isMSAccessDriver: isAccess}, nil
return c.close()
}

func (c *Conn) Close() (err error) {
func (c *Conn) close() (err error) {
if c.tx != nil {
c.tx.Rollback()
}
Expand All @@ -68,7 +53,27 @@ func (c *Conn) Close() (err error) {
func (c *Conn) newError(apiName string, handle interface{}) error {
err := NewError(apiName, handle)
if err == driver.ErrBadConn {
c.bad = true
c.bad.Store(true)
}
return err
}

// implement driver.Conn
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}

// implement driver.Conn
func (c *Conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}

//implement driver.Execer
func (c *Conn) Exec(query string, args []driver.Value) (driver.Result, error) {
return c.ExecContext(context.Background(), query, toNamedValues(args))
}

//implement driver.Queryer
func (c *Conn) Query(query string, args []driver.Value) (driver.Rows, error) {
return c.QueryContext(context.Background(), query, toNamedValues(args))
}
131 changes: 131 additions & 0 deletions conn_go18.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
package odbc

import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"unsafe"

"github.com/polytomic/odbc/api"
)

var ErrTXAlreadyStarted = errors.New("already in a transaction")

var sqlIsolationLevel = map[sql.IsolationLevel]uintptr{
sql.LevelReadCommitted: api.SQL_TXN_READ_COMMITTED,
sql.LevelReadUncommitted: api.SQL_TXN_READ_UNCOMMITTED,
sql.LevelRepeatableRead: api.SQL_TXN_REPEATABLE_READ,
sql.LevelSerializable: api.SQL_TXN_SERIALIZABLE,
}

var testBeginErr error // used during tests

//implement driver.ConnBeginTx
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (tx driver.Tx, err error) {
if c.bad.Load() {
return nil, driver.ErrBadConn
}
//TODO(ninthclowd): refactor to use mocks / test hook functions or behavior tests so there isn't test logic in production
if testBeginErr != nil {
c.bad.Store(true)
return nil, testBeginErr
}

if c.tx != nil {
return nil, ErrTXAlreadyStarted
}
c.tx = &Tx{c: c, opts: opts}

if ret := api.SQLSetConnectUIntPtrAttr(c.h, api.SQL_ATTR_AUTOCOMMIT, api.SQL_AUTOCOMMIT_OFF, api.SQL_IS_UINTEGER); IsError(ret) {
c.bad.Store(true)
return nil, NewError("SQLSetConnectUIntPtrAttr", c.h)
}

if isolation, modeAvailable := sqlIsolationLevel[sql.IsolationLevel(opts.Isolation)]; modeAvailable {
if ret := api.SQLSetConnectUIntPtrAttr(c.h, api.SQL_ATTR_TXN_ISOLATION, isolation, api.SQL_IS_UINTEGER); IsError(ret) {
c.bad.Store(true)
return nil, NewError("SQLSetConnectUIntPtrAttr", c.h)
}
}
if opts.ReadOnly {
if ret := api.SQLSetConnectUIntPtrAttr(c.h, api.SQL_ATTR_ACCESS_MODE, api.SQL_MODE_READ_ONLY, api.SQL_IS_UINTEGER); IsError(ret) {
c.bad.Store(true)
return nil, NewError("SQLSetConnectUIntPtrAttr", c.h)
}
}

return c.tx, nil
}

//implement driver.ConnPrepareContext
func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if c.bad.Load() {
return nil, driver.ErrBadConn
}

var out api.SQLHANDLE
ret := api.SQLAllocHandle(api.SQL_HANDLE_STMT, api.SQLHANDLE(c.h), &out)
if IsError(ret) {
return nil, c.newError("SQLAllocHandle", c.h)
}
h := api.SQLHSTMT(out)
drv.Stats.StmtCount.Inc()

b := api.StringToUTF16(query)
ret = api.SQLPrepare(h, (*api.SQLWCHAR)(unsafe.Pointer(&b[0])), api.SQL_NTS)
if IsError(ret) {
defer releaseHandle(h)
return nil, c.newError("SQLPrepare", h)
}
ps, err := ExtractParameters(h)
if err != nil {
defer releaseHandle(h)
return nil, err
}

return &Stmt{
c: c,
query: query,
h: h,
parameters: ps,
rows: nil,
}, nil
}

//implement driver.ExecerContext
func (c *Conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (result driver.Result, err error) {
//TODO(ninthclowd): build and execute a statement with SQLExecDirect instead of preparing the statement
return nil, driver.ErrSkip
}

//implement driver.QueryerContext
func (c *Conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (rows driver.Rows, err error) {
//TODO(ninthclowd): build and execute a statement with SQLExecDirect instead of preparing the statement
return nil, driver.ErrSkip
}

//implement driver.SessionResetter
func (c *Conn) ResetSession(ctx context.Context) error {
if c.bad.Load() {
return driver.ErrBadConn
}
return nil
}

//implement driver.Pinger
func (c *Conn) Ping(ctx context.Context) error {
if c.bad.Load() {
return driver.ErrBadConn
}
stmt, err := c.PrepareContext(ctx, "SELECT 1")
if err != nil {
return driver.ErrBadConn
}
defer stmt.Close()

if _, err := stmt.(*Stmt).ExecContext(ctx, nil); err != nil {
return driver.ErrBadConn
}
return nil
}
21 changes: 21 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package odbc

import (
"context"
"database/sql/driver"
)

type connector struct {
d *Driver
name string
}

//implement driver.Connector
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
return c.d.open(c.name, ctx)
}

//implement driver.Connector
func (c *connector) Driver() driver.Driver {
return c.d
}
Loading