xfrout_test.py.in 41.9 KB
Newer Older
Likun Zhang's avatar
Likun Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (C) 2010  Internet Systems Consortium.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

'''Tests for the XfroutSession and UnixSockServer classes '''


import unittest
import os
21
from isc.testutils.tsigctx_mock import MockTSIGContext
Likun Zhang's avatar
Likun Zhang committed
22
from isc.cc.session import *
23
import isc.config
24
from isc.dns import *
Likun Zhang's avatar
Likun Zhang committed
25
from xfrout import *
26
import xfrout
27
import isc.log
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
28
import isc.acl.dns
Likun Zhang's avatar
Likun Zhang committed
29

30
31
TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")

Likun Zhang's avatar
Likun Zhang committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# our fake socket, where we can read and insert messages
class MySocket():
    def __init__(self, family, type):
        self.family = family
        self.type = type
        self.sendqueue = bytearray()

    def connect(self, to):
        pass

    def close(self):
        pass

    def send(self, data):
        self.sendqueue.extend(data);
        return len(data)

    def readsent(self):
50
51
52
53
54
55
        if len(self.sendqueue) >= 2:
            size = 2 + struct.unpack("!H", self.sendqueue[:2])[0]
        else:
            size = 0
        result = self.sendqueue[:size]
        self.sendqueue = self.sendqueue[size:]
Likun Zhang's avatar
Likun Zhang committed
56
        return result
57

Likun Zhang's avatar
Likun Zhang committed
58
59
    def read_msg(self):
        sent_data = self.readsent()
60
61
        get_msg = Message(Message.PARSE)
        get_msg.from_wire(bytes(sent_data[2:]))
Likun Zhang's avatar
Likun Zhang committed
62
        return get_msg
63

Likun Zhang's avatar
Likun Zhang committed
64
65
66
    def clear_send(self):
        del self.sendqueue[:]

67
68
69
70
71
72
73
74
75
76
77
78
79
class MockDataSrcClient:
    def __init__(self, type, config):
        pass

    def get_iterator(self, zone_name):
        if zone_name == Name('notauth.example.com'):
            raise isc.datasrc.Error('no such zone')
        self._zone_name = zone_name
        return self

    def get_soa(self):  # emulate ZoneIterator.get_soa()
        if self._zone_name == Name('nosoa.example.com'):
            return None
80
81
        soa_rrset = RRset(self._zone_name, RRClass.IN(), RRType.SOA(),
                          RRTTL(3600))
82
83
84
85
86
87
88
89
90
91
92
        soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
                                  'master.example.com. ' +
                                  'admin.example.com. 1234 ' +
                                  '3600 1800 2419200 7200'))
        if self._zone_name == Name('multisoa.example.com'):
            soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
                                      'master.example.com. ' +
                                      'admin.example.com. 1300 ' +
                                      '3600 1800 2419200 7200'))
        return soa_rrset

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
class MyCCSession(isc.config.ConfigData):
    def __init__(self):
        module_spec = isc.config.module_spec_from_file(
            xfrout.SPECFILE_LOCATION)
        ConfigData.__init__(self, module_spec)

    def get_remote_config_value(self, module_name, identifier):
        if module_name == "Auth" and identifier == "database_file":
            return "initdb.file", False
        else:
            return "unknown", False

# This constant dictionary stores all default configuration parameters
# defined in the xfrout spec file.
DEFAULT_CONFIG = MyCCSession().get_full_config()

109
110
# We subclass the Session class we're testing here, only overriding a few
# methods
Likun Zhang's avatar
Likun Zhang committed
111
class MyXfroutSession(XfroutSession):
112
113
114
115
    def _handle(self):
        pass

    def _close_socket(self):
Likun Zhang's avatar
Likun Zhang committed
116
        pass
117
118
119
120
121
122
123
124

    def _send_data(self, sock, data):
        size = len(data)
        total_count = 0
        while total_count < size:
            count = sock.send(data[total_count:])
            total_count += count

Likun Zhang's avatar
Likun Zhang committed
125
126
127
class Dbserver:
    def __init__(self):
        self._shutdown_event = threading.Event()
128
        self.transfer_counter = 0
129
        self._max_transfers_out = DEFAULT_CONFIG['transfers_out']
Likun Zhang's avatar
Likun Zhang committed
130
    def get_db_file(self):
131
        return 'test.sqlite3'
132
    def increase_transfers_counter(self):
133
        self.transfer_counter += 1
134
        return True
Likun Zhang's avatar
Likun Zhang committed
135
    def decrease_transfers_counter(self):
136
        self.transfer_counter -= 1
Likun Zhang's avatar
Likun Zhang committed
137
138
139

class TestXfroutSession(unittest.TestCase):
    def getmsg(self):
140
141
        msg = Message(Message.PARSE)
        msg.from_wire(self.mdata)
Likun Zhang's avatar
Likun Zhang committed
142
143
        return msg

144
145
146
147
148
149
150
151
152
153
154
    def create_mock_tsig_ctx(self, error):
        # This helper function creates a MockTSIGContext for a given key
        # and TSIG error to be used as a result of verify (normally faked
        # one)
        mock_ctx = MockTSIGContext(TSIG_KEY)
        mock_ctx.error = error
        return mock_ctx

    def message_has_tsig(self, msg):
        return msg.get_tsig_record() is not None

155
    def create_request_data(self, with_question=True, with_tsig=False):
156
157
158
159
160
        msg = Message(Message.RENDER)
        query_id = 0x1035
        msg.set_qid(query_id)
        msg.set_opcode(Opcode.QUERY())
        msg.set_rcode(Rcode.NOERROR())
161
162
163
        if with_question:
            msg.add_question(Question(Name("example.com"), RRClass.IN(),
                                      RRType.AXFR()))
164
165

        renderer = MessageRenderer()
166
167
168
169
170
171
172
        if with_tsig:
            tsig_ctx = MockTSIGContext(TSIG_KEY)
            msg.to_wire(renderer, tsig_ctx)
        else:
            msg.to_wire(renderer)
        request_data = renderer.get_data()
        return request_data
173

Likun Zhang's avatar
Likun Zhang committed
174
    def setUp(self):
175
        self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
176
        self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
177
178
179
                                       TSIGKeyRing(),
                                       (socket.AF_INET, socket.SOCK_STREAM,
                                        ('127.0.0.1', 12345)),
180
181
                                       # When not testing ACLs, simply accept
                                       isc.acl.dns.REQUEST_LOADER.load(
182
183
                                           [{"action": "ACCEPT"}]),
                                       {})
184
        self.mdata = self.create_request_data()
185
186
187
188
189
190
        self.soa_rrset = RRset(Name('example.com'), RRClass.IN(), RRType.SOA(),
                               RRTTL(3600))
        self.soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
                                       'master.Example.com. ' +
                                       'admin.exAmple.com. ' +
                                       '1234 3600 1800 2419200 7200'))
Likun Zhang's avatar
Likun Zhang committed
191

192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    def tearDown(self):
        # transfer_counter must be always be reset no matter happens within
        # the XfroutSession object.  We check the condition here.
        self.assertEqual(0, self.xfrsess._server.transfer_counter)

    def test_quota_error(self):
        '''Emulating the server being too busy.

        '''
        self.xfrsess._request_data = self.mdata
        self.xfrsess._server.increase_transfers_counter = lambda : False
        XfroutSession._handle(self.xfrsess)
        self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.REFUSED())

    def test_quota_ok(self):
        '''The default case in terms of the xfrout quota.

        '''
        # set up a bogus request, which should result in FORMERR. (it only
        # has to be something that is different from the previous case)
        self.xfrsess._request_data = \
            self.create_request_data(with_question=False)
        # Replace the data source client to avoid datasrc related exceptions
        self.xfrsess.ClientClass = MockDataSrcClient
        XfroutSession._handle(self.xfrsess)
        self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.FORMERR())

    def test_exception_from_session(self):
        '''Test the case where the main processing raises an exception.

JINMEI Tatuya's avatar
JINMEI Tatuya committed
222
        We just check it doesn't any unexpected disruption and (in tearDown)
223
224
225
226
227
228
229
230
        transfer_counter is correctly reset to 0.

        '''
        def dns_xfrout_start(fd, msg, quota):
            raise ValueError('fake exception')
        self.xfrsess.dns_xfrout_start = dns_xfrout_start
        XfroutSession._handle(self.xfrsess)

Likun Zhang's avatar
Likun Zhang committed
231
232
233
234
    def test_parse_query_message(self):
        [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(get_rcode.to_text(), "NOERROR")

235
236
237
238
239
        # Broken request: no question
        request_data = self.create_request_data(with_question=False)
        rcode, msg = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(Rcode.FORMERR(), rcode)

240
        # tsig signed query message
241
        request_data = self.create_request_data(with_tsig=True)
242
243
244
245
246
        # BADKEY
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOTAUTH")
        self.assertTrue(self.xfrsess._tsig_ctx is not None)
        # NOERROR
247
248
        self.assertEqual(TSIGKeyRing.SUCCESS,
                         self.xfrsess._tsig_key_ring.add(TSIG_KEY))
249
250
251
252
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOERROR")
        self.assertTrue(self.xfrsess._tsig_ctx is not None)

253
    def check_transfer_acl(self, acl_setter):
254
        # ACL checks, put some ACL inside
255
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
256
257
258
259
260
261
262
263
            {
                "from": "127.0.0.1",
                "action": "ACCEPT"
            },
            {
                "from": "192.0.2.1",
                "action": "DROP"
            }
264
        ]))
265
266
267
268
        # Localhost (the default in this test) is accepted
        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(rcode.to_text(), "NOERROR")
        # This should be dropped completely, therefore returning None
269
270
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.1', 12345))
271
        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
272
        self.assertEqual(None, rcode)
JINMEI Tatuya's avatar
JINMEI Tatuya committed
273
        # This should be refused, therefore REFUSED
274
275
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.2', 12345))
276
277
        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(rcode.to_text(), "REFUSED")
278
279

        # TSIG signed request
280
        request_data = self.create_request_data(with_tsig=True)
281

282
283
        # If the TSIG check fails, it should not check ACL
        # (If it checked ACL as well, it would just drop the request)
284
285
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.1', 12345))
286
287
288
289
        self.xfrsess._tsig_key_ring = TSIGKeyRing()
        rcode, msg = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOTAUTH")
        self.assertTrue(self.xfrsess._tsig_ctx is not None)
290

291
        # ACL using TSIG: successful case
292
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
293
            {"key": "example.com", "action": "ACCEPT"}, {"action": "REJECT"}
294
        ]))
295
296
        self.assertEqual(TSIGKeyRing.SUCCESS,
                         self.xfrsess._tsig_key_ring.add(TSIG_KEY))
297
298
299
300
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOERROR")

        # ACL using TSIG: key name doesn't match; should be rejected
301
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
302
            {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}
303
        ]))
304
305
306
307
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "REFUSED")

        # ACL using TSIG: no TSIG; should be rejected
308
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
309
            {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}
310
        ]))
311
312
313
314
315
316
        [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(rcode.to_text(), "REFUSED")

        #
        # ACL using IP + TSIG: both should match
        #
317
        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
318
319
320
                {"ALL": [{"key": "example.com"}, {"from": "192.0.2.1"}],
                 "action": "ACCEPT"},
                {"action": "REJECT"}
321
        ]))
322
        # both matches
323
324
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.1', 12345))
325
326
327
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOERROR")
        # TSIG matches, but address doesn't
328
329
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.2', 12345))
330
331
332
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "REFUSED")
        # Address matches, but TSIG doesn't (not included)
333
334
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.1', 12345))
335
336
337
        [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(rcode.to_text(), "REFUSED")
        # Neither address nor TSIG matches
338
339
        self.xfrsess._remote = (socket.AF_INET, socket.SOCK_STREAM,
                                ('192.0.2.2', 12345))
340
341
342
        [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(rcode.to_text(), "REFUSED")

343
    def test_transfer_acl(self):
344
        # ACL checks only with the default ACL
345
346
347
348
        def acl_setter(acl):
            self.xfrsess._acl = acl
        self.check_transfer_acl(acl_setter)

349
350
351
352
    def test_transfer_zoneacl(self):
        # ACL check with a per zone ACL + default ACL.  The per zone ACL
        # should match the queryied zone, so it should be used.
        def acl_setter(acl):
353
            zone_key = ('IN', 'example.com.')
354
355
356
357
358
359
360
361
362
363
            self.xfrsess._zone_config[zone_key] = {}
            self.xfrsess._zone_config[zone_key]['transfer_acl'] = acl
            self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
                    {"from": "127.0.0.1", "action": "DROP"}])
        self.check_transfer_acl(acl_setter)

    def test_transfer_zoneacl_nomatch(self):
        # similar to the previous one, but the per zone doesn't match the
        # query.  The default should be used.
        def acl_setter(acl):
364
            zone_key = ('IN', 'example.org.')
365
366
367
368
369
370
371
            self.xfrsess._zone_config[zone_key] = {}
            self.xfrsess._zone_config[zone_key]['transfer_acl'] = \
                isc.acl.dns.REQUEST_LOADER.load([
                    {"from": "127.0.0.1", "action": "DROP"}])
            self.xfrsess._acl = acl
        self.check_transfer_acl(acl_setter)

372
373
374
375
376
377
378
379
380
381
382
383
384
    def test_get_transfer_acl(self):
        # set the default ACL.  If there's no specific zone ACL, this one
        # should be used.
        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
                {"from": "127.0.0.1", "action": "ACCEPT"}])
        acl = self.xfrsess._get_transfer_acl(Name('example.com'), RRClass.IN())
        self.assertEqual(acl, self.xfrsess._acl)

        # install a per zone config with transfer ACL for example.com.  Then
        # that ACL will be used for example.com; for others the default ACL
        # will still be used.
        com_acl = isc.acl.dns.REQUEST_LOADER.load([
                {"from": "127.0.0.1", "action": "REJECT"}])
385
386
        self.xfrsess._zone_config[('IN', 'example.com.')] = {}
        self.xfrsess._zone_config[('IN', 'example.com.')]['transfer_acl'] = \
387
388
389
390
391
392
393
394
395
396
397
398
399
            com_acl
        self.assertEqual(com_acl,
                         self.xfrsess._get_transfer_acl(Name('example.com'),
                                                        RRClass.IN()))
        self.assertEqual(self.xfrsess._acl,
                         self.xfrsess._get_transfer_acl(Name('example.org'),
                                                        RRClass.IN()))

        # Name matching should be case insensitive.
        self.assertEqual(com_acl,
                         self.xfrsess._get_transfer_acl(Name('EXAMPLE.COM'),
                                                        RRClass.IN()))

Likun Zhang's avatar
Likun Zhang committed
400
401
402
403
404
405
406
    def test_send_data(self):
        self.xfrsess._send_data(self.sock, self.mdata)
        senddata = self.sock.readsent()
        self.assertEqual(senddata, self.mdata)

    def test_reply_xfrout_query_with_error_rcode(self):
        msg = self.getmsg()
407
        self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
Likun Zhang's avatar
Likun Zhang committed
408
        get_msg = self.sock.read_msg()
409
410
        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")

411
412
413
414
415
416
417
418
        # tsig signed message
        msg = self.getmsg()
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
        self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
        get_msg = self.sock.read_msg()
        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
        self.assertTrue(self.message_has_tsig(get_msg))

419
420
421
    def test_send_message(self):
        msg = self.getmsg()
        msg.make_response()
422
423
424
425
426
427
428
        # SOA record data with different cases
        soa_rrset = RRset(Name('Example.com.'), RRClass.IN(), RRType.SOA(),
                               RRTTL(3600))
        soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
                                  'master.Example.com. admin.exAmple.com. ' +
                                  '1234 3600 1800 2419200 7200'))
        msg.add_rrset(Message.SECTION_ANSWER, soa_rrset)
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        self.xfrsess._send_message(self.sock, msg)
        send_out_data = self.sock.readsent()[2:]

        # CASE_INSENSITIVE compression mode
        render = MessageRenderer();
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
        msg.to_wire(render)
        self.assertNotEqual(render.get_data(), send_out_data)

        # CASE_SENSITIVE compression mode
        render.clear()
        render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
        msg.to_wire(render)
        self.assertEqual(render.get_data(), send_out_data)

Likun Zhang's avatar
Likun Zhang committed
445
446
447
448
449
450
451
452
453
454
    def test_clear_message(self):
        msg = self.getmsg()
        qid = msg.get_qid()
        opcode = msg.get_opcode()
        rcode = msg.get_rcode()

        self.xfrsess._clear_message(msg)
        self.assertEqual(msg.get_qid(), qid)
        self.assertEqual(msg.get_opcode(), opcode)
        self.assertEqual(msg.get_rcode(), rcode)
455
        self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))
Likun Zhang's avatar
Likun Zhang committed
456
457
458
459

    def test_send_message_with_last_soa(self):
        msg = self.getmsg()
        msg.make_response()
460
461
462

        # packet number less than TSIG_SIGN_EVERY_NTH
        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
463
464
465
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset, 0,
                                                 packet_neet_not_sign)
Likun Zhang's avatar
Likun Zhang committed
466
        get_msg = self.sock.read_msg()
467
468
        # tsig context is not exist
        self.assertFalse(self.message_has_tsig(get_msg))
Likun Zhang's avatar
Likun Zhang committed
469

470
471
472
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
Likun Zhang's avatar
Likun Zhang committed
473

474
        #answer_rrset_iter = section_iter(get_msg, section.ANSWER())
475
        answer = get_msg.get_section(Message.SECTION_ANSWER)[0]#answer_rrset_iter.get_rrset()
Likun Zhang's avatar
Likun Zhang committed
476
        self.assertEqual(answer.get_name().to_text(), "example.com.")
477
        self.assertEqual(answer.get_class(), RRClass("IN"))
Likun Zhang's avatar
Likun Zhang committed
478
        self.assertEqual(answer.get_type().to_text(), "SOA")
479
        rdata = answer.get_rdata()
480
        self.assertEqual(rdata[0], self.soa_rrset.get_rdata()[0])
Likun Zhang's avatar
Likun Zhang committed
481

482
483
        # msg is the TSIG_SIGN_EVERY_NTH one
        # sending the message with last soa together
484
485
486
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset, 0,
                                                 TSIG_SIGN_EVERY_NTH)
487
488
489
490
        get_msg = self.sock.read_msg()
        # tsig context is not exist
        self.assertFalse(self.message_has_tsig(get_msg))

491
    def test_send_message_with_last_soa_with_tsig(self):
492
        # create tsig context
493
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
494

495
496
        msg = self.getmsg()
        msg.make_response()
497
498
499
500
501

        # packet number less than TSIG_SIGN_EVERY_NTH
        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
        # msg is not the TSIG_SIGN_EVERY_NTH one
        # sending the message with last soa together
502
503
504
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset, 0,
                                                 packet_neet_not_sign)
505
506
507
508
509
510
511
512
513
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))

        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)

        # msg is the TSIG_SIGN_EVERY_NTH one
        # sending the message with last soa together
514
515
516
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset, 0,
                                                 TSIG_SIGN_EVERY_NTH)
517
518
519
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))

520
521
522
523
    def test_trigger_send_message_with_last_soa(self):
        rrset_a = RRset(Name("example.com"), RRClass.IN(), RRType.A(), RRTTL(3600))
        rrset_a.add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))

Likun Zhang's avatar
Likun Zhang committed
524
        msg = self.getmsg()
525
        msg.make_response()
526
        msg.add_rrset(Message.SECTION_ANSWER, rrset_a)
527

528
        # length larger than MAX-len(rrset)
529
530
        length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - \
            get_rrset_len(self.soa_rrset) + 1
531
532
533
534
        # packet number less than TSIG_SIGN_EVERY_NTH
        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1

        # give the function a value that is larger than MAX-len(rrset)
535
536
537
        # this should have triggered the sending of two messages
        # (1 with the rrset we added manually, and 1 that triggered
        # the sending in _with_last_soa)
538
539
540
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset,
                                                 length_need_split,
541
                                                 packet_neet_not_sign)
542
        get_msg = self.sock.read_msg()
543
        self.assertFalse(self.message_has_tsig(get_msg))
544
545
546
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
547

548
        answer = get_msg.get_section(Message.SECTION_ANSWER)[0]
549
550
551
552
553
554
555
        self.assertEqual(answer.get_name().to_text(), "example.com.")
        self.assertEqual(answer.get_class(), RRClass("IN"))
        self.assertEqual(answer.get_type().to_text(), "A")
        rdata = answer.get_rdata()
        self.assertEqual(rdata[0].to_text(), "192.0.2.1")

        get_msg = self.sock.read_msg()
556
        self.assertFalse(self.message_has_tsig(get_msg))
557
558
559
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 0)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
560

561
562
        #answer_rrset_iter = section_iter(get_msg, Message.SECTION_ANSWER)
        answer = get_msg.get_section(Message.SECTION_ANSWER)[0]
563
564
565
566
        self.assertEqual(answer.get_name().to_text(), "example.com.")
        self.assertEqual(answer.get_class(), RRClass("IN"))
        self.assertEqual(answer.get_type().to_text(), "SOA")
        rdata = answer.get_rdata()
567
        self.assertEqual(rdata[0], self.soa_rrset.get_rdata()[0])
568
569
570
571

        # and it should not have sent anything else
        self.assertEqual(0, len(self.sock.sendqueue))

572
573
574
575
    def test_trigger_send_message_with_last_soa_with_tsig(self):
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
        msg = self.getmsg()
        msg.make_response()
576
        msg.add_rrset(Message.SECTION_ANSWER, self.soa_rrset)
577
578

        # length larger than MAX-len(rrset)
579
580
        length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - \
            get_rrset_len(self.soa_rrset) + 1
581
582
583
584
585
586
587
        # packet number less than TSIG_SIGN_EVERY_NTH
        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1

        # give the function a value that is larger than MAX-len(rrset)
        # this should have triggered the sending of two messages
        # (1 with the rrset we added manually, and 1 that triggered
        # the sending in _with_last_soa)
588
589
590
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset,
                                                 length_need_split,
591
                                                 packet_neet_not_sign)
592
        get_msg = self.sock.read_msg()
593
594
595
        # msg is not the TSIG_SIGN_EVERY_NTH one, it shouldn't be tsig signed
        self.assertFalse(self.message_has_tsig(get_msg))
        # the last packet should be tsig signed
596
597
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))
598
599
        # and it should not have sent anything else
        self.assertEqual(0, len(self.sock.sendqueue))
600

601
602

        # msg is the TSIG_SIGN_EVERY_NTH one, it should be tsig signed
603
604
605
        self.xfrsess._send_message_with_last_soa(msg, self.sock,
                                                 self.soa_rrset,
                                                 length_need_split,
606
607
608
609
610
611
                                                 xfrout.TSIG_SIGN_EVERY_NTH)
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))
        # the last packet should be tsig signed
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))
612
613
614
        # and it should not have sent anything else
        self.assertEqual(0, len(self.sock.sendqueue))

615
    def test_get_rrset_len(self):
616
        self.assertEqual(82, get_rrset_len(self.soa_rrset))
Likun Zhang's avatar
Likun Zhang committed
617
618

    def test_check_xfrout_available(self):
619
620
621
622
623
        self.xfrsess.ClientClass = MockDataSrcClient
        self.assertEqual(self.xfrsess._check_xfrout_available(
                Name('notauth.example.com')), Rcode.NOTAUTH())
        self.assertEqual(self.xfrsess._check_xfrout_available(
                Name('nosoa.example.com')), Rcode.SERVFAIL())
624
625
        self.assertEqual(self.xfrsess._check_xfrout_available(
                Name('multisoa.example.com')), Rcode.SERVFAIL())
626

Likun Zhang's avatar
Likun Zhang committed
627
628
629
630
631
    def test_dns_xfrout_start_formerror(self):
        # formerror
        self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
        sent_data = self.sock.readsent()
        self.assertEqual(len(sent_data), 0)
632

Likun Zhang's avatar
Likun Zhang committed
633
634
635
636
637
638
    def default(self, param):
        return "example.com"

    def test_dns_xfrout_start_notauth(self):
        self.xfrsess._get_query_zone_name = self.default
        def notauth(formpara):
639
            return Rcode.NOTAUTH()
Likun Zhang's avatar
Likun Zhang committed
640
641
642
643
        self.xfrsess._check_xfrout_available = notauth
        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
        get_msg = self.sock.read_msg()
        self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
644

