Skip to content

Commit

Permalink
Be able to retry the query via TCP if a query failed because of a tim…
Browse files Browse the repository at this point in the history
…eout when u… (netty#13757)

…sing UDP

Motivation:

We should allow people to retry the query via TCP if the query failed because of a
timeout when using UDP.

Modifications:

- Move all the retry code for TCP into DnsQueryContext so we can reuse
the same code for handling truncation and retry.
- Retry on timeout if configured by user
- Add unit tests

Result:

More robust resolver
  • Loading branch information
normanmaurer authored Jan 12, 2024
1 parent 47e1f43 commit 684dfd8
Show file tree
Hide file tree
Showing 7 changed files with 482 additions and 290 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package io.netty.resolver.dns;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.handler.codec.dns.DatagramDnsQuery;
Expand All @@ -33,10 +34,12 @@ final class DatagramDnsQueryContext extends DnsQueryContext {
InetSocketAddress nameServerAddr,
DnsQueryContextManager queryContextManager,
int maxPayLoadSize, boolean recursionDesired,
long queryTimeoutMillis,
DnsQuestion question, DnsRecord[] additionals,
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise) {
Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise,
Bootstrap socketBootstrap, boolean retryWithTcpOnTimeout) {
super(channel, channelReadyFuture, nameServerAddr, queryContextManager, maxPayLoadSize, recursionDesired,
question, additionals, promise);
queryTimeoutMillis, question, additionals, promise, socketBootstrap, retryWithTcpOnTimeout);
}

@Override
Expand Down
254 changes: 21 additions & 233 deletions resolver-dns/src/main/java/io/netty/resolver/dns/DnsNameResolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import io.netty.channel.ChannelFactory;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerAdapter;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
Expand All @@ -43,8 +45,6 @@
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.handler.codec.dns.DnsRecordType;
import io.netty.handler.codec.dns.DnsResponse;
import io.netty.handler.codec.dns.TcpDnsQueryEncoder;
import io.netty.handler.codec.dns.TcpDnsResponseDecoder;
import io.netty.resolver.DefaultHostsFileEntriesResolver;
import io.netty.resolver.HostsFileEntries;
import io.netty.resolver.HostsFileEntriesResolver;
Expand Down Expand Up @@ -121,6 +121,13 @@ public class DnsNameResolver extends InetNameResolver {
private static final InternetProtocolFamily[] IPV6_PREFERRED_RESOLVED_PROTOCOL_FAMILIES =
{InternetProtocolFamily.IPv6, InternetProtocolFamily.IPv4};

private static final ChannelHandler NOOP_HANDLER = new ChannelHandlerAdapter() {
@Override
public boolean isSharable() {
return true;
}
};

static final ResolvedAddressTypes DEFAULT_RESOLVE_ADDRESS_TYPES;
static final String[] DEFAULT_SEARCH_DOMAINS;
private static final UnixResolverOptions DEFAULT_OPTIONS;
Expand Down Expand Up @@ -227,7 +234,6 @@ protected DnsResponse decodeResponse(ChannelHandlerContext ctx, DatagramPacket p
}
};
private static final DatagramDnsQueryEncoder DATAGRAM_ENCODER = new DatagramDnsQueryEncoder();
private static final TcpDnsQueryEncoder TCP_ENCODER = new TcpDnsQueryEncoder();

private final Promise<Channel> channelReadyPromise;
private final Channel ch;
Expand Down Expand Up @@ -273,6 +279,7 @@ protected DnsServerAddressStream initialValue() {
private final DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory;
private final boolean completeOncePreferredResolved;
private final Bootstrap socketBootstrap;
private final boolean retryWithTcpOnTimeout;

private final int maxNumConsolidation;
private final Map<String, Future<List<InetAddress>>> inflightLookups;
Expand Down Expand Up @@ -376,44 +383,18 @@ public DnsNameResolver(
String[] searchDomains,
int ndots,
boolean decodeIdn) {
this(eventLoop, channelFactory, null, resolveCache, NoopDnsCnameCache.INSTANCE, authoritativeDnsServerCache,
this(eventLoop, channelFactory, null, false, resolveCache,
NoopDnsCnameCache.INSTANCE, authoritativeDnsServerCache, null,
dnsQueryLifecycleObserverFactory, queryTimeoutMillis, resolvedAddressTypes, recursionDesired,
maxQueriesPerResolve, traceEnabled, maxPayloadSize, optResourceEnabled, hostsFileEntriesResolver,
dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn, false);
}

DnsNameResolver(
EventLoop eventLoop,
ChannelFactory<? extends DatagramChannel> channelFactory,
ChannelFactory<? extends SocketChannel> socketChannelFactory,
final DnsCache resolveCache,
final DnsCnameCache cnameCache,
final AuthoritativeDnsServerCache authoritativeDnsServerCache,
DnsQueryLifecycleObserverFactory dnsQueryLifecycleObserverFactory,
long queryTimeoutMillis,
ResolvedAddressTypes resolvedAddressTypes,
boolean recursionDesired,
int maxQueriesPerResolve,
boolean traceEnabled,
int maxPayloadSize,
boolean optResourceEnabled,
HostsFileEntriesResolver hostsFileEntriesResolver,
DnsServerAddressStreamProvider dnsServerAddressStreamProvider,
String[] searchDomains,
int ndots,
boolean decodeIdn,
boolean completeOncePreferredResolved) {
this(eventLoop, channelFactory, socketChannelFactory, resolveCache, cnameCache, authoritativeDnsServerCache,
null, dnsQueryLifecycleObserverFactory, queryTimeoutMillis, resolvedAddressTypes,
recursionDesired, maxQueriesPerResolve, traceEnabled, maxPayloadSize, optResourceEnabled,
hostsFileEntriesResolver, dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn,
completeOncePreferredResolved, 0);
dnsServerAddressStreamProvider, searchDomains, ndots, decodeIdn, false, 0);
}

DnsNameResolver(
EventLoop eventLoop,
ChannelFactory<? extends DatagramChannel> channelFactory,
ChannelFactory<? extends SocketChannel> socketChannelFactory,
boolean retryWithTcpOnTimeout,
final DnsCache resolveCache,
final DnsCnameCache cnameCache,
final AuthoritativeDnsServerCache authoritativeDnsServerCache,
Expand Down Expand Up @@ -457,6 +438,7 @@ public DnsNameResolver(
this.ndots = ndots >= 0 ? ndots : DEFAULT_OPTIONS.ndots();
this.decodeIdn = decodeIdn;
this.completeOncePreferredResolved = completeOncePreferredResolved;
this.retryWithTcpOnTimeout = retryWithTcpOnTimeout;
if (socketChannelFactory == null) {
socketBootstrap = null;
} else {
Expand All @@ -465,7 +447,7 @@ public DnsNameResolver(
.group(executor())
.channelFactory(socketChannelFactory)
.attr(DNS_PIPELINE_ATTRIBUTE, Boolean.TRUE)
.handler(TCP_ENCODER);
.handler(NOOP_HANDLER);
if (queryTimeoutMillis > 0 && queryTimeoutMillis <= Integer.MAX_VALUE) {
// Set the connect timeout to the same as queryTimeout as otherwise it might take a long
// time for the query to fail in case of a connection timeout.
Expand Down Expand Up @@ -1354,8 +1336,9 @@ final Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> query0(
final int payloadSize = isOptResourceEnabled() ? maxPayloadSize() : 0;
try {
DnsQueryContext queryContext = new DatagramDnsQueryContext(ch, channelReadyPromise, nameServerAddr,
queryContextManager, payloadSize, isRecursionDesired(), question, additionals, castPromise);
ChannelFuture future = queryContext.writeQuery(queryTimeoutMillis(), flush);
queryContextManager, payloadSize, isRecursionDesired(), queryTimeoutMillis(), question, additionals,
castPromise, socketBootstrap, retryWithTcpOnTimeout);
ChannelFuture future = queryContext.writeQuery(flush);
queryLifecycleObserver.queryWritten(nameServerAddr, future);
return castPromise;
} catch (Exception e) {
Expand Down Expand Up @@ -1400,94 +1383,8 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
return;
}

// Check if the response was truncated and if we can fallback to TCP to retry.
if (!res.isTruncated() || socketBootstrap == null) {
qCtx.finishSuccess(res);
return;
}

socketBootstrap.connect(res.sender()).addListener(new ChannelFutureListener() {
@Override
public void operationComplete(ChannelFuture future) {
if (!future.isSuccess()) {
logger.debug("{} Unable to fallback to TCP [{}: {}]",
ch, queryId, res.sender(), future.cause());

// TCP fallback failed, just use the truncated response.
qCtx.finishSuccess(res);
return;
}
final Channel tcpCh = future.channel();

Promise<AddressedEnvelope<DnsResponse, InetSocketAddress>> promise =
tcpCh.eventLoop().newPromise();
final int payloadSize = isOptResourceEnabled() ? maxPayloadSize() : 0;
final TcpDnsQueryContext tcpCtx = new TcpDnsQueryContext(tcpCh, channelReadyPromise,
(InetSocketAddress) tcpCh.remoteAddress(), queryContextManager, payloadSize,
isRecursionDesired(), qCtx.question(), EMPTY_ADDITIONALS, promise);

tcpCh.pipeline().addLast(new TcpDnsResponseDecoder());
tcpCh.pipeline().addLast(new ChannelInboundHandlerAdapter() {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
Channel tcpCh = ctx.channel();
DnsResponse response = (DnsResponse) msg;
int queryId = response.id();

if (logger.isDebugEnabled()) {
logger.debug("{} RECEIVED: TCP [{}: {}], {}", tcpCh, queryId,
tcpCh.remoteAddress(), response);
}

DnsQueryContext foundCtx = queryContextManager.get(res.sender(), queryId);
if (foundCtx != null && foundCtx.isDone()) {
logger.debug("{} Received a DNS response for a query that was timed out or cancelled " +
": TCP [{}: {}]", tcpCh, queryId, res.sender());
response.release();
} else if (foundCtx == tcpCtx) {
tcpCtx.finishSuccess(new AddressedEnvelopeAdapter(
(InetSocketAddress) ctx.channel().remoteAddress(),
(InetSocketAddress) ctx.channel().localAddress(),
response));
} else {
response.release();
tcpCtx.finishFailure("Received TCP DNS response with unexpected ID", null, false);
if (logger.isDebugEnabled()) {
logger.debug("{} Received a DNS response with an unexpected ID: TCP [{}: {}]",
tcpCh, queryId, tcpCh.remoteAddress());
}
}
}

@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
if (tcpCtx.finishFailure(
"TCP fallback error", cause, false) && logger.isDebugEnabled()) {
logger.debug("{} Error during processing response: TCP [{}: {}]",
ctx.channel(), queryId,
ctx.channel().remoteAddress(), cause);
}
}
});

promise.addListener(
new FutureListener<AddressedEnvelope<DnsResponse, InetSocketAddress>>() {
@Override
public void operationComplete(
Future<AddressedEnvelope<DnsResponse, InetSocketAddress>> future) {
if (future.isSuccess()) {
qCtx.finishSuccess(future.getNow());
res.release();
} else {
// TCP fallback failed, just use the truncated response.
qCtx.finishSuccess(res);
}
tcpCh.close();
}
});
tcpCtx.writeQuery(queryTimeoutMillis(), true);
}
});
// The context will handle truncation itself.
qCtx.finishSuccess(res, res.isTruncated());
}

@Override
Expand All @@ -1505,113 +1402,4 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
}
}
}

