xfrin_test.py 96.3 KB
Newer Older
Jelte Jansen's avatar
Jelte Jansen committed
1
# Copyright (C) 2009-2011  Internet Systems Consortium.
Likun Zhang's avatar
Likun Zhang committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import unittest
17
import shutil
Likun Zhang's avatar
Likun Zhang committed
18
import socket
19
import io
20
from isc.testutils.tsigctx_mock import MockTSIGContext
Likun Zhang's avatar
Likun Zhang committed
21
from xfrin import *
22
from isc.xfrin.diff import Diff
23
import isc.log
Likun Zhang's avatar
Likun Zhang committed
24

25
26
27
#
# Commonly used (mostly constant) test parameters
#
28
29
TEST_ZONE_NAME_STR = "example.com."
TEST_ZONE_NAME = Name(TEST_ZONE_NAME_STR)
30
TEST_RRCLASS = RRClass.IN()
31
TEST_RRCLASS_STR = 'IN'
32
33
TEST_DB_FILE = 'db_file'
TEST_MASTER_IPV4_ADDRESS = '127.0.0.1'
34
35
36
TEST_MASTER_IPV4_ADDRINFO = (socket.AF_INET, socket.SOCK_STREAM,
                             socket.IPPROTO_TCP, '',
                             (TEST_MASTER_IPV4_ADDRESS, 53))
37
TEST_MASTER_IPV6_ADDRESS = '::1'
38
39
40
TEST_MASTER_IPV6_ADDRINFO = (socket.AF_INET6, socket.SOCK_STREAM,
                             socket.IPPROTO_TCP, '',
                             (TEST_MASTER_IPV6_ADDRESS, 53))
41
42
43

TESTDATA_SRCDIR = os.getenv("TESTDATASRCDIR")
TESTDATA_OBJDIR = os.getenv("TESTDATAOBJDIR")
44
45
46
47
# XXX: This should be a non priviledge port that is unlikely to be used.
# If some other process uses this port test will fail.
TEST_MASTER_PORT = '53535'

48
49
TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")

50
# SOA intended to be used for the new SOA as a result of transfer.
51
52
53
soa_rdata = Rdata(RRType.SOA(), TEST_RRCLASS,
                  'master.example.com. admin.example.com ' +
                  '1234 3600 1800 2419200 7200')
54
soa_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA(), RRTTL(3600))
55
soa_rrset.add_rdata(soa_rdata)
56
57
58
59
60
61
62
63

# SOA intended to be used for the current SOA at the secondary side.
# Note that its serial is smaller than that of soa_rdata.
begin_soa_rdata = Rdata(RRType.SOA(), TEST_RRCLASS,
                        'master.example.com. admin.example.com ' +
                        '1230 3600 1800 2419200 7200')
begin_soa_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA(), RRTTL(3600))
begin_soa_rrset.add_rdata(begin_soa_rdata)
64
65
example_axfr_question = Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.AXFR())
example_soa_question = Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA())
66
default_questions = [example_axfr_question]
67
default_answers = [soa_rrset]
68

69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
def check_diffs(assert_fn, expected, actual):
    '''A helper function checking the differences made in the XFR session.

    This is expected called from some subclass of unittest.TestCase and
    assert_fn is generally expected to be 'self.assertEqual' of that class.

    '''
    assert_fn(len(expected), len(actual))
    for (diffs_exp, diffs_actual) in zip(expected, actual):
        assert_fn(len(diffs_exp), len(diffs_actual))
        for (diff_exp, diff_actual) in zip(diffs_exp, diffs_actual):
            # operation should match
            assert_fn(diff_exp[0], diff_actual[0])
            # The diff as RRset should be equal (for simplicity we assume
            # all RRsets contain exactly one RDATA)
            assert_fn(diff_exp[1].get_name(), diff_actual[1].get_name())
            assert_fn(diff_exp[1].get_type(), diff_actual[1].get_type())
            assert_fn(diff_exp[1].get_class(), diff_actual[1].get_class())
            assert_fn(diff_exp[1].get_rdata_count(),
                      diff_actual[1].get_rdata_count())
            assert_fn(1, diff_exp[1].get_rdata_count())
            assert_fn(diff_exp[1].get_rdata()[0],
                      diff_actual[1].get_rdata()[0])

93
94
class XfrinTestException(Exception):
    pass
95

96
97
98
class XfrinTestTimeoutException(Exception):
    pass

99
100
101
102
103
class MockCC():
    def get_default_value(self, identifier):
        if identifier == "zones/master_port":
            return TEST_MASTER_PORT
        if identifier == "zones/class":
Jelte Jansen's avatar
Jelte Jansen committed
104
            return TEST_RRCLASS_STR
105

106
107
108
109
110
class MockDataSourceClient():
    '''A simple mock data source client.

    This class provides a minimal set of wrappers related the data source
    API that would be used by Diff objects.  For our testing purposes they
111
    only keep truck of the history of the changes.
112
113

    '''
114
    def __init__(self):
115
        self.force_fail = False # if True, raise an exception on commit
116
117
118
        self.committed_diffs = []
        self.diffs = []

119
120
121
122
123
124
125
126
127
128
    def get_class(self):
        '''Mock version of get_class().

        We simply return the commonly used constant RR class.  If and when
        we use this mock for a different RR class we need to adjust it
        accordingly.

        '''
        return TEST_RRCLASS

129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    def find_zone(self, zone_name):
        '''Mock version of find_zone().

        It returns itself (subsequently acting as a mock ZoneFinder) for
        some test zone names.  For some others it returns either NOTFOUND
        or PARTIALMATCH.

        '''
        if zone_name == TEST_ZONE_NAME or \
                zone_name == Name('no-soa.example') or \
                zone_name == Name('dup-soa.example'):
            return (isc.datasrc.DataSourceClient.SUCCESS, self)
        elif zone_name == Name('no-such-zone.example'):
            return (DataSourceClient.NOTFOUND, None)
        elif zone_name == Name('partial-match-zone.example'):
            return (DataSourceClient.PARTIALMATCH, self)
        raise ValueError('Unexpected input to mock client: bug in test case?')

    def find(self, name, rrtype, target, options):
        '''Mock ZoneFinder.find().

        It returns the predefined SOA RRset to queries for SOA of the common
        test zone name.  It also emulates some unusual cases for special
        zone names.

        '''
        if name == TEST_ZONE_NAME and rrtype == RRType.SOA():
            return (ZoneFinder.SUCCESS, begin_soa_rrset)
        if name == Name('no-soa.example'):
            return (ZoneFinder.NXDOMAIN, None)
        if name == Name('dup-soa.example'):
            dup_soa_rrset = RRset(name, TEST_RRCLASS, RRType.SOA(), RRTTL(0))
            dup_soa_rrset.add_rdata(begin_soa_rdata)
            dup_soa_rrset.add_rdata(soa_rdata)
            return (ZoneFinder.SUCCESS, dup_soa_rrset)
        raise ValueError('Unexpected input to mock finder: bug in test case?')

166
167
168
169
    def get_updater(self, zone_name, replace):
        return self

    def add_rrset(self, rrset):
170
        self.diffs.append(('add', rrset))
171

172
    def delete_rrset(self, rrset):
173
        self.diffs.append(('delete', rrset))
174
175

    def commit(self):
176
177
        if self.force_fail:
            raise isc.datasrc.Error('Updater.commit() failed')
178
179
        self.committed_diffs.append(self.diffs)
        self.diffs = []
180

181
class MockXfrin(Xfrin):
182
183
184
185
186
187
188
    # This is a class attribute of a callable object that specifies a non
    # default behavior triggered in _cc_check_command().  Specific test methods
    # are expected to explicitly set this attribute before creating a
    # MockXfrin object (when it needs a non default behavior).
    # See the TestMain class.
    check_command_hook = None

189
    def _cc_setup(self):
Jelte Jansen's avatar
Jelte Jansen committed
190
        self._tsig_key = None
191
        self._module_cc = MockCC()
Likun Zhang's avatar
Likun Zhang committed
192
        pass
193
194
195

    def _get_db_file(self):
        pass
chenzhengzhang's avatar
chenzhengzhang committed
196

197
198
199
200
    def _cc_check_command(self):
        self._shutdown_event.set()
        if MockXfrin.check_command_hook:
            MockXfrin.check_command_hook()
Likun Zhang's avatar
Likun Zhang committed
201

202
    def xfrin_start(self, zone_name, rrclass, db_file, master_addrinfo,
203
                    tsig_key, request_type, check_soa=True):
204
205
206
207
        # store some of the arguments for verification, then call this
        # method in the superclass
        self.xfrin_started_master_addr = master_addrinfo[2][0]
        self.xfrin_started_master_port = master_addrinfo[2][1]
208
        self.xfrin_started_request_type = request_type
209
        return Xfrin.xfrin_start(self, zone_name, rrclass, None,
Jelte Jansen's avatar
Jelte Jansen committed
210
                                 master_addrinfo, tsig_key,
211
                                 request_type, check_soa)
212

213
class MockXfrinConnection(XfrinConnection):
214
    def __init__(self, sock_map, zone_name, rrclass, shutdown_event,
215
                 master_addr):
216
        super().__init__(sock_map, zone_name, rrclass, MockDataSourceClient(),
217
                         shutdown_event, master_addr)
218
219
220
221
        self.query_data = b''
        self.reply_data = b''
        self.force_time_out = False
        self.force_close = False
222
        self.qlen = None
223
224
        self.qid = None
        self.response_generator = None
225

226
227
228
229
230
231
232
233
    def _asyncore_loop(self):
        if self.force_close:
            self.handle_close()
        elif not self.force_time_out:
            self.handle_read()

    def connect_to_master(self):
        return True
234

235
236
    def recv(self, size):
        data = self.reply_data[:size]
237
        self.reply_data = self.reply_data[size:]
238
239
        if len(data) == 0:
            raise XfrinTestTimeoutException('Emulated timeout')
240
        if len(data) < size:
241
242
            raise XfrinTestException('cannot get reply data (' + str(size) +
                                     ' bytes)')
243
        return data
244
245

    def send(self, data):
246
247
248
249
250
        if self.qlen != None and len(self.query_data) >= self.qlen:
            # This is a new query.  reset the internal state.
            self.qlen = None
            self.qid = None
            self.query_data = b''
251
        self.query_data += data
252
253
254
255
256
257

        # when the outgoing data is sufficiently large to contain the length
        # and the QID fields (4 octets or more), extract these fields.
        # The length will be reset the internal query data to support multiple
        # queries in a single test.
        # The QID will be used to construct a matching response.
258
        if len(self.query_data) >= 4 and self.qid == None:
259
260
            self.qlen = socket.htons(struct.unpack('H',
                                                   self.query_data[0:2])[0])
261
262
263
264
            self.qid = socket.htons(struct.unpack('H', self.query_data[2:4])[0])
            # if the response generator method is specified, invoke it now.
            if self.response_generator != None:
                self.response_generator()
265
266
        return len(data)

267
268
269
270
    def create_response_data(self, response=True, bad_qid=False,
                             rcode=Rcode.NOERROR(),
                             questions=default_questions,
                             answers=default_answers,
271
                             tsig_ctx=None):
272
        resp = Message(Message.RENDER)
273
274
275
276
        qid = self.qid
        if bad_qid:
            qid += 1
        resp.set_qid(qid)
277
        resp.set_opcode(Opcode.QUERY())
278
279
        resp.set_rcode(rcode)
        if response:
280
            resp.set_header_flag(Message.HEADERFLAG_QR)
281
        [resp.add_question(q) for q in questions]
282
        [resp.add_rrset(Message.SECTION_ANSWER, a) for a in answers]
283

284
        renderer = MessageRenderer()
285
        if tsig_ctx is not None:
286
287
288
            resp.to_wire(renderer, tsig_ctx)
        else:
            resp.to_wire(renderer)
289
290
        reply_data = struct.pack('H', socket.htons(renderer.get_length()))
        reply_data += renderer.get_data()
291

292
        return reply_data
293

294
295
296
class TestXfrinState(unittest.TestCase):
    def setUp(self):
        self.sock_map = {}
297
        self.conn = MockXfrinConnection(self.sock_map, TEST_ZONE_NAME,
298
                                        TEST_RRCLASS, threading.Event(),
299
                                        TEST_MASTER_IPV4_ADDRINFO)
300
301
302
303
        self.begin_soa = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA(),
                               RRTTL(3600))
        self.begin_soa.add_rdata(Rdata(RRType.SOA(), TEST_RRCLASS,
                                       'm. r. 1230 0 0 0 0'))
