Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
/// Error thrown when `PostgresCopyFromWriter.write` is called while the backend is not expecting `CopyData` messages.
///
/// This should only happen if the user explicitly escapes the `PostgresCopyFromWriter` from the `copyFrom` closure and
/// tries to write data to it when tries to write data after the `copyFrom` call has terminated.
struct NotInCopyFromModeError: Error, CustomStringConvertible {
var description: String {
"Writing COPY FROM data when backend is not in COPY mode"
}
}

/// Handle to send data for a `COPY ... FROM STDIN` query to the backend.
public struct PostgresCopyFromWriter: Sendable {
private let channelHandler: NIOLoopBound<PostgresChannelHandler>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ struct ConnectionStateMachine {
/// should be aborted to avoid unnecessary work.
mutating func checkBackendCanReceiveCopyData(channelIsWritable: Bool, promise: EventLoopPromise<Void>) -> CheckBackendCanReceiveCopyDataAction {
guard case .extendedQuery(var queryState, let connectionContext) = self.state else {
preconditionFailure("Copy mode is only supported for extended queries")
return .failPromise(promise, error: NotInCopyFromModeError())
}

self.state = .modifying // avoid CoW
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ struct ExtendedQueryStateMachine {
return .failPromise(promise, error: error)
}
guard case .copyingData(.readyToSend) = self.state else {
preconditionFailure("Not ready to send data")
return .failPromise(promise, error: NotInCopyFromModeError())
}
if channelIsWritable {
return .succeedPromise(promise)
Expand Down
30 changes: 30 additions & 0 deletions Tests/PostgresNIOTests/New/PostgresConnectionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,36 @@ import Synchronization
}
}

@Test func testWritingDataAfterCopyFromHasFinishedThrowsError() async throws {
try await self.withAsyncTestingChannel { connection, channel in
try await withThrowingTaskGroup(of: Void.self) { taskGroup async throws -> () in
taskGroup.addTask {
var escapedWriter: PostgresCopyFromWriter?
try await connection.copyFrom(table: "test", logger: .psqlTest) { writer in
escapedWriter = writer
}
let writer = try #require(escapedWriter)
await #expect(throws: (any Error).self) {
try await writer.write(ByteBuffer(string: "oops"))
}
}

_ = try await channel.waitForUnpreparedRequest()

try await channel.sendUnpreparedRequestWithNoParametersBindResponse()
try await channel.writeInbound(PostgresBackendMessage.copyInResponse(.init(format: .textual, columnFormats: Array(repeating: .textual, count: 2))))

_ = try await channel.waitForCopyData()
try await channel.writeInbound(PostgresBackendMessage.commandComplete("COPY 0"))

try await channel.waitForPostgresFrontendMessage(\.sync)
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))

try await taskGroup.waitForAll()
}
}
}

func withAsyncTestingChannel(_ body: (PostgresConnection, NIOAsyncTestingChannel) async throws -> ()) async throws {
let eventLoop = NIOAsyncTestingEventLoop()
let channel = try await NIOAsyncTestingChannel(loop: eventLoop) { channel in
Expand Down
Loading