xfrout.py.in 37.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#!@PYTHON@

# Copyright (C) 2010  Internet Systems Consortium.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.


import sys; sys.path.append ('@@PYTHONPATH@@')
import isc
import isc.cc
import threading
import struct
import signal
25
from isc.datasrc import DataSourceClient, ZoneFinder, ZoneJournalReader
26
27
28
from socketserver import *
import os
from isc.config.ccsession import *
29
from isc.cc import SessionError, SessionTimeout
30
from isc.notify import notify_out
31
import isc.util.process
32
import socket
33
import select
34
import errno
35
from optparse import OptionParser, OptionValueError
Likun Zhang's avatar
Likun Zhang committed
36
from isc.util import socketserver_mixin
37

38
from isc.log_messages.xfrout_messages import *
39
40
41
42

isc.log.init("b10-xfrout")
logger = isc.log.Logger("xfrout")

43
try:
44
    from libutil_io_python import *
Jelte Jansen's avatar
Jelte Jansen committed
45
    from pydnspp import *
46
47
48
except ImportError as e:
    # C++ loadable module may not be installed; even so the xfrout process
    # must keep running, so we warn about it and move forward.
49
    log.error(XFROUT_IMPORT, str(e))
Michal Vaner's avatar
Michal Vaner committed
50

51
from isc.acl.acl import ACCEPT, REJECT, DROP, LoaderError
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
52
53
from isc.acl.dns import REQUEST_LOADER

54
isc.util.process.rename()
55

56
57
58
59
60
61
62
63
64
65
66
67
class XfroutConfigError(Exception):
    """An exception indicating an error in updating xfrout configuration.

    This exception is raised when the xfrout process encouters an error in
    handling configuration updates.  Not all syntax error can be caught
    at the module-CC layer, so xfrout needs to (explicitly or implicitly)
    validate the given configuration data itself.  When it finds an error
    it raises this exception (either directly or by converting an exception
    from other modules) as a unified error in configuration.
    """
    pass

68
69
70
71
72
73
74
75
76
77
78
79
def init_paths():
    global SPECFILE_PATH
    global AUTH_SPECFILE_PATH
    global UNIX_SOCKET_FILE
    if "B10_FROM_BUILD" in os.environ:
        SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/xfrout"
        AUTH_SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/auth"
        if "B10_FROM_SOURCE_LOCALSTATEDIR" in os.environ:
            UNIX_SOCKET_FILE = os.environ["B10_FROM_SOURCE_LOCALSTATEDIR"] + \
                "/auth_xfrout_conn"
        else:
            UNIX_SOCKET_FILE = os.environ["B10_FROM_BUILD"] + "/auth_xfrout_conn"
80
    else:
81
82
83
84
85
86
87
        PREFIX = "@prefix@"
        DATAROOTDIR = "@datarootdir@"
        SPECFILE_PATH = "@datadir@/@PACKAGE@".replace("${datarootdir}", DATAROOTDIR).replace("${prefix}", PREFIX)
        AUTH_SPECFILE_PATH = SPECFILE_PATH
        if "BIND10_XFROUT_SOCKET_FILE" in os.environ:
            UNIX_SOCKET_FILE = os.environ["BIND10_XFROUT_SOCKET_FILE"]
        else:
88
            UNIX_SOCKET_FILE = "@@LOCALSTATEDIR@@/@PACKAGE_NAME@/auth_xfrout_conn"
89
90

init_paths()
91

92
SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec"
93
AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
Jerry's avatar
Jerry committed
94
VERBOSE_MODE = False
95
96
XFROUT_MAX_MESSAGE_SIZE = 65535

97
98
99
100
101
102
103
104
105
106
# borrowed from xfrin.py @ #1298.  We should eventually unify it.
def format_zone_str(zone_name, zone_class):
    """Helper function to format a zone name and class as a string of
       the form '<name>/<class>'.
       Parameters:
       zone_name (isc.dns.Name) name to format
       zone_class (isc.dns.RRClass) class to format
    """
    return zone_name.to_text() + '/' + str(zone_class)

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# borrowed from xfrin.py @ #1298.
def format_addrinfo(addrinfo):
    """Helper function to format the addrinfo as a string of the form
       <addr>:<port> (for IPv4) or [<addr>]:port (for IPv6). For unix domain
       sockets, and unknown address families, it returns a basic string
       conversion of the third element of the passed tuple.
       Parameters:
       addrinfo: a 3-tuple consisting of address family, socket type, and,
                 depending on the family, either a 2-tuple with the address
                 and port, or a filename
    """
    try:
        if addrinfo[0] == socket.AF_INET:
            return str(addrinfo[2][0]) + ":" + str(addrinfo[2][1])
        elif addrinfo[0] == socket.AF_INET6:
            return "[" + str(addrinfo[2][0]) + "]:" + str(addrinfo[2][1])
        else:
            return str(addrinfo[2])
    except IndexError:
        raise TypeError("addrinfo argument to format_addrinfo() does not "
                        "appear to be consisting of (family, socktype, (addr, port))")

129
130
131
132
133
134
def get_rrset_len(rrset):
    """Returns the wire length of the given RRset"""
    bytes = bytearray()
    rrset.to_wire(bytes)
    return len(bytes)

135
136
137
138
139
def get_soa_serial(soa_rdata):
    '''Extract the serial field of an SOA RDATA and returns it as an intger.
    (borrowed from xfrin)
    '''
    return int(soa_rdata.to_text().split()[2])
140

141
class XfroutSession():
142
    def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
143
                 default_acl, zone_config, client_class=DataSourceClient):
144
145
146
        self._sock_fd = sock_fd
        self._request_data = request_data
        self._server = server
147
148
149
        self._tsig_key_ring = tsig_key_ring
        self._tsig_ctx = None
        self._tsig_len = 0
150
        self._remote = remote
151
152
        self._request_type = None
        self._request_typestr = None
153
154
        self._acl = default_acl
        self._zone_config = zone_config
155
        self.ClientClass = client_class # parameterize this for testing
156
        self._soa = None # will be set in _xfrout_setup or in tests
157
        self._handle()
Jerry's avatar
Jerry committed
158

159
160
161
162
    def create_tsig_ctx(self, tsig_record, tsig_key_ring):
        return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(),
                           tsig_key_ring)

163
164
165
166
167
168
169
170
171
172
173
    def _handle(self):
        ''' Handle a xfrout query, send xfrout response(s).

        This is separated from the constructor so that we can override
        it from tests.

        '''
        # Check the xfrout quota.  We do both increase/decrease in this
        # method so it's clear we always release it once acuired.
        quota_ok = self._server.increase_transfers_counter()
        ex = None
174
        try:
175
            self.dns_xfrout_start(self._sock_fd, self._request_data, quota_ok)
176
        except Exception as e:
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
            # To avoid resource leak we need catch all possible exceptions
            # We log it later to exclude the case where even logger raises
            # an exception.
            ex = e

        # Release any critical resources
        if quota_ok:
            self._server.decrease_transfers_counter()
        self._close_socket()

        if ex is not None:
            logger.error(XFROUT_HANDLE_QUERY_ERROR, ex)

    def _close_socket(self):
        '''Simply close the socket via the given FD.
192

193
194
195
196
        This is a dedicated subroutine of handle() and is sepsarated from it
        for the convenience of tests.

        '''
197
        os.close(self._sock_fd)
198

199
200
201
202
203
204
205
206
207
208
209
210
    def _check_request_tsig(self, msg, request_data):
        ''' If request has a tsig record, perform tsig related checks '''
        tsig_record = msg.get_tsig_record()
        if tsig_record is not None:
            self._tsig_len = tsig_record.get_length()
            self._tsig_ctx = self.create_tsig_ctx(tsig_record, self._tsig_key_ring)
            tsig_error = self._tsig_ctx.verify(tsig_record, request_data)
            if tsig_error != TSIGError.NOERROR:
                return Rcode.NOTAUTH()

        return Rcode.NOERROR()

211
212
    def _parse_query_message(self, mdata):
        ''' parse query message to [socket,message]'''
213
        #TODO, need to add parseHeader() in case the message header is invalid
214
        try:
215
            msg = Message(Message.PARSE)
216
            Message.from_wire(msg, mdata)
217
        except Exception as err: # Exception is too broad
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
218
            logger.error(XFROUT_PARSE_QUERY_ERROR, err)
219
            return Rcode.FORMERR(), None
220

221
222
        # TSIG related checks
        rcode = self._check_request_tsig(msg, mdata)
223
224
225
226
227
228
229
230
        if rcode != Rcode.NOERROR():
            return rcode, msg

        # Make sure the question is valid.  This should be ensured by
        # the auth server, but since it's far from our xfrout itself,
        # we check it by ourselves.
        if msg.get_rr_count(Message.SECTION_QUESTION) != 1:
            return Rcode.FORMERR(), msg
231
232
233
234
235
236
237
238
239
        question = msg.get_question()[0]

        # Identify the request type
        self._request_type = question.get_type()
        if self._request_type == RRType.AXFR():
            self._request_typestr = 'AXFR'
        elif self._request_type == RRType.IXFR():
            self._request_typestr = 'IXFR'
        else:
240
241
            # Likewise, this should be impossible.  (TBD: to be tested)
            raise RuntimeError('Unexpected XFR type: ' + \
242
                                   str(self._request_type))
243
244

        # ACL checks
245
246
        zone_name = question.get_name()
        zone_class = question.get_class()
247
248
        acl = self._get_transfer_acl(zone_name, zone_class)
        acl_result = acl.execute(
249
            isc.acl.dns.RequestContext(self._remote[2], msg.get_tsig_record()))
250
        if acl_result == DROP:
251
            logger.info(XFROUT_QUERY_DROPPED, self._request_typestr,
252
253
                        format_addrinfo(self._remote),
                        format_zone_str(zone_name, zone_class))
254
255
            return None, None
        elif acl_result == REJECT:
256
            logger.info(XFROUT_QUERY_REJECTED, self._request_typestr,
257
258
                        format_addrinfo(self._remote),
                        format_zone_str(zone_name, zone_class))
259
            return Rcode.REFUSED(), msg
260

261
        return rcode, msg
262

263
264
265
266
267
268
269
270
271
272
273
274
    def _get_transfer_acl(self, zone_name, zone_class):
        '''Return the ACL that should be applied for a given zone.

        The zone is identified by a tuple of name and RR class.
        If a per zone configuration for the zone exists and contains
        transfer_acl, that ACL will be used; otherwise, the default
        ACL will be used.

        '''
        # Internally zone names are managed in lower cased label characters,
        # so we first need to convert the name.
        zone_name_lower = Name(zone_name.to_text(), True)
275
        config_key = (zone_class.to_text(), zone_name_lower.to_text())
276
277
278
279
280
        if config_key in self._zone_config and \
                'transfer_acl' in self._zone_config[config_key]:
            return self._zone_config[config_key]['transfer_acl']
        return self._acl

281
    def _send_data(self, sock_fd, data):
282
283
284
        size = len(data)
        total_count = 0
        while total_count < size:
285
            count = os.write(sock_fd, data[total_count:])
286
287
288
            total_count += count


289
    def _send_message(self, sock_fd, msg, tsig_ctx=None):
290
        render = MessageRenderer()
291
292
        # As defined in RFC5936 section3.4, perform case-preserving name
        # compression for AXFR message.
293
        render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
294
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
295
296
297
298
299
300
301
302

        # XXX Currently, python wrapper doesn't accept 'None' parameter in this case,
        # we should remove the if statement and use a universal interface later.
        if tsig_ctx is not None:
            msg.to_wire(render, tsig_ctx)
        else:
            msg.to_wire(render)

303
        header_len = struct.pack('H', socket.htons(render.get_length()))
304
305
        self._send_data(sock_fd, header_len)
        self._send_data(sock_fd, render.get_data())
306
307


308
    def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
309
        if not msg:
310
            return # query message is invalid. send nothing back.
311
312

        msg.make_response()
313
        msg.set_rcode(rcode_)
314
        self._send_message(sock_fd, msg, self._tsig_ctx)
315

316
    def _get_zone_soa(self, zone_name):
317
318
319
320
321
322
323
324
        '''Retrieve the SOA RR of the given zone.

        It returns a pair of RCODE and the SOA (in the form of RRset).
        On success RCODE is NOERROR and returned SOA is not None;
        on failure RCODE indicates the appropriate code in the context of
        xfr processing, and the returned SOA is None.

        '''
325
326
        result, finder = self._datasrc_client.find_zone(zone_name)
        if result != DataSourceClient.SUCCESS:
327
328
329
            return (Rcode.NOTAUTH(), None)
        result, soa_rrset = finder.find(zone_name, RRType.SOA(), None,
                                        ZoneFinder.FIND_DEFAULT)
330
        if result != ZoneFinder.SUCCESS:
331
            return (Rcode.SERVFAIL(), None)
332
333
334
335
        # Especially for database-based zones, a working zone may be in
        # a broken state where it has more than one SOA RR.  We proactively
        # check the condition and abort the xfr attempt if we identify it.
        if soa_rrset.get_rdata_count() != 1:
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
            return (Rcode.SERVFAIL(), None)
        return (Rcode.NOERROR(), soa_rrset)

    def __setup_axfr(self, zone_name):
        '''Setup a zone iterator for AXFR or AXFR-style IXFR.

        '''
        try:
            # Note that we disable 'adjust_ttl'.  In xfr-out we need to
            # preserve as many things as possible (even if it's half
            # broken) stored in the zone.
            self._iterator = self._datasrc_client.get_iterator(zone_name,
                                                               False)
        except isc.datasrc.Error:
            # If the current name server does not have authority for the
            # zone, xfrout can't serve for it, return rcode NOTAUTH.
            # Note: this exception can happen for other reasons.  We should
            # update get_iterator() API so that we can distinguish "no such
            # zone" and other cases (#1373).  For now we consider all these
            # cases as NOTAUTH.
            return Rcode.NOTAUTH()

        # If we are an authoritative name server for the zone, but fail
        # to find the zone's SOA record in datasource, xfrout can't
        # provide zone transfer for it.
        self._soa = self._iterator.get_soa()
        if self._soa is None or self._soa.get_rdata_count() != 1:
            return Rcode.SERVFAIL()

        return Rcode.NOERROR()

    def __setup_ixfr(self, request_msg, zone_name):
        '''Setup a zone journal reader for IXFR.

        If the underlying data source does not know the requested range
        of zone differences it automatically falls back to AXFR-style
        IXFR by setting up a zone iterator instead of a journal reader.

        '''
        # TODO: more error case handling
        remote_soa = None
        for auth_rrset in request_msg.get_section(Message.SECTION_AUTHORITY):
            if auth_rrset.get_type() != RRType.SOA():
                continue
            remote_soa = auth_rrset
        rcode, self._soa = self._get_zone_soa(zone_name)
        if rcode != Rcode.NOERROR():
            return rcode
384
385
386
387
388
389
390
391
392
        try:
            code, self._jnl_reader = self._datasrc_client.get_journal_reader(
                zone_name, get_soa_serial(remote_soa.get_rdata()[0]),
                get_soa_serial(self._soa.get_rdata()[0]))
        except isc.datasrc.NotImplemented as ex:
            # The underlying data source doesn't support journaling.
            # Fallback to AXFR-style IXFR.
            # TBD: log it.
            return self.__setup_axfr(zone_name)
393
394
        if code == ZoneJournalReader.NO_SUCH_VERSION:
            # fallback to AXFR-style IXFR
395
            # TBD: log it.
396
            return self.__setup_axfr(zone_name)
397
398
399
400
401
402
        if code == ZoneJournalReader.NO_SUCH_ZONE:
            # this is quite unexpected as we know zone's SOA exists.
            # It might be a bug or the data source is somehow broken,
            # but it can still happen if someone has removed the zone
            # between these two operations.  We treat it as NOTAUTH.
            return Rcode.NOTAUTH()
403
404

        return Rcode.NOERROR()
405

406
407
408
409
410
411
412
    def _xfrout_setup(self, request_msg, zone_name):
        '''Setup a context for xfr responses according to the request type.

        This method identifies the most appropriate data source for the
        request and set up a zone iterator or journal reader depending on
        whether the request is AXFR or IXFR.  If it identifies any protocol
        level error it returns an RCODE other than NOERROR.
413

414
        '''
415

416
417
        # Identify the data source for the requested zone and see if it has
        # SOA while initializing objects used for request processing later.
JINMEI Tatuya's avatar
JINMEI Tatuya committed
418
419
420
        # We should eventually generalize this so that we can choose the
        # appropriate data source from (possible) multiple candidates.
        # We should eventually take into account the RR class here.
421
        # For now, we hardcode a particular type (SQLite3-based), and only
422
        # consider that one.
423
424
        datasrc_config = '{ "database_file": "' + \
            self._server.get_db_file() + '"}'
425
        self._datasrc_client = self.ClientClass('sqlite3', datasrc_config)
426
427

        if self._request_type == RRType.AXFR():
428
            return self.__setup_axfr(zone_name)
429
        else:
430
            return self.__setup_ixfr(request_msg, zone_name)
431

432
    def dns_xfrout_start(self, sock_fd, msg_query, quota_ok=True):
433
434
        rcode_, msg = self._parse_query_message(msg_query)
        #TODO. create query message and parse header
435
436
437
        if rcode_ is None: # Dropped by ACL
            return
        elif rcode_ == Rcode.NOTAUTH() or rcode_ == Rcode.REFUSED():
438
            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
439
        elif rcode_ != Rcode.NOERROR():
440
441
            return self._reply_query_with_error_rcode(msg, sock_fd,
                                                      Rcode.FORMERR())
442
        elif not quota_ok:
443
            logger.warn(XFROUT_QUERY_QUOTA_EXCCEEDED, self._request_typestr,
444
445
                        format_addrinfo(self._remote),
                        self._server._max_transfers_out)
446
447
            return self._reply_query_with_error_rcode(msg, sock_fd,
                                                      Rcode.REFUSED())
448

449
450
451
452
        question = msg.get_question()[0]
        zone_name = question.get_name()
        zone_class = question.get_class()
        zone_str = format_zone_str(zone_name, zone_class) # for logging
453

454
        # TODO: we should also include class in the check
455
        try:
456
            rcode_ = self._xfrout_setup(msg, zone_name)
457
        except Exception as ex:
458
            logger.error(XFROUT_XFR_TRANSFER_CHECK_ERROR, self._request_typestr,
459
460
                         format_addrinfo(self._remote), zone_str, ex)
            rcode_ = Rcode.SERVFAIL()
461
        if rcode_ != Rcode.NOERROR():
462
            logger.info(XFROUT_AXFR_TRANSFER_FAILED, self._request_typestr,
463
                        format_addrinfo(self._remote), zone_str, rcode_)
464
            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
465
466

        try:
467
            logger.info(XFROUT_AXFR_TRANSFER_STARTED, self._request_typestr,
468
                        format_addrinfo(self._remote), zone_str)
469
            self._reply_xfrout_query(msg, sock_fd)
470
        except Exception as err:
471
            logger.error(XFROUT_AXFR_TRANSFER_ERROR, self._request_typestr,
472
                    format_addrinfo(self._remote), zone_str, err)
473
            pass
474
        logger.info(XFROUT_AXFR_TRANSFER_DONE, self._request_typestr,
475
                    format_addrinfo(self._remote), zone_str)
476
477
478
479
480

    def _clear_message(self, msg):
        qid = msg.get_qid()
        opcode = msg.get_opcode()
        rcode = msg.get_rcode()
481

482
        msg.clear(Message.RENDER)
483
484
485
        msg.set_qid(qid)
        msg.set_opcode(opcode)
        msg.set_rcode(rcode)
486
487
        msg.set_header_flag(Message.HEADERFLAG_AA)
        msg.set_header_flag(Message.HEADERFLAG_QR)
488
489
        return msg

490
491
    def _send_message_with_last_soa(self, msg, sock_fd, rrset_soa,
                                    message_upper_len):
492
493
494
        '''Add the SOA record to the end of message. If it can't be
        added, a new message should be created to send out the last soa .
        '''
495
496
        if (message_upper_len + self._tsig_len + get_rrset_len(rrset_soa) >=
            XFROUT_MAX_MESSAGE_SIZE):
497
498
            self._send_message(sock_fd, msg, self._tsig_ctx)
            msg = self._clear_message(msg)
499

500
501
        # If tsig context exist, sign the last packet
        msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
502
        self._send_message(sock_fd, msg, self._tsig_ctx)
503

504
    def _reply_xfrout_query(self, msg, sock_fd):
505
506
        #TODO, there should be a better way to insert rrset.
        msg.make_response()
507
        msg.set_header_flag(Message.HEADERFLAG_AA)
508
        msg.add_rrset(Message.SECTION_ANSWER, self._soa)
509

510
        message_upper_len = get_rrset_len(self._soa) + self._tsig_len
511

512
513
514
        for rrset in self._iterator:
            # Check if xfrout is shutdown
            if  self._server._shutdown_event.is_set():
515
                logger.info(XFROUT_STOPPING)
516
                return
Jelte Jansen's avatar
Jelte Jansen committed
517

518
519
            if rrset.get_type() == RRType.SOA():
                continue
520
521
522
523

            # We calculate the maximum size of the RRset (i.e. the
            # size without compression) and use that to see if we
            # may have reached the limit
524
            rrset_len = get_rrset_len(rrset)
525
            if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
526
                msg.add_rrset(Message.SECTION_ANSWER, rrset)
527
                message_upper_len += rrset_len
528
529
                continue

530
            self._send_message(sock_fd, msg, self._tsig_ctx)
531

532
            msg = self._clear_message(msg)
533
            # Add the RRset to the new message
534
            msg.add_rrset(Message.SECTION_ANSWER, rrset)
535
536

            # Reserve tsig space for signed packet
537
            message_upper_len = rrset_len + self._tsig_len
538

539
        self._send_message_with_last_soa(msg, sock_fd, self._soa,
540
                                         message_upper_len)
541

542
543
class UnixSockServer(socketserver_mixin.NoPollMixIn,
                     ThreadingUnixStreamServer):
544
545
    '''The unix domain socket server which accept xfr query sent from auth server.'''

546
547
    def __init__(self, sock_file, handle_class, shutdown_event, config_data,
                 cc):
548
        self._remove_unused_sock_file(sock_file)
549
        self._sock_file = sock_file
550
        socketserver_mixin.NoPollMixIn.__init__(self)
551
552
        ThreadingUnixStreamServer.__init__(self, sock_file, handle_class)
        self._shutdown_event = shutdown_event
553
        self._write_sock, self._read_sock = socket.socketpair()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
554
        self._common_init()
555
        self._cc = cc
556
        self.update_config_data(config_data)
557

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
558
    def _common_init(self):
559
        '''Initialization shared with the mock server class used for tests'''
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
560
561
        self._lock = threading.Lock()
        self._transfers_counter = 0
562
563
        self._zone_config = {}
        self._acl = None # this will be initialized in update_config_data()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
564

565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
    def _receive_query_message(self, sock):
        ''' receive request message from sock'''
        # receive data length
        data_len = sock.recv(2)
        if not data_len:
            return None
        msg_len = struct.unpack('!H', data_len)[0]
        # receive data
        recv_size = 0
        msgdata = b''
        while recv_size < msg_len:
            data = sock.recv(msg_len - recv_size)
            if not data:
                return None
            recv_size += len(data)
            msgdata += data

        return msgdata

584
585
586
587
588
    def handle_request(self):
        ''' Enable server handle a request until shutdown or auth is closed.'''
        try:
            request, client_address = self.get_request()
        except socket.error:
589
            logger.error(XFROUT_FETCH_REQUEST_ERROR)
