Skip to content

Commit 0303cdc

Browse files
committed
fixed the handshake
1 parent d3fcb5d commit 0303cdc

File tree

10 files changed

+112
-90
lines changed

10 files changed

+112
-90
lines changed

src/SuperSocket.MySQL/MySQLConnection.cs

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,6 @@ private byte[] GenerateAuthResponse(HandshakePacket handshakePacket)
9797
if (string.IsNullOrEmpty(_password))
9898
return Array.Empty<byte>();
9999

100-
// Combine auth plugin data parts to form the complete salt
101-
var salt = new byte[20];
102-
handshakePacket.AuthPluginDataPart1?.CopyTo(salt, 0);
103-
if (handshakePacket.AuthPluginDataPart2 != null)
104-
{
105-
var part2Length = Math.Min(handshakePacket.AuthPluginDataPart2.Length, 12);
106-
Array.Copy(handshakePacket.AuthPluginDataPart2, 0, salt, 8, part2Length);
107-
}
108-
109100
// MySQL native password authentication algorithm:
110101
// SHA1(password) XOR SHA1(salt + SHA1(SHA1(password)))
111102
using (var sha1 = SHA1.Create())
@@ -114,13 +105,20 @@ private byte[] GenerateAuthResponse(HandshakePacket handshakePacket)
114105
var sha1Password = sha1.ComputeHash(passwordBytes);
115106
var sha1Sha1Password = sha1.ComputeHash(sha1Password);
116107

117-
var combined = new byte[salt.Length + sha1Sha1Password.Length];
118-
salt.CopyTo(combined, 0);
119-
sha1Sha1Password.CopyTo(combined, salt.Length);
108+
sha1.TransformBlock(handshakePacket.AuthPluginDataPart1, 0, handshakePacket.AuthPluginDataPart1.Length, null, 0);
109+
110+
if (handshakePacket.AuthPluginDataPart2 != null)
111+
{
112+
var part2Length = Math.Min(handshakePacket.AuthPluginDataPart2.Length, 12);
113+
sha1.TransformBlock(handshakePacket.AuthPluginDataPart2, 0, part2Length, null, 0);
114+
}
120115

121-
var sha1Combined = sha1.ComputeHash(combined);
116+
sha1.TransformFinalBlock(sha1Sha1Password, 0, sha1Sha1Password.Length);
117+
118+
var sha1Combined = sha1.Hash;
122119

123120
var result = new byte[sha1Password.Length];
121+
124122
for (int i = 0; i < sha1Password.Length; i++)
125123
{
126124
result[i] = (byte)(sha1Password[i] ^ sha1Combined[i]);
@@ -169,9 +167,5 @@ public async Task DisconnectAsync()
169167
IsAuthenticated = false;
170168
}
171169
}
172-
173-
protected override void OnError(string message, Exception exception)
174-
{
175-
}
176170
}
177171
}

src/SuperSocket.MySQL/MySQLPacketDecoder.cs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,6 @@ public MySQLPacket Decode(ref ReadOnlySequence<byte> buffer, object context)
4343
packetType = (int)packetTypeByte;
4444
}
4545

46-
// Reset reader to beginning
47-
reader = new SequenceReader<byte>(buffer);
48-
4946
var package = _packetFactory.Create(packetType);
5047

5148
package.Decode(ref reader, context);
Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Buffers;
3+
using SuperSocket.MySQL.Packets;
34
using SuperSocket.ProtoBase;
45

56
namespace SuperSocket.MySQL
@@ -8,7 +9,31 @@ internal class MySQLPacketEncoder : IPackageEncoder<MySQLPacket>
89
{
910
public int Encode(IBufferWriter<byte> bufferWriter, MySQLPacket package)
1011
{
11-
return package.Encode(bufferWriter);
12+
var packetContentWriter = new ArrayBufferWriter<byte>();
13+
var packetLen = package.Encode(packetContentWriter);
14+
15+
var headerSpan = bufferWriter.GetSpan(4);
16+
headerSpan[0] = (byte)(packetLen & 0xFF);
17+
headerSpan[1] = (byte)((packetLen >> 8) & 0xFF);
18+
headerSpan[2] = (byte)((packetLen >> 16) & 0xFF);
19+
headerSpan[3] = (byte)package.SequenceId; // Sequence ID, typically starts at 0 for the first packet
20+
21+
bufferWriter.Advance(4);
22+
23+
packetLen += 4;
24+
25+
if (package is IPacketWithHeaderByte packetWithHeader)
26+
{
27+
// If the packet type byte is to be encoded, write it as the first byte of the content
28+
var contentSpan = packetContentWriter.GetSpan(1);
29+
contentSpan[0] = packetWithHeader.Header; // Example: using first character of type name as packet type byte
30+
packetContentWriter.Advance(1);
31+
32+
packetLen++;
33+
}
34+
35+
bufferWriter.Write(packetContentWriter.WrittenSpan);
36+
return packetLen;
1237
}
1338
}
1439
}

src/SuperSocket.MySQL/MySQLPacketFactory.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,15 @@ public MySQLPacket Create(int packageType)
3939
throw new InvalidDataException($"No packet registered for package type {packageType}");
4040
}
4141

42-
return creator();
42+
var packet = creator();
43+
44+
if (packet is IPacketWithHeaderByte packetWithHeader)
45+
{
46+
// If the packet type byte is to be encoded, set it as the first byte of the content
47+
packetWithHeader.Header = (byte)packageType;
48+
}
49+
50+
return packet;
4351
}
4452
}
4553
}

src/SuperSocket.MySQL/Packets/ErrorPacket.cs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,16 @@
33

