Commit 53d9f469 authored by JINMEI Tatuya's avatar JINMEI Tatuya
Browse files

[trac812next] implemented TSIG signing main part: add TSIGRecord::toWire() and...

[trac812next] implemented TSIG signing main part: add TSIGRecord::toWire() and have the Message class use it with a TSIGContext.
(There are some other small cleanups in this commit)
parent 2842dda9
......@@ -15,6 +15,7 @@
#include <stdint.h>
#include <algorithm>
#include <cassert>
#include <string>
#include <sstream>
#include <vector>
......@@ -40,6 +41,7 @@
#include <dns/rrtype.h>
#include <dns/rrttl.h>
#include <dns/rrset.h>
#include <dns/tsig.h>
using namespace std;
using namespace boost;
......@@ -123,6 +125,7 @@ public:
void setRcode(const Rcode& rcode);
int parseQuestion(InputBuffer& buffer);
int parseSection(const Message::Section section, InputBuffer& buffer);
void toWire(MessageRenderer& renderer, TSIGContext* tsig_ctx);
};
MessageImpl::MessageImpl(Message::Mode mode) :
......@@ -164,6 +167,139 @@ MessageImpl::setRcode(const Rcode& rcode) {
rcode_ = &rcode_placeholder_;
}
namespace {
template <typename T>
struct RenderSection {
RenderSection(MessageRenderer& renderer, const bool partial_ok) :
counter_(0), renderer_(renderer), partial_ok_(partial_ok),
truncated_(false)
{}
void operator()(const T& entry) {
// If it's already truncated, ignore the rest of the section.
if (truncated_) {
return;
}
const size_t pos0 = renderer_.getLength();
counter_ += entry->toWire(renderer_);
if (renderer_.isTruncated()) {
truncated_ = true;
if (!partial_ok_) {
// roll back to the end of the previous RRset.
renderer_.trim(renderer_.getLength() - pos0);
}
}
}
unsigned int getTotalCount() { return (counter_); }
unsigned int counter_;
MessageRenderer& renderer_;
const bool partial_ok_;
bool truncated_;
};
}
void
MessageImpl::toWire(MessageRenderer& renderer, TSIGContext* tsig_ctx) {
if (mode_ != Message::RENDER) {
isc_throw(InvalidMessageOperation,
"Message rendering attempted in non render mode");
}
if (rcode_ == NULL) {
isc_throw(InvalidMessageOperation,
"Message rendering attempted without Rcode set");
}
if (opcode_ == NULL) {
isc_throw(InvalidMessageOperation,
"Message rendering attempted without Opcode set");
}
// reserve room for the header
renderer.skip(HEADERLEN);
uint16_t qdcount =
for_each(questions_.begin(), questions_.end(),
RenderSection<QuestionPtr>(renderer, false)).getTotalCount();
// TBD: sort RRsets in each section based on configuration policy.
uint16_t ancount = 0;
if (!renderer.isTruncated()) {
ancount =
for_each(rrsets_[Message::SECTION_ANSWER].begin(),
rrsets_[Message::SECTION_ANSWER].end(),
RenderSection<RRsetPtr>(renderer, true)).getTotalCount();
}
uint16_t nscount = 0;
if (!renderer.isTruncated()) {
nscount =
for_each(rrsets_[Message::SECTION_AUTHORITY].begin(),
rrsets_[Message::SECTION_AUTHORITY].end(),
RenderSection<RRsetPtr>(renderer, true)).getTotalCount();
}
uint16_t arcount = 0;
if (renderer.isTruncated()) {
flags_ |= Message::HEADERFLAG_TC;
} else {
arcount =
for_each(rrsets_[Message::SECTION_ADDITIONAL].begin(),
rrsets_[Message::SECTION_ADDITIONAL].end(),
RenderSection<RRsetPtr>(renderer, false)).getTotalCount();
}
// Add EDNS OPT RR if necessary. Basically, we add it only when EDNS
// has been explicitly set. However, if the RCODE would require it and
// no EDNS has been set we generate a temporary local EDNS and use it.
if (!renderer.isTruncated()) {
ConstEDNSPtr local_edns = edns_;
if (!local_edns && rcode_->getExtendedCode() != 0) {
local_edns = ConstEDNSPtr(new EDNS());
}
if (local_edns) {
arcount += local_edns->toWire(renderer, rcode_->getExtendedCode());
}
}
// Adjust the counter buffer.
// XXX: these may not be equal to the number of corresponding entries
// in rrsets_[] or questions_ if truncation occurred or an EDNS OPT RR
// was inserted. This is not good, and we should revisit the entire
// design.
counts_[Message::SECTION_QUESTION] = qdcount;
counts_[Message::SECTION_ANSWER] = ancount;
counts_[Message::SECTION_AUTHORITY] = nscount;
counts_[Message::SECTION_ADDITIONAL] = arcount;
// fill in the header
size_t header_pos = 0;
renderer.writeUint16At(qid_, header_pos);
header_pos += sizeof(uint16_t);
uint16_t codes_and_flags =
(opcode_->getCode() << OPCODE_SHIFT) & OPCODE_MASK;
codes_and_flags |= (rcode_->getCode() & RCODE_MASK);
codes_and_flags |= (flags_ & HEADERFLAG_MASK);
renderer.writeUint16At(codes_and_flags, header_pos);
header_pos += sizeof(uint16_t);
// XXX: should avoid repeated pattern (TODO)
renderer.writeUint16At(qdcount, header_pos);
header_pos += sizeof(uint16_t);
renderer.writeUint16At(ancount, header_pos);
header_pos += sizeof(uint16_t);
renderer.writeUint16At(nscount, header_pos);
header_pos += sizeof(uint16_t);
renderer.writeUint16At(arcount, header_pos);
// Add TSIG, if necessary, at the end of the message.
// TBD: truncate case consideration
if (tsig_ctx != NULL) {
tsig_ctx->sign(qid_, renderer.getData(),
renderer.getLength())->toWire(renderer);
// update the ARCOUNT for the TSIG RR
++arcount;
assert(arcount != 0); // this should never happen for a sane message
renderer.writeUint16At(arcount, header_pos);
}
}
Message::Message(Mode mode) :
impl_(new MessageImpl(mode))
{}
......@@ -363,129 +499,14 @@ Message::addQuestion(const Question& question) {
addQuestion(QuestionPtr(new Question(question)));
}
namespace {
template <typename T>
struct RenderSection {
RenderSection(MessageRenderer& renderer, const bool partial_ok) :
counter_(0), renderer_(renderer), partial_ok_(partial_ok),
truncated_(false)
{}
void operator()(const T& entry) {
// If it's already truncated, ignore the rest of the section.
if (truncated_) {
return;
}
const size_t pos0 = renderer_.getLength();
counter_ += entry->toWire(renderer_);
if (renderer_.isTruncated()) {
truncated_ = true;
if (!partial_ok_) {
// roll back to the end of the previous RRset.
renderer_.trim(renderer_.getLength() - pos0);
}
}
}
unsigned int getTotalCount() { return (counter_); }
unsigned int counter_;
MessageRenderer& renderer_;
const bool partial_ok_;
bool truncated_;
};
}
void
Message::toWire(MessageRenderer& renderer) {
if (impl_->mode_ != Message::RENDER) {
isc_throw(InvalidMessageOperation,
"Message rendering attempted in non render mode");
}
if (impl_->rcode_ == NULL) {
isc_throw(InvalidMessageOperation,
"Message rendering attempted without Rcode set");
}
if (impl_->opcode_ == NULL) {
isc_throw(InvalidMessageOperation,
"Message rendering attempted without Opcode set");
}
// reserve room for the header
renderer.skip(HEADERLEN);
uint16_t qdcount =
for_each(impl_->questions_.begin(), impl_->questions_.end(),
RenderSection<QuestionPtr>(renderer, false)).getTotalCount();
// TBD: sort RRsets in each section based on configuration policy.
uint16_t ancount = 0;
if (!renderer.isTruncated()) {
ancount =
for_each(impl_->rrsets_[SECTION_ANSWER].begin(),
impl_->rrsets_[SECTION_ANSWER].end(),
RenderSection<RRsetPtr>(renderer, true)).getTotalCount();
}
uint16_t nscount = 0;
if (!renderer.isTruncated()) {
nscount =
for_each(impl_->rrsets_[SECTION_AUTHORITY].begin(),
impl_->rrsets_[SECTION_AUTHORITY].end(),
RenderSection<RRsetPtr>(renderer, true)).getTotalCount();
}
uint16_t arcount = 0;
if (renderer.isTruncated()) {
setHeaderFlag(HEADERFLAG_TC, true);
} else {
arcount =
for_each(impl_->rrsets_[SECTION_ADDITIONAL].begin(),
impl_->rrsets_[SECTION_ADDITIONAL].end(),
RenderSection<RRsetPtr>(renderer, false)).getTotalCount();
}
// Add EDNS OPT RR if necessary. Basically, we add it only when EDNS
// has been explicitly set. However, if the RCODE would require it and
// no EDNS has been set we generate a temporary local EDNS and use it.
if (!renderer.isTruncated()) {
ConstEDNSPtr local_edns = impl_->edns_;
if (!local_edns && impl_->rcode_->getExtendedCode() != 0) {
local_edns = ConstEDNSPtr(new EDNS());
}
if (local_edns) {
arcount += local_edns->toWire(renderer,
impl_->rcode_->getExtendedCode());
}
}
// Adjust the counter buffer.
// XXX: these may not be equal to the number of corresponding entries
// in rrsets_[] or questions_ if truncation occurred or an EDNS OPT RR
// was inserted. This is not good, and we should revisit the entire
// design.
impl_->counts_[SECTION_QUESTION] = qdcount;
impl_->counts_[SECTION_ANSWER] = ancount;
impl_->counts_[SECTION_AUTHORITY] = nscount;
impl_->counts_[SECTION_ADDITIONAL] = arcount;
// TBD: TSIG, SIG(0) etc.
// fill in the header
size_t header_pos = 0;
renderer.writeUint16At(impl_->qid_, header_pos);
header_pos += sizeof(uint16_t);
impl_->toWire(renderer, NULL);
}
uint16_t codes_and_flags =
(impl_->opcode_->getCode() << OPCODE_SHIFT) & OPCODE_MASK;
codes_and_flags |= (impl_->rcode_->getCode() & RCODE_MASK);
codes_and_flags |= (impl_->flags_ & HEADERFLAG_MASK);
renderer.writeUint16At(codes_and_flags, header_pos);
header_pos += sizeof(uint16_t);
// XXX: should avoid repeated pattern (TODO)
renderer.writeUint16At(qdcount, header_pos);
header_pos += sizeof(uint16_t);
renderer.writeUint16At(ancount, header_pos);
header_pos += sizeof(uint16_t);
renderer.writeUint16At(nscount, header_pos);
header_pos += sizeof(uint16_t);
renderer.writeUint16At(arcount, header_pos);
header_pos += sizeof(uint16_t);
void
Message::toWire(MessageRenderer& renderer, TSIGContext& tsig_ctx) {
impl_->toWire(renderer, &tsig_ctx);
}
void
......
......@@ -33,6 +33,7 @@ class InputBuffer;
}
namespace dns {
class TSIGContext;
///
/// \brief A standard DNS module exception that is thrown if a wire format
......@@ -531,6 +532,9 @@ public:
/// class \c InvalidMessageOperation will be thrown.
void toWire(MessageRenderer& renderer);
// TBD
void toWire(MessageRenderer& renderer, TSIGContext& tsig_ctx);
/// \brief Parse the header section of the \c Message.
void parseHeader(isc::util::InputBuffer& buffer);
......
......@@ -24,7 +24,7 @@
#include <dns/messagerenderer.h>
#include <dns/rdata.h>
#include <dns/rdataclass.h>
#include <dns/tsigerror.h>
using namespace std;
using namespace boost;
......@@ -313,15 +313,7 @@ TSIG::toText() const {
result += encodeBase64(impl_->mac_) + " ";
}
result += lexical_cast<string>(impl_->original_id_) + " ";
if (impl_->error_ == 16) { // XXX: we'll soon introduce generic converter.
result += "BADSIG ";
} else if (impl_->error_ == 17) {
result += "BADKEY ";
} else if (impl_->error_ == 18) {
result += "BADTIME ";
} else {
result += lexical_cast<string>(impl_->error_) + " ";
}
result += TSIGError(impl_->error_).toText() + " ";
result += lexical_cast<string>(impl_->other_data_.size());
if (impl_->other_data_.size() > 0) {
result += " " + encodeBase64(impl_->other_data_);
......
......@@ -50,6 +50,7 @@ run_unittests_SOURCES += message_unittest.cc
run_unittests_SOURCES += tsig_unittest.cc
run_unittests_SOURCES += tsigerror_unittest.cc
run_unittests_SOURCES += tsigkey_unittest.cc
run_unittests_SOURCES += tsigrecord_unittest.cc
run_unittests_SOURCES += run_unittests.cc
run_unittests_CPPFLAGS = $(AM_CPPFLAGS) $(GTEST_INCLUDES)
run_unittests_LDFLAGS = $(AM_LDFLAGS) $(GTEST_LDFLAGS)
......
......@@ -12,6 +12,8 @@
// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
// PERFORMANCE OF THIS SOFTWARE.
#include <boost/scoped_ptr.hpp>
#include <exceptions/exceptions.h>
#include <util/buffer.h>
......@@ -26,6 +28,8 @@
#include <dns/rrclass.h>
#include <dns/rrttl.h>
#include <dns/rrtype.h>
#include <dns/tsig.h>
#include <dns/tsigkey.h>
#include <gtest/gtest.h>
......@@ -53,6 +57,17 @@ using namespace isc::dns::rdata;
const uint16_t Message::DEFAULT_MAX_UDPSIZE;
const Name test_name("test.example.com");
// See dnssectime.cc
namespace isc {
namespace dns {
namespace tsig {
namespace detail {
extern int64_t (*gettimeFunction)();
}
}
}
}
namespace {
class MessageTest : public ::testing::Test {
protected:
......@@ -60,7 +75,9 @@ protected:
message_parse(Message::PARSE),
message_render(Message::RENDER),
bogus_section(static_cast<Message::Section>(
Message::SECTION_ADDITIONAL + 1))
Message::SECTION_ADDITIONAL + 1)),
tsig_ctx(TSIGKey("www.example.com:"
"SFuWd/q99SzF8Yzd1QbB9g=="))
{
rrset_a = RRsetPtr(new RRset(test_name, RRClass::IN(),
RRType::A(), RRTTL(3600)));
......@@ -88,6 +105,9 @@ protected:
RRsetPtr rrset_a; // A RRset with two RDATAs
RRsetPtr rrset_aaaa; // AAAA RRset with one RDATA with RRSIG
RRsetPtr rrset_rrsig; // RRSIG for the AAAA RRset
TSIGContext tsig_ctx;
vector<unsigned char> expected_data;
static void factoryFromFile(Message& message, const char* datafile);
};
......@@ -519,6 +539,65 @@ TEST_F(MessageTest, toWireInParseMode) {
EXPECT_THROW(message_parse.toWire(renderer), InvalidMessageOperation);
}
// See dnssectime_unittest.cc
template <int64_t NOW>
int64_t
testGetTime() {
return (NOW);
}
void
commonTSIGToWireCheck(Message& message, MessageRenderer& renderer,
TSIGContext& tsig_ctx, const char* const expected_file)
{
message.setOpcode(Opcode::QUERY());
message.setRcode(Rcode::NOERROR());
message.setHeaderFlag(Message::HEADERFLAG_RD, true);
message.addQuestion(Question(Name("www.example.com"), RRClass::IN(),
RRType::A()));
message.toWire(renderer, tsig_ctx);
vector<unsigned char> expected_data;
UnitTestUtil::readWireData(expected_file, expected_data);
EXPECT_PRED_FORMAT4(UnitTestUtil::matchWireData, renderer.getData(),
renderer.getLength(),
&expected_data[0], expected_data.size());
}
TEST_F(MessageTest, toWireWithTSIG) {
// Rendering a message with TSIG. Various special cases specific to
// TSIG are tested in the tsig tests. We only check the message contains
// a TSIG at the end and the ARCOUNT of the header is updated.
tsig::detail::gettimeFunction = testGetTime<0x4da8877a>;
message_render.setQid(0x2d65);
{
SCOPED_TRACE("Message sign with TSIG");
commonTSIGToWireCheck(message_render, renderer, tsig_ctx,
"message_toWire2.wire");
}
}
TEST_F(MessageTest, toWireWithEDNSAndTSIG) {
// Similar to the previous test, but with an EDNS before TSIG.
// The wire data check will confirm the ordering.
tsig::detail::gettimeFunction = testGetTime<0x4db60d1f>;
message_render.setQid(0x6cd);
EDNSPtr edns(new EDNS());
edns->setUDPSize(4096);
message_render.setEDNS(edns);
{
SCOPED_TRACE("Message sign with TSIG and EDNS");
commonTSIGToWireCheck(message_render, renderer, tsig_ctx,
"message_toWire3.wire");
}
}
TEST_F(MessageTest, toWireWithoutOpcode) {
message_render.setRcode(Rcode::NOERROR());
EXPECT_THROW(message_render.toWire(renderer), InvalidMessageOperation);
......
......@@ -3,6 +3,7 @@ CLEANFILES = *.wire
BUILT_SOURCES = edns_toWire1.wire edns_toWire2.wire edns_toWire3.wire
BUILT_SOURCES += edns_toWire4.wire
BUILT_SOURCES += message_fromWire10.wire message_fromWire11.wire
BUILT_SOURCES += message_toWire2.wire message_toWire3.wire
BUILT_SOURCES += name_toWire5.wire name_toWire6.wire
BUILT_SOURCES += rdatafields1.wire rdatafields2.wire rdatafields3.wire
BUILT_SOURCES += rdatafields4.wire rdatafields5.wire rdatafields6.wire
......@@ -33,6 +34,7 @@ BUILT_SOURCES += rdata_tsig_fromWire9.wire
BUILT_SOURCES += rdata_tsig_toWire1.wire rdata_tsig_toWire2.wire
BUILT_SOURCES += rdata_tsig_toWire3.wire rdata_tsig_toWire4.wire
BUILT_SOURCES += rdata_tsig_toWire5.wire
BUILT_SOURCES += tsigrecord_toWire1.wire tsigrecord_toWire2.wire
# NOTE: keep this in sync with real file listing
# so is included in tarball
......@@ -46,7 +48,7 @@ EXTRA_DIST += message_fromWire5 message_fromWire6
EXTRA_DIST += message_fromWire7 message_fromWire8
EXTRA_DIST += message_fromWire9 message_fromWire10.spec
EXTRA_DIST += message_fromWire11.spec
EXTRA_DIST += message_toWire1
EXTRA_DIST += message_toWire1 message_toWire2.spec message_toWire3.spec
EXTRA_DIST += name_fromWire1 name_fromWire2 name_fromWire3_1 name_fromWire3_2
EXTRA_DIST += name_fromWire4 name_fromWire6 name_fromWire7 name_fromWire8
EXTRA_DIST += name_fromWire9 name_fromWire10 name_fromWire11 name_fromWire12
......@@ -66,7 +68,8 @@ EXTRA_DIST += rdata_nsec_fromWire6.spec rdata_nsec_fromWire7.spec
EXTRA_DIST += rdata_nsec_fromWire8.spec rdata_nsec_fromWire9.spec
EXTRA_DIST += rdata_nsec_fromWire10.spec
EXTRA_DIST += rdata_nsec3param_fromWire1
EXTRA_DIST += rdata_nsec3_fromWire1 rdata_nsec3_fromWire3
EXTRA_DIST += rdata_nsec3_fromWire1
EXTRA_DIST += rdata_nsec3_fromWire2.spec rdata_nsec3_fromWire3
EXTRA_DIST += rdata_nsec3_fromWire4.spec rdata_nsec3_fromWire5.spec
EXTRA_DIST += rdata_nsec3_fromWire6.spec rdata_nsec3_fromWire7.spec
EXTRA_DIST += rdata_nsec3_fromWire8.spec rdata_nsec3_fromWire9.spec
......@@ -94,7 +97,7 @@ EXTRA_DIST += rdata_tsig_fromWire9.spec
EXTRA_DIST += rdata_tsig_toWire1.spec rdata_tsig_toWire2.spec
EXTRA_DIST += rdata_tsig_toWire3.spec rdata_tsig_toWire4.spec
EXTRA_DIST += rdata_tsig_toWire5.spec
EXTRA_DIST += rdata_nsec3_fromWire2.spec
EXTRA_DIST += tsigrecord_toWire1.spec tsigrecord_toWire2.spec
.spec.wire:
./gen-wiredata.py -o $@ $<
......@@ -433,6 +433,11 @@ class RRSIG:
f.write('%04x %s %s\n' % (self.tag, name_wire, sig_wire))
class TSIG:
as_rr = False
rr_name = 'example.com' # only when as_rr is True, same for class/TTL
rr_class = parse_value('ANY', dict_rrclass)
rr_ttl = 0
rdlen = None # auto-calculate
algorithm = 'hmac-sha256'
time_signed = 1286978795 # arbitrarily chosen default
......@@ -471,8 +476,16 @@ class TSIG:
if rdlen is None:
rdlen = int(len(name_wire) / 2 + 16 + len(mac) / 2 + \
len(other_data) / 2)
f.write('\n# TSIG RDATA (RDLEN=%d)\n' % rdlen)
f.write('%04x\n' % rdlen);
if self.as_rr:
f.write('\n# TSIG RR (QNAME=%s Class=%s TTL=%d RDLEN=%d)\n' %
(self.rr_name, rdict_rrclass[self.rr_class],
self.rr_ttl, rdlen))
f.write('%s %04x %04x %08x %04x\n' %
(encode_name(self.rr_name), dict_rrtype['tsig'],
self.rr_class, self.rr_ttl, rdlen))
else:
f.write('\n# TSIG RDATA (RDLEN=%d)\n' % rdlen)
f.write('%04x\n' % rdlen);
f.write('# Algorithm=%s Time-Signed=%d Fudge=%d\n' %
(self.algorithm, self.time_signed, self.fudge))
f.write('%s %012x %04x\n' % (name_wire, self.time_signed, self.fudge))
......
#
# A simple DNS response message with TSIG signed
# A simple TSIG RR (some of the parameters are taken from a live example
# and don't have a specific meaning)
#
[custom]
......
#
# TSIG RR after some names that could (unexpectedly) cause name compression
#
[custom]
sections: name/1:name/2:tsig
[name/1]
name: hmac-md5.sig-alg.reg.int
[name/2]
name: foo.example.com
[tsig]
as_rr: True
# TSIG QNAME won't be compressed
rr_name: www.example.com
algorithm: hmac-md5
time_signed: 0x4da8877a
mac_size: 16
mac: 0xdadadadadadadadadadadadadadadada
original_id: 0x2d65
......@@ -118,8 +118,9 @@ TSIGContext::sign(const uint16_t qid, const void* const data,
// specified in Section 4.3 of RFC2845.
if (error == TSIGError::BAD_SIG() || error == TSIGError::BAD_KEY()) {
ConstTSIGRecordPtr tsig(new TSIGRecord(
impl_->key_.getKeyName(),
any::TSIG(impl_->key_.getAlgorithmName(),
now, DEFAULT_FUDGE, NULL, 0,
now, DEFAULT_FUDGE, 0, NULL,
qid, error.getCode(), 0, NULL)));
impl_->previous_digest_.clear();
impl_->state_ = SIGNED;
......@@ -187,6 +188,7 @@ TSIGContext::sign(const uint16_t qid, const void* const data,
// Get the final digest, update internal state, then finish.
vector<uint8_t> digest = hmac->sign();
ConstTSIGRecordPtr tsig(new TSIGRecord(
impl_->key_.getKeyName(),
any::TSIG(impl_->key_.getAlgorithmName(),
time_signed, DEFAULT_FUDGE,
digest.size(), &digest[0],
......
......@@ -12,16 +12,91 @@
// OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR
// PERFORMANCE OF THIS SOFTWARE.
#include <ostream>
#include <string>
#include <util/buffer.h>
#include <dns/messagerenderer.h>
#include <dns/rrclass.h>
#include <dns/rrttl.h>
#include <dns/tsigrecord.h>
using namespace isc::util;
namespace {
// Internally used constants:
// Size in octets for the RR type, class TTL fields.
const size_t RR_COMMON_LEN = 8;
// Size in octets for the fixed part of TSIG RDATAs.
// - Time Signed (6)