590
591
592
593
594
595
596
597
598
599
600
601
602
            return

        # Check self._shutdown_event to ensure the real shutdown comes.
        # Linux could trigger a spurious readable event on the _read_sock
        # due to a bug, so we need perform a double check.
        while not self._shutdown_event.is_set(): # Check if xfrout is shutdown
            try:
                (rlist, wlist, xlist) = select.select([self._read_sock, request], [], [])
            except select.error as e:
                if e.args[0] == errno.EINTR:
                    (rlist, wlist, xlist) = ([], [], [])
                    continue
                else:
603
                    logger.error(XFROUT_SOCKET_SELECT_ERROR, str(e))
604
605
606
607
608
609
610
611
612
                    break

            # self.server._shutdown_event will be set by now, if it is not a false
            # alarm
            if self._read_sock in rlist:
                continue

            try:
                self.process_request(request)
613
            except Exception as pre:
614
                log.error(XFROUT_PROCESS_REQUEST_ERROR, str(pre))
615
616
                break

617
    def _handle_request_noblock(self):
618
619
        """Override the function _handle_request_noblock(), it creates a new
        thread to handle requests for each auth"""
620
621
622
623
        td = threading.Thread(target=self.handle_request)
        td.setDaemon(True)
        td.start()

624
    def process_request(self, request):
625
626
627
628
629
630
631
        """Receive socket fd and query message from auth, then
        start a new thread to process the request."""
        sock_fd = recv_fd(request.fileno())
        if sock_fd < 0:
            # This may happen when one xfrout process try to connect to
            # xfrout unix socket server, to check whether there is another
            # xfrout running.
632
            if sock_fd == FD_COMM_ERROR:
633
                logger.error(XFROUT_RECEIVE_FILE_DESCRIPTOR_ERROR)
634
635
636
637
638
639
640
            return

        # receive request msg
        request_data = self._receive_query_message(request)
        if not request_data:
            return

641
        t = threading.Thread(target=self.finish_request,
642
                             args = (sock_fd, request_data))
643
644
645
646
        if self.daemon_threads:
            t.daemon = True
        t.start()

647
    def _guess_remote(self, sock_fd):
648
649
650
651
652
653
        """Guess remote address and port of the socket.

        The sock_fd must be a file descriptor of a socket.
        This method retuns a 3-tuple consisting of address family,
        socket type, and a 2-tuple with the address (string) and port (int).

654
655
        """
        # This uses a trick. If the socket is IPv4 in reality and we pretend
656
        # it to be IPv6, it returns IPv4 address anyway. This doesn't seem
657
658
659
660
661
662
663
664
        # to care about the SOCK_STREAM parameter at all (which it really is,
        # except for testing)
        if socket.has_ipv6:
            sock = socket.fromfd(sock_fd, socket.AF_INET6, socket.SOCK_STREAM)
        else:
            # To make it work even on hosts without IPv6 support
            # (Any idea how to simulate this in test?)
            sock = socket.fromfd(sock_fd, socket.AF_INET, socket.SOCK_STREAM)
665
666
667
668
669
670
671
672
673
674
        peer = sock.getpeername()

        # Identify the correct socket family.  Due to the above "trick",
        # we cannot simply use sock.family.
        family = socket.AF_INET6
        try:
            socket.inet_pton(socket.AF_INET6, peer[0])
        except socket.error:
            family = socket.AF_INET
        return (family, socket.SOCK_STREAM, peer)
675

676
    def finish_request(self, sock_fd, request_data):
677
678
        '''Finish one request by instantiating RequestHandlerClass.

679
680
681
        This is an entry point of a separate thread spawned in
        UnixSockServer.process_request().

682
683
        This method creates a XfroutSession object.
        '''
684
685
686
687
        self._lock.acquire()
        acl = self._acl
        zone_config = self._zone_config
        self._lock.release()
688
689
        self.RequestHandlerClass(sock_fd, request_data, self,
                                 self.tsig_key_ring,
690
                                 self._guess_remote(sock_fd), acl, zone_config)
691
692

    def _remove_unused_sock_file(self, sock_file):
693
694
        '''Try to remove the socket file. If the file is being used
        by one running xfrout process, exit from python.
695
696
697
        If it's not a socket file or nobody is listening
        , it will be removed. If it can't be removed, exit from python. '''
        if self._sock_file_in_use(sock_file):
698
            logger.error(XFROUT_UNIX_SOCKET_FILE_IN_USE, sock_file)
699
700
701
702
703
704
705
706
            sys.exit(0)
        else:
            if not os.path.exists(sock_file):
                return

            try:
                os.unlink(sock_file)
            except OSError as err:
707
                logger.error(XFROUT_REMOVE_OLD_UNIX_SOCKET_FILE_ERROR, sock_file, str(err))
708
                sys.exit(0)
709

710
    def _sock_file_in_use(self, sock_file):
711
712
        '''Check whether the socket file 'sock_file' exists and
        is being used by one running xfrout process. If it is,
713
714
715
716
717
718
719
        return True, or else return False. '''
        try:
            sock = socket.socket(socket.AF_UNIX)
            sock.connect(sock_file)
        except socket.error as err:
            return False
        else:
720
            return True
721

722
    def shutdown(self):
723
        self._write_sock.send(b"shutdown") #terminate the xfrout session thread
724
        super().shutdown() # call the shutdown() of class socketserver_mixin.NoPollMixIn
725
726
        try:
            os.unlink(self._sock_file)
Jerry's avatar
Jerry committed
727
        except Exception as e:
Jelte Jansen's avatar
Jelte Jansen committed
728
            logger.error(XFROUT_REMOVE_UNIX_SOCKET_FILE_ERROR, self._sock_file, str(e))
729
            pass
730
731

    def update_config_data(self, new_config):
732
733
734
        '''Apply the new config setting of xfrout module.

        '''
735
        self._lock.acquire()
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
        try:
            logger.info(XFROUT_NEW_CONFIG)
            new_acl = self._acl
            if 'transfer_acl' in new_config:
                try:
                    new_acl = REQUEST_LOADER.load(new_config['transfer_acl'])
                except LoaderError as e:
                    raise XfroutConfigError('Failed to parse transfer_acl: ' +
                                            str(e))

            new_zone_config = self._zone_config
            zconfig_data = new_config.get('zone_config')
            if zconfig_data is not None:
                new_zone_config = self.__create_zone_config(zconfig_data)

            self._acl = new_acl
            self._zone_config = new_zone_config
            self._max_transfers_out = new_config.get('transfers_out')
            self.set_tsig_key_ring(new_config.get('tsig_key_ring'))
        except Exception as e:
            self._lock.release()
            raise e
758
        self._lock.release()
759
        logger.info(XFROUT_NEW_CONFIG_DONE)
760

761
762
763
764
765
    def __create_zone_config(self, zone_config_list):
        new_config = {}
        for zconf in zone_config_list:
            # convert the class, origin (name) pair.  First build pydnspp
            # object to reject invalid input.
766
767
768
769
770
            zclass_str = zconf.get('class')
            if zclass_str is None:
                #zclass_str = 'IN' # temporary
                zclass_str = self._cc.get_default_value('zone_config/class')
            zclass = RRClass(zclass_str)
771
772
773
774
775
            zorigin = Name(zconf['origin'], True)
            config_key = (zclass.to_text(), zorigin.to_text())

            # reject duplicate config
            if config_key in new_config:
776
                raise XfroutConfigError('Duplicate zone_config for ' +
777
                                        str(zorigin) + '/' + str(zclass))
778
779
780
781

            # create a new config entry, build any given (and known) config
            new_config[config_key] = {}
            if 'transfer_acl' in zconf:
782
783
784
785
786
787
788
                try:
                    new_config[config_key]['transfer_acl'] = \
                        REQUEST_LOADER.load(zconf['transfer_acl'])
                except LoaderError as e:
                    raise XfroutConfigError('Failed to parse transfer_acl ' +
                                            'for ' + zorigin.to_text() + '/' +
                                            zclass_str + ': ' + str(e))
789
790
        return new_config

791
    def set_tsig_key_ring(self, key_list):
792
793
794
        """Set the tsig_key_ring , given a TSIG key string list representation. """

        # XXX add values to configure zones/tsig options
795
        self.tsig_key_ring = TSIGKeyRing()
796
        # If key string list is empty, create a empty tsig_key_ring
797
798
799
800
801
802
803
        if not key_list:
            return

        for key_item in key_list:
            try:
                self.tsig_key_ring.add(TSIGKey(key_item))
            except InvalidParameter as ipe:
804
                logger.error(XFROUT_BAD_TSIG_KEY_STRING, str(key_item))
805

806
    def get_db_file(self):
807
808
809
810
811
812
        file, is_default = self._cc.get_remote_config_value("Auth", "database_file")
        # this too should be unnecessary, but currently the
        # 'from build' override isn't stored in the config
        # (and we don't have indirect python access to datasources yet)
        if is_default and "B10_FROM_BUILD" in os.environ:
            file = os.environ["B10_FROM_BUILD"] + os.sep + "bind10_zones.sqlite3"
813
814
        return file

815

816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
    def increase_transfers_counter(self):
        '''Return False, if counter + 1 > max_transfers_out, or else
        return True
        '''
        ret = False
        self._lock.acquire()
        if self._transfers_counter < self._max_transfers_out:
            self._transfers_counter += 1
            ret = True
        self._lock.release()
        return ret

    def decrease_transfers_counter(self):
        self._lock.acquire()
        self._transfers_counter -= 1
        self._lock.release()

class XfroutServer:
    def __init__(self):
835
        self._unix_socket_server = None
836
        self._listen_sock_file = UNIX_SOCKET_FILE
837
        self._shutdown_event = threading.Event()
