xfrout.py.in 31.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#!@PYTHON@

# Copyright (C) 2010  Internet Systems Consortium.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.


import sys; sys.path.append ('@@PYTHONPATH@@')
import isc
import isc.cc
import threading
import struct
import signal
25
from isc.datasrc import DataSourceClient
26
27
28
from socketserver import *
import os
from isc.config.ccsession import *
29
from isc.cc import SessionError, SessionTimeout
30
from isc.notify import notify_out
31
import isc.util.process
32
import socket
33
import select
34
import errno
35
from optparse import OptionParser, OptionValueError
Likun Zhang's avatar
Likun Zhang committed
36
from isc.util import socketserver_mixin
37

38
from isc.log_messages.xfrout_messages import *
39
40
41
42

isc.log.init("b10-xfrout")
logger = isc.log.Logger("xfrout")

43
try:
44
    from libutil_io_python import *
Jelte Jansen's avatar
Jelte Jansen committed
45
    from pydnspp import *
46
47
48
except ImportError as e:
    # C++ loadable module may not be installed; even so the xfrout process
    # must keep running, so we warn about it and move forward.
49
    log.error(XFROUT_IMPORT, str(e))
Michal Vaner's avatar
Michal Vaner committed
50

51
from isc.acl.acl import ACCEPT, REJECT, DROP, LoaderError
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
52
53
from isc.acl.dns import REQUEST_LOADER

54
isc.util.process.rename()
55

56
57
58
59
60
61
62
63
64
65
66
67
class XfroutConfigError(Exception):
    """An exception indicating an error in updating xfrout configuration.

    This exception is raised when the xfrout process encouters an error in
    handling configuration updates.  Not all syntax error can be caught
    at the module-CC layer, so xfrout needs to (explicitly or implicitly)
    validate the given configuration data itself.  When it finds an error
    it raises this exception (either directly or by converting an exception
    from other modules) as a unified error in configuration.
    """
    pass

68
69
70
71
72
73
74
75
76
77
78
79
def init_paths():
    global SPECFILE_PATH
    global AUTH_SPECFILE_PATH
    global UNIX_SOCKET_FILE
    if "B10_FROM_BUILD" in os.environ:
        SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/xfrout"
        AUTH_SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/auth"
        if "B10_FROM_SOURCE_LOCALSTATEDIR" in os.environ:
            UNIX_SOCKET_FILE = os.environ["B10_FROM_SOURCE_LOCALSTATEDIR"] + \
                "/auth_xfrout_conn"
        else:
            UNIX_SOCKET_FILE = os.environ["B10_FROM_BUILD"] + "/auth_xfrout_conn"
80
    else:
81
82
83
84
85
86
87
        PREFIX = "@prefix@"
        DATAROOTDIR = "@datarootdir@"
        SPECFILE_PATH = "@datadir@/@PACKAGE@".replace("${datarootdir}", DATAROOTDIR).replace("${prefix}", PREFIX)
        AUTH_SPECFILE_PATH = SPECFILE_PATH
        if "BIND10_XFROUT_SOCKET_FILE" in os.environ:
            UNIX_SOCKET_FILE = os.environ["BIND10_XFROUT_SOCKET_FILE"]
        else:
88
            UNIX_SOCKET_FILE = "@@LOCALSTATEDIR@@/@PACKAGE_NAME@/auth_xfrout_conn"
89
90

init_paths()
91

92
SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec"
93
AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
Jerry's avatar
Jerry committed
94
VERBOSE_MODE = False
95
96
# tsig sign every N axfr packets.
TSIG_SIGN_EVERY_NTH = 96
97

98
99
XFROUT_MAX_MESSAGE_SIZE = 65535

100
101
102
103
104
105
106
107
108
109
# borrowed from xfrin.py @ #1298.  We should eventually unify it.
def format_zone_str(zone_name, zone_class):
    """Helper function to format a zone name and class as a string of
       the form '<name>/<class>'.
       Parameters:
       zone_name (isc.dns.Name) name to format
       zone_class (isc.dns.RRClass) class to format
    """
    return zone_name.to_text() + '/' + str(zone_class)

110
111
112
113
114
115
116
def get_rrset_len(rrset):
    """Returns the wire length of the given RRset"""
    bytes = bytearray()
    rrset.to_wire(bytes)
    return len(bytes)


117
class XfroutSession():
118
    def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
119
                 default_acl, zone_config, client_class=DataSourceClient):
120
121
122
        self._sock_fd = sock_fd
        self._request_data = request_data
        self._server = server
123
124
125
        self._tsig_key_ring = tsig_key_ring
        self._tsig_ctx = None
        self._tsig_len = 0
126
        self._remote = remote
127
128
        self._acl = default_acl
        self._zone_config = zone_config
129
        self.ClientClass = client_class # parameterize this for testing
130
        self._soa = None # will be set in _check_xfrout_available or in tests
131
        self.handle()
Jerry's avatar
Jerry committed
132

133
134
135
136
    def create_tsig_ctx(self, tsig_record, tsig_key_ring):
        return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(),
                           tsig_key_ring)

137
    def handle(self):
138
        ''' Handle a xfrout query, send xfrout response '''
139
        try:
140
            self.dns_xfrout_start(self._sock_fd, self._request_data)
141
142
            #TODO, avoid catching all exceptions
        except Exception as e:
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
143
            logger.error(XFROUT_HANDLE_QUERY_ERROR, e)
144
            pass
145

146
        os.close(self._sock_fd)
147

148
149
150
151
152
153
154
155
156
157
158
159
    def _check_request_tsig(self, msg, request_data):
        ''' If request has a tsig record, perform tsig related checks '''
        tsig_record = msg.get_tsig_record()
        if tsig_record is not None:
            self._tsig_len = tsig_record.get_length()
            self._tsig_ctx = self.create_tsig_ctx(tsig_record, self._tsig_key_ring)
            tsig_error = self._tsig_ctx.verify(tsig_record, request_data)
            if tsig_error != TSIGError.NOERROR:
                return Rcode.NOTAUTH()

        return Rcode.NOERROR()

160
161
    def _parse_query_message(self, mdata):
        ''' parse query message to [socket,message]'''
162
        #TODO, need to add parseHeader() in case the message header is invalid
163
        try:
164
            msg = Message(Message.PARSE)
165
            Message.from_wire(msg, mdata)
166
        except Exception as err: # Exception is too broad
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
167
            logger.error(XFROUT_PARSE_QUERY_ERROR, err)
168
            return Rcode.FORMERR(), None
169

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        # TSIG related checks
        rcode = self._check_request_tsig(msg, mdata)

        if rcode == Rcode.NOERROR():
            # ACL checks
            zone_name = msg.get_question()[0].get_name()
            zone_class = msg.get_question()[0].get_class()
            acl = self._get_transfer_acl(zone_name, zone_class)
            acl_result = acl.execute(
                isc.acl.dns.RequestContext(self._remote,
                                           msg.get_tsig_record()))
            if acl_result == DROP:
                logger.info(XFROUT_QUERY_DROPPED, zone_name, zone_class,
                            self._remote[0], self._remote[1])
                return None, None
            elif acl_result == REJECT:
                logger.info(XFROUT_QUERY_REJECTED, zone_name, zone_class,
                            self._remote[0], self._remote[1])
                return Rcode.REFUSED(), msg

190
        return rcode, msg
191

192
193
194
195
196
197
198
199
200
201
202
203
    def _get_transfer_acl(self, zone_name, zone_class):
        '''Return the ACL that should be applied for a given zone.

        The zone is identified by a tuple of name and RR class.
        If a per zone configuration for the zone exists and contains
        transfer_acl, that ACL will be used; otherwise, the default
        ACL will be used.

        '''
        # Internally zone names are managed in lower cased label characters,
        # so we first need to convert the name.
        zone_name_lower = Name(zone_name.to_text(), True)
204
        config_key = (zone_class.to_text(), zone_name_lower.to_text())
205
206
207
208
209
        if config_key in self._zone_config and \
                'transfer_acl' in self._zone_config[config_key]:
            return self._zone_config[config_key]['transfer_acl']
        return self._acl

