xfrout.py.in 31.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#!@PYTHON@

# Copyright (C) 2010  Internet Systems Consortium.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.


import sys; sys.path.append ('@@PYTHONPATH@@')
import isc
import isc.cc
import threading
import struct
import signal
Evan Hunt's avatar
Evan Hunt committed
25
from isc.datasrc import sqlite3_ds
26
27
28
from socketserver import *
import os
from isc.config.ccsession import *
29
from isc.cc import SessionError, SessionTimeout
30
from isc.notify import notify_out
31
import isc.util.process
32
import socket
33
import select
34
import errno
35
from optparse import OptionParser, OptionValueError
Likun Zhang's avatar
Likun Zhang committed
36
from isc.util import socketserver_mixin
37

38
from isc.log_messages.xfrout_messages import *
39
40
41
42

isc.log.init("b10-xfrout")
logger = isc.log.Logger("xfrout")

43
try:
44
    from libutil_io_python import *
Jelte Jansen's avatar
Jelte Jansen committed
45
    from pydnspp import *
46
47
48
except ImportError as e:
    # C++ loadable module may not be installed; even so the xfrout process
    # must keep running, so we warn about it and move forward.
49
    log.error(XFROUT_IMPORT, str(e))
Michal Vaner's avatar
Michal Vaner committed
50

51
from isc.acl.acl import ACCEPT, REJECT, DROP, LoaderError
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
52
53
from isc.acl.dns import REQUEST_LOADER

54
isc.util.process.rename()
55

56
57
58
59
60
61
62
63
64
65
66
67
class XfroutConfigError(Exception):
    """An exception indicating an error in updating xfrout configuration.

    This exception is raised when the xfrout process encouters an error in
    handling configuration updates.  Not all syntax error can be caught
    at the module-CC layer, so xfrout needs to (explicitly or implicitly)
    validate the given configuration data itself.  When it finds an error
    it raises this exception (either directly or by converting an exception
    from other modules) as a unified error in configuration.
    """
    pass

68
69
70
71
72
73
74
75
76
77
78
79
def init_paths():
    global SPECFILE_PATH
    global AUTH_SPECFILE_PATH
    global UNIX_SOCKET_FILE
    if "B10_FROM_BUILD" in os.environ:
        SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/xfrout"
        AUTH_SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/auth"
        if "B10_FROM_SOURCE_LOCALSTATEDIR" in os.environ:
            UNIX_SOCKET_FILE = os.environ["B10_FROM_SOURCE_LOCALSTATEDIR"] + \
                "/auth_xfrout_conn"
        else:
            UNIX_SOCKET_FILE = os.environ["B10_FROM_BUILD"] + "/auth_xfrout_conn"
80
    else:
81
82
83
84
85
86
87
88
89
90
        PREFIX = "@prefix@"
        DATAROOTDIR = "@datarootdir@"
        SPECFILE_PATH = "@datadir@/@PACKAGE@".replace("${datarootdir}", DATAROOTDIR).replace("${prefix}", PREFIX)
        AUTH_SPECFILE_PATH = SPECFILE_PATH
        if "BIND10_XFROUT_SOCKET_FILE" in os.environ:
            UNIX_SOCKET_FILE = os.environ["BIND10_XFROUT_SOCKET_FILE"]
        else:
            UNIX_SOCKET_FILE = "@@LOCALSTATEDIR@@/auth_xfrout_conn"

init_paths()
91

92
SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec"
93
AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
Jerry's avatar
Jerry committed
94
VERBOSE_MODE = False
95
96
# tsig sign every N axfr packets.
TSIG_SIGN_EVERY_NTH = 96
97

98
99
100
101
102
103
104
105
106
XFROUT_MAX_MESSAGE_SIZE = 65535

def get_rrset_len(rrset):
    """Returns the wire length of the given RRset"""
    bytes = bytearray()
    rrset.to_wire(bytes)
    return len(bytes)


107
class XfroutSession():
108
    def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
109
                 default_acl, zone_config):
110
111
112
        self._sock_fd = sock_fd
        self._request_data = request_data
        self._server = server
113
114
115
        self._tsig_key_ring = tsig_key_ring
        self._tsig_ctx = None
        self._tsig_len = 0
116
        self._remote = remote
117
118
        self._acl = default_acl
        self._zone_config = zone_config
119
        self.handle()
Jerry's avatar
Jerry committed
120

121
122
123
124
    def create_tsig_ctx(self, tsig_record, tsig_key_ring):
        return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(),
                           tsig_key_ring)

125
    def handle(self):
126
        ''' Handle a xfrout query, send xfrout response '''
127
        try:
128
            self.dns_xfrout_start(self._sock_fd, self._request_data)
129
130
            #TODO, avoid catching all exceptions
        except Exception as e:
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
131
            logger.error(XFROUT_HANDLE_QUERY_ERROR, e)
132
            pass
133

134
        os.close(self._sock_fd)
135

136
137
138
139
140
141
142
143
144
145
146
147
    def _check_request_tsig(self, msg, request_data):
        ''' If request has a tsig record, perform tsig related checks '''
        tsig_record = msg.get_tsig_record()
        if tsig_record is not None:
            self._tsig_len = tsig_record.get_length()
            self._tsig_ctx = self.create_tsig_ctx(tsig_record, self._tsig_key_ring)
            tsig_error = self._tsig_ctx.verify(tsig_record, request_data)
            if tsig_error != TSIGError.NOERROR:
                return Rcode.NOTAUTH()

        return Rcode.NOERROR()

148
149
    def _parse_query_message(self, mdata):
        ''' parse query message to [socket,message]'''
150
        #TODO, need to add parseHeader() in case the message header is invalid
151
        try:
152
            msg = Message(Message.PARSE)
153
            Message.from_wire(msg, mdata)
154
        except Exception as err: # Exception is too broad
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
155
            logger.error(XFROUT_PARSE_QUERY_ERROR, err)
156
            return Rcode.FORMERR(), None
157

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
        # TSIG related checks
        rcode = self._check_request_tsig(msg, mdata)

        if rcode == Rcode.NOERROR():
            # ACL checks
            zone_name = msg.get_question()[0].get_name()
            zone_class = msg.get_question()[0].get_class()
            acl = self._get_transfer_acl(zone_name, zone_class)
            acl_result = acl.execute(
                isc.acl.dns.RequestContext(self._remote,
                                           msg.get_tsig_record()))
            if acl_result == DROP:
                logger.info(XFROUT_QUERY_DROPPED, zone_name, zone_class,
                            self._remote[0], self._remote[1])
                return None, None
            elif acl_result == REJECT:
                logger.info(XFROUT_QUERY_REJECTED, zone_name, zone_class,
                            self._remote[0], self._remote[1])
                return Rcode.REFUSED(), msg

178
        return rcode, msg
179

180
181
182
183
184
185
186
187
188
189
190
191
    def _get_transfer_acl(self, zone_name, zone_class):
        '''Return the ACL that should be applied for a given zone.

        The zone is identified by a tuple of name and RR class.
        If a per zone configuration for the zone exists and contains
        transfer_acl, that ACL will be used; otherwise, the default
        ACL will be used.

        '''
        # Internally zone names are managed in lower cased label characters,
        # so we first need to convert the name.
        zone_name_lower = Name(zone_name.to_text(), True)
192
        config_key = (zone_class.to_text(), zone_name_lower.to_text())
193
194
195
196
197
        if config_key in self._zone_config and \
                'transfer_acl' in self._zone_config[config_key]:
            return self._zone_config[config_key]['transfer_acl']
        return self._acl

198
    def _get_query_zone_name(self, msg):
199
        question = msg.get_question()[0]
200
201
        return question.get_name().to_text()

202
203
204
    def _get_query_zone_class(self, msg):
        question = msg.get_question()[0]
        return question.get_class().to_text()
205

206
    def _send_data(self, sock_fd, data):
207
208
209
        size = len(data)
        total_count = 0
        while total_count < size:
210
            count = os.write(sock_fd, data[total_count:])
211
212
213
            total_count += count


214
    def _send_message(self, sock_fd, msg, tsig_ctx=None):
215
        render = MessageRenderer()
216
217
        # As defined in RFC5936 section3.4, perform case-preserving name
        # compression for AXFR message.
218
        render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
219
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
220
221
222
223
224
225
226
227

        # XXX Currently, python wrapper doesn't accept 'None' parameter in this case,
        # we should remove the if statement and use a universal interface later.
        if tsig_ctx is not None:
            msg.to_wire(render, tsig_ctx)
        else:
            msg.to_wire(render)

228
        header_len = struct.pack('H', socket.htons(render.get_length()))
229
230
        self._send_data(sock_fd, header_len)
        self._send_data(sock_fd, render.get_data())
231
232


233
    def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
234
        if not msg:
235
            return # query message is invalid. send nothing back.
236
237

        msg.make_response()
238
        msg.set_rcode(rcode_)
239
        self._send_message(sock_fd, msg, self._tsig_ctx)
240

JINMEI Tatuya's avatar
JINMEI Tatuya committed
241
242
243
244
245
246
247
    def _zone_has_soa(self, zone):
        '''Judge if the zone has an SOA record.'''
        # In some sense, the SOA defines a zone.
        # If the current name server has authority for the
        # specific zone, we need to judge if the zone has an SOA record;
        # if not, we consider the zone has incomplete data, so xfrout can't
        # serve for it.
248
        if sqlite3_ds.get_zone_soa(zone, self._server.get_db_file()):
249
            return True
JINMEI Tatuya's avatar
JINMEI Tatuya committed
250

251
252
        return False

JINMEI Tatuya's avatar
JINMEI Tatuya committed
253
254
255
256
257
258
259
    def _zone_exist(self, zonename):
        '''Judge if the zone is configured by config manager.'''
        # Currently, if we find the zone in datasource successfully, we
        # consider the zone is configured, and the current name server has
        # authority for the specific zone.
        # TODO: should get zone's configuration from cfgmgr or other place
        # in future.
260
        return sqlite3_ds.zone_exist(zonename, self._server.get_db_file())
261

262
263
264
    def _check_xfrout_available(self, zone_name):
        '''Check if xfr request can be responsed.
           TODO, Get zone's configuration from cfgmgr or some other place
265
           eg. check allow_transfer setting,
266
        '''
JINMEI Tatuya's avatar
JINMEI Tatuya committed
267
268
        # If the current name server does not have authority for the
        # zone, xfrout can't serve for it, return rcode NOTAUTH.
269
        if not self._zone_exist(zone_name):
270
            return Rcode.NOTAUTH()
271

JINMEI Tatuya's avatar
JINMEI Tatuya committed
272
273
274
275
        # If we are an authoritative name server for the zone, but fail
        # to find the zone's SOA record in datasource, xfrout can't
        # provide zone transfer for it.
        if not self._zone_has_soa(zone_name):
276
            return Rcode.SERVFAIL()
277
278

        #TODO, check allow_transfer
279
        if not self._server.increase_transfers_counter():
280
            return Rcode.REFUSED()
281

282
        return Rcode.NOERROR()
283
284


285
    def dns_xfrout_start(self, sock_fd, msg_query):
286
287
        rcode_, msg = self._parse_query_message(msg_query)
        #TODO. create query message and parse header
288
289
290
        if rcode_ is None: # Dropped by ACL
            return
        elif rcode_ == Rcode.NOTAUTH() or rcode_ == Rcode.REFUSED():
291
            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
292
        elif rcode_ != Rcode.NOERROR():
293
294
            return self._reply_query_with_error_rcode(msg, sock_fd,
                                                      Rcode.FORMERR())
295
296

        zone_name = self._get_query_zone_name(msg)
297
298
        zone_class_str = self._get_query_zone_class(msg)
        # TODO: should we not also include class in the check?
299
        rcode_ = self._check_xfrout_available(zone_name)
300

301
        if rcode_ != Rcode.NOERROR():
302
303
            logger.info(XFROUT_AXFR_TRANSFER_FAILED, zone_name,
                        zone_class_str, rcode_.to_text())
304
            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
305
306

        try:
307
            logger.info(XFROUT_AXFR_TRANSFER_STARTED, zone_name, zone_class_str)
308
            self._reply_xfrout_query(msg, sock_fd, zone_name)
