Commit d7fb4b72 authored by Michal 'vorner' Vaner's avatar Michal 'vorner' Vaner

Merge #1601

Conflicts:
	src/bin/auth/auth_srv.cc
	src/bin/auth/tests/auth_srv_unittest.cc
parents 4cee65dd d9ae23e5
......@@ -87,14 +87,14 @@ public:
~AuthSrvImpl();
isc::data::ConstElementPtr setDbFile(isc::data::ConstElementPtr config);
bool processNormalQuery(const IOMessage& io_message, MessagePtr message,
OutputBufferPtr buffer,
bool processNormalQuery(const IOMessage& io_message, Message& message,
OutputBuffer& buffer,
auto_ptr<TSIGContext> tsig_context);
bool processXfrQuery(const IOMessage& io_message, MessagePtr message,
OutputBufferPtr buffer,
bool processXfrQuery(const IOMessage& io_message, Message& message,
OutputBuffer& buffer,
auto_ptr<TSIGContext> tsig_context);
bool processNotify(const IOMessage& io_message, MessagePtr message,
OutputBufferPtr buffer,
bool processNotify(const IOMessage& io_message, Message& message,
OutputBuffer& buffer,
auto_ptr<TSIGContext> tsig_context);
IOService io_service_;
......@@ -142,7 +142,7 @@ public:
/// \param done If true, the Rcode from the given message is counted,
/// this value is then passed to server->resume(bool)
void resumeServer(isc::asiodns::DNSServer* server,
isc::dns::MessagePtr message,
isc::dns::Message& message,
bool done);
private:
std::string db_file_;
......@@ -200,12 +200,11 @@ public:
MessageLookup(AuthSrv* srv) : server_(srv) {}
virtual void operator()(const IOMessage& io_message,
MessagePtr message,
MessagePtr answer_message,
MessagePtr, // Not used here
OutputBufferPtr buffer,
DNSServer* server) const
{
(void) answer_message;
server_->processMessage(io_message, message, buffer, server);
server_->processMessage(io_message, *message, *buffer, server);
}
private:
AuthSrv* server_;
......@@ -266,57 +265,57 @@ AuthSrv::~AuthSrv() {
namespace {
class QuestionInserter {
public:
QuestionInserter(MessagePtr message) : message_(message) {}
QuestionInserter(Message& message) : message_(message) {}
void operator()(const QuestionPtr question) {
message_->addQuestion(question);
message_.addQuestion(question);
}
MessagePtr message_;
Message& message_;
};
void
makeErrorMessage(MessagePtr message, OutputBufferPtr buffer,
const Rcode& rcode,
makeErrorMessage(Message& message, OutputBuffer& buffer,
const Rcode& rcode,
std::auto_ptr<TSIGContext> tsig_context =
std::auto_ptr<TSIGContext>())
{
// extract the parameters that should be kept.
// XXX: with the current implementation, it's not easy to set EDNS0
// depending on whether the query had it. So we'll simply omit it.
const qid_t qid = message->getQid();
const bool rd = message->getHeaderFlag(Message::HEADERFLAG_RD);
const bool cd = message->getHeaderFlag(Message::HEADERFLAG_CD);
const Opcode& opcode = message->getOpcode();
const qid_t qid = message.getQid();
const bool rd = message.getHeaderFlag(Message::HEADERFLAG_RD);
const bool cd = message.getHeaderFlag(Message::HEADERFLAG_CD);
const Opcode& opcode = message.getOpcode();
vector<QuestionPtr> questions;
// If this is an error to a query or notify, we should also copy the
// question section.
if (opcode == Opcode::QUERY() || opcode == Opcode::NOTIFY()) {
questions.assign(message->beginQuestion(), message->endQuestion());
questions.assign(message.beginQuestion(), message.endQuestion());
}
message->clear(Message::RENDER);
message->setQid(qid);
message->setOpcode(opcode);
message->setHeaderFlag(Message::HEADERFLAG_QR);
message.clear(Message::RENDER);
message.setQid(qid);
message.setOpcode(opcode);
message.setHeaderFlag(Message::HEADERFLAG_QR);
if (rd) {
message->setHeaderFlag(Message::HEADERFLAG_RD);
message.setHeaderFlag(Message::HEADERFLAG_RD);
}
if (cd) {
message->setHeaderFlag(Message::HEADERFLAG_CD);
message.setHeaderFlag(Message::HEADERFLAG_CD);
}
for_each(questions.begin(), questions.end(), QuestionInserter(message));
message->setRcode(rcode);
message.setRcode(rcode);
MessageRenderer renderer;
renderer.setBuffer(buffer.get());
renderer.setBuffer(&buffer);
if (tsig_context.get() != NULL) {
message->toWire(renderer, *tsig_context);
message.toWire(renderer, *tsig_context);
} else {
message->toWire(renderer);
message.toWire(renderer);
}
renderer.setBuffer(NULL);
LOG_DEBUG(auth_logger, DBG_AUTH_MESSAGES, AUTH_SEND_ERROR_RESPONSE)
.arg(renderer.getLength()).arg(*message);
.arg(renderer.getLength()).arg(message);
}
}
......@@ -414,18 +413,18 @@ AuthSrv::setStatisticsTimerInterval(uint32_t interval) {
}
void
AuthSrv::processMessage(const IOMessage& io_message, MessagePtr message,
OutputBufferPtr buffer, DNSServer* server)
AuthSrv::processMessage(const IOMessage& io_message, Message& message,
OutputBuffer& buffer, DNSServer* server)
{
InputBuffer request_buffer(io_message.getData(), io_message.getDataSize());
// First, check the header part. If we fail even for the base header,
// just drop the message.
try {
message->parseHeader(request_buffer);
message.parseHeader(request_buffer);
// Ignore all responses.
if (message->getHeaderFlag(Message::HEADERFLAG_QR)) {
if (message.getHeaderFlag(Message::HEADERFLAG_QR)) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_RESPONSE_RECEIVED);
impl_->resumeServer(server, message, false);
return;
......@@ -439,7 +438,7 @@ AuthSrv::processMessage(const IOMessage& io_message, MessagePtr message,
try {
// Parse the message.
message->fromWire(request_buffer);
message.fromWire(request_buffer);
} catch (const DNSProtocolError& error) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_PACKET_PROTOCOL_ERROR)
.arg(error.getRcode().toText()).arg(error.what());
......@@ -455,13 +454,13 @@ AuthSrv::processMessage(const IOMessage& io_message, MessagePtr message,
} // other exceptions will be handled at a higher layer.
LOG_DEBUG(auth_logger, DBG_AUTH_MESSAGES, AUTH_PACKET_RECEIVED)
.arg(message->toText());
.arg(message);
// Perform further protocol-level validation.
// TSIG first
// If this is set to something, we know we need to answer with TSIG as well
std::auto_ptr<TSIGContext> tsig_context;
const TSIGRecord* tsig_record(message->getTSIGRecord());
const TSIGRecord* tsig_record(message.getTSIGRecord());
TSIGError tsig_error(TSIGError::NOERROR());
// Do we do TSIG?
......@@ -485,19 +484,19 @@ AuthSrv::processMessage(const IOMessage& io_message, MessagePtr message,
try {
// update per opcode statistics counter. This can only be reliable
// after TSIG check succeeds.
impl_->counters_.inc(message->getOpcode());
impl_->counters_.inc(message.getOpcode());
if (message->getOpcode() == Opcode::NOTIFY()) {
if (message.getOpcode() == Opcode::NOTIFY()) {
send_answer = impl_->processNotify(io_message, message, buffer,
tsig_context);
} else if (message->getOpcode() != Opcode::QUERY()) {
} else if (message.getOpcode() != Opcode::QUERY()) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_UNSUPPORTED_OPCODE)
.arg(message->getOpcode().toText());
.arg(message.getOpcode().toText());
makeErrorMessage(message, buffer, Rcode::NOTIMP(), tsig_context);
} else if (message->getRRCount(Message::SECTION_QUESTION) != 1) {
} else if (message.getRRCount(Message::SECTION_QUESTION) != 1) {
makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context);
} else {
ConstQuestionPtr question = *message->beginQuestion();
ConstQuestionPtr question = *message.beginQuestion();
const RRType &qtype = question->getType();
if (qtype == RRType::AXFR()) {
send_answer = impl_->processXfrQuery(io_message, message,
......@@ -522,18 +521,18 @@ AuthSrv::processMessage(const IOMessage& io_message, MessagePtr message,
}
bool
AuthSrvImpl::processNormalQuery(const IOMessage& io_message, MessagePtr message,
OutputBufferPtr buffer,
AuthSrvImpl::processNormalQuery(const IOMessage& io_message, Message& message,
OutputBuffer& buffer,
auto_ptr<TSIGContext> tsig_context)
{
ConstEDNSPtr remote_edns = message->getEDNS();
ConstEDNSPtr remote_edns = message.getEDNS();
const bool dnssec_ok = remote_edns && remote_edns->getDNSSECAwareness();
const uint16_t remote_bufsize = remote_edns ? remote_edns->getUDPSize() :
Message::DEFAULT_MAX_UDPSIZE;
message->makeResponse();
message->setHeaderFlag(Message::HEADERFLAG_AA);
message->setRcode(Rcode::NOERROR());
message.makeResponse();
message.setHeaderFlag(Message::HEADERFLAG_AA);
message.setRcode(Rcode::NOERROR());
// Increment query counter.
incCounter(io_message.getSocket().getProtocol());
......@@ -542,20 +541,20 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, MessagePtr message,
EDNSPtr local_edns = EDNSPtr(new EDNS());
local_edns->setDNSSECAwareness(dnssec_ok);
local_edns->setUDPSize(AuthSrvImpl::DEFAULT_LOCAL_UDPSIZE);
message->setEDNS(local_edns);
message.setEDNS(local_edns);
}
try {
// If a memory data source is configured call the separate
// Query::process()
const ConstQuestionPtr question = *message->beginQuestion();
const ConstQuestionPtr question = *message.beginQuestion();
if (memory_client_ && memory_client_class_ == question->getClass()) {
const RRType& qtype = question->getType();
const Name& qname = question->getName();
auth::Query(*memory_client_, qname, qtype, *message,
auth::Query(*memory_client_, qname, qtype, message,
dnssec_ok).process();
} else {
datasrc::Query query(*message, cache_, dnssec_ok);
datasrc::Query query(message, cache_, dnssec_ok);
data_sources_.doQuery(query);
}
} catch (const Exception& ex) {
......@@ -565,25 +564,25 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, MessagePtr message,
}
MessageRenderer renderer;
renderer.setBuffer(buffer.get());
renderer.setBuffer(&buffer);
const bool udp_buffer =
(io_message.getSocket().getProtocol() == IPPROTO_UDP);
renderer.setLengthLimit(udp_buffer ? remote_bufsize : 65535);
if (tsig_context.get() != NULL) {
message->toWire(renderer, *tsig_context);
message.toWire(renderer, *tsig_context);
} else {
message->toWire(renderer);
message.toWire(renderer);
}
renderer.setBuffer(NULL);
LOG_DEBUG(auth_logger, DBG_AUTH_MESSAGES, AUTH_SEND_NORMAL_RESPONSE)
.arg(renderer.getLength()).arg(message->toText());
.arg(renderer.getLength()).arg(message);
return (true);
}
bool
AuthSrvImpl::processXfrQuery(const IOMessage& io_message, MessagePtr message,
OutputBufferPtr buffer,
AuthSrvImpl::processXfrQuery(const IOMessage& io_message, Message& message,
OutputBuffer& buffer,
auto_ptr<TSIGContext> tsig_context)
{
// Increment query counter.
......@@ -624,19 +623,19 @@ AuthSrvImpl::processXfrQuery(const IOMessage& io_message, MessagePtr message,
}
bool
AuthSrvImpl::processNotify(const IOMessage& io_message, MessagePtr message,
OutputBufferPtr buffer,
AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
OutputBuffer& buffer,
std::auto_ptr<TSIGContext> tsig_context)
{
// The incoming notify must contain exactly one question for SOA of the
// zone name.
if (message->getRRCount(Message::SECTION_QUESTION) != 1) {
if (message.getRRCount(Message::SECTION_QUESTION) != 1) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_NOTIFY_QUESTIONS)
.arg(message->getRRCount(Message::SECTION_QUESTION));
.arg(message.getRRCount(Message::SECTION_QUESTION));
makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context);
return (true);
}
ConstQuestionPtr question = *message->beginQuestion();
ConstQuestionPtr question = *message.beginQuestion();
if (question->getType() != RRType::SOA()) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_NOTIFY_RRTYPE)
.arg(question->getType().toText());
......@@ -691,16 +690,16 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, MessagePtr message,
return (false);
}
message->makeResponse();
message->setHeaderFlag(Message::HEADERFLAG_AA);
message->setRcode(Rcode::NOERROR());
message.makeResponse();
message.setHeaderFlag(Message::HEADERFLAG_AA);
message.setRcode(Rcode::NOERROR());
MessageRenderer renderer;
renderer.setBuffer(buffer.get());
renderer.setBuffer(&buffer);
if (tsig_context.get() != NULL) {
message->toWire(renderer, *tsig_context);
message.toWire(renderer, *tsig_context);
} else {
message->toWire(renderer);
message.toWire(renderer);
}
renderer.setBuffer(NULL);
return (true);
......@@ -786,9 +785,9 @@ AuthSrvImpl::setDbFile(ConstElementPtr config) {
}
void
AuthSrvImpl::resumeServer(DNSServer* server, MessagePtr message, bool done) {
AuthSrvImpl::resumeServer(DNSServer* server, Message& message, bool done) {
if (done) {
counters_.inc(message->getRcode());
counters_.inc(message.getRcode());
}
server->resume(done);
}
......
......@@ -115,14 +115,14 @@ public:
/// send the reply.
///
/// \param io_message The raw message received
/// \param message Pointer to the \c Message object
/// \param buffer Pointer to an \c OutputBuffer for the resposne
/// \param message the \c Message object
/// \param buffer an \c OutputBuffer for the resposne
/// \param server Pointer to the \c DNSServer
///
/// \throw isc::Unexpected Protocol type of \a message is unexpected
void processMessage(const isc::asiolink::IOMessage& io_message,
isc::dns::MessagePtr message,
isc::util::OutputBufferPtr buffer,
isc::dns::Message& message,
isc::util::OutputBuffer& buffer,
isc::asiodns::DNSServer* server);
/// \brief Updates the data source for the \c AuthSrv object.
......
......@@ -76,8 +76,8 @@ private:
typedef boost::shared_ptr<const IOEndpoint> IOEndpointPtr;
protected:
QueryBenchMark(const bool enable_cache,
const BenchQueries& queries, MessagePtr query_message,
OutputBufferPtr buffer) :
const BenchQueries& queries, Message& query_message,
OutputBuffer& buffer) :
server_(new AuthSrv(enable_cache, xfrout_client)),
queries_(queries),
query_message_(query_message),
......@@ -95,8 +95,8 @@ public:
for (query = queries_.begin(); query != query_end; ++query) {
IOMessage io_message(&(*query)[0], (*query).size(), dummy_socket,
*dummy_endpoint);
query_message_->clear(Message::PARSE);
buffer_->clear();
query_message_.clear(Message::PARSE);
buffer_.clear();
server_->processMessage(io_message, query_message_, buffer_,
&server);
}
......@@ -107,8 +107,8 @@ protected:
AuthSrvPtr server_;
private:
const BenchQueries& queries_;
MessagePtr query_message_;
OutputBufferPtr buffer_;
Message& query_message_;
OutputBuffer& buffer_;
IOSocket& dummy_socket;
IOEndpointPtr dummy_endpoint;
};
......@@ -118,8 +118,8 @@ public:
Sqlite3QueryBenchMark(const int cache_slots,
const char* const datasrc_file,
const BenchQueries& queries,
MessagePtr query_message,
OutputBufferPtr buffer) :
Message& query_message,
OutputBuffer& buffer) :
QueryBenchMark(cache_slots >= 0 ? true : false, queries,
query_message, buffer)
{
......@@ -136,8 +136,8 @@ public:
MemoryQueryBenchMark(const char* const zone_file,
const char* const zone_origin,
const BenchQueries& queries,
MessagePtr query_message,
OutputBufferPtr buffer) :
Message& query_message,
OutputBuffer& buffer) :
QueryBenchMark(false, queries, query_message, buffer)
{
configureAuthServer(*server_,
......@@ -255,8 +255,8 @@ main(int argc, char* argv[]) {
BenchQueries queries;
loadQueryData(query_data_file, queries, RRClass::IN());
OutputBufferPtr buffer(new OutputBuffer(4096));
MessagePtr message(new Message(Message::PARSE));
OutputBuffer buffer(4096);
Message message(Message::PARSE);
cout << "Parameters:" << endl;
cout << " Iterations: " << iteration << endl;
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment