Commit 67801f23 authored by Evan Hunt's avatar Evan Hunt
Browse files

checkpoint:

 - refactored NSEC support
 - fixed bug in which NSEC records could be duplicated
 - added NSEC3 code -- please note NSEC3 is COMPLETELY untested


git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1145 e5f2f494-b856-4b98-b285-d166d9295462
parent 4ef47c4c
......@@ -249,6 +249,149 @@ hasDelegation(const DataSrc* ds, const Name* zonename, Query& q,
return (false);
}
static inline DataSrc::Result
addSOA(Query& q, const Name* zonename, const DataSrc* ds) {
Message& m = q.message();
DataSrc::Result result;
RRsetList soa;
QueryTask newtask(*zonename, q.qclass(), RRType::SOA(),
QueryTask::SIMPLE_QUERY);
result = doQueryTask(ds, zonename, q, newtask, soa);
if (result != DataSrc::SUCCESS || newtask.flags != 0) {
return (DataSrc::ERROR);
}
m.addRRset(Section::AUTHORITY(), soa[RRType::SOA()], q.wantDnssec());
return (DataSrc::SUCCESS);
}
static inline DataSrc::Result
addNSEC(Query& q, const QueryTaskPtr task, const Name& name,
const Name& zonename, const DataSrc* ds)
{
RRsetList nsec;
Message& m = q.message();
DataSrc::Result result;
QueryTask newtask(name, task->qclass, RRType::NSEC(),
QueryTask::SIMPLE_QUERY);
result = doQueryTask(ds, &zonename, q, newtask, nsec);
if (result != DataSrc::SUCCESS) {
return (DataSrc::ERROR);
}
if (newtask.flags == 0) {
m.addRRset(Section::AUTHORITY(), nsec[RRType::NSEC()], true);
}
return (DataSrc::SUCCESS);
}
static inline DataSrc::Result
addNSEC3(const string& hash, Query& q, const DataSrc* ds, const Name& zonename)
{
RRsetList nsec3;
Message& m = q.message();
DataSrc::Result result;
result = ds->findCoveringNSEC3(q, hash, zonename, nsec3);
if (result != DataSrc::SUCCESS) {
return (DataSrc::ERROR);
}
m.addRRset(Section::AUTHORITY(), nsec3[RRType::NSEC3()], true);
return (DataSrc::SUCCESS);
}
static Nsec3Param*
getNsec3Param(Query& q, const DataSrc* ds, const Name& zonename)
{
DataSrc::Result result;
RRsetList nsec3param;
QueryTask newtask(zonename, q.qclass(), RRType::NSEC3PARAM(),
QueryTask::SIMPLE_QUERY);
result = doQueryTask(ds, &zonename, q, newtask, nsec3param);
newtask.flags &= ~DataSrc::REFERRAL;
if (result != DataSrc::SUCCESS || newtask.flags != 0) {
return (NULL);
}
RRsetPtr rrset = nsec3param[RRType::NSEC3PARAM()];
if (!rrset) {
return (NULL);
}
// XXX: currently only one NSEC3 chain per zone is supported;
// we will need to revisit this.
RdataIteratorPtr it = rrset->getRdataIterator();
it->first();
if (it->isLast()) {
return (NULL);
}
const generic::NSEC3PARAM& np =
dynamic_cast<const generic::NSEC3PARAM&>(it->getCurrent());
return (new Nsec3Param(np.getHashalg(), np.getFlags(),
np.getIterations(), np.getSalt()));
}
static inline DataSrc::Result
proveNX(Query& q, QueryTaskPtr task, const DataSrc* ds, const Name& zonename)
{
DataSrc::Result result;
Nsec3Param* nsec3 = getNsec3Param(q, ds, zonename);
if (nsec3) {
string node = nsec3->getHash(task->qname);
string apex = nsec3->getHash(zonename);
string wild = nsec3->getHash(Name("*").concatenate(zonename));
delete nsec3;
result = addNSEC3(node, q, ds, zonename);
if (result != DataSrc::SUCCESS) {
return (result);
}
if (node != apex) {
result = addNSEC3(apex, q, ds, zonename);
if (result != DataSrc::SUCCESS) {
return (result);
}
}
if ((task->flags & DataSrc::NAME_NOT_FOUND) != 0 && node != wild) {
result = addNSEC3(wild, q, ds, zonename);
if (result != DataSrc::SUCCESS) {
return (result);
}
}
} else {
Name nsecname(task->qname);
if ((task->flags & DataSrc::NAME_NOT_FOUND) != 0) {
ds->findPreviousName(q, task->qname, nsecname, &zonename);
}
result = addNSEC(q, task, nsecname, zonename, ds);
if (result != DataSrc::SUCCESS) {
return (result);
}
if ((task->flags & DataSrc::TYPE_NOT_FOUND) != 0 ||
nsecname == zonename)
{
return (DataSrc::SUCCESS);
}
result = addNSEC(q, task, zonename, zonename, ds);
if (result != DataSrc::SUCCESS) {
return (result);
}
}
return (DataSrc::SUCCESS);
}
// Attempt a wildcard lookup
static inline DataSrc::Result
tryWildcard(Query& q, QueryTaskPtr task, const DataSrc* ds,
......@@ -312,19 +455,6 @@ tryWildcard(Query& q, QueryTaskPtr task, const DataSrc* ds,
copyAuth(q, auth);
}
} else if (q.wantDnssec()) {
// No wildcard found; add an NSEC to prove it
RRsetList nsec;
QueryTask newtask(*zonename, task->qclass, RRType::NSEC(),
QueryTask::SIMPLE_QUERY);
result = doQueryTask(ds, zonename, q, newtask, nsec);
if (result != DataSrc::SUCCESS) {
return (DataSrc::ERROR);
}
if (newtask.flags == 0) {
m.addRRset(Section::AUTHORITY(), nsec[RRType::NSEC()], true);
}
}
return (DataSrc::SUCCESS);
......@@ -510,44 +640,30 @@ DataSrc::doQuery(Query& q)
// NXDOMAIN, and also add the previous NSEC to the authority
// section. For TYPE_NOT_FOUND, do not set an error rcode,
// and send the current NSEC in the authority section.
Name nsecname(task->qname);
if ((task->flags & NAME_NOT_FOUND) != 0) {
datasource->findPreviousName(q, task->qname, nsecname,
zonename);
}
if (task->state == QueryTask::GETANSWER) {
if ((task->flags & NAME_NOT_FOUND) != 0) {
m.setRcode(Rcode::NXDOMAIN());
}
RRsetList soa;
QueryTask newtask(*zonename, task->qclass, RRType::SOA(),
QueryTask::SIMPLE_QUERY);
result = doQueryTask(datasource, zonename, q, newtask, soa);
if (result != SUCCESS || newtask.flags != 0) {
result = addSOA(q, zonename, datasource);
if (result != SUCCESS) {
m.setRcode(Rcode::SERVFAIL());
return;
}
}
m.addRRset(Section::AUTHORITY(), soa[RRType::SOA()],
q.wantDnssec());
Name nsecname(task->qname);
if ((task->flags & NAME_NOT_FOUND) != 0) {
datasource->findPreviousName(q, task->qname, nsecname,
zonename);
}
if (q.wantDnssec()) {
RRsetList nsec;
QueryTask newtask(nsecname, task->qclass,
RRType::NSEC(), QueryTask::SIMPLE_QUERY);
result = doQueryTask(datasource, zonename, q, newtask, nsec);
if (result != SUCCESS) {
result = proveNX(q, task, datasource, *zonename);
if (result != DataSrc::SUCCESS) {
m.setRcode(Rcode::SERVFAIL());
return;
}
if (newtask.flags == 0) {
m.addRRset(Section::AUTHORITY(), nsec[RRType::NSEC()],
true);
}
}
return;
......@@ -694,7 +810,7 @@ NameMatch::update(const DataSrc& new_source, const Name& container)
}
Nsec3Param::Nsec3Param(uint8_t a, uint8_t f, uint16_t i,
std::vector<uint8_t>& s) :
const std::vector<uint8_t>& s) :
algorithm(a), flags(f), iterations(i), salt(s)
{}
......
......@@ -145,8 +145,7 @@ public:
// This MUST be implemented by concrete data sources which support
// NSEC3, but is optional for others
virtual Result findCoveringNSEC3(const Query& q,
const Nsec3Param& param,
const isc::dns::Name& qname,
const std::string& hash,
const isc::dns::Name& zonename,
isc::dns::RRsetList& target) const = 0;
};
......@@ -215,8 +214,7 @@ public:
const isc::dns::Name* zonename) const = 0;
virtual Result findCoveringNSEC3(const Query& q,
const Nsec3Param& param,
const isc::dns::Name& qname,
const std::string& hash,
const isc::dns::Name& zonename,
isc::dns::RRsetList& target) const = 0;
......@@ -294,8 +292,7 @@ public:
}
virtual Result findCoveringNSEC3(const Query& q,
const Nsec3Param& param,
const isc::dns::Name& qname,
const std::string& qname,
const isc::dns::Name& zonename,
isc::dns::RRsetList& target) const
{
......@@ -326,7 +323,7 @@ private:
class Nsec3Param {
public:
Nsec3Param(uint8_t a, uint8_t f, uint16_t i, std::vector<uint8_t>& s);
Nsec3Param(uint8_t a, uint8_t f, uint16_t i, const std::vector<uint8_t>& s);
const uint8_t algorithm;
const uint8_t flags;
......
......@@ -370,7 +370,6 @@ Sqlite3DataSrc::setupPreparedStatements(void) {
throw(e);
}
#if 0 // XXX
const char* q_nsec3_str = "SELECT rdtype, ttl, rdata FROM nsec3 "
"WHERE zone_id=?1 AND hash == $2";
try {
......@@ -391,7 +390,6 @@ Sqlite3DataSrc::setupPreparedStatements(void) {
cout << sqlite3_errmsg(db) << endl;
throw(e);
}
#endif
}
void
......@@ -525,8 +523,7 @@ Sqlite3DataSrc::findPreviousName(const Query& q,
DataSrc::Result
Sqlite3DataSrc::findCoveringNSEC3(const Query& q,
const Nsec3Param& nsec3param,
const Name& qname,
const string& hashstr,
const Name& zonename,
RRsetList& target) const
{
......@@ -535,8 +532,6 @@ Sqlite3DataSrc::findCoveringNSEC3(const Query& q,
return (ERROR);
}
string hashstr = nsec3param.getHash(qname);
sqlite3_reset(q_prevnsec3);
sqlite3_clear_bindings(q_prevnsec3);
......@@ -558,9 +553,9 @@ Sqlite3DataSrc::findCoveringNSEC3(const Query& q,
// We need to find the final NSEC3 in the chain.
// A valid NSEC3 hash is in base32, which contains no
// letters higher than V, so a search for the previous
// NSEC3 from "W" will always find it.
// NSEC3 from "w" will always find it.
sqlite3_reset(q_prevnsec3);
rc = sqlite3_bind_text(q_prevnsec3, 2, "W", -1, SQLITE_STATIC);
rc = sqlite3_bind_text(q_prevnsec3, 2, "w", -1, SQLITE_STATIC);
if (rc != SQLITE_OK) {
throw ("Could not bind 2 (last NSEC3)");
}
......
......@@ -90,8 +90,7 @@ public:
const isc::dns::Name* zonename) const;
Result findCoveringNSEC3(const Query& q,
const Nsec3Param& param,
const isc::dns::Name& qname,
const std::string& hash,
const isc::dns::Name& zonename,
isc::dns::RRsetList& target) const;
......
......@@ -175,9 +175,8 @@ StaticDataSrc::findPreviousName(const Query& q, const Name& qname,
}
DataSrc::Result
StaticDataSrc::findCoveringNSEC3(const Query& q, const Nsec3Param& param,
const Name& qname, const Name& zonename,
RRsetList& target) const
StaticDataSrc::findCoveringNSEC3(const Query& q, const string& hash,
const Name& zonename, RRsetList& target) const
{
return (NOT_IMPLEMENTED);
}
......
......@@ -82,8 +82,7 @@ public:
const isc::dns::Name* zonename) const;
Result findCoveringNSEC3(const Query& q,
const Nsec3Param& param,
const isc::dns::Name& qname,
const std::string& hash,
const isc::dns::Name& zonename,
isc::dns::RRsetList& target) const;
......
......@@ -189,7 +189,7 @@ TEST_F(DataSrcTest, Nxdomain) {
RRsetIterator rit = msg.beginSection(Section::AUTHORITY());
RRsetPtr rrset = *rit;
EXPECT_EQ(Name("example.com"), rrset->getName());
EXPECT_EQ(RRType::NSEC(), rrset->getType());
EXPECT_EQ(RRType::SOA(), rrset->getType());
EXPECT_EQ(RRClass::IN(), rrset->getClass());
// XXX: check for other authority section answers
}
......
......@@ -28,10 +28,6 @@ using namespace isc::dns;
namespace isc {
namespace auth {
// Destructors defined here to avoid confusing the linker
Query::~Query() {}
QueryTask::~QueryTask() {}
QueryTask::QueryTask(const isc::dns::Name& n, const isc::dns::RRClass& c,
const isc::dns::RRType& t, const isc::dns::Section& sect) :
qname(n), qclass(c), qtype(t), section(sect), op(AUTH_QUERY),
......@@ -90,6 +86,8 @@ QueryTask::QueryTask(const isc::dns::Name& n, const isc::dns::RRClass& c,
}
}
QueryTask::~QueryTask() {}
Query::Query(Message& m, bool dnssec) :
status_(PENDING), qname_(NULL), qclass_(NULL), qtype_(NULL),
message_(&m), want_additional_(true), want_dnssec_(dnssec)
......@@ -109,5 +107,7 @@ Query::Query(Message& m, bool dnssec) :
Section::ANSWER())));
}
Query::~Query() {}
}
}
......@@ -221,6 +221,7 @@ private:
}
}
#endif
// Local Variables:
......
......@@ -738,7 +738,7 @@ TestDataSrc::findPreviousName(const Query& q,
assert(zonename != NULL);
if (*zonename == example) {
if (qname >= example || qname < cnameext) {
if (qname >= example && qname < cnameext) {
target = example;
} else if (qname < cnameint) {
target = cnameext;
......@@ -775,8 +775,7 @@ TestDataSrc::findPreviousName(const Query& q,
DataSrc::Result
TestDataSrc::findCoveringNSEC3(const Query& q,
const Nsec3Param& param,
const Name& qname,
const string& hash,
const Name& zonename,
RRsetList& target) const
{
......
......@@ -86,8 +86,7 @@ public:
const isc::dns::Name* zonename) const;
Result findCoveringNSEC3(const Query& q,
const Nsec3Param& param,
const isc::dns::Name& qname,
const std::string& hash,
const isc::dns::Name& zonename,
isc::dns::RRsetList& target) const;
......
......@@ -41,14 +41,14 @@ using namespace std;
struct NSEC3Impl {
// straightforward representation of NSEC3 RDATA fields
NSEC3Impl(uint8_t hash, uint8_t flags, uint16_t iterations,
NSEC3Impl(uint8_t hashalg, uint8_t flags, uint16_t iterations,
vector<uint8_t>salt, vector<uint8_t>next,
vector<uint8_t> typebits) :
hash_(hash), flags_(flags), iterations_(iterations),
hashalg_(hashalg), flags_(flags), iterations_(iterations),
salt_(salt), next_(next), typebits_(typebits)
{}
uint8_t hash_;
uint8_t hashalg_;
uint8_t flags_;
uint16_t iterations_;
vector<uint8_t> salt_;
......@@ -60,15 +60,15 @@ NSEC3::NSEC3(const string& nsec3_str) :
impl_(NULL)
{
istringstream iss(nsec3_str);
unsigned int hash, flags, iterations;
unsigned int hashalg, flags, iterations;
string salthex;
iss >> hash >> flags >> iterations >> salthex;
iss >> hashalg >> flags >> iterations >> salthex;
if (iss.bad() || iss.fail()) {
dns_throw(InvalidRdataText, "Invalid NSEC3 text");
}
if (hash > 0xf) {
dns_throw(InvalidRdataText, "NSEC3 hash out of range");
if (hashalg > 0xf) {
dns_throw(InvalidRdataText, "NSEC3 hash algorithm out of range");
}
if (flags > 0xff) {
dns_throw(InvalidRdataText, "NSEC3 flags out of range");
......@@ -84,7 +84,7 @@ NSEC3::NSEC3(const string& nsec3_str) :
iss >> setw(32) >> nextstr;
vector<uint8_t> next;
if (iss.bad() || iss.fail()) {
dns_throw(InvalidRdataText, "Invalid NSEC3 hash");
dns_throw(InvalidRdataText, "Invalid NSEC3 hash algorithm");
}
decodeBase32(nextstr, next);
......@@ -116,7 +116,7 @@ NSEC3::NSEC3(const string& nsec3_str) :
}
}
impl_ = new NSEC3Impl(hash, flags, iterations, salt, next, typebits);
impl_ = new NSEC3Impl(hashalg, flags, iterations, salt, next, typebits);
}
NSEC3::NSEC3(InputBuffer& buffer, size_t rdata_len)
......@@ -125,7 +125,7 @@ NSEC3::NSEC3(InputBuffer& buffer, size_t rdata_len)
dns_throw(InvalidRdataLength, "NSEC3 too short");
}
uint8_t hash = buffer.readUint8();
uint8_t hashalg = buffer.readUint8();
uint8_t flags = buffer.readUint8();
uint16_t iterations = buffer.readUint16();
rdata_len -= 4;
......@@ -161,7 +161,7 @@ NSEC3::NSEC3(InputBuffer& buffer, size_t rdata_len)
vector<uint8_t> typebits(rdata_len);
buffer.readData(&typebits[0], rdata_len);
impl_ = new NSEC3Impl(hash, flags, iterations, salt, next, typebits);
impl_ = new NSEC3Impl(hashalg, flags, iterations, salt, next, typebits);
}
NSEC3::NSEC3(const NSEC3& source) :
......@@ -213,7 +213,7 @@ NSEC3::toText() const
}
using namespace boost;
return (lexical_cast<string>(static_cast<int>(impl_->hash_)) +
return (lexical_cast<string>(static_cast<int>(impl_->hashalg_)) +
" " + lexical_cast<string>(static_cast<int>(impl_->flags_)) +
" " + lexical_cast<string>(static_cast<int>(impl_->iterations_)) +
" " + encodeHex(impl_->salt_) +
......@@ -223,7 +223,7 @@ NSEC3::toText() const
void
NSEC3::toWire(OutputBuffer& buffer) const
{
buffer.writeUint8(impl_->hash_);
buffer.writeUint8(impl_->hashalg_);
buffer.writeUint8(impl_->flags_);
buffer.writeUint16(impl_->iterations_);
buffer.writeUint8(impl_->salt_.size());
......@@ -236,7 +236,7 @@ NSEC3::toWire(OutputBuffer& buffer) const
void
NSEC3::toWire(MessageRenderer& renderer) const
{
renderer.writeUint8(impl_->hash_);
renderer.writeUint8(impl_->hashalg_);
renderer.writeUint8(impl_->flags_);
renderer.writeUint16(impl_->iterations_);
renderer.writeUint8(impl_->salt_.size());
......@@ -251,8 +251,8 @@ NSEC3::compare(const Rdata& other) const
{
const NSEC3& other_nsec3 = dynamic_cast<const NSEC3&>(other);
if (impl_->hash_ != other_nsec3.impl_->hash_) {
return (impl_->hash_ < other_nsec3.impl_->hash_ ? -1 : 1);
if (impl_->hashalg_ != other_nsec3.impl_->hashalg_) {
return (impl_->hashalg_ < other_nsec3.impl_->hashalg_ ? -1 : 1);
}
if (impl_->flags_ != other_nsec3.impl_->flags_) {
return (impl_->flags_ < other_nsec3.impl_->flags_ ? -1 : 1);
......@@ -302,8 +302,8 @@ NSEC3::compare(const Rdata& other) const
}
uint8_t
NSEC3::getHash() const {
return impl_->hash_;
NSEC3::getHashalg() const {
return impl_->hashalg_;
}
uint8_t
......@@ -316,7 +316,7 @@ NSEC3::getIterations() const {
return impl_->iterations_;
}
vector<uint8_t>
vector<uint8_t>&
NSEC3::getSalt() const {
return impl_->salt_;
}
......
......@@ -42,10 +42,10 @@ public:
NSEC3& operator=(const NSEC3& source);
~NSEC3();
uint8_t getHash() const;
uint8_t getHashalg() const;
uint8_t getFlags() const;
uint16_t getIterations() const;
std::vector<uint8_t> getSalt() const;
std::vector<uint8_t>& getSalt() const;
private:
NSEC3Impl* impl_;
......
......@@ -37,30 +37,30 @@ using namespace std;
struct NSEC3PARAMImpl {
// straightforward representation of NSEC3PARAM RDATA fields
NSEC3PARAMImpl(uint8_t hash, uint8_t flags, uint16_t iterations,
NSEC3PARAMImpl(uint8_t hashalg, uint8_t flags, uint16_t iterations,
vector<uint8_t>salt) :
hash_(hash), flags_(flags), iterations_(iterations), salt_(salt)
hashalg_(hashalg), flags_(flags), iterations_(iterations), salt_(salt)
{}
uint8_t hash_;
uint8_t hashalg_;
uint8_t flags_;
uint16_t iterations_;
const vector<uint8_t> salt_;
const vector<uint8_t>& salt_;
};
NSEC3PARAM::NSEC3PARAM(const string& nsec3param_str) :
impl_(NULL)
{
istringstream iss(nsec3param_str);
uint16_t hash, flags, iterations;
uint16_t hashalg, flags, iterations;
stringbuf saltbuf;
iss >> hash >> flags >> iterations >> &saltbuf;
iss >> hashalg >> flags >> iterations >> &saltbuf;
if (iss.bad() || iss.fail()) {
dns_throw(InvalidRdataText, "Invalid NSEC3PARAM text");
}
if (hash > 0xf) {
dns_throw(InvalidRdataText, "NSEC3PARAM hash out of range");
if (hashalg > 0xf) {