210
    def _send_data(self, sock_fd, data):
211
212
213
        size = len(data)
        total_count = 0
        while total_count < size:
214
            count = os.write(sock_fd, data[total_count:])
215
216
217
            total_count += count


218
    def _send_message(self, sock_fd, msg, tsig_ctx=None):
219
        render = MessageRenderer()
220
221
        # As defined in RFC5936 section3.4, perform case-preserving name
        # compression for AXFR message.
222
        render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
223
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
224
225
226
227
228
229
230
231

        # XXX Currently, python wrapper doesn't accept 'None' parameter in this case,
        # we should remove the if statement and use a universal interface later.
        if tsig_ctx is not None:
            msg.to_wire(render, tsig_ctx)
        else:
            msg.to_wire(render)

232
        header_len = struct.pack('H', socket.htons(render.get_length()))
233
234
        self._send_data(sock_fd, header_len)
        self._send_data(sock_fd, render.get_data())
235
236


237
    def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
238
        if not msg:
239
            return # query message is invalid. send nothing back.
240
241

        msg.make_response()
242
        msg.set_rcode(rcode_)
243
        self._send_message(sock_fd, msg, self._tsig_ctx)
244
245
246
247

    def _check_xfrout_available(self, zone_name):
        '''Check if xfr request can be responsed.
           TODO, Get zone's configuration from cfgmgr or some other place
248
           eg. check allow_transfer setting,
249
        '''
250
251
252
253
254
255

        # Reject the attempt if we are too busy.  Check this first to avoid
        # unnecessary resource consumption even if we discard it soon.
        if not self._server.increase_transfers_counter():
            return Rcode.REFUSED()

256
257
        # Identify the data source for the requested zone and see if it has
        # SOA while initializing objects used for request processing later.
JINMEI Tatuya's avatar
JINMEI Tatuya committed
258
259
260
        # We should eventually generalize this so that we can choose the
        # appropriate data source from (possible) multiple candidates.
        # We should eventually take into account the RR class here.
261
262
263
264
265
266
267
268
269
270
271
272
        # For now, we  hardcode a particular type (SQLite3-based), and only
        # consider that one.
        datasrc_config = '{ \"database_file\": \"' + \
            self._server.get_db_file() + '\"}'
        self._datasrc_client = self.ClientClass('sqlite3', datasrc_config)
        try:
            self._iterator = self._datasrc_client.get_iterator(zone_name)
        except isc.datasrc.Error as error:
            # If the current name server does not have authority for the
            # zone, xfrout can't serve for it, return rcode NOTAUTH.
            # Note: this exception can happen for other reasons.  We should
            # update get_iterator() API so that we can distinguish "no such
JINMEI Tatuya's avatar
JINMEI Tatuya committed
273
274
            # zone" and other cases (#1373).  For now we consider all these
            # cases as NOTAUTH.
275
            return Rcode.NOTAUTH()
276

JINMEI Tatuya's avatar
JINMEI Tatuya committed
277
278
279
        # If we are an authoritative name server for the zone, but fail
        # to find the zone's SOA record in datasource, xfrout can't
        # provide zone transfer for it.
280
281
        self._soa = self._iterator.get_soa()
        if self._soa is None:
282
            return Rcode.SERVFAIL()
283
284
285

        #TODO, check allow_transfer

286
        return Rcode.NOERROR()
287
288


289
    def dns_xfrout_start(self, sock_fd, msg_query):
290
291
        rcode_, msg = self._parse_query_message(msg_query)
        #TODO. create query message and parse header
292
293
294
        if rcode_ is None: # Dropped by ACL
            return
        elif rcode_ == Rcode.NOTAUTH() or rcode_ == Rcode.REFUSED():
295
            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
296
        elif rcode_ != Rcode.NOERROR():
297
298
            return self._reply_query_with_error_rcode(msg, sock_fd,
                                                      Rcode.FORMERR())
299

300
301
302
303
        question = msg.get_question()[0]
        zone_name = question.get_name()
        zone_class = question.get_class()
        zone_str = format_zone_str(zone_name, zone_class) # for logging
