xfrout.py.in 30.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#!@PYTHON@

# Copyright (C) 2010  Internet Systems Consortium.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.


import sys; sys.path.append ('@@PYTHONPATH@@')
import isc
import isc.cc
import threading
import struct
import signal
Evan Hunt's avatar
Evan Hunt committed
25
from isc.datasrc import sqlite3_ds
26
27
28
from socketserver import *
import os
from isc.config.ccsession import *
29
from isc.cc import SessionError, SessionTimeout
30
from isc.notify import notify_out
31
import isc.util.process
32
import socket
33
import select
34
import errno
35
from optparse import OptionParser, OptionValueError
Likun Zhang's avatar
Likun Zhang committed
36
from isc.util import socketserver_mixin
37

38
from isc.log_messages.xfrout_messages import *
39
40
41
42

isc.log.init("b10-xfrout")
logger = isc.log.Logger("xfrout")

43
try:
44
    from libutil_io_python import *
Jelte Jansen's avatar
Jelte Jansen committed
45
    from pydnspp import *
46
47
48
except ImportError as e:
    # C++ loadable module may not be installed; even so the xfrout process
    # must keep running, so we warn about it and move forward.
49
    log.error(XFROUT_IMPORT, str(e))
Michal Vaner's avatar
Michal Vaner committed
50

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
51
52
53
from isc.acl.acl import ACCEPT, REJECT, DROP
from isc.acl.dns import REQUEST_LOADER

54
isc.util.process.rename()
55

56
57
58
59
60
61
62
63
64
65
66
67
def init_paths():
    global SPECFILE_PATH
    global AUTH_SPECFILE_PATH
    global UNIX_SOCKET_FILE
    if "B10_FROM_BUILD" in os.environ:
        SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/xfrout"
        AUTH_SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/auth"
        if "B10_FROM_SOURCE_LOCALSTATEDIR" in os.environ:
            UNIX_SOCKET_FILE = os.environ["B10_FROM_SOURCE_LOCALSTATEDIR"] + \
                "/auth_xfrout_conn"
        else:
            UNIX_SOCKET_FILE = os.environ["B10_FROM_BUILD"] + "/auth_xfrout_conn"
68
    else:
69
70
71
72
73
74
75
76
77
78
        PREFIX = "@prefix@"
        DATAROOTDIR = "@datarootdir@"
        SPECFILE_PATH = "@datadir@/@PACKAGE@".replace("${datarootdir}", DATAROOTDIR).replace("${prefix}", PREFIX)
        AUTH_SPECFILE_PATH = SPECFILE_PATH
        if "BIND10_XFROUT_SOCKET_FILE" in os.environ:
            UNIX_SOCKET_FILE = os.environ["BIND10_XFROUT_SOCKET_FILE"]
        else:
            UNIX_SOCKET_FILE = "@@LOCALSTATEDIR@@/auth_xfrout_conn"

init_paths()
79

80
SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec"
81
AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
82
MAX_TRANSFERS_OUT = 10
Jerry's avatar
Jerry committed
83
VERBOSE_MODE = False
84
85
# tsig sign every N axfr packets.
TSIG_SIGN_EVERY_NTH = 96
86

87
88
XFROUT_MAX_MESSAGE_SIZE = 65535

89
90
# In practice, RR class is almost always fixed, so if and when we allow
# it to be configured, it's convenient to make it optional.
91
DEFAULT_RRCLASS = RRClass.IN()
92

93
94
95
96
97
98
99
def get_rrset_len(rrset):
    """Returns the wire length of the given RRset"""
    bytes = bytearray()
    rrset.to_wire(bytes)
    return len(bytes)


100
class XfroutSession():
101
102
    def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
                 acl):
103
104
105
        self._sock_fd = sock_fd
        self._request_data = request_data
        self._server = server
106
107
108
        self._tsig_key_ring = tsig_key_ring
        self._tsig_ctx = None
        self._tsig_len = 0
109
        self._remote = remote
110
        self._acl = acl
111
        self._zone_config = {}
112
        self.handle()
Jerry's avatar
Jerry committed
113

114
115
116
117
    def create_tsig_ctx(self, tsig_record, tsig_key_ring):
        return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(),
                           tsig_key_ring)

118
    def handle(self):
119
        ''' Handle a xfrout query, send xfrout response '''
120
        try:
121
            self.dns_xfrout_start(self._sock_fd, self._request_data)
122
123
            #TODO, avoid catching all exceptions
        except Exception as e:
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
124
            logger.error(XFROUT_HANDLE_QUERY_ERROR, e)
125
            pass
126

127
        os.close(self._sock_fd)
128

129
130
131
132
133
134
135
136
137
138
139
140
    def _check_request_tsig(self, msg, request_data):
        ''' If request has a tsig record, perform tsig related checks '''
        tsig_record = msg.get_tsig_record()
        if tsig_record is not None:
            self._tsig_len = tsig_record.get_length()
            self._tsig_ctx = self.create_tsig_ctx(tsig_record, self._tsig_key_ring)
            tsig_error = self._tsig_ctx.verify(tsig_record, request_data)
            if tsig_error != TSIGError.NOERROR:
                return Rcode.NOTAUTH()

        return Rcode.NOERROR()

141
142
    def _parse_query_message(self, mdata):
        ''' parse query message to [socket,message]'''
143
        #TODO, need to add parseHeader() in case the message header is invalid
144
        try:
145
            msg = Message(Message.PARSE)
146
            Message.from_wire(msg, mdata)
147
148
149
150

            # TSIG related checks
            rcode = self._check_request_tsig(msg, mdata)

151
152
153
            if rcode == Rcode.NOERROR():
                # ACL checks
                acl_result = self._acl.execute(
154
155
                    isc.acl.dns.RequestContext(self._remote,
                                               msg.get_tsig_record()))
156
157
158
159
160
161
162
163
164
165
166
167
                if acl_result == DROP:
                    logger.info(XFROUT_QUERY_DROPPED,
                                self._get_query_zone_name(msg),
                                self._get_query_zone_class(msg),
                                self._remote[0], self._remote[1])
                    return None, None
                elif acl_result == REJECT:
                    logger.info(XFROUT_QUERY_REJECTED,
                                self._get_query_zone_name(msg),
                                self._get_query_zone_class(msg),
                                self._remote[0], self._remote[1])
                    return Rcode.REFUSED(), msg
168

169
        except Exception as err:
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
170
            logger.error(XFROUT_PARSE_QUERY_ERROR, err)
171
            return Rcode.FORMERR(), None
172

173
        return rcode, msg
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    def _get_transfer_acl(self, zone_name, zone_class):
        '''Return the ACL that should be applied for a given zone.

        The zone is identified by a tuple of name and RR class.
        If a per zone configuration for the zone exists and contains
        transfer_acl, that ACL will be used; otherwise, the default
        ACL will be used.

        '''
        # Internally zone names are managed in lower cased label characters,
        # so we first need to convert the name.
        zone_name_lower = Name(zone_name.to_text(), True)
        config_key = (zone_name_lower.to_text(), zone_class.to_text())
        if config_key in self._zone_config and \
                'transfer_acl' in self._zone_config[config_key]:
            return self._zone_config[config_key]['transfer_acl']
        return self._acl

193
    def _get_query_zone_name(self, msg):
194
        question = msg.get_question()[0]
195
196
        return question.get_name().to_text()

197
198
199
    def _get_query_zone_class(self, msg):
        question = msg.get_question()[0]
        return question.get_class().to_text()
200

201
    def _send_data(self, sock_fd, data):
202
203
204
        size = len(data)
        total_count = 0
        while total_count < size:
205
            count = os.write(sock_fd, data[total_count:])
206
207
208
            total_count += count


209
    def _send_message(self, sock_fd, msg, tsig_ctx=None):
210
        render = MessageRenderer()
211
212
        # As defined in RFC5936 section3.4, perform case-preserving name
        # compression for AXFR message.
213
        render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
214
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
215
216
217
218
219
220
221
222

        # XXX Currently, python wrapper doesn't accept 'None' parameter in this case,
        # we should remove the if statement and use a universal interface later.
        if tsig_ctx is not None:
            msg.to_wire(render, tsig_ctx)
        else:
            msg.to_wire(render)

223
        header_len = struct.pack('H', socket.htons(render.get_length()))
224
225
        self._send_data(sock_fd, header_len)
        self._send_data(sock_fd, render.get_data())
226
227


228
    def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
229
        if not msg:
230
            return # query message is invalid. send nothing back.
231
232

        msg.make_response()
233
        msg.set_rcode(rcode_)
234
        self._send_message(sock_fd, msg, self._tsig_ctx)
235

JINMEI Tatuya's avatar
JINMEI Tatuya committed
236
237
238
239
240
241
242
    def _zone_has_soa(self, zone):
        '''Judge if the zone has an SOA record.'''
        # In some sense, the SOA defines a zone.
        # If the current name server has authority for the
        # specific zone, we need to judge if the zone has an SOA record;
        # if not, we consider the zone has incomplete data, so xfrout can't
        # serve for it.
243
        if sqlite3_ds.get_zone_soa(zone, self._server.get_db_file()):
244
            return True
JINMEI Tatuya's avatar
JINMEI Tatuya committed
245

246
247
        return False

JINMEI Tatuya's avatar
JINMEI Tatuya committed
248
249
250
251
252
253
254
    def _zone_exist(self, zonename):
        '''Judge if the zone is configured by config manager.'''
        # Currently, if we find the zone in datasource successfully, we
        # consider the zone is configured, and the current name server has
        # authority for the specific zone.
        # TODO: should get zone's configuration from cfgmgr or other place
        # in future.
255
        return sqlite3_ds.zone_exist(zonename, self._server.get_db_file())
256

257
258
259
    def _check_xfrout_available(self, zone_name):
        '''Check if xfr request can be responsed.
           TODO, Get zone's configuration from cfgmgr or some other place
260
           eg. check allow_transfer setting,
261
        '''
JINMEI Tatuya's avatar
JINMEI Tatuya committed
262
263
        # If the current name server does not have authority for the
        # zone, xfrout can't serve for it, return rcode NOTAUTH.
264
        if not self._zone_exist(zone_name):
265
            return Rcode.NOTAUTH()
266

JINMEI Tatuya's avatar
JINMEI Tatuya committed
267
268
269
270
        # If we are an authoritative name server for the zone, but fail
        # to find the zone's SOA record in datasource, xfrout can't
        # provide zone transfer for it.
        if not self._zone_has_soa(zone_name):
271
            return Rcode.SERVFAIL()
272
273

        #TODO, check allow_transfer
274
        if not self._server.increase_transfers_counter():
275
            return Rcode.REFUSED()
276

277
        return Rcode.NOERROR()
278
279


280
    def dns_xfrout_start(self, sock_fd, msg_query):
281
282
        rcode_, msg = self._parse_query_message(msg_query)
        #TODO. create query message and parse header
283
284
285
        if rcode_ is None: # Dropped by ACL
            return
        elif rcode_ == Rcode.NOTAUTH() or rcode_ == Rcode.REFUSED():
286
            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
287
        elif rcode_ != Rcode.NOERROR():
288
289
            return self._reply_query_with_error_rcode(msg, sock_fd,
                                                      Rcode.FORMERR())
290
291

        zone_name = self._get_query_zone_name(msg)
292
293
        zone_class_str = self._get_query_zone_class(msg)
        # TODO: should we not also include class in the check?
294
        rcode_ = self._check_xfrout_available(zone_name)
295

296
        if rcode_ != Rcode.NOERROR():
297
298
            logger.info(XFROUT_AXFR_TRANSFER_FAILED, zone_name,
                        zone_class_str, rcode_.to_text())
299
            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
300
301

        try:
302
            logger.info(XFROUT_AXFR_TRANSFER_STARTED, zone_name, zone_class_str)
303
            self._reply_xfrout_query(msg, sock_fd, zone_name)
304
        except Exception as err:
305
306
            logger.error(XFROUT_AXFR_TRANSFER_ERROR, zone_name,
                         zone_class_str, str(err))
307
            pass
308
        logger.info(XFROUT_AXFR_TRANSFER_DONE, zone_name, zone_class_str)
309

310
        self._server.decrease_transfers_counter()
311
        return
312
313
314
315
316
317


    def _clear_message(self, msg):
        qid = msg.get_qid()
        opcode = msg.get_opcode()
        rcode = msg.get_rcode()
318

319
        msg.clear(Message.RENDER)
320
321
322
        msg.set_qid(qid)
        msg.set_opcode(opcode)
        msg.set_rcode(rcode)
323
324
        msg.set_header_flag(Message.HEADERFLAG_AA)
        msg.set_header_flag(Message.HEADERFLAG_QR)
325
326
327
        return msg

    def _create_rrset_from_db_record(self, record):
328
        '''Create one rrset from one record of datasource, if the schema of record is changed,
329
330
        This function should be updated first.
        '''
331
332
333
        rrtype_ = RRType(record[5])
        rdata_ = Rdata(rrtype_, RRClass("IN"), " ".join(record[7:]))
        rrset_ = RRset(Name(record[2]), RRClass("IN"), rrtype_, RRTTL( int(record[4])))
334
335
        rrset_.add_rdata(rdata_)
        return rrset_
336

337
338
    def _send_message_with_last_soa(self, msg, sock_fd, rrset_soa, message_upper_len,
                                    count_since_last_tsig_sign):
339
340
341
        '''Add the SOA record to the end of message. If it can't be
        added, a new message should be created to send out the last soa .
        '''
342
        rrset_len = get_rrset_len(rrset_soa)
343

344
345
346
347
348
349
350
        if (count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH and
            message_upper_len + rrset_len >= XFROUT_MAX_MESSAGE_SIZE):
            # If tsig context exist, sign the packet with serial number TSIG_SIGN_EVERY_NTH
            self._send_message(sock_fd, msg, self._tsig_ctx)
            msg = self._clear_message(msg)
        elif (count_since_last_tsig_sign != TSIG_SIGN_EVERY_NTH and
              message_upper_len + rrset_len + self._tsig_len >= XFROUT_MAX_MESSAGE_SIZE):
351
            self._send_message(sock_fd, msg)
352
353
            msg = self._clear_message(msg)

354
355
        # If tsig context exist, sign the last packet
        msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
356
        self._send_message(sock_fd, msg, self._tsig_ctx)
357
358


359
    def _reply_xfrout_query(self, msg, sock_fd, zone_name):
360
        #TODO, there should be a better way to insert rrset.
361
        count_since_last_tsig_sign = TSIG_SIGN_EVERY_NTH
362
        msg.make_response()
363
        msg.set_header_flag(Message.HEADERFLAG_AA)
364
        soa_record = sqlite3_ds.get_zone_soa(zone_name, self._server.get_db_file())
365
        rrset_soa = self._create_rrset_from_db_record(soa_record)
366
        msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
367

368
        message_upper_len = get_rrset_len(rrset_soa) + self._tsig_len
369

370
371
        for rr_data in sqlite3_ds.get_zone_datas(zone_name, self._server.get_db_file()):
            if  self._server._shutdown_event.is_set(): # Check if xfrout is shutdown
372
                logger.info(XFROUT_STOPPING)
373
                return
374
375
            # TODO: RRType.SOA() ?
            if RRType(rr_data[5]) == RRType("SOA"): #ignore soa record
376
                continue
Jelte Jansen's avatar
Jelte Jansen committed
377

378
            rrset_ = self._create_rrset_from_db_record(rr_data)
379
380
381
382
383
384

            # We calculate the maximum size of the RRset (i.e. the
            # size without compression) and use that to see if we
            # may have reached the limit
            rrset_len = get_rrset_len(rrset_)
            if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