Likun Zhang's avatar
Likun Zhang committed
645
646
647
    def test_dns_xfrout_start_noerror(self):
        self.xfrsess._get_query_zone_name = self.default
        def noerror(form):
648
            return Rcode.NOERROR()
Likun Zhang's avatar
Likun Zhang committed
649
        self.xfrsess._check_xfrout_available = noerror
650

651
        def myreply(msg, sock):
Likun Zhang's avatar
Likun Zhang committed
652
            self.sock.send(b"success")
653

Likun Zhang's avatar
Likun Zhang committed
654
655
656
        self.xfrsess._reply_xfrout_query = myreply
        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
        self.assertEqual(self.sock.readsent(), b"success")
657

Likun Zhang's avatar
Likun Zhang committed
658
    def test_reply_xfrout_query_noerror(self):
659
        self.xfrsess._soa = self.soa_rrset
660
        self.xfrsess._iterator = [self.soa_rrset]
661
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)
Likun Zhang's avatar
Likun Zhang committed
662
        reply_msg = self.sock.read_msg()
663
        self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2)
Likun Zhang's avatar
Likun Zhang committed
664

665
    def test_reply_xfrout_query_noerror_with_tsig(self):
666
667
668
        rrset = RRset(Name('a.example.com'), RRClass.IN(), RRType.A(),
                      RRTTL(3600))
        rrset.add_rdata(Rdata(RRType.A(), RRClass.IN(), '192.0.2.1'))
669
670
671
672
673
        global xfrout

        def get_rrset_len(rrset):
            return 65520

674
        self.xfrsess._soa = self.soa_rrset
675
        self.xfrsess._iterator = [rrset for i in range(0, 100)]
676
677
678
        xfrout.get_rrset_len = get_rrset_len

        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
679
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702

        # tsig signed first package
        reply_msg = self.sock.read_msg()
        self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertTrue(self.message_has_tsig(reply_msg))
        # (TSIG_SIGN_EVERY_NTH - 1) packets have no tsig
        for i in range(0, xfrout.TSIG_SIGN_EVERY_NTH - 1):
            reply_msg = self.sock.read_msg()
            self.assertFalse(self.message_has_tsig(reply_msg))
        # TSIG_SIGN_EVERY_NTH packet has tsig
        reply_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(reply_msg))

        for i in range(0, 100 - TSIG_SIGN_EVERY_NTH):
            reply_msg = self.sock.read_msg()
            self.assertFalse(self.message_has_tsig(reply_msg))
        # tsig signed last package
        reply_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(reply_msg))

        # and it should not have sent anything else
        self.assertEqual(0, len(self.sock.sendqueue))

Likun Zhang's avatar
Likun Zhang committed
703
704
705
class MyUnixSockServer(UnixSockServer):
    def __init__(self):
        self._shutdown_event = threading.Event()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
706
        self._common_init()
707
708
        self._cc = MyCCSession()
        self.update_config_data(self._cc.get_full_config())
Likun Zhang's avatar
Likun Zhang committed
709
710
711

class TestUnixSockServer(unittest.TestCase):
    def setUp(self):
712
        self.write_sock, self.read_sock = socket.socketpair()
Likun Zhang's avatar
Likun Zhang committed
713
        self.unix = MyUnixSockServer()
714

715
716
717
718
719
    def test_guess_remote(self):
        """Test we can guess the remote endpoint when we have only the
           file descriptor. This is needed, because we get only that one
           from auth."""
        # We test with UDP, as it can be "connected" without other
720
721
        # endpoint.  Note that in the current implementation _guess_remote()
        # unconditionally returns SOCK_STREAM.
722
723
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.connect(('127.0.0.1', 12345))
724
725
        self.assertEqual((socket.AF_INET, socket.SOCK_STREAM,
                          ('127.0.0.1', 12345)),
726
                         self.unix._guess_remote(sock.fileno()))
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
727
728
729
730
        if socket.has_ipv6:
            # Don't check IPv6 address on hosts not supporting them
            sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
            sock.connect(('::1', 12345))
731
732
            self.assertEqual((socket.AF_INET6, socket.SOCK_STREAM,
                              ('::1', 12345, 0, 0)),
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
733
734
735
736
737
738
                             self.unix._guess_remote(sock.fileno()))
            # Try when pretending there's no IPv6 support
            # (No need to pretend when there's really no IPv6)
            xfrout.socket.has_ipv6 = False
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.connect(('127.0.0.1', 12345))
739
740
            self.assertEqual((socket.AF_INET, socket.SOCK_STREAM,
                              ('127.0.0.1', 12345)),
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
741
742
743
                             self.unix._guess_remote(sock.fileno()))
            # Return it back
            xfrout.socket.has_ipv6 = True
744

745
746
747
748
749
750
751
752
    def test_receive_query_message(self):
        send_msg = b"\xd6=\x00\x00\x00\x01\x00"
        msg_len = struct.pack('H', socket.htons(len(send_msg)))
        self.write_sock.send(msg_len)
        self.write_sock.send(send_msg)
        recv_msg = self.unix._receive_query_message(self.read_sock)
        self.assertEqual(recv_msg, send_msg)

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
753
754
    def check_default_ACL(self):
        context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
