diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java index 4cea712833e2..ca382dcbb114 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DatagramDnsQueryContext.java @@ -15,6 +15,7 @@ */ package io.netty.resolver.dns; +import io.netty.bootstrap.Bootstrap; import io.netty.channel.AddressedEnvelope; import io.netty.channel.Channel; import io.netty.handler.codec.dns.DatagramDnsQuery; @@ -33,10 +34,12 @@ final class DatagramDnsQueryContext extends DnsQueryContext { InetSocketAddress nameServerAddr, DnsQueryContextManager queryContextManager, int maxPayLoadSize, boolean recursionDesired, + long queryTimeoutMillis, DnsQuestion question, DnsRecord[] additionals, - Promise> promise) { + Promise> promise, + Bootstrap socketBootstrap, boolean retryWithTcpOnTimeout) { super(channel, channelReadyFuture, nameServerAddr, queryContextManager, maxPayLoadSize, recursionDesired, - question, additionals, promise); + queryTimeoutMillis, question, additionals, promise, socketBootstrap, retryWithTcpOnTimeout); } @Override diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java index 6fbf86fc29d9..d286bac27832 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java @@ -23,6 +23,8 @@ import io.netty.channel.ChannelFactory; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInitializer; @@ -43,8 +45,6 @@ import io.netty.handler.codec.dns.DnsRecord; import io.netty.handler.codec.dns.DnsRecordType; import io.netty.handler.codec.dns.DnsResponse; -import io.netty.handler.codec.dns.TcpDnsQueryEncoder; -import io.netty.handler.codec.dns.TcpDnsResponseDecoder; import io.netty.resolver.DefaultHostsFileEntriesResolver; import io.netty.resolver.HostsFileEntries; import io.netty.resolver.HostsFileEntriesResolver; @@ -121,6 +121,13 @@ public class DnsNameResolver extends InetNameResolver { private static final InternetProtocolFamily[] IPV6_PREFERRED_RESOLVED_PROTOCOL_FAMILIES = {InternetProtocolFamily.IPv6, InternetProtocolFamily.IPv4}; + private static final ChannelHandler NOOP_HANDLER = new ChannelHandlerAdapter() { + @Override + public boolean isSharable() { + return true; + } + }; + static final ResolvedAddressTypes DEFAULT_RESOLVE_ADDRESS_TYPES; static final String[] DEFAULT_SEARCH_DOMAINS; private static final UnixResolverOptions DEFAULT_OPTIONS; @@ -227,7 +234,6 @@ protected DnsResponse decodeResponse(ChannelHandlerContext ctx, DatagramPacket p } }; private static final DatagramDnsQueryEncoder DATAGRAM_ENCODER = new DatagramDnsQueryEncoder(); - private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder(); private final Promise channelReadyPromise; private final Channel ch; @@ -273,6 +279,7 @@ protected DnsServerAddressStream initialValue() { private final DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory; private final boolean completeOncePreferredResolved; private final Bootstrap socketBootstrap; + private final boolean retryWithTcpOnTimeout; private final int maxNumConsolidation; private final Map>> inflightLookups; @@ -376,44 +383,18 @@ public DnsNameResolver( String[] searchDomains, int ndots, boolean decodeIdn) { - this(eventLoop, channelFactory, null, resolveCache, NoopDnsCnameCache.INSTANCE, authoritativeDnsServerCache, + this(eventLoop, channelFactory, null, false, resolveCache, + NoopDnsCnameCache.INSTANCE, authoritativeDnsServerCache, null, dnsQueryLifecycleObserverFactory, queryTimeoutMillis, resolvedAddressTypes, recursionDesired, maxQueriesPerResolve, traceEnabled, maxPayloadSize, optResourceEnabled, hostsFileEntriesResolver, - dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn, false); - } - - DnsNameResolver( - EventLoop eventLoop, - ChannelFactory channelFactory, - ChannelFactory socketChannelFactory, - final DnsCache resolveCache, - final DnsCnameCache cnameCache, - final AuthoritativeDnsServerCache authoritativeDnsServerCache, - DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory, - long queryTimeoutMillis, - ResolvedAddressTypes resolvedAddressTypes, - boolean recursionDesired, - int maxQueriesPerResolve, - boolean traceEnabled, - int maxPayloadSize, - boolean optResourceEnabled, - HostsFileEntriesResolver hostsFileEntriesResolver, - DnsServerAddressStreamProvider dnsServerAddressStreamProvider, - String[] searchDomains, - int ndots, - boolean decodeIdn, - boolean completeOncePreferredResolved) { - this(eventLoop, channelFactory, socketChannelFactory, resolveCache, cnameCache, authoritativeDnsServerCache, - null, dnsQueryLifecycleObserverFactory, queryTimeoutMillis, resolvedAddressTypes, - recursionDesired, maxQueriesPerResolve, traceEnabled, maxPayloadSize, optResourceEnabled, - hostsFileEntriesResolver, dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn, - completeOncePreferredResolved, 0); + dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn, false, 0); } DnsNameResolver( EventLoop eventLoop, ChannelFactory channelFactory, ChannelFactory socketChannelFactory, + boolean retryWithTcpOnTimeout, final DnsCache resolveCache, final DnsCnameCache cnameCache, final AuthoritativeDnsServerCache authoritativeDnsServerCache, @@ -457,6 +438,7 @@ public DnsNameResolver( this.ndots = ndots >= 0 ? ndots : DEFAULT_OPTIONS.ndots(); this.decodeIdn = decodeIdn; this.completeOncePreferredResolved = completeOncePreferredResolved; + this.retryWithTcpOnTimeout = retryWithTcpOnTimeout; if (socketChannelFactory == null) { socketBootstrap = null; } else { @@ -465,7 +447,7 @@ public DnsNameResolver( .group(executor()) .channelFactory(socketChannelFactory) .attr(DNS_PIPELINE_ATTRIBUTE, Boolean.TRUE) - .handler(TCP_ENCODER); + .handler(NOOP_HANDLER); if (queryTimeoutMillis > 0 && queryTimeoutMillis <= Integer.MAX_VALUE) { // Set the connect timeout to the same as queryTimeout as otherwise it might take a long // time for the query to fail in case of a connection timeout. @@ -1354,8 +1336,9 @@ final Future> query0( final int payloadSize = isOptResourceEnabled() ? maxPayloadSize() : 0; try { DnsQueryContext queryContext = new DatagramDnsQueryContext(ch, channelReadyPromise, nameServerAddr, - queryContextManager, payloadSize, isRecursionDesired(), question, additionals, castPromise); - ChannelFuture future = queryContext.writeQuery(queryTimeoutMillis(), flush); + queryContextManager, payloadSize, isRecursionDesired(), queryTimeoutMillis(), question, additionals, + castPromise, socketBootstrap, retryWithTcpOnTimeout); + ChannelFuture future = queryContext.writeQuery(flush); queryLifecycleObserver.queryWritten(nameServerAddr, future); return castPromise; } catch (Exception e) { @@ -1400,94 +1383,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { return; } - // Check if the response was truncated and if we can fallback to TCP to retry. - if (!res.isTruncated() || socketBootstrap == null) { - qCtx.finishSuccess(res); - return; - } - - socketBootstrap.connect(res.sender()).addListener(new ChannelFutureListener() { - @Override - public void operationComplete(ChannelFuture future) { - if (!future.isSuccess()) { - logger.debug("{} Unable to fallback to TCP [{}: {}]", - ch, queryId, res.sender(), future.cause()); - - // TCP fallback failed, just use the truncated response. - qCtx.finishSuccess(res); - return; - } - final Channel tcpCh = future.channel(); - - Promise> promise = - tcpCh.eventLoop().newPromise(); - final int payloadSize = isOptResourceEnabled() ? maxPayloadSize() : 0; - final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(tcpCh, channelReadyPromise, - (InetSocketAddress) tcpCh.remoteAddress(), queryContextManager, payloadSize, - isRecursionDesired(), qCtx.question(), EMPTY_ADDITIONALS, promise); - - tcpCh.pipeline().addLast(new TcpDnsResponseDecoder()); - tcpCh.pipeline().addLast(new ChannelInboundHandlerAdapter() { - @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { - Channel tcpCh = ctx.channel(); - DnsResponse response = (DnsResponse) msg; - int queryId = response.id(); - - if (logger.isDebugEnabled()) { - logger.debug("{} RECEIVED: TCP [{}: {}], {}", tcpCh, queryId, - tcpCh.remoteAddress(), response); - } - - DnsQueryContext foundCtx = queryContextManager.get(res.sender(), queryId); - if (foundCtx != null && foundCtx.isDone()) { - logger.debug("{} Received a DNS response for a query that was timed out or cancelled " + - ": TCP [{}: {}]", tcpCh, queryId, res.sender()); - response.release(); - } else if (foundCtx == tcpCtx) { - tcpCtx.finishSuccess(new AddressedEnvelopeAdapter( - (InetSocketAddress) ctx.channel().remoteAddress(), - (InetSocketAddress) ctx.channel().localAddress(), - response)); - } else { - response.release(); - tcpCtx.finishFailure("Received TCP DNS response with unexpected ID", null, false); - if (logger.isDebugEnabled()) { - logger.debug("{} Received a DNS response with an unexpected ID: TCP [{}: {}]", - tcpCh, queryId, tcpCh.remoteAddress()); - } - } - } - - @Override - public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - if (tcpCtx.finishFailure( - "TCP fallback error", cause, false) && logger.isDebugEnabled()) { - logger.debug("{} Error during processing response: TCP [{}: {}]", - ctx.channel(), queryId, - ctx.channel().remoteAddress(), cause); - } - } - }); - - promise.addListener( - new FutureListener>() { - @Override - public void operationComplete( - Future> future) { - if (future.isSuccess()) { - qCtx.finishSuccess(future.getNow()); - res.release(); - } else { - // TCP fallback failed, just use the truncated response. - qCtx.finishSuccess(res); - } - tcpCh.close(); - } - }); - tcpCtx.writeQuery(queryTimeoutMillis(), true); - } - }); + // The context will handle truncation itself. + qCtx.finishSuccess(res, res.isTruncated()); } @Override @@ -1505,113 +1402,4 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { } } } - - private static final class AddressedEnvelopeAdapter implements AddressedEnvelope { - private final InetSocketAddress sender; - private final InetSocketAddress recipient; - private final DnsResponse response; - - AddressedEnvelopeAdapter(InetSocketAddress sender, InetSocketAddress recipient, DnsResponse response) { - this.sender = sender; - this.recipient = recipient; - this.response = response; - } - - @Override - public DnsResponse content() { - return response; - } - - @Override - public InetSocketAddress sender() { - return sender; - } - - @Override - public InetSocketAddress recipient() { - return recipient; - } - - @Override - public AddressedEnvelope retain() { - response.retain(); - return this; - } - - @Override - public AddressedEnvelope retain(int increment) { - response.retain(increment); - return this; - } - - @Override - public AddressedEnvelope touch() { - response.touch(); - return this; - } - - @Override - public AddressedEnvelope touch(Object hint) { - response.touch(hint); - return this; - } - - @Override - public int refCnt() { - return response.refCnt(); - } - - @Override - public boolean release() { - return response.release(); - } - - @Override - public boolean release(int decrement) { - return response.release(decrement); - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { - return true; - } - - if (!(obj instanceof AddressedEnvelope)) { - return false; - } - - @SuppressWarnings("unchecked") - final AddressedEnvelope that = (AddressedEnvelope) obj; - if (sender() == null) { - if (that.sender() != null) { - return false; - } - } else if (!sender().equals(that.sender())) { - return false; - } - - if (recipient() == null) { - if (that.recipient() != null) { - return false; - } - } else if (!recipient().equals(that.recipient())) { - return false; - } - - return response.equals(obj); - } - - @Override - public int hashCode() { - int hashCode = response.hashCode(); - if (sender() != null) { - hashCode = hashCode * 31 + sender().hashCode(); - } - if (recipient() != null) { - hashCode = hashCode * 31 + recipient().hashCode(); - } - return hashCode; - } - } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java index ddd26d501b9b..0745ac45c285 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolverBuilder.java @@ -47,6 +47,8 @@ public final class DnsNameResolverBuilder { volatile EventLoop eventLoop; private ChannelFactory channelFactory; private ChannelFactory socketChannelFactory; + private boolean retryOnTimeout; + private DnsCache resolveCache; private DnsCnameCache cnameCache; private AuthoritativeDnsServerCache authoritativeDnsServerCache; @@ -143,7 +145,44 @@ public DnsNameResolverBuilder channelType(Class chann * @return {@code this} */ public DnsNameResolverBuilder socketChannelFactory(ChannelFactory channelFactory) { + return socketChannelFactory(channelFactory, false); + } + + /** + * Sets the {@link ChannelFactory} as a {@link ReflectiveChannelFactory} of this type for + * TCP fallback if needed. + * Use as an alternative to {@link #socketChannelFactory(ChannelFactory)}. + * + * TCP fallback is not enabled by default and must be enabled by providing a non-null + * {@code channelType} for this method. + * + * @param channelType the type or {@code null} if TCP fallback + * should not be supported. By default, TCP fallback is not enabled. + * @return {@code this} + */ + public DnsNameResolverBuilder socketChannelType(Class channelType) { + return socketChannelType(channelType, false); + } + + /** + * Sets the {@link ChannelFactory} that will create a {@link SocketChannel} for + * TCP fallback if needed. + * + * TCP fallback is not enabled by default and must be enabled by providing a non-null + * {@link ChannelFactory} for this method. + * + * @param channelFactory the {@link ChannelFactory} or {@code null} + * if TCP fallback should not be supported. + * By default, TCP fallback is not enabled. + * @param retryOnTimeout if {@code true} the {@link DnsNameResolver} will also fallback to TCP if a timeout + * was detected, if {@code false} it will only try to use TCP if the response was marked + * as truncated. + * @return {@code this} + */ + public DnsNameResolverBuilder socketChannelFactory( + ChannelFactory channelFactory, boolean retryOnTimeout) { this.socketChannelFactory = channelFactory; + this.retryOnTimeout = retryOnTimeout; return this; } @@ -157,13 +196,17 @@ public DnsNameResolverBuilder socketChannelFactory(ChannelFactoryTCP fallback * should not be supported. By default, TCP fallback is not enabled. + * @param retryOnTimeout if {@code true} the {@link DnsNameResolver} will also fallback to TCP if a timeout + * was detected, if {@code false} it will only try to use TCP if the response was marked + * as truncated. * @return {@code this} */ - public DnsNameResolverBuilder socketChannelType(Class channelType) { + public DnsNameResolverBuilder socketChannelType( + Class channelType, boolean retryOnTimeout) { if (channelType == null) { - return socketChannelFactory(null); + return socketChannelFactory(null, retryOnTimeout); } - return socketChannelFactory(new ReflectiveChannelFactory(channelType)); + return socketChannelFactory(new ReflectiveChannelFactory(channelType), retryOnTimeout); } /** @@ -528,6 +571,7 @@ public DnsNameResolver build() { eventLoop, channelFactory, socketChannelFactory, + retryOnTimeout, resolveCache, cnameCache, authoritativeDnsServerCache, @@ -565,9 +609,7 @@ public DnsNameResolverBuilder copy() { copiedBuilder.channelFactory(channelFactory); } - if (socketChannelFactory != null) { - copiedBuilder.socketChannelFactory(socketChannelFactory); - } + copiedBuilder.socketChannelFactory(socketChannelFactory, retryOnTimeout); if (resolveCache != null) { copiedBuilder.resolveCache(resolveCache); diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java index e3db5f807f6d..da1091b1d160 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/DnsQueryContext.java @@ -15,10 +15,13 @@ */ package io.netty.resolver.dns; +import io.netty.bootstrap.Bootstrap; import io.netty.channel.AddressedEnvelope; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelPromise; import io.netty.handler.codec.dns.AbstractDnsOptPseudoRrRecord; import io.netty.handler.codec.dns.DnsQuery; @@ -27,16 +30,20 @@ import io.netty.handler.codec.dns.DnsRecordType; import io.netty.handler.codec.dns.DnsResponse; import io.netty.handler.codec.dns.DnsSection; +import io.netty.handler.codec.dns.TcpDnsQueryEncoder; +import io.netty.handler.codec.dns.TcpDnsResponseDecoder; import io.netty.util.ReferenceCountUtil; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.FutureListener; import io.netty.util.concurrent.GenericFutureListener; import io.netty.util.concurrent.Promise; import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.ThrowableUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import java.net.InetSocketAddress; +import java.net.SocketAddress; import java.util.concurrent.CancellationException; import java.util.concurrent.TimeUnit; @@ -53,6 +60,8 @@ abstract class DnsQueryContext { logger.debug("-Dio.netty.resolver.dns.idReuseOnTimeoutDelayMillis: {}", ID_REUSE_ON_TIMEOUT_DELAY_MILLIS); } + private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder(); + private final Future channelReadyFuture; private final Channel channel; private final InetSocketAddress nameServerAddr; @@ -64,6 +73,12 @@ abstract class DnsQueryContext { private final DnsRecord optResource; private final boolean recursionDesired; + + private final Bootstrap socketBootstrap; + + private final boolean retryWithTcpOnTimeout; + private final long queryTimeoutMillis; + private volatile Future timeoutFuture; private int id = -1; @@ -74,9 +89,12 @@ abstract class DnsQueryContext { DnsQueryContextManager queryContextManager, int maxPayLoadSize, boolean recursionDesired, + long queryTimeoutMillis, DnsQuestion question, DnsRecord[] additionals, - Promise> promise) { + Promise> promise, + Bootstrap socketBootstrap, + boolean retryWithTcpOnTimeout) { this.channel = checkNotNull(channel, "channel"); this.queryContextManager = checkNotNull(queryContextManager, "queryContextManager"); this.channelReadyFuture = checkNotNull(channelReadyFuture, "channelReadyFuture"); @@ -85,6 +103,9 @@ abstract class DnsQueryContext { this.additionals = checkNotNull(additionals, "additionals"); this.promise = checkNotNull(promise, "promise"); this.recursionDesired = recursionDesired; + this.queryTimeoutMillis = queryTimeoutMillis; + this.socketBootstrap = socketBootstrap; + this.retryWithTcpOnTimeout = retryWithTcpOnTimeout; if (maxPayLoadSize > 0 && // Only add the extra OPT record if there is not already one. This is required as only one is allowed @@ -147,12 +168,10 @@ final DnsQuestion question() { /** * Write the query and return the {@link ChannelFuture} that is completed once the write completes. * - * @param queryTimeoutMillis the timeout after which the query is considered timeout and the original - * {@link Promise} will be failed. * @param flush {@code true} if {@link Channel#flush()} should be called as well. * @return the {@link ChannelFuture} that is notified once once the write completes. */ - final ChannelFuture writeQuery(long queryTimeoutMillis, boolean flush) { + final ChannelFuture writeQuery(boolean flush) { assert id == -1 : this.getClass().getSimpleName() + ".writeQuery(...) can only be executed once."; id = queryContextManager.add(nameServerAddr, this); @@ -205,7 +224,7 @@ public void run() { channel, protocol(), id, nameServerAddr, question); } - return sendQuery(nameServerAddr, query, queryTimeoutMillis, flush); + return sendQuery(query, flush); } private void removeFromContextManager(InetSocketAddress nameServerAddr) { @@ -214,11 +233,10 @@ private void removeFromContextManager(InetSocketAddress nameServerAddr) { assert self == this : "Removed DnsQueryContext is not the correct instance"; } - private ChannelFuture sendQuery(final InetSocketAddress nameServerAddr, final DnsQuery query, - final long queryTimeoutMillis, final boolean flush) { + private ChannelFuture sendQuery(final DnsQuery query, final boolean flush) { final ChannelPromise writePromise = channel.newPromise(); if (channelReadyFuture.isSuccess()) { - writeQuery(nameServerAddr, query, queryTimeoutMillis, flush, writePromise); + writeQuery(query, flush, writePromise); } else { Throwable cause = channelReadyFuture.cause(); if (cause != null) { @@ -233,7 +251,7 @@ public void operationComplete(Future future) { // If the query is done in a late fashion (as the channel was not ready yet) we always flush // to ensure we did not race with a previous flush() that was done when the Channel was not // ready yet. - writeQuery(nameServerAddr, query, queryTimeoutMillis, true, writePromise); + writeQuery(query, true, writePromise); } else { Throwable cause = future.cause(); failQuery(query, cause, writePromise); @@ -254,7 +272,7 @@ private void failQuery(DnsQuery query, Throwable cause, ChannelPromise writeProm } } - private void writeQuery(final InetSocketAddress nameServerAddr, final DnsQuery query, final long queryTimeoutMillis, + private void writeQuery(final DnsQuery query, final boolean flush, ChannelPromise promise) { final ChannelFuture writeFuture = flush ? channel.writeAndFlush(query, promise) : channel.write(query, promise); @@ -298,18 +316,21 @@ public void run() { * Notifies the original {@link Promise} that the response for the query was received. * This method takes ownership of passed {@link AddressedEnvelope}. */ - void finishSuccess(AddressedEnvelope envelope) { - final DnsResponse res = envelope.content(); - if (res.count(DnsSection.QUESTION) != 1) { - logger.warn("{} Received a DNS response with invalid number of questions. Expected: 1, found: {}", - channel, envelope); - } else if (!question().equals(res.recordAt(DnsSection.QUESTION))) { - logger.warn("{} Received a mismatching DNS response. Expected: [{}], found: {}", - channel, question(), envelope); - } else if (trySuccess(envelope)) { - return; // Ownership transferred, don't release + void finishSuccess(AddressedEnvelope envelope, boolean truncated) { + // Check if the response was not truncated or if a fallback to TCP is possible. + if (!truncated || !retryWithTcp(envelope)) { + final DnsResponse res = envelope.content(); + if (res.count(DnsSection.QUESTION) != 1) { + logger.warn("{} Received a DNS response with invalid number of questions. Expected: 1, found: {}", + channel, envelope); + } else if (!question().equals(res.recordAt(DnsSection.QUESTION))) { + logger.warn("{} Received a mismatching DNS response. Expected: [{}], found: {}", + channel, question(), envelope); + } else if (trySuccess(envelope)) { + return; // Ownership transferred, don't release + } + envelope.release(); } - envelope.release(); } @SuppressWarnings("unchecked") @@ -342,9 +363,229 @@ final boolean finishFailure(String message, Throwable cause, boolean timeout) { // This was caused by a timeout so use DnsNameResolverTimeoutException to allow the user to // handle it special (like retry the query). e = new DnsNameResolverTimeoutException(nameServerAddr, question, buf.toString()); + if (retryWithTcpOnTimeout && retryWithTcp(e)) { + // We did successfully retry with TCP. + return false; + } } else { e = new DnsNameResolverException(nameServerAddr, question, buf.toString(), cause); } return promise.tryFailure(e); } + + /** + * Retry the original query with TCP if possible. + * + * @param originalResult the result of the original {@link DnsQueryContext}. + * @return {@code true} if retry via TCP is supported and so the ownership of + * {@code originalResult} was transferred, {@code false} otherwise. + */ + private boolean retryWithTcp(final Object originalResult) { + if (socketBootstrap == null) { + return false; + } + + socketBootstrap.connect(nameServerAddr).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (!future.isSuccess()) { + logger.debug("{} Unable to fallback to TCP [{}: {}]", + future.channel(), id, nameServerAddr, future.cause()); + + // TCP fallback failed, just use the truncated response or error. + finishOriginal(originalResult, future); + return; + } + final Channel tcpCh = future.channel(); + Promise> promise = + tcpCh.eventLoop().newPromise(); + final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(tcpCh, channelReadyFuture, + (InetSocketAddress) tcpCh.remoteAddress(), queryContextManager, 0, + recursionDesired, queryTimeoutMillis, question(), additionals, promise); + tcpCh.pipeline().addLast(TCP_ENCODER); + tcpCh.pipeline().addLast(new TcpDnsResponseDecoder()); + tcpCh.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + Channel tcpCh = ctx.channel(); + DnsResponse response = (DnsResponse) msg; + int queryId = response.id(); + + if (logger.isDebugEnabled()) { + logger.debug("{} RECEIVED: TCP [{}: {}], {}", tcpCh, queryId, + tcpCh.remoteAddress(), response); + } + + DnsQueryContext foundCtx = queryContextManager.get(nameServerAddr, queryId); + if (foundCtx != null && foundCtx.isDone()) { + logger.debug("{} Received a DNS response for a query that was timed out or cancelled " + + ": TCP [{}: {}]", tcpCh, queryId, nameServerAddr); + response.release(); + } else if (foundCtx == tcpCtx) { + tcpCtx.finishSuccess(new AddressedEnvelopeAdapter( + (InetSocketAddress) ctx.channel().remoteAddress(), + (InetSocketAddress) ctx.channel().localAddress(), + response), false); + } else { + response.release(); + tcpCtx.finishFailure("Received TCP DNS response with unexpected ID", null, false); + if (logger.isDebugEnabled()) { + logger.debug("{} Received a DNS response with an unexpected ID: TCP [{}: {}]", + tcpCh, queryId, tcpCh.remoteAddress()); + } + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (tcpCtx.finishFailure( + "TCP fallback error", cause, false) && logger.isDebugEnabled()) { + logger.debug("{} Error during processing response: TCP [{}: {}]", + ctx.channel(), id, + ctx.channel().remoteAddress(), cause); + } + } + }); + + promise.addListener( + new FutureListener>() { + @Override + public void operationComplete( + Future> future) { + if (future.isSuccess()) { + finishSuccess(future.getNow(), false); + // Release the original result. + ReferenceCountUtil.release(originalResult); + } else { + // TCP fallback failed, just use the truncated response or error. + finishOriginal(originalResult, future); + } + tcpCh.close(); + } + }); + tcpCtx.writeQuery(true); + } + }); + return true; + } + + @SuppressWarnings("unchecked") + private void finishOriginal(Object originalResult, Future future) { + if (originalResult instanceof Throwable) { + Throwable error = (Throwable) originalResult; + ThrowableUtil.addSuppressed(error, future.cause()); + promise.tryFailure(error); + } else { + finishSuccess((AddressedEnvelope) originalResult, false); + } + } + + private static final class AddressedEnvelopeAdapter implements AddressedEnvelope { + private final InetSocketAddress sender; + private final InetSocketAddress recipient; + private final DnsResponse response; + + AddressedEnvelopeAdapter(InetSocketAddress sender, InetSocketAddress recipient, DnsResponse response) { + this.sender = sender; + this.recipient = recipient; + this.response = response; + } + + @Override + public DnsResponse content() { + return response; + } + + @Override + public InetSocketAddress sender() { + return sender; + } + + @Override + public InetSocketAddress recipient() { + return recipient; + } + + @Override + public AddressedEnvelope retain() { + response.retain(); + return this; + } + + @Override + public AddressedEnvelope retain(int increment) { + response.retain(increment); + return this; + } + + @Override + public AddressedEnvelope touch() { + response.touch(); + return this; + } + + @Override + public AddressedEnvelope touch(Object hint) { + response.touch(hint); + return this; + } + + @Override + public int refCnt() { + return response.refCnt(); + } + + @Override + public boolean release() { + return response.release(); + } + + @Override + public boolean release(int decrement) { + return response.release(decrement); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (!(obj instanceof AddressedEnvelope)) { + return false; + } + + @SuppressWarnings("unchecked") + final AddressedEnvelope that = (AddressedEnvelope) obj; + if (sender() == null) { + if (that.sender() != null) { + return false; + } + } else if (!sender().equals(that.sender())) { + return false; + } + + if (recipient() == null) { + if (that.recipient() != null) { + return false; + } + } else if (!recipient().equals(that.recipient())) { + return false; + } + + return response.equals(obj); + } + + @Override + public int hashCode() { + int hashCode = response.hashCode(); + if (sender() != null) { + hashCode = hashCode * 31 + sender().hashCode(); + } + if (recipient() != null) { + hashCode = hashCode * 31 + recipient().hashCode(); + } + return hashCode; + } + } } diff --git a/resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java b/resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java index 8f25ab8664e7..f8022a8c6a77 100644 --- a/resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java +++ b/resolver-dns/src/main/java/io/netty/resolver/dns/TcpDnsQueryContext.java @@ -33,10 +33,12 @@ final class TcpDnsQueryContext extends DnsQueryContext { InetSocketAddress nameServerAddr, DnsQueryContextManager queryContextManager, int maxPayLoadSize, boolean recursionDesired, + long queryTimeoutMillis, DnsQuestion question, DnsRecord[] additionals, Promise> promise) { super(channel, channelReadyFuture, nameServerAddr, queryContextManager, maxPayLoadSize, recursionDesired, - question, additionals, promise); + // No retry via TCP. + queryTimeoutMillis, question, additionals, promise, null, false); } @Override diff --git a/resolver-dns/src/main/resources/META-INF/native-image/io.netty/netty-resolver-dns/generated/handlers/reflect-config.json b/resolver-dns/src/main/resources/META-INF/native-image/io.netty/netty-resolver-dns/generated/handlers/reflect-config.json index 5960b0047047..68508d9de7b0 100644 --- a/resolver-dns/src/main/resources/META-INF/native-image/io.netty/netty-resolver-dns/generated/handlers/reflect-config.json +++ b/resolver-dns/src/main/resources/META-INF/native-image/io.netty/netty-resolver-dns/generated/handlers/reflect-config.json @@ -7,9 +7,16 @@ "queryAllPublicMethods": true }, { - "name": "io.netty.resolver.dns.DnsNameResolver$3", + "name": "io.netty.resolver.dns.DnsNameResolver$2", "condition": { - "typeReachable": "io.netty.resolver.dns.DnsNameResolver$3" + "typeReachable": "io.netty.resolver.dns.DnsNameResolver$2" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.resolver.dns.DnsNameResolver$4", + "condition": { + "typeReachable": "io.netty.resolver.dns.DnsNameResolver$4" }, "queryAllPublicMethods": true }, @@ -21,9 +28,9 @@ "queryAllPublicMethods": true }, { - "name": "io.netty.resolver.dns.DnsNameResolver$DnsResponseHandler$1$1", + "name": "io.netty.resolver.dns.DnsQueryContext$6$1", "condition": { - "typeReachable": "io.netty.resolver.dns.DnsNameResolver$DnsResponseHandler$1$1" + "typeReachable": "io.netty.resolver.dns.DnsQueryContext$6$1" }, "queryAllPublicMethods": true } diff --git a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java index 4977981a53b3..b873a2f55510 100644 --- a/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java +++ b/resolver-dns/src/test/java/io/netty/resolver/dns/DnsNameResolverTest.java @@ -3293,30 +3293,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { if (tcpFallback) { // If we are configured to use TCP as a fallback lets replay the dns message over TCP Socket socket = serverSocket.accept(); + responseViaSocket(socket, messageRef.get()); - InputStream in = socket.getInputStream(); - assertTrue((in.read() << 8 | (in.read() & 0xff)) > 2); // skip length field - int txnId = in.read() << 8 | (in.read() & 0xff); - - IoBuffer ioBuffer = IoBuffer.allocate(1024); - // Must replace the transactionId with the one from the TCP request - DnsMessageModifier modifier = modifierFrom(messageRef.get()); - modifier.setTransactionId(txnId); - new DnsMessageEncoder().encode(ioBuffer, modifier.getDnsMessage()); - ioBuffer.flip(); - - ByteBuffer lenBuffer = ByteBuffer.allocate(2); - lenBuffer.putShort((short) ioBuffer.remaining()); - lenBuffer.flip(); - - while (lenBuffer.hasRemaining()) { - socket.getOutputStream().write(lenBuffer.get()); - } - - while (ioBuffer.hasRemaining()) { - socket.getOutputStream().write(ioBuffer.get()); - } - socket.getOutputStream().flush(); // Let's wait until we received the envelope before closing the socket. envelopeFuture.syncUninterruptibly(); @@ -3352,6 +3330,137 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { } } + private static void responseViaSocket(Socket socket, DnsMessage message) throws IOException { + InputStream in = socket.getInputStream(); + assertTrue((in.read() << 8 | (in.read() & 0xff)) > 2); // skip length field + int txnId = in.read() << 8 | (in.read() & 0xff); + + IoBuffer ioBuffer = IoBuffer.allocate(1024); + // Must replace the transactionId with the one from the TCP request + DnsMessageModifier modifier = modifierFrom(message); + modifier.setTransactionId(txnId); + new DnsMessageEncoder().encode(ioBuffer, modifier.getDnsMessage()); + ioBuffer.flip(); + + ByteBuffer lenBuffer = ByteBuffer.allocate(2); + lenBuffer.putShort((short) ioBuffer.remaining()); + lenBuffer.flip(); + + while (lenBuffer.hasRemaining()) { + socket.getOutputStream().write(lenBuffer.get()); + } + + while (ioBuffer.hasRemaining()) { + socket.getOutputStream().write(ioBuffer.get()); + } + socket.getOutputStream().flush(); + } + + @Test + public void testTcpFallbackWhenTimeout() throws IOException { + testTcpFallbackWhenTimeout(true); + } + + @Test + public void testTcpFallbackFailedWhenTimeout() throws IOException { + testTcpFallbackWhenTimeout(false); + } + + private void testTcpFallbackWhenTimeout(boolean tcpSuccess) throws IOException { + ServerSocket serverSocket = new ServerSocket(); + serverSocket.setReuseAddress(true); + serverSocket.bind(new InetSocketAddress(NetUtil.LOCALHOST4, 0)); + + final String host = "somehost.netty.io"; + final String txt = "this is a txt record"; + final AtomicReference messageRef = new AtomicReference(); + + TestDnsServer dnsServer2 = new TestDnsServer(new RecordStore() { + @Override + public Set getRecords(QuestionRecord question) { + String name = question.getDomainName(); + if (name.equals(host)) { + return Collections.singleton( + new TestDnsServer.TestResourceRecord(name, RecordType.TXT, + Collections.singletonMap( + DnsAttribute.CHARACTER_STRING.toLowerCase(), txt))); + } + return null; + } + }) { + @Override + protected DnsMessage filterMessage(DnsMessage message) { + // Store a original message so we can replay it later on. + messageRef.set(message); + return null; + } + }; + DnsNameResolver resolver = null; + try { + DnsNameResolverBuilder builder = newResolver(); + final DatagramChannel datagramChannel = new NioDatagramChannel(); + ChannelFactory channelFactory = new ChannelFactory() { + @Override + public DatagramChannel newChannel() { + return datagramChannel; + } + }; + builder.channelFactory(channelFactory); + dnsServer2.start(null, (InetSocketAddress) serverSocket.getLocalSocketAddress()); + // If we are configured to use TCP as a fallback also bind a TCP socket + builder.socketChannelType(NioSocketChannel.class, true); + + builder.queryTimeoutMillis(1000) + .resolvedAddressTypes(ResolvedAddressTypes.IPV4_PREFERRED) + .maxQueriesPerResolve(16) + .nameServerProvider(new SingletonDnsServerAddressStreamProvider(dnsServer2.localAddress())); + resolver = builder.build(); + Future> envelopeFuture = resolver.query( + new DefaultDnsQuestion(host, DnsRecordType.TXT)); + + // If we are configured to use TCP as a fallback lets replay the dns message over TCP + Socket socket = serverSocket.accept(); + + if (tcpSuccess) { + responseViaSocket(socket, messageRef.get()); + + // Let's wait until we received the envelope before closing the socket. + envelopeFuture.syncUninterruptibly(); + socket.close(); + + AddressedEnvelope envelope = + envelopeFuture.syncUninterruptibly().getNow(); + assertNotNull(envelope.sender()); + + DnsResponse response = envelope.content(); + assertNotNull(response); + + assertEquals(DnsResponseCode.NOERROR, response.code()); + int count = response.count(DnsSection.ANSWER); + + assertEquals(1, count); + List texts = decodeTxt(response.recordAt(DnsSection.ANSWER, 0)); + assertEquals(1, texts.size()); + assertEquals(txt, texts.get(0)); + + assertFalse(envelope.content().isTruncated()); + assertTrue(envelope.release()); + } else { + // Just close the socket. This should cause the original exception to be used. + socket.close(); + Throwable error = envelopeFuture.awaitUninterruptibly().cause(); + assertThat(error, instanceOf(DnsNameResolverTimeoutException.class)); + assertThat(error.getSuppressed().length, greaterThanOrEqualTo(1)); + } + } finally { + dnsServer2.stop(); + if (resolver != null) { + resolver.close(); + } + serverSocket.close(); + } + } + @Test public void testCancelPromise() throws Exception { final EventLoop eventLoop = group.next();