Skip to content

Commit d73aa6c

Browse files
committed
refactor: optimize packet writing by pre-allocating header space in payload buffers
1 parent dbcae9d commit d73aa6c

19 files changed

+206
-128
lines changed

mariadb/impl/client/socket/write_stream.py

Lines changed: 71 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -65,50 +65,64 @@ async def write_payload(self, payload: bytes, packet_type: str = "", reset_seque
6565
Write payload with MariaDB packet framing (async version)
6666
6767
Args:
68-
payload: Payload bytes to send
68+
payload: Payload bytes with first 4 bytes reserved for header
6969
packet_type: Packet type for logging (e.g., "COM_QUERY")
7070
reset_sequence: Whether to reset sequence number before sending
7171
"""
7272
if reset_sequence:
7373
self.sequence.set(-1)
7474

75-
payload_len = len(payload)
76-
offset = 0
75+
# Payload has 4 bytes reserved at start for header
76+
payload_len = len(payload) - 4
77+
data_offset = 4 # Data starts after reserved header space
7778

7879
# Handle empty payload - still need to send header
7980
if payload_len == 0:
8081
seq = self.sequence.increment_and_get()
81-
header = b'\x00\x00\x00' + bytes([seq])
82+
# Write header into first 4 bytes
83+
payload_buf = bytearray(payload)
84+
payload_buf[0:4] = b'\x00\x00\x00' + bytes([seq])
8285

8386
if logger.isEnabledFor(logging.DEBUG):
8487
conn_id_str = f"[conn_id={self.connection_id}]" if self.connection_id >= 0 else ""
8588
packet_type_str = f" {packet_type}" if packet_type else ""
86-
logger.debug(hex_dump(header, f"SEND async: {conn_id_str}{packet_type_str}"))
89+
logger.debug(hex_dump(bytes(payload_buf[0:4]), f"SEND async: {conn_id_str}{packet_type_str}"))
8790

88-
self.writer.write(header)
91+
self.writer.write(payload_buf[0:4])
8992
await self.writer.drain()
9093
return
9194

95+
# Convert to bytearray for in-place header writing
96+
payload_buf = bytearray(payload)
97+
9298
# Handle packet splitting for large payloads
93-
while offset < payload_len:
94-
chunk_size = min(MAX_PACKET_SIZE, payload_len - offset)
99+
sent = 0
100+
101+
while sent < payload_len:
102+
chunk_size = min(MAX_PACKET_SIZE, payload_len - sent)
95103
seq = self.sequence.increment_and_get()
96104

97-
# Build header: 3-byte length + 1-byte sequence
98-
header = chunk_size.to_bytes(3, 'little') + bytes([seq])
105+
# Data for this chunk starts at data_offset + sent
106+
chunk_start = data_offset + sent
107+
chunk_end = chunk_start + chunk_size
108+
109+
# Write header 4 bytes before the chunk data
110+
header_pos = chunk_start - 4
111+
payload_buf[header_pos] = chunk_size & 0xff
112+
payload_buf[header_pos + 1] = (chunk_size >> 8) & 0xff
113+
payload_buf[header_pos + 2] = (chunk_size >> 16) & 0xff
114+
payload_buf[header_pos + 3] = seq
99115

100-
# Log if debug enabled (need to build full packet for logging)
116+
# Log if debug enabled
101117
if logger.isEnabledFor(logging.DEBUG):
102-
chunk = payload[offset:offset + chunk_size]
103-
packet = header + chunk
118+
packet = bytes(payload_buf[header_pos:chunk_end])
104119
conn_id_str = f"[conn_id={self.connection_id}]" if self.connection_id >= 0 else ""
105120
packet_type_str = f" {packet_type}" if packet_type else ""
106121
logger.debug(hex_dump(packet, f"SEND async: {conn_id_str}{packet_type_str}"))
107122

108-
# Send header and chunk separately (more efficient - no concatenation)
109-
self.writer.write(header)
110-
self.writer.write(payload[offset:offset + chunk_size])
111-
offset += chunk_size
123+
# Send packet: header + chunk data
124+
self.writer.write(payload_buf[header_pos:chunk_end])
125+
sent += chunk_size
112126

113127
# Flush all buffered data
114128
await self.writer.drain()
@@ -124,71 +138,79 @@ async def write_payload(self, payload: bytes, packet_type: str = "", reset_seque
124138
class SyncWriteStream(BaseWriteStream):
125139
"""Sync write stream implementation using blocking socket operations"""
126140

127-
def __init__(self, sock: socket.socket, connection_id: int = -1):
128-
"""
129-
Initialize sync write stream
130-
131-
Args:
132-
sock: Blocking socket
133-
connection_id: Connection ID for logging
134-
"""
135-
self.socket: socket.socket = sock
136-
super().__init__(connection_id)
141+
def __init__(self, socket: socket.socket, connection_id: int = -1):
142+
"""Initialize write stream with socket"""
143+
self.socket = socket
144+
self.sequence = MutableInt(-1)
145+
self.connection_id = connection_id
146+
# Check once if sendmsg is supported (Unix) or if we need sendall (Windows)
147+
self.has_sendmsg = hasattr(socket, 'sendmsg')
137148

138149
def write_payload(self, payload: bytes, packet_type: str = "", reset_sequence: bool = True) -> None:
139150
"""
140151
Write payload with MariaDB packet framing (sync version)
141152
142153
Args:
143-
payload: Payload bytes to send
154+
payload: Payload bytes with first 4 bytes reserved for header
144155
packet_type: Packet type for logging (e.g., "COM_QUERY")
145156
reset_sequence: Whether to reset sequence number before sending
146157
"""
147158
if reset_sequence:
148159
self.sequence.set(-1)
149160

150-
payload_len = len(payload)
151-
offset = 0
161+
# Payload has 4 bytes reserved at start for header
162+
payload_len = len(payload) - 4
163+
data_offset = 4 # Data starts after reserved header space
152164

153165
# Handle empty payload - still need to send header
154166
if payload_len == 0:
155167
seq = self.sequence.increment_and_get()
156-
header = b'\x00\x00\x00' + bytes([seq])
168+
# Write header into first 4 bytes
169+
payload_buf = bytearray(payload)
170+
payload_buf[0:4] = b'\x00\x00\x00' + bytes([seq])
157171

158172
if logger.isEnabledFor(logging.DEBUG):
159173
conn_id_str = f"[conn_id={self.connection_id}]" if self.connection_id >= 0 else ""
160174
packet_type_str = f" {packet_type}" if packet_type else ""
161-
logger.debug(hex_dump(header, f"SEND sync: {conn_id_str}{packet_type_str}"))
175+
logger.debug(hex_dump(bytes(payload_buf[0:4]), f"SEND sync: {conn_id_str}{packet_type_str}"))
162176

163-
self.socket.sendall(header)
177+
self.socket.sendall(payload_buf[0:4])
164178
return
165179

180+
# Convert to bytearray for in-place header writing
181+
payload_buf = bytearray(payload)
182+
166183
# Handle packet splitting for large payloads
167-
while offset < payload_len:
168-
chunk_size = min(MAX_PACKET_SIZE, payload_len - offset)
184+
sent = 0 # Track how much data we've sent
185+
186+
187+
while sent < payload_len:
188+
chunk_size = min(MAX_PACKET_SIZE, payload_len - sent)
169189
seq = self.sequence.increment_and_get()
170190

171-
# Build header: 3-byte length + 1-byte sequence
172-
header = chunk_size.to_bytes(3, 'little') + bytes([seq])
173-
chunk = payload[offset:offset + chunk_size]
191+
# Data for this chunk starts at data_offset + sent
192+
chunk_start = data_offset + sent
193+
chunk_end = chunk_start + chunk_size
194+
195+
# Write header 4 bytes before the chunk data
196+
header_pos = chunk_start - 4
197+
198+
payload_buf[header_pos] = chunk_size & 0xff
199+
payload_buf[header_pos + 1] = (chunk_size >> 8) & 0xff
200+
payload_buf[header_pos + 2] = (chunk_size >> 16) & 0xff
201+
payload_buf[header_pos + 3] = seq
174202

175-
# Log if debug enabled (need full packet for logging)
203+
# Log if debug enabled
176204
if logger.isEnabledFor(logging.DEBUG):
177-
packet = header + chunk
205+
packet = bytes(payload_buf[header_pos:chunk_end])
178206
conn_id_str = f"[conn_id={self.connection_id}]" if self.connection_id >= 0 else ""
179207
packet_type_str = f" {packet_type}" if packet_type else ""
180208
logger.debug(hex_dump(packet, f"SEND sync: {conn_id_str}{packet_type_str}"))
181209

182-
# Send header and chunk in a single syscall using scatter-gather I/O
183-
# sendmsg() is available on Unix and sends multiple buffers efficiently
184-
try:
185-
self.socket.sendmsg([header, chunk])
186-
except Exception:
187-
# Fallback for platforms without sendmsg (e.g., Windows)
188-
self.socket.sendall(header)
189-
self.socket.sendall(chunk)
210+
# Send packet: header + chunk data
211+
self.socket.sendall(payload_buf[header_pos:chunk_end])
190212

191-
offset += chunk_size
213+
sent += chunk_size
192214

193215
# If last packet was exactly MAX_PACKET_SIZE, send empty packet to signal end
194216
if payload_len % MAX_PACKET_SIZE == 0:

mariadb/impl/client/sync_client.py

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def __init__(self, configuration: Configuration) -> None:
6464
self.sequence: MutableInt = MutableInt(-1)
6565

6666
# Read buffer management
67-
self._recv_buf: bytearray = bytearray(8192)
67+
self._default_recv_buf: bytearray = bytearray(8192)
68+
self._recv_buf: bytearray = self._default_recv_buf
69+
6870
self._recv_pos = 0
6971
self._recv_len = 0
7072

@@ -79,7 +81,8 @@ def _ensure_space(self, needed):
7981
ALIGN = 16384
8082
if (len(self._recv_buf) - self._recv_len >= needed):
8183
return
82-
self._recv_buf.extend(bytearray((needed + ALIGN - 1) & ~(ALIGN - 1)))
84+
self._recv_buf = self._recv_buf + bytearray((needed + ALIGN - 1) & ~(ALIGN - 1))
85+
8386

8487
def _recv_into_buffer(self, size=0):
8588
"""
@@ -97,20 +100,29 @@ def _recv_into_buffer(self, size=0):
97100

98101
# Keep trying to read until we have enough data or there's nothing left
99102
try:
103+
if self.logger.isEnabledFor(logging.DEBUG):
104+
self.logger.debug(f"_recv_into_buffer: requesting size={size}, buffer_len={len(self._recv_buf)}, recv_len={self._recv_len}, recv_pos={self._recv_pos}")
105+
100106
if size == 0:
101107
n = self.socket.recv_into(mv[self._recv_len + received:])
108+
if self.logger.isEnabledFor(logging.DEBUG):
109+
self.logger.debug(f"_recv_into_buffer: received {n} bytes (no size limit)")
102110
if n == 0:
103111
raise ConnectionError("Connection reset by peer")
104112
return n
105113
while received < size:
106114
n = self.socket.recv_into(mv[self._recv_len + received:], size - received)
115+
if self.logger.isEnabledFor(logging.DEBUG):
116+
self.logger.debug(f"_recv_into_buffer: received {n} bytes, total {received + n}/{size}")
107117
if n == 0:
108118
raise ConnectionError("Connection reset by peer")
109119
received += n
110120
return received
111121

112122
except socket.timeout:
113-
raise TimeoutError("Socket recv timed out")
123+
if self.logger.isEnabledFor(logging.DEBUG):
124+
self.logger.debug(f"_recv_into_buffer: TIMEOUT after receiving {received} bytes (requested {size})")
125+
raise TimeoutError("Socket recv timed out")
114126

115127
except ConnectionResetError:
116128
raise ConnectionError("Connection reset by peer")
@@ -130,6 +142,7 @@ def read_payload(self):
130142
of the buffer
131143
132144
"""
145+
from ..debug_utils import hex_dump
133146

134147
# for faster local lookup
135148
PKT_HDR_SIZE=4
@@ -149,7 +162,7 @@ def read_payload(self):
149162

150163
first_pos = self._recv_pos
151164
total_size = 0
152-
payload_write_pos = None # Track where to write compacted payload
165+
packet_count = 0
153166

154167
while True:
155168
bytes_in_buffer = self._recv_len - self._recv_pos
@@ -180,31 +193,57 @@ def read_payload(self):
180193
continue
181194

182195
# We have complete packet (header + payload)
183-
if payload_write_pos is None:
184-
# First packet - payload starts after first header
185-
payload_write_pos = first_pos + PKT_HDR_SIZE
186-
elif self._recv_pos != payload_write_pos:
196+
packet_count += 1
197+
198+
# Log complete packet with data
199+
if self.logger.isEnabledFor(logging.DEBUG):
200+
packet_data = bytes(self._recv_buf[self._recv_pos:self._recv_pos + PKT_HDR_SIZE + packet_length])
201+
conn_id_str = f"[conn_id={self.connection_id}]" if hasattr(self, 'connection_id') and self.connection_id >= 0 else ""
202+
self.logger.debug(hex_dump(packet_data, f"RECV sync: {conn_id_str} packet {packet_count} complete"))
203+
204+
if packet_count > 1:
187205
# Multi-packet: compact by removing intermediate header
188-
# Move this packet's payload to the write position
189-
payload_start = self._recv_pos + PKT_HDR_SIZE
190-
self._recv_buf[payload_write_pos:payload_write_pos + packet_length] = \
191-
self._recv_buf[payload_start:payload_start + packet_length]
206+
# Move this packet's payload immediately after previous payload
207+
payload_src = self._recv_pos + PKT_HDR_SIZE
208+
payload_dst = first_pos + PKT_HDR_SIZE + total_size
209+
if payload_src != payload_dst:
210+
# Calculate how much data is after this packet
211+
data_after_packet = self._recv_len - (self._recv_pos + PKT_HDR_SIZE + packet_length)
212+
# Move this packet's payload
213+
self._recv_buf[payload_dst:payload_dst + packet_length] = \
214+
self._recv_buf[payload_src:payload_src + packet_length]
215+
# Move any data after this packet
216+
if data_after_packet > 0:
217+
self._recv_buf[payload_dst + packet_length:payload_dst + packet_length + data_after_packet] = \
218+
self._recv_buf[self._recv_pos + PKT_HDR_SIZE + packet_length:self._recv_len]
219+
# After compaction, adjust buffer length to account for removed header
220+
self._recv_len -= PKT_HDR_SIZE
192221

193-
payload_write_pos += packet_length
194222
total_size += packet_length
195223

196224
# Check if this is the last packet
197225
if packet_length < MAX_PKT_SIZE:
198226
# Last packet - return accumulated payload
227+
if self.logger.isEnabledFor(logging.DEBUG):
228+
conn_id_str = f"[conn_id={self.connection_id}]" if hasattr(self, 'connection_id') and self.connection_id >= 0 else ""
229+
self.logger.debug(f"RECV sync: {conn_id_str} complete multi-packet message: {packet_count} packets, {total_size} bytes total")
230+
199231
self._recv_pos = first_pos + PKT_HDR_SIZE + total_size
200232
return memoryview(self._recv_buf[first_pos + PKT_HDR_SIZE:first_pos + PKT_HDR_SIZE + total_size])
201233

202234
# Multi-packet: advance to next packet header
203-
self._recv_pos += PKT_HDR_SIZE + packet_length
235+
# After compaction, the next header is immediately after current payload
236+
if packet_count > 1:
237+
# After compaction, next header is at: first_pos + PKT_HDR_SIZE + total_size
238+
self._recv_pos = first_pos + PKT_HDR_SIZE + total_size
239+
else:
240+
# First packet, no compaction yet
241+
self._recv_pos += PKT_HDR_SIZE + packet_length
204242
else:
205243
self._recv_len += self._recv_into_buffer()
206244

207245
def reset_buffer(self):
246+
self._recv_buf = self._default_recv_buf
208247
self._recv_pos = 0
209248
self._recv_len = 0
210249

@@ -449,16 +488,29 @@ def execute_many(self, messages: List[ClientMessage], config: 'Configuration' =
449488
BATCH_SIZE = 1000
450489

451490
self.reset_buffer()
452-
for i in range(0, len(messages), BATCH_SIZE):
453-
batch = messages[i:i + BATCH_SIZE]
454-
455-
# Write batch
456-
for message in batch:
491+
492+
# For large payloads (>1MB), process one at a time to avoid buffer issues
493+
# For small payloads, batch for performance
494+
has_large_payload = any(len(msg.payload(self.context)) > 1024 * 1024 for msg in messages[:min(10, len(messages))])
495+
496+
if has_large_payload:
497+
# Process one command at a time for large payloads
498+
# This prevents TCP buffer issues and command mixing with multi-MB payloads
499+
for message in messages:
457500
self.write_stream.write_payload(message.payload(self.context), message.type(), True)
458-
459-
# Read responses for this batch
460-
for message in batch:
461501
results.append(self._read_result(message.is_binary(), config, buffered, prepare_stmt_packet))
502+
else:
503+
# Batch processing for small payloads
504+
for i in range(0, len(messages), BATCH_SIZE):
505+
batch = messages[i:i + BATCH_SIZE]
506+
507+
# Write batch
508+
for message in batch:
509+
self.write_stream.write_payload(message.payload(self.context), message.type(), True)
510+
511+
# Read responses for this batch
512+
for message in batch:
513+
results.append(self._read_result(message.is_binary(), config, buffered, prepare_stmt_packet))
462514

463515
except DatabaseError as e:
464516
raise e

mariadb/impl/debug_utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ def hex_dump(data: Union[bytes, bytearray], descr: str = "") -> str:
2525
if not data:
2626
return ""
2727

28+
MAX_DUMP_SIZE = 1024
29+
original_len = len(data)
30+
truncated = False
31+
32+
# Truncate if data is too large
33+
if len(data) > MAX_DUMP_SIZE:
34+
data = data[:MAX_DUMP_SIZE]
35+
truncated = True
36+
2837
lines = [f"{descr}"]
2938

3039
# Header
@@ -75,4 +84,8 @@ def hex_dump(data: Union[bytes, bytearray], descr: str = "") -> str:
7584
# Footer
7685
lines.append("+------+---------------------------------------------------+------------------+")
7786

87+
# Add truncation notice if data was truncated
88+
if truncated:
89+
lines.append(f"[DATA TRUNCATED: showing {MAX_DUMP_SIZE} of {original_len} bytes]")
90+
7891
return "\n".join(lines)

0 commit comments

Comments
 (0)