xfrout_test.py.in 30.3 KB
Newer Older
Likun Zhang's avatar
Likun Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright (C) 2010  Internet Systems Consortium.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

'''Tests for the XfroutSession and UnixSockServer classes '''


import unittest
import os
21
from isc.testutils.tsigctx_mock import MockTSIGContext
Likun Zhang's avatar
Likun Zhang committed
22
from isc.cc.session import *
Jelte Jansen's avatar
Jelte Jansen committed
23
from pydnspp import *
Likun Zhang's avatar
Likun Zhang committed
24
from xfrout import *
25
import xfrout
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
26
import isc.acl.dns
Likun Zhang's avatar
Likun Zhang committed
27

28
29
TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")

Likun Zhang's avatar
Likun Zhang committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# our fake socket, where we can read and insert messages
class MySocket():
    def __init__(self, family, type):
        self.family = family
        self.type = type
        self.sendqueue = bytearray()

    def connect(self, to):
        pass

    def close(self):
        pass

    def send(self, data):
        self.sendqueue.extend(data);
        return len(data)

    def readsent(self):
48
49
50
51
52
53
        if len(self.sendqueue) >= 2:
            size = 2 + struct.unpack("!H", self.sendqueue[:2])[0]
        else:
            size = 0
        result = self.sendqueue[:size]
        self.sendqueue = self.sendqueue[size:]
Likun Zhang's avatar
Likun Zhang committed
54
        return result
55

Likun Zhang's avatar
Likun Zhang committed
56
57
    def read_msg(self):
        sent_data = self.readsent()
58
59
        get_msg = Message(Message.PARSE)
        get_msg.from_wire(bytes(sent_data[2:]))
Likun Zhang's avatar
Likun Zhang committed
60
        return get_msg
61

Likun Zhang's avatar
Likun Zhang committed
62
63
64
65
    def clear_send(self):
        del self.sendqueue[:]

# We subclass the Session class we're testing here, only
66
# to override the handle() and _send_data() method
Likun Zhang's avatar
Likun Zhang committed
67
68
69
class MyXfroutSession(XfroutSession):
    def handle(self):
        pass
70
71
72
73
74
75
76
77

    def _send_data(self, sock, data):
        size = len(data)
        total_count = 0
        while total_count < size:
            count = sock.send(data[total_count:])
            total_count += count

Likun Zhang's avatar
Likun Zhang committed
78
79
80
81
82
83
84
85
86
87
class Dbserver:
    def __init__(self):
        self._shutdown_event = threading.Event()
    def get_db_file(self):
        return None
    def decrease_transfers_counter(self):
        pass

class TestXfroutSession(unittest.TestCase):
    def getmsg(self):
88
89
        msg = Message(Message.PARSE)
        msg.from_wire(self.mdata)
Likun Zhang's avatar
Likun Zhang committed
90
91
        return msg

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    def create_mock_tsig_ctx(self, error):
        # This helper function creates a MockTSIGContext for a given key
        # and TSIG error to be used as a result of verify (normally faked
        # one)
        mock_ctx = MockTSIGContext(TSIG_KEY)
        mock_ctx.error = error
        return mock_ctx

    def message_has_tsig(self, msg):
        return msg.get_tsig_record() is not None

    def create_request_data_with_tsig(self):
        msg = Message(Message.RENDER)
        query_id = 0x1035
        msg.set_qid(query_id)
        msg.set_opcode(Opcode.QUERY())
        msg.set_rcode(Rcode.NOERROR())
        query_question = Question(Name("example.com."), RRClass.IN(), RRType.AXFR())
        msg.add_question(query_question)

        renderer = MessageRenderer()
113
        tsig_ctx = MockTSIGContext(TSIG_KEY)
114
115
116
117
        msg.to_wire(renderer, tsig_ctx)
        reply_data = renderer.get_data()
        return reply_data

Likun Zhang's avatar
Likun Zhang committed
118
    def setUp(self):
119
        self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
120
        self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
