Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func (c *conn) Close() error {

if err != nil {
log.Err(err).Msg("databricks: failed to close connection")
return dbsqlerrint.NewRequestError(ctx, dbsqlerr.ErrCloseConnection, err)
return dbsqlerrint.NewBadConnectionError(err)
}
return nil
}
Expand Down Expand Up @@ -168,9 +168,7 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
return nil, dbsqlerrint.NewExecutionError(ctx, dbsqlerr.ErrQueryExecution, err, opStatusResp)
}

corrId := driverctx.CorrelationIdFromContext(ctx)
rows, err := rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)

rows, err := rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
return rows, err

}
Expand Down Expand Up @@ -367,7 +365,7 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
var statusResp *cli_service.TGetOperationStatusResp
ctx = driverctx.NewContextWithConnId(ctx, c.id)
newCtx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), c.id), corrId)
newCtx := context.WithoutCancel(ctx)
pollSentinel := sentinel.Sentinel{
OnDoneFn: func(statusResp any) (any, error) {
return statusResp, nil
Expand Down Expand Up @@ -566,7 +564,6 @@ func (c *conn) execStagingOperation(
return nil
}

corrId := driverctx.CorrelationIdFromContext(ctx)
var row driver.Rows
var err error

Expand All @@ -589,7 +586,7 @@ func (c *conn) execStagingOperation(
}

if len(driverctx.StagingPathsFromContext(ctx)) != 0 {
row, err = rows.NewRows(c.id, corrId, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
row, err = rows.NewRows(ctx, exStmtResp.OperationHandle, c.client, c.cfg, exStmtResp.DirectResults)
if err != nil {
return dbsqlerrint.NewDriverError(ctx, "error reading row.", err)
}
Expand Down
31 changes: 21 additions & 10 deletions internal/rows/arrowbased/arrowRecordIterator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"testing"

"github.com/databricks/databricks-sql-go/driverctx"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/client"
"github.com/databricks/databricks-sql-go/internal/config"
Expand All @@ -32,15 +33,17 @@ func TestArrowRecordIterator(t *testing.T) {

var fetchesInfo []fetchResultsInfo

ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")

simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2})
rpi := rowscanner.NewResultPageIterator(
ctx,
rowscanner.NewDelimiter(0, 7311),
5000,
nil,
false,
simpleClient,
"connectionId",
"correlationId",
logger,
)

Expand Down Expand Up @@ -126,17 +129,19 @@ func TestArrowRecordIterator(t *testing.T) {
fetchResp3 := cli_service.TFetchResultsResp{}
loadTestData2(t, "multipleFetch/FetchResults3.json", &fetchResp3)

ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")

var fetchesInfo []fetchResultsInfo

simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2, fetchResp3})
rpi := rowscanner.NewResultPageIterator(
ctx,
rowscanner.NewDelimiter(0, 0),
5000,
nil,
false,
simpleClient,
"connectionId",
"correlationId",
logger,
)

Expand Down Expand Up @@ -199,16 +204,18 @@ func TestArrowRecordIteratorSchema(t *testing.T) {
fetchResp1 := cli_service.TFetchResultsResp{}
loadTestData2(t, "directResultsMultipleFetch/FetchResults1.json", &fetchResp1)

ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")

var fetchesInfo []fetchResultsInfo
simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1})
rpi := rowscanner.NewResultPageIterator(
ctx,
rowscanner.NewDelimiter(0, 0),
5000,
nil,
false,
simpleClient,
"connectionId",
"correlationId",
logger,
)

Expand Down Expand Up @@ -251,16 +258,18 @@ func TestArrowRecordIteratorSchema(t *testing.T) {
fetchResp1 := cli_service.TFetchResultsResp{}
loadTestData2(t, "multipleFetch/FetchResults1.json", &fetchResp1)

ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")

var fetchesInfo []fetchResultsInfo
simpleClient := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1})
rpi := rowscanner.NewResultPageIterator(
ctx,
rowscanner.NewDelimiter(0, 0),
5000,
nil,
false,
simpleClient,
"connectionId",
"correlationId",
logger,
)

