Skip to content

Commit 6d9396c

Browse files
Merge pull request #328 from microsoft/v1.9.7
Added merge capabilities and enabled warehouse snapshots
2 parents 9882d7a + a538c27 commit 6d9396c

File tree

9 files changed

+552
-42
lines changed

9 files changed

+552
-42
lines changed

dbt/adapters/fabric/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
version = "1.9.6"
1+
version = "1.9.7"

dbt/adapters/fabric/fabric_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def valid_incremental_strategies(self):
177177
"""The set of standard builtin strategies which this adapter supports out-of-the-box.
178178
Not used to validate custom strategies defined by end users.
179179
"""
180-
return ["append", "delete+insert", "microbatch"]
180+
return ["append", "delete+insert", "microbatch", "merge"]
181181

182182
# This is for use in the test suite
183183
def run_sql_for_tests(self, sql, fetch, conn):

dbt/adapters/fabric/fabric_connection_manager.py

Lines changed: 106 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import atexit
12
import datetime as dt
23
import struct
4+
import sys
5+
import threading
36
import time
47
from contextlib import contextmanager
58
from itertools import chain, repeat
@@ -21,11 +24,21 @@
2124

2225
from dbt.adapters.fabric import __version__
2326
from dbt.adapters.fabric.fabric_credentials import FabricCredentials
27+
from dbt.adapters.fabric.warehouse_snapshots import WarehouseSnapshotManager as wh_snapshot_manager
28+
29+
_init_done = False
30+
_init_lock = threading.Lock()
31+
_snapshot_manager = None
32+
33+
# Command filtering
34+
TARGET_COMMANDS = {"run", "build", "snapshot"}
35+
2436

2537
AZURE_CREDENTIAL_SCOPE = "https://database.windows.net//.default"
38+
POWER_BI_CREDENTIAL_SCOPE = "https://api.fabric.microsoft.com/.default"
2639
SYNAPSE_SPARK_CREDENTIAL_SCOPE = "DW"
2740
_TOKEN: Optional[AccessToken] = None
28-
AZURE_AUTH_FUNCTION_TYPE = Callable[[FabricCredentials], AccessToken]
41+
AZURE_AUTH_FUNCTION_TYPE = Callable[[FabricCredentials, Optional[str]], AccessToken]
2942

3043
logger = AdapterLogger("fabric")
3144

@@ -88,7 +101,9 @@ def convert_access_token_to_mswindows_byte_string(token: AccessToken) -> bytes:
88101
return convert_bytes_to_mswindows_byte_string(value)
89102

90103

91-
def get_synapse_spark_access_token(credentials: FabricCredentials) -> AccessToken:
104+
def get_synapse_spark_access_token(
105+
credentials: FabricCredentials, scope: Optional[str] = SYNAPSE_SPARK_CREDENTIAL_SCOPE
106+
) -> AccessToken:
92107
"""
93108
Get an Azure access token by using mspsarkutils
94109
Parameters
@@ -102,7 +117,7 @@ def get_synapse_spark_access_token(credentials: FabricCredentials) -> AccessToke
102117
"""
103118
from notebookutils import mssparkutils
104119