309
        except Exception as err:
310
311
            logger.error(XFROUT_AXFR_TRANSFER_ERROR, zone_name,
                         zone_class_str, str(err))
312
            pass
313
        logger.info(XFROUT_AXFR_TRANSFER_DONE, zone_name, zone_class_str)
314

315
        self._server.decrease_transfers_counter()
316
        return
317
318
319
320
321
322


    def _clear_message(self, msg):
        qid = msg.get_qid()
        opcode = msg.get_opcode()
        rcode = msg.get_rcode()
323

324
        msg.clear(Message.RENDER)
325
326
327
        msg.set_qid(qid)
        msg.set_opcode(opcode)
        msg.set_rcode(rcode)
328
329
        msg.set_header_flag(Message.HEADERFLAG_AA)
        msg.set_header_flag(Message.HEADERFLAG_QR)
330
331
332
        return msg

    def _create_rrset_from_db_record(self, record):
333
        '''Create one rrset from one record of datasource, if the schema of record is changed,
334
335
        This function should be updated first.
        '''
336
337
338
        rrtype_ = RRType(record[5])
        rdata_ = Rdata(rrtype_, RRClass("IN"), " ".join(record[7:]))
        rrset_ = RRset(Name(record[2]), RRClass("IN"), rrtype_, RRTTL( int(record[4])))
339
340
        rrset_.add_rdata(rdata_)
        return rrset_
341

342
343
    def _send_message_with_last_soa(self, msg, sock_fd, rrset_soa, message_upper_len,
                                    count_since_last_tsig_sign):
344
345
346
        '''Add the SOA record to the end of message. If it can't be
        added, a new message should be created to send out the last soa .
        '''
347
        rrset_len = get_rrset_len(rrset_soa)
348

349
350
351
352
353
354
355
        if (count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH and
            message_upper_len + rrset_len >= XFROUT_MAX_MESSAGE_SIZE):
            # If tsig context exist, sign the packet with serial number TSIG_SIGN_EVERY_NTH
            self._send_message(sock_fd, msg, self._tsig_ctx)
            msg = self._clear_message(msg)
        elif (count_since_last_tsig_sign != TSIG_SIGN_EVERY_NTH and
              message_upper_len + rrset_len + self._tsig_len >= XFROUT_MAX_MESSAGE_SIZE):
356
            self._send_message(sock_fd, msg)
357
358
            msg = self._clear_message(msg)

359
360
        # If tsig context exist, sign the last packet
        msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
361
        self._send_message(sock_fd, msg, self._tsig_ctx)
362
363


364
    def _reply_xfrout_query(self, msg, sock_fd, zone_name):
365
        #TODO, there should be a better way to insert rrset.
366
        count_since_last_tsig_sign = TSIG_SIGN_EVERY_NTH
367
        msg.make_response()
368
        msg.set_header_flag(Message.HEADERFLAG_AA)
369
        soa_record = sqlite3_ds.get_zone_soa(zone_name, self._server.get_db_file())
370
        rrset_soa = self._create_rrset_from_db_record(soa_record)
371
        msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
372

373
        message_upper_len = get_rrset_len(rrset_soa) + self._tsig_len
374

375
376
        for rr_data in sqlite3_ds.get_zone_datas(zone_name, self._server.get_db_file()):
            if  self._server._shutdown_event.is_set(): # Check if xfrout is shutdown
377
                logger.info(XFROUT_STOPPING)
378
                return
379
380
            # TODO: RRType.SOA() ?
            if RRType(rr_data[5]) == RRType("SOA"): #ignore soa record
381
                continue
Jelte Jansen's avatar
Jelte Jansen committed
382

383
            rrset_ = self._create_rrset_from_db_record(rr_data)
384
385
386
387
388
389

            # We calculate the maximum size of the RRset (i.e. the
            # size without compression) and use that to see if we
            # may have reached the limit
            rrset_len = get_rrset_len(rrset_)
            if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
390
                msg.add_rrset(Message.SECTION_ANSWER, rrset_)
391
                message_upper_len += rrset_len
392
393
                continue

394
            # If tsig context exist, sign every N packets
395
396
397
398
399
400
401
            if count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH:
                count_since_last_tsig_sign = 0
                self._send_message(sock_fd, msg, self._tsig_ctx)
            else:
                self._send_message(sock_fd, msg)

            count_since_last_tsig_sign += 1
402
            msg = self._clear_message(msg)
403
            msg.add_rrset(Message.SECTION_ANSWER, rrset_) # Add the rrset to the new message
404
405
406
407
408
409

            # Reserve tsig space for signed packet
            if count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH:
                message_upper_len = rrset_len + self._tsig_len
            else:
                message_upper_len = rrset_len
410

411
412
        self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len,
                                         count_since_last_tsig_sign)
413

414
415
class UnixSockServer(socketserver_mixin.NoPollMixIn,
                     ThreadingUnixStreamServer):
416
417
    '''The unix domain socket server which accept xfr query sent from auth server.'''

418
419
    def __init__(self, sock_file, handle_class, shutdown_event, config_data,
                 cc):
420
        self._remove_unused_sock_file(sock_file)
421
        self._sock_file = sock_file
422
        socketserver_mixin.NoPollMixIn.__init__(self)
423
424
        ThreadingUnixStreamServer.__init__(self, sock_file, handle_class)
        self._shutdown_event = shutdown_event
425
        self._write_sock, self._read_sock = socket.socketpair()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
426
        self._common_init()
427
        self._cc = cc
428
        self.update_config_data(config_data)
429

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
430
    def _common_init(self):
431
        '''Initialization shared with the mock server class used for tests'''
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
432
433
        self._lock = threading.Lock()
        self._transfers_counter = 0
434
435
        self._zone_config = {}
        self._acl = None # this will be initialized in update_config_data()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
436

437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    def _receive_query_message(self, sock):
        ''' receive request message from sock'''
        # receive data length
        data_len = sock.recv(2)
        if not data_len:
            return None
        msg_len = struct.unpack('!H', data_len)[0]
        # receive data
        recv_size = 0
        msgdata = b''
        while recv_size < msg_len:
            data = sock.recv(msg_len - recv_size)
            if not data:
                return None
            recv_size += len(data)
            msgdata += data

        return msgdata

456
457
458
459
460
    def handle_request(self):
        ''' Enable server handle a request until shutdown or auth is closed.'''
        try:
            request, client_address = self.get_request()
        except socket.error:
461
            logger.error(XFROUT_FETCH_REQUEST_ERROR)
462
463
464
465
466
467
468
469
470
471
472
473
474
            return

        # Check self._shutdown_event to ensure the real shutdown comes.
        # Linux could trigger a spurious readable event on the _read_sock
        # due to a bug, so we need perform a double check.
        while not self._shutdown_event.is_set(): # Check if xfrout is shutdown
            try:
                (rlist, wlist, xlist) = select.select([self._read_sock, request], [], [])
            except select.error as e:
                if e.args[0] == errno.EINTR:
                    (rlist, wlist, xlist) = ([], [], [])
                    continue
                else:
475
                    logger.error(XFROUT_SOCKET_SELECT_ERROR, str(e))
476
477
478
479
480
481
482
483
484
                    break

            # self.server._shutdown_event will be set by now, if it is not a false
            # alarm
            if self._read_sock in rlist:
                continue

            try:
                self.process_request(request)
485
            except Exception as pre:
486
                log.error(XFROUT_PROCESS_REQUEST_ERROR, str(pre))
487
488
                break

489
    def _handle_request_noblock(self):
490
491
        """Override the function _handle_request_noblock(), it creates a new
        thread to handle requests for each auth"""
492
493
494
495
        td = threading.Thread(target=self.handle_request)
        td.setDaemon(True)
        td.start()

496
    def process_request(self, request):
497
498
499
500
501
502
503
        """Receive socket fd and query message from auth, then
        start a new thread to process the request."""
        sock_fd = recv_fd(request.fileno())
        if sock_fd < 0:
            # This may happen when one xfrout process try to connect to
            # xfrout unix socket server, to check whether there is another
            # xfrout running.
504
            if sock_fd == FD_COMM_ERROR:
505
                logger.error(XFROUT_RECEIVE_FILE_DESCRIPTOR_ERROR)
506
507
508
509
510
511
512
            return

        # receive request msg
        request_data = self._receive_query_message(request)
        if not request_data:
            return

513
        t = threading.Thread(target=self.finish_request,
514
                             args = (sock_fd, request_data))
515
516
517
518
        if self.daemon_threads:
            t.daemon = True
        t.start()

519
520
    def _guess_remote(self, sock_fd):
        """
521
           Guess remote address and port of the socket. The sock_fd must be a
522
523
524
           socket
        """
        # This uses a trick. If the socket is IPv4 in reality and we pretend
525
        # it to be IPv6, it returns IPv4 address anyway. This doesn't seem
526
527
528
529
530
531
532
533
534
        # to care about the SOCK_STREAM parameter at all (which it really is,
        # except for testing)
        if socket.has_ipv6:
            sock = socket.fromfd(sock_fd, socket.AF_INET6, socket.SOCK_STREAM)
        else:
            # To make it work even on hosts without IPv6 support
            # (Any idea how to simulate this in test?)
            sock = socket.fromfd(sock_fd, socket.AF_INET, socket.SOCK_STREAM)
        return sock.getpeername()
535

536
    def finish_request(self, sock_fd, request_data):
537
538
539
540
        '''Finish one request by instantiating RequestHandlerClass.

        This method creates a XfroutSession object.
        '''
541
542
543
544
        self._lock.acquire()
        acl = self._acl
        zone_config = self._zone_config
        self._lock.release()
545
546
        self.RequestHandlerClass(sock_fd, request_data, self,
                                 self.tsig_key_ring,
547
                                 self._guess_remote(sock_fd), acl, zone_config)
548
549

    def _remove_unused_sock_file(self, sock_file):
550
551
        '''Try to remove the socket file. If the file is being used
        by one running xfrout process, exit from python.
552
553
554
        If it's not a socket file or nobody is listening
        , it will be removed. If it can't be removed, exit from python. '''
        if self._sock_file_in_use(sock_file):
555
            logger.error(XFROUT_UNIX_SOCKET_FILE_IN_USE, sock_file)
556
557
558
559
560
561
562
563
            sys.exit(0)
        else:
            if not os.path.exists(sock_file):
                return

            try:
                os.unlink(sock_file)
            except OSError as err:
564
                logger.error(XFROUT_REMOVE_OLD_UNIX_SOCKET_FILE_ERROR, sock_file, str(err))
565
                sys.exit(0)
566

567
    def _sock_file_in_use(self, sock_file):
568
569
        '''Check whether the socket file 'sock_file' exists and
        is being used by one running xfrout process. If it is,
570
571
572
573
574
575
576
        return True, or else return False. '''
        try:
            sock = socket.socket(socket.AF_UNIX)
            sock.connect(sock_file)
        except socket.error as err:
            return False
        else:
577
            return True
578

579
    def shutdown(self):
580
        self._write_sock.send(b"shutdown") #terminate the xfrout session thread
581
        super().shutdown() # call the shutdown() of class socketserver_mixin.NoPollMixIn
582
583
        try:
            os.unlink(self._sock_file)
Jerry's avatar
Jerry committed
584
        except Exception as e:
Jelte Jansen's avatar
Jelte Jansen committed
585
            logger.error(XFROUT_REMOVE_UNIX_SOCKET_FILE_ERROR, self._sock_file, str(e))
586
            pass
587
588

    def update_config_data(self, new_config):
589
590
591
        '''Apply the new config setting of xfrout module.

        '''
592
        self._lock.acquire()
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
        try:
            logger.info(XFROUT_NEW_CONFIG)
            new_acl = self._acl
            if 'transfer_acl' in new_config:
                try:
                    new_acl = REQUEST_LOADER.load(new_config['transfer_acl'])
                except LoaderError as e:
                    raise XfroutConfigError('Failed to parse transfer_acl: ' +
                                            str(e))

            new_zone_config = self._zone_config
            zconfig_data = new_config.get('zone_config')
            if zconfig_data is not None:
                new_zone_config = self.__create_zone_config(zconfig_data)

            self._acl = new_acl
            self._zone_config = new_zone_config
            self._max_transfers_out = new_config.get('transfers_out')
            self.set_tsig_key_ring(new_config.get('tsig_key_ring'))
        except Exception as e:
            self._lock.release()
            raise e