Expand Down Expand Up @@ -293,14 +302,16 @@ func TestArrowRecordIteratorSchema(t *testing.T) {
},
}

ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")

rpi := rowscanner.NewResultPageIterator(
ctx,
rowscanner.NewDelimiter(0, 0),
5000,
nil,
false,
failingClient,
"connectionId",
"correlationId",
logger,
)

Expand Down
7 changes: 5 additions & 2 deletions internal/rows/arrowbased/arrowRows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/apache/arrow/go/v12/arrow"
"github.com/apache/arrow/go/v12/arrow/array"
"github.com/databricks/databricks-sql-go/driverctx"
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/config"
Expand Down Expand Up @@ -1525,18 +1526,20 @@ func TestArrowRowScanner(t *testing.T) {
fetchResp2 := cli_service.TFetchResultsResp{}
loadTestData2(t, "directResultsMultipleFetch/FetchResults2.json", &fetchResp2)

ctx := driverctx.NewContextWithConnId(context.Background(), "connectionId")
ctx = driverctx.NewContextWithCorrelationId(ctx, "correlationId")

var fetchesInfo []fetchResultsInfo
client := getSimpleClient(&fetchesInfo, []cli_service.TFetchResultsResp{fetchResp1, fetchResp2})
logger := dbsqllog.WithContext("connectionId", "correlationId", "")

rpi := rowscanner.NewResultPageIterator(
ctx,
rowscanner.NewDelimiter(0, 7311),
5000,
nil,
false,
client,
"connectionId",
"correlationId",
logger)

cfg := config.WithDefaults()
Expand Down
16 changes: 7 additions & 9 deletions internal/rows/rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,22 @@ var _ driver.RowsColumnTypeLength = (*rows)(nil)
var _ dbsqlrows.Rows = (*rows)(nil)

func NewRows(
connId string,
correlationId string,
ctx context.Context,
opHandle *cli_service.TOperationHandle,
client cli_service.TCLIService,
config *config.Config,
directResults *cli_service.TSparkDirectResults,
) (driver.Rows, dbsqlerr.DBError) {

connId := driverctx.ConnIdFromContext(ctx)
correlationId := driverctx.CorrelationIdFromContext(ctx)

var logger *dbsqllog.DBSQLLogger
var ctx context.Context
if opHandle != nil {
logger = dbsqllog.WithContext(connId, correlationId, dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
ctx = driverctx.NewContextWithQueryId(driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId), dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
ctx = driverctx.NewContextWithQueryId(ctx, dbsqlclient.SprintGuid(opHandle.OperationId.GUID))
} else {
logger = dbsqllog.WithContext(connId, correlationId, "")
ctx = driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), connId), correlationId)
}

if client == nil {
Expand Down Expand Up @@ -140,13 +140,12 @@ func NewRows(
// the operations.
closedOnServer := directResults != nil && directResults.CloseOperation != nil
r.ResultPageIterator = rowscanner.NewResultPageIterator(
ctx,
d,
pageSize,
opHandle,
closedOnServer,
client,
connId,
correlationId,
r.logger(),
)

Expand Down Expand Up @@ -417,9 +416,8 @@ func (r *rows) getResultSetSchema() (*cli_service.TTableSchema, dbsqlerr.DBError
req := cli_service.TGetResultSetMetadataReq{
OperationHandle: r.opHandle,
}
ctx := driverctx.NewContextWithCorrelationId(driverctx.NewContextWithConnId(context.Background(), r.connId), r.correlationId)

resp, err2 := r.client.GetResultSetMetadata(ctx, &req)
resp, err2 := r.client.GetResultSetMetadata(r.ctx, &req)
if err2 != nil {
r.logger().Err(err2).Msg(err2.Error())
return nil, dbsqlerr_int.NewRequestError(r.ctx, errRowsMetadataFetchFailed, err)
Expand Down
Loading