Commit da4d1f0f authored by JINMEI Tatuya's avatar JINMEI Tatuya
Browse files

[master] Merge branch 'trac1783'

parents 3dfd3b25 e0cee375
......@@ -78,6 +78,35 @@ using namespace isc::asiolink;
using namespace isc::asiodns;
using namespace isc::server_common::portconfig;
namespace {
// A helper class for cleaning up message renderer.
//
// A temporary object of this class is expected to be created before starting
// response message rendering. On construction, it (re)initialize the given
// message renderer with the given buffer. On destruction, it releases
// the previously set buffer and then release any internal resource in the
// renderer, no matter what happened during the rendering, especially even
// when it resulted in an exception.
//
// Note: if we need this helper in many other places we might consider making
// it visible to other modules. As of this implementation this is the only
// user of this class, so we hide it within the implementation.
class RendererHolder {
public:
RendererHolder(MessageRenderer& renderer, OutputBuffer* buffer) :
renderer_(renderer)
{
renderer.setBuffer(buffer);
}
~RendererHolder() {
renderer_.setBuffer(NULL);
renderer_.clear();
}
private:
MessageRenderer& renderer_;
};
}
class AuthSrvImpl {
private:
// prohibit copy
......@@ -277,8 +306,8 @@ public:
};
void
makeErrorMessage(Message& message, OutputBuffer& buffer,
const Rcode& rcode,
makeErrorMessage(MessageRenderer& renderer, Message& message,
OutputBuffer& buffer, const Rcode& rcode,
std::auto_ptr<TSIGContext> tsig_context =
std::auto_ptr<TSIGContext>())
{
......@@ -311,14 +340,12 @@ makeErrorMessage(Message& message, OutputBuffer& buffer,
message.setRcode(rcode);
MessageRenderer renderer;
renderer.setBuffer(&buffer);
RendererHolder holder(renderer, &buffer);
if (tsig_context.get() != NULL) {
message.toWire(renderer, *tsig_context);
} else {
message.toWire(renderer);
}
renderer.setBuffer(NULL);
LOG_DEBUG(auth_logger, DBG_AUTH_MESSAGES, AUTH_SEND_ERROR_RESPONSE)
.arg(renderer.getLength()).arg(message);
}
......@@ -447,13 +474,13 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message,
} catch (const DNSProtocolError& error) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_PACKET_PROTOCOL_ERROR)
.arg(error.getRcode().toText()).arg(error.what());
makeErrorMessage(message, buffer, error.getRcode());
makeErrorMessage(impl_->renderer_, message, buffer, error.getRcode());
impl_->resumeServer(server, message, true);
return;
} catch (const Exception& ex) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_PACKET_PARSE_ERROR)
.arg(ex.what());
makeErrorMessage(message, buffer, Rcode::SERVFAIL());
makeErrorMessage(impl_->renderer_, message, buffer, Rcode::SERVFAIL());
impl_->resumeServer(server, message, true);
return;
} // other exceptions will be handled at a higher layer.
......@@ -480,7 +507,8 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message,
}
if (tsig_error != TSIGError::NOERROR()) {
makeErrorMessage(message, buffer, tsig_error.toRcode(), tsig_context);
makeErrorMessage(impl_->renderer_, message, buffer,
tsig_error.toRcode(), tsig_context);
impl_->resumeServer(server, message, true);
return;
}
......@@ -497,9 +525,11 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message,
} else if (message.getOpcode() != Opcode::QUERY()) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_UNSUPPORTED_OPCODE)
.arg(message.getOpcode().toText());
makeErrorMessage(message, buffer, Rcode::NOTIMP(), tsig_context);
makeErrorMessage(impl_->renderer_, message, buffer,
Rcode::NOTIMP(), tsig_context);
} else if (message.getRRCount(Message::SECTION_QUESTION) != 1) {
makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context);
makeErrorMessage(impl_->renderer_, message, buffer,
Rcode::FORMERR(), tsig_context);
} else {
ConstQuestionPtr question = *message.beginQuestion();
const RRType &qtype = question->getType();
......@@ -517,10 +547,10 @@ AuthSrv::processMessage(const IOMessage& io_message, Message& message,
} catch (const std::exception& ex) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_RESPONSE_FAILURE)
.arg(ex.what());
makeErrorMessage(message, buffer, Rcode::SERVFAIL());
makeErrorMessage(impl_->renderer_, message, buffer, Rcode::SERVFAIL());
} catch (...) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_RESPONSE_FAILURE_UNKNOWN);
makeErrorMessage(message, buffer, Rcode::SERVFAIL());
makeErrorMessage(impl_->renderer_, message, buffer, Rcode::SERVFAIL());
}
impl_->resumeServer(server, message, send_answer);
}
......@@ -563,13 +593,11 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, Message& message,
}
} catch (const Exception& ex) {
LOG_ERROR(auth_logger, AUTH_PROCESS_FAIL).arg(ex.what());
makeErrorMessage(message, buffer, Rcode::SERVFAIL());
makeErrorMessage(renderer_, message, buffer, Rcode::SERVFAIL());
return (true);
}
renderer_.clear();
renderer_.setBuffer(&buffer);
RendererHolder holder(renderer_, &buffer);
const bool udp_buffer =
(io_message.getSocket().getProtocol() == IPPROTO_UDP);
renderer_.setLengthLimit(udp_buffer ? remote_bufsize : 65535);
......@@ -578,7 +606,6 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, Message& message,
} else {
message.toWire(renderer_);
}
renderer_.setBuffer(NULL);
LOG_DEBUG(auth_logger, DBG_AUTH_MESSAGES, AUTH_SEND_NORMAL_RESPONSE)
.arg(renderer_.getLength()).arg(message);
return (true);
......@@ -594,7 +621,8 @@ AuthSrvImpl::processXfrQuery(const IOMessage& io_message, Message& message,
if (io_message.getSocket().getProtocol() == IPPROTO_UDP) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_AXFR_UDP);
makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context);
makeErrorMessage(renderer_, message, buffer, Rcode::FORMERR(),
tsig_context);
return (true);
}
......@@ -619,7 +647,8 @@ AuthSrvImpl::processXfrQuery(const IOMessage& io_message, Message& message,
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_AXFR_ERROR)
.arg(err.what());
makeErrorMessage(message, buffer, Rcode::SERVFAIL(), tsig_context);
makeErrorMessage(renderer_, message, buffer, Rcode::SERVFAIL(),
tsig_context);
return (true);
}
......@@ -636,14 +665,16 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
if (message.getRRCount(Message::SECTION_QUESTION) != 1) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_NOTIFY_QUESTIONS)
.arg(message.getRRCount(Message::SECTION_QUESTION));
makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context);
makeErrorMessage(renderer_, message, buffer, Rcode::FORMERR(),
tsig_context);
return (true);
}
ConstQuestionPtr question = *message.beginQuestion();
if (question->getType() != RRType::SOA()) {
LOG_DEBUG(auth_logger, DBG_AUTH_DETAIL, AUTH_NOTIFY_RRTYPE)
.arg(question->getType().toText());
makeErrorMessage(message, buffer, Rcode::FORMERR(), tsig_context);
makeErrorMessage(renderer_, message, buffer, Rcode::FORMERR(),
tsig_context);
return (true);
}
......@@ -698,14 +729,12 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
message.setHeaderFlag(Message::HEADERFLAG_AA);
message.setRcode(Rcode::NOERROR());
renderer_.clear();
renderer_.setBuffer(&buffer);
RendererHolder holder(renderer_, &buffer);
if (tsig_context.get() != NULL) {
message.toWire(renderer_, *tsig_context);
} else {
message.toWire(renderer_);
}
renderer_.setBuffer(NULL);
return (true);
}
......
......@@ -1138,11 +1138,12 @@ checkThrow(ThrowWhen method, ThrowWhen throw_at, bool isc_exception) {
class FakeZoneFinder : public isc::datasrc::ZoneFinder {
public:
FakeZoneFinder(isc::datasrc::ZoneFinderPtr zone_finder,
ThrowWhen throw_when,
bool isc_exception) :
ThrowWhen throw_when, bool isc_exception,
ConstRRsetPtr fake_rrset) :
real_zone_finder_(zone_finder),
throw_when_(throw_when),
isc_exception_(isc_exception)
isc_exception_(isc_exception),
fake_rrset_(fake_rrset)
{}
virtual isc::dns::Name
......@@ -1162,7 +1163,18 @@ public:
const isc::dns::RRType& type,
isc::datasrc::ZoneFinder::FindOptions options)
{
using namespace isc::datasrc;
checkThrow(THROW_AT_FIND, throw_when_, isc_exception_);
// If faked RRset was specified on construction and it matches the
// query, return it instead of searching the real data source.
if (fake_rrset_ && fake_rrset_->getName() == name &&
fake_rrset_->getType() == type)
{
return (ZoneFinderContextPtr(new ZoneFinder::Context(
*this, options,
ResultContext(SUCCESS,
fake_rrset_))));
}
return (real_zone_finder_->find(name, type, options));
}
......@@ -1190,6 +1202,7 @@ private:
isc::datasrc::ZoneFinderPtr real_zone_finder_;
ThrowWhen throw_when_;
bool isc_exception_;
ConstRRsetPtr fake_rrset_;
};
/// \brief Proxy InMemoryClient that can throw exceptions at specified times
......@@ -1206,12 +1219,15 @@ public:
/// class or the related FakeZoneFinder)
/// \param isc_exception if true, throw isc::Exception, otherwise,
/// throw std::exception
/// \param fake_rrset If non NULL, it will be used as an answer to
/// find() for that name and type.
FakeInMemoryClient(AuthSrv::InMemoryClientPtr real_client,
ThrowWhen throw_when,
bool isc_exception) :
ThrowWhen throw_when, bool isc_exception,
ConstRRsetPtr fake_rrset = ConstRRsetPtr()) :
real_client_(real_client),
throw_when_(throw_when),
isc_exception_(isc_exception)
isc_exception_(isc_exception),
fake_rrset_(fake_rrset)
{}
/// \brief proxy call for findZone
......@@ -1226,14 +1242,16 @@ public:
const FindResult result = real_client_->findZone(name);
return (FindResult(result.code, isc::datasrc::ZoneFinderPtr(
new FakeZoneFinder(result.zone_finder,
throw_when_,
isc_exception_))));
throw_when_,
isc_exception_,
fake_rrset_))));
}
private:
AuthSrv::InMemoryClientPtr real_client_;
ThrowWhen throw_when_;
bool isc_exception_;
ConstRRsetPtr fake_rrset_;
};
} // end anonymous namespace for throwing proxy classes
......@@ -1248,9 +1266,7 @@ TEST_F(AuthSrvTest, queryWithInMemoryClientProxy) {
AuthSrv::InMemoryClientPtr fake_client(
new FakeInMemoryClient(server.getInMemoryClient(rrclass),
THROW_NEVER,
false));
THROW_NEVER, false));
ASSERT_NE(AuthSrv::InMemoryClientPtr(), server.getInMemoryClient(rrclass));
server.setInMemoryClient(rrclass, fake_client);
......@@ -1267,9 +1283,11 @@ TEST_F(AuthSrvTest, queryWithInMemoryClientProxy) {
// to throw in the given method
// If isc_exception is true, it will throw isc::Exception, otherwise
// it will throw std::exception
// If non null rrset is given, it will be passed to the proxy so it can
// return some faked response.
void
setupThrow(AuthSrv* server, const char *config, ThrowWhen throw_when,
bool isc_exception)
bool isc_exception, ConstRRsetPtr rrset = ConstRRsetPtr())
{
// Set real inmem client to proxy
updateConfig(server, config, true);
......@@ -1279,8 +1297,7 @@ setupThrow(AuthSrv* server, const char *config, ThrowWhen throw_when,
AuthSrv::InMemoryClientPtr fake_client(
new FakeInMemoryClient(
server->getInMemoryClient(isc::dns::RRClass::IN()),
throw_when,
isc_exception));
throw_when, isc_exception, rrset));
ASSERT_NE(AuthSrv::InMemoryClientPtr(),
server->getInMemoryClient(isc::dns::RRClass::IN()));
......@@ -1324,4 +1341,45 @@ TEST_F(AuthSrvTest, queryWithInMemoryClientProxyGetClass) {
opcode.getCode(), QR_FLAG | AA_FLAG, 1, 1, 2, 1);
}
TEST_F(AuthSrvTest, queryWithThrowingInToWire) {
// Set up a faked data source. It will return an empty RRset for the
// query.
ConstRRsetPtr empty_rrset(new RRset(Name("foo.example"),
RRClass::IN(), RRType::TXT(),
RRTTL(0)));
setupThrow(&server, CONFIG_INMEMORY_EXAMPLE, THROW_NEVER, true,
empty_rrset);
// Repeat the query processing two times. Due to the faked RRset,
// toWire() should throw, and it should result in SERVFAIL.
OutputBufferPtr orig_buffer;
for (int i = 0; i < 2; ++i) {
UnitTestUtil::createDNSSECRequestMessage(request_message, opcode,
default_qid,
Name("foo.example."),
RRClass::IN(), RRType::TXT());
createRequestPacket(request_message, IPPROTO_UDP);
server.processMessage(*io_message, *parse_message, *response_obuffer,
&dnsserv);
headerCheck(*parse_message, default_qid, Rcode::SERVFAIL(),
opcode.getCode(), QR_FLAG, 1, 0, 0, 0);
// Make a backup of the original buffer for latest tests and replace
// it with a new one
if (!orig_buffer) {
orig_buffer = response_obuffer;
response_obuffer.reset(new OutputBuffer(0));
}
request_message.clear(Message::RENDER);
parse_message->clear(Message::PARSE);
}
// Now check if the original buffer is intact
parse_message->clear(Message::PARSE);
InputBuffer ibuffer(orig_buffer->getData(), orig_buffer->getLength());
parse_message->fromWire(ibuffer);
headerCheck(*parse_message, default_qid, Rcode::SERVFAIL(),
opcode.getCode(), QR_FLAG, 1, 0, 0, 0);
}
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment