1+ import atexit
12import datetime as dt
23import struct
4+ import sys
5+ import threading
36import time
47from contextlib import contextmanager
58from itertools import chain , repeat
2124
2225from dbt .adapters .fabric import __version__
2326from dbt .adapters .fabric .fabric_credentials import FabricCredentials
27+ from dbt .adapters .fabric .warehouse_snapshots import WarehouseSnapshotManager as wh_snapshot_manager
28+
29+ _init_done = False
30+ _init_lock = threading .Lock ()
31+ _snapshot_manager = None
32+
33+ # Command filtering
34+ TARGET_COMMANDS = {"run" , "build" , "snapshot" }
35+
2436
2537AZURE_CREDENTIAL_SCOPE = "https://database.windows.net//.default"
38+ POWER_BI_CREDENTIAL_SCOPE = "https://api.fabric.microsoft.com/.default"
2639SYNAPSE_SPARK_CREDENTIAL_SCOPE = "DW"
2740_TOKEN : Optional [AccessToken ] = None
28- AZURE_AUTH_FUNCTION_TYPE = Callable [[FabricCredentials ], AccessToken ]
41+ AZURE_AUTH_FUNCTION_TYPE = Callable [[FabricCredentials , Optional [ str ] ], AccessToken ]
2942
3043logger = AdapterLogger ("fabric" )
3144
@@ -88,7 +101,9 @@ def convert_access_token_to_mswindows_byte_string(token: AccessToken) -> bytes:
88101 return convert_bytes_to_mswindows_byte_string (value )
89102
90103
91- def get_synapse_spark_access_token (credentials : FabricCredentials ) -> AccessToken :
104+ def get_synapse_spark_access_token (
105+ credentials : FabricCredentials , scope : Optional [str ] = SYNAPSE_SPARK_CREDENTIAL_SCOPE
106+ ) -> AccessToken :
92107 """
93108 Get an Azure access token by using mspsarkutils
94109 Parameters
@@ -102,7 +117,7 @@ def get_synapse_spark_access_token(credentials: FabricCredentials) -> AccessToke
102117 """
103118 from notebookutils import mssparkutils
104119
105- aad_token = mssparkutils .credentials .getToken (SYNAPSE_SPARK_CREDENTIAL_SCOPE )
120+ aad_token = mssparkutils .credentials .getToken (scope )
106121 expires_on = int (time .time () + 4500.0 )
107122 token = AccessToken (
108123 token = aad_token ,
@@ -111,7 +126,9 @@ def get_synapse_spark_access_token(credentials: FabricCredentials) -> AccessToke
111126 return token
112127
113128
114- def get_cli_access_token (credentials : FabricCredentials ) -> AccessToken :
129+ def get_cli_access_token (
130+ credentials : FabricCredentials , scope : Optional [str ] = AZURE_CREDENTIAL_SCOPE
131+ ) -> AccessToken :
115132 """
116133 Get an Azure access token using the CLI credentials
117134
@@ -133,12 +150,14 @@ def get_cli_access_token(credentials: FabricCredentials) -> AccessToken:
133150 """
134151 _ = credentials
135152 token = AzureCliCredential ().get_token (
136- AZURE_CREDENTIAL_SCOPE , timeout = getattr (credentials , "login_timeout" , None )
153+ scope , timeout = getattr (credentials , "login_timeout" , None )
137154 )
138155 return token
139156
140157
141- def get_auto_access_token (credentials : FabricCredentials ) -> AccessToken :
158+ def get_auto_access_token (
159+ credentials : FabricCredentials , scope : Optional [str ] = AZURE_CREDENTIAL_SCOPE
160+ ) -> AccessToken :
142161 """
143162 Get an Azure access token automatically through azure-identity
144163
@@ -153,12 +172,14 @@ def get_auto_access_token(credentials: FabricCredentials) -> AccessToken:
153172 The access token.
154173 """
155174 token = DefaultAzureCredential ().get_token (
156- AZURE_CREDENTIAL_SCOPE , timeout = getattr (credentials , "login_timeout" , None )
175+ scope , timeout = getattr (credentials , "login_timeout" , None )
157176 )
158177 return token
159178
160179
161- def get_environment_access_token (credentials : FabricCredentials ) -> AccessToken :
180+ def get_environment_access_token (
181+ credentials : FabricCredentials , scope : Optional [str ] = AZURE_CREDENTIAL_SCOPE
182+ ) -> AccessToken :
162183 """
163184 Get an Azure access token by reading environment variables
164185
@@ -173,7 +194,7 @@ def get_environment_access_token(credentials: FabricCredentials) -> AccessToken:
173194 The access token.
174195 """
175196 token = EnvironmentCredential ().get_token (
176- AZURE_CREDENTIAL_SCOPE , timeout = getattr (credentials , "login_timeout" , None )
197+ scope , timeout = getattr (credentials , "login_timeout" , None )
177198 )
178199 return token
179200
@@ -206,7 +227,9 @@ def get_pyodbc_attrs_before_credentials(credentials: FabricCredentials) -> Dict:
206227
207228 if credentials .authentication .lower () in AZURE_AUTH_FUNCTIONS :
208229 if not _TOKEN or (_TOKEN .expires_on - time .time () < MAX_REMAINING_TIME ):
209- _TOKEN = AZURE_AUTH_FUNCTIONS [credentials .authentication .lower ()](credentials )
230+ _TOKEN = AZURE_AUTH_FUNCTIONS [credentials .authentication .lower ()](
231+ credentials , AZURE_CREDENTIAL_SCOPE
232+ )
210233 return {sql_copt_ss_access_token : convert_access_token_to_mswindows_byte_string (_TOKEN )}
211234
212235 if credentials .authentication .lower () == "activedirectoryaccesstoken" :
@@ -280,6 +303,63 @@ def byte_array_to_datetime(value: bytes) -> dt.datetime:
280303 )
281304
282305
306+ def _should_run_init () -> bool :
307+ """Check if we should run init for this command."""
308+ try :
309+ argv_lower = [a .lower () for a in sys .argv ]
310+ # Only run for run, build, snapshot
311+ return any (cmd in argv_lower for cmd in TARGET_COMMANDS )
312+ except Exception :
313+ return False
314+
315+
316+ def _run_start_action (credentials : FabricCredentials ) -> Dict [str , Any ]:
317+ """Enhanced run start action with snapshot management."""
318+ global _snapshot_manager
319+
320+ try :
321+ # Get credentials from dbt context
322+ workspace_id = credentials .workspace_id
323+ if workspace_id is None :
324+ logger .warning ("No workspace_id provided; skipping snapshot management." )
325+ return {}
326+
327+ access_token = AZURE_AUTH_FUNCTIONS [credentials .authentication .lower ()](
328+ credentials , POWER_BI_CREDENTIAL_SCOPE
329+ ).token
330+ _snapshot_manager = wh_snapshot_manager (workspace_id , access_token )
331+
332+ if credentials .warehouse_snapshot_name is None :
333+ logger .info (
334+ "No warehouse snapshot name provided; skipping pre-run snapshot management."
335+ )
336+ return {}
337+
338+ snapshot_Result = _snapshot_manager .orchestrate_snapshot_management (
339+ warehouse_name = credentials .database ,
340+ snapshot_name = credentials .warehouse_snapshot_name ,
341+ )
342+ return snapshot_Result
343+ except Exception as e :
344+ logger .error (f"Pre-run snapshot failed: { e } " )
345+ raise e
346+
347+
348+ def _run_end_action (snapshot_result : Optional [Dict [str , Any ]] = None ):
349+ """Enhanced run end action with snapshot result."""
350+ global _snapshot_manager
351+
352+ try :
353+ if snapshot_result and _snapshot_manager is not None :
354+ print (
355+ "Updating warehouse snapshot timestamp at end of run..." ,
356+ snapshot_result ["displayName" ],
357+ )
358+ _snapshot_manager .update_warehouse_snapshot (snapshot_id = snapshot_result ["snapshot_id" ])
359+ except Exception as e :
360+ logger .error (f"Post-run action failed: { e } " )
361+
362+
283363class FabricConnectionManager (SQLConnectionManager ):
284364 TYPE = "fabric"
285365
@@ -422,14 +502,29 @@ def connect():
422502 logger .debug (f"Connected to db: { credentials .database } " )
423503 return handle
424504
425- return cls .retry_connection (
505+ # Open the connection (with retries) and capture the returned Connection.
506+ conn = cls .retry_connection (
426507 connection ,
427508 connect = connect ,
428509 logger = logger ,
429510 retry_limit = credentials .retries ,
430511 retryable_exceptions = retryable_exceptions ,
431512 )
432513
514+ # Simple one-time init with command detection
515+ if _should_run_init ():
516+ global _init_done
517+ with _init_lock :
518+ if not _init_done :
519+ try :
520+ result = _run_start_action (credentials )
521+ atexit .register (lambda : _run_end_action (result ))
522+ except Exception as e :
523+ logger .debug ("Failed to run init actions" , e )
524+ _init_done = True
525+
526+ return conn
527+
433528 def cancel (self , connection : Connection ):
434529 logger .debug ("Cancel query" )
435530
0 commit comments