121
122
123
124
                                       TSIGKeyRing(), ('127.0.0.1', 12345),
                                       # When not testing ACLs, simply accept
                                       isc.acl.dns.REQUEST_LOADER.load(
                                           [{"action": "ACCEPT"}]))
125
        self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
Likun Zhang's avatar
Likun Zhang committed
126
127
128
129
130
131
        self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')

    def test_parse_query_message(self):
        [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(get_rcode.to_text(), "NOERROR")

132
133
134
135
136
137
138
139
140
141
142
143
        # tsig signed query message
        request_data = self.create_request_data_with_tsig()
        # BADKEY
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOTAUTH")
        self.assertTrue(self.xfrsess._tsig_ctx is not None)
        # NOERROR
        self.xfrsess._tsig_key_ring.add(TSIG_KEY)
        [rcode, msg] = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOERROR")
        self.assertTrue(self.xfrsess._tsig_ctx is not None)

144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
        # ACL checks, put some ACL inside
        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
            {
                "from": "127.0.0.1",
                "action": "ACCEPT"
            },
            {
                "from": "192.0.2.1",
                "action": "DROP"
            }
        ])
        # Localhost (the default in this test) is accepted
        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(rcode.to_text(), "NOERROR")
        # This should be dropped completely, therefore returning None
        self.xfrsess._remote = ('192.0.2.1', 12345)
        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
161
        self.assertEqual(None, rcode)
JINMEI Tatuya's avatar
JINMEI Tatuya committed
162
        # This should be refused, therefore REFUSED
163
164
165
        self.xfrsess._remote = ('192.0.2.2', 12345)
        rcode, msg = self.xfrsess._parse_query_message(self.mdata)
        self.assertEqual(rcode.to_text(), "REFUSED")
166
167
168
169
170
171
172
        # If the TSIG check fails, it should not check ACL
        # (If it checked ACL as well, it would just drop the request)
        self.xfrsess._remote = ('192.0.2.1', 12345)
        self.xfrsess._tsig_key_ring = TSIGKeyRing()
        rcode, msg = self.xfrsess._parse_query_message(request_data)
        self.assertEqual(rcode.to_text(), "NOTAUTH")
        self.assertTrue(self.xfrsess._tsig_ctx is not None)
173

Likun Zhang's avatar
Likun Zhang committed
174
175
176
    def test_get_query_zone_name(self):
        msg = self.getmsg()
        self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
177

Likun Zhang's avatar
Likun Zhang committed
178
179
180
181
182
183
184
    def test_send_data(self):
        self.xfrsess._send_data(self.sock, self.mdata)
        senddata = self.sock.readsent()
        self.assertEqual(senddata, self.mdata)

    def test_reply_xfrout_query_with_error_rcode(self):
        msg = self.getmsg()
185
        self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
Likun Zhang's avatar
Likun Zhang committed
186
        get_msg = self.sock.read_msg()
187
188
        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")

189
190
191
192
193
194
195
196
        # tsig signed message
        msg = self.getmsg()
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
        self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
        get_msg = self.sock.read_msg()
        self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
        self.assertTrue(self.message_has_tsig(get_msg))

197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    def test_send_message(self):
        msg = self.getmsg()
        msg.make_response()
        # soa record data with different cases
        soa_record = (4, 3, 'Example.com.', 'com.Example.', 3600, 'SOA', None, 'master.Example.com. admin.exAmple.com. 1234 3600 1800 2419200 7200')
        rrset_soa = self.xfrsess._create_rrset_from_db_record(soa_record)
        msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
        self.xfrsess._send_message(self.sock, msg)
        send_out_data = self.sock.readsent()[2:]

        # CASE_INSENSITIVE compression mode
        render = MessageRenderer();
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
        msg.to_wire(render)
        self.assertNotEqual(render.get_data(), send_out_data)

        # CASE_SENSITIVE compression mode
        render.clear()
        render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
        msg.to_wire(render)
        self.assertEqual(render.get_data(), send_out_data)