755
756
                                             1234, 0, socket.SOCK_DGRAM,
                                             socket.IPPROTO_UDP,
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
757
                                             socket.AI_NUMERICHOST)[0][4])
758
        self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
759

760
    def check_loaded_ACL(self, acl):
761
        context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
762
763
                                             1234, 0, socket.SOCK_DGRAM,
                                             socket.IPPROTO_UDP,
764
                                             socket.AI_NUMERICHOST)[0][4])
765
        self.assertEqual(isc.acl.acl.ACCEPT, acl.execute(context))
766
        context = isc.acl.dns.RequestContext(socket.getaddrinfo("192.0.2.1",
767
768
                                             1234, 0, socket.SOCK_DGRAM,
                                             socket.IPPROTO_UDP,
769
                                             socket.AI_NUMERICHOST)[0][4])
770
        self.assertEqual(isc.acl.acl.REJECT, acl.execute(context))
771

772
    def test_update_config_data(self):
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
773
        self.check_default_ACL()
774
775
776
        tsig_key_str = 'example.com:SFuWd/q99SzF8Yzd1QbB9g=='
        tsig_key_list = [tsig_key_str]
        bad_key_list = ['bad..example.com:SFuWd/q99SzF8Yzd1QbB9g==']
777
        self.unix.update_config_data({'transfers_out':10 })
Likun Zhang's avatar
Likun Zhang committed
778
        self.assertEqual(self.unix._max_transfers_out, 10)
779
        self.assertTrue(self.unix.tsig_key_ring is not None)
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
780
        self.check_default_ACL()
781

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
782
783
        self.unix.update_config_data({'transfers_out':9,
                                      'tsig_key_ring':tsig_key_list})
784
785
786
787
788
789
790
791
792
        self.assertEqual(self.unix._max_transfers_out, 9)
        self.assertEqual(self.unix.tsig_key_ring.size(), 1)
        self.unix.tsig_key_ring.remove(Name("example.com."))
        self.assertEqual(self.unix.tsig_key_ring.size(), 0)

        # bad tsig key
        config_data = {'transfers_out':9, 'tsig_key_ring': bad_key_list}
        self.assertRaises(None, self.unix.update_config_data(config_data))
        self.assertEqual(self.unix.tsig_key_ring.size(), 0)
Likun Zhang's avatar
Likun Zhang committed
793

794
        # Load the ACL
795
        self.unix.update_config_data({'transfer_acl': [{'from': '127.0.0.1',
796
                                               'action': 'ACCEPT'}]})
797
        self.check_loaded_ACL(self.unix._acl)
798
        # Pass a wrong data there and check it does not replace the old one
799
        self.assertRaises(XfroutConfigError,
800
                          self.unix.update_config_data,
801
                          {'transfer_acl': ['Something bad']})
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
        self.check_loaded_ACL(self.unix._acl)

    def test_zone_config_data(self):
        # By default, there's no specific zone config
        self.assertEqual({}, self.unix._zone_config)

        # Adding config for a specific zone.  The config is empty unless
        # explicitly specified.
        self.unix.update_config_data({'zone_config':
                                          [{'origin': 'example.com',
                                            'class': 'IN'}]})
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])

        # zone class can be omitted
        self.unix.update_config_data({'zone_config':
                                          [{'origin': 'example.com'}]})
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])

        # zone class, name are stored in the "normalized" form.  class
        # strings are upper cased, names are down cased.
        self.unix.update_config_data({'zone_config':
                                          [{'origin': 'EXAMPLE.com'}]})
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])

        # invalid zone class, name will result in exceptions
        self.assertRaises(EmptyLabel,
                          self.unix.update_config_data,
                          {'zone_config': [{'origin': 'bad..example'}]})
        self.assertRaises(InvalidRRClass,
                          self.unix.update_config_data,
                          {'zone_config': [{'origin': 'example.com',
                                            'class': 'badclass'}]})

        # Configuring a couple of more zones
        self.unix.update_config_data({'zone_config':
                                          [{'origin': 'example.com'},
                                           {'origin': 'example.com',
                                            'class': 'CH'},
                                           {'origin': 'example.org'}]})
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])
        self.assertEqual({}, self.unix._zone_config[('CH', 'example.com.')])
        self.assertEqual({}, self.unix._zone_config[('IN', 'example.org.')])

        # Duplicate data: should be rejected with an exception
846
        self.assertRaises(XfroutConfigError,
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
                          self.unix.update_config_data,
                          {'zone_config': [{'origin': 'example.com'},
                                           {'origin': 'example.org'},
                                           {'origin': 'example.com'}]})

    def test_zone_config_data_with_acl(self):
        # Similar to the previous test, but with transfer_acl config
        self.unix.update_config_data({'zone_config':
                                          [{'origin': 'example.com',
                                            'transfer_acl':
                                                [{'from': '127.0.0.1',
                                                  'action': 'ACCEPT'}]}]})
        acl = self.unix._zone_config[('IN', 'example.com.')]['transfer_acl']
        self.check_loaded_ACL(acl)

        # invalid ACL syntax will be rejected with exception
863
        self.assertRaises(XfroutConfigError,
864
865
866
867
                          self.unix.update_config_data,
                          {'zone_config': [{'origin': 'example.com',
                                            'transfer_acl':
                                                [{'action': 'BADACTION'}]}]})
868

Likun Zhang's avatar
Likun Zhang committed
869
870
871
872
873
874
875
876
877
878
879
880
881
    def test_get_db_file(self):
        self.assertEqual(self.unix.get_db_file(), "initdb.file")

    def test_increase_transfers_counter(self):
        self.unix._max_transfers_out = 10
        count = self.unix._transfers_counter
        self.assertEqual(self.unix.increase_transfers_counter(), True)
        self.assertEqual(count + 1, self.unix._transfers_counter)

        self.unix._max_transfers_out = 0
        count = self.unix._transfers_counter
        self.assertEqual(self.unix.increase_transfers_counter(), False)
        self.assertEqual(count, self.unix._transfers_counter)
882

Likun Zhang's avatar
Likun Zhang committed
883
884
885
886
887
    def test_decrease_transfers_counter(self):
        count = self.unix._transfers_counter
        self.unix.decrease_transfers_counter()
        self.assertEqual(count - 1, self.unix._transfers_counter)

888
889
890
891
892
    def _remove_file(self, sock_file):
        try:
            os.remove(sock_file)
        except OSError:
            pass
893

894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
    def test_sock_file_in_use_file_exist(self):
        sock_file = 'temp.sock.file'
        self._remove_file(sock_file)
        self.assertFalse(self.unix._sock_file_in_use(sock_file))
        self.assertFalse(os.path.exists(sock_file))

    def test_sock_file_in_use_file_not_exist(self):
        self.assertFalse(self.unix._sock_file_in_use('temp.sock.file'))

    def _start_unix_sock_server(self, sock_file):
        serv = ThreadingUnixStreamServer(sock_file, BaseRequestHandler)
        serv_thread = threading.Thread(target=serv.serve_forever)
        serv_thread.setDaemon(True)
        serv_thread.start()

    def test_sock_file_in_use(self):
        sock_file = 'temp.sock.file'
        self._remove_file(sock_file)
        self.assertFalse(self.unix._sock_file_in_use(sock_file))
        self._start_unix_sock_server(sock_file)

        old_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
        self.assertTrue(self.unix._sock_file_in_use(sock_file))
        sys.stdout = old_stdout

    def test_remove_unused_sock_file_in_use(self):
        sock_file = 'temp.sock.file'
        self._remove_file(sock_file)
        self.assertFalse(self.unix._sock_file_in_use(sock_file))
        self._start_unix_sock_server(sock_file)
        old_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
        try:
            self.unix._remove_unused_sock_file(sock_file)
        except SystemExit:
            pass
        else:
            # This should never happen
            self.assertTrue(False)

        sys.stdout = old_stdout

    def test_remove_unused_sock_file_dir(self):
        import tempfile
        dir_name = tempfile.mkdtemp()
        old_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
        try:
            self.unix._remove_unused_sock_file(dir_name)
        except SystemExit:
            pass
        else:
            # This should never happen
            self.assertTrue(False)

        sys.stdout = old_stdout
        os.rmdir(dir_name)
Likun Zhang's avatar
Likun Zhang committed
952

953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
class TestInitialization(unittest.TestCase):
    def setEnv(self, name, value):
        if value is None:
            if name in os.environ:
                del os.environ[name]
        else:
            os.environ[name] = value

    def setUp(self):
        self._oldSocket = os.getenv("BIND10_XFROUT_SOCKET_FILE")
        self._oldFromBuild = os.getenv("B10_FROM_BUILD")

    def tearDown(self):
        self.setEnv("B10_FROM_BUILD", self._oldFromBuild)
        self.setEnv("BIND10_XFROUT_SOCKET_FILE", self._oldSocket)
        # Make sure even the computed values are back
        xfrout.init_paths()

    def testNoEnv(self):
        self.setEnv("B10_FROM_BUILD", None)
        self.setEnv("BIND10_XFROUT_SOCKET_FILE", None)
        xfrout.init_paths()
        self.assertEqual(xfrout.UNIX_SOCKET_FILE,
976
                         "@@LOCALSTATEDIR@@/@PACKAGE_NAME@/auth_xfrout_conn")
977
978
979
980
981
982
983

    def testProvidedSocket(self):
        self.setEnv("B10_FROM_BUILD", None)
        self.setEnv("BIND10_XFROUT_SOCKET_FILE", "The/Socket/File")
        xfrout.init_paths()
        self.assertEqual(xfrout.UNIX_SOCKET_FILE, "The/Socket/File")

Likun Zhang's avatar
Likun Zhang committed
984
if __name__== "__main__":
985
    isc.log.resetUnitTestRootLogger()
Likun Zhang's avatar
Likun Zhang committed
986
    unittest.main()