Commit 5e96173c authored by JINMEI Tatuya's avatar JINMEI Tatuya
Browse files

partially supported fromWire() rdata (A/AAA/NS only).

now allowed parsing a full DNS message.


git-svn-id: svn://bind10.isc.org/svn/bind10/branches/f2f200910@202 e5f2f494-b856-4b98-b285-d166d9295462
parent 35eba18b
......@@ -49,8 +49,11 @@ public:
virtual size_t getSize() const = 0;
virtual size_t getSpace() const = 0;
virtual size_t getCurrent() const = 0;
virtual void setCurrent(size_t pos) = 0;
virtual uint8_t readUint8() = 0;
virtual uint16_t readUint16() = 0;
virtual uint32_t readUint32() = 0;
virtual void readData(void* data, size_t len) = 0;
virtual int recvFrom(int s, struct sockaddr *from,
socklen_t *from_len) = 0;
};
......@@ -99,8 +102,16 @@ public:
size_t getSize() const { return (buf_.size()); }
size_t getSpace() const { return (buf_.size() - _readpos); }
size_t getCurrent() const { return (_readpos); }
void setCurrent(size_t pos)
{
if (pos >= buf_.size())
throw isc::ISCBufferInvalidPosition();
_readpos = pos;
}
uint8_t readUint8();
uint16_t readUint16();
uint32_t readUint32();
void readData(void* data, size_t len);
int recvFrom(int s, struct sockaddr* from, socklen_t* from_len);
private:
......@@ -130,6 +141,30 @@ SingleBuffer::readUint16()
return (ntohs(data));
}
inline uint32_t
SingleBuffer::readUint32()
{
uint32_t data;
if (_readpos + sizeof(data) > buf_.size())
throw ISCBufferInvalidPosition();
memcpy((void*)&data, &buf_[_readpos], sizeof(data));
_readpos += sizeof(data);
return (ntohl(data));
}
inline void
SingleBuffer::readData(void *data, size_t len)
{
if (_readpos + len > buf_.size())
throw ISCBufferInvalidPosition();
memcpy(data, &buf_[_readpos], len);
_readpos += len;
}
}
#endif // __BUFFER_HH
......
......@@ -47,6 +47,9 @@ class DNSInvalidMessageSection : public DNSException {};
class DNSInvalidRendererPosition : public DNSException {};
class DNSMessageTooShort : public DNSException {};
class DNSCharStringTooLong : public DNSException {};
class DNSNameDecompressionProhibited : public DNSException {};
class DNSNameBadPointer : public DNSException {};
class DNSInvalidRdata : public DNSException {};
}
}
#endif // __EXCEPTIONS_HH
......
......@@ -24,16 +24,19 @@
#include <boost/lexical_cast.hpp>
#include <dns/buffer.h>
#include <dns/name.h>
#include <dns/rrset.h>
#include <dns/message.h>
using isc::dns::Name;
using isc::dns::Message;
using isc::dns::RRType;
using isc::dns::RRClass;
using isc::dns::TTL;
using isc::dns::Message;
using isc::dns::Rdata::Rdata;
using isc::dns::Rdata::RdataPtr;
using isc::dns::RRsetPtr;
using isc::dns::RR;
using isc::dns::TTL;
Message::Message()
{
......@@ -161,7 +164,9 @@ Message::fromWire()
counts_[SECTION_ADDITIONAL] = buffer_->readUint16();
parse_question();
// parse other sections (TBD)
for (int section = SECTION_ANSWER; section < SECTION_MAX; ++section) {
parse_section(static_cast<section_t>(section)); // XXX cast
}
}
void
......@@ -190,6 +195,30 @@ Message::parse_question()
}
}
void
Message::parse_section(section_t section)
{
if (buffer_ == NULL)
throw DNSNoMessageBuffer();
for (int count = 0; count < this->counts_[section]; count++) {
Name name(*buffer_, getDecompressor());
// Get type, class, TTL
if (buffer_->getSpace() < 2 * sizeof(uint16_t) + sizeof(uint32_t))
throw DNSMessageTooShort();
RRType rrtype(buffer_->readUint16());
RRClass rrclass(buffer_->readUint16());
TTL ttl(buffer_->readUint32());
addRR(section, RR(name, rrclass, rrtype, ttl,
RdataPtr(isc::dns::Rdata::Rdata::fromWire(rrclass,
rrtype,
*buffer_,
getDecompressor()))));
}
}
static const char *opcodetext[] = {
"QUERY",
"IQUERY",
......
......@@ -162,6 +162,7 @@ public:
private:
void initialize();
void parse_question();
void parse_section(section_t section);
private:
// Open issues: should we rather have a header in wire-format
......
......@@ -269,7 +269,8 @@ Name::Name(const std::string& namestr)
Name::Name(Buffer& buffer, NameDecompressor& decompressor)
{
unsigned int nused, labels, n, nmax;
unsigned int current;
unsigned int cused; /* Bytes of compressed name data used */
unsigned int current, new_current, biggest_pointer, pos_begin;
bool done;
fw_state state = fw_start;
unsigned int c;
......@@ -287,6 +288,7 @@ Name::Name(Buffer& buffer, NameDecompressor& decompressor)
labels = 0;
done = false;
nused = 0;
seen_pointer = false;
/*
* Find the maximum number of uncompressed target name
......@@ -296,7 +298,10 @@ Name::Name(Buffer& buffer, NameDecompressor& decompressor)
*/
nmax = MAXWIRE;
cused = 0;
current = buffer.getCurrent();
pos_begin = current;
biggest_pointer = current;
/*
* Note: The following code is not optimized for speed, but
......@@ -305,6 +310,8 @@ Name::Name(Buffer& buffer, NameDecompressor& decompressor)
while (current < buffer.getSize() && !done) {
c = buffer.readUint8();
current++;
if (!seen_pointer)
cused++;
switch (state) {
case fw_start:
......@@ -333,7 +340,11 @@ Name::Name(Buffer& buffer, NameDecompressor& decompressor)
/*
* Ordinary 14-bit pointer.
*/
throw DNSBadLabelType(); // XXX not implemented
if (!decompressor.isAllowed())
throw DNSNameDecompressionProhibited();
new_current = c & 0x3F;
n = 1;
state = fw_newcurrent;
} else
throw DNSBadLabelType();
break;
......@@ -348,6 +359,21 @@ Name::Name(Buffer& buffer, NameDecompressor& decompressor)
state = fw_start;
break;
case fw_newcurrent:
new_current *= 256;
new_current += c;
n--;
if (n != 0)
break;
if (new_current >= biggest_pointer)
throw DNSNameBadPointer();
biggest_pointer = new_current;
current = new_current;
buffer.setCurrent(current);
seen_pointer = true;
state = fw_start;
break;
// XXX not implemented, fall through
default:
throw ISCUnexpected();
......@@ -359,6 +385,7 @@ Name::Name(Buffer& buffer, NameDecompressor& decompressor)
labels_ = labels;
length_ = nused;
buffer.setCurrent(pos_begin + cused);
}
string
......
......@@ -24,9 +24,13 @@
namespace isc {
namespace dns {
// Define them as an empty class for rapid prototyping
// Define it as an empty class for rapid prototyping
class NameCompressor {};
class NameDecompressor {};
// Define it as an almost-empty class for rapid prototyping
class NameDecompressor {
public:
bool isAllowed() { return (true); }
};
class NameComparisonResult {
public:
......
......@@ -31,6 +31,8 @@
using std::pair;
using std::map;
using isc::Buffer;
using isc::dns::NameDecompressor;
using isc::dns::RRClass;
using isc::dns::RRType;
using isc::dns::TTL;
......@@ -127,9 +129,12 @@ TTL::toWire(Buffer& buffer) const
buffer.writeUint32(ttlval_);
}
typedef Rdata* (*RdataFactory)(const std::string& text_rdata);
typedef Rdata* (*TextRdataFactory)(const std::string& text_rdata);
typedef Rdata* (*WireRdataFactory)(Buffer& buffer,
NameDecompressor& decompressor);
typedef pair<RRClass, RRType> RRClassTypePair;
static map<RRClassTypePair, RdataFactory> rdata_factory_repository;
static map<RRClassTypePair, TextRdataFactory> text_rdata_factory_repository;
static map<RRClassTypePair, WireRdataFactory> wire_rdata_factory_repository;
struct RdataFactoryRegister {
public:
......@@ -140,73 +145,91 @@ private:
static RdataFactoryRegister rdata_factory;
Rdata *
createADataFromText(const std::string& text_rdata)
{
return (new A(text_rdata));
}
Rdata *
createAAAADataFromText(const std::string& text_rdata)
{
return (new AAAA(text_rdata));
}
Rdata *
createNSDataFromText(const std::string& text_rdata)
template <typename T>
Rdata*
createDataFromText(const std::string& text_rdata)
{
return (new NS(text_rdata));
return (new T(text_rdata));
}
Rdata *
createTXTDataFromText(const std::string& text_rdata)
template <typename T>
Rdata*
createDataFromWire(Buffer& buffer, NameDecompressor& decompressor)
{
return (new TXT(text_rdata));
return (new T(buffer, decompressor));
}
RdataFactoryRegister::RdataFactoryRegister()
{
rdata_factory_repository.insert(pair<RRClassTypePair, RdataFactory>
(RRClassTypePair(RRClass::IN, RRType::A),
createADataFromText));
rdata_factory_repository.insert(pair<RRClassTypePair, RdataFactory>
(RRClassTypePair(RRClass::IN, RRType::AAAA),
createAAAADataFromText));
text_rdata_factory_repository.insert(pair<RRClassTypePair, TextRdataFactory>
(RRClassTypePair(RRClass::IN, RRType::A),
createDataFromText<isc::dns::Rdata::IN::A>));
text_rdata_factory_repository.insert(pair<RRClassTypePair, TextRdataFactory>
(RRClassTypePair(RRClass::IN, RRType::AAAA),
createDataFromText<isc::dns::Rdata::IN::AAAA>));
//XXX: NS/TXT belongs to the 'generic' class. should revisit it.
rdata_factory_repository.insert(pair<RRClassTypePair, RdataFactory>
(RRClassTypePair(RRClass::IN, RRType::NS),
createNSDataFromText));
rdata_factory_repository.insert(pair<RRClassTypePair, RdataFactory>
(RRClassTypePair(RRClass::IN, RRType::TXT),
createTXTDataFromText));
text_rdata_factory_repository.insert(pair<RRClassTypePair, TextRdataFactory>
(RRClassTypePair(RRClass::IN, RRType::NS),
createDataFromText<isc::dns::Rdata::Generic::NS>));
text_rdata_factory_repository.insert(pair<RRClassTypePair, TextRdataFactory>
(RRClassTypePair(RRClass::IN, RRType::TXT),
createDataFromText<isc::dns::Rdata::Generic::TXT>));
// XXX: we should treat class-agnostic type accordingly.
rdata_factory_repository.insert(pair<RRClassTypePair, RdataFactory>
(RRClassTypePair(RRClass::CH, RRType::TXT),
createTXTDataFromText));}
text_rdata_factory_repository.insert(pair<RRClassTypePair, TextRdataFactory>
(RRClassTypePair(RRClass::CH, RRType::TXT),
createDataFromText<isc::dns::Rdata::Generic::TXT>));
wire_rdata_factory_repository.insert(pair<RRClassTypePair, WireRdataFactory>
(RRClassTypePair(RRClass::IN, RRType::A),
createDataFromWire<isc::dns::Rdata::IN::A>));
wire_rdata_factory_repository.insert(pair<RRClassTypePair, WireRdataFactory>
(RRClassTypePair(RRClass::IN, RRType::AAAA),
createDataFromWire<isc::dns::Rdata::IN::AAAA>));
wire_rdata_factory_repository.insert(pair<RRClassTypePair, WireRdataFactory>
(RRClassTypePair(RRClass::IN, RRType::NS),
createDataFromWire<isc::dns::Rdata::Generic::NS>));
}
Rdata *
Rdata::fromText(const RRClass& rrclass, const RRType& rrtype,
const std::string& text_rdata)
{
map<RRClassTypePair, RdataFactory>::const_iterator entry;
entry = rdata_factory_repository.find(RRClassTypePair(rrclass, rrtype));
if (entry != rdata_factory_repository.end()) {
map<RRClassTypePair, TextRdataFactory>::const_iterator entry;
entry = text_rdata_factory_repository.find(RRClassTypePair(rrclass,
rrtype));
if (entry != text_rdata_factory_repository.end()) {
return (entry->second(text_rdata));
}
throw DNSInvalidRRType();
}
Rdata *
Rdata::fromWire(const RRClass& rrclass, const RRType& rrtype,
Buffer& buffer, NameDecompressor& decompressor)
{
map<RRClassTypePair, WireRdataFactory>::const_iterator entry;
entry = wire_rdata_factory_repository.find(RRClassTypePair(rrclass,
rrtype));
if (entry != wire_rdata_factory_repository.end()) {
return (entry->second(buffer, decompressor));
}
throw DNSInvalidRRType();
}
A::A(const std::string& addrstr)
{
if (inet_pton(AF_INET, addrstr.c_str(), &addr_) != 1)
throw ISCInvalidAddressString();
}
void
A::fromWire(Buffer& buffer, NameDecompressor& decompressor)
A::A(Buffer& buffer, NameDecompressor& decompressor)
{
//TBD
size_t len = buffer.readUint16();
if (len != sizeof(addr_))
throw DNSInvalidRdata();
buffer.readData(&addr_, sizeof(addr_));
}
void
......@@ -239,10 +262,11 @@ AAAA::AAAA(const std::string& addrstr)
throw ISCInvalidAddressString();
}
void
AAAA::fromWire(Buffer& buffer, NameDecompressor& decompressor)
AAAA::AAAA(Buffer& buffer, NameDecompressor& decompressor)
{
//TBD
if (buffer.readUint16() != sizeof(addr_))
throw DNSInvalidRdata();
buffer.readData(&addr_, sizeof(addr_));
}
void
......@@ -269,10 +293,12 @@ AAAA::copy() const
return (new AAAA(toText()));
}
void
NS::fromWire(Buffer& buffer, NameDecompressor& decompressor)
NS::NS(Buffer& buffer, NameDecompressor& decompressor)
{
//TBD
size_t len = buffer.readUint16();
nsname_ = Name(buffer, decompressor);
if (nsname_.getLength() < len)
throw DNSInvalidRdata();
}
void
......
......@@ -131,14 +131,15 @@ public:
virtual unsigned int count() const = 0;
virtual const RRType& getType() const = 0;
virtual std::string toText() const = 0;
virtual void fromWire(Buffer& b, NameDecompressor& c) = 0;
virtual void toWire(Buffer& b, NameCompressor& c) const = 0;
// need generic method for getting n-th field? c.f. ldns
// e.g. string getField(int n);
// A semi polymorphic factory.
// semi-polymorphic factories.
static Rdata* fromText(const RRClass& rrclass, const RRType& rrtype,
const std::string& text_rdata);
static Rdata* fromWire(const RRClass& rrclass, const RRType& rrtype,
Buffer& b, NameDecompressor& d);
// polymorphic copy constructor (XXX should revisit it)
virtual Rdata* copy() const = 0;
......@@ -150,11 +151,11 @@ public:
NS() {}
explicit NS(const std::string& namestr) : nsname_(namestr) {}
explicit NS(const Name& nsname) : nsname_(nsname) {}
explicit NS(Buffer& buffer, NameDecompressor& decompressor);
unsigned int count() const { return (1); }
const RRType& getType() const { return (RRType::NS); }
static const RRType& getTypeStatic() { return (RRType::NS); }
std::string toText() const;
void fromWire(Buffer& b, NameDecompressor& c);
void toWire(Buffer& b, NameCompressor& c) const;
const std::string getNsname() const { return (nsname_.toText(false)); }
bool operator==(const NS &other) const
......@@ -194,11 +195,11 @@ public:
A() {}
// constructor from a textual IPv4 address
explicit A(const std::string& addrstr);
explicit A(Buffer& buffer, NameDecompressor& decompressor);
unsigned int count() const { return (1); }
const RRType& getType() const { return (RRType::A); }
static const RRType& getTypeStatic() { return (RRType::A); }
std::string toText() const;
void fromWire(Buffer& b, NameDecompressor& c);
void toWire(Buffer& b, NameCompressor& c) const;
const struct in_addr& getAddress() const { return (addr_); }
bool operator==(const A &other) const
......@@ -216,11 +217,11 @@ public:
AAAA() {}
// constructor from a textual IPv6 address
explicit AAAA(const std::string& addrstr);
explicit AAAA(Buffer& buffer, NameDecompressor& decompressor);
unsigned int count() const { return (1); }
std::string toText() const;
const RRType& getType() const { return (RRType::AAAA); }
static const RRType& getTypeStatic() { return (RRType::AAAA); }
void fromWire(Buffer& b, NameDecompressor& c);
void toWire(Buffer& b, NameCompressor& c) const;
const struct in6_addr& getAddress() const { return (addr_); }
bool operator==(const AAAA &other) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment