Commit ca42fb64 authored by JINMEI Tatuya's avatar JINMEI Tatuya
Browse files

[1288] make sure transfers_counter is always reset whateven happens within

XfroutSession().  The original code was already buggy in this sense, but
with the newer data source API it will be a bit more likely to happen
due to the generality of the API, so it would make sense to fix it here.
parent bcb37a2f
......@@ -64,10 +64,40 @@ class MySocket():
def clear_send(self):
del self.sendqueue[:]
# We subclass the Session class we're testing here, only
# to override the handle() and _send_data() method
class MockDataSrcClient:
def __init__(self, type, config):
pass
def get_iterator(self, zone_name):
if zone_name == Name('notauth.example.com'):
raise isc.datasrc.Error('no such zone')
self._zone_name = zone_name
return self
def get_soa(self): # emulate ZoneIterator.get_soa()
if self._zone_name == Name('nosoa.example.com'):
return None
soa_rrset = RRset(Name('multisoa.example.com'), RRClass.IN(),
RRType.SOA(), RRTTL(3600))
soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
'master.example.com. ' +
'admin.example.com. 1234 ' +
'3600 1800 2419200 7200'))
if self._zone_name == Name('multisoa.example.com'):
soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
'master.example.com. ' +
'admin.example.com. 1300 ' +
'3600 1800 2419200 7200'))
return soa_rrset
return soa_rrset
# We subclass the Session class we're testing here, only overriding a few
# methods
class MyXfroutSession(XfroutSession):
def handle(self):
def _handle(self):
pass
def _close_socket(self):
pass
def _send_data(self, sock, data):
......@@ -80,12 +110,14 @@ class MyXfroutSession(XfroutSession):
class Dbserver:
def __init__(self):
self._shutdown_event = threading.Event()
self.transfer_counter = 0
def get_db_file(self):
return 'test.sqlite3'
def increase_transfers_counter(self):
self.transfer_counter += 1
return True
def decrease_transfers_counter(self):
pass
self.transfer_counter -= 1
class TestXfroutSession(unittest.TestCase):
def getmsg(self):
......@@ -139,6 +171,45 @@ class TestXfroutSession(unittest.TestCase):
'admin.exAmple.com. ' +
'1234 3600 1800 2419200 7200'))
def tearDown(self):
# transfer_counter must be always be reset no matter happens within
# the XfroutSession object. We check the condition here.
self.assertEqual(0, self.xfrsess._server.transfer_counter)
def test_quota_error(self):
'''Emulating the server being too busy.
'''
self.xfrsess._request_data = self.mdata
self.xfrsess._server.increase_transfers_counter = lambda : False
XfroutSession._handle(self.xfrsess)
self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.REFUSED())
def test_quota_ok(self):
'''The default case in terms of the xfrout quota.
'''
# set up a bogus request, which should result in FORMERR. (it only
# has to be something that is different from the previous case)
self.xfrsess._request_data = \
self.create_request_data(with_question=False)
# Replace the data source client to avoid datasrc related exceptions
self.xfrsess.ClientClass = MockDataSrcClient
XfroutSession._handle(self.xfrsess)
self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.FORMERR())
def test_exception_from_session(self):
'''Test the case where the main processing raises an exception.
We just check it doesn't any unexpected disruption and (in teraDown)
transfer_counter is correctly reset to 0.
'''
def dns_xfrout_start(fd, msg, quota):
raise ValueError('fake exception')
self.xfrsess.dns_xfrout_start = dns_xfrout_start
XfroutSession._handle(self.xfrsess)
def test_parse_query_message(self):
[get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
self.assertEqual(get_rcode.to_text(), "NOERROR")
......@@ -520,32 +591,6 @@ class TestXfroutSession(unittest.TestCase):
self.assertEqual(82, get_rrset_len(self.soa_rrset))
def test_check_xfrout_available(self):
class MockDataSrcClient:
def __init__(self, type, config): pass
def get_iterator(self, zone_name):
if zone_name == Name('notauth.example.com'):
raise isc.datasrc.Error('no such zone')
self._zone_name = zone_name
return self
def get_soa(self): # emulate ZoneIterator.get_soa()
if self._zone_name == Name('nosoa.example.com'):
return None
soa_rrset = RRset(Name('multisoa.example.com'), RRClass.IN(),
RRType.SOA(), RRTTL(3600))
soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
'master.example.com. ' +
'admin.example.com. 1234 ' +
'3600 1800 2419200 7200'))
if self._zone_name == Name('multisoa.example.com'):
soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
'master.example.com. ' +
'admin.example.com. 1300 ' +
'3600 1800 2419200 7200'))
return soa_rrset
return soa_rrset
self.xfrsess.ClientClass = MockDataSrcClient
self.assertEqual(self.xfrsess._check_xfrout_available(
Name('notauth.example.com')), Rcode.NOTAUTH())
......@@ -554,13 +599,6 @@ class TestXfroutSession(unittest.TestCase):
self.assertEqual(self.xfrsess._check_xfrout_available(
Name('multisoa.example.com')), Rcode.SERVFAIL())
self.xfrsess._server.increase_transfers_counter = lambda : False
self.assertEqual(self.xfrsess._check_xfrout_available(
Name('example.com')), Rcode.REFUSED())
self.xfrsess._server.increase_transfers_counter = lambda : True
self.assertEqual(self.xfrsess._check_xfrout_available(
Name('example.com')), Rcode.NOERROR())
def test_dns_xfrout_start_formerror(self):
# formerror
self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
......
......@@ -128,21 +128,46 @@ class XfroutSession():
self._zone_config = zone_config
self.ClientClass = client_class # parameterize this for testing
self._soa = None # will be set in _check_xfrout_available or in tests
self.handle()
self._handle()
def create_tsig_ctx(self, tsig_record, tsig_key_ring):
return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(),
tsig_key_ring)
def handle(self):
''' Handle a xfrout query, send xfrout response '''
def _handle(self):
''' Handle a xfrout query, send xfrout response(s).
This is separated from the constructor so that we can override
it from tests.
'''
# Check the xfrout quota. We do both increase/decrease in this
# method so it's clear we always release it once acuired.
quota_ok = self._server.increase_transfers_counter()
ex = None
try:
self.dns_xfrout_start(self._sock_fd, self._request_data)
#TODO, avoid catching all exceptions
self.dns_xfrout_start(self._sock_fd, self._request_data, quota_ok)
except Exception as e:
logger.error(XFROUT_HANDLE_QUERY_ERROR, e)
pass
# To avoid resource leak we need catch all possible exceptions
# We log it later to exclude the case where even logger raises
# an exception.
ex = e
# Release any critical resources
if quota_ok:
self._server.decrease_transfers_counter()
self._close_socket()
if ex is not None:
logger.error(XFROUT_HANDLE_QUERY_ERROR, ex)
def _close_socket(self):
'''Simply close the socket via the given FD.
This is a dedicated subroutine of handle() and is sepsarated from it
for the convenience of tests.
'''
os.close(self._sock_fd)
def _check_request_tsig(self, msg, request_data):
......@@ -252,12 +277,8 @@ class XfroutSession():
'''Check if xfr request can be responsed.
TODO, Get zone's configuration from cfgmgr or some other place
eg. check allow_transfer setting,
'''
# Reject the attempt if we are too busy. Check this first to avoid
# unnecessary resource consumption even if we discard it soon.
if not self._server.increase_transfers_counter():
return Rcode.REFUSED()
'''
# Identify the data source for the requested zone and see if it has
# SOA while initializing objects used for request processing later.
......@@ -292,7 +313,7 @@ class XfroutSession():
return Rcode.NOERROR()
def dns_xfrout_start(self, sock_fd, msg_query):
def dns_xfrout_start(self, sock_fd, msg_query, quota_ok=True):
rcode_, msg = self._parse_query_message(msg_query)
#TODO. create query message and parse header
if rcode_ is None: # Dropped by ACL
......@@ -302,6 +323,9 @@ class XfroutSession():
elif rcode_ != Rcode.NOERROR():
return self._reply_query_with_error_rcode(msg, sock_fd,
Rcode.FORMERR())
elif not quota_ok:
return self._reply_query_with_error_rcode(msg, sock_fd,
Rcode.REFUSED())
question = msg.get_question()[0]
zone_name = question.get_name()
......@@ -322,8 +346,6 @@ class XfroutSession():
pass
logger.info(XFROUT_AXFR_TRANSFER_DONE, zone_str)
self._server.decrease_transfers_counter()
def _clear_message(self, msg):
qid = msg.get_qid()
opcode = msg.get_opcode()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment