Commit 7b84de4c authored by Jelte Jansen's avatar Jelte Jansen
Browse files

Merge branch 'experiments/resolver'

Conflicts:
	src/bin/resolver/resolver.cc
parents dd37e953 b9b7ef39
......@@ -159,9 +159,13 @@ AuthSrvImpl::~AuthSrvImpl() {
class MessageLookup : public DNSLookup {
public:
MessageLookup(AuthSrv* srv) : server_(srv) {}
virtual void operator()(const IOMessage& io_message, MessagePtr message,
OutputBufferPtr buffer, DNSServer* server) const
virtual void operator()(const IOMessage& io_message,
MessagePtr message,
MessagePtr answer_message,
OutputBufferPtr buffer,
DNSServer* server) const
{
(void) answer_message;
server_->processMessage(io_message, message, buffer, server);
}
private:
......@@ -180,7 +184,7 @@ class MessageAnswer : public DNSAnswer {
public:
MessageAnswer(AuthSrv*) {}
virtual void operator()(const IOMessage&, MessagePtr,
OutputBufferPtr) const
MessagePtr, OutputBufferPtr) const
{}
};
......
......@@ -138,9 +138,10 @@ TEST_F(AuthSrvTest, builtInQueryViaDNSServer) {
createRequestPacket(request_message, IPPROTO_UDP);
(*server.getDNSLookupProvider())(*io_message, parse_message,
response_message,
response_obuffer, &dnsserv);
(*server.getDNSAnswerProvider())(*io_message, parse_message,
response_obuffer);
response_message, response_obuffer);
createBuiltinVersionResponse(default_qid, response_data);
EXPECT_PRED_FORMAT4(UnitTestUtil::matchWireData,
......@@ -153,9 +154,10 @@ TEST_F(AuthSrvTest, builtInQueryViaDNSServer) {
TEST_F(AuthSrvTest, iqueryViaDNSServer) {
createDataFromFile("iquery_fromWire.wire");
(*server.getDNSLookupProvider())(*io_message, parse_message,
response_message,
response_obuffer, &dnsserv);
(*server.getDNSAnswerProvider())(*io_message, parse_message,
response_obuffer);
response_message, response_obuffer);
UnitTestUtil::readWireData("iquery_response_fromWire.wire",
response_data);
......
......@@ -95,20 +95,20 @@ public:
{
upstream_ = upstream;
if (dnss) {
if (upstream_.empty()) {
dlog("Asked to do full recursive, but not implemented yet. "
"I'll do nothing.",true);
} else {
if (!upstream_.empty()) {
dlog("Setting forward addresses:");
BOOST_FOREACH(const addr_t& address, upstream) {
dlog(" " + address.first + ":" +
boost::lexical_cast<string>(address.second));
}
} else {
dlog("No forward addresses, running in recursive mode");
}
}
}
void processNormalQuery(const Question& question, MessagePtr message,
void processNormalQuery(const Question& question,
MessagePtr answer_message,
OutputBufferPtr buffer,
DNSServer* server);
......@@ -149,20 +149,6 @@ public:
MessagePtr message_;
};
class SectionInserter {
public:
SectionInserter(MessagePtr message, const Message::Section sect) :
message_(message), section_(sect)
{}
void operator()(const RRsetPtr rrset) {
//dlog("Adding RRSet to message section " +
// boost::lexical_cast<string>(section_));
message_->addRRset(section_, rrset, true);
}
MessagePtr message_;
const Message::Section section_;
};
void
makeErrorMessage(MessagePtr message, OutputBufferPtr buffer,
const Rcode& rcode)
......@@ -210,10 +196,14 @@ public:
MessageLookup(Resolver* srv) : server_(srv) {}
// \brief Handle the DNS Lookup
virtual void operator()(const IOMessage& io_message, MessagePtr message,
OutputBufferPtr buffer, DNSServer* server) const
virtual void operator()(const IOMessage& io_message,
MessagePtr query_message,
MessagePtr answer_message,
OutputBufferPtr buffer,
DNSServer* server) const
{
server_->processMessage(io_message, message, buffer, server);
server_->processMessage(io_message, query_message,
answer_message, buffer, server);
}
private:
Resolver* server_;
......@@ -226,76 +216,62 @@ private:
class MessageAnswer : public DNSAnswer {
public:
virtual void operator()(const IOMessage& io_message,
MessagePtr message,
MessagePtr query_message,
MessagePtr answer_message,
OutputBufferPtr buffer) const
{
const qid_t qid = message->getQid();
const bool rd = message->getHeaderFlag(Message::HEADERFLAG_RD);
const bool cd = message->getHeaderFlag(Message::HEADERFLAG_CD);
const Opcode& opcode = message->getOpcode();
const Rcode& rcode = message->getRcode();
vector<QuestionPtr> questions;
questions.assign(message->beginQuestion(), message->endQuestion());
const qid_t qid = query_message->getQid();
const bool rd = query_message->getHeaderFlag(Message::HEADERFLAG_RD);
const bool cd = query_message->getHeaderFlag(Message::HEADERFLAG_CD);
const Opcode& opcode = query_message->getOpcode();
message->clear(Message::RENDER);
message->setQid(qid);
message->setOpcode(opcode);
message->setRcode(rcode);
// Fill in the final details of the answer message
answer_message->setQid(qid);
answer_message->setOpcode(opcode);
message->setHeaderFlag(Message::HEADERFLAG_QR);
message->setHeaderFlag(Message::HEADERFLAG_RA);
answer_message->setHeaderFlag(Message::HEADERFLAG_QR);
answer_message->setHeaderFlag(Message::HEADERFLAG_RA);
if (rd) {
message->setHeaderFlag(Message::HEADERFLAG_RD);
answer_message->setHeaderFlag(Message::HEADERFLAG_RD);
}
if (cd) {
message->setHeaderFlag(Message::HEADERFLAG_CD);
}
// Copy the question section.
for_each(questions.begin(), questions.end(), QuestionInserter(message));
// If the buffer already has an answer in it, copy RRsets from
// that into the new message, then clear the buffer and render
// the new message into it.
if (buffer->getLength() != 0) {
try {
Message incoming(Message::PARSE);
InputBuffer ibuf(buffer->getData(), buffer->getLength());
incoming.fromWire(ibuf);
message->setRcode(incoming.getRcode());
for_each(incoming.beginSection(Message::SECTION_ANSWER),
incoming.endSection(Message::SECTION_ANSWER),
SectionInserter(message, Message::SECTION_ANSWER));
for_each(incoming.beginSection(Message::SECTION_AUTHORITY),
incoming.endSection(Message::SECTION_AUTHORITY),
SectionInserter(message, Message::SECTION_AUTHORITY));
for_each(incoming.beginSection(Message::SECTION_ADDITIONAL),
incoming.endSection(Message::SECTION_ADDITIONAL),
SectionInserter(message, Message::SECTION_ADDITIONAL));
} catch (const Exception& ex) {
// Incoming message couldn't be read, we just SERVFAIL
message->setRcode(Rcode::SERVFAIL());
}
answer_message->setHeaderFlag(Message::HEADERFLAG_CD);
}
vector<QuestionPtr> questions;
questions.assign(query_message->beginQuestion(), query_message->endQuestion());
for_each(questions.begin(), questions.end(), QuestionInserter(answer_message));
// Now we can clear the buffer and render the new message into it
buffer->clear();
MessageRenderer renderer(*buffer);
ConstEDNSPtr edns(query_message->getEDNS());
const bool dnssec_ok = edns && edns->getDNSSECAwareness();
if (edns) {
EDNSPtr edns_response(new EDNS());
edns_response->setDNSSECAwareness(dnssec_ok);
// TODO: We should make our own edns bufsize length configurable
edns_response->setUDPSize(Message::DEFAULT_MAX_EDNS0_UDPSIZE);
answer_message->setEDNS(edns_response);
}
if (io_message.getSocket().getProtocol() == IPPROTO_UDP) {
ConstEDNSPtr edns(message->getEDNS());
renderer.setLengthLimit(edns ? edns->getUDPSize() :
Message::DEFAULT_MAX_UDPSIZE);
if (edns) {
renderer.setLengthLimit(edns->getUDPSize());
} else {
renderer.setLengthLimit(Message::DEFAULT_MAX_UDPSIZE);
}
} else {
renderer.setLengthLimit(65535);
}
message->toWire(renderer);
answer_message->toWire(renderer);
dlog(string("sending a response (") +
boost::lexical_cast<string>(renderer.getLength()) + "bytes): \n" +
message->toText());
answer_message->toText());
}
};
......@@ -345,18 +321,21 @@ Resolver::getConfigSession() const {
}
void
Resolver::processMessage(const IOMessage& io_message, MessagePtr message,
OutputBufferPtr buffer, DNSServer* server)
Resolver::processMessage(const IOMessage& io_message,
MessagePtr query_message,
MessagePtr answer_message,
OutputBufferPtr buffer,
DNSServer* server)
{
dlog("Got a DNS message");
InputBuffer request_buffer(io_message.getData(), io_message.getDataSize());
// First, check the header part. If we fail even for the base header,
// just drop the message.
try {
message->parseHeader(request_buffer);
query_message->parseHeader(request_buffer);
// Ignore all responses.
if (message->getHeaderFlag(Message::HEADERFLAG_QR)) {
if (query_message->getHeaderFlag(Message::HEADERFLAG_QR)) {
dlog("Received unexpected response, ignoring");
server->resume(false);
return;
......@@ -369,52 +348,53 @@ Resolver::processMessage(const IOMessage& io_message, MessagePtr message,
// Parse the message. On failure, return an appropriate error.
try {
message->fromWire(request_buffer);
query_message->fromWire(request_buffer);
} catch (const DNSProtocolError& error) {
dlog(string("returning ") + error.getRcode().toText() + ": " +
error.what());
makeErrorMessage(message, buffer, error.getRcode());
makeErrorMessage(query_message, buffer, error.getRcode());
server->resume(true);
return;
} catch (const Exception& ex) {
dlog(string("returning SERVFAIL: ") + ex.what());
makeErrorMessage(message, buffer, Rcode::SERVFAIL());
makeErrorMessage(query_message, buffer, Rcode::SERVFAIL());
server->resume(true);
return;
} // other exceptions will be handled at a higher layer.
dlog("received a message:\n" + message->toText());
dlog("received a message:\n" + query_message->toText());
// Perform further protocol-level validation.
bool sendAnswer = true;
if (message->getOpcode() == Opcode::NOTIFY()) {
makeErrorMessage(message, buffer, Rcode::NOTAUTH());
if (query_message->getOpcode() == Opcode::NOTIFY()) {
makeErrorMessage(query_message, buffer, Rcode::NOTAUTH());
dlog("Notify arrived, but we are not authoritative");
} else if (message->getOpcode() != Opcode::QUERY()) {
dlog("Unsupported opcode (got: " + message->getOpcode().toText() +
} else if (query_message->getOpcode() != Opcode::QUERY()) {
dlog("Unsupported opcode (got: " + query_message->getOpcode().toText() +
", expected: " + Opcode::QUERY().toText());
makeErrorMessage(message, buffer, Rcode::NOTIMP());
} else if (message->getRRCount(Message::SECTION_QUESTION) != 1) {
makeErrorMessage(query_message, buffer, Rcode::NOTIMP());
} else if (query_message->getRRCount(Message::SECTION_QUESTION) != 1) {
dlog("The query contained " +
boost::lexical_cast<string>(message->getRRCount(
boost::lexical_cast<string>(query_message->getRRCount(
Message::SECTION_QUESTION) + " questions, exactly one expected"));
makeErrorMessage(message, buffer, Rcode::FORMERR());
makeErrorMessage(query_message, buffer, Rcode::FORMERR());
} else {
ConstQuestionPtr question = *message->beginQuestion();
ConstQuestionPtr question = *query_message->beginQuestion();
const RRType &qtype = question->getType();
if (qtype == RRType::AXFR()) {
if (io_message.getSocket().getProtocol() == IPPROTO_UDP) {
makeErrorMessage(message, buffer, Rcode::FORMERR());
makeErrorMessage(query_message, buffer, Rcode::FORMERR());
} else {
makeErrorMessage(message, buffer, Rcode::NOTIMP());
makeErrorMessage(query_message, buffer, Rcode::NOTIMP());
}
} else if (qtype == RRType::IXFR()) {
makeErrorMessage(message, buffer, Rcode::NOTIMP());
makeErrorMessage(query_message, buffer, Rcode::NOTIMP());
} else {
// The RecursiveQuery object will post the "resume" event to the
// DNSServer when an answer arrives, so we don't have to do it now.
sendAnswer = false;
impl_->processNormalQuery(*question, message, buffer, server);
impl_->processNormalQuery(*question, answer_message,
buffer, server);
}
}
......@@ -424,23 +404,13 @@ Resolver::processMessage(const IOMessage& io_message, MessagePtr message,
}
void
ResolverImpl::processNormalQuery(const Question& question, MessagePtr message,
OutputBufferPtr buffer, DNSServer* server)
ResolverImpl::processNormalQuery(const Question& question,
MessagePtr answer_message,
OutputBufferPtr buffer,
DNSServer* server)
{
dlog("Processing normal query");
ConstEDNSPtr edns(message->getEDNS());
const bool dnssec_ok = edns && edns->getDNSSECAwareness();
message->makeResponse();
message->setHeaderFlag(Message::HEADERFLAG_RA);
message->setRcode(Rcode::NOERROR());
if (edns) {
EDNSPtr edns_response(new EDNS());
edns_response->setDNSSECAwareness(dnssec_ok);
edns_response->setUDPSize(ResolverImpl::DEFAULT_LOCAL_UDPSIZE);
message->setEDNS(edns_response);
}
rec_query_->sendQuery(question, buffer, server);
rec_query_->sendQuery(question, answer_message, buffer, server);
}
namespace {
......
......@@ -63,7 +63,8 @@ public:
/// \param buffer Pointer to an \c OutputBuffer for the resposne
/// \param server Pointer to the \c DNSServer
void processMessage(const asiolink::IOMessage& io_message,
isc::dns::MessagePtr message,
isc::dns::MessagePtr query_message,
isc::dns::MessagePtr answer_message,
isc::dns::OutputBufferPtr buffer,
asiolink::DNSServer* server);
......
......@@ -134,8 +134,8 @@ TEST_F(ResolverConfig, listenAddresses) {
// Try putting there some addresses
vector<pair<string, uint16_t> > addresses;
addresses.push_back(pair<string, uint16_t>("127.0.0.1", 5300));
addresses.push_back(pair<string, uint16_t>("::1", 5300));
addresses.push_back(pair<string, uint16_t>("127.0.0.1", 5321));
addresses.push_back(pair<string, uint16_t>("::1", 5321));
server.setListenAddresses(addresses);
EXPECT_EQ(2, server.getListenAddresses().size());
EXPECT_EQ("::1", server.getListenAddresses()[1].first);
......@@ -155,7 +155,7 @@ TEST_F(ResolverConfig, DISABLED_listenAddressConfig) {
"\"listen_on\": ["
" {"
" \"address\": \"127.0.0.1\","
" \"port\": 5300"
" \"port\": 5321"
" }"
"]"
"}"));
......@@ -163,7 +163,7 @@ TEST_F(ResolverConfig, DISABLED_listenAddressConfig) {
EXPECT_EQ(result->toWire(), isc::config::createAnswer()->toWire());
ASSERT_EQ(1, server.getListenAddresses().size());
EXPECT_EQ("127.0.0.1", server.getListenAddresses()[0].first);
EXPECT_EQ(5300, server.getListenAddresses()[0].second);
EXPECT_EQ(5321, server.getListenAddresses()[0].second);
// As this is example address, the machine should not have it on
// any interface
......@@ -174,7 +174,7 @@ TEST_F(ResolverConfig, DISABLED_listenAddressConfig) {
"\"listen_on\": ["
" {"
" \"address\": \"192.0.2.0\","
" \"port\": 5300"
" \"port\": 5321"
" }"
"]"
"}");
......@@ -182,7 +182,7 @@ TEST_F(ResolverConfig, DISABLED_listenAddressConfig) {
EXPECT_FALSE(result->equals(*isc::config::createAnswer()));
ASSERT_EQ(1, server.getListenAddresses().size());
EXPECT_EQ("127.0.0.1", server.getListenAddresses()[0].first);
EXPECT_EQ(5300, server.getListenAddresses()[0].second);
EXPECT_EQ(5321, server.getListenAddresses()[0].second);
}
TEST_F(ResolverConfig, invalidListenAddresses) {
......
......@@ -30,7 +30,10 @@ class ResolverTest : public SrvTestBase{
protected:
ResolverTest() : server(){}
virtual void processMessage() {
server.processMessage(*io_message, parse_message, response_obuffer,
server.processMessage(*io_message,
parse_message,
response_message,
response_obuffer,
&dnsserv);
}
Resolver server;
......@@ -83,7 +86,11 @@ TEST_F(ResolverTest, AXFRFail) {
RRType::AXFR());
createRequestPacket(request_message, IPPROTO_TCP);
// AXFR is not implemented and should always send NOTIMP.
server.processMessage(*io_message, parse_message, response_obuffer, &dnsserv);
server.processMessage(*io_message,
parse_message,
response_message,
response_obuffer,
&dnsserv);
EXPECT_TRUE(dnsserv.hasAnswer());
headerCheck(*parse_message, default_qid, Rcode::NOTIMP(), opcode.getCode(),
QR_FLAG, 1, 0, 0, 0);
......@@ -98,7 +105,11 @@ TEST_F(ResolverTest, notifyFail) {
request_message.setQid(default_qid);
request_message.setHeaderFlag(Message::HEADERFLAG_AA);
createRequestPacket(request_message, IPPROTO_UDP);
server.processMessage(*io_message, parse_message, response_obuffer, &dnsserv);
server.processMessage(*io_message,
parse_message,
response_message,
response_obuffer,
&dnsserv);
EXPECT_TRUE(dnsserv.hasAnswer());
headerCheck(*parse_message, default_qid, Rcode::NOTAUTH(),
Opcode::NOTIFY().getCode(), QR_FLAG, 0, 0, 0, 0);
......
......@@ -30,6 +30,7 @@
#include <dns/buffer.h>
#include <dns/message.h>
#include <dns/rcode.h>
#include <asiolink/asiolink.h>
#include <asiolink/internal/tcpdns.h>
......@@ -37,6 +38,7 @@
#include <log/dummylog.h>
using namespace asio;
using asio::ip::udp;
using asio::ip::tcp;
......@@ -46,8 +48,46 @@ using namespace isc::dns;
using isc::log::dlog;
using namespace boost;
// Is this something we can use in libdns++?
namespace {
class SectionInserter {
public:
SectionInserter(MessagePtr message, const Message::Section sect) :
message_(message), section_(sect)
{}
void operator()(const RRsetPtr rrset) {
message_->addRRset(section_, rrset, true);
}
MessagePtr message_;
const Message::Section section_;
};
/// \brief Copies the parts relevant for a DNS answer to the
/// target message
///
/// This adds all the RRsets in the answer, authority and
/// additional sections to the target, as well as the response
/// code
void copyAnswerMessage(const Message& source, MessagePtr target) {
target->setRcode(source.getRcode());
for_each(source.beginSection(Message::SECTION_ANSWER),
source.endSection(Message::SECTION_ANSWER),
SectionInserter(target, Message::SECTION_ANSWER));
for_each(source.beginSection(Message::SECTION_AUTHORITY),
source.endSection(Message::SECTION_AUTHORITY),
SectionInserter(target, Message::SECTION_AUTHORITY));
for_each(source.beginSection(Message::SECTION_ADDITIONAL),
source.endSection(Message::SECTION_ADDITIONAL),
SectionInserter(target, Message::SECTION_ADDITIONAL));
}
}
namespace asiolink {
typedef pair<string, uint16_t> addr_t;
class IOServiceImpl {
private:
IOServiceImpl(const IOService& source);
......@@ -296,6 +336,12 @@ private:
// Info for (re)sending the query (the question and destination)
Question question_;
// This is where we build and store our final answer
MessagePtr answer_message_;
// currently we use upstream as the current list of NS records
// we should differentiate between forwarding and resolving
shared_ptr<AddressVector> upstream_;
// Buffer to store the result.
......@@ -312,9 +358,26 @@ private:
int timeout_;
unsigned retries_;
// normal query state
// if we change this to running and add a sent, we can do
// decoupled timeouts i think
bool done;
// Not using NSAS at this moment, so we keep a list
// of 'current' zone servers
vector<addr_t> zone_servers_;
// Update the question that will be sent to the server
void setQuestion(const Question& new_question) {
question_ = new_question;
}
// (re)send the query to the server.
void send() {
const int uc = upstream_->size();
const int zs = zone_servers_.size();
buffer_->clear();
if (uc > 0) {
int serverIndex = rand() % uc;
dlog("Sending upstream query (" + question_.toText() +
......@@ -324,34 +387,138 @@ private:
upstream_->at(serverIndex).second, buffer_, this,
timeout_);
io_.post(query);
} else if (zs > 0) {
int serverIndex = rand() % zs;
dlog("Sending query to zone server (" + question_.toText() +
") to " + zone_servers_.at(serverIndex).first);
UDPQuery query(io_, question_,
zone_servers_.at(serverIndex).first,
zone_servers_.at(serverIndex).second, buffer_, this,
timeout_);
io_.post(query);
} else {
dlog("Error, no upstream servers to send to.");
}
}
// This function is called by operator() if there is an actual
// answer from a server and we are in recursive mode
// depending on the contents, we go on recursing or return
//
// Note that the footprint may change as this function may
// need to append data to the answer we are building later.
//
// returns true if we are done
// returns false if we are not done
bool handleRecursiveAnswer(const Message& incoming) {
if (incoming.getRRCount(Message::SECTION_ANSWER) > 0) {
dlog("Got final result, copying answer.");
copyAnswerMessage(incoming, answer_message_);
return true;
} else {
dlog("Got delegation, continuing");
// ok we need to do some more processing.
// the ns list should contain all nameservers
// while the additional may contain addresses for
// them.
// this needs to tie into NSAS of course
// for this very first mockup, hope there is an
// address in additional and just use that
// send query to the addresses in the delegation
bool found_ns_address = false;
zone_servers_.clear();
for (RRsetIterator rrsi = incoming.beginSection(Message::SECTION_ADDITIONAL);
rrsi != incoming.endSection(Message::SECTION_ADDITIONAL) && !found_ns_address;
rrsi++) {
ConstRRsetPtr rrs = *rrsi;
if (rrs->getType() == RRType::A()) {
// found address
RdataIteratorPtr rdi = rrs->getRdataIterator();
// just use the first for now
if (!rdi->isLast()) {
std::string addr_str = rdi->getCurrent().toText();
dlog("[XX] first address found: " + addr_str);
// now we have one address, simply
// resend that exact same query