Likun Zhang's avatar
Likun Zhang committed
220
221
222
223
224
225
226
227
228
229
    def test_clear_message(self):
        msg = self.getmsg()
        qid = msg.get_qid()
        opcode = msg.get_opcode()
        rcode = msg.get_rcode()

        self.xfrsess._clear_message(msg)
        self.assertEqual(msg.get_qid(), qid)
        self.assertEqual(msg.get_opcode(), opcode)
        self.assertEqual(msg.get_rcode(), rcode)
230
        self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))
Likun Zhang's avatar
Likun Zhang committed
231
232
233
234

    def test_create_rrset_from_db_record(self):
        rrset = self.xfrsess._create_rrset_from_db_record(self.soa_record)
        self.assertEqual(rrset.get_name().to_text(), "example.com.")
235
        self.assertEqual(rrset.get_class(), RRClass("IN"))
Likun Zhang's avatar
Likun Zhang committed
236
        self.assertEqual(rrset.get_type().to_text(), "SOA")
237
238
        rdata = rrset.get_rdata()
        self.assertEqual(rdata[0].to_text(), self.soa_record[7])
Likun Zhang's avatar
Likun Zhang committed
239
240
241
242
243

    def test_send_message_with_last_soa(self):
        rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
        msg = self.getmsg()
        msg.make_response()
244
245
246
247
248

        # packet number less than TSIG_SIGN_EVERY_NTH
        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
                                                 0, packet_neet_not_sign)
Likun Zhang's avatar
Likun Zhang committed
249
        get_msg = self.sock.read_msg()
250
251
        # tsig context is not exist
        self.assertFalse(self.message_has_tsig(get_msg))
Likun Zhang's avatar
Likun Zhang committed
252

253
254
255
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
Likun Zhang's avatar
Likun Zhang committed
256

257
        #answer_rrset_iter = section_iter(get_msg, section.ANSWER())
258
        answer = get_msg.get_section(Message.SECTION_ANSWER)[0]#answer_rrset_iter.get_rrset()
Likun Zhang's avatar
Likun Zhang committed
259
        self.assertEqual(answer.get_name().to_text(), "example.com.")
260
        self.assertEqual(answer.get_class(), RRClass("IN"))
Likun Zhang's avatar
Likun Zhang committed
261
        self.assertEqual(answer.get_type().to_text(), "SOA")
262
263
        rdata = answer.get_rdata()
        self.assertEqual(rdata[0].to_text(), self.soa_record[7])
Likun Zhang's avatar
Likun Zhang committed
264

265
266
267
268
269
270
271
272
        # msg is the TSIG_SIGN_EVERY_NTH one
        # sending the message with last soa together
        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
                                                 0, TSIG_SIGN_EVERY_NTH)
        get_msg = self.sock.read_msg()
        # tsig context is not exist
        self.assertFalse(self.message_has_tsig(get_msg))

273
    def test_send_message_with_last_soa_with_tsig(self):
274
        # create tsig context
275
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
276

277
278
279
        rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
        msg = self.getmsg()
        msg.make_response()
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297

        # packet number less than TSIG_SIGN_EVERY_NTH
        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1
        # msg is not the TSIG_SIGN_EVERY_NTH one
        # sending the message with last soa together
        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
                                                 0, packet_neet_not_sign)
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))

        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)

        # msg is the TSIG_SIGN_EVERY_NTH one
        # sending the message with last soa together
        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa,
                                                 0, TSIG_SIGN_EVERY_NTH)
298
299
300
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))

301
302
303
304
305
    def test_trigger_send_message_with_last_soa(self):
        rrset_a = RRset(Name("example.com"), RRClass.IN(), RRType.A(), RRTTL(3600))
        rrset_a.add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))
        rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)

Likun Zhang's avatar
Likun Zhang committed
306
        msg = self.getmsg()
307
        msg.make_response()
308
        msg.add_rrset(Message.SECTION_ANSWER, rrset_a)
309

310
311
312
313
314
315
        # length larger than MAX-len(rrset)
        length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - get_rrset_len(rrset_soa) + 1
        # packet number less than TSIG_SIGN_EVERY_NTH
        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1

        # give the function a value that is larger than MAX-len(rrset)
316
317
318
        # this should have triggered the sending of two messages
        # (1 with the rrset we added manually, and 1 that triggered
        # the sending in _with_last_soa)
319
320
        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split,
                                                 packet_neet_not_sign)
321
        get_msg = self.sock.read_msg()
322
        self.assertFalse(self.message_has_tsig(get_msg))
323
324
325
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
326

327
        answer = get_msg.get_section(Message.SECTION_ANSWER)[0]
328
329
330
331
332
333
334
        self.assertEqual(answer.get_name().to_text(), "example.com.")
        self.assertEqual(answer.get_class(), RRClass("IN"))
        self.assertEqual(answer.get_type().to_text(), "A")
        rdata = answer.get_rdata()
        self.assertEqual(rdata[0].to_text(), "192.0.2.1")

        get_msg = self.sock.read_msg()
335
        self.assertFalse(self.message_has_tsig(get_msg))
336
337
338
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_QUESTION), 0)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertEqual(get_msg.get_rr_count(Message.SECTION_AUTHORITY), 0)
339

340
341
        #answer_rrset_iter = section_iter(get_msg, Message.SECTION_ANSWER)
        answer = get_msg.get_section(Message.SECTION_ANSWER)[0]
342
343
344
345
346
347
348
349
350
        self.assertEqual(answer.get_name().to_text(), "example.com.")
        self.assertEqual(answer.get_class(), RRClass("IN"))
        self.assertEqual(answer.get_type().to_text(), "SOA")
        rdata = answer.get_rdata()
        self.assertEqual(rdata[0].to_text(), self.soa_record[7])

        # and it should not have sent anything else
        self.assertEqual(0, len(self.sock.sendqueue))

351
352
353
354
355
356
    def test_trigger_send_message_with_last_soa_with_tsig(self):
        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
        rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
        msg = self.getmsg()
        msg.make_response()
        msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
357
358
359
360
361
362
363
364
365
366
367
368

        # length larger than MAX-len(rrset)
        length_need_split = xfrout.XFROUT_MAX_MESSAGE_SIZE - get_rrset_len(rrset_soa) + 1
        # packet number less than TSIG_SIGN_EVERY_NTH
        packet_neet_not_sign = xfrout.TSIG_SIGN_EVERY_NTH - 1

        # give the function a value that is larger than MAX-len(rrset)
        # this should have triggered the sending of two messages
        # (1 with the rrset we added manually, and 1 that triggered
        # the sending in _with_last_soa)
        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split,
                                                 packet_neet_not_sign)
369
        get_msg = self.sock.read_msg()
370
371
372
        # msg is not the TSIG_SIGN_EVERY_NTH one, it shouldn't be tsig signed
        self.assertFalse(self.message_has_tsig(get_msg))
        # the last packet should be tsig signed
373
374
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))
375
376
        # and it should not have sent anything else
        self.assertEqual(0, len(self.sock.sendqueue))
377

378
379
380
381
382
383
384
385
386

        # msg is the TSIG_SIGN_EVERY_NTH one, it should be tsig signed
        self.xfrsess._send_message_with_last_soa(msg, self.sock, rrset_soa, length_need_split,
                                                 xfrout.TSIG_SIGN_EVERY_NTH)
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))
        # the last packet should be tsig signed
        get_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(get_msg))
387
388
389
        # and it should not have sent anything else
        self.assertEqual(0, len(self.sock.sendqueue))

390
391
392
    def test_get_rrset_len(self):
        rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
        self.assertEqual(82, get_rrset_len(rrset_soa))
Likun Zhang's avatar
Likun Zhang committed
393

JINMEI Tatuya's avatar
JINMEI Tatuya committed
394
    def test_zone_has_soa(self):
Likun Zhang's avatar
Likun Zhang committed
395
396
397
398
        global sqlite3_ds
        def mydb1(zone, file):
            return True
        sqlite3_ds.get_zone_soa = mydb1
JINMEI Tatuya's avatar
JINMEI Tatuya committed
399
        self.assertTrue(self.xfrsess._zone_has_soa(""))
Likun Zhang's avatar
Likun Zhang committed
400
401
402
        def mydb2(zone, file):
            return False
        sqlite3_ds.get_zone_soa = mydb2
JINMEI Tatuya's avatar
JINMEI Tatuya committed
403
        self.assertFalse(self.xfrsess._zone_has_soa(""))
Likun Zhang's avatar
Likun Zhang committed
404
405
406

    def test_zone_exist(self):
        global sqlite3_ds
JINMEI Tatuya's avatar
JINMEI Tatuya committed
407
        def zone_exist(zone, file):
Likun Zhang's avatar
Likun Zhang committed
408
            return zone
JINMEI Tatuya's avatar
JINMEI Tatuya committed
409
410
411
        sqlite3_ds.zone_exist = zone_exist
        self.assertTrue(self.xfrsess._zone_exist(True))
        self.assertFalse(self.xfrsess._zone_exist(False))
412

Likun Zhang's avatar
Likun Zhang committed
413
414
415
    def test_check_xfrout_available(self):
        def zone_exist(zone):
            return zone
JINMEI Tatuya's avatar
JINMEI Tatuya committed
416
417
        def zone_has_soa(zone):
            return (not zone)
Likun Zhang's avatar
Likun Zhang committed
418
        self.xfrsess._zone_exist = zone_exist
JINMEI Tatuya's avatar
JINMEI Tatuya committed
419
        self.xfrsess._zone_has_soa = zone_has_soa
Likun Zhang's avatar
Likun Zhang committed
420
421
422
423
        self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH")
        self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL")

        def zone_empty(zone):
JINMEI Tatuya's avatar
JINMEI Tatuya committed
424
425
            return zone
        self.xfrsess._zone_has_soa = zone_empty
Likun Zhang's avatar
Likun Zhang committed
426
427
        def false_func():
            return False
428
        self.xfrsess._server.increase_transfers_counter = false_func
Likun Zhang's avatar
Likun Zhang committed
429
430
431
        self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "REFUSED")
        def true_func():
            return True
432
        self.xfrsess._server.increase_transfers_counter = true_func
Likun Zhang's avatar
Likun Zhang committed
433
434
435
436
437
438
439
        self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "NOERROR")

    def test_dns_xfrout_start_formerror(self):
        # formerror
        self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
        sent_data = self.sock.readsent()
        self.assertEqual(len(sent_data), 0)
440

Likun Zhang's avatar
Likun Zhang committed
441
442
443
444
445
446
    def default(self, param):
        return "example.com"

    def test_dns_xfrout_start_notauth(self):
        self.xfrsess._get_query_zone_name = self.default
        def notauth(formpara):
447
            return Rcode.NOTAUTH()
Likun Zhang's avatar
Likun Zhang committed
448
449
450
451
        self.xfrsess._check_xfrout_available = notauth
        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
        get_msg = self.sock.read_msg()
        self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
452

Likun Zhang's avatar
Likun Zhang committed
453
454
455
    def test_dns_xfrout_start_noerror(self):
        self.xfrsess._get_query_zone_name = self.default
        def noerror(form):
456
            return Rcode.NOERROR()
Likun Zhang's avatar
Likun Zhang committed
457
        self.xfrsess._check_xfrout_available = noerror
458

Likun Zhang's avatar
Likun Zhang committed
459
460
        def myreply(msg, sock, zonename):
            self.sock.send(b"success")
461

Likun Zhang's avatar
Likun Zhang committed
462
463
464
        self.xfrsess._reply_xfrout_query = myreply
        self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
        self.assertEqual(self.sock.readsent(), b"success")
465

Likun Zhang's avatar
Likun Zhang committed
466
467
468
469
470
471
472
473
474
475
476
477
    def test_reply_xfrout_query_noerror(self):
        global sqlite3_ds
        def get_zone_soa(zonename, file):
            return self.soa_record

        def get_zone_datas(zone, file):
            return [self.soa_record]

        sqlite3_ds.get_zone_soa = get_zone_soa
        sqlite3_ds.get_zone_datas = get_zone_datas
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")
        reply_msg = self.sock.read_msg()
478
        self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2)
Likun Zhang's avatar
Likun Zhang committed
479

480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
    def test_reply_xfrout_query_noerror_with_tsig(self):
        rrset_data = (4, 3, 'a.example.com.', 'com.example.', 3600, 'A', None, '192.168.1.1')
        global sqlite3_ds
        global xfrout
        def get_zone_soa(zonename, file):
            return self.soa_record

        def get_zone_datas(zone, file):
            zone_rrsets = []
            for i in range(0, 100):
                zone_rrsets.insert(i, rrset_data)
            return zone_rrsets

        def get_rrset_len(rrset):
            return 65520

        sqlite3_ds.get_zone_soa = get_zone_soa
        sqlite3_ds.get_zone_datas = get_zone_datas
        xfrout.get_rrset_len = get_rrset_len

        self.xfrsess._tsig_ctx = self.create_mock_tsig_ctx(TSIGError.NOERROR)
        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock, "example.com.")

        # tsig signed first package
        reply_msg = self.sock.read_msg()
        self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 1)
        self.assertTrue(self.message_has_tsig(reply_msg))
        # (TSIG_SIGN_EVERY_NTH - 1) packets have no tsig
        for i in range(0, xfrout.TSIG_SIGN_EVERY_NTH - 1):
            reply_msg = self.sock.read_msg()
            self.assertFalse(self.message_has_tsig(reply_msg))
        # TSIG_SIGN_EVERY_NTH packet has tsig
        reply_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(reply_msg))

        for i in range(0, 100 - TSIG_SIGN_EVERY_NTH):
            reply_msg = self.sock.read_msg()
            self.assertFalse(self.message_has_tsig(reply_msg))
        # tsig signed last package
        reply_msg = self.sock.read_msg()
        self.assertTrue(self.message_has_tsig(reply_msg))

        # and it should not have sent anything else
        self.assertEqual(0, len(self.sock.sendqueue))

525
526
527
528
529
530
531
532
533
class MyCCSession():
    def __init__(self):
        pass

    def get_remote_config_value(self, module_name, identifier):
        if module_name == "Auth" and identifier == "database_file":
            return "initdb.file", False
        else:
            return "unknown", False
534

Likun Zhang's avatar
Likun Zhang committed
535
536
537
538
539

class MyUnixSockServer(UnixSockServer):
    def __init__(self):
        self._shutdown_event = threading.Event()
        self._max_transfers_out = 10
540
        self._cc = MyCCSession()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
541
        self._common_init()
Likun Zhang's avatar
Likun Zhang committed
542
543
544

class TestUnixSockServer(unittest.TestCase):
    def setUp(self):
545
        self.write_sock, self.read_sock = socket.socketpair()
Likun Zhang's avatar
Likun Zhang committed
546
        self.unix = MyUnixSockServer()
547

548
549
550
551
552
553
554
555
556
557
    def test_guess_remote(self):
        """Test we can guess the remote endpoint when we have only the
           file descriptor. This is needed, because we get only that one
           from auth."""
        # We test with UDP, as it can be "connected" without other
        # endpoint
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        sock.connect(('127.0.0.1', 12345))
        self.assertEqual(('127.0.0.1', 12345),
                         self.unix._guess_remote(sock.fileno()))
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
        if socket.has_ipv6:
            # Don't check IPv6 address on hosts not supporting them
            sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
            sock.connect(('::1', 12345))
            self.assertEqual(('::1', 12345, 0, 0),
                             self.unix._guess_remote(sock.fileno()))
            # Try when pretending there's no IPv6 support
            # (No need to pretend when there's really no IPv6)
            xfrout.socket.has_ipv6 = False
            sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
            sock.connect(('127.0.0.1', 12345))
            self.assertEqual(('127.0.0.1', 12345),
                             self.unix._guess_remote(sock.fileno()))
            # Return it back
            xfrout.socket.has_ipv6 = True
573

574
575
576
577
578
579
580
581
    def test_receive_query_message(self):
        send_msg = b"\xd6=\x00\x00\x00\x01\x00"
        msg_len = struct.pack('H', socket.htons(len(send_msg)))
        self.write_sock.send(msg_len)
        self.write_sock.send(send_msg)
        recv_msg = self.unix._receive_query_message(self.read_sock)
        self.assertEqual(recv_msg, send_msg)

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
582
583
584
585
    def check_default_ACL(self):
        context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
                                             1234, 0, 0, 0,
                                             socket.AI_NUMERICHOST)[0][4])
586
        self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
587

588
589
590
591
592
593
594
595
596
597
    def check_loaded_ACL(self):
        context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
                                             1234, 0, 0, 0,
                                             socket.AI_NUMERICHOST)[0][4])
        self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
        context = isc.acl.dns.RequestContext(socket.getaddrinfo("192.0.2.1",
                                             1234, 0, 0, 0,
                                             socket.AI_NUMERICHOST)[0][4])
        self.assertEqual(isc.acl.acl.REJECT, self.unix._acl.execute(context))

598
    def test_update_config_data(self):
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
599
        self.check_default_ACL()
600
601
602
        tsig_key_str = 'example.com:SFuWd/q99SzF8Yzd1QbB9g=='
        tsig_key_list = [tsig_key_str]
        bad_key_list = ['bad..example.com:SFuWd/q99SzF8Yzd1QbB9g==']
603
        self.unix.update_config_data({'transfers_out':10 })
Likun Zhang's avatar
Likun Zhang committed
604
        self.assertEqual(self.unix._max_transfers_out, 10)
605
        self.assertTrue(self.unix.tsig_key_ring is not None)
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
606
        self.check_default_ACL()
607

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
608
609
        self.unix.update_config_data({'transfers_out':9,
                                      'tsig_key_ring':tsig_key_list})
610
611
612
613
614
615
616
617
618
        self.assertEqual(self.unix._max_transfers_out, 9)
        self.assertEqual(self.unix.tsig_key_ring.size(), 1)
        self.unix.tsig_key_ring.remove(Name("example.com."))
        self.assertEqual(self.unix.tsig_key_ring.size(), 0)

        # bad tsig key
        config_data = {'transfers_out':9, 'tsig_key_ring': bad_key_list}
        self.assertRaises(None, self.unix.update_config_data(config_data))
        self.assertEqual(self.unix.tsig_key_ring.size(), 0)
Likun Zhang's avatar
Likun Zhang committed
619

620
        # Load the ACL
621
        self.unix.update_config_data({'query_acl': [{'from': '127.0.0.1',
622
623
624
625
626
                                               'action': 'ACCEPT'}]})
        self.check_loaded_ACL()
        # Pass a wrong data there and check it does not replace the old one
        self.assertRaises(isc.acl.acl.LoaderError,
                          self.unix.update_config_data,
627
                          {'query_acl': ['Something bad']})
628
629
        self.check_loaded_ACL()

Likun Zhang's avatar
Likun Zhang committed
630
631
632
633
634
635
636
637
638
639
640
641
642
    def test_get_db_file(self):
        self.assertEqual(self.unix.get_db_file(), "initdb.file")

    def test_increase_transfers_counter(self):
        self.unix._max_transfers_out = 10
        count = self.unix._transfers_counter
        self.assertEqual(self.unix.increase_transfers_counter(), True)
        self.assertEqual(count + 1, self.unix._transfers_counter)

        self.unix._max_transfers_out = 0
        count = self.unix._transfers_counter
        self.assertEqual(self.unix.increase_transfers_counter(), False)
        self.assertEqual(count, self.unix._transfers_counter)
643

Likun Zhang's avatar
Likun Zhang committed
644
645
646
647
648
    def test_decrease_transfers_counter(self):
        count = self.unix._transfers_counter
        self.unix.decrease_transfers_counter()
        self.assertEqual(count - 1, self.unix._transfers_counter)

649
650
651
652
653
    def _remove_file(self, sock_file):
        try:
            os.remove(sock_file)
        except OSError:
            pass
654

655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
    def test_sock_file_in_use_file_exist(self):
        sock_file = 'temp.sock.file'
        self._remove_file(sock_file)
        self.assertFalse(self.unix._sock_file_in_use(sock_file))
        self.assertFalse(os.path.exists(sock_file))

    def test_sock_file_in_use_file_not_exist(self):
        self.assertFalse(self.unix._sock_file_in_use('temp.sock.file'))

    def _start_unix_sock_server(self, sock_file):
        serv = ThreadingUnixStreamServer(sock_file, BaseRequestHandler)
        serv_thread = threading.Thread(target=serv.serve_forever)
        serv_thread.setDaemon(True)
        serv_thread.start()

    def test_sock_file_in_use(self):
        sock_file = 'temp.sock.file'
        self._remove_file(sock_file)
        self.assertFalse(self.unix._sock_file_in_use(sock_file))
        self._start_unix_sock_server(sock_file)

        old_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
        self.assertTrue(self.unix._sock_file_in_use(sock_file))
        sys.stdout = old_stdout

    def test_remove_unused_sock_file_in_use(self):
        sock_file = 'temp.sock.file'
        self._remove_file(sock_file)
        self.assertFalse(self.unix._sock_file_in_use(sock_file))
        self._start_unix_sock_server(sock_file)
        old_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
        try:
            self.unix._remove_unused_sock_file(sock_file)
        except SystemExit:
            pass
        else:
            # This should never happen
            self.assertTrue(False)

        sys.stdout = old_stdout

    def test_remove_unused_sock_file_dir(self):
        import tempfile
        dir_name = tempfile.mkdtemp()
        old_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')
        try:
            self.unix._remove_unused_sock_file(dir_name)
        except SystemExit:
            pass
        else:
            # This should never happen
            self.assertTrue(False)

        sys.stdout = old_stdout
        os.rmdir(dir_name)
Likun Zhang's avatar
Likun Zhang committed
713

714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
class TestInitialization(unittest.TestCase):
    def setEnv(self, name, value):
        if value is None:
            if name in os.environ:
                del os.environ[name]
        else:
            os.environ[name] = value

    def setUp(self):
        self._oldSocket = os.getenv("BIND10_XFROUT_SOCKET_FILE")
        self._oldFromBuild = os.getenv("B10_FROM_BUILD")

    def tearDown(self):
        self.setEnv("B10_FROM_BUILD", self._oldFromBuild)
        self.setEnv("BIND10_XFROUT_SOCKET_FILE", self._oldSocket)
        # Make sure even the computed values are back
        xfrout.init_paths()

    def testNoEnv(self):
        self.setEnv("B10_FROM_BUILD", None)
        self.setEnv("BIND10_XFROUT_SOCKET_FILE", None)
        xfrout.init_paths()
        self.assertEqual(xfrout.UNIX_SOCKET_FILE,
                         "@@LOCALSTATEDIR@@/auth_xfrout_conn")

    def testProvidedSocket(self):
        self.setEnv("B10_FROM_BUILD", None)
        self.setEnv("BIND10_XFROUT_SOCKET_FILE", "The/Socket/File")
        xfrout.init_paths()
        self.assertEqual(xfrout.UNIX_SOCKET_FILE, "The/Socket/File")

Likun Zhang's avatar
Likun Zhang committed
745
746
if __name__== "__main__":
    unittest.main()