838
        self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler)
839
840
        self._config_data = self._cc.get_full_config()
        self._cc.start()
841
        self._cc.add_remote_config(AUTH_SPECFILE_LOCATION);
842
        self._start_xfr_query_listener()
843
        self._start_notifier()
844

845
846
    def _start_xfr_query_listener(self):
        '''Start a new thread to accept xfr query. '''
847
848
849
850
        self._unix_socket_server = UnixSockServer(self._listen_sock_file,
                                                  XfroutSession,
                                                  self._shutdown_event,
                                                  self._config_data,
851
                                                  self._cc)
852
        listener = threading.Thread(target=self._unix_socket_server.serve_forever)
853
        listener.start()
854

855
856
    def _start_notifier(self):
        datasrc = self._unix_socket_server.get_db_file()
857
        self._notifier = notify_out.NotifyOut(datasrc)
Michal Vaner's avatar
Michal Vaner committed
858
        self._notifier.dispatcher()
859

860
861
    def send_notify(self, zone_name, zone_class):
        self._notifier.send_notify(zone_name, zone_class)
862
863
864
865
866
867
868
869
870

    def config_handler(self, new_config):
        '''Update config data. TODO. Do error check'''
        answer = create_answer(0)
        for key in new_config:
            if key not in self._config_data:
                answer = create_answer(1, "Unknown config data: " + str(key))
                continue
            self._config_data[key] = new_config[key]
Michal Vaner's avatar
Michal Vaner committed
871

872
        if self._unix_socket_server:
873
874
875
            try:
                self._unix_socket_server.update_config_data(self._config_data)
            except Exception as e:
876
877
878
                answer = create_answer(1,
                                       "Failed to handle new configuration: " +
                                       str(e))
879

880
881
882
883
        return answer


    def shutdown(self):
884
        ''' shutdown the xfrout process. The thread which is doing zone transfer-out should be
885
886
        terminated.
        '''
887
888
889

        global xfrout_server
        xfrout_server = None #Avoid shutdown is called twice
890
        self._shutdown_event.set()
891
        self._notifier.shutdown()
892
893
        if self._unix_socket_server:
            self._unix_socket_server.shutdown()
894

895
        # Wait for all threads to terminate
896
897
898
899
900
901
902
903
        main_thread = threading.currentThread()
        for th in threading.enumerate():
            if th is main_thread:
                continue
            th.join()

    def command_handler(self, cmd, args):
        if cmd == "shutdown":
904
            logger.info(XFROUT_RECEIVED_SHUTDOWN_COMMAND)
905
906
            self.shutdown()
            answer = create_answer(0)
Michal Vaner's avatar
Michal Vaner committed
907

908
        elif cmd == notify_out.ZONE_NEW_DATA_READY_CMD:
909
            zone_name = args.get('zone_name')
910
911
            zone_class = args.get('zone_class')
            if zone_name and zone_class:
912
                logger.info(XFROUT_NOTIFY_COMMAND, zone_name, zone_class)
913
                self.send_notify(zone_name, zone_class)
914
915
916
917
                answer = create_answer(0)
            else:
                answer = create_answer(1, "Bad command parameter:" + str(args))

918
        else:
919
920
            answer = create_answer(1, "Unknown command:" + str(cmd))

Michal Vaner's avatar
Michal Vaner committed
921
        return answer
922
923
924
925

    def run(self):
        '''Get and process all commands sent from cfgmgr or other modules. '''
        while not self._shutdown_event.is_set():
926
            self._cc.check_command(False)
927
928
929
930
931


xfrout_server = None

def signal_handler(signal, frame):
932
    if xfrout_server:
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
        xfrout_server.shutdown()
        sys.exit(0)

def set_signal_handler():
    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

def set_cmd_options(parser):
    parser.add_option("-v", "--verbose", dest="verbose", action="store_true",
            help="display more about what is going on")

if '__main__' == __name__:
    try:
        parser = OptionParser()
        set_cmd_options(parser)
        (options, args) = parser.parse_args()
Jerry's avatar
Jerry committed
949
        VERBOSE_MODE = options.verbose
950
951
952
953
954

        set_signal_handler()
        xfrout_server = XfroutServer()
        xfrout_server.run()
    except KeyboardInterrupt:
955
        logger.INFO(XFROUT_STOPPED_BY_KEYBOARD)
956
    except SessionError as e:
957
        logger.error(XFROUT_CC_SESSION_ERROR, str(e))
958
959
960
961
    except ModuleCCSessionError as e:
        logger.error(XFROUT_MODULECC_SESSION_ERROR, str(e))
    except XfroutConfigError as e:
        logger.error(XFROUT_CONFIG_ERROR, str(e))
962
    except SessionTimeout as e:
963
        logger.error(XFROUT_CC_SESSION_TIMEOUT_ERROR)
964

965
966
967
    if xfrout_server:
        xfrout_server.shutdown()