Commit 6764e7da authored by JINMEI Tatuya's avatar JINMEI Tatuya
Browse files

[1299] pre-work update: use Serial object for SOA serial throughout the code

and tests.
parent f9cbe6fb
......@@ -342,7 +342,7 @@ class TestXfrinInitialSOA(TestXfrinState):
self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
self.assertEqual(type(XfrinFirstData()),
type(self.conn.get_xfrstate()))
self.assertEqual(1234, self.conn._end_serial)
self.assertEqual(1234, self.conn._end_serial.get_value())
def test_handle_not_soa(self):
# The given RR is not of SOA
......@@ -357,7 +357,8 @@ class TestXfrinFirstData(TestXfrinState):
super().setUp()
self.state = XfrinFirstData()
self.conn._request_type = RRType.IXFR()
self.conn._request_serial = 1230 # arbitrary chosen serial < 1234
# arbitrary chosen serial < 1234:
self.conn._request_serial = isc.dns.Serial(1230)
self.conn._diff = None # should be replaced in the AXFR case
def test_handle_ixfr_begin_soa(self):
......@@ -437,7 +438,7 @@ class TestXfrinIXFRDelete(TestXfrinState):
# false.
self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
self.assertEqual([], self.conn._diff.get_buffer())
self.assertEqual(1234, self.conn._current_serial)
self.assertEqual(1234, self.conn._current_serial.get_value())
self.assertEqual(type(XfrinIXFRAddSOA()),
type(self.conn.get_xfrstate()))
......@@ -468,7 +469,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
# We need record the state in 'conn' to check the case where the
# state doesn't change.
XfrinIXFRAdd().set_xfrstate(self.conn, XfrinIXFRAdd())
self.conn._current_serial = 1230
self.conn._current_serial = isc.dns.Serial(1230)
self.state = self.conn.get_xfrstate()
def test_handle_add_rr(self):
......@@ -480,7 +481,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
self.assertEqual(type(XfrinIXFRAdd()), type(self.conn.get_xfrstate()))
def test_handle_end_soa(self):
self.conn._end_serial = 1234
self.conn._end_serial = isc.dns.Serial(1234)
self.conn._diff.add_data(self.ns_rrset) # put some dummy change
self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
......@@ -489,7 +490,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
self.assertEqual([], self.conn._diff.get_buffer())
def test_handle_new_delete(self):
self.conn._end_serial = 1234
self.conn._end_serial = isc.dns.Serial(1234)
# SOA RR whose serial is the current one means we are going to a new
# difference, starting with removing that SOA.
self.conn._diff.add_data(self.ns_rrset) # put some dummy change
......@@ -500,7 +501,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
def test_handle_out_of_sync(self):
# getting SOA with an inconsistent serial. This is an error.
self.conn._end_serial = 1235
self.conn._end_serial = isc.dns.Serial(1235)
self.assertRaises(XfrinProtocolError, self.state.handle_rr,
self.conn, soa_rrset)
......@@ -523,7 +524,7 @@ class TestXfrinAXFR(TestXfrinState):
def setUp(self):
super().setUp()
self.state = XfrinAXFR()
self.conn._end_serial = 1234
self.conn._end_serial = isc.dns.Serial(1234)
def test_handle_rr(self):
"""
......@@ -781,7 +782,7 @@ class TestAXFR(TestXfrinConnection):
# IXFR query
msg = self.conn._create_query(RRType.IXFR())
check_query(RRType.IXFR(), begin_soa_rrset)
self.assertEqual(1230, self.conn._request_serial)
self.assertEqual(1230, self.conn._request_serial.get_value())
def test_create_ixfr_query_fail(self):
# In these cases _create_query() will fail to find a valid SOA RR to
......@@ -1270,7 +1271,7 @@ class TestIXFRResponse(TestXfrinConnection):
def setUp(self):
super().setUp()
self.conn._query_id = self.conn.qid = 1035
self.conn._request_serial = 1230
self.conn._request_serial = isc.dns.Serial(1230)
self.conn._request_type = RRType.IXFR()
self._zone_name = TEST_ZONE_NAME
self.conn._datasrc_client = MockDataSourceClient()
......@@ -1543,9 +1544,9 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
self.conn.response_generator = create_ixfr_response
# Confirm xfrin succeeds and SOA is updated
self.assertEqual(1230, self.get_zone_serial())
self.assertEqual(1230, self.get_zone_serial().get_value())
self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.IXFR()))
self.assertEqual(1234, self.get_zone_serial())
self.assertEqual(1234, self.get_zone_serial().get_value())
# Also confirm the corresponding diffs are stored in the diffs table
conn = sqlite3.connect(self.sqlite3db_obj)
......@@ -1574,9 +1575,9 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
self._create_soa('1235')])
self.conn.response_generator = create_ixfr_response
self.assertEqual(1230, self.get_zone_serial())
self.assertEqual(1230, self.get_zone_serial().get_value())
self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
self.assertEqual(1230, self.get_zone_serial())
self.assertEqual(1230, self.get_zone_serial().get_value())
def test_do_ixfrin_nozone_sqlite3(self):
self.conn._zone_name = Name('nosuchzone.example')
......@@ -1595,11 +1596,11 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
self.conn.response_generator = create_response
# Confirm xfrin succeeds and SOA is updated, A RR is deleted.
self.assertEqual(1230, self.get_zone_serial())
self.assertEqual(1230, self.get_zone_serial().get_value())
self.assertTrue(self.record_exist(Name('dns01.example.com'),
RRType.A()))
self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, type))
self.assertEqual(1234, self.get_zone_serial())
self.assertEqual(1234, self.get_zone_serial().get_value())
self.assertFalse(self.record_exist(Name('dns01.example.com'),
RRType.A()))
......@@ -1627,11 +1628,11 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
answers=[soa_rrset, self._create_ns(), soa_rrset, soa_rrset])
self.conn.response_generator = create_response
self.assertEqual(1230, self.get_zone_serial())
self.assertEqual(1230, self.get_zone_serial().get_value())
self.assertTrue(self.record_exist(Name('dns01.example.com'),
RRType.A()))
self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, type))
self.assertEqual(1230, self.get_zone_serial())
self.assertEqual(1230, self.get_zone_serial().get_value())
self.assertTrue(self.record_exist(Name('dns01.example.com'),
RRType.A()))
......@@ -1669,7 +1670,7 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.AXFR()))
self.assertEqual(type(XfrinAXFREnd()),
type(self.conn.get_xfrstate()))
self.assertEqual(1234, self.get_zone_serial())
self.assertEqual(1234, self.get_zone_serial().get_value())
self.assertFalse(self.record_exist(Name('dns01.example.com'),
RRType.A()))
......
......@@ -153,7 +153,7 @@ def format_addrinfo(addrinfo):
"appear to be consisting of (family, socktype, (addr, port))")
def get_soa_serial(soa_rdata):
'''Extract the serial field of an SOA RDATA and returns it as an intger.
'''Extract the serial field of SOA RDATA and return it as a Serial object.
We don't have to be very efficient here, so we first dump the entire RDATA
as a string and convert the first corresponding field. This should be
......@@ -162,7 +162,7 @@ def get_soa_serial(soa_rdata):
should be a more direct and convenient way to get access to the SOA
fields.
'''
return int(soa_rdata.to_text().split()[2])
return Serial(int(soa_rdata.to_text().split()[2]))
class XfrinState:
'''
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment