Skip to content

Commit 2a12fbd

Browse files
committed
refactor: add pipeline configuration option for prepared statement execution
1 parent 96b76cc commit 2a12fbd

File tree

3 files changed

+122
-34
lines changed

3 files changed

+122
-34
lines changed

mariadb/impl/client/async_client.py

Lines changed: 55 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -526,35 +526,74 @@ async def execute_stmt(self, sql: str, messages: List[ClientMessage], config: 'C
526526
# Not in cache, prepare once and execute all
527527
from ..message.client.prepare_packet import PreparePacket
528528
prepare_message = PreparePacket(sql)
529-
await self.write_payload(prepare_message.payload(self.context), prepare_message.type(), True)
530-
for message in messages:
531-
await self.write_payload(message.payload(self.context), message.type(), True)
532-
prepareResult = None
533-
first_error = None
534-
try:
535-
prepareResult = await self._parse_prepare_response(await self.read_payload(), sql)
536-
except DatabaseError as e:
537-
first_error = e
538-
finally:
539-
# Ensure reading, even if prepared has an error
529+
530+
# Check if pipelining is enabled and server supports BULK operations
531+
use_pipeline = (self.configuration.pipeline and
532+
self.context.has_capability(constants.CAPABILITY.BULK_OPERATIONS))
533+
534+
if use_pipeline:
535+
# Pipeline mode: write prepare and all execute messages before reading
536+
await self.write_payload(prepare_message.payload(self.context), prepare_message.type(), True)
537+
for message in messages:
538+
await self.write_payload(message.payload(self.context), message.type(), True)
539+
prepareResult = None
540+
first_error = None
541+
try:
542+
prepareResult = await self._parse_prepare_response(await self.read_payload(), sql)
543+
except DatabaseError as e:
544+
first_error = e
545+
finally:
546+
# Ensure reading, even if prepared has an error
540547

548+
all_completions = []
549+
for message in messages:
550+
try:
551+
completions = await self._read_result(message.is_binary(), config, buffered, prepareResult)
552+
all_completions.append(completions)
553+
except DatabaseError as e:
554+
if not first_error:
555+
first_error = e
556+
557+
if prepareResult:
558+
if self.configuration.cache_prep_stmts:
559+
self.prepared_statement_cache[key] = prepareResult
560+
prepareResult.close()
561+
562+
if first_error:
563+
raise first_error
564+
return all_completions
565+
else:
566+
# Non-pipeline mode: read prepare response before writing execute messages
567+
await self.write_payload(prepare_message.payload(self.context), prepare_message.type(), True)
568+
569+
prepareResult = None
570+
first_error = None
571+
try:
572+
prepareResult = await self._parse_prepare_response(await self.read_payload(), sql)
573+
except DatabaseError as e:
574+
first_error = e
575+
raise
576+
577+
# Now write and read execute messages
541578
all_completions = []
542579
for message in messages:
580+
message.statement_id = prepareResult.statement_id
581+
await self.write_payload(message.payload(self.context), message.type(), True)
543582
try:
544583
completions = await self._read_result(message.is_binary(), config, buffered, prepareResult)
545584
all_completions.append(completions)
546585
except DatabaseError as e:
547586
if not first_error:
548587
first_error = e
549-
588+
550589
if prepareResult:
551590
if self.configuration.cache_prep_stmts:
552591
self.prepared_statement_cache[key] = prepareResult
553592
prepareResult.close()
554-
555-
if first_error:
556-
raise first_error
557-
return all_completions
593+
594+
if first_error:
595+
raise first_error
596+
return all_completions
558597

559598
except DatabaseError as e:
560599
raise e

mariadb/impl/client/sync_client.py

Lines changed: 59 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -530,37 +530,79 @@ def execute_stmt(self, sql: str, messages: List[ClientMessage], config: 'Configu
530530
# Not in cache, prepare once and execute all
531531
from ..message.client.prepare_packet import PreparePacket
532532
prepare_message = PreparePacket(sql)
533-
self.write_payload(prepare_message.payload(self.context), prepare_message.type(), True)
534-
for message in messages:
535-
self.write_payload(message.payload(self.context), message.type(), True)
536-
self.reset_buffer()
537533

538-
prepareResult = None
539-
first_error = None
540-
try:
541-
prepareResult = self._parse_prepare_response(self.read_payload(), sql)
542-
except DatabaseError as e:
543-
first_error = e
544-
finally:
545-
# Ensure reading, even if prepared has an error
534+
# Check if pipelining is enabled and server supports BULK operations
535+
use_pipeline = (self.configuration.pipeline and
536+
self.context.has_capability(constants.CAPABILITY.BULK_OPERATIONS))
537+
538+
if use_pipeline:
539+
# Pipeline mode: write prepare and all execute messages before reading
540+
self.write_payload(prepare_message.payload(self.context), prepare_message.type(), True)
541+
542+
for message in messages:
543+
self.write_payload(message.payload(self.context), message.type(), True)
544+
self.reset_buffer()
545+
546+
prepareResult = None
547+
first_error = None
548+
try:
549+
prepareResult = self._parse_prepare_response(self.read_payload(), sql)
550+
except DatabaseError as e:
551+
first_error = e
552+
finally:
553+
# Ensure reading, even if prepared has an error
546554

555+
all_completions = []
556+
for message in messages:
557+
try:
558+
completions = self._read_result(message.is_binary(), config, buffered, prepareResult)
559+
all_completions.append(completions)
560+
except DatabaseError as e:
561+
if not first_error:
562+
first_error = e
563+
564+
if prepareResult:
565+
if self.configuration.cache_prep_stmts:
566+
self.prepared_statement_cache[key] = prepareResult
567+
prepareResult.close()
568+
569+
if first_error:
570+
raise first_error
571+
return all_completions
572+
else:
573+
# Non-pipeline mode: read prepare response before writing execute messages
574+
self.write_payload(prepare_message.payload(self.context), prepare_message.type(), True)
575+
self.reset_buffer()
576+
577+
prepareResult = None
578+
first_error = None
579+
try:
580+
prepareResult = self._parse_prepare_response(self.read_payload(), sql)
581+
except DatabaseError as e:
582+
first_error = e
583+
raise
584+
585+
# Now write and read execute messages
547586
all_completions = []
548587
for message in messages:
588+
message.statement_id = prepareResult.statement_id
589+
self.write_payload(message.payload(self.context), message.type(), True)
590+
self.reset_buffer()
549591
try:
550592
completions = self._read_result(message.is_binary(), config, buffered, prepareResult)
551593
all_completions.append(completions)
552594
except DatabaseError as e:
553595
if not first_error:
554596
first_error = e
555-
597+
556598
if prepareResult:
557599
if self.configuration.cache_prep_stmts:
558600
self.prepared_statement_cache[key] = prepareResult
559601
prepareResult.close()
560-
561-
if first_error:
562-
raise first_error
563-
return all_completions
602+
603+
if first_error:
604+
raise first_error
605+
return all_completions
564606

565607
except DatabaseError as e:
566608
raise e

mariadb/impl/configuration.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ class Configuration:
6565
cache_prep_stmts: bool = True # Enable prepared statement caching
6666
prep_stmt_cache_size: int = 100 # Maximum number of cached prepared statements
6767

68+
# Pipeline option
69+
pipeline: bool = True # Enable pipelining for prepared statements
70+
6871
# Additional options
6972
non_mapped_options: Dict[str, Any] = field(default_factory=dict)
7073

@@ -198,6 +201,10 @@ def from_dict(cls, params: Dict[str, Any]) -> 'Configuration':
198201
if 'prep_stmt_cache_size' in params:
199202
config.prep_stmt_cache_size = int(params['prep_stmt_cache_size'])
200203

204+
# Pipeline option
205+
if 'pipeline' in params:
206+
config.pipeline = bool(params['pipeline'])
207+
201208
# Store any unmapped options
202209
valid_params = {
203210
'host', 'hostname', 'server', 'user', 'username', 'password', 'passwd',
@@ -210,7 +217,7 @@ def from_dict(cls, params: Dict[str, Any]) -> 'Configuration':
210217
'compress',
211218
'query_timeout', 'max_allowed_packet',
212219
'character_encoding', 'charset', 'init_command', 'converter', 'named_tuple', 'dictionary', 'native_object',
213-
'cache_prep_stmts', 'prep_stmt_cache_size'
220+
'cache_prep_stmts', 'prep_stmt_cache_size', 'pipeline'
214221
}
215222

216223
for key, value in params.items():

0 commit comments

Comments
 (0)