105-
aad_token = mssparkutils.credentials.getToken(SYNAPSE_SPARK_CREDENTIAL_SCOPE)
120+
aad_token = mssparkutils.credentials.getToken(scope)
106121
expires_on = int(time.time() + 4500.0)
107122
token = AccessToken(
108123
token=aad_token,
@@ -111,7 +126,9 @@ def get_synapse_spark_access_token(credentials: FabricCredentials) -> AccessToke
111126
return token
112127

113128

114-
def get_cli_access_token(credentials: FabricCredentials) -> AccessToken:
129+
def get_cli_access_token(
130+
credentials: FabricCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE
131+
) -> AccessToken:
115132
"""
116133
Get an Azure access token using the CLI credentials
117134
@@ -133,12 +150,14 @@ def get_cli_access_token(credentials: FabricCredentials) -> AccessToken:
133150
"""
134151
_ = credentials
135152
token = AzureCliCredential().get_token(
136-
AZURE_CREDENTIAL_SCOPE, timeout=getattr(credentials, "login_timeout", None)
153+
scope, timeout=getattr(credentials, "login_timeout", None)
137154
)
138155
return token
139156

140157

141-
def get_auto_access_token(credentials: FabricCredentials) -> AccessToken:
158+
def get_auto_access_token(
159+
credentials: FabricCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE
160+
) -> AccessToken:
142161
"""
143162
Get an Azure access token automatically through azure-identity
144163
@@ -153,12 +172,14 @@ def get_auto_access_token(credentials: FabricCredentials) -> AccessToken:
153172
The access token.
154173
"""
155174
token = DefaultAzureCredential().get_token(
156-
AZURE_CREDENTIAL_SCOPE, timeout=getattr(credentials, "login_timeout", None)
175+
scope, timeout=getattr(credentials, "login_timeout", None)
157176
)
158177
return token
159178

160179

161-
def get_environment_access_token(credentials: FabricCredentials) -> AccessToken:
180+
def get_environment_access_token(
181+
credentials: FabricCredentials, scope: Optional[str] = AZURE_CREDENTIAL_SCOPE
182+
) -> AccessToken:
162183
"""
163184
Get an Azure access token by reading environment variables
164185
@@ -173,7 +194,7 @@ def get_environment_access_token(credentials: FabricCredentials) -> AccessToken:
173194
The access token.
174195
"""
175196
token = EnvironmentCredential().get_token(
176-
AZURE_CREDENTIAL_SCOPE, timeout=getattr(credentials, "login_timeout", None)
197+
scope, timeout=getattr(credentials, "login_timeout", None)
177198
)
178199
return token
179200

@@ -206,7 +227,9 @@ def get_pyodbc_attrs_before_credentials(credentials: FabricCredentials) -> Dict:
206227

207228
if credentials.authentication.lower() in AZURE_AUTH_FUNCTIONS:
208229
if not _TOKEN or (_TOKEN.expires_on - time.time() < MAX_REMAINING_TIME):
209-
_TOKEN = AZURE_AUTH_FUNCTIONS[credentials.authentication.lower()](credentials)
230+
_TOKEN = AZURE_AUTH_FUNCTIONS[credentials.authentication.lower()](
231+
credentials, AZURE_CREDENTIAL_SCOPE
232+
)
210233
return {sql_copt_ss_access_token: convert_access_token_to_mswindows_byte_string(_TOKEN)}
211234

212235
if credentials.authentication.lower() == "activedirectoryaccesstoken":
@@ -280,6 +303,63 @@ def byte_array_to_datetime(value: bytes) -> dt.datetime:
280303
)
281304

282305

306+
def _should_run_init() -> bool:
307+
"""Check if we should run init for this command."""
308+
try:
309+
argv_lower = [a.lower() for a in sys.argv]
310+
# Only run for run, build, snapshot
311+
return any(cmd in argv_lower for cmd in TARGET_COMMANDS)
312+
except Exception:
313+
return False
314+
315+
316+
def _run_start_action(credentials: FabricCredentials) -> Dict[str, Any]:
317+
"""Enhanced run start action with snapshot management."""
318+
global _snapshot_manager
319+
320+
try:
321+
# Get credentials from dbt context
322+
workspace_id = credentials.workspace_id
323+
if workspace_id is None:
324+
logger.warning("No workspace_id provided; skipping snapshot management.")
325+
return {}
326+
327+
access_token = AZURE_AUTH_FUNCTIONS[credentials.authentication.lower()](
328+
credentials, POWER_BI_CREDENTIAL_SCOPE
329+
).token
330+
_snapshot_manager = wh_snapshot_manager(workspace_id, access_token)
331+
332+
if credentials.warehouse_snapshot_name is None:
333+
logger.info(
334+
"No warehouse snapshot name provided; skipping pre-run snapshot management."
335+
)
336+
return {}
337+
338+
snapshot_Result = _snapshot_manager.orchestrate_snapshot_management(
339+
warehouse_name=credentials.database,
340+
snapshot_name=credentials.warehouse_snapshot_name,
341+
)
342+
return snapshot_Result
343+
except Exception as e:
344+
logger.error(f"Pre-run snapshot failed: {e}")
345+
raise e
346+
347+
348+
def _run_end_action(snapshot_result: Optional[Dict[str, Any]] = None):
349+
"""Enhanced run end action with snapshot result."""
350+
global _snapshot_manager
351+
352+
try:
353+
if snapshot_result and _snapshot_manager is not None:
354+
print(
355+
"Updating warehouse snapshot timestamp at end of run...",
356+
snapshot_result["displayName"],
357+
)
358+
_snapshot_manager.update_warehouse_snapshot(snapshot_id=snapshot_result["snapshot_id"])
359+
except Exception as e:
360+
logger.error(f"Post-run action failed: {e}")
361+
362+
283363
class FabricConnectionManager(SQLConnectionManager):
284364
TYPE = "fabric"
285365