304
305
306
307
        self.ns_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.NS(),
                              RRTTL(3600))
        self.ns_rrset.add_rdata(Rdata(RRType.NS(), TEST_RRCLASS,
                                      'ns.example.com'))
308
309
310
311
        self.a_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.A(),
                             RRTTL(3600))
        self.a_rrset.add_rdata(Rdata(RRType.A(), TEST_RRCLASS, '192.0.2.1'))

312
        self.conn._datasrc_client = MockDataSourceClient()
313
        self.conn._diff = Diff(self.conn._datasrc_client, TEST_ZONE_NAME)
314

315
316
317
318
319
320
321
322
323
class TestXfrinStateBase(TestXfrinState):
    def setUp(self):
        super().setUp()

    def test_handle_rr_on_base(self):
        # The base version of handle_rr() isn't supposed to be called
        # directly (the argument doesn't matter in this test)
        self.assertRaises(XfrinException, XfrinState().handle_rr, None)

324
325
326
327
328
329
330
class TestXfrinInitialSOA(TestXfrinState):
    def setUp(self):
        super().setUp()
        self.state = XfrinInitialSOA()

    def test_handle_rr(self):
        # normal case
331
        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
332
333
334
335
336
337
338
339
340
        self.assertEqual(type(XfrinFirstData()),
                         type(self.conn.get_xfrstate()))
        self.assertEqual(1234, self.conn._end_serial)

    def test_handle_not_soa(self):
        # The given RR is not of SOA
        self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
                          self.ns_rrset)

341
342
343
    def test_finish_message(self):
        self.assertTrue(self.state.finish_message(self.conn))

344
345
346
347
348
349
class TestXfrinFirstData(TestXfrinState):
    def setUp(self):
        super().setUp()
        self.state = XfrinFirstData()
        self.conn._request_type = RRType.IXFR()
        self.conn._request_serial = 1230 # arbitrary chosen serial < 1234
350
        self.conn._diff = None           # should be replaced in the AXFR case
351
352
353

    def test_handle_ixfr_begin_soa(self):
        self.conn._request_type = RRType.IXFR()
354
        self.assertFalse(self.state.handle_rr(self.conn, self.begin_soa))
355
356
357
358
359
360
361
        self.assertEqual(type(XfrinIXFRDeleteSOA()),
                         type(self.conn.get_xfrstate()))

    def test_handle_axfr(self):
        # If the original type is AXFR, other conditions aren't considered,
        # and AXFR processing will continue
        self.conn._request_type = RRType.AXFR()
362
        self.assertFalse(self.state.handle_rr(self.conn, self.begin_soa))
363
364
365
366
367
368
369
        self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))

    def test_handle_ixfr_to_axfr(self):
        # Detecting AXFR-compatible IXFR response by seeing a non SOA RR after
        # the initial SOA.  Should switch to AXFR.
        self.assertFalse(self.state.handle_rr(self.conn, self.ns_rrset))
        self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))
370
371
        # The Diff for AXFR should be created at this point
        self.assertNotEqual(None, self.conn._diff)
372
373

    def test_handle_ixfr_to_axfr_by_different_soa(self):
374
375
376
        # An unusual case: Response contains two consecutive SOA but the
        # serial of the second does not match the requested one.  See
        # the documentation for XfrinFirstData.handle_rr().
377
378
        self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
        self.assertEqual(type(XfrinAXFR()), type(self.conn.get_xfrstate()))
379
        self.assertNotEqual(None, self.conn._diff)
380

381
382
383
    def test_finish_message(self):
        self.assertTrue(self.state.finish_message(self.conn))

384
class TestXfrinIXFRDeleteSOA(TestXfrinState):
385
386
387
    def setUp(self):
        super().setUp()
        self.state = XfrinIXFRDeleteSOA()
388
389
390
        # In this state a new Diff object is expected to be created.  To
        # confirm it, we nullify it beforehand.
        self.conn._diff = None
391
392
393
394
395

    def test_handle_rr(self):
        self.assertTrue(self.state.handle_rr(self.conn, self.begin_soa))
        self.assertEqual(type(XfrinIXFRDelete()),
                         type(self.conn.get_xfrstate()))
396
        self.assertEqual([('delete', self.begin_soa)],
397
398
399
400
401
402
                         self.conn._diff.get_buffer())

    def test_handle_non_soa(self):
        self.assertRaises(XfrinException, self.state.handle_rr, self.conn,
                          self.ns_rrset)

403
404
405
    def test_finish_message(self):
        self.assertTrue(self.state.finish_message(self.conn))

406
class TestXfrinIXFRDelete(TestXfrinState):
407
408
    def setUp(self):
        super().setUp()
409
410
        # We need record the state in 'conn' to check the case where the
        # state doesn't change.
411
412
413
414
        XfrinIXFRDelete().set_xfrstate(self.conn, XfrinIXFRDelete())
        self.state = self.conn.get_xfrstate()

    def test_handle_delete_rr(self):
415
        # Non SOA RRs are simply (goting to be) deleted in this state
416
        self.assertTrue(self.state.handle_rr(self.conn, self.ns_rrset))
417
        self.assertEqual([('delete', self.ns_rrset)],
418
419
420
421
422
423
424
425
426
427
428
429
                         self.conn._diff.get_buffer())
        # The state shouldn't change
        self.assertEqual(type(XfrinIXFRDelete()),
                         type(self.conn.get_xfrstate()))

    def test_handle_soa(self):
        # SOA in this state means the beginning of added RRs.  This SOA
        # should also be added in the next state, so handle_rr() should return
        # false.
        self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
        self.assertEqual([], self.conn._diff.get_buffer())
        self.assertEqual(1234, self.conn._current_serial)
430
        self.assertEqual(type(XfrinIXFRAddSOA()),
431
432
                         type(self.conn.get_xfrstate()))

433
434
435
    def test_finish_message(self):
        self.assertTrue(self.state.finish_message(self.conn))

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
class TestXfrinIXFRAddSOA(TestXfrinState):
    def setUp(self):
        super().setUp()
        self.state = XfrinIXFRAddSOA()

    def test_handle_rr(self):
        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
        self.assertEqual(type(XfrinIXFRAdd()), type(self.conn.get_xfrstate()))
        self.assertEqual([('add', soa_rrset)],
                         self.conn._diff.get_buffer())

    def test_handle_non_soa(self):
        self.assertRaises(XfrinException, self.state.handle_rr, self.conn,
                          self.ns_rrset)

451
452
453
    def test_finish_message(self):
        self.assertTrue(self.state.finish_message(self.conn))

454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
class TestXfrinIXFRAdd(TestXfrinState):
    def setUp(self):
        super().setUp()
        # We need record the state in 'conn' to check the case where the
        # state doesn't change.
        XfrinIXFRAdd().set_xfrstate(self.conn, XfrinIXFRAdd())
        self.conn._current_serial = 1230
        self.state = self.conn.get_xfrstate()

    def test_handle_add_rr(self):
        # Non SOA RRs are simply (goting to be) added in this state
        self.assertTrue(self.state.handle_rr(self.conn, self.ns_rrset))
        self.assertEqual([('add', self.ns_rrset)],
                         self.conn._diff.get_buffer())
        # The state shouldn't change
        self.assertEqual(type(XfrinIXFRAdd()), type(self.conn.get_xfrstate()))

    def test_handle_end_soa(self):
        self.conn._end_serial = 1234
        self.conn._diff.add_data(self.ns_rrset) # put some dummy change
        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
        self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
        # handle_rr should have caused commit, and the buffer should now be
        # empty.
        self.assertEqual([], self.conn._diff.get_buffer())

    def test_handle_new_delete(self):
JINMEI Tatuya's avatar
JINMEI Tatuya committed
481
        self.conn._end_serial = 1234
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        # SOA RR whose serial is the current one means we are going to a new
        # difference, starting with removing that SOA.
        self.conn._diff.add_data(self.ns_rrset) # put some dummy change
        self.assertFalse(self.state.handle_rr(self.conn, self.begin_soa))
        self.assertEqual([], self.conn._diff.get_buffer())
        self.assertEqual(type(XfrinIXFRDeleteSOA()),
                         type(self.conn.get_xfrstate()))

    def test_handle_out_of_sync(self):
        # getting SOA with an inconsistent serial.  This is an error.
        self.conn._end_serial = 1235
        self.assertRaises(XfrinProtocolError, self.state.handle_rr,
                          self.conn, soa_rrset)

496
497
498
    def test_finish_message(self):
        self.assertTrue(self.state.finish_message(self.conn))

499
500
501
502
503
504
505
506
507
class TestXfrinIXFREnd(TestXfrinState):
    def setUp(self):
        super().setUp()
        self.state = XfrinIXFREnd()

    def test_handle_rr(self):
        self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
                          self.ns_rrset)

508
509
510
    def test_finish_message(self):
        self.assertFalse(self.state.finish_message(self.conn))

511
512
513
514
class TestXfrinAXFR(TestXfrinState):
    def setUp(self):
        super().setUp()
        self.state = XfrinAXFR()
515
        self.conn._end_serial = 1234
516
517

    def test_handle_rr(self):
518
519
520
521
        """
        Test we can put data inside.
        """
        # Put some data inside
522
        self.assertTrue(self.state.handle_rr(self.conn, self.a_rrset))
523
524
525
        # This test uses internal Diff structure to check the behaviour of
        # XfrinAXFR. Maybe there could be a cleaner way, but it would be more
        # complicated.
526
527
        self.assertEqual([('add', self.a_rrset)], self.conn._diff.get_buffer())
        # This SOA terminates the transfer
528
529
530
        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
        # It should have changed the state
        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
531
532
533
        # At this point, the data haven't been committed yet
        self.assertEqual([('add', self.a_rrset), ('add', soa_rrset)],
                         self.conn._diff.get_buffer())
534

535
536
537
538
539
540
541
    def test_handle_rr_mismatch_soa(self):
        """ SOA with inconsistent serial - unexpected, but we accept it.

        """
        self.assertTrue(self.state.handle_rr(self.conn, begin_soa_rrset))
        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))

542
    def test_finish_message(self):
543
544
545
546
        """
        Check normal end of message.
        """
        # When a message ends, nothing happens usually
547
548
        self.assertTrue(self.state.finish_message(self.conn))

549
550
551
552
553
554
555
556
557
558
class TestXfrinAXFREnd(TestXfrinState):
    def setUp(self):
        super().setUp()
        self.state = XfrinAXFREnd()

    def test_handle_rr(self):
        self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
                          self.ns_rrset)

    def test_finish_message(self):
559
560
        self.conn._diff.add_data(self.a_rrset)
        self.conn._diff.add_data(soa_rrset)
561
562
        self.assertFalse(self.state.finish_message(self.conn))

563
564
565
566
567
568
569
        # The data should have been committed
        self.assertEqual([], self.conn._diff.get_buffer())
        check_diffs(self.assertEqual, [[('add', self.a_rrset),
                                        ('add', soa_rrset)]],
                    self.conn._datasrc_client.committed_diffs)
        self.assertRaises(ValueError, self.conn._diff.commit)

570
class TestXfrinConnection(unittest.TestCase):
571
572
573
574
575
576
577
    '''Convenient parent class for XFR-protocol tests.

    This class provides common setups and helper methods for protocol related
    tests on AXFR and IXFR.

    '''

578
    def setUp(self):
579
580
        if os.path.exists(TEST_DB_FILE):
            os.remove(TEST_DB_FILE)
581
        self.sock_map = {}
582
        self.conn = MockXfrinConnection(self.sock_map, TEST_ZONE_NAME,
583
                                        TEST_RRCLASS, threading.Event(),
584
585
586
587
588
                                        TEST_MASTER_IPV4_ADDRINFO)
        self.soa_response_params = {
            'questions': [example_soa_question],
            'bad_qid': False,
            'response': True,
589
            'rcode': Rcode.NOERROR(),
590
            'tsig': False,
591
592
            'axfr_after_soa': self._create_normal_response_data
            }
593
        self.axfr_response_params = {
594
595
            'question_1st': default_questions,
            'question_2nd': default_questions,
596
597
            'answer_1st': [soa_rrset, self._create_ns()],
            'answer_2nd': default_answers,
598
599
            'tsig_1st': None,
            'tsig_2nd': None
600
            }
601
602
603
604
605
606

    def tearDown(self):
        self.conn.close()
        if os.path.exists(TEST_DB_FILE):
            os.remove(TEST_DB_FILE)

607
608
    def _create_normal_response_data(self):
        # This helper method creates a simple sequence of DNS messages that
609
610
        # forms a valid AXFR transaction.  It consists of two messages: the
        # first one containing SOA, NS, the second containing the trailing SOA.
611
612
        question_1st = self.axfr_response_params['question_1st']
        question_2nd = self.axfr_response_params['question_2nd']
613
614
        answer_1st = self.axfr_response_params['answer_1st']
        answer_2nd = self.axfr_response_params['answer_2nd']
615
616
        tsig_1st = self.axfr_response_params['tsig_1st']
        tsig_2nd = self.axfr_response_params['tsig_2nd']
617
        self.conn.reply_data = self.conn.create_response_data(
618
            questions=question_1st, answers=answer_1st,
619
            tsig_ctx=tsig_1st)
620
        self.conn.reply_data += \
621
            self.conn.create_response_data(questions=question_2nd,
622
                                           answers=answer_2nd,
623
                                           tsig_ctx=tsig_2nd)
624
625
626
627
628
629
630
631
632
633
634
635

    def _create_soa_response_data(self):
        # This helper method creates a DNS message that is supposed to be
        # used a valid response to SOA queries prior to XFR.
        # If tsig is True, it tries to verify the query with a locally
        # created TSIG context (which may or may not succeed) so that the
        # response will include a TSIG.
        # If axfr_after_soa is True, it resets the response_generator so that
        # a valid XFR messages will follow.

        verify_ctx = None
        if self.soa_response_params['tsig']:
Jelte Jansen's avatar
Jelte Jansen committed
636
            # xfrin (currently) always uses TCP.  strip off the length field.
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
            query_data = self.conn.query_data[2:]
            query_message = Message(Message.PARSE)
            query_message.from_wire(query_data)
            verify_ctx = TSIGContext(TSIG_KEY)
            verify_ctx.verify(query_message.get_tsig_record(), query_data)

        self.conn.reply_data = self.conn.create_response_data(
            bad_qid=self.soa_response_params['bad_qid'],
            response=self.soa_response_params['response'],
            rcode=self.soa_response_params['rcode'],
            questions=self.soa_response_params['questions'],
            tsig_ctx=verify_ctx)
        if self.soa_response_params['axfr_after_soa'] != None:
            self.conn.response_generator = \
                self.soa_response_params['axfr_after_soa']

    def _create_broken_response_data(self):
        # This helper method creates a bogus "DNS message" that only contains
        # 4 octets of data.  The DNS message parser will raise an exception.
        bogus_data = b'xxxx'
        self.conn.reply_data = struct.pack('H', socket.htons(len(bogus_data)))
        self.conn.reply_data += bogus_data

660
661
662
663
664
665
666
667
668
669
670
671
672
    def _create_a(self, address):
        rrset = RRset(Name('a.example.com'), TEST_RRCLASS, RRType.A(),
                      RRTTL(3600))
        rrset.add_rdata(Rdata(RRType.A(), TEST_RRCLASS, address))
        return rrset

    def _create_soa(self, serial):
        rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA(),
                      RRTTL(3600))
        rdata_str = 'm. r. ' + serial + ' 3600 1800 2419200 7200'
        rrset.add_rdata(Rdata(RRType.SOA(), TEST_RRCLASS, rdata_str))
        return rrset

JINMEI Tatuya's avatar
JINMEI Tatuya committed
673
674
675
676
677
    def _create_ns(self, nsname='ns.'+TEST_ZONE_NAME_STR):
        rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.NS(), RRTTL(3600))
        rrset.add_rdata(Rdata(RRType.NS(), TEST_RRCLASS, nsname))
        return rrset

678
679
680
class TestAXFR(TestXfrinConnection):
    def setUp(self):
        super().setUp()
681
        XfrinInitialSOA().set_xfrstate(self.conn, XfrinInitialSOA())
682

683
684
685
686
687
688
689
690
    def __create_mock_tsig(self, key, error):
        # This helper function creates a MockTSIGContext for a given key
        # and TSIG error to be used as a result of verify (normally faked
        # one)
        mock_ctx = MockTSIGContext(key)
        mock_ctx.error = error
        return mock_ctx

chenzhengzhang's avatar
chenzhengzhang committed
691
    def __match_exception(self, expected_exception, expected_msg, expression):
692
693
694
695
        # This helper method is a higher-granularity version of assertRaises().
        # If it's not sufficient to check the exception class (e.g., when
        # the same type of exceptions can be thrown from many places), this
        # method can be used to check it with the exception argument.
chenzhengzhang's avatar
chenzhengzhang committed
696
697
698
699
700
701
702
        try:
            expression()
        except expected_exception as ex:
            self.assertEqual(str(ex), expected_msg)
        else:
            self.assertFalse('exception is expected, but not raised')

703
704
705
706
707
708
709
710
711
    def test_close(self):
        # we shouldn't be using the global asyncore map.
        self.assertEqual(len(asyncore.socket_map), 0)
        # there should be exactly one entry in our local map
        self.assertEqual(len(self.sock_map), 1)
        # once closing the dispatch the map should become empty
        self.conn.close()
        self.assertEqual(len(self.sock_map), 0)

712
713
714
715
716
717
    def test_init_ip6(self):
        # This test simply creates a new XfrinConnection object with an
        # IPv6 address, tries to bind it to an IPv6 wildcard address/port
        # to confirm an AF_INET6 socket has been created.  A naive application
        # tends to assume it's IPv4 only and hardcode AF_INET.  This test
        # uncovers such a bug.
718
719
        c = MockXfrinConnection({}, TEST_ZONE_NAME, TEST_RRCLASS,
                                threading.Event(), TEST_MASTER_IPV6_ADDRINFO)
720
721
722
723
        c.bind(('::', 0))
        c.close()

    def test_init_chclass(self):
724
        c = MockXfrinConnection({}, TEST_ZONE_NAME, RRClass.CH(),
725
                                threading.Event(), TEST_MASTER_IPV4_ADDRINFO)
726
727
728
        axfrmsg = c._create_query(RRType.AXFR())
        self.assertEqual(axfrmsg.get_question()[0].get_class(),
                         RRClass.CH())
729
        c.close()
730

731
    def test_create_query(self):
732
        def check_query(expected_qtype, expected_auth):