private static final class AddressedEnvelopeAdapter implements AddressedEnvelope<DnsResponse, InetSocketAddress> {
private final InetSocketAddress sender;
private final InetSocketAddress recipient;
private final DnsResponse response;

AddressedEnvelopeAdapter(InetSocketAddress sender, InetSocketAddress recipient, DnsResponse response) {
this.sender = sender;
this.recipient = recipient;
this.response = response;
}

@Override
public DnsResponse content() {
return response;
}

@Override
public InetSocketAddress sender() {
return sender;
}

@Override
public InetSocketAddress recipient() {
return recipient;
}

@Override
public AddressedEnvelope<DnsResponse, InetSocketAddress> retain() {
response.retain();
return this;
}

@Override
public AddressedEnvelope<DnsResponse, InetSocketAddress> retain(int increment) {
response.retain(increment);
return this;
}

@Override
public AddressedEnvelope<DnsResponse, InetSocketAddress> touch() {
response.touch();
return this;
}

@Override
public AddressedEnvelope<DnsResponse, InetSocketAddress> touch(Object hint) {
response.touch(hint);
return this;
}

@Override
public int refCnt() {
return response.refCnt();
}

@Override
public boolean release() {
return response.release();
}

@Override
public boolean release(int decrement) {
return response.release(decrement);
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
}

if (!(obj instanceof AddressedEnvelope)) {
return false;
}

@SuppressWarnings("unchecked")
final AddressedEnvelope<?, SocketAddress> that = (AddressedEnvelope<?, SocketAddress>) obj;
if (sender() == null) {
if (that.sender() != null) {
return false;
}
} else if (!sender().equals(that.sender())) {
return false;
}

if (recipient() == null) {
if (that.recipient() != null) {
return false;
}
} else if (!recipient().equals(that.recipient())) {
return false;
}

return response.equals(obj);
}

@Override
public int hashCode() {
int hashCode = response.hashCode();
if (sender() != null) {
hashCode = hashCode * 31 + sender().hashCode();
}
if (recipient() != null) {
hashCode = hashCode * 31 + recipient().hashCode();
}
return hashCode;
}
}
}
Loading

0 comments on commit 684dfd8

Please sign in to comment.