615
        self._lock.release()
616
        logger.info(XFROUT_NEW_CONFIG_DONE)
617

618
619
620
621
622
    def __create_zone_config(self, zone_config_list):
        new_config = {}
        for zconf in zone_config_list:
            # convert the class, origin (name) pair.  First build pydnspp
            # object to reject invalid input.
623
624
625
626
627
            zclass_str = zconf.get('class')
            if zclass_str is None:
                #zclass_str = 'IN' # temporary
                zclass_str = self._cc.get_default_value('zone_config/class')
            zclass = RRClass(zclass_str)
628
629
630
631
632
            zorigin = Name(zconf['origin'], True)
            config_key = (zclass.to_text(), zorigin.to_text())

            # reject duplicate config
            if config_key in new_config:
633
                raise XfroutConfigError('Duplicate zone_config for ' +
634
                                        str(zorigin) + '/' + str(zclass))
635
636
637
638

            # create a new config entry, build any given (and known) config
            new_config[config_key] = {}
            if 'transfer_acl' in zconf:
639
640
641
642
643
644
645
                try:
                    new_config[config_key]['transfer_acl'] = \
                        REQUEST_LOADER.load(zconf['transfer_acl'])
                except LoaderError as e:
                    raise XfroutConfigError('Failed to parse transfer_acl ' +
                                            'for ' + zorigin.to_text() + '/' +
                                            zclass_str + ': ' + str(e))
646
647
        return new_config

648
    def set_tsig_key_ring(self, key_list):
649
650
651
        """Set the tsig_key_ring , given a TSIG key string list representation. """

        # XXX add values to configure zones/tsig options
652
        self.tsig_key_ring = TSIGKeyRing()
653
        # If key string list is empty, create a empty tsig_key_ring
654
655
656
657
658
659
660
        if not key_list:
            return

        for key_item in key_list:
            try:
                self.tsig_key_ring.add(TSIGKey(key_item))
            except InvalidParameter as ipe:
661
                logger.error(XFROUT_BAD_TSIG_KEY_STRING, str(key_item))
662

663
    def get_db_file(self):
664
665
666
667
668
669
        file, is_default = self._cc.get_remote_config_value("Auth", "database_file")
        # this too should be unnecessary, but currently the
        # 'from build' override isn't stored in the config
        # (and we don't have indirect python access to datasources yet)
        if is_default and "B10_FROM_BUILD" in os.environ:
            file = os.environ["B10_FROM_BUILD"] + os.sep + "bind10_zones.sqlite3"
670
671
        return file

672

673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
    def increase_transfers_counter(self):
        '''Return False, if counter + 1 > max_transfers_out, or else
        return True
        '''
        ret = False
        self._lock.acquire()
        if self._transfers_counter < self._max_transfers_out:
            self._transfers_counter += 1
            ret = True
        self._lock.release()
        return ret

    def decrease_transfers_counter(self):
        self._lock.acquire()
        self._transfers_counter -= 1
        self._lock.release()

class XfroutServer:
    def __init__(self):
692
        self._unix_socket_server = None
693
        self._listen_sock_file = UNIX_SOCKET_FILE
694
        self._shutdown_event = threading.Event()
695
        self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler)
696
697
        self._config_data = self._cc.get_full_config()
        self._cc.start()
698
        self._cc.add_remote_config(AUTH_SPECFILE_LOCATION);
699
        self._start_xfr_query_listener()
700
        self._start_notifier()
701

702
703
    def _start_xfr_query_listener(self):
        '''Start a new thread to accept xfr query. '''
704
705
706
707
        self._unix_socket_server = UnixSockServer(self._listen_sock_file,
                                                  XfroutSession,
                                                  self._shutdown_event,
                                                  self._config_data,
708
                                                  self._cc)
709
        listener = threading.Thread(target=self._unix_socket_server.serve_forever)
710
        listener.start()
711

712
713
    def _start_notifier(self):
        datasrc = self._unix_socket_server.get_db_file()
714
        self._notifier = notify_out.NotifyOut(datasrc)
Michal Vaner's avatar
Michal Vaner committed
715
        self._notifier.dispatcher()
716

717
718
    def send_notify(self, zone_name, zone_class):
        self._notifier.send_notify(zone_name, zone_class)
719
720
721
722
723
724
725
726
727

    def config_handler(self, new_config):
        '''Update config data. TODO. Do error check'''
        answer = create_answer(0)
        for key in new_config:
            if key not in self._config_data:
                answer = create_answer(1, "Unknown config data: " + str(key))
                continue
            self._config_data[key] = new_config[key]
Michal Vaner's avatar
Michal Vaner committed
728

729
        if self._unix_socket_server:
730
731
732
            try:
                self._unix_socket_server.update_config_data(self._config_data)
            except Exception as e:
733
734
735
                answer = create_answer(1,
                                       "Failed to handle new configuration: " +
                                       str(e))
736

737
738
739
740
        return answer


    def shutdown(self):
741
        ''' shutdown the xfrout process. The thread which is doing zone transfer-out should be
742
743
        terminated.
        '''
744
745
746

        global xfrout_server
        xfrout_server = None #Avoid shutdown is called twice
747
        self._shutdown_event.set()
748
        self._notifier.shutdown()
749
750
        if self._unix_socket_server:
            self._unix_socket_server.shutdown()
751

752
        # Wait for all threads to terminate
753
754
755
756
757
758
759
760
        main_thread = threading.currentThread()
        for th in threading.enumerate():
            if th is main_thread:
                continue
            th.join()

    def command_handler(self, cmd, args):
        if cmd == "shutdown":
761
            logger.info(XFROUT_RECEIVED_SHUTDOWN_COMMAND)
762
763
            self.shutdown()
            answer = create_answer(0)
Michal Vaner's avatar
Michal Vaner committed
764

765
        elif cmd == notify_out.ZONE_NEW_DATA_READY_CMD:
766
            zone_name = args.get('zone_name')
767
768
            zone_class = args.get('zone_class')
            if zone_name and zone_class:
769
                logger.info(XFROUT_NOTIFY_COMMAND, zone_name, zone_class)
770
                self.send_notify(zone_name, zone_class)
771
772
773
774
                answer = create_answer(0)
            else:
                answer = create_answer(1, "Bad command parameter:" + str(args))

775
        else:
776
777
            answer = create_answer(1, "Unknown command:" + str(cmd))

Michal Vaner's avatar
Michal Vaner committed
778
        return answer
779
780
781
782

    def run(self):
        '''Get and process all commands sent from cfgmgr or other modules. '''
        while not self._shutdown_event.is_set():
783
            self._cc.check_command(False)
784
785
786
787
788


xfrout_server = None

def signal_handler(signal, frame):
789
    if xfrout_server:
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
        xfrout_server.shutdown()
        sys.exit(0)

def set_signal_handler():
    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

def set_cmd_options(parser):
    parser.add_option("-v", "--verbose", dest="verbose", action="store_true",
            help="display more about what is going on")

if '__main__' == __name__:
    try:
        parser = OptionParser()
        set_cmd_options(parser)
        (options, args) = parser.parse_args()
Jerry's avatar
Jerry committed
806
        VERBOSE_MODE = options.verbose
807
808
809
810
811

        set_signal_handler()
        xfrout_server = XfroutServer()
        xfrout_server.run()
    except KeyboardInterrupt:
812
        logger.INFO(XFROUT_STOPPED_BY_KEYBOARD)
813
    except SessionError as e:
814
        logger.error(XFROUT_CC_SESSION_ERROR, str(e))
815
816
817
818
    except ModuleCCSessionError as e:
        logger.error(XFROUT_MODULECC_SESSION_ERROR, str(e))
    except XfroutConfigError as e:
        logger.error(XFROUT_CONFIG_ERROR, str(e))
819
    except SessionTimeout as e:
820
        logger.error(XFROUT_CC_SESSION_TIMEOUT_ERROR)
821

822
823
824
    if xfrout_server:
        xfrout_server.shutdown()