733
734
735
736
737
738
739
740
            '''Helper method to repeat the same pattern of tests'''
            self.assertEqual(Opcode.QUERY(), msg.get_opcode())
            self.assertEqual(Rcode.NOERROR(), msg.get_rcode())
            self.assertEqual(1, msg.get_rr_count(Message.SECTION_QUESTION))
            self.assertEqual(TEST_ZONE_NAME, msg.get_question()[0].get_name())
            self.assertEqual(expected_qtype, msg.get_question()[0].get_type())
            self.assertEqual(0, msg.get_rr_count(Message.SECTION_ANSWER))
            self.assertEqual(0, msg.get_rr_count(Message.SECTION_ADDITIONAL))
741
            if expected_auth is None:
742
743
744
745
746
                self.assertEqual(0,
                                 msg.get_rr_count(Message.SECTION_AUTHORITY))
            else:
                self.assertEqual(1,
                                 msg.get_rr_count(Message.SECTION_AUTHORITY))
747
748
749
750
751
752
753
754
755
756
                auth_rr = msg.get_section(Message.SECTION_AUTHORITY)[0]
                self.assertEqual(expected_auth.get_name(), auth_rr.get_name())
                self.assertEqual(expected_auth.get_type(), auth_rr.get_type())
                self.assertEqual(expected_auth.get_class(),
                                 auth_rr.get_class())
                # In our test scenario RDATA must be 1
                self.assertEqual(1, expected_auth.get_rdata_count())
                self.assertEqual(1, auth_rr.get_rdata_count())
                self.assertEqual(expected_auth.get_rdata()[0],
                                 auth_rr.get_rdata()[0])
757
758
759
760
761
762
763
764
765

        # Actual tests start here
        # SOA query
        msg = self.conn._create_query(RRType.SOA())
        check_query(RRType.SOA(), None)

        # AXFR query
        msg = self.conn._create_query(RRType.AXFR())
        check_query(RRType.AXFR(), None)
766

767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
        # IXFR query
        msg = self.conn._create_query(RRType.IXFR())
        check_query(RRType.IXFR(), begin_soa_rrset)
        self.assertEqual(1230, self.conn._request_serial)

    def test_create_ixfr_query_fail(self):
        # In these cases _create_query() will fail to find a valid SOA RR to
        # insert in the IXFR query, and should raise an exception.

        self.conn._zone_name = Name('no-such-zone.example')
        self.assertRaises(XfrinException, self.conn._create_query,
                          RRType.IXFR())

        self.conn._zone_name = Name('partial-match-zone.example')
        self.assertRaises(XfrinException, self.conn._create_query,
                          RRType.IXFR())

        self.conn._zone_name = Name('no-soa.example')
        self.assertRaises(XfrinException, self.conn._create_query,
                          RRType.IXFR())

        self.conn._zone_name = Name('dup-soa.example')
        self.assertRaises(XfrinException, self.conn._create_query,
                          RRType.IXFR())
791

792
    def test_send_query(self):
793
794
795
796
797
798
799
800
        def message_has_tsig(data):
            # a simple check if the actual data contains a TSIG RR.
            # At our level this simple check should suffice; other detailed
            # tests regarding the TSIG protocol are done in pydnspp.
            msg = Message(Message.PARSE)
            msg.from_wire(data)
            return msg.get_tsig_record() is not None

801
        # soa request with tsig
802
        self.conn._tsig_key = TSIG_KEY
803
        self.conn._send_query(RRType.SOA())
804
        self.assertTrue(message_has_tsig(self.conn.query_data[2:]))
805
806
807

        # axfr request with tsig
        self.conn._send_query(RRType.AXFR())
808
        self.assertTrue(message_has_tsig(self.conn.query_data[2:]))
809

810
    def test_response_with_invalid_msg(self):
811
        self.conn.reply_data = b'aaaxxxx'
812
813
        self.assertRaises(XfrinTestException,
                          self.conn._handle_xfrin_responses)
814

815
    def test_response_with_tsigfail(self):
816
        self.conn._tsig_key = TSIG_KEY
817
818
819
        # server tsig check fail, return with RCODE 9 (NOTAUTH)
        self.conn._send_query(RRType.SOA())
        self.conn.reply_data = self.conn.create_response_data(rcode=Rcode.NOTAUTH())
820
        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
821

822
    def test_response_without_end_soa(self):
823
        self.conn._send_query(RRType.AXFR())
824
        self.conn.reply_data = self.conn.create_response_data()
825
826
827
828
        # This should result in timeout in the asyncore loop.  We emulate
        # that situation in recv() by emptying the reply data buffer.
        self.assertRaises(XfrinTestTimeoutException,
                          self.conn._handle_xfrin_responses)
829
830

    def test_response_bad_qid(self):
831
        self.conn._send_query(RRType.AXFR())
832
833
        self.conn.reply_data = self.conn.create_response_data(bad_qid=True)
        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
834

835
836
837
838
839
840
841
842
843
    def test_response_error_code_bad_sig(self):
        self.conn._tsig_key = TSIG_KEY
        self.conn._tsig_ctx_creator = \
            lambda key: self.__create_mock_tsig(key, TSIGError.BAD_SIG)
        self.conn._send_query(RRType.AXFR())
        self.conn.reply_data = self.conn.create_response_data(
                rcode=Rcode.SERVFAIL())
        # xfrin should check TSIG before other part of incoming message
        # validate log message for XfrinException
chenzhengzhang's avatar
chenzhengzhang committed
844
845
        self.__match_exception(XfrinException,
                               "TSIG verify fail: BADSIG",
846
                               self.conn._handle_xfrin_responses)
847
848
849
850
851
852

    def test_response_bad_qid_bad_key(self):
        self.conn._tsig_key = TSIG_KEY
        self.conn._tsig_ctx_creator = \
            lambda key: self.__create_mock_tsig(key, TSIGError.BAD_KEY)
        self.conn._send_query(RRType.AXFR())
853
        self.conn.reply_data = self.conn.create_response_data(bad_qid=True)
854
855
        # xfrin should check TSIG before other part of incoming message
        # validate log message for XfrinException
chenzhengzhang's avatar
chenzhengzhang committed
856
857
        self.__match_exception(XfrinException,
                               "TSIG verify fail: BADKEY",
858
                               self.conn._handle_xfrin_responses)
859

860
    def test_response_non_response(self):
861
        self.conn._send_query(RRType.AXFR())
