Skip to content

Commit 0754516

Browse files
committed
tried to impelement query result packet decoding correctly
1 parent 4d6cf66 commit 0754516

File tree

12 files changed

+308
-203
lines changed

12 files changed

+308
-203
lines changed

src/SuperSocket.MySQL/MySQLConnection.cs

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System;
22
using System.Buffers;
33
using System.Collections.Generic;
4+
using System.Linq;
45
using System.Net;
56
using System.Net.Sockets;
67
using System.Security.Cryptography;
@@ -26,15 +27,23 @@ public class MySQLConnection : EasyClient<MySQLPacket>
2627

2728
public bool IsAuthenticated { get; private set; }
2829

30+
private readonly MySQLFilterContext filterContext;
31+
2932
public MySQLConnection(string host, int port, string userName, string password, ILogger logger = null)
30-
: base(new MySQLPacketFilter(MySQLPacketDecoder.ClientInstance), logger)
33+
: this(new MySQLPacketFilter(MySQLPacketDecoder.ClientInstance), logger)
3134
{
3235
_host = host ?? throw new ArgumentNullException(nameof(host));
3336
_port = port > 0 ? port : DefaultPort;
3437
_userName = userName ?? throw new ArgumentNullException(nameof(userName));
3538
_password = password ?? throw new ArgumentNullException(nameof(password));
3639
}
3740

41+
private MySQLConnection(MySQLPacketFilter packetFilter, ILogger logger)
42+
: base(packetFilter, logger)
43+
{
44+
filterContext = packetFilter.Context as MySQLFilterContext;
45+
}
46+
3847
public async Task ConnectAsync(CancellationToken cancellationToken = default)
3948
{
4049
if (string.IsNullOrEmpty(_host))
@@ -83,6 +92,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
8392
case OKPacket okPacket:
8493
// Authentication successful
8594
IsAuthenticated = true;
95+
filterContext.State = MySQLConnectionState.Authenticated;
8696
break;
8797
case ErrorPacket errorPacket:
8898
// Authentication failed
@@ -95,6 +105,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken = default)
95105
if ((eofPacket.StatusFlags & 0x0002) != 0)
96106
{
97107
IsAuthenticated = true;
108+
filterContext.State = MySQLConnectionState.Authenticated;
98109
break;
99110
}
100111
else
@@ -164,12 +175,27 @@ public async Task<QueryResultPacket> ExecuteQueryAsync(string query, Cancellatio
164175
SequenceId = 0
165176
};
166177

178+
filterContext.State = MySQLConnectionState.CommandPhase;
167179
await SendAsync(PacketEncoder, commandPacket).ConfigureAwait(false);
168180

169181
// Read response
170182
var response = await ReceiveAsync().ConfigureAwait(false);
171183

172-
return (QueryResultPacket)response;
184+
// Handle different response types
185+
switch (response)
186+
{
187+
case ErrorPacket errorPacket:
188+
// Query failed
189+
return QueryResultPacket.FromError((short)errorPacket.ErrorCode, errorPacket.ErrorMessage);
190+
191+
case QueryResultPacket queryResult:
192+
// Already a query result packet
193+
return queryResult;
194+
195+
default:
196+
// Handle result set responses (SELECT queries)
197+
return await ReadResultSetAsync(response).ConfigureAwait(false);
198+
}
173199
}
174200
catch (Exception ex)
175201
{
@@ -186,22 +212,17 @@ public async Task<QueryResultPacket> ExecuteQueryAsync(string query, Cancellatio
186212
public async Task<string> ExecuteQueryStringAsync(string query, CancellationToken cancellationToken = default)
187213
{
188214
var result = await ExecuteQueryAsync(query, cancellationToken).ConfigureAwait(false);
189-
215+
190216
if (!result.IsSuccess)
191217
{
192218
return $"Error {result.ErrorCode}: {result.ErrorMessage}";
193219
}
194220

195-
if (result.Columns == null || result.Columns.Count == 0)
196-
{
197-
return $"Query executed successfully. {result.AffectedRows} rows affected.";
198-
}
199-
200221
var sb = new StringBuilder();
201-
222+
202223
// Add column headers
203224
sb.AppendLine(string.Join("\t", result.Columns));
204-
225+
205226
// Add separator line
206227
sb.AppendLine(new string('-', result.Columns.Count * 10));
207228

@@ -213,9 +234,9 @@ public async Task<string> ExecuteQueryStringAsync(string query, CancellationToke
213234
sb.AppendLine(string.Join("\t", row ?? new string[result.Columns.Count]));
214235
}
215236
}
216-
237+
217238
sb.AppendLine($"\n{result.Rows?.Count ?? 0} rows returned.");
218-
239+
219240
return sb.ToString();
220241
}
221242

@@ -235,5 +256,50 @@ public async Task DisconnectAsync()
235256
IsAuthenticated = false;
236257
}
237258
}
259+
260+
/// <summary>
261+
/// Reads a complete result set from the MySQL server
262+
/// </summary>
263+
/// <param name="firstPacket">The first packet received after sending the query</param>
264+
/// <returns>A QueryResultPacket containing the complete result set</returns>
265+
private Task<QueryResultPacket> ReadResultSetAsync(MySQLPacket firstPacket)
266+
{
267+
try
268+
{
269+
// If the first packet is already a QueryResultPacket (decoded by UnknownPacket), return it
270+
if (firstPacket is QueryResultPacket queryResult)
271+
{
272+
return Task.FromResult(queryResult);
273+
}
274+
275+
// If the first packet is an UnknownPacket, it should have been decoded to QueryResultPacket
276+
// but if that failed, we'll create a minimal fallback
277+
if (firstPacket is UnknownPacket)
278+
{
279+
// Try to read additional packets to build a result set
280+
// This is a simplified implementation that attempts to handle basic SELECT queries
281+
282+
var columns = new List<ColumnDefinitionPacket>();
283+
var rows = new List<IReadOnlyList<string>>();
284+
285+
// For now, create a minimal successful result
286+
// This could be enhanced to parse more complex result sets in the future
287+
return Task.FromResult(QueryResultPacket.FromResultSet(columns.AsReadOnly(), rows.AsReadOnly()));
288+
}
289+
290+
// For any other packet type, treat as an unexpected response
291+
return Task.FromResult(QueryResultPacket.FromError(-1, $"Unexpected packet type in result set: {firstPacket?.GetType().Name ?? "null"}"));
292+
}
293+
catch (Exception ex)
294+
{
295+
return Task.FromResult(QueryResultPacket.FromError(-1, $"Failed to read result set: {ex.Message}"));
296+
}
297+
}
298+
299+
protected override void OnClosed(object sender, EventArgs e)
300+
{
301+
filterContext.State = MySQLConnectionState.Closed;
302+
base.OnClosed(sender, e);
303+
}
238304
}
239305
}
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using SuperSocket.MySQL.Packets;
4+
5+
namespace SuperSocket.MySQL
6+
{
7+
internal class MySQLFilterContext
8+
{
9+
public int SequenceId { get; set; }
10+
11+
public bool IsHandshakeCompleted { get; set; }
12+
13+
public MySQLConnectionState State { get; set; }
14+
15+
public MySQLPacket NextPacket { get; set; }
16+
17+
public int QueryResultColumnCount { get; set; }
18+
19+
public List<ColumnDefinitionPacket> ColumnDefinitionPackets { get; set; }
20+
21+
public MySQLFilterContext()
22+
{
23+
SequenceId = 0;
24+
IsHandshakeCompleted = false;
25+
State = MySQLConnectionState.Initial;
26+
}
27+
28+
public void Reset()
29+
{
30+
SequenceId = 0;
31+
IsHandshakeCompleted = false;
32+
State = MySQLConnectionState.Initial;
33+
}
34+
35+
public void IncrementSequenceId()
36+
{
37+
SequenceId = (SequenceId + 1) % 256;
38+
}
39+
}
40+
41+
public enum MySQLConnectionState
42+
{
43+
Initial,
44+
HandshakeInitiated,
45+
Authenticated,
46+
CommandPhase,
47+
Closed
48+
}
49+
}

src/SuperSocket.MySQL/MySQLPacketDecoder.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ public MySQLPacket Decode(ref ReadOnlySequence<byte> buffer, object context)
3030
reader.Advance(3); // Skip the first 3 bytes of the header
3131
reader.TryRead(out var sequenceId); // Read the sequence ID
3232

33-
var filter = context as MySQLPacketFilter;
33+
var filterContext = context as MySQLFilterContext;
3434

3535
var packetType = -1;
3636

3737
// Read the first byte to determine packet type
38-
if (filter.ReceivedHandshake)
38+
if (filterContext.State != MySQLConnectionState.Initial)
3939
{
40+
// In handshake state, we expect the first byte to be the packet type
4041
if (!reader.TryRead(out var packetTypeByte))
4142
return null;
4243

@@ -48,8 +49,10 @@ public MySQLPacket Decode(ref ReadOnlySequence<byte> buffer, object context)
4849
package = package.Decode(ref reader, context);
4950
package.SequenceId = sequenceId;
5051

51-
if (!filter.ReceivedHandshake)
52-
filter.ReceivedHandshake = true;
52+
if (filterContext.State == MySQLConnectionState.Initial)
53+
{
54+
filterContext.State = MySQLConnectionState.HandshakeInitiated;
55+
}
5356

5457
return package;
5558
}

