xfrout_test.py.in 44 KB
Newer Older
Likun Zhang's avatar
Likun Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (C) 2010  Internet Systems Consortium.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

'''Tests for the XfroutSession and UnixSockServer classes '''


import unittest
import os
21
from isc.testutils.tsigctx_mock import MockTSIGContext
Likun Zhang's avatar
Likun Zhang committed
22
from isc.cc.session import *
23
import isc.config
24
from isc.dns import *
Likun Zhang's avatar
Likun Zhang committed
25
from xfrout import *
26
import xfrout
27
import isc.log
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
28
import isc.acl.dns
Likun Zhang's avatar
Likun Zhang committed
29

30
TESTDATA_SRCDIR = os.getenv("TESTDATASRCDIR")
31
32
TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")

Likun Zhang's avatar
Likun Zhang committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# our fake socket, where we can read and insert messages
class MySocket():
    def __init__(self, family, type):
        self.family = family
        self.type = type
        self.sendqueue = bytearray()

    def connect(self, to):
        pass

    def close(self):
        pass

    def send(self, data):
        self.sendqueue.extend(data);
        return len(data)

    def readsent(self):
51
52
53
54
55
56
        if len(self.sendqueue) >= 2:
            size = 2 + struct.unpack("!H", self.sendqueue[:2])[0]
        else:
            size = 0
        result = self.sendqueue[:size]
        self.sendqueue = self.sendqueue[size:]
Likun Zhang's avatar
Likun Zhang committed
57
        return result
58

59
    def read_msg(self, parse_options=Message.PARSE_DEFAULT):
Likun Zhang's avatar
Likun Zhang committed
60
        sent_data = self.readsent()
61
        get_msg = Message(Message.PARSE)
62
        get_msg.from_wire(bytes(sent_data[2:]), parse_options)
Likun Zhang's avatar
Likun Zhang committed
63
        return get_msg
64

Likun Zhang's avatar
Likun Zhang committed
65
66
67
    def clear_send(self):
        del self.sendqueue[:]

68
69
70
71
class MockDataSrcClient:
    def __init__(self, type, config):
        pass

72
    def get_iterator(self, zone_name, adjust_ttl=False):
73
74
75
76
77
78
79
80
        if zone_name == Name('notauth.example.com'):
            raise isc.datasrc.Error('no such zone')
        self._zone_name = zone_name
        return self

    def get_soa(self):  # emulate ZoneIterator.get_soa()
        if self._zone_name == Name('nosoa.example.com'):
            return None
81
82
        soa_rrset = RRset(self._zone_name, RRClass.IN(), RRType.SOA(),
                          RRTTL(3600))
83
84
85
86
87
88
89
90
91
92
93
        soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
                                  'master.example.com. ' +
                                  'admin.example.com. 1234 ' +
                                  '3600 1800 2419200 7200'))
        if self._zone_name == Name('multisoa.example.com'):
            soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
                                      'master.example.com. ' +
                                      'admin.example.com. 1300 ' +
                                      '3600 1800 2419200 7200'))
        return soa_rrset

94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
class MyCCSession(isc.config.ConfigData):
    def __init__(self):
        module_spec = isc.config.module_spec_from_file(
            xfrout.SPECFILE_LOCATION)
        ConfigData.__init__(self, module_spec)

    def get_remote_config_value(self, module_name, identifier):
        if module_name == "Auth" and identifier == "database_file":
            return "initdb.file", False
        else:
            return "unknown", False

# This constant dictionary stores all default configuration parameters
# defined in the xfrout spec file.
DEFAULT_CONFIG = MyCCSession().get_full_config()

110
111
# We subclass the Session class we're testing here, only overriding a few
# methods
Likun Zhang's avatar
Likun Zhang committed
112
class MyXfroutSession(XfroutSession):
113
114
115
116
    def _handle(self):
        pass

    def _close_socket(self):
Likun Zhang's avatar
Likun Zhang committed
117
        pass
118
119
120
121
122
123
124
125

    def _send_data(self, sock, data):
        size = len(data)
        total_count = 0
        while total_count < size:
            count = sock.send(data[total_count:])
            total_count += count

Likun Zhang's avatar
Likun Zhang committed
126
127
128
class Dbserver:
    def __init__(self):
        self._shutdown_event = threading.Event()
129
        self.transfer_counter = 0
130
        self._max_transfers_out = DEFAULT_CONFIG['transfers_out']
Likun Zhang's avatar
Likun Zhang committed
131
    def get_db_file(self):
132
        return 'test.sqlite3'
133
    def increase_transfers_counter(self):
134
        self.transfer_counter += 1
135
        return True
Likun Zhang's avatar
Likun Zhang committed
136
    def decrease_transfers_counter(self):
137
        self.transfer_counter -= 1
Likun Zhang's avatar
Likun Zhang committed
138

139
140
141
142
143
144
145
class TestXfroutSessionBase(unittest.TestCase):
    '''Base classs for tests related to xfrout sessions

    This class defines common setup/teadown and utility methods.  Actual
    tests are delegated to subclasses.

    '''
Likun Zhang's avatar
Likun Zhang committed
146
    def getmsg(self):
147
148
        msg = Message(Message.PARSE)
        msg.from_wire(self.mdata)
Likun Zhang's avatar
Likun Zhang committed
149
150
        return msg

151
152
153
154
155
156
157
158
159
160
161
    def create_mock_tsig_ctx(self, error):
        # This helper function creates a MockTSIGContext for a given key
        # and TSIG error to be used as a result of verify (normally faked
        # one)
        mock_ctx = MockTSIGContext(TSIG_KEY)
        mock_ctx.error = error
        return mock_ctx

    def message_has_tsig(self, msg):
        return msg.get_tsig_record() is not None

162
    def create_request_data(self, with_question=True, with_tsig=False):
163
164
165
166
167
        msg = Message(Message.RENDER)
        query_id = 0x1035
        msg.set_qid(query_id)
        msg.set_opcode(Opcode.QUERY())
        msg.set_rcode(Rcode.NOERROR())
168
169
170
        if with_question:
            msg.add_question(Question(Name("example.com"), RRClass.IN(),
                                      RRType.AXFR()))
171
172

        renderer = MessageRenderer()
173
174
175
176
177
178
179
        if with_tsig:
            tsig_ctx = MockTSIGContext(TSIG_KEY)
            msg.to_wire(renderer, tsig_ctx)
        else:
            msg.to_wire(renderer)
        request_data = renderer.get_data()
        return request_data
180

Likun Zhang's avatar
Likun Zhang committed
181
    def setUp(self):
182
        self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
183
        self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
184
185
186
                                       TSIGKeyRing(),
                                       (socket.AF_INET, socket.SOCK_STREAM,
                                        ('127.0.0.1', 12345)),
187
188
                                       # When not testing ACLs, simply accept
                                       isc.acl.dns.REQUEST_LOADER.load(
189
190
                                           [{"action": "ACCEPT"}]),
                                       {})
191
        self.mdata = self.create_request_data()
192
193
194
195
196
197
        self.soa_rrset = RRset(Name('example.com'), RRClass.IN(), RRType.SOA(),
                               RRTTL(3600))
        self.soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
                                       'master.Example.com. ' +
                                       'admin.exAmple.com. ' +
                                       '1234 3600 1800 2419200 7200'))
198
199
200
        # some test replaces a module-wide function.  We should ensure the
        # original is used elsewhere.
        self.orig_get_rrset_len = xfrout.get_rrset_len
Likun Zhang's avatar
Likun Zhang committed
201

202
    def tearDown(self):
203
        xfrout.get_rrset_len = self.orig_get_rrset_len
204
205
206
207
        # transfer_counter must be always be reset no matter happens within
        # the XfroutSession object.  We check the condition here.
        self.assertEqual(0, self.xfrsess._server.transfer_counter)

208
class TestXfroutSession(TestXfroutSessionBase):
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    def test_quota_error(self):
        '''Emulating the server being too busy.

        '''
        self.xfrsess._request_data = self.mdata
        self.xfrsess._server.increase_transfers_counter = lambda : False
        XfroutSession._handle(self.xfrsess)
        self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.REFUSED())

    def test_quota_ok(self):
        '''The default case in terms of the xfrout quota.

        '''
        # set up a bogus request, which should result in FORMERR. (it only
        # has to be something that is different from the previous case)
        self.xfrsess._request_data = \
            self.create_request_data(with_question=False)
        # Replace the data source client to avoid datasrc related exceptions
        self.xfrsess.ClientClass = MockDataSrcClient
        XfroutSession._handle(self.xfrsess)
        self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.FORMERR())

    def test_exception_from_session(self):
        '''Test the case where the main processing raises an exception.

JINMEI Tatuya's avatar
JINMEI Tatuya committed
234
        We just check it doesn't any unexpected disruption and (in tearDown)
235
236
237
238
239
240
241
242
        transfer_counter is correctly reset to 0.

        '''
        def dns_xfrout_start(fd, msg, quota):
            raise ValueError('fake exception')
        self.xfrsess.dns_xfrout_start = dns_xfrout_start
        XfroutSession._handle(self.xfrsess)

Likun Zhang's avatar
Likun Zhang committed
243
244
245
246
    def test_parse_query_message(self):
        [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(get_rcode.to_text(), "NOERROR")

247
248
249
250
251
        # Broken request: no question
        request_data = self.create_request_data(with_question=False)
        rcode, msg = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(Rcode.FORMERR(), rcode)

252
        # tsig signed query message
253
        request_data = self.create_request_data(with_tsig=True)
254
255
256
257
258
        # BADKEY
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOTAUTH")
        self.assertTrue(self.xfrsess._tsig_ctx is not None)
        # NOERROR
259
260
        self.assertEqual(TSIGKeyRing.SUCCESS,
                         self.xfrsess._tsig_key_ring.add(TSIG_KEY))
261
262
263
264
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOERROR")
        self.assertTrue(self.xfrsess._tsig_ctx is not None)

265
    def check_transfer_acl(self, acl_setter):
266
        # ACL checks, put some ACL inside
267
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
268
269
270
271
272
273
274
275
            {
                "from": "127.0.0.1",
                "action": "ACCEPT"
            },
            {
                "from": "192.0.2.1",
                "action": "DROP"
            }
276
        ]))
277
278
279
280
        # Localhost (the default in this test) is accepted
        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(rcode.to_text(), "NOERROR")
        # This should be dropped completely, therefore returning None
281
282
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.1', 12345))
283
        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
284
        self.assertEqual(None, rcode)
JINMEI Tatuya's avatar
JINMEI Tatuya committed
285
        # This should be refused, therefore REFUSED
286
287
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.2', 12345))
288
289
        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(rcode.to_text(), "REFUSED")
290
291

        # TSIG signed request
292
        request_data = self.create_request_data(with_tsig=True)
293

294
295
        # If the TSIG check fails, it should not check ACL
        # (If it checked ACL as well, it would just drop the request)
296
297
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.1', 12345))
298
299
300
301
        self.xfrsess._tsig_key_ring = TSIGKeyRing()
        rcode, msg = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOTAUTH")
        self.assertTrue(self.xfrsess._tsig_ctx is not None)
302

303
        # ACL using TSIG: successful case
304
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
305
            {"key": "example.com", "action": "ACCEPT"}, {"action": "REJECT"}
306
        ]))
307
308
        self.assertEqual(TSIGKeyRing.SUCCESS,
                         self.xfrsess._tsig_key_ring.add(TSIG_KEY))
309
310
311
312
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOERROR")

        # ACL using TSIG: key name doesn't match; should be rejected
313
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
314
            {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}
315
        ]))
316
317
318
319
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "REFUSED")

        # ACL using TSIG: no TSIG; should be rejected
320
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
321
            {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}
322
        ]))
323
324
325
326
327
328
        [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(rcode.to_text(), "REFUSED")

        #
        # ACL using IP + TSIG: both should match
        #
329
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
330
331
332
                {"ALL": [{"key": "example.com"}, {"from": "192.0.2.1"}],
                 "action": "ACCEPT"},
                {"action": "REJECT"}
333
        ]))
334
        # both matches
335
336
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.1', 12345))
337
338
339
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOERROR")
        # TSIG matches, but address doesn't
340
341
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.2', 12345))
342
343
344
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "REFUSED")
        # Address matches, but TSIG doesn't (not included)
345
346
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.1', 12345))
347
348
349
        [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(rcode.to_text(), "REFUSED")
        # Neither address nor TSIG matches
350
351
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.2', 12345))
352
353
354
        [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(rcode.to_text(), "REFUSED")

355
    def test_transfer_acl(self):
356
        # ACL checks only with the default ACL
357
358
359
360
        def acl_setter(acl):
            self.xfrsess._acl = acl
        self.check_transfer_acl(acl_setter)

361
362
363
364
    def test_transfer_zoneacl(self):
        # ACL check with a per zone ACL + default ACL.  The per zone ACL
        # should match the queryied zone, so it should be used.
        def acl_setter(acl):
365
            zone_key = ('IN', 'example.com.')
366
367
368
369
370
371
372
373
374
375
            self.xfrsess._zone_config[zone_key] = {}
            self.xfrsess._zone_config[zone_key]['transfer_acl'] = acl
            self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
                    {"from": "127.0.0.1", "action": "DROP"}])
        self.check_transfer_acl(acl_setter)

    def test_transfer_zoneacl_nomatch(self):
        # similar to the previous one, but the per zone doesn't match the
        # query.  The default should be used.
        def acl_setter(acl):
376
            zone_key = ('IN', 'example.org.')
377
378
379
380
381
382
383
            self.xfrsess._zone_config[zone_key] = {}
            self.xfrsess._zone_config[zone_key]['transfer_acl'] = \
                isc.acl.dns.REQUEST_LOADER.load([
                    {"from": "127.0.0.1", "action": "DROP"}])
            self.xfrsess._acl = acl
        self.check_transfer_acl(acl_setter)

384
385
386
387
388
389
390
391
392
393
394
395
396
    def test_get_transfer_acl(self):
        # set the default ACL.  If there's no specific zone ACL, this one
        # should be used.
        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
                {"from": "127.0.0.1", "action": "ACCEPT"}])
        acl = self.xfrsess._get_transfer_acl(Name('example.com'), RRClass.IN())
        self.assertEqual(acl, self.xfrsess._acl)

        # install a per zone config with transfer ACL for example.com.  Then
        # that ACL will be used for example.com; for others the default ACL
        # will still be used.
        com_acl = isc.acl.dns.REQUEST_LOADER.load([
                {"from": "127.0.0.1", "action": "REJECT"}])
397
398
        self.xfrsess._zone_config[('IN', 'example.com.')] = {}
        self.xfrsess._zone_config[('IN', 'example.com.')]['transfer_acl'] = \
399
400
401
402
403
404
405
406
407
408
409
410
411
            com_acl
        self.assertEqual(com_acl,
                         self.xfrsess._get_transfer_acl(Name('example.com'),
                                                        RRClass.IN()))
        self.assertEqual(self.xfrsess._acl,
                         self.xfrsess._get_transfer_acl(Name('example.org'),
                                                        RRClass.IN()))

        # Name matching should be case insensitive.
        self.assertEqual(com_acl,
                         self.xfrsess._get_transfer_acl(Name('EXAMPLE.COM'),
                                                        RRClass.IN()))

Likun Zhang's avatar
Likun Zhang committed
412
413
414
415
416
417
418
    def test_send_data(self):
        self.xfrsess._send_data(self.sock, self.mdata)
        senddata = self.sock.readsent()
        self.assertEqual(senddata, self.mdata)

    def test_reply_xfrout_query_with_error_rcode(self):
        msg = self.getmsg()
419
        self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
Likun Zhang's avatar
Likun Zhang committed
420
        get_msg = self.sock.read_msg()
421
422
        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")

423
424
425
426
427
428
429
430
        # tsig signed message
        msg = self.getmsg()
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
        self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
        get_msg = self.sock.read_msg()
        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
        self.assertTrue(self.message_has_tsig(get_msg))

431
432
433
    def test_send_message(self):
        msg = self.getmsg()
        msg.make_response()
434
435
436
437
438
439
440
        # SOA record data with different cases
        soa_rrset = RRset(Name('Example.com.'), RRClass.IN(), RRType.SOA(),
                               RRTTL(3600))
        soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
                                  'master.Example.com. admin.exAmple.com. ' +
                                  '1234 3600 1800 2419200 7200'))
        msg.add_rrset(Message.SECTION_ANSWER, soa_rrset)
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
        self.xfrsess._send_message(self.sock, msg)
        send_out_data = self.sock.readsent()[2:]

        # CASE_INSENSITIVE compression mode
        render = MessageRenderer();
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
        msg.to_wire(render)
        self.assertNotEqual(render.get_data(), send_out_data)

        # CASE_SENSITIVE compression mode
        render.clear()
        render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
        msg.to_wire(render)
        self.assertEqual(render.get_data(), send_out_data)

Likun Zhang's avatar
Likun Zhang committed
457
458
459
460
461
462
463
464
465
466
    def test_clear_message(self):
        msg = self.getmsg()
        qid = msg.get_qid()
        opcode = msg.get_opcode()
        rcode = msg.get_rcode()

        self.xfrsess._clear_message(msg)
        self.assertEqual(msg.get_qid(), qid)
        self.assertEqual(msg.get_opcode(), opcode)
        self.assertEqual(msg.get_rcode(), rcode)
467
        self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))
Likun Zhang's avatar
Likun Zhang committed
468
469
470
471

    def test_send_message_with_last_soa(self):
        msg = self.getmsg()
        msg.make_response()
472
473
474

        # packet number less than TSIG_SIGN_EVERY_NTH
        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
475
476
477
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset, 0,
                                                 packet_neet_not_sign)
Likun Zhang's avatar
Likun Zhang committed
478
        get_msg = self.sock.read_msg()
479
480
        # tsig context is not exist
        self.assertFalse(self.message_has_tsig(get_msg))
Likun Zhang's avatar
Likun Zhang committed
481

482
483
484
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
Likun Zhang's avatar
Likun Zhang committed
485

486
        #answer_rrset_iter = section_iter(get_msg, section.ANSWER())
487
        answer = get_msg.get_section(Message.SECTION_ANSWER)[0]#answer_rrset_iter.get_rrset()
Likun Zhang's avatar
Likun Zhang committed
488
        self.assertEqual(answer.get_name().to_text(), "example.com.")
489
        self.assertEqual(answer.get_class(), RRClass("IN"))
Likun Zhang's avatar
Likun Zhang committed
490
        self.assertEqual(answer.get_type().to_text(), "SOA")
491
        rdata = answer.get_rdata()
492
        self.assertEqual(rdata[0], self.soa_rrset.get_rdata()[0])
Likun Zhang's avatar
Likun Zhang committed
493

494
495
        # msg is the TSIG_SIGN_EVERY_NTH one
        # sending the message with last soa together
496
497
498
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset, 0,
                                                 TSIG_SIGN_EVERY_NTH)
499
500
501
502
        get_msg = self.sock.read_msg()
        # tsig context is not exist
        self.assertFalse(self.message_has_tsig(get_msg))

503
    def test_send_message_with_last_soa_with_tsig(self):
504
        # create tsig context
505
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
506

507
508
        msg = self.getmsg()
        msg.make_response()
509
510
511
512
513

        # packet number less than TSIG_SIGN_EVERY_NTH
        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
        # msg is not the TSIG_SIGN_EVERY_NTH one
        # sending the message with last soa together
514
515
516
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset, 0,
                                                 packet_neet_not_sign)
517
518
519
520
521
522
523
524
525
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))

        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)

        # msg is the TSIG_SIGN_EVERY_NTH one
        # sending the message with last soa together
526
527
528
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset, 0,
                                                 TSIG_SIGN_EVERY_NTH)
529
530
531
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))

532
533
534
535
    def test_trigger_send_message_with_last_soa(self):
        rrset_a = RRset(Name("example.com"), RRClass.IN(), RRType.A(), RRTTL(3600))
        rrset_a.add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))

Likun Zhang's avatar
Likun Zhang committed
536
        msg = self.getmsg()
537
        msg.make_response()
538
        msg.add_rrset(Message.SECTION_ANSWER, rrset_a)
539

540
        # length larger than MAX-len(rrset)
541
542
        length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - \
            get_rrset_len(self.soa_rrset) + 1
543
544
545
546
        # packet number less than TSIG_SIGN_EVERY_NTH
        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1

        # give the function a value that is larger than MAX-len(rrset)
547
548
549
        # this should have triggered the sending of two messages
        # (1 with the rrset we added manually, and 1 that triggered
        # the sending in _with_last_soa)
550
551
552
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset,
                                                 length_need_split,
553
                                                 packet_neet_not_sign)
554
        get_msg = self.sock.read_msg()
555
        self.assertFalse(self.message_has_tsig(get_msg))
556
557
558
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
559

560
        answer = get_msg.get_section(Message.SECTION_ANSWER)[0]
561
562
563
564
565
566
567
        self.assertEqual(answer.get_name().to_text(), "example.com.")
        self.assertEqual(answer.get_class(), RRClass("IN"))
        self.assertEqual(answer.get_type().to_text(), "A")
        rdata = answer.get_rdata()
        self.assertEqual(rdata[0].to_text(), "192.0.2.1")

        get_msg = self.sock.read_msg()
568
        self.assertFalse(self.message_has_tsig(get_msg))
569
570
571
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 0)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
572

573
574
        #answer_rrset_iter = section_iter(get_msg, Message.SECTION_ANSWER)
        answer = get_msg.get_section(Message.SECTION_ANSWER)[0]
575
576
577
578
        self.assertEqual(answer.get_name().to_text(), "example.com.")
        self.assertEqual(answer.get_class(), RRClass("IN"))
        self.assertEqual(answer.get_type().to_text(), "SOA")
        rdata = answer.get_rdata()
579
        self.assertEqual(rdata[0], self.soa_rrset.get_rdata()[0])
580
581
582
583

        # and it should not have sent anything else
        self.assertEqual(0, len(self.sock.sendqueue))

584
585
586
587
    def test_trigger_send_message_with_last_soa_with_tsig(self):
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
        msg = self.getmsg()
        msg.make_response()
588
        msg.add_rrset(Message.SECTION_ANSWER, self.soa_rrset)
589
590

        # length larger than MAX-len(rrset)
591
592
        length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - \
            get_rrset_len(self.soa_rrset) + 1
593
594
595
596
597
598
599
        # packet number less than TSIG_SIGN_EVERY_NTH
        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1

        # give the function a value that is larger than MAX-len(rrset)
        # this should have triggered the sending of two messages
        # (1 with the rrset we added manually, and 1 that triggered
        # the sending in _with_last_soa)
600
601
602
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset,
                                                 length_need_split,
603
                                                 packet_neet_not_sign)
604
        get_msg = self.sock.read_msg()
605
606
607
        # msg is not the TSIG_SIGN_EVERY_NTH one, it shouldn't be tsig signed
        self.assertFalse(self.message_has_tsig(get_msg))
        # the last packet should be tsig signed
608
609
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))
610
611
        # and it should not have sent anything else
        self.assertEqual(0, len(self.sock.sendqueue))
612

613
614

        # msg is the TSIG_SIGN_EVERY_NTH one, it should be tsig signed
615
616
617
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset,
                                                 length_need_split,
618
619
620
621
622
623
                                                 xfrout.TSIG_SIGN_EVERY_NTH)
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))
        # the last packet should be tsig signed
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))
624
625
626
        # and it should not have sent anything else
        self.assertEqual(0, len(self.sock.sendqueue))

627
    def test_get_rrset_len(self):
628
        self.assertEqual(82, get_rrset_len(self.soa_rrset))
Likun Zhang's avatar
Likun Zhang committed
629
630

    def test_check_xfrout_available(self):
631
        self.xfrsess.ClientClass = MockDataSrcClient
632
633
        self.assertEqual(self.xfrsess._check_xfrout_available(
                Name('example.com')), Rcode.NOERROR())
634
635
636
637
        self.assertEqual(self.xfrsess._check_xfrout_available(
                Name('notauth.example.com')), Rcode.NOTAUTH())
        self.assertEqual(self.xfrsess._check_xfrout_available(
                Name('nosoa.example.com')), Rcode.SERVFAIL())
638
639
        self.assertEqual(self.xfrsess._check_xfrout_available(
                Name('multisoa.example.com')), Rcode.SERVFAIL())
640

Likun Zhang's avatar
Likun Zhang committed
641
642
643
644
645
    def test_dns_xfrout_start_formerror(self):
        # formerror
        self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
        sent_data = self.sock.readsent()
        self.assertEqual(len(sent_data), 0)
646

Likun Zhang's avatar
Likun Zhang committed
647
648
649
650
651
    def default(self, param):
        return "example.com"

    def test_dns_xfrout_start_notauth(self):
        def notauth(formpara):
652
            return Rcode.NOTAUTH()
Likun Zhang's avatar
Likun Zhang committed
653
654
655
656
        self.xfrsess._check_xfrout_available = notauth
        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
        get_msg = self.sock.read_msg()
        self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
657

658
659
660
661
662
663
664
    def test_dns_xfrout_start_datasrc_servfail(self):
        def internal_raise(x, y):
            raise isc.datasrc.Error('exception for the sake of test')
        self.xfrsess.ClientClass = internal_raise
        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
        self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.SERVFAIL())

Likun Zhang's avatar
Likun Zhang committed
665
666
    def test_dns_xfrout_start_noerror(self):
        def noerror(form):
667
            return Rcode.NOERROR()
Likun Zhang's avatar
Likun Zhang committed
668
        self.xfrsess._check_xfrout_available = noerror
669

670
        def myreply(msg, sock):
Likun Zhang's avatar
Likun Zhang committed
671
            self.sock.send(b"success")
672

Likun Zhang's avatar
Likun Zhang committed
673
674
675
        self.xfrsess._reply_xfrout_query = myreply
        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
        self.assertEqual(self.sock.readsent(), b"success")
676

Likun Zhang's avatar
Likun Zhang committed
677
    def test_reply_xfrout_query_noerror(self):
678
        self.xfrsess._soa = self.soa_rrset
679
        self.xfrsess._iterator = [self.soa_rrset]
680
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)
Likun Zhang's avatar
Likun Zhang committed
681
        reply_msg = self.sock.read_msg()
682
        self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2)
Likun Zhang's avatar
Likun Zhang committed
683

684
    def test_reply_xfrout_query_noerror_with_tsig(self):
685
686
687
        rrset = RRset(Name('a.example.com'), RRClass.IN(), RRType.A(),
                      RRTTL(3600))
        rrset.add_rdata(Rdata(RRType.A(), RRClass.IN(), '192.0.2.1'))
688
689
690
691
692
        global xfrout

        def get_rrset_len(rrset):
            return 65520

693
        self.xfrsess._soa = self.soa_rrset
694
        self.xfrsess._iterator = [rrset for i in range(0, 100)]
695
696
697
        xfrout.get_rrset_len = get_rrset_len

        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
698
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721

        # tsig signed first package
        reply_msg = self.sock.read_msg()
        self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertTrue(self.message_has_tsig(reply_msg))
        # (TSIG_SIGN_EVERY_NTH - 1) packets have no tsig
        for i in range(0, xfrout.TSIG_SIGN_EVERY_NTH - 1):
            reply_msg = self.sock.read_msg()
            self.assertFalse(self.message_has_tsig(reply_msg))
        # TSIG_SIGN_EVERY_NTH packet has tsig
        reply_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(reply_msg))

        for i in range(0, 100 - TSIG_SIGN_EVERY_NTH):
            reply_msg = self.sock.read_msg()
            self.assertFalse(self.message_has_tsig(reply_msg))
        # tsig signed last package
        reply_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(reply_msg))

        # and it should not have sent anything else
        self.assertEqual(0, len(self.sock.sendqueue))

722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750

class TestXfroutSessionWithSQLite3(TestXfroutSessionBase):
    '''Tests for XFR-out sessions using an SQLite3 DB.

    These are provided mainly to confirm the implementation actually works
    in an environment closer to actual operational environments.  So we
    only check a few common cases; other details are tested using mock
    data sources.

    '''
    def setUp(self):
        super().setUp()
        self.xfrsess._request_data = self.mdata
        self.xfrsess._server.get_db_file = lambda : TESTDATA_SRCDIR + \
            'test.sqlite3'

    def test_axfr_normal_session(self):
        XfroutSession._handle(self.xfrsess)
        response = self.sock.read_msg(Message.PRESERVE_ORDER);
        self.assertEqual(Rcode.NOERROR(), response.get_rcode())
        # This zone contains two A RRs for the same name with different TTLs.
        # These TTLs should be preseved in the AXFR stream.
        actual_ttls = []
        for rr in response.get_section(Message.SECTION_ANSWER):
            if rr.get_type() == RRType.A() and \
                    not rr.get_ttl() in actual_ttls:
                actual_ttls.append(rr.get_ttl().get_value())
        self.assertEqual([3600, 7200], sorted(actual_ttls))

Likun Zhang's avatar
Likun Zhang committed
751
752
753
class MyUnixSockServer(UnixSockServer):
    def __init__(self):
        self._shutdown_event = threading.Event()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
754
        self._common_init()
755
756
        self._cc = MyCCSession()
        self.update_config_data(self._cc.get_full_config())
Likun Zhang's avatar
Likun Zhang committed
757
758
759

class TestUnixSockServer(unittest.TestCase):
    def setUp(self):
760
        self.write_sock, self.read_sock = socket.socketpair()
Likun Zhang's avatar
Likun Zhang committed
761
        self.unix = MyUnixSockServer()
762

763
764
765
766
767
    def test_guess_remote(self):
        """Test we can guess the remote endpoint when we have only the
           file descriptor. This is needed, because we get only that one
           from auth."""
        # We test with UDP, as it can be "connected" without other
768
769
        # endpoint.  Note that in the current implementation _guess_remote()
        # unconditionally returns SOCK_STREAM.
770
771
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.connect(('127.0.0.1', 12345))
772
773
        self.assertEqual((socket.AF_INET, socket.SOCK_STREAM,
                          ('127.0.0.1', 12345)),
774
                         self.unix._guess_remote(sock.fileno()))
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
775
776
777
778
        if socket.has_ipv6:
            # Don't check IPv6 address on hosts not supporting them
            sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
            sock.connect(('::1', 12345))
779
780
            self.assertEqual((socket.AF_INET6, socket.SOCK_STREAM,
                              ('::1', 12345, 0, 0)),
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
781
782
783
784
785
786
                             self.unix._guess_remote(sock.fileno()))
            # Try when pretending there's no IPv6 support
            # (No need to pretend when there's really no IPv6)
            xfrout.socket.has_ipv6 = False
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.connect(('127.0.0.1', 12345))
787
788
            self.assertEqual((socket.AF_INET, socket.SOCK_STREAM,
                              ('127.0.0.1', 12345)),
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
789
790
791
                             self.unix._guess_remote(sock.fileno()))
            # Return it back
            xfrout.socket.has_ipv6 = True
792

793
794
795
796
797
798
799
800
    def test_receive_query_message(self):
        send_msg = b"\xd6=\x00\x00\x00\x01\x00"
        msg_len = struct.pack('H', socket.htons(len(send_msg)))
        self.write_sock.send(msg_len)
        self.write_sock.send(send_msg)
        recv_msg = self.unix._receive_query_message(self.read_sock)
        self.assertEqual(recv_msg, send_msg)

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
801
802
    def check_default_ACL(self):
        context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
803
804
                                             1234, 0, socket.SOCK_DGRAM,
                                             socket.IPPROTO_UDP,
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
805
                                             socket.AI_NUMERICHOST)[0][4])
806
        self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
807

808
    def check_loaded_ACL(self, acl):
809
        context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
810
811
                                             1234, 0, socket.SOCK_DGRAM,
                                             socket.IPPROTO_UDP,
812
                                             socket.AI_NUMERICHOST)[0][4])
813
        self.assertEqual(isc.acl.acl.ACCEPT, acl.execute(context))
814
        context = isc.acl.dns.RequestContext(socket.getaddrinfo("192.0.2.1",
815
816
                                             1234, 0, socket.SOCK_DGRAM,
                                             socket.IPPROTO_UDP,
817
                                             socket.AI_NUMERICHOST)[0][4])
818
        self.assertEqual(isc.acl.acl.REJECT, acl.execute(context))
819

820
    def test_update_config_data(self):
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
821
        self.check_default_ACL()
822
823
824
        tsig_key_str = 'example.com:SFuWd/q99SzF8Yzd1QbB9g=='
        tsig_key_list = [tsig_key_str]
        bad_key_list = ['bad..example.com:SFuWd/q99SzF8Yzd1QbB9g==']
825
        self.unix.update_config_data({'transfers_out':10 })
Likun Zhang's avatar
Likun Zhang committed
826
        self.assertEqual(self.unix._max_transfers_out, 10)
827
        self.assertTrue(self.unix.tsig_key_ring is not None)
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
828
        self.check_default_ACL()
829

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
830
831
        self.unix.update_config_data({'transfers_out':9,
                                      'tsig_key_ring':tsig_key_list})
832
833
834
835
836
837
838
839
840
        self.assertEqual(self.unix._max_transfers_out, 9)
        self.assertEqual(self.unix.tsig_key_ring.size(), 1)
        self.unix.tsig_key_ring.remove(Name("example.com."))
        self.assertEqual(self.unix.tsig_key_ring.size(), 0)

        # bad tsig key
        config_data = {'transfers_out':9, 'tsig_key_ring': bad_key_list}
        self.assertRaises(None, self.unix.update_config_data(config_data))
        self.assertEqual(self.unix.tsig_key_ring.size(), 0)
Likun Zhang's avatar
Likun Zhang committed
841

842
        # Load the ACL
843
        self.unix.update_config_data({'transfer_acl': [{'from': '127.0.0.1',
844
                                               'action': 'ACCEPT'}]})
845
        self.check_loaded_ACL(self.unix._acl)
846
        # Pass a wrong data there and check it does not replace the old one
847
        self.assertRaises(XfroutConfigError,
848
                          self.unix.update_config_data,
849
                          {'transfer_acl': ['Something bad']})
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
        self.check_loaded_ACL(self.unix._acl)

    def test_zone_config_data(self):
        # By default, there's no specific zone config
        self.assertEqual({}, self.unix._zone_config)

        # Adding config for a specific zone.  The config is empty unless
        # explicitly specified.
        self.unix.update_config_data({'zone_config':
                                          [{'origin': 'example.com',
                                            'class': 'IN'}]})
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])

        # zone class can be omitted
        self.unix.update_config_data({'zone_config':
                                          [{'origin': 'example.com'}]})
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])

        # zone class, name are stored in the "normalized" form.  class
        # strings are upper cased, names are down cased.
        self.unix.update_config_data({'zone_config':
                                          [{'origin': 'EXAMPLE.com'}]})
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])

        # invalid zone class, name will result in exceptions
        self.assertRaises(EmptyLabel,
                          self.unix.update_config_data,
                          {'zone_config': [{'origin': 'bad..example'}]})
        self.assertRaises(InvalidRRClass,
                          self.unix.update_config_data,
                          {'zone_config': [{'origin': 'example.com',
                                            'class': 'badclass'}]})

        # Configuring a couple of more zones
        self.unix.update_config_data({'zone_config':
                                          [{'origin': 'example.com'},
                                           {'origin': 'example.com',
                                            'class': 'CH'},
                                           {'origin': 'example.org'}]})
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])
        self.assertEqual({}, self.unix._zone_config[('CH', 'example.com.')])
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.org.')])

        # Duplicate data: should be rejected with an exception
894
        self.assertRaises(XfroutConfigError,
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
                          self.unix.update_config_data,
                          {'zone_config': [{'origin': 'example.com'},
                                           {'origin': 'example.org'},
                                           {'origin': 'example.com'}]})

    def test_zone_config_data_with_acl(self):
        # Similar to the previous test, but with transfer_acl config
        self.unix.update_config_data({'zone_config':
                                          [{'origin': 'example.com',
                                            'transfer_acl':
                                                [{'from': '127.0.0.1',
                                                  'action': 'ACCEPT'}]}]})
        acl = self.unix._zone_config[('IN', 'example.com.')]['transfer_acl']
        self.check_loaded_ACL(acl)

        # invalid ACL syntax will be rejected with exception
911
        self.assertRaises(XfroutConfigError,
912
913
914
915
                          self.unix.update_config_data,
                          {'zone_config': [{'origin': 'example.com',
                                            'transfer_acl':
                                                [{'action': 'BADACTION'}]}]})
916

Likun Zhang's avatar
Likun Zhang committed
917
918
919
920
921
922
923
924
925
926
927
928
929
    def test_get_db_file(self):
        self.assertEqual(self.unix.get_db_file(), "initdb.file")

    def test_increase_transfers_counter(self):
        self.unix._max_transfers_out = 10
        count = self.unix._transfers_counter
        self.assertEqual(self.unix.increase_transfers_counter(), True)
        self.assertEqual(count + 1, self.unix._transfers_counter)

        self.unix._max_transfers_out = 0
        count = self.unix._transfers_counter
        self.assertEqual(self.unix.increase_transfers_counter(), False)
        self.assertEqual(count, self.unix._transfers_counter)
930

Likun Zhang's avatar
Likun Zhang committed
931
932
933
934
935
    def test_decrease_transfers_counter(self):
        count = self.unix._transfers_counter
        self.unix.decrease_transfers_counter()
        self.assertEqual(count - 1, self.unix._transfers_counter)

936
937
938
939
940
    def _remove_file(self, sock_file):
        try:
            os.remove(sock_file)
        except OSError:
            pass
941

942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
    def test_sock_file_in_use_file_exist(self):
        sock_file = 'temp.sock.file'
        self._remove_file(sock_file)
        self.assertFalse(self.unix._sock_file_in_use(sock_file))
        self.assertFalse(os.path.exists(sock_file))

    def test_sock_file_in_use_file_not_exist(self):
        self.assertFalse(self.unix._sock_file_in_use('temp.sock.file'))

    def _start_unix_sock_server(self, sock_file):
        serv = ThreadingUnixStreamServer(sock_file, BaseRequestHandler)
        serv_thread = threading.Thread(target=serv.serve_forever)
        serv_thread.setDaemon(True)
        serv_thread.start()

    def test_sock_file_in_use(self):
        sock_file = 'temp.sock.file'
        self._remove_file(sock_file)
        self.assertFalse(self.unix._sock_file_in_use(sock_file))
        self._start_unix_sock_server(sock_file)

        old_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
        self.assertTrue(self.unix._sock_file_in_use(sock_file))
        sys.stdout = old_stdout

    def test_remove_unused_sock_file_in_use(self):
        sock_file = 'temp.sock.file'
        self._remove_file(sock_file)
        self.assertFalse(self.unix._sock_file_in_use(sock_file))
        self._start_unix_sock_server(sock_file)
        old_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
        try:
            self.unix._remove_unused_sock_file(sock_file)
        except SystemExit:
            pass
        else:
            # This should never happen
            self.assertTrue(False)

        sys.stdout = old_stdout

    def test_remove_unused_sock_file_dir(self):
        import tempfile
        dir_name = tempfile.mkdtemp()
        old_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
        try:
            self.unix._remove_unused_sock_file(dir_name)
        except SystemExit:
            pass
        else:
            # This should never happen
            self.assertTrue(False)

        sys.stdout = old_stdout
        os.rmdir(dir_name)
Likun Zhang's avatar
Likun Zhang committed
1000