862
863
        self.conn.reply_data = self.conn.create_response_data(response=False)
        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
864
865

    def test_response_error_code(self):
866
        self.conn._send_query(RRType.AXFR())
867
        self.conn.reply_data = self.conn.create_response_data(
868
            rcode=Rcode.SERVFAIL())
869
        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
870
871

    def test_response_multi_question(self):
872
        self.conn._send_query(RRType.AXFR())
873
        self.conn.reply_data = self.conn.create_response_data(
874
            questions=[example_axfr_question, example_axfr_question])
875
        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
876
877

    def test_response_non_response(self):
878
        self.conn._send_query(RRType.AXFR())
879
        self.conn.reply_data = self.conn.create_response_data(response = False)
880
        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
881
882
883
884
885
886
887
888
889

    def test_soacheck(self):
        # we need to defer the creation until we know the QID, which is
        # determined in _check_soa_serial(), so we use response_generator.
        self.conn.response_generator = self._create_soa_response_data
        self.assertEqual(self.conn._check_soa_serial(), XFRIN_OK)

    def test_soacheck_with_bad_response(self):
        self.conn.response_generator = self._create_broken_response_data
890
        self.assertRaises(MessageTooShort, self.conn._check_soa_serial)
891
892
893
894
895
896

    def test_soacheck_badqid(self):
        self.soa_response_params['bad_qid'] = True
        self.conn.response_generator = self._create_soa_response_data
        self.assertRaises(XfrinException, self.conn._check_soa_serial)

897
898
899
900
901
902
903
904
    def test_soacheck_bad_qid_bad_sig(self):
        self.conn._tsig_key = TSIG_KEY
        self.conn._tsig_ctx_creator = \
            lambda key: self.__create_mock_tsig(key, TSIGError.BAD_SIG)
        self.soa_response_params['bad_qid'] = True
        self.conn.response_generator = self._create_soa_response_data
        # xfrin should check TSIG before other part of incoming message
        # validate log message for XfrinException
chenzhengzhang's avatar
chenzhengzhang committed
905
906
907
        self.__match_exception(XfrinException,
                               "TSIG verify fail: BADSIG",
                               self.conn._check_soa_serial)
908

909
910
911
912
913
914
    def test_soacheck_non_response(self):
        self.soa_response_params['response'] = False
        self.conn.response_generator = self._create_soa_response_data
        self.assertRaises(XfrinException, self.conn._check_soa_serial)

    def test_soacheck_error_code(self):
915
        self.soa_response_params['rcode'] = Rcode.SERVFAIL()
916
917
        self.conn.response_generator = self._create_soa_response_data
        self.assertRaises(XfrinException, self.conn._check_soa_serial)
918

919
    def test_soacheck_with_tsig(self):
920
921
922
923
        # Use a mock tsig context emulating a validly signed response
        self.conn._tsig_key = TSIG_KEY
        self.conn._tsig_ctx_creator = \
            lambda key: self.__create_mock_tsig(key, TSIGError.NOERROR)
924
925
926
927
928
929
        self.conn.response_generator = self._create_soa_response_data
        self.assertEqual(self.conn._check_soa_serial(), XFRIN_OK)
        self.assertEqual(self.conn._tsig_ctx.get_error(), TSIGError.NOERROR)

    def test_soacheck_with_tsig_notauth(self):
        # emulate a valid error response
930
931
932
        self.conn._tsig_key = TSIG_KEY
        self.conn._tsig_ctx_creator = \
            lambda key: self.__create_mock_tsig(key, TSIGError.BAD_SIG)
933
934
935
936
937
938
        self.soa_response_params['rcode'] = Rcode.NOTAUTH()
        self.conn.response_generator = self._create_soa_response_data

        self.assertRaises(XfrinException, self.conn._check_soa_serial)

    def test_soacheck_with_tsig_noerror_badsig(self):
939
940
941
        self.conn._tsig_key = TSIG_KEY
        self.conn._tsig_ctx_creator = \
            lambda key: self.__create_mock_tsig(key, TSIGError.BAD_SIG)
942
943
944
945
946
947
948
949
950
951
952
953
954
955

        # emulate a normal response bad verification failure due to BADSIG.
        # According RFC2845, in this case we should ignore it and keep
        # waiting for a valid response until a timeout.  But we immediately
        # treat this as a final failure (just as BIND 9 does).
        self.conn.response_generator = self._create_soa_response_data

        self.assertRaises(XfrinException, self.conn._check_soa_serial)

    def test_soacheck_with_tsig_unsigned_response(self):
        # we can use a real TSIGContext for this.  the response doesn't
        # contain a TSIG while we sent a signed query.  RFC2845 states
        # we should wait for a valid response in this case, but we treat
        # it as a fatal transaction failure, too.
956
        self.conn._tsig_key = TSIG_KEY
957
958
959
960
961
962
963
964
965
966
        self.conn.response_generator = self._create_soa_response_data
        self.assertRaises(XfrinException, self.conn._check_soa_serial)

    def test_soacheck_with_unexpected_tsig_response(self):
        # we reject unexpected TSIG in responses (following BIND 9's
        # behavior)
        self.soa_response_params['tsig'] = True
        self.conn.response_generator = self._create_soa_response_data
        self.assertRaises(XfrinException, self.conn._check_soa_serial)

967
968
969
    def test_response_shutdown(self):
        self.conn.response_generator = self._create_normal_response_data
        self.conn._shutdown_event.set()
970
        self.conn._send_query(RRType.AXFR())
971
        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
972
973
974
975

    def test_response_timeout(self):
        self.conn.response_generator = self._create_normal_response_data
        self.conn.force_time_out = True
976
        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
977
978
979
980

    def test_response_remote_close(self):
        self.conn.response_generator = self._create_normal_response_data
        self.conn.force_close = True
981
        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
982

983
984
    def test_response_bad_message(self):
        self.conn.response_generator = self._create_broken_response_data
985
        self.conn._send_query(RRType.AXFR())
986
        self.assertRaises(Exception, self.conn._handle_xfrin_responses)
987

988
    def test_axfr_response(self):
989
        # A simple normal case: AXFR consists of SOA, NS, then trailing SOA.
990
        self.conn.response_generator = self._create_normal_response_data
991
        self.conn._send_query(RRType.AXFR())
992
993
994
995
996
        self.conn._handle_xfrin_responses()
        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
        check_diffs(self.assertEqual,
                    [[('add', self._create_ns()), ('add', soa_rrset)]],
                    self.conn._datasrc_client.committed_diffs)
Likun Zhang's avatar
Likun Zhang committed
997

998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
    def test_response_empty_answer(self):
        '''Test with an empty AXFR answer section.

        This is an unusual response, but there is no reason to reject it.
        The second message is a complete AXFR response, and transfer should
        succeed just like the normal case.

        '''

        self.axfr_response_params['answer_1st'] = []
        self.axfr_response_params['answer_2nd'] = [soa_rrset,
                                                   self._create_ns(),
                                                   soa_rrset]
        self.conn.response_generator = self._create_normal_response_data
        self.conn._send_query(RRType.AXFR())
        self.conn._handle_xfrin_responses()
        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
        check_diffs(self.assertEqual,
                    [[('add', self._create_ns()), ('add', soa_rrset)]],
                    self.conn._datasrc_client.committed_diffs)

1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
    def test_axfr_response_soa_mismatch(self):
        '''AXFR response whose begin/end SOAs are not same.

        What should we do this is moot, for now we accept it, so does BIND 9.

        '''
        ns_rr = self._create_ns()
        a_rr = self._create_a('192.0.2.1')
        self.conn._send_query(RRType.AXFR())
        self.conn.reply_data = self.conn.create_response_data(
            questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS,
                                RRType.AXFR())],
            # begin serial=1230, end serial=1234. end will be used.
            answers=[begin_soa_rrset, ns_rr, a_rr, soa_rrset])
        self.conn._handle_xfrin_responses()
        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
        check_diffs(self.assertEqual,
                    [[('add', ns_rr), ('add', a_rr), ('add', soa_rrset)]],
                    self.conn._datasrc_client.committed_diffs)

    def test_axfr_response_extra(self):
        '''Test with an extra RR after the end of AXFR session.

        The session should be rejected, and nothing should be committed.

        '''
        ns_rr = self._create_ns()
        a_rr = self._create_a('192.0.2.1')
        self.conn._send_query(RRType.AXFR())
        self.conn.reply_data = self.conn.create_response_data(
            questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS,
                                RRType.AXFR())],
            answers=[soa_rrset, ns_rr, a_rr, soa_rrset, a_rr])
        self.assertRaises(XfrinProtocolError,
                          self.conn._handle_xfrin_responses)
        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
        self.assertEqual([], self.conn._datasrc_client.committed_diffs)

    def test_axfr_response_qname_mismatch(self):
        '''AXFR response with a mismatch question name.

        Our implementation accepts that, so does BIND 9.

        '''
        self.axfr_response_params['question_1st'] = \
            [Question(Name('mismatch.example'), TEST_RRCLASS, RRType.AXFR())]
1065
        self.conn.response_generator = self._create_normal_response_data
1066
        self.conn._send_query(RRType.AXFR())
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
        self.conn._handle_xfrin_responses()
        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
        check_diffs(self.assertEqual,
                    [[('add', self._create_ns()), ('add', soa_rrset)]],
                    self.conn._datasrc_client.committed_diffs)

    def test_axfr_response_qclass_mismatch(self):
        '''AXFR response with a mismatch RR class.

        Our implementation accepts that, so does BIND 9.

        '''
        self.axfr_response_params['question_1st'] = \
            [Question(TEST_ZONE_NAME, RRClass.CH(), RRType.AXFR())]
        self.conn.response_generator = self._create_normal_response_data
        self.conn._send_query(RRType.AXFR())
        self.conn._handle_xfrin_responses()
        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
        check_diffs(self.assertEqual,
                    [[('add', self._create_ns()), ('add', soa_rrset)]],
                    self.conn._datasrc_client.committed_diffs)

    def test_axfr_response_qtype_mismatch(self):
        '''AXFR response with a mismatch RR type.

        Our implementation accepts that, so does BIND 9.

        '''
        # returning IXFR in question to AXFR query
        self.axfr_response_params['question_1st'] = \
            [Question(TEST_ZONE_NAME, RRClass.CH(), RRType.IXFR())]
        self.conn.response_generator = self._create_normal_response_data
        self.conn._send_query(RRType.AXFR())
        self.conn._handle_xfrin_responses()
        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
        check_diffs(self.assertEqual,
                    [[('add', self._create_ns()), ('add', soa_rrset)]],
                    self.conn._datasrc_client.committed_diffs)
1105

1106
1107
1108
1109
1110
1111
1112
    def test_axfr_response_empty_question(self):
        '''AXFR response with an empty question.

        Our implementation accepts that, so does BIND 9.

        '''
        self.axfr_response_params['question_1st'] = []
1113
        self.conn.response_generator = self._create_normal_response_data
1114
        self.conn._send_query(RRType.AXFR())
1115
1116
1117
1118
1119
        self.conn._handle_xfrin_responses()
        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
        check_diffs(self.assertEqual,
                    [[('add', self._create_ns()), ('add', soa_rrset)]],
                    self.conn._datasrc_client.committed_diffs)
Likun Zhang's avatar
Likun Zhang committed
1120

1121
1122
1123
1124
    def test_do_xfrin(self):
        self.conn.response_generator = self._create_normal_response_data
        self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK)

1125
1126
1127
    def test_do_xfrin_with_tsig(self):
        # use TSIG with a mock context.  we fake all verify results to
        # emulate successful verification.
1128
1129
1130
        self.conn._tsig_key = TSIG_KEY
        self.conn._tsig_ctx_creator = \
            lambda key: self.__create_mock_tsig(key, TSIGError.NOERROR)
1131
1132
        self.conn.response_generator = self._create_normal_response_data
        self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK)
1133
1134
1135
1136
        self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
        check_diffs(self.assertEqual,
                    [[('add', self._create_ns()), ('add', soa_rrset)]],
                    self.conn._datasrc_client.committed_diffs)
1137
1138
1139
1140

    def test_do_xfrin_with_tsig_fail(self):
        # TSIG verify will fail for the first message.  xfrin should fail
        # immediately.
1141
1142
1143
        self.conn._tsig_key = TSIG_KEY
        self.conn._tsig_ctx_creator = \
            lambda key: self.__create_mock_tsig(key, TSIGError.BAD_SIG)
1144
1145
1146
1147
1148
1149
1150