44
namespace SuperSocket.MySQL.Packets
55
{
6-
public class ErrorPacket : MySQLPacket
6+
public class ErrorPacket : MySQLPacket, IPacketWithHeaderByte
77
{
8-
public byte Header { get; set; } = 0xFF; // Error packet identifier
8+
public byte Header { get; set; }
99
public ushort ErrorCode { get; set; }
1010
public string SqlStateMarker { get; set; } = "#";
1111
public string SqlState { get; set; }
1212
public string ErrorMessage { get; set; }
1313

1414
protected internal override void Decode(ref SequenceReader<byte> reader, object context)
1515
{
16-
// Read header (should be 0xFF for Error packet)
17-
reader.TryRead(out byte header);
18-
Header = header;
19-
2016
// Read error code (2 bytes)
2117
reader.TryReadLittleEndian(out short errorCode);
2218
ErrorCode = (ushort)errorCode;

src/SuperSocket.MySQL/Packets/HandshakePacket.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,57 +25,57 @@ protected internal override void Decode(ref SequenceReader<byte> reader, object
2525
// Read protocol version (1 byte)
2626
reader.TryRead(out byte protocolVersion);
2727
ProtocolVersion = protocolVersion;
28-
28+
2929
// Read null-terminated server version string
3030
ServerVersion = reader.TryReadNullTerminatedString(out string serverVersion) ? serverVersion : string.Empty;
31-
31+
3232
// Read connection ID (4 bytes)
3333
reader.TryReadLittleEndian(out int connectionId);
3434
ConnectionId = (uint)connectionId;
35-
35+
3636
// Read auth plugin data part 1 (8 bytes)
3737
AuthPluginDataPart1 = new byte[8];
3838
reader.TryCopyTo(AuthPluginDataPart1);
3939
reader.Advance(8);
40-
40+
4141
// Skip filler byte (1 byte)
4242
reader.Advance(1);
4343

4444
// Read capability flags lower (2 bytes)
4545
reader.TryReadLittleEndian(out short capabilityFlagsLower);
4646
CapabilityFlagsLower = (uint)(ushort)capabilityFlagsLower;
47-
47+
4848
// Check if more data is available (for MySQL 4.1+)
4949
if (reader.Remaining > 0)
5050
{
5151
// Read character set (1 byte)
5252
reader.TryRead(out byte characterSet);
5353
CharacterSet = characterSet;
54-
54+
5555
// Read status flags (2 bytes)
5656
reader.TryReadLittleEndian(out short statusFlags);
5757
StatusFlags = (ushort)statusFlags;
58-
58+
5959
// Read capability flags upper (2 bytes)
6060
reader.TryReadLittleEndian(out short capabilityFlagsUpper);
6161
CapabilityFlagsUpper = (uint)(ushort)capabilityFlagsUpper;
62-
62+
6363
// Read auth plugin data length (1 byte)
6464
reader.TryRead(out byte authPluginDataLength);
6565
AuthPluginDataLength = authPluginDataLength;
66-
66+
6767
// Skip reserved bytes (10 bytes)
6868
reader.Advance(10);
69-
69+
7070
// Read auth plugin data part 2 if present
71-
if ((CapabilityFlags & 0x00008000) != 0) // CLIENT_SECURE_CONNECTION
71+
if (AuthPluginDataLength > 8)
7272
{
7373
var part2Length = Math.Max(13, AuthPluginDataLength - 8);
7474
AuthPluginDataPart2 = new byte[part2Length];
7575
reader.TryCopyTo(AuthPluginDataPart2);
7676
reader.Advance(part2Length);
7777
}
78-
78+
7979
// Read auth plugin name if present
8080
if ((CapabilityFlags & 0x00080000) != 0) // CLIENT_PLUGIN_AUTH
8181
{
Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Buffers;
3+
using System.Buffers.Binary;
34
using System.Text;
45

56
namespace SuperSocket.MySQL.Packets
@@ -13,26 +14,27 @@ public class HandshakeResponsePacket : MySQLPacket
1314
public byte[] AuthResponse { get; set; }
1415
public string Database { get; set; }
1516
public string AuthPluginName { get; set; }
17+
1618
protected internal override void Decode(ref SequenceReader<byte> reader, object context)
1719
{
1820
// Read capability flags (4 bytes)
1921
reader.TryReadLittleEndian(out int capabilityFlags);
2022
CapabilityFlags = (uint)capabilityFlags;
21-
23+
2224
// Read max packet size (4 bytes)
2325
reader.TryReadLittleEndian(out int maxPacketSize);
2426
MaxPacketSize = (uint)maxPacketSize;
25-
27+
2628
// Read character set (1 byte)
2729
reader.TryRead(out byte characterSet);
2830
CharacterSet = characterSet;
29-
31+
3032
// Skip reserved bytes (23 bytes)
3133
reader.Advance(23);
32-
34+
3335
// Read null-terminated username
3436
Username = reader.TryReadNullTerminatedString(out string username) ? username : string.Empty;
35-
37+
3638
// Read auth response length and data
3739
if ((CapabilityFlags & 0x00200000) != 0) // CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
3840
{
@@ -52,13 +54,13 @@ protected internal override void Decode(ref SequenceReader<byte> reader, object
5254
{
5355
AuthResponse = reader.TryReadNullTerminatedString(out string authResponseString) ? Encoding.UTF8.GetBytes(authResponseString) : Array.Empty<byte>();
5456
}
55-
57+
5658
// Read database name if present
5759
if ((CapabilityFlags & 0x00000008) != 0) // CLIENT_CONNECT_WITH_DB
5860
{
5961
Database = reader.TryReadNullTerminatedString(out string database) ? database : string.Empty;
6062
}
61-
63+
6264
// Read auth plugin name if present
6365
if ((CapabilityFlags & 0x00080000) != 0) // CLIENT_PLUGIN_AUTH
6466
{
@@ -68,90 +70,80 @@ protected internal override void Decode(ref SequenceReader<byte> reader, object
6870

6971
protected internal override int Encode(IBufferWriter<byte> writer)
7072
{
71-
var contentWriter = new ArrayBufferWriter<byte>();
72-
7373
var bytesWritten = 0;
74-
var span = contentWriter.GetSpan(4);
75-
74+
var span = writer.GetSpan(4);
75+
76+
BinaryPrimitives.WriteUInt32LittleEndian(span, CapabilityFlags);
77+
7678
// Write capability flags (4 bytes)
77-
BitConverter.TryWriteBytes(span, CapabilityFlags);
78-
contentWriter.Advance(4);
79+
//BitConverter.TryWriteBytes(span, CapabilityFlags);
80+
writer.Advance(4);
7981
bytesWritten += 4;
8082

8183
// Write max packet size (4 bytes)
82-
span = contentWriter.GetSpan(4);
84+
span = writer.GetSpan(4);
8385
BitConverter.TryWriteBytes(span, MaxPacketSize);
84-
contentWriter.Advance(4);
86+
writer.Advance(4);
8587
bytesWritten += 4;
8688

8789
// Write character set (1 byte)
88-
span = contentWriter.GetSpan(1);
90+
span = writer.GetSpan(1);
8991
span[0] = CharacterSet;
90-
contentWriter.Advance(1);
92+
writer.Advance(1);
9193
bytesWritten += 1;
9294

9395
// Write reserved bytes (23 bytes of zeros)
94-
span = contentWriter.GetSpan(23);
96+
span = writer.GetSpan(23);
9597
span.Slice(0, 23).Clear();
96-
contentWriter.Advance(23);
98+
writer.Advance(23);
9799
bytesWritten += 23;
98100

99-
bytesWritten += contentWriter.WriteNullTerminatedString(Username);
101+
bytesWritten += writer.WriteNullTerminatedString(Username);
100102

101103
// Write auth response
102104
if (AuthResponse != null)
103105
{
104106
if ((CapabilityFlags & (uint)ClientCapabilities.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA) != 0) // CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA
105107
{
106-
bytesWritten += contentWriter.WriteLengthEncodedInteger((ulong)AuthResponse.Length);
107-
contentWriter.Write(AuthResponse);
108+
bytesWritten += writer.WriteLengthEncodedInteger((ulong)AuthResponse.Length);
109+
writer.Write(AuthResponse);
108110
bytesWritten += AuthResponse.Length;
109111
}
110112
else if ((CapabilityFlags & (uint)ClientCapabilities.CLIENT_SECURE_CONNECTION) != 0) // CLIENT_SECURE_CONNECTION
111113
{
112-
span = contentWriter.GetSpan(1);
114+
span = writer.GetSpan(1);
113115
span[0] = (byte)AuthResponse.Length;
114-
contentWriter.Advance(1);
116+
writer.Advance(1);
115117
bytesWritten += 1;
116118

117-
contentWriter.Write(AuthResponse);
119+
writer.Write(AuthResponse);
118120
bytesWritten += AuthResponse.Length;
119121
}
120122
else
121123
{
122-
contentWriter.Write(AuthResponse);
124+
writer.Write(AuthResponse);
123125
bytesWritten += AuthResponse.Length;
124126

125-
span = contentWriter.GetSpan(1);
127+
span = writer.GetSpan(1);
126128
span[0] = 0; // null terminator
127-
contentWriter.Advance(1);
129+
writer.Advance(1);
128130
bytesWritten += 1;
129131
}
130132
}
131133

132134
// Write database name if present
133135
if ((CapabilityFlags & 0x00000008) != 0 && !string.IsNullOrEmpty(Database))
134136
{
135-
bytesWritten += contentWriter.WriteNullTerminatedString(Database);
137+
bytesWritten += writer.WriteNullTerminatedString(Database);
136138
}
137139

138140
// Write auth plugin name if present
139141
if ((CapabilityFlags & 0x00080000) != 0 && !string.IsNullOrEmpty(AuthPluginName))
140142
{
141-
bytesWritten += contentWriter.WriteNullTerminatedString(AuthPluginName);
143+
bytesWritten += writer.WriteNullTerminatedString(AuthPluginName);
142144
}
143145

144-
var headerSpan = writer.GetSpan(4);
145-
headerSpan[0] = (byte)(bytesWritten & 0xFF);
146-
headerSpan[1] = (byte)((bytesWritten >> 8) & 0xFF);
147-
headerSpan[2] = (byte)((bytesWritten >> 16) & 0xFF);
148-
headerSpan[3] = (byte)SequenceId; // Sequence ID, typically starts at 0 for the first packet
149-
//headerSpan[4] = 0x00; // Set header byte
150-
writer.Advance(4);
151-
152-
writer.Write(contentWriter.WrittenSpan);
153-
154-
return bytesWritten + 4;
146+
return bytesWritten;
155147
}
156148
}
157149
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
namespace SuperSocket.MySQL.Packets
2+
{
3+
internal interface IPacketWithHeaderByte
4+
{
5+
byte Header { get; set; }
6+
}
7+
}

0 commit comments

Comments
 (0)