304

305
306
        # TODO: we should also include class in the check
        rcode_ = self._check_xfrout_available(zone_name)
307
        if rcode_ != Rcode.NOERROR():
308
            logger.info(XFROUT_AXFR_TRANSFER_FAILED, zone_str, rcode_)
309
            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
310
311

        try:
312
            logger.info(XFROUT_AXFR_TRANSFER_STARTED, zone_str)
313
            self._reply_xfrout_query(msg, sock_fd)
314
        except Exception as err:
315
            logger.error(XFROUT_AXFR_TRANSFER_ERROR, zone_str, err)
316
            pass
317
        logger.info(XFROUT_AXFR_TRANSFER_DONE, zone_str)
318

319
        self._server.decrease_transfers_counter()
320
321
322
323
324

    def _clear_message(self, msg):
        qid = msg.get_qid()
        opcode = msg.get_opcode()
        rcode = msg.get_rcode()
325

326
        msg.clear(Message.RENDER)
327
328
329
        msg.set_qid(qid)
        msg.set_opcode(opcode)
        msg.set_rcode(rcode)
330
331
        msg.set_header_flag(Message.HEADERFLAG_AA)
        msg.set_header_flag(Message.HEADERFLAG_QR)
332
333
        return msg

334
335
    def _send_message_with_last_soa(self, msg, sock_fd, rrset_soa, message_upper_len,
                                    count_since_last_tsig_sign):
336
337
338
        '''Add the SOA record to the end of message. If it can't be
        added, a new message should be created to send out the last soa .
        '''
339
        rrset_len = get_rrset_len(rrset_soa)
340

341
342
343
344
345
346
347
        if (count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH and
            message_upper_len + rrset_len >= XFROUT_MAX_MESSAGE_SIZE):
            # If tsig context exist, sign the packet with serial number TSIG_SIGN_EVERY_NTH
            self._send_message(sock_fd, msg, self._tsig_ctx)
            msg = self._clear_message(msg)
        elif (count_since_last_tsig_sign != TSIG_SIGN_EVERY_NTH and
              message_upper_len + rrset_len + self._tsig_len >= XFROUT_MAX_MESSAGE_SIZE):
348
            self._send_message(sock_fd, msg)
349
350
            msg = self._clear_message(msg)

351
352
        # If tsig context exist, sign the last packet
        msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
353
        self._send_message(sock_fd, msg, self._tsig_ctx)
354
355


356
    def _reply_xfrout_query(self, msg, sock_fd):
357
        #TODO, there should be a better way to insert rrset.
358
        count_since_last_tsig_sign = TSIG_SIGN_EVERY_NTH
359
        msg.make_response()
360
        msg.set_header_flag(Message.HEADERFLAG_AA)
361
        msg.add_rrset(Message.SECTION_ANSWER, self._soa)
362

363
        message_upper_len = get_rrset_len(self._soa) + self._tsig_len
364

365
366
367
        for rrset in self._iterator:
            # Check if xfrout is shutdown
            if  self._server._shutdown_event.is_set():
368
                logger.info(XFROUT_STOPPING)
369
                return
Jelte Jansen's avatar
Jelte Jansen committed
370

371
372
            if rrset.get_type() == RRType.SOA():
                continue
373
374
375
376

            # We calculate the maximum size of the RRset (i.e. the
            # size without compression) and use that to see if we
            # may have reached the limit
377
            rrset_len = get_rrset_len(rrset)
378
            if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
379
                msg.add_rrset(Message.SECTION_ANSWER, rrset)
380
                message_upper_len += rrset_len
381
382
                continue

383
            # If tsig context exist, sign every N packets
384
385
386
387
388
389
390
            if count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH:
                count_since_last_tsig_sign = 0
                self._send_message(sock_fd, msg, self._tsig_ctx)
            else:
                self._send_message(sock_fd, msg)

            count_since_last_tsig_sign += 1
391
            msg = self._clear_message(msg)
392
            # Add the RRset to the new message
393
            msg.add_rrset(Message.SECTION_ANSWER, rrset)
394
395
396
397
398
399

            # Reserve tsig space for signed packet
            if count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH:
                message_upper_len = rrset_len + self._tsig_len
            else:
                message_upper_len = rrset_len
400

401
402
        self._send_message_with_last_soa(msg, sock_fd, self._soa,
                                         message_upper_len,
403
                                         count_since_last_tsig_sign)
404

405
406
class UnixSockServer(socketserver_mixin.NoPollMixIn,
                     ThreadingUnixStreamServer):
407
408
    '''The unix domain socket server which accept xfr query sent from auth server.'''

409
410
    def __init__(self, sock_file, handle_class, shutdown_event, config_data,
                 cc):
411
        self._remove_unused_sock_file(sock_file)
412
        self._sock_file = sock_file
413
        socketserver_mixin.NoPollMixIn.__init__(self)
414
415
        ThreadingUnixStreamServer.__init__(self, sock_file, handle_class)
        self._shutdown_event = shutdown_event
416
        self._write_sock, self._read_sock = socket.socketpair()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
417
        self._common_init()
418
        self._cc = cc
419
        self.update_config_data(config_data)
420

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
421
    def _common_init(self):
422
        '''Initialization shared with the mock server class used for tests'''
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
423
424
        self._lock = threading.Lock()
        self._transfers_counter = 0
425
426
        self._zone_config = {}
        self._acl = None # this will be initialized in update_config_data()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
427

428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    def _receive_query_message(self, sock):
        ''' receive request message from sock'''
        # receive data length
        data_len = sock.recv(2)
        if not data_len:
            return None
        msg_len = struct.unpack('!H', data_len)[0]
        # receive data
        recv_size = 0
        msgdata = b''
        while recv_size < msg_len:
            data = sock.recv(msg_len - recv_size)
            if not data:
                return None
            recv_size += len(data)
            msgdata += data

        return msgdata

447
448
449
450
451
    def handle_request(self):
        ''' Enable server handle a request until shutdown or auth is closed.'''
        try:
            request, client_address = self.get_request()
        except socket.error:
452
            logger.error(XFROUT_FETCH_REQUEST_ERROR)
453
454
455
456
457
458
459
460
461
462
463
464
465
            return

        # Check self._shutdown_event to ensure the real shutdown comes.
        # Linux could trigger a spurious readable event on the _read_sock
        # due to a bug, so we need perform a double check.
        while not self._shutdown_event.is_set(): # Check if xfrout is shutdown
            try:
                (rlist, wlist, xlist) = select.select([self._read_sock, request], [], [])
            except select.error as e:
                if e.args[0] == errno.EINTR:
                    (rlist, wlist, xlist) = ([], [], [])
                    continue
                else:
466
                    logger.error(XFROUT_SOCKET_SELECT_ERROR, str(e))
467
468
469
470
471
472
473
474
475
                    break

            # self.server._shutdown_event will be set by now, if it is not a false
            # alarm
            if self._read_sock in rlist:
                continue

            try:
                self.process_request(request)
476
            except Exception as pre:
477
                log.error(XFROUT_PROCESS_REQUEST_ERROR, str(pre))
478
479
                break

480
    def _handle_request_noblock(self):
481
482
        """Override the function _handle_request_noblock(), it creates a new
        thread to handle requests for each auth"""
483
484
485
486
        td = threading.Thread(target=self.handle_request)
        td.setDaemon(True)
        td.start()

487
    def process_request(self, request):
488
489
490
491
492
493
494
        """Receive socket fd and query message from auth, then
        start a new thread to process the request."""
        sock_fd = recv_fd(request.fileno())
        if sock_fd < 0:
            # This may happen when one xfrout process try to connect to
            # xfrout unix socket server, to check whether there is another
            # xfrout running.
495
            if sock_fd == FD_COMM_ERROR:
496
                logger.error(XFROUT_RECEIVE_FILE_DESCRIPTOR_ERROR)
497
498
499
500
501
502
503
            return

        # receive request msg
        request_data = self._receive_query_message(request)
        if not request_data:
            return

504
        t = threading.Thread(target=self.finish_request,
505
                             args = (sock_fd, request_data))
506
507
508
509
        if self.daemon_threads:
            t.daemon = True
        t.start()

510
511
    def _guess_remote(self, sock_fd):
        """
512
           Guess remote address and port of the socket. The sock_fd must be a
513
514
515
           socket
        """
        # This uses a trick. If the socket is IPv4 in reality and we pretend
516
        # it to be IPv6, it returns IPv4 address anyway. This doesn't seem
517
518
519
520
521
522
523
524
525
        # to care about the SOCK_STREAM parameter at all (which it really is,
        # except for testing)
        if socket.has_ipv6:
            sock = socket.fromfd(sock_fd, socket.AF_INET6, socket.SOCK_STREAM)
        else:
            # To make it work even on hosts without IPv6 support
            # (Any idea how to simulate this in test?)
            sock = socket.fromfd(sock_fd, socket.AF_INET, socket.SOCK_STREAM)
        return sock.getpeername()
526

527
    def finish_request(self, sock_fd, request_data):
528
529
        '''Finish one request by instantiating RequestHandlerClass.

530
531
532
        This is an entry point of a separate thread spawned in
        UnixSockServer.process_request().

533
534
        This method creates a XfroutSession object.
        '''
535
536
537
538
        self._lock.acquire()
        acl = self._acl
        zone_config = self._zone_config
        self._lock.release()
539
540
        self.RequestHandlerClass(sock_fd, request_data, self,
                                 self.tsig_key_ring,
541
                                 self._guess_remote(sock_fd), acl, zone_config)
542
543

    def _remove_unused_sock_file(self, sock_file):
544
545
        '''Try to remove the socket file. If the file is being used
        by one running xfrout process, exit from python.
546
547
548
        If it's not a socket file or nobody is listening
        , it will be removed. If it can't be removed, exit from python. '''
        if self._sock_file_in_use(sock_file):
549
            logger.error(XFROUT_UNIX_SOCKET_FILE_IN_USE, sock_file)
550
551
552
553
554
555
556
557
            sys.exit(0)
        else:
            if not os.path.exists(sock_file):
                return

            try:
                os.unlink(sock_file)
            except OSError as err:
558
                logger.error(XFROUT_REMOVE_OLD_UNIX_SOCKET_FILE_ERROR, sock_file, str(err))
559
                sys.exit(0)
560

561
    def _sock_file_in_use(self, sock_file):
562
563
        '''Check whether the socket file 'sock_file' exists and
        is being used by one running xfrout process. If it is,
564
565
566
567
568
569
570
        return True, or else return False. '''
        try:
            sock = socket.socket(socket.AF_UNIX)
            sock.connect(sock_file)
        except socket.error as err:
            return False
        else:
571
            return True
572

573
    def shutdown(self):
574
        self._write_sock.send(b"shutdown") #terminate the xfrout session thread
575
        super().shutdown() # call the shutdown() of class socketserver_mixin.NoPollMixIn
576
577
        try:
            os.unlink(self._sock_file)
Jerry's avatar
Jerry committed
578
        except Exception as e:
Jelte Jansen's avatar
Jelte Jansen committed
579
            logger.error(XFROUT_REMOVE_UNIX_SOCKET_FILE_ERROR, self._sock_file, str(e))
580
            pass
581
582

    def update_config_data(self, new_config):
583
584
585
        '''Apply the new config setting of xfrout module.

        '''
586
        self._lock.acquire()
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
        try:
            logger.info(XFROUT_NEW_CONFIG)
            new_acl = self._acl
            if 'transfer_acl' in new_config:
                try:
                    new_acl = REQUEST_LOADER.load(new_config['transfer_acl'])
                except LoaderError as e:
                    raise XfroutConfigError('Failed to parse transfer_acl: ' +
                                            str(e))

            new_zone_config = self._zone_config
            zconfig_data = new_config.get('zone_config')
            if zconfig_data is not None:
                new_zone_config = self.__create_zone_config(zconfig_data)

            self._acl = new_acl
            self._zone_config = new_zone_config
            self._max_transfers_out = new_config.get('transfers_out')
            self.set_tsig_key_ring(new_config.get('tsig_key_ring'))
        except Exception as e:
            self._lock.release()
            raise e
609
        self._lock.release()
610
        logger.info(XFROUT_NEW_CONFIG_DONE)
611

612
613
614
615
616
    def __create_zone_config(self, zone_config_list):
        new_config = {}
        for zconf in zone_config_list:
            # convert the class, origin (name) pair.  First build pydnspp
            # object to reject invalid input.
617
618
619
620
621
            zclass_str = zconf.get('class')
            if zclass_str is None:
                #zclass_str = 'IN' # temporary
                zclass_str = self._cc.get_default_value('zone_config/class')
            zclass = RRClass(zclass_str)
622
623
624
625
626
            zorigin = Name(zconf['origin'], True)
            config_key = (zclass.to_text(), zorigin.to_text())

            # reject duplicate config
            if config_key in new_config:
627
                raise XfroutConfigError('Duplicate zone_config for ' +
628
                                        str(zorigin) + '/' + str(zclass))
629
630
631
632

            # create a new config entry, build any given (and known) config
            new_config[config_key] = {}
            if 'transfer_acl' in zconf:
633
634
635
636
637
638
639
                try:
                    new_config[config_key]['transfer_acl'] = \
                        REQUEST_LOADER.load(zconf['transfer_acl'])
                except LoaderError as e:
                    raise XfroutConfigError('Failed to parse transfer_acl ' +
                                            'for ' + zorigin.to_text() + '/' +
                                            zclass_str + ': ' + str(e))
640
641
        return new_config

642
    def set_tsig_key_ring(self, key_list):
643
644
645
        """Set the tsig_key_ring , given a TSIG key string list representation. """

        # XXX add values to configure zones/tsig options
646
        self.tsig_key_ring = TSIGKeyRing()
647
        # If key string list is empty, create a empty tsig_key_ring
648
649
650
651
652
653
654
        if not key_list:
            return

        for key_item in key_list:
            try:
                self.tsig_key_ring.add(TSIGKey(key_item))
            except InvalidParameter as ipe:
655
                logger.error(XFROUT_BAD_TSIG_KEY_STRING, str(key_item))
656

657
    def get_db_file(self):
658
659
660
661
662
663
        file, is_default = self._cc.get_remote_config_value("Auth", "database_file")
        # this too should be unnecessary, but currently the
        # 'from build' override isn't stored in the config
        # (and we don't have indirect python access to datasources yet)
        if is_default and "B10_FROM_BUILD" in os.environ:
            file = os.environ["B10_FROM_BUILD"] + os.sep + "bind10_zones.sqlite3"
664
665
        return file

666

667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
    def increase_transfers_counter(self):
        '''Return False, if counter + 1 > max_transfers_out, or else
        return True
        '''
        ret = False
        self._lock.acquire()
        if self._transfers_counter < self._max_transfers_out:
            self._transfers_counter += 1
            ret = True
        self._lock.release()
        return ret

    def decrease_transfers_counter(self):
        self._lock.acquire()
        self._transfers_counter -= 1
        self._lock.release()

class XfroutServer:
    def __init__(self):
686
        self._unix_socket_server = None
687
        self._listen_sock_file = UNIX_SOCKET_FILE
688
        self._shutdown_event = threading.Event()
689
        self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler)
690
691
        self._config_data = self._cc.get_full_config()
        self._cc.start()
692
        self._cc.add_remote_config(AUTH_SPECFILE_LOCATION);
693
        self._start_xfr_query_listener()
694
        self._start_notifier()
695

696
697
    def _start_xfr_query_listener(self):
        '''Start a new thread to accept xfr query. '''
698
699
700
701
        self._unix_socket_server = UnixSockServer(self._listen_sock_file,
                                                  XfroutSession,
                                                  self._shutdown_event,
                                                  self._config_data,
702
                                                  self._cc)
703
        listener = threading.Thread(target=self._unix_socket_server.serve_forever)
704
        listener.start()
705

706
707
    def _start_notifier(self):
        datasrc = self._unix_socket_server.get_db_file()
708
        self._notifier = notify_out.NotifyOut(datasrc)
Michal Vaner's avatar
Michal Vaner committed
709
        self._notifier.dispatcher()
710

711
712
    def send_notify(self, zone_name, zone_class):
        self._notifier.send_notify(zone_name, zone_class)
713
714
715
716
717
718
719
720
721

    def config_handler(self, new_config):
        '''Update config data. TODO. Do error check'''
        answer = create_answer(0)
        for key in new_config:
            if key not in self._config_data:
                answer = create_answer(1, "Unknown config data: " + str(key))
                continue
            self._config_data[key] = new_config[key]
Michal Vaner's avatar
Michal Vaner committed
722

723
        if self._unix_socket_server:
724
725
726
            try:
                self._unix_socket_server.update_config_data(self._config_data)
            except Exception as e:
727
728
729
                answer = create_answer(1,
                                       "Failed to handle new configuration: " +
                                       str(e))
730

731
732
733
734
        return answer


    def shutdown(self):
735
        ''' shutdown the xfrout process. The thread which is doing zone transfer-out should be
736
737
        terminated.
        '''
738
739
740

        global xfrout_server
        xfrout_server = None #Avoid shutdown is called twice
741
        self._shutdown_event.set()
742
        self._notifier.shutdown()
743
744
        if self._unix_socket_server:
            self._unix_socket_server.shutdown()
745

746
        # Wait for all threads to terminate
747
748
749
750
751
752
753
754
        main_thread = threading.currentThread()
        for th in threading.enumerate():
            if th is main_thread:
                continue
            th.join()

    def command_handler(self, cmd, args):
        if cmd == "shutdown":
755
            logger.info(XFROUT_RECEIVED_SHUTDOWN_COMMAND)
756
757
            self.shutdown()
            answer = create_answer(0)
Michal Vaner's avatar
Michal Vaner committed
758

759
        elif cmd == notify_out.ZONE_NEW_DATA_READY_CMD:
760
            zone_name = args.get('zone_name')
761
762
            zone_class = args.get('zone_class')
            if zone_name and zone_class:
763
                logger.info(XFROUT_NOTIFY_COMMAND, zone_name, zone_class)
764
                self.send_notify(zone_name, zone_class)
765
766
767
768
                answer = create_answer(0)
            else:
                answer = create_answer(1, "Bad command parameter:" + str(args))

769
        else:
770
771
            answer = create_answer(1, "Unknown command:" + str(cmd))

Michal Vaner's avatar
Michal Vaner committed
772
        return answer
773
774
775
776

    def run(self):
        '''Get and process all commands sent from cfgmgr or other modules. '''
        while not self._shutdown_event.is_set():
777
            self._cc.check_command(False)
778
779
780
781
782


xfrout_server = None

def signal_handler(signal, frame):
783
    if xfrout_server:
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
        xfrout_server.shutdown()
        sys.exit(0)

def set_signal_handler():
    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

def set_cmd_options(parser):
    parser.add_option("-v", "--verbose", dest="verbose", action="store_true",
            help="display more about what is going on")

if '__main__' == __name__:
    try:
        parser = OptionParser()
        set_cmd_options(parser)
        (options, args) = parser.parse_args()
Jerry's avatar
Jerry committed
800
        VERBOSE_MODE = options.verbose
801
802
803
804
805

        set_signal_handler()
        xfrout_server = XfroutServer()
        xfrout_server.run()
    except KeyboardInterrupt:
806
        logger.INFO(XFROUT_STOPPED_BY_KEYBOARD)
807
    except SessionError as e:
808
        logger.error(XFROUT_CC_SESSION_ERROR, str(e))
809
810
811
812
    except ModuleCCSessionError as e:
        logger.error(XFROUT_MODULECC_SESSION_ERROR, str(e))
    except XfroutConfigError as e:
        logger.error(XFROUT_CONFIG_ERROR, str(e))
813
    except SessionTimeout as e:
814
        logger.error(XFROUT_CC_SESSION_TIMEOUT_ERROR)
815

816
817
818
    if xfrout_server:
        xfrout_server.shutdown()