src/SuperSocket.MySQL/MySQLPacketFilter.cs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@ internal class MySQLPacketFilter : FixedHeaderPipelineFilter<MySQLPacket>
88
{
99
private const int headerSize = 4; // MySQL package header size is 4 bytes
1010

11-
internal bool ReceivedHandshake { get; set; }
12-
1311
public MySQLPacketFilter(IPackageDecoder<MySQLPacket> decoder)
1412
: base(headerSize)
1513
{
1614
this.Decoder = decoder ?? throw new ArgumentNullException(nameof(decoder));
17-
this.Context = this;
15+
this.Context = new MySQLFilterContext();
1816
}
1917

2018
protected override int GetBodyLengthFromHeader(ref ReadOnlySequence<byte> buffer)
@@ -27,5 +25,19 @@ protected override int GetBodyLengthFromHeader(ref ReadOnlySequence<byte> buffer
2725

2826
return byte2 * 256 * 256 + byte1 * 256 + byte0;
2927
}
28+
29+
public override MySQLPacket Filter(ref SequenceReader<byte> reader)
30+
{
31+
var packet = base.Filter(ref reader);
32+
33+
if (packet == null || packet.IsPartialPacket)
34+
{
35+
// If the packet is null or a partial packet, we cannot return it yet
36+
// We will wait for more data to complete the packet
37+
return null;
38+
}
39+
40+
return packet;
41+
}
3042
}
3143
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
using System;
2+
using System.Buffers;
3+
using System.Collections.Generic;
4+
using SuperSocket.ProtoBase;
5+
6+
namespace SuperSocket.MySQL.Packets
7+
{
8+
public class ColumnCountPacket : MySQLPacket
9+
{
10+
public ulong ColumnCount { get; set; }
11+
12+
protected internal override MySQLPacket Decode(ref SequenceReader<byte> reader, object context)
13+
{
14+
ColumnCount = reader.ReadLengthEncodedInteger();
15+
16+
var filterContext = context as MySQLFilterContext;
17+
filterContext.NextPacket = new ColumnDefinitionPacket();
18+
filterContext.QueryResultColumnCount = (int)ColumnCount;
19+
filterContext.ColumnDefinitionPackets = new List<ColumnDefinitionPacket>(filterContext.QueryResultColumnCount);
20+
21+
return this;
22+
}
23+
24+
protected internal override int Encode(IBufferWriter<byte> writer)
25+
{
26+
throw new NotImplementedException();
27+
}
28+
29+
internal override bool IsPartialPacket => true;
30+
}
31+
}

src/SuperSocket.MySQL/Packets/ColumnDefinitionPacket.cs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ protected internal override MySQLPacket Decode(ref SequenceReader<byte> reader,
3232
// Read catalog (length-encoded string)
3333
if (!reader.TryReadLengthEncodedString(out string catalog))
3434
throw new InvalidOperationException("Failed to read catalog");
35+
3536
Catalog = catalog;
3637

3738
// Read schema (length-encoded string)
@@ -90,6 +91,18 @@ protected internal override MySQLPacket Decode(ref SequenceReader<byte> reader,
9091
// Skip the two null bytes that follow
9192
reader.Advance(2);
9293

94+
var filterContext = context as MySQLFilterContext;
95+
filterContext.ColumnDefinitionPackets.Add(this);
96+
97+
if (filterContext.QueryResultColumnCount > filterContext.ColumnDefinitionPackets.Count)
98+
{
99+
filterContext.NextPacket = new ColumnDefinitionPacket();
100+
}
101+
else
102+
{
103+
filterContext.NextPacket = new ResultRowsPacket();
104+
}
105+
93106
return this;
94107
}
95108

@@ -156,5 +169,7 @@ protected internal override int Encode(IBufferWriter<byte> writer)
156169

157170
return bytesWritten;
158171
}
172+
173+
internal override bool IsPartialPacket => true;
159174
}
160175
}

src/SuperSocket.MySQL/Packets/MySQLPacket.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,7 @@ public abstract class MySQLPacket
1515
protected internal abstract MySQLPacket Decode(ref SequenceReader<byte> reader, object context);
1616

1717
protected internal abstract int Encode(IBufferWriter<byte> writer);
18+
19+
internal virtual bool IsPartialPacket => false;
1820
}
1921
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
using System;
2+
using System.Buffers;
3+
4+
namespace SuperSocket.MySQL.Packets
5+
{
6+
public class NonQueryResultPacket : MySQLPacket
7+
{
8+
public long AffectedRows { get; private set; }
9+
public long LastInsertId { get; private set; }
10+
11+
protected internal override MySQLPacket Decode(ref SequenceReader<byte> reader, object context)
12+
{
13+
// Read affected rows and last insert ID for non-SELECT queries
14+
if (reader.TryReadLengthEncodedInteger(out long affectedRows))
15+
{
16+
AffectedRows = affectedRows;
17+
}
18+
19+
if (reader.TryReadLengthEncodedInteger(out long lastInsertId))
20+
{
21+
LastInsertId = lastInsertId;
22+
}
23+
24+
return this;
25+
}
26+
27+
protected internal override int Encode(IBufferWriter<byte> writer)
28+
{
29+
throw new NotImplementedException();
30+
}
31+
}
32+
}

0 commit comments

Comments
 (0)