@@ -422,14 +502,29 @@ def connect():
422502
logger.debug(f"Connected to db: {credentials.database}")
423503
return handle
424504

425-
return cls.retry_connection(
505+
# Open the connection (with retries) and capture the returned Connection.
506+
conn = cls.retry_connection(
426507
connection,
427508
connect=connect,
428509
logger=logger,
429510
retry_limit=credentials.retries,
430511
retryable_exceptions=retryable_exceptions,
431512
)
432513

514+
# Simple one-time init with command detection
515+
if _should_run_init():
516+
global _init_done
517+
with _init_lock:
518+
if not _init_done:
519+
try:
520+
result = _run_start_action(credentials)
521+
atexit.register(lambda: _run_end_action(result))
522+
except Exception as e:
523+
logger.debug("Failed to run init actions", e)
524+
_init_done = True
525+
526+
return conn
527+
433528
def cancel(self, connection: Connection):
434529
logger.debug("Cancel query")
435530

dbt/adapters/fabric/fabric_credentials.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ class FabricCredentials(Credentials):
2727
schema_authorization: Optional[str] = None
2828
login_timeout: Optional[int] = 0
2929
query_timeout: Optional[int] = 0
30+
workspace_id: Optional[str] = None
31+
warehouse_snapshot_name: Optional[str] = None
32+
warehouse_snapshot_id: Optional[str] = None
33+
snapshot_timestamp: Optional[str] = None
3034

3135
_ALIASES = {
3236
"user": "UID",
@@ -41,12 +45,24 @@ class FabricCredentials(Credentials):
4145
"TrustServerCertificate": "trust_cert",
4246
"schema_auth": "schema_authorization",
4347
"SQL_ATTR_TRACE": "trace_flag",
48+
"workspace_id": "workspace_id",
49+
"warehouse_snapshot_name": "warehouse_snapshot_name",
4450
}
4551

4652
@property
4753
def type(self):
4854
return "fabric"
4955

56+
def validate_snapshot_properties(self):
57+
workspace_provided = self.workspace_id is not None
58+
snapshot_name_provided = self.warehouse_snapshot_name is not None
59+
60+
if workspace_provided != snapshot_name_provided:
61+
raise ValueError(
62+
"Both workspace_id and warehouse_snapshot_name must be provided together, "
63+
"or both must be None. Cannot have one without the other."
64+
)
65+
5066
def _connection_keys(self):
5167
# return an iterator of keys to pretty-print in 'dbt debug'
5268
# raise NotImplementedError
@@ -56,19 +72,23 @@ def _connection_keys(self):
5672
if self.authentication.lower().strip() == "serviceprincipal":
5773
self.authentication = "ActiveDirectoryServicePrincipal"
5874

75+
self.validate_snapshot_properties()
76+
5977
return (
6078
"server",
6179
"database",
6280
"schema",
81+
"warehouse_snapshot_name",
82+
"snapshot_timestamp",
6383
"UID",
64-
"client_id",
84+
"workspace_id",
6585
"authentication",
66-
"encrypt",
67-
"trust_cert",
6886
"retries",
6987
"login_timeout",
7088
"query_timeout",
7189
"trace_flag",
90+
"encrypt",
91+
"trust_cert",
7292
)
7393

7494
@property

0 commit comments

Comments
 (0)