385
                msg.add_rrset(Message.SECTION_ANSWER, rrset_)
386
                message_upper_len += rrset_len
387
388
                continue

389
            # If tsig context exist, sign every N packets
390
391
392
393
394
395
396
            if count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH:
                count_since_last_tsig_sign = 0
                self._send_message(sock_fd, msg, self._tsig_ctx)
            else:
                self._send_message(sock_fd, msg)

            count_since_last_tsig_sign += 1
397
            msg = self._clear_message(msg)
398
            msg.add_rrset(Message.SECTION_ANSWER, rrset_) # Add the rrset to the new message
399
400
401
402
403
404

            # Reserve tsig space for signed packet
            if count_since_last_tsig_sign == TSIG_SIGN_EVERY_NTH:
                message_upper_len = rrset_len + self._tsig_len
            else:
                message_upper_len = rrset_len
405

406
407
        self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len,
                                         count_since_last_tsig_sign)
408

409
class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
410
411
    '''The unix domain socket server which accept xfr query sent from auth server.'''

412
    def __init__(self, sock_file, handle_class, shutdown_event, config_data, cc):
413
        self._remove_unused_sock_file(sock_file)
414
        self._sock_file = sock_file
415
        socketserver_mixin.NoPollMixIn.__init__(self)
416
417
        ThreadingUnixStreamServer.__init__(self, sock_file, handle_class)
        self._shutdown_event = shutdown_event
418
        self._write_sock, self._read_sock = socket.socketpair()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
419
        self._common_init()
420
        self.update_config_data(config_data)
421
        self._cc = cc
422

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
423
424
425
    def _common_init(self):
        self._lock = threading.Lock()
        self._transfers_counter = 0
426
427
428
        # These default values will probably get overwritten by the (same)
        # default value from the spec file. These are here just to make
        # sure and to make the default values in tests consistent.
429
        self._acl = REQUEST_LOADER.load('[{"action": "ACCEPT"}]')
430
        self._zone_config = {}
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
431

432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    def _receive_query_message(self, sock):
        ''' receive request message from sock'''
        # receive data length
        data_len = sock.recv(2)
        if not data_len:
            return None
        msg_len = struct.unpack('!H', data_len)[0]
        # receive data
        recv_size = 0
        msgdata = b''
        while recv_size < msg_len:
            data = sock.recv(msg_len - recv_size)
            if not data:
                return None
            recv_size += len(data)
            msgdata += data

        return msgdata

451
452
453
454
455
    def handle_request(self):
        ''' Enable server handle a request until shutdown or auth is closed.'''
        try:
            request, client_address = self.get_request()
        except socket.error:
456
            logger.error(XFROUT_FETCH_REQUEST_ERROR)
457
458
459
460
461
462
463
464
465
466
467
468
469
            return

        # Check self._shutdown_event to ensure the real shutdown comes.
        # Linux could trigger a spurious readable event on the _read_sock
        # due to a bug, so we need perform a double check.
        while not self._shutdown_event.is_set(): # Check if xfrout is shutdown
            try:
                (rlist, wlist, xlist) = select.select([self._read_sock, request], [], [])
            except select.error as e:
                if e.args[0] == errno.EINTR:
                    (rlist, wlist, xlist) = ([], [], [])
                    continue
                else:
470
                    logger.error(XFROUT_SOCKET_SELECT_ERROR, str(e))
471
472
473
474
475
476
477
478
479
                    break

            # self.server._shutdown_event will be set by now, if it is not a false
            # alarm
            if self._read_sock in rlist:
                continue

            try:
                self.process_request(request)
480
            except Exception as pre:
481
                log.error(XFROUT_PROCESS_REQUEST_ERROR, str(pre))
482
483
                break

484
    def _handle_request_noblock(self):
485
486
        """Override the function _handle_request_noblock(), it creates a new
        thread to handle requests for each auth"""
487
488
489
490
        td = threading.Thread(target=self.handle_request)
        td.setDaemon(True)
        td.start()

491
    def process_request(self, request):
492
493
494
495
496
497
498
        """Receive socket fd and query message from auth, then
        start a new thread to process the request."""
        sock_fd = recv_fd(request.fileno())
        if sock_fd < 0:
            # This may happen when one xfrout process try to connect to
            # xfrout unix socket server, to check whether there is another
            # xfrout running.
499
            if sock_fd == FD_COMM_ERROR:
500
                logger.error(XFROUT_RECEIVE_FILE_DESCRIPTOR_ERROR)
501
502
503
504
505
506
507
508
            return

        # receive request msg
        request_data = self._receive_query_message(request)
        if not request_data:
            return

        t = threading.Thread(target = self.finish_request,
509
                             args = (sock_fd, request_data))
510
511
512
513
        if self.daemon_threads:
            t.daemon = True
        t.start()

514
515
    def _guess_remote(self, sock_fd):
        """
516
           Guess remote address and port of the socket. The sock_fd must be a
517
518
519
           socket
        """
        # This uses a trick. If the socket is IPv4 in reality and we pretend
520
        # it to be IPv6, it returns IPv4 address anyway. This doesn't seem
521
522
523
524
525
526
527
528
529
        # to care about the SOCK_STREAM parameter at all (which it really is,
        # except for testing)
        if socket.has_ipv6:
            sock = socket.fromfd(sock_fd, socket.AF_INET6, socket.SOCK_STREAM)
        else:
            # To make it work even on hosts without IPv6 support
            # (Any idea how to simulate this in test?)
            sock = socket.fromfd(sock_fd, socket.AF_INET, socket.SOCK_STREAM)
        return sock.getpeername()
530

531
    def finish_request(self, sock_fd, request_data):
Jerry's avatar
Jerry committed
532
        '''Finish one request by instantiating RequestHandlerClass.'''
533
534
        self.RequestHandlerClass(sock_fd, request_data, self,
                                 self.tsig_key_ring,
535
                                 self._guess_remote(sock_fd), self._acl)
536
537

    def _remove_unused_sock_file(self, sock_file):
538
539
        '''Try to remove the socket file. If the file is being used
        by one running xfrout process, exit from python.
540
541
542
        If it's not a socket file or nobody is listening
        , it will be removed. If it can't be removed, exit from python. '''
        if self._sock_file_in_use(sock_file):
543
            logger.error(XFROUT_UNIX_SOCKET_FILE_IN_USE, sock_file)
544
545
546
547
548
549
550
551
            sys.exit(0)
        else:
            if not os.path.exists(sock_file):
                return

            try:
                os.unlink(sock_file)
            except OSError as err:
552
                logger.error(XFROUT_REMOVE_OLD_UNIX_SOCKET_FILE_ERROR, sock_file, str(err))
553
                sys.exit(0)
554

555
    def _sock_file_in_use(self, sock_file):
556
557
        '''Check whether the socket file 'sock_file' exists and
        is being used by one running xfrout process. If it is,
558
559
560
561
562
563
564
        return True, or else return False. '''
        try:
            sock = socket.socket(socket.AF_UNIX)
            sock.connect(sock_file)
        except socket.error as err:
            return False
        else:
565
            return True
566

567
    def shutdown(self):
568
        self._write_sock.send(b"shutdown") #terminate the xfrout session thread
569
        super().shutdown() # call the shutdown() of class socketserver_mixin.NoPollMixIn
570
571
        try:
            os.unlink(self._sock_file)
Jerry's avatar
Jerry committed
572
        except Exception as e:
Jelte Jansen's avatar
Jelte Jansen committed
573
            logger.error(XFROUT_REMOVE_UNIX_SOCKET_FILE_ERROR, self._sock_file, str(e))
574
            pass
575
576

    def update_config_data(self, new_config):
577
578
579
580
581
582
583
        '''Apply the new config setting of xfrout module.

        Note: this method does not provide strong exception guarantee;
        if an exception is raised in the middle of parsing and building the
        given config data, the incomplete set of new configuration will
        remain.  This should be fixed.
        '''
584
        logger.info(XFROUT_NEW_CONFIG)
585
586
        if 'query_acl' in new_config:
            self._acl = REQUEST_LOADER.load(new_config['query_acl'])
587
588
589
        if 'zone_config' in new_config:
            self._zone_config = \
                self.__create_zone_config(new_config.get('zone_config'))
590
591
        self._lock.acquire()
        self._max_transfers_out = new_config.get('transfers_out')
592
        self.set_tsig_key_ring(new_config.get('tsig_key_ring'))
593
        self._lock.release()
594
        logger.info(XFROUT_NEW_CONFIG_DONE)
595

596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
    def __create_zone_config(self, zone_config_list):
        new_config = {}
        for zconf in zone_config_list:
            # convert the class, origin (name) pair.  First build pydnspp
            # object to reject invalid input.
            if 'class' in zconf:
                zclass = RRClass(zconf['class'])
            else:
                zclass = DEFAULT_RRCLASS
            zorigin = Name(zconf['origin'], True)
            config_key = (zclass.to_text(), zorigin.to_text())

            # reject duplicate config
            if config_key in new_config:
                raise ValueError('Duplicaet zone_config for ' +
                                 str(zorigin) + '/' + str(zclass))

            # create a new config entry, build any given (and known) config
            new_config[config_key] = {}
            if 'transfer_acl' in zconf:
                new_config[config_key]['transfer_acl'] = \
                    REQUEST_LOADER.load(zconf['transfer_acl'])
        return new_config

620
    def set_tsig_key_ring(self, key_list):
621
622
623
        """Set the tsig_key_ring , given a TSIG key string list representation. """

        # XXX add values to configure zones/tsig options
624
        self.tsig_key_ring = TSIGKeyRing()
625
        # If key string list is empty, create a empty tsig_key_ring
626
627
628
629
630
631
632
        if not key_list:
            return

        for key_item in key_list:
            try:
                self.tsig_key_ring.add(TSIGKey(key_item))
            except InvalidParameter as ipe:
633
                logger.error(XFROUT_BAD_TSIG_KEY_STRING, str(key_item))
634

635
    def get_db_file(self):
636
637
638
639
640
641
        file, is_default = self._cc.get_remote_config_value("Auth", "database_file")
        # this too should be unnecessary, but currently the
        # 'from build' override isn't stored in the config
        # (and we don't have indirect python access to datasources yet)
        if is_default and "B10_FROM_BUILD" in os.environ:
            file = os.environ["B10_FROM_BUILD"] + os.sep + "bind10_zones.sqlite3"
642
643
        return file

644

645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
    def increase_transfers_counter(self):
        '''Return False, if counter + 1 > max_transfers_out, or else
        return True
        '''
        ret = False
        self._lock.acquire()
        if self._transfers_counter < self._max_transfers_out:
            self._transfers_counter += 1
            ret = True
        self._lock.release()
        return ret

    def decrease_transfers_counter(self):
        self._lock.acquire()
        self._transfers_counter -= 1
        self._lock.release()

class XfroutServer:
    def __init__(self):
664
        self._unix_socket_server = None
665
        self._listen_sock_file = UNIX_SOCKET_FILE
666
        self._shutdown_event = threading.Event()
667
        self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler)
668
669
        self._config_data = self._cc.get_full_config()
        self._cc.start()
670
        self._cc.add_remote_config(AUTH_SPECFILE_LOCATION);
671
        self._start_xfr_query_listener()
672
        self._start_notifier()
673

674
675
    def _start_xfr_query_listener(self):
        '''Start a new thread to accept xfr query. '''
676
        self._unix_socket_server = UnixSockServer(self._listen_sock_file, XfroutSession,
677
                                                  self._shutdown_event, self._config_data,
678
                                                  self._cc)
679
        listener = threading.Thread(target=self._unix_socket_server.serve_forever)
680
        listener.start()
681

682
683
    def _start_notifier(self):
        datasrc = self._unix_socket_server.get_db_file()
684
        self._notifier = notify_out.NotifyOut(datasrc)
Michal Vaner's avatar
Michal Vaner committed
685
        self._notifier.dispatcher()
686

687
688
    def send_notify(self, zone_name, zone_class):
        self._notifier.send_notify(zone_name, zone_class)
689
690
691
692
693
694
695
696
697

    def config_handler(self, new_config):
        '''Update config data. TODO. Do error check'''
        answer = create_answer(0)
        for key in new_config:
            if key not in self._config_data:
                answer = create_answer(1, "Unknown config data: " + str(key))
                continue
            self._config_data[key] = new_config[key]
Michal Vaner's avatar
Michal Vaner committed
698

699
        if self._unix_socket_server:
700
701
702
            try:
                self._unix_socket_server.update_config_data(self._config_data)
            except Exception as e:
703
704
705
                answer = create_answer(1,
                                       "Failed to handle new configuration: " +
                                       str(e))
706

707
708
709
710
        return answer


    def shutdown(self):
711
        ''' shutdown the xfrout process. The thread which is doing zone transfer-out should be
712
713
        terminated.
        '''
714
715
716

        global xfrout_server
        xfrout_server = None #Avoid shutdown is called twice
717
        self._shutdown_event.set()
718
        self._notifier.shutdown()
719
720
        if self._unix_socket_server:
            self._unix_socket_server.shutdown()
721

722
        # Wait for all threads to terminate
723
724
725
726
727
728
729
730
        main_thread = threading.currentThread()
        for th in threading.enumerate():
            if th is main_thread:
                continue
            th.join()

    def command_handler(self, cmd, args):
        if cmd == "shutdown":
731
            logger.info(XFROUT_RECEIVED_SHUTDOWN_COMMAND)
732
733
            self.shutdown()
            answer = create_answer(0)
Michal Vaner's avatar
Michal Vaner committed
734

735
        elif cmd == notify_out.ZONE_NEW_DATA_READY_CMD:
736
            zone_name = args.get('zone_name')
737
738
            zone_class = args.get('zone_class')
            if zone_name and zone_class:
739
                logger.info(XFROUT_NOTIFY_COMMAND, zone_name, zone_class)
740
                self.send_notify(zone_name, zone_class)
741
742
743
744
                answer = create_answer(0)
            else:
                answer = create_answer(1, "Bad command parameter:" + str(args))

745
        else:
746
747
            answer = create_answer(1, "Unknown command:" + str(cmd))

Michal Vaner's avatar
Michal Vaner committed
748
        return answer
749
750
751
752

    def run(self):
        '''Get and process all commands sent from cfgmgr or other modules. '''
        while not self._shutdown_event.is_set():
753
            self._cc.check_command(False)
754
755
756
757
758


xfrout_server = None

def signal_handler(signal, frame):
759
    if xfrout_server:
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
        xfrout_server.shutdown()
        sys.exit(0)

def set_signal_handler():
    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

def set_cmd_options(parser):
    parser.add_option("-v", "--verbose", dest="verbose", action="store_true",
            help="display more about what is going on")

if '__main__' == __name__:
    try:
        parser = OptionParser()
        set_cmd_options(parser)
        (options, args) = parser.parse_args()
Jerry's avatar
Jerry committed
776
        VERBOSE_MODE = options.verbose
777
778
779
780
781

        set_signal_handler()
        xfrout_server = XfroutServer()
        xfrout_server.run()
    except KeyboardInterrupt:
782
        logger.INFO(XFROUT_STOPPED_BY_KEYBOARD)
783
    except SessionError as e:
784
        logger.error(XFROUT_CC_SESSION_ERROR, str(e))
785
    except SessionTimeout as e:
786
        logger.error(XFROUT_CC_SESSION_TIMEOUT_ERROR)
787

788
789
790
    if xfrout_server:
        xfrout_server.shutdown()