diff --git a/src/java/org/apache/cassandra/net/OutboundConnection.java b/src/java/org/apache/cassandra/net/OutboundConnection.java index 4aa754d8aa31..aa63441acf95 100644 --- a/src/java/org/apache/cassandra/net/OutboundConnection.java +++ b/src/java/org/apache/cassandra/net/OutboundConnection.java @@ -1127,42 +1127,39 @@ void onCompletedHandshake(Result result) // it is expected that close, if successful, has already cancelled us; so we do not need to worry about leaking connections assert !state.isClosed(); - MessagingSuccess success = result.success(); - debug.onConnect(success.messagingVersion, settings); - state.disconnected().maintenance.cancel(false); - - FrameEncoder.PayloadAllocator payloadAllocator = success.allocator; - Channel channel = success.channel; - Established established = new Established(success.messagingVersion, channel, payloadAllocator, settings); - state = established; - channel.pipeline().addLast("handleExceptionalStates", new ChannelInboundHandlerAdapter() { - @Override - public void channelInactive(ChannelHandlerContext ctx) - { - disconnectNow(established); - ctx.fireChannelInactive(); - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) - { - try - { - invalidateChannel(established, cause); + if (result.success() != null) { + MessagingSuccess success = (MessagingSuccess) result.success(); + debug.onConnect(success.messagingVersion, settings); + state.disconnected().maintenance.cancel(false); + + FrameEncoder.PayloadAllocator payloadAllocator = success.allocator; + Channel channel = success.channel; + Established established = new Established(success.messagingVersion, channel, payloadAllocator, settings); + state = established; + channel.pipeline().addLast("handleExceptionalStates", new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) { + disconnectNow(established); + ctx.fireChannelInactive(); } - catch (Throwable t) - { - logger.error("Unexpected exception in {}.exceptionCaught", this.getClass().getSimpleName(), t); + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + try { + invalidateChannel(established, cause); + } catch (Throwable t) { + logger.error("Unexpected exception in {}.exceptionCaught", this.getClass().getSimpleName(), t); + } } - } - }); - ++successfulConnections; + }); + ++successfulConnections; - logger.info("{} successfully connected, version = {}, framing = {}, encryption = {}", + logger.info("{} successfully connected, version = {}, framing = {}, encryption = {}", id(true), success.messagingVersion, settings.framing, encryptionConnectionSummary(channel)); + } break; case RETRY: diff --git a/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java b/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java index 2bddc174f2e8..218bd17f42d8 100644 --- a/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java +++ b/src/java/org/apache/cassandra/net/OutboundConnectionInitiator.java @@ -525,14 +525,27 @@ private Result(Outcome outcome) } boolean isSuccess() { return outcome == Outcome.SUCCESS; } - public SuccessType success() { return (SuccessType) this; } + public Success success() { + if (this.outcome == outcome.SUCCESS) + return (Success) this; + return null; + } static MessagingSuccess messagingSuccess(Channel channel, int messagingVersion, FrameEncoder.PayloadAllocator allocator) { return new MessagingSuccess(channel, messagingVersion, allocator); } static StreamingSuccess streamingSuccess(Channel channel, int messagingVersion) { return new StreamingSuccess(channel, messagingVersion); } - public Retry retry() { return (Retry) this; } + public Retry retry() { + if (this.outcome == outcome.RETRY) + return (Retry) this; + return null; + } static Result retry(int withMessagingVersion) { return new Retry<>(withMessagingVersion); } - public Incompatible incompatible() { return (Incompatible) this; } + public Incompatible incompatible() + { + if (this.outcome == outcome.INCOMPATIBLE) + return (Incompatible) this; + return null; + } static Result incompatible(int closestSupportedVersion, int maxMessagingVersion) { return new Incompatible(closestSupportedVersion, maxMessagingVersion); } } diff --git a/test/unit/org/apache/cassandra/net/StreamingTest.java b/test/unit/org/apache/cassandra/net/StreamingTest.java new file mode 100644 index 000000000000..e2d2a8527ae4 --- /dev/null +++ b/test/unit/org/apache/cassandra/net/StreamingTest.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.net; + +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; + +import com.google.common.net.InetAddresses; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import io.netty.channel.EventLoop; +import io.netty.util.concurrent.Future; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions; +import org.apache.cassandra.config.EncryptionOptions.ServerEncryptionOptions.Builder; +import org.apache.cassandra.config.ParameterizedClass; +import org.apache.cassandra.db.commitlog.CommitLog; +import org.apache.cassandra.gms.GossipDigestSyn; +import org.apache.cassandra.locator.InetAddressAndPort; +import org.apache.cassandra.security.DefaultSslContextFactory; +import org.apache.cassandra.transport.TlsTestUtils; + +import static org.apache.cassandra.net.OutboundConnectionInitiator.Result; +import static org.apache.cassandra.net.OutboundConnectionInitiator.SslFallbackConnectionType; +import static org.apache.cassandra.net.OutboundConnectionInitiator.initiateStreaming; +import static org.apache.cassandra.net.MessagingService.current_version; +import static org.apache.cassandra.net.MessagingService.minimum_version; +import static org.apache.cassandra.config.EncryptionOptions.ClientEncryptionOptions.ClientAuth.NOT_REQUIRED; +import static org.apache.cassandra.config.EncryptionOptions.ClientEncryptionOptions.ClientAuth.REQUIRED; +import static org.apache.cassandra.tcm.ClusterMetadata.EMPTY_METADATA_IDENTIFIER; +public class StreamingTest +{ + private static final SocketFactory factory = new SocketFactory(); + static final InetAddressAndPort TO_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.2"), 7012); + static final InetAddressAndPort FROM_ADDR = InetAddressAndPort.getByAddressOverrideDefaults(InetAddresses.forString("127.0.0.1"), 7012); + private volatile Throwable handshakeEx; + @BeforeClass + public static void startup() + { + DatabaseDescriptor.daemonInitialization(); + CommitLog.instance.start(); + } + + @AfterClass + public static void cleanup() throws InterruptedException + { + factory.shutdownNow(); + } + + @Before + public void setup() + { + handshakeEx = null; + } + + private Result streamingConnect(AcceptVersions acceptOutbound, AcceptVersions acceptInbound) throws ExecutionException, InterruptedException + { + InboundSockets inbound = new InboundSockets(new InboundConnectionSettings().withAcceptMessaging(acceptInbound)); + try + { + inbound.open(); + InetAddressAndPort endpoint = inbound.sockets().stream().map(s -> s.settings.bindAddress).findFirst().get(); + EventLoop eventLoop = factory.defaultGroup().next(); + Future> result = initiateStreaming(eventLoop, + new OutboundConnectionSettings(endpoint) + .withAcceptVersions(acceptOutbound) + .withDefaults(ConnectionCategory.STREAMING), + SslFallbackConnectionType.SERVER_CONFIG + ); + result.awaitUninterruptibly(); + Assert.assertTrue(result.isSuccess()); + + return result.getNow(); + } + finally + { + inbound.close().await(1L, TimeUnit.SECONDS); + } + } + + @Test + public void testIncompatibleVersion() throws InterruptedException, ExecutionException + { + Result nowResult = streamingConnect(new AcceptVersions(current_version + 1, current_version + 1), new AcceptVersions(minimum_version + 2, current_version + 3)); + Assert.assertNull(nowResult.success()); + Assert.assertEquals(Result.Outcome.INCOMPATIBLE, nowResult.outcome); + Assert.assertEquals(current_version, nowResult.incompatible().closestSupportedVersion); + Assert.assertEquals(current_version, nowResult.incompatible().maxMessagingVersion); + } + + @Test + public void testCompatibleVersion() throws InterruptedException, ExecutionException + { + Result nowResult = streamingConnect(new AcceptVersions(MessagingService.minimum_version, current_version + 1), new AcceptVersions(minimum_version + 2, current_version + 3)); + Assert.assertNotNull(nowResult.success()); + Assert.assertNotNull(nowResult.success().channel); + Assert.assertEquals(Result.Outcome.SUCCESS, nowResult.outcome); + Assert.assertEquals(current_version, nowResult.success().messagingVersion); + } + + private ServerEncryptionOptions getServerEncryptionOptions(SslFallbackConnectionType sslConnectionType, boolean optional) + { + Builder serverEncryptionOptionsBuilder = new Builder(); + + serverEncryptionOptionsBuilder.withOutboundKeystore(TlsTestUtils.SERVER_OUTBOUND_KEYSTORE_PATH) + .withOutboundKeystorePassword(TlsTestUtils.SERVER_OUTBOUND_KEYSTORE_PASSWORD) + .withOptional(optional) + .withKeyStore(TlsTestUtils.SERVER_KEYSTORE_PATH) + .withKeyStorePassword(TlsTestUtils.SERVER_KEYSTORE_PASSWORD) + .withTrustStore(TlsTestUtils.SERVER_TRUSTSTORE_PATH).withTrustStorePassword(TlsTestUtils.SERVER_TRUSTSTORE_PASSWORD) + .withSslContextFactory((new ParameterizedClass(DefaultSslContextFactory.class.getName(), + new HashMap<>()))); + + if (sslConnectionType == SslFallbackConnectionType.MTLS) + { + serverEncryptionOptionsBuilder.withInternodeEncryption(ServerEncryptionOptions.InternodeEncryption.all) + .withRequireClientAuth(REQUIRED); + } + else if (sslConnectionType == SslFallbackConnectionType.SSL) + { + serverEncryptionOptionsBuilder.withInternodeEncryption(ServerEncryptionOptions.InternodeEncryption.all) + .withRequireClientAuth(NOT_REQUIRED); + } + return serverEncryptionOptionsBuilder.build(); + } + + private OutboundConnection initiateOutbound(InetAddressAndPort endpoint, SslFallbackConnectionType connectionType, boolean optional) throws ClosedChannelException + { + final OutboundConnectionSettings settings = new OutboundConnectionSettings(endpoint) + .withAcceptVersions(new AcceptVersions(minimum_version, current_version)) + .withDefaults(ConnectionCategory.MESSAGING) + .withEncryption(getServerEncryptionOptions(connectionType, optional)) + .withDebugCallbacks(new HandshakeAcknowledgeChecker(t -> handshakeEx = t)) + .withFrom(FROM_ADDR); + OutboundConnections outboundConnections = OutboundConnections.tryRegister(new ConcurrentHashMap<>(), TO_ADDR, settings); + GossipDigestSyn syn = new GossipDigestSyn("cluster", "partitioner", EMPTY_METADATA_IDENTIFIER, new ArrayList<>(0)); + Message message = Message.out(Verb.GOSSIP_DIGEST_SYN, syn); + OutboundConnection outboundConnection = outboundConnections.connectionFor(message); + outboundConnection.enqueue(message); + return outboundConnection; + } + private static class HandshakeAcknowledgeChecker implements OutboundDebugCallbacks + { + private final AtomicInteger acks = new AtomicInteger(0); + private final Consumer fail; + + private HandshakeAcknowledgeChecker(Consumer fail) + { + this.fail = fail; + } + + @Override + public void onSendSmallFrame(int messageCount, int payloadSizeInBytes) + { + } + + @Override + public void onSentSmallFrame(int messageCount, int payloadSizeInBytes) + { + } + + @Override + public void onFailedSmallFrame(int messageCount, int payloadSizeInBytes) + { + } + + @Override + public void onConnect(int messagingVersion, OutboundConnectionSettings settings) + { + if (acks.incrementAndGet() > 1) + fail.accept(new AssertionError("Handshake was acknowledged more than once")); + } + } +} \ No newline at end of file