Skip to content

Commit 25b3174

Browse files
committed
refactor: optimize OkPacket parsing with mandatory parameters
1 parent 973f3d5 commit 25b3174

File tree

3 files changed

+58
-83
lines changed

3 files changed

+58
-83
lines changed

mariadb/impl/client/async_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ async def _read_result(self, is_binary: bool, config: 'Configuration' = None, bu
519519
)
520520

521521
# Create completion with streaming result
522-
completion = OkPacket()
522+
completion = OkPacket(0,0,0,0,b'')
523523
completion.result_set = streaming_result
524524
results.append(completion)
525525
return results

mariadb/impl/client/sync_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def _read_result(self, is_binary: bool, config: 'Configuration' = None, buffered
453453
)
454454

455455
# Create completion with streaming result
456-
completion = OkPacket()
456+
completion = OkPacket(0,0,0,0,b'')
457457
completion.result_set = streaming_result
458458
results.append(completion)
459459
return results

mariadb/impl/message/server/ok_packet.py

Lines changed: 56 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
from ...client.context import Context
1717

1818

19+
_PS_OUT_PARAMS_MASK = constants.STATUS.PS_OUT_PARAMS
20+
_SESSION_STATE_CHANGED = constants.STATUS.SESSION_STATE_CHANGED
21+
_SESSION_TRACKING_CAP = constants.CAPABILITY.SESSION_TRACKING
22+
1923
class OkPacket(Completion):
2024
"""
2125
OK Packet from MariaDB server
@@ -33,16 +37,15 @@ class OkPacket(Completion):
3337
'server_status',
3438
'info',
3539
)
40+
3641
def __init__(
3742
self,
38-
affected_rows: int = 0,
39-
insert_id: int = 0,
40-
server_status: int = 0,
41-
warning_count: int = 0,
42-
info: bytes = b'',
43+
affected_rows: int,
44+
insert_id: int,
45+
server_status: int,
46+
warning_count: int,
47+
info: bytes,
4348
):
44-
"""Initialize OK packet with affected rows, insert ID, status, warnings, and info"""
45-
# Direct assignment is faster than super().__init__
4649
self.affected_rows = affected_rows
4750
self.insert_id = insert_id
4851
self.warning_count = warning_count
@@ -51,95 +54,67 @@ def __init__(
5154
self.info = info
5255

5356
def is_output_parameters(self) -> bool:
54-
"""Check if completion has output parameters"""
55-
return (self.server_status & constants.STATUS.PS_OUT_PARAMS) != 0
57+
return (self.server_status & _PS_OUT_PARAMS_MASK) != 0
5658

5759
@staticmethod
5860
def decode(data: memoryview, context: 'Context') -> 'OkPacket':
59-
"""Decode OK packet from bytearray with context"""
6061
parser = PayloadParser(data)
6162

62-
parser.skip(1) # Skip OK marker (0x00 or 0xFE)
63+
parser.skip(1)
6364
affected_rows = parser.read_length_encoded_int()
6465
insert_id = parser.read_length_encoded_int()
65-
66-
# Read server_status and warning_count in one operation (4 bytes total)
6766
server_status = parser.read_int16()
6867
warning_count = parser.read_int16()
6968

70-
# Update context with server status (context is always present)
7169
context.server_status = server_status
7270
context.warning_count = warning_count
7371

74-
# Optional info string and session tracking
72+
# Fast path: no info/tracking (most common case)
73+
if not parser.has_remaining():
74+
return OkPacket(affected_rows, insert_id, server_status, warning_count, b'')
75+
7576
info = b''
76-
if parser.has_remaining():
77-
# Check if session tracking is present
78-
has_session_tracking = (context.has_capability(constants.CAPABILITY.SESSION_TRACKING) and
79-
(server_status & constants.STATUS.SESSION_STATE_CHANGED))
80-
81-
try:
82-
# Read info string length
83-
info_length = parser.read_length_encoded_int()
84-
if info_length > 0:
85-
# Read info bytes (may contain fingerprint validation hash)
86-
info = parser.read_bytes(info_length)
87-
88-
# Process session tracking data if present
89-
if has_session_tracking and parser.has_remaining():
90-
while parser.has_remaining():
91-
# Total length of session tracking data (length-encoded)
92-
total_length = parser.read_length_encoded_int()
93-
if total_length == 0:
94-
break
95-
96-
# Track start position to ensure we don't read beyond this tracking block
97-
start_pos = parser.pos
98-
99-
# Session tracking type (1 byte)
100-
tracking_type = parser.read_byte()
101-
102-
# Data length (length-encoded)
103-
data_length = parser.read_length_encoded_int()
104-
105-
# Process based on tracking type
106-
if tracking_type == constants.SESSION_TRACK.SYSTEM_VARIABLES:
107-
# System variable change
108-
end_pos = start_pos + total_length
109-
while parser.pos < end_pos:
110-
var_name_len = parser.read_length_encoded_int()
111-
var_name = parser.read_bytes(var_name_len).decode('utf-8')
112-
113-
var_value_len = parser.read_length_encoded_int()
114-
var_value = parser.read_bytes(var_value_len).decode('utf-8')
115-
116-
# Update context with system variable change
117-
if hasattr(context, 'update_system_variable'):
118-
context.update_system_variable(var_name, var_value)
77+
info_length = parser.read_length_encoded_int()
78+
if info_length > 0:
79+
info = parser.read_bytes(info_length)
80+
81+
# Session tracking check
82+
if ((server_status & _SESSION_STATE_CHANGED) and
83+
context.has_capability(_SESSION_TRACKING_CAP) and
84+
parser.has_remaining()):
85+
_process_session_tracking(parser, context)
11986

120-
elif tracking_type == constants.SESSION_TRACK.SCHEMA:
121-
# Schema change
122-
schema_len = parser.read_length_encoded_int()
123-
schema = parser.read_bytes(schema_len).decode('utf-8')
124-
if hasattr(context, 'database'):
125-
context.database = schema
126-
else:
127-
# Unknown tracking type - skip data
128-
parser.skip(data_length)
87+
return OkPacket(affected_rows, insert_id, server_status, warning_count, info)
12988

130-
# Ensure we're at the correct position
131-
expected_pos = start_pos + total_length
132-
if parser.pos < expected_pos:
133-
parser.skip(expected_pos - parser.pos)
13489

135-
except Exception:
136-
# Don't fail on info/session tracking errors
137-
pass
90+
def _process_session_tracking(parser: PayloadParser, context: 'Context') -> None:
91+
"""Process session tracking data (separate function for better branch prediction)"""
92+
while parser.has_remaining():
93+
total_length = parser.read_length_encoded_int()
94+
if total_length == 0:
95+
break
96+
97+
start_pos = parser.pos
98+
tracking_type = parser.read_byte()
99+
data_length = parser.read_length_encoded_int()
100+
101+
# NOT NEEDED FOR NOW
102+
#if tracking_type == constants.SESSION_TRACK.SYSTEM_VARIABLES:
103+
# end_pos = start_pos + total_length
104+
# while parser.pos < end_pos:
105+
# var_name_len = parser.read_length_encoded_int()
106+
# var_name = parser.read_bytes(var_name_len).decode('utf-8')
107+
# var_value_len = parser.read_length_encoded_int()
108+
# var_value = parser.read_bytes(var_value_len).decode('utf-8')
109+
# context.update_system_variable(var_name, var_value)
110+
111+
if tracking_type == constants.SESSION_TRACK.SCHEMA:
112+
schema_len = parser.read_length_encoded_int()
113+
context.database = parser.read_bytes(schema_len).decode('utf-8')
114+
115+
else:
116+
parser.skip(data_length)
138117

139-
return OkPacket(
140-
affected_rows,
141-
insert_id,
142-
server_status,
143-
warning_count,
144-
info
145-
)
118+
expected_pos = start_pos + total_length
119+
if parser.pos < expected_pos:
120+
parser.skip(expected_pos - parser.pos)

0 commit comments

Comments
 (0)