99import warnings
1010import configparser
1111import getpass
12+ import ssl as ssllib
1213from functools import partial
1314
1415from pymysql .charset import charset_by_name , charset_by_id
@@ -53,7 +54,7 @@ def connect(host="localhost", user=None, password="",
5354 connect_timeout = None , read_default_group = None ,
5455 autocommit = False , echo = False ,
5556 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
56- program_name = '' , server_public_key = None ):
57+ program_name = '' , server_public_key = None , implicit_tls = False ):
5758 """See connections.Connection.__init__() for information about
5859 defaults."""
5960 coro = _connect (host = host , user = user , password = password , db = db ,
@@ -66,7 +67,8 @@ def connect(host="localhost", user=None, password="",
6667 read_default_group = read_default_group ,
6768 autocommit = autocommit , echo = echo ,
6869 local_infile = local_infile , loop = loop , ssl = ssl ,
69- auth_plugin = auth_plugin , program_name = program_name )
70+ auth_plugin = auth_plugin , program_name = program_name ,
71+ implicit_tls = implicit_tls )
7072 return _ConnectionContextManager (coro )
7173
7274
@@ -142,7 +144,7 @@ def __init__(self, host="localhost", user=None, password="",
142144 connect_timeout = None , read_default_group = None ,
143145 autocommit = False , echo = False ,
144146 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
145- program_name = '' , server_public_key = None ):
147+ program_name = '' , server_public_key = None , implicit_tls = False ):
146148 """
147149 Establish a connection to the MySQL database. Accepts several
148150 arguments:
@@ -184,6 +186,9 @@ def __init__(self, host="localhost", user=None, password="",
184186 handshaking with MySQL. (omitted by default)
185187 :param server_public_key: SHA256 authentication plugin public
186188 key value.
189+ :param implicit_tls: Establish TLS immediately, skipping non-TLS
190+ preamble before upgrading to TLS.
191+ (default: False)
187192 :param loop: asyncio loop
188193 """
189194 self ._loop = loop or asyncio .get_event_loop ()
@@ -218,6 +223,7 @@ def __init__(self, host="localhost", user=None, password="",
218223 self ._auth_plugin_used = ""
219224 self ._secure = False
220225 self .server_public_key = server_public_key
226+ self ._implicit_tls = implicit_tls
221227 self .salt = None
222228
223229 from . import __version__
@@ -241,7 +247,10 @@ def __init__(self, host="localhost", user=None, password="",
241247 self .use_unicode = use_unicode
242248
243249 self ._ssl_context = ssl
244- if ssl :
250+ # TLS is required when implicit_tls is True
251+ if implicit_tls and not self ._ssl_context :
252+ self ._ssl_context = ssllib .create_default_context ()
253+ if ssl and not implicit_tls :
245254 client_flag |= CLIENT .SSL
246255
247256 self ._encoding = charset_by_name (self ._charset ).encoding
@@ -536,7 +545,8 @@ async def _connect(self):
536545
537546 self ._next_seq_id = 0
538547
539- await self ._get_server_information ()
548+ if not self ._implicit_tls :
549+ await self ._get_server_information ()
540550 await self ._request_authentication ()
541551
542552 self .connected_time = self ._loop .time ()
@@ -738,7 +748,8 @@ async def _execute_command(self, command, sql):
738748
739749 async def _request_authentication (self ):
740750 # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
741- if int (self .server_version .split ('.' , 1 )[0 ]) >= 5 :
751+ # FIXME: change this before merge
752+ if self ._implicit_tls or int (self .server_version .split ('.' , 1 )[0 ]) >= 5 :
742753 self .client_flag |= CLIENT .MULTI_RESULTS
743754
744755 if self .user is None :
@@ -748,8 +759,10 @@ async def _request_authentication(self):
748759 data_init = struct .pack ('<iIB23s' , self .client_flag , MAX_PACKET_LEN ,
749760 charset_id , b'' )
750761
751- if self ._ssl_context and self .server_capabilities & CLIENT .SSL :
752- self .write_packet (data_init )
762+ if self ._ssl_context and \
763+ (self ._implicit_tls or self .server_capabilities & CLIENT .SSL ):
764+ if not self ._implicit_tls :
765+ self .write_packet (data_init )
753766
754767 # Stop sending events to data_received
755768 self ._writer .transport .pause_reading ()
@@ -771,6 +784,9 @@ async def _request_authentication(self):
771784 server_hostname = self ._host
772785 )
773786
787+ if self ._implicit_tls :
788+ await self ._get_server_information ()
789+
774790 self ._secure = True
775791
776792 if isinstance (self .user , str ):
0 commit comments