Commit 2b9742f8 authored by Likun Zhang's avatar Likun Zhang
Browse files

1. Add unittest and docstring to some functions in xfrin.

2. Minor fix to the xfrin.

git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1729 e5f2f494-b856-4b98-b285-d166d9295462
parent 62984250
......@@ -18,11 +18,62 @@ import unittest
import socket
from xfrin import *
# An axfr response of the simple zone "example.com(without soa record at the end)."
axfr_response1 = b'\x84\x00\x00\x01\x00\x06\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01\xc0\x0c\x00\x06\x00\x01\x00\x00\x0e\x10\x00$\x05dns01\xc0\x0c\x05admin\xc0\x0c\x00\x00\x04\xd2\x00\x00\x0e\x10\x00\x00\x07\x08\x00$\xea\x00\x00\x00\x1c \xc0\x0c\x00\x02\x00\x01\x00\x00\x0e\x10\x00\x02\xc0)\xc0)\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\xa8\x02\x02\x04sql1\xc0\x0c\x00\x02\x00\x01\x00\x00\x0e\x10\x00\x02\xc0)\x04sql2\xc0\x0c\x00\x02\x00\x01\x00\x00\x0e\x10\x00\x02\xc0)\x03ns1\x07subzone\xc0\x0c\x00\x01\x00\x01\x00\x00\x0e\x10\x00\x04\xc0\xa8\x03\x01'
# The second axfr response with only the end soa record.
axfr_response2 = b'\x84\x00\x00\x00\x00\x01\x00\x00\x00\x00\x07example\x03com\x00\x00\x06\x00\x01\x00\x00\x0e\x10\x00$\x05dns01\xc0\x0c\x05admin\xc0\x0c\x00\x00\x04\xd2\x00\x00\x0e\x10\x00\x00\x07\x08\x00$\xea\x00\x00\x00\x1c '
DB_FILE = 'db_file'
# Rewrite the class for unittest.
class MyXfrin(Xfrin):
def __init__(self):
pass
class MyXfrinConnection(XfrinConnection):
query_data = b''
eply_data = b''
def _handle_xfrin_response(self):
for rr in super()._handle_xfrin_response():
pass
def _get_request_response(self, size):
ret = self.reply_data[:size]
self.reply_data = self.reply_data[size:]
if (len(ret) < size):
raise XfrinException('cannot get reply data')
return ret
def send(self, data):
self.query_data += data
return len(data)
def create_response_data(self, data):
reply_data = self.query_data[2:4] + data
size = socket.htons(len(reply_data))
reply_data = struct.pack('H', size) + reply_data
return reply_data
class TestXfrinConnection(unittest.TestCase):
def setUp(self):
self.conn = MyXfrinConnection('example.com.', DB_FILE, threading.Event(), '1.1.1.1')
def test_response_with_invalid_msg(self):
self.conn.data_exchange = b'aaaxxxx'
self.assertRaises(Exception, self.conn._handle_xfrin_response)
def test_response_without_end_soa(self):
self.conn._send_query(rr_type.AXFR())
self.conn.reply_data = self.conn.create_response_data(axfr_response1)
self.assertRaises(XfrinException, self.conn._handle_xfrin_response)
def test_response(self):
self.conn._send_query(rr_type.AXFR())
self.conn.reply_data = self.conn.create_response_data(axfr_response1)
self.conn.reply_data += self.conn.create_response_data(axfr_response2)
self.conn._handle_xfrin_response()
class TestXfrin(unittest.TestCase):
def test_parse_cmd_params(self):
......@@ -51,6 +102,10 @@ class TestXfrin(unittest.TestCase):
self.assertRaises(XfrinException, xfr._parse_cmd_params, {'zone_name':'ds.cn.'})
self.assertRaises(XfrinException, xfr._parse_cmd_params, {'master':'ds.cn.'})
if __name__== "__main__":
unittest.main()
try:
unittest.main()
os.remove(DB_FILE)
except KeyboardInterrupt as e:
print(e)
......@@ -50,15 +50,6 @@ SPECFILE_LOCATION = SPECFILE_PATH + "/xfrin.spec"
__version__ = 'BIND10'
# define xfrin rcode
XFRIN_OK = 0
XFRIN_RECV_TIMEOUT = 1
XFRIN_NO_NEWDATA = 2
XFRIN_QUOTA_ERROR = 3
XFRIN_IS_DOING = 4
# define xfrin state
XFRIN_QUERY_SOA = 1
XFRIN_FIRST_AXFR = 2
XFRIN_FIRST_IXFR = 3
def log_error(msg):
sys.stderr.write("[b10-xfrin] ")
......@@ -68,21 +59,17 @@ def log_error(msg):
class XfrinException(Exception):
pass
class XfrinConnection(asyncore.dispatcher):
'''Do xfrin in this class. '''
def __init__(self, zone_name, db_file,
shutdown_event,
master_addr,
port = 53,
check_soa = True,
verbose = False,
idle_timeout = 60):
def __init__(self,
zone_name, db_file, shutdown_event, master_addr,
port = 53, verbose = False, idle_timeout = 60):
''' idle_timeout: max idle time for read data from socket.
db_file: specify the data source file.
check_soa: when it's true, check soa first before sending xfr query
'''
asyncore.dispatcher.__init__(self)
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self._zone_name = zone_name
......@@ -92,18 +79,22 @@ class XfrinConnection(asyncore.dispatcher):
self.setblocking(1)
self._shutdown_event = shutdown_event
self._verbose = verbose
self._master_addr = master_addr
self._port = port
def connect_to_master(self, master_addr, port):
def connect_to_master(self):
'''Connect to master in TCP.'''
try:
self.connect((master_addr, port))
self.connect((self._master_addr, self._port))
return True
except socket.error as e:
self.log_msg('Failed to connect:(%s:%d), %s' % (master_addr, port, str(e)))
self.log_msg('Failed to connect:(%s:%d), %s' % (self._master_addr, self._port, str(e)))
return False
def _create_query(self, query_type):
'''Create dns query message. '''
msg = message(message_mode.RENDER)
query_id = random.randint(1, 0xFFFF)
self._query_id = query_id
......@@ -123,6 +114,7 @@ class XfrinConnection(asyncore.dispatcher):
def _send_query(self, query_type):
'''Send query message over TCP. '''
msg = self._create_query(query_type)
obuf = output_buffer(0)
render = message_render(obuf)
......@@ -147,26 +139,23 @@ class XfrinConnection(asyncore.dispatcher):
return data
def handle_read(self):
'''Read query's response from socket. '''
self._recvd_data = self.recv(self._need_recv_size)
self._recvd_size = len(self._recvd_data)
self._recv_time_out = False
def _check_soa_serial(self):
''' Compare the soa serial, if soa serial in master is less than
the soa serial in local, Finish xfrin.
False: soa serial in master is less or equal to the local one.
True: soa serial in master is bigger
'''
self._send_query(rr_type.SOA())
data_size = self._get_request_response(2)
soa_reply = self._get_request_response(int(data_size))
#TODO, need select soa record from data source then compare the two
#serial
#serial, current just return OK, since this function hasn't been used now
return XFRIN_OK
def do_xfrin(self, check_soa, ixfr_first = False):
'''Do xfr by sending xfr request and parsing response. '''
try:
ret = XFRIN_OK
if check_soa:
......@@ -194,6 +183,8 @@ class XfrinConnection(asyncore.dispatcher):
return ret
def _check_response_status(self, msg):
'''Check validation of xfr response. '''
#TODO, check more?
msg_rcode = msg.get_rcode()
if msg_rcode != rcode.NOERROR():
......@@ -212,6 +203,8 @@ class XfrinConnection(asyncore.dispatcher):
raise XfrinException('query section count greater than 1')
def _handle_answer_section(self, rrset_iter):
'''Return a generator for the reponse in one tcp package to a zone transfer.'''
while not rrset_iter.is_last():
rrset = rrset_iter.get_rrset()
rrset_iter.next()
......@@ -242,6 +235,8 @@ class XfrinConnection(asyncore.dispatcher):
rdata_iter.next()
def _handle_xfrin_response(self):
'''Return a generator for the response to a zone transfer. '''
while True:
data_len = self._get_request_response(2)
msg_len = socket.htons(struct.unpack('H', data_len)[0])
......@@ -258,12 +253,18 @@ class XfrinConnection(asyncore.dispatcher):
break
if self._shutdown_event.is_set():
#Check if xfrin process is shutdown.
#TODO, xfrin may be blocked in one loop.
raise XfrinException('xfrin is forced to stop')
def handle_read(self):
'''Read query's response from socket. '''
self._recvd_data = self.recv(self._need_recv_size)
self._recvd_size = len(self._recvd_data)
self._recv_time_out = False
def writable(self):
'''Ignore the writable socket. '''
return False
def log_info(self, msg, type='info'):
......@@ -282,9 +283,9 @@ def process_xfrin(xfrin_recorder, zone_name, db_file,
port = int(port)
xfrin_recorder.increment(zone_name)
conn = XfrinConnection(zone_name, db_file, shutdown_event,
master_addr, port, check_soa, verbose)
if conn.connect_to_master(master_addr, port):
conn.do_xfrin(False)
master_addr, port, verbose)
if conn.connect_to_master():
conn.do_xfrin(check_soa)
xfrin_recorder.decrement(zone_name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment