Commit 85dc228f authored by JINMEI Tatuya's avatar JINMEI Tatuya
Browse files

added preliminary level truncation support


git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1199 e5f2f494-b856-4b98-b285-d166d9295462
parent 207c9929
......@@ -124,6 +124,7 @@ AuthSrv::processMessage(const int fd)
OutputBuffer obuffer(remote_bufsize);
MessageRenderer renderer(obuffer);
renderer.setLengthLimit(remote_bufsize);
msg.toWire(renderer);
cout << "sending a response (" <<
boost::lexical_cast<string>(obuffer.getLength())
......
......@@ -306,7 +306,7 @@ public:
/// exception class of \c InvalidBufferPosition will be thrown.
///
/// \param pos The position in the buffer to be returned.
uint8_t operator[](size_t pos) const
const uint8_t& operator[](size_t pos) const
{
if (pos >= data_.size()) {
isc_throw(InvalidBufferPosition, "read at invalid position");
......@@ -326,6 +326,14 @@ public:
/// that is to be filled in later, e.g, by \ref writeUint16At().
/// \param len The length of the gap to be inserted in bytes.
void skip(size_t len) { data_.insert(data_.end(), len, 0); }
/// \brief TBD
void trim(size_t len)
{
if (len > data_.size()) {
isc_throw(OutOfRange, "trimming too large from output buffer");
}
data_.resize(data_.size() - len);
}
/// \brief Clear buffer content.
///
/// This method can be used to re-initialize and reuse the buffer without
......
......@@ -386,17 +386,31 @@ namespace {
template <typename T>
struct RenderSection
{
RenderSection(MessageRenderer& renderer) :
counter_(0), renderer_(renderer) {}
RenderSection(MessageRenderer& renderer, const bool partial_ok) :
counter_(0), renderer_(renderer), partial_ok_(partial_ok_),
truncated_(false)
{}
void operator()(const T& entry)
{
// TBD: if truncation is necessary, do something special.
// throw an exception, set an internal flag, etc.
// If it's already truncated, ignore the rest of the section.
if (truncated_) {
return;
}
size_t pos0 = renderer_.getLength();
counter_ += entry->toWire(renderer_);
if (renderer_.isTruncated()) {
truncated_ = true;
if (!partial_ok_) {
// roll back to the end of the previous RRset.
renderer_.trim(renderer_.getLength() - pos0);
}
}
}
unsigned int getTotalCount() { return (counter_); }
unsigned int counter_;
MessageRenderer& renderer_;
const bool partial_ok_;
bool truncated_;
};
}
......@@ -421,6 +435,13 @@ addEDNS(MessageImpl* mimpl, MessageRenderer& renderer)
return (false);
}
// If adding the OPT RR would exceed the size limit, don't do it.
// 11 = len(".") + type(2byte) + class(2byte) + TTL(4byte) + RDLEN(2byte)
// (RDATA is empty in this simple implementation)
if (renderer.getLength() + 11 > renderer.getLengthLimit()) {
return (false);
}
// Render EDNS OPT RR
uint32_t extrcode_flags = ((mimpl->rcode_.getCode() & 0xff0) << 24);
if (mimpl->dnssec_ok_) {
......@@ -446,31 +467,41 @@ Message::toWire(MessageRenderer& renderer)
// reserve room for the header
renderer.skip(HEADERLEN);
uint16_t ancount = 0, nscount = 0, arcount = 0;
uint16_t qdcount =
for_each(impl_->questions_.begin(), impl_->questions_.end(),
RenderSection<QuestionPtr>(renderer)).getTotalCount();
RenderSection<QuestionPtr>(renderer, false)).getTotalCount();
// TBD: sort RRsets in each section based on configuration policy.
uint16_t ancount =
for_each(impl_->rrsets_[sectionCodeToId(Section::ANSWER())].begin(),
impl_->rrsets_[sectionCodeToId(Section::ANSWER())].end(),
RenderSection<RRsetPtr>(renderer)).getTotalCount();
uint16_t nscount =
for_each(impl_->rrsets_[sectionCodeToId(Section::AUTHORITY())].begin(),
impl_->rrsets_[sectionCodeToId(Section::AUTHORITY())].end(),
RenderSection<RRsetPtr>(renderer)).getTotalCount();
uint16_t arcount =
for_each(impl_->rrsets_[sectionCodeToId(Section::ADDITIONAL())].begin(),
impl_->rrsets_[sectionCodeToId(Section::ADDITIONAL())].end(),
RenderSection<RRsetPtr>(renderer)).getTotalCount();
if (!renderer.isTruncated()) {
ancount =
for_each(impl_->rrsets_[sectionCodeToId(Section::ANSWER())].begin(),
impl_->rrsets_[sectionCodeToId(Section::ANSWER())].end(),
RenderSection<RRsetPtr>(renderer, true)).getTotalCount();
}
if (!renderer.isTruncated()) {
nscount =
for_each(impl_->rrsets_[sectionCodeToId(Section::AUTHORITY())].begin(),
impl_->rrsets_[sectionCodeToId(Section::AUTHORITY())].end(),
RenderSection<RRsetPtr>(renderer, true)).getTotalCount();
}
if (renderer.isTruncated()) {
setHeaderFlag(MessageFlag::TC());
} else {
arcount =
for_each(impl_->rrsets_[sectionCodeToId(Section::ADDITIONAL())].begin(),
impl_->rrsets_[sectionCodeToId(Section::ADDITIONAL())].end(),
RenderSection<RRsetPtr>(renderer, false)).getTotalCount();
}
// Added EDNS OPT RR if necessary (we want to avoid hardcoding specialized
// logic, see the parser case)
if (addEDNS(this->impl_, renderer)) {
if (!renderer.isTruncated() && addEDNS(this->impl_, renderer)) {
++arcount;
}
// TBD: EDNS, TSIG, etc.
// TBD: TSIG, SIG(0) etc.
// fill in the header
size_t header_pos = 0;
......@@ -766,6 +797,13 @@ Message::clear()
impl_->init();
}
void
Message::clear(Mode mode)
{
impl_->init();
impl_->mode_ = mode;
}
void
Message::makeResponse()
{
......
......@@ -564,6 +564,7 @@ public:
//void removeRR(const Section& section, const RR& rr);
void clear();
void clear(Mode mode);
// prepare for making a response from a request. This will clear the
// DNS header except those fields that should be kept for the response,
......
......@@ -135,7 +135,9 @@ struct MessageRendererImpl {
/// \param buffer An \c OutputBuffer object to which wire format data is
/// written.
MessageRendererImpl(OutputBuffer& buffer) :
buffer_(buffer), nbuffer_(Name::MAX_WIRE) {}
buffer_(buffer), nbuffer_(Name::MAX_WIRE), msglength_limit_(512),
truncated_(false)
{}
/// The buffer that holds the entire DNS message.
OutputBuffer& buffer_;
/// A local working buffer to convert each given name into wire format.
......@@ -145,6 +147,10 @@ struct MessageRendererImpl {
OutputBuffer nbuffer_;
/// A set of compression pointers.
std::set<NameCompressNode, NameCompare> nodeset_;
/// TBD
uint16_t msglength_limit_;
bool truncated_;
};
MessageRenderer::MessageRenderer(OutputBuffer& buffer) :
......@@ -162,6 +168,12 @@ MessageRenderer::skip(size_t len)
impl_->buffer_.skip(len);
}
void
MessageRenderer::trim(size_t len)
{
impl_->buffer_.trim(len);
}
void
MessageRenderer::clear()
{
......@@ -212,6 +224,30 @@ MessageRenderer::getLength() const
return (impl_->buffer_.getLength());
}
size_t
MessageRenderer::getLengthLimit() const
{
return (impl_->msglength_limit_);
}
void
MessageRenderer::setLengthLimit(size_t len)
{
impl_->msglength_limit_ = len;
}
bool
MessageRenderer::isTruncated() const
{
return (impl_->truncated_);
}
void
MessageRenderer::setTruncated()
{
impl_->truncated_ = true;
}
void
MessageRenderer::writeName(const Name& name, bool compress)
{
......
......@@ -99,6 +99,23 @@ public:
const void* getData() const;
/// \brief Return the length of data written in the internal buffer.
size_t getLength() const;
/// \brief TBD
bool isTruncated() const;
/// \brief TBD
size_t getLengthLimit() const;
//@}
///
/// \name Setter Methods
///
//@{
/// \brief TBD
void setLengthLimit(size_t len);
/// \brief TBD
void setTruncated();
//@}
///
......@@ -113,6 +130,9 @@ public:
///
/// \param len The length of the gap to be inserted in bytes.
void skip(size_t len);
/// \brief TBD
void trim(size_t len);
/// \brief Clear the internal buffer and other internal resources.
///
/// This method can be used to re-initialize and reuse the renderer
......
......@@ -64,7 +64,7 @@ AbstractRRset::toText() const
namespace {
template <typename T>
inline unsigned int
rrsetToWire(const AbstractRRset& rrset, T& output)
rrsetToWire(const AbstractRRset& rrset, T& output, const size_t limit)
{
unsigned int n = 0;
RdataIteratorPtr it = rrset.getRdataIterator();
......@@ -77,16 +77,25 @@ rrsetToWire(const AbstractRRset& rrset, T& output)
// sort the set of Rdata based on rrset-order and sortlist, and possible
// other options. Details to be considered.
do {
const size_t pos0 = output.getLength();
assert(pos0 < 65536);
rrset.getName().toWire(output);
rrset.getType().toWire(output);
rrset.getClass().toWire(output);
rrset.getTTL().toWire(output);
size_t pos = output.getLength();
const size_t pos = output.getLength();
output.skip(sizeof(uint16_t)); // leave the space for RDLENGTH
it->getCurrent().toWire(output);
output.writeUint16At(output.getLength() - pos - sizeof(uint16_t), pos);
if (limit > 0 && output.getLength() > limit) {
// truncation is needed
output.trim(output.getLength() - pos0);
return (n);
}
it->next();
++n;
} while (!it->isLast());
......@@ -98,13 +107,18 @@ rrsetToWire(const AbstractRRset& rrset, T& output)
unsigned int
AbstractRRset::toWire(OutputBuffer& buffer) const
{
return (rrsetToWire<OutputBuffer>(*this, buffer));
return (rrsetToWire<OutputBuffer>(*this, buffer, 0));
}
unsigned int
AbstractRRset::toWire(MessageRenderer& renderer) const
{
return (rrsetToWire<MessageRenderer>(*this, renderer));
const unsigned int rrs_written = rrsetToWire<MessageRenderer>(
*this, renderer, renderer.getLengthLimit());
if (getRdataCount() > rrs_written) {
renderer.setTruncated();
}
return (rrs_written);
}
ostream&
......
......@@ -14,10 +14,14 @@
// $Id$
#include <exceptions/exceptions.h>
#include <dns/buffer.h>
#include <gtest/gtest.h>
using namespace isc;
namespace {
using isc::dns::InputBuffer;
......@@ -158,6 +162,20 @@ TEST_F(BufferTest, outputBufferSkip)
EXPECT_EQ(6, obuffer.getLength());
}
TEST_F(BufferTest, outputBufferTrim)
{
obuffer.writeData(testdata, sizeof(testdata));
EXPECT_EQ(5, obuffer.getLength());
obuffer.trim(1);
EXPECT_EQ(4, obuffer.getLength());
obuffer.trim(2);
EXPECT_EQ(2, obuffer.getLength());
EXPECT_THROW(obuffer.trim(3), OutOfRange);
}
TEST_F(BufferTest, outputBufferReadat)
{
obuffer.writeData(testdata, sizeof(testdata));
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment