1616 from ...client .context import Context
1717
1818
19+ _PS_OUT_PARAMS_MASK = constants .STATUS .PS_OUT_PARAMS
20+ _SESSION_STATE_CHANGED = constants .STATUS .SESSION_STATE_CHANGED
21+ _SESSION_TRACKING_CAP = constants .CAPABILITY .SESSION_TRACKING
22+
1923class OkPacket (Completion ):
2024 """
2125 OK Packet from MariaDB server
@@ -33,16 +37,15 @@ class OkPacket(Completion):
3337 'server_status' ,
3438 'info' ,
3539 )
40+
3641 def __init__ (
3742 self ,
38- affected_rows : int = 0 ,
39- insert_id : int = 0 ,
40- server_status : int = 0 ,
41- warning_count : int = 0 ,
42- info : bytes = b'' ,
43+ affected_rows : int ,
44+ insert_id : int ,
45+ server_status : int ,
46+ warning_count : int ,
47+ info : bytes ,
4348 ):
44- """Initialize OK packet with affected rows, insert ID, status, warnings, and info"""
45- # Direct assignment is faster than super().__init__
4649 self .affected_rows = affected_rows
4750 self .insert_id = insert_id
4851 self .warning_count = warning_count
@@ -51,95 +54,67 @@ def __init__(
5154 self .info = info
5255
5356 def is_output_parameters (self ) -> bool :
54- """Check if completion has output parameters"""
55- return (self .server_status & constants .STATUS .PS_OUT_PARAMS ) != 0
57+ return (self .server_status & _PS_OUT_PARAMS_MASK ) != 0
5658
5759 @staticmethod
5860 def decode (data : memoryview , context : 'Context' ) -> 'OkPacket' :
59- """Decode OK packet from bytearray with context"""
6061 parser = PayloadParser (data )
6162
62- parser .skip (1 ) # Skip OK marker (0x00 or 0xFE)
63+ parser .skip (1 )
6364 affected_rows = parser .read_length_encoded_int ()
6465 insert_id = parser .read_length_encoded_int ()
65-
66- # Read server_status and warning_count in one operation (4 bytes total)
6766 server_status = parser .read_int16 ()
6867 warning_count = parser .read_int16 ()
6968
70- # Update context with server status (context is always present)
7169 context .server_status = server_status
7270 context .warning_count = warning_count
7371
74- # Optional info string and session tracking
72+ # Fast path: no info/tracking (most common case)
73+ if not parser .has_remaining ():
74+ return OkPacket (affected_rows , insert_id , server_status , warning_count , b'' )
75+
7576 info = b''
76- if parser .has_remaining ():
77- # Check if session tracking is present
78- has_session_tracking = (context .has_capability (constants .CAPABILITY .SESSION_TRACKING ) and
79- (server_status & constants .STATUS .SESSION_STATE_CHANGED ))
80-
81- try :
82- # Read info string length
83- info_length = parser .read_length_encoded_int ()
84- if info_length > 0 :
85- # Read info bytes (may contain fingerprint validation hash)
86- info = parser .read_bytes (info_length )
87-
88- # Process session tracking data if present
89- if has_session_tracking and parser .has_remaining ():
90- while parser .has_remaining ():
91- # Total length of session tracking data (length-encoded)
92- total_length = parser .read_length_encoded_int ()
93- if total_length == 0 :
94- break
95-
96- # Track start position to ensure we don't read beyond this tracking block
97- start_pos = parser .pos
98-
99- # Session tracking type (1 byte)
100- tracking_type = parser .read_byte ()
101-
102- # Data length (length-encoded)
103- data_length = parser .read_length_encoded_int ()
104-
105- # Process based on tracking type
106- if tracking_type == constants .SESSION_TRACK .SYSTEM_VARIABLES :
107- # System variable change
108- end_pos = start_pos + total_length
109- while parser .pos < end_pos :
110- var_name_len = parser .read_length_encoded_int ()
111- var_name = parser .read_bytes (var_name_len ).decode ('utf-8' )
112-
113- var_value_len = parser .read_length_encoded_int ()
114- var_value = parser .read_bytes (var_value_len ).decode ('utf-8' )
115-
116- # Update context with system variable change
117- if hasattr (context , 'update_system_variable' ):
118- context .update_system_variable (var_name , var_value )
77+ info_length = parser .read_length_encoded_int ()
78+ if info_length > 0 :
79+ info = parser .read_bytes (info_length )
80+
81+ # Session tracking check
82+ if ((server_status & _SESSION_STATE_CHANGED ) and
83+ context .has_capability (_SESSION_TRACKING_CAP ) and
84+ parser .has_remaining ()):
85+ _process_session_tracking (parser , context )
11986
120- elif tracking_type == constants .SESSION_TRACK .SCHEMA :
121- # Schema change
122- schema_len = parser .read_length_encoded_int ()
123- schema = parser .read_bytes (schema_len ).decode ('utf-8' )
124- if hasattr (context , 'database' ):
125- context .database = schema
126- else :
127- # Unknown tracking type - skip data
128- parser .skip (data_length )
87+ return OkPacket (affected_rows , insert_id , server_status , warning_count , info )
12988
130- # Ensure we're at the correct position
131- expected_pos = start_pos + total_length
132- if parser .pos < expected_pos :
133- parser .skip (expected_pos - parser .pos )
13489
135- except Exception :
136- # Don't fail on info/session tracking errors
137- pass
90+ def _process_session_tracking (parser : PayloadParser , context : 'Context' ) -> None :
91+ """Process session tracking data (separate function for better branch prediction)"""
92+ while parser .has_remaining ():
93+ total_length = parser .read_length_encoded_int ()
94+ if total_length == 0 :
95+ break
96+
97+ start_pos = parser .pos
98+ tracking_type = parser .read_byte ()
99+ data_length = parser .read_length_encoded_int ()
100+
101+ # NOT NEEDED FOR NOW
102+ #if tracking_type == constants.SESSION_TRACK.SYSTEM_VARIABLES:
103+ # end_pos = start_pos + total_length
104+ # while parser.pos < end_pos:
105+ # var_name_len = parser.read_length_encoded_int()
106+ # var_name = parser.read_bytes(var_name_len).decode('utf-8')
107+ # var_value_len = parser.read_length_encoded_int()
108+ # var_value = parser.read_bytes(var_value_len).decode('utf-8')
109+ # context.update_system_variable(var_name, var_value)
110+
111+ if tracking_type == constants .SESSION_TRACK .SCHEMA :
112+ schema_len = parser .read_length_encoded_int ()
113+ context .database = parser .read_bytes (schema_len ).decode ('utf-8' )
114+
115+ else :
116+ parser .skip (data_length )
138117
139- return OkPacket (
140- affected_rows ,
141- insert_id ,
142- server_status ,
143- warning_count ,
144- info
145- )
118+ expected_pos = start_pos + total_length
119+ if parser .pos < expected_pos :
120+ parser .skip (expected_pos - parser .pos )
0 commit comments