Skip to content

Commit 772d896

Browse files
authored
V1.9.5 release (#295)
* Added OAuth, pyodbc 5.20.0 and retryable exceptions support
1 parent d3b8665 commit 772d896

File tree

6 files changed

+99
-76
lines changed

6 files changed

+99
-76
lines changed

.github/workflows/integration-tests-azure.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,4 +67,4 @@ jobs:
6767
DBT_TEST_USER_1: dbo
6868
DBT_TEST_USER_2: dbo
6969
DBT_TEST_USER_3: dbo
70-
run: pytest -ra -v tests/functional --profile "${{ matrix.profile }}"
70+
run: pytest -ra -vv -x tests/functional --profile "${{ matrix.profile }}"

dbt/adapters/fabric/__version__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
1-
version = "1.9.4"
2-
1+
version = "1.9.5"

dbt/adapters/fabric/fabric_connection_manager.py

Lines changed: 91 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
from contextlib import contextmanager
55
from itertools import chain, repeat
6-
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Union
6+
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type, Union
77

88
import agate
99
import dbt_common.exceptions
@@ -12,7 +12,7 @@
1212
from azure.identity import AzureCliCredential, DefaultAzureCredential, EnvironmentCredential
1313
from dbt.adapters.contracts.connection import AdapterResponse, Connection, ConnectionState
1414
from dbt.adapters.events.logging import AdapterLogger
15-
from dbt.adapters.events.types import ConnectionUsed, SQLQuery, SQLQueryStatus
15+
from dbt.adapters.events.types import AdapterEventDebug, ConnectionUsed, SQLQuery, SQLQueryStatus
1616
from dbt.adapters.sql import SQLConnectionManager
1717
from dbt_common.clients.agate_helper import empty_table
1818
from dbt_common.events.contextvars import get_node_info
@@ -182,7 +182,7 @@ def get_environment_access_token(credentials: FabricCredentials) -> AccessToken:
182182

183183
def get_pyodbc_attrs_before_credentials(credentials: FabricCredentials) -> Dict:
184184
"""
185-
Get the pyodbc attrs before.
185+
Get the pyodbc attributes for authentication.
186186
187187
Parameters
188188
----------
@@ -191,63 +191,34 @@ def get_pyodbc_attrs_before_credentials(credentials: FabricCredentials) -> Dict:
191191
192192
Returns
193193
-------
194-
out : Dict
195-
The pyodbc attrs before.
196-
197-
Source
198-
------
199-
Authentication for SQL server with an access token:
200-
https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory?view=sql-server-ver15#authenticating-with-an-access-token
194+
Dict
195+
The pyodbc attributes for authentication.
201196
"""
202197
global _TOKEN
203-
attrs_before: Dict
198+
sql_copt_ss_access_token = 1256 # ODBC constant for access token
204199
MAX_REMAINING_TIME = 300
205200

206-
authentication = str(credentials.authentication).lower()
207-
if authentication in AZURE_AUTH_FUNCTIONS:
208-
time_remaining = (_TOKEN.expires_on - time.time()) if _TOKEN else MAX_REMAINING_TIME
209-
210-
if _TOKEN is None or (time_remaining < MAX_REMAINING_TIME):
211-
azure_auth_function = AZURE_AUTH_FUNCTIONS[authentication]
212-
_TOKEN = azure_auth_function(credentials)
213-
214-
token_bytes = convert_access_token_to_mswindows_byte_string(_TOKEN)
215-
sql_copt_ss_access_token = 1256 # see source in docstring
216-
attrs_before = {sql_copt_ss_access_token: token_bytes}
217-
else:
218-
attrs_before = {}
219-
220-
return attrs_before
221-
222-
223-
def get_pyodbc_attrs_before_accesstoken(accessToken: str) -> Dict:
224-
"""
225-
Get the pyodbc attrs before.
201+
if credentials.authentication.lower() in AZURE_AUTH_FUNCTIONS:
202+
if not _TOKEN or (_TOKEN.expires_on - time.time() < MAX_REMAINING_TIME):
203+
_TOKEN = AZURE_AUTH_FUNCTIONS[credentials.authentication.lower()](credentials)
204+
return {sql_copt_ss_access_token: convert_access_token_to_mswindows_byte_string(_TOKEN)}
226205

227-
Parameters
228-
----------
229-
credentials : Access Token for Integration Tests
230-
Credentials.
231-
232-
Returns
233-
-------
234-
out : Dict
235-
The pyodbc attrs before.
236-
237-
Source
238-
------
239-
Authentication for SQL server with an access token:
240-
https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory?view=sql-server-ver15#authenticating-with-an-access-token
241-
"""
242-
243-
access_token_utf16 = accessToken.encode("utf-16-le")
244-
token_struct = struct.pack(
245-
f"<I{len(access_token_utf16)}s", len(access_token_utf16), access_token_utf16
246-
)
247-
sql_copt_ss_access_token = 1256 # see source in docstring
248-
attrs_before = {sql_copt_ss_access_token: token_struct}
206+
if credentials.authentication.lower() == "activedirectoryaccesstoken":
207+
if credentials.access_token is None or credentials.access_token_expires_on is None:
208+
raise ValueError(
209+
"Access token and access token expiry are required for ActiveDirectoryAccessToken authentication."
210+
)
211+
_TOKEN = AccessToken(
212+
token=credentials.access_token,
213+
expires_on=int(
214+
time.time() + 4500.0
215+
if credentials.access_token_expires_on == 0
216+
else credentials.access_token_expires_on
217+
),
218+
)
219+
return {sql_copt_ss_access_token: convert_access_token_to_mswindows_byte_string(_TOKEN)}
249220

250-
return attrs_before
221+
return {}
251222

252223

253224
def bool_to_connection_string_arg(key: str, value: bool) -> str:
@@ -362,6 +333,8 @@ def open(cls, connection: Connection) -> Connection:
362333

363334
assert credentials.authentication is not None
364335

336+
# Access token authentication does not additional connection string parameters. The access token
337+
# is passed in the pyodbc attributes.
365338
if (
366339
"ActiveDirectory" in credentials.authentication
367340
and credentials.authentication != "ActiveDirectoryAccessToken"
@@ -429,10 +402,9 @@ def open(cls, connection: Connection) -> Connection:
429402
def connect():
430403
logger.debug(f"Using connection string: {con_str_display}")
431404
pyodbc.pooling = True
432-
if credentials.authentication == "ActiveDirectoryAccessToken":
433-
attrs_before = get_pyodbc_attrs_before_accesstoken(credentials.access_token)
434-
else:
435-
attrs_before = get_pyodbc_attrs_before_credentials(credentials)
405+
406+
# pyodbc attributes includes the access token provided by the user if required.
407+
attrs_before = get_pyodbc_attrs_before_credentials(credentials)
436408

437409
handle = pyodbc.connect(
438410
con_str_concat,
@@ -469,7 +441,58 @@ def add_query(
469441
auto_begin: bool = True,
470442
bindings: Optional[Any] = None,
471443
abridge_sql_log: bool = False,
444+
retryable_exceptions: Tuple[Type[Exception], ...] = (),
445+
retry_limit: int = 2,
472446
) -> Tuple[Connection, Any]:
447+
"""
448+
Retry function encapsulated here to avoid commitment to some
449+
user-facing interface. Right now, Redshift commits to a 1 second
450+
retry timeout so this serves as a default.
451+
"""
452+
453+
def _execute_query_with_retry(
454+
cursor: Any,
455+
sql: str,
456+
bindings: Optional[Any],
457+
retryable_exceptions: Tuple[Type[Exception], ...],
458+
retry_limit: int,
459+
attempt: int,
460+
):
461+
"""
462+
A success sees the try exit cleanly and avoid any recursive
463+
retries. Failure begins a sleep and retry routine.
464+
"""
465+
try:
466+
# pyodbc does not handle a None type binding!
467+
if bindings is None:
468+
cursor.execute(sql)
469+
else:
470+
bindings = [
471+
binding if not isinstance(binding, dt.datetime) else binding.isoformat()
472+
for binding in bindings
473+
]
474+
cursor.execute(sql, bindings)
475+
except retryable_exceptions as e:
476+
# Cease retries and fail when limit is hit.
477+
if attempt >= retry_limit:
478+
raise e
479+
480+
fire_event(
481+
AdapterEventDebug(
482+
message=f"Got a retryable error {type(e)}. {retry_limit-attempt} retries left. Retrying in 1 second.\nError:\n{e}"
483+
)
484+
)
485+
time.sleep(1)
486+
487+
return _execute_query_with_retry(
488+
cursor=cursor,
489+
sql=sql,
490+
bindings=bindings,
491+
retryable_exceptions=retryable_exceptions,
492+
retry_limit=retry_limit,
493+
attempt=attempt + 1,
494+
)
495+
473496
connection = self.get_thread_connection()
474497

475498
if auto_begin and connection.transaction_open is False:
@@ -498,16 +521,16 @@ def add_query(
498521
pre = time.time()
499522

500523
cursor = connection.handle.cursor()
501-
502-
# pyodbc does not handle a None type binding!
503-
if bindings is None:
504-
cursor.execute(sql)
505-
else:
506-
bindings = [
507-
binding if not isinstance(binding, dt.datetime) else binding.isoformat()
508-
for binding in bindings
509-
]
510-
cursor.execute(sql, bindings)
524+
credentials = self.get_credentials(connection.credentials)
525+
526+
_execute_query_with_retry(
527+
cursor=cursor,
528+
sql=sql,
529+
bindings=bindings,
530+
retryable_exceptions=retryable_exceptions,
531+
retry_limit=credentials.retries if credentials.retries > 3 else retry_limit,
532+
attempt=1,
533+
)
511534

512535
# convert DATETIMEOFFSET binary structures to datetime ojbects
513536
# https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794
@@ -568,4 +591,3 @@ def execute(
568591
while cursor.nextset():
569592
pass
570593
return response, table
571-

dbt/adapters/fabric/fabric_credentials.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ class FabricCredentials(Credentials):
1818
client_id: Optional[str] = None
1919
client_secret: Optional[str] = None
2020
access_token: Optional[str] = None
21-
authentication: Optional[str] = "ActiveDirectoryServicePrincipal"
21+
# Added for access token expiration for oAuth and integration tests scenarios.
22+
access_token_expires_on: Optional[int] = 0
23+
authentication: str = "ActiveDirectoryServicePrincipal"
2224
encrypt: Optional[bool] = True # default value in MS ODBC Driver 18 as well
2325
trust_cert: Optional[bool] = False # default value in MS ODBC Driver 18 as well
2426
retries: int = 3

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def run(self):
6666
packages=find_namespace_packages(include=["dbt", "dbt.*"]),
6767
include_package_data=True,
6868
install_requires=[
69-
"pyodbc>=4.0.35,<5.2.0",
69+
"pyodbc>=5.2.0",
7070
"azure-identity>=1.12.0",
7171
"dbt-common>=1.0.4,<2.0",
7272
"dbt-core>=1.8.0",

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def _profile_ci_azure_base():
5353
return {
5454
**_all_profiles_base(),
5555
**{
56-
"host": os.getenv("DBT_AZURESQL_SERVER"),
57-
"database": os.getenv("DBT_AZURESQL_DB"),
56+
"host": os.getenv("DBT_AZURESQL_SERVER", os.getenv("FABRIC_TEST_HOST")),
57+
"database": os.getenv("DBT_AZURESQL_DB", os.getenv("FABRIC_TEST_DBNAME")),
5858
"encrypt": True,
5959
"trust_cert": True,
6060
"trace_flag": False,

0 commit comments

Comments
 (0)