xfrout.py.in 37.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#!@PYTHON@

# Copyright (C) 2010  Internet Systems Consortium.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.


import sys; sys.path.append ('@@PYTHONPATH@@')
import isc
import isc.cc
import threading
import struct
import signal
25
from isc.datasrc import DataSourceClient, ZoneFinder, ZoneJournalReader
26
27
28
from socketserver import *
import os
from isc.config.ccsession import *
29
from isc.cc import SessionError, SessionTimeout
30
from isc.notify import notify_out
31
import isc.util.process
32
import socket
33
import select
34
import errno
35
from optparse import OptionParser, OptionValueError
Likun Zhang's avatar
Likun Zhang committed
36
from isc.util import socketserver_mixin
37

38
from isc.log_messages.xfrout_messages import *
39
40
41
42

isc.log.init("b10-xfrout")
logger = isc.log.Logger("xfrout")

43
try:
44
    from libutil_io_python import *
Jelte Jansen's avatar
Jelte Jansen committed
45
    from pydnspp import *
46
47
48
except ImportError as e:
    # C++ loadable module may not be installed; even so the xfrout process
    # must keep running, so we warn about it and move forward.
49
    log.error(XFROUT_IMPORT, str(e))
Michal Vaner's avatar
Michal Vaner committed
50

51
from isc.acl.acl import ACCEPT, REJECT, DROP, LoaderError
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
52
53
from isc.acl.dns import REQUEST_LOADER

54
isc.util.process.rename()
55

56
57
58
59
60
61
62
63
64
65
66
67
class XfroutConfigError(Exception):
    """An exception indicating an error in updating xfrout configuration.

    This exception is raised when the xfrout process encouters an error in
    handling configuration updates.  Not all syntax error can be caught
    at the module-CC layer, so xfrout needs to (explicitly or implicitly)
    validate the given configuration data itself.  When it finds an error
    it raises this exception (either directly or by converting an exception
    from other modules) as a unified error in configuration.
    """
    pass

68
69
70
71
72
73
74
75
76
77
78
79
def init_paths():
    global SPECFILE_PATH
    global AUTH_SPECFILE_PATH
    global UNIX_SOCKET_FILE
    if "B10_FROM_BUILD" in os.environ:
        SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/xfrout"
        AUTH_SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/auth"
        if "B10_FROM_SOURCE_LOCALSTATEDIR" in os.environ:
            UNIX_SOCKET_FILE = os.environ["B10_FROM_SOURCE_LOCALSTATEDIR"] + \
                "/auth_xfrout_conn"
        else:
            UNIX_SOCKET_FILE = os.environ["B10_FROM_BUILD"] + "/auth_xfrout_conn"
80
    else:
81
82
83
84
85
86
87
        PREFIX = "@prefix@"
        DATAROOTDIR = "@datarootdir@"
        SPECFILE_PATH = "@datadir@/@PACKAGE@".replace("${datarootdir}", DATAROOTDIR).replace("${prefix}", PREFIX)
        AUTH_SPECFILE_PATH = SPECFILE_PATH
        if "BIND10_XFROUT_SOCKET_FILE" in os.environ:
            UNIX_SOCKET_FILE = os.environ["BIND10_XFROUT_SOCKET_FILE"]
        else:
88
            UNIX_SOCKET_FILE = "@@LOCALSTATEDIR@@/@PACKAGE_NAME@/auth_xfrout_conn"
89
90

init_paths()
91

92
SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec"
93
AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
Jerry's avatar
Jerry committed
94
VERBOSE_MODE = False
95
96
XFROUT_MAX_MESSAGE_SIZE = 65535

97
98
99
100
101
102
103
104
105
106
# borrowed from xfrin.py @ #1298.  We should eventually unify it.
def format_zone_str(zone_name, zone_class):
    """Helper function to format a zone name and class as a string of
       the form '<name>/<class>'.
       Parameters:
       zone_name (isc.dns.Name) name to format
       zone_class (isc.dns.RRClass) class to format
    """
    return zone_name.to_text() + '/' + str(zone_class)

107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# borrowed from xfrin.py @ #1298.
def format_addrinfo(addrinfo):
    """Helper function to format the addrinfo as a string of the form
       <addr>:<port> (for IPv4) or [<addr>]:port (for IPv6). For unix domain
       sockets, and unknown address families, it returns a basic string
       conversion of the third element of the passed tuple.
       Parameters:
       addrinfo: a 3-tuple consisting of address family, socket type, and,
                 depending on the family, either a 2-tuple with the address
                 and port, or a filename
    """
    try:
        if addrinfo[0] == socket.AF_INET:
            return str(addrinfo[2][0]) + ":" + str(addrinfo[2][1])
        elif addrinfo[0] == socket.AF_INET6:
            return "[" + str(addrinfo[2][0]) + "]:" + str(addrinfo[2][1])
        else:
            return str(addrinfo[2])
    except IndexError:
        raise TypeError("addrinfo argument to format_addrinfo() does not "
                        "appear to be consisting of (family, socktype, (addr, port))")

129
130
131
132
133
134
def get_rrset_len(rrset):
    """Returns the wire length of the given RRset"""
    bytes = bytearray()
    rrset.to_wire(bytes)
    return len(bytes)

135
136
137
138
139
def get_soa_serial(soa_rdata):
    '''Extract the serial field of an SOA RDATA and returns it as an intger.
    (borrowed from xfrin)
    '''
    return int(soa_rdata.to_text().split()[2])
140

141
class XfroutSession():
142
    def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
143
                 default_acl, zone_config, client_class=DataSourceClient):
144
145
146
        self._sock_fd = sock_fd
        self._request_data = request_data
        self._server = server
147
148
149
        self._tsig_key_ring = tsig_key_ring
        self._tsig_ctx = None
        self._tsig_len = 0
150
        self._remote = remote
151
152
        self._request_type = None
        self._request_typestr = None
153
154
        self._acl = default_acl
        self._zone_config = zone_config
155
        self.ClientClass = client_class # parameterize this for testing
156
        self._soa = None # will be set in _xfrout_setup or in tests
157
        self._handle()
Jerry's avatar
Jerry committed
158

159
160
161
162
    def create_tsig_ctx(self, tsig_record, tsig_key_ring):
        return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(),
                           tsig_key_ring)

163
164
165
166
167
168
169
170
171
172
173
    def _handle(self):
        ''' Handle a xfrout query, send xfrout response(s).

        This is separated from the constructor so that we can override
        it from tests.

        '''
        # Check the xfrout quota.  We do both increase/decrease in this
        # method so it's clear we always release it once acuired.
        quota_ok = self._server.increase_transfers_counter()
        ex = None
174
        try:
175
            self.dns_xfrout_start(self._sock_fd, self._request_data, quota_ok)
176
        except Exception as e:
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
            # To avoid resource leak we need catch all possible exceptions
            # We log it later to exclude the case where even logger raises
            # an exception.
            ex = e

        # Release any critical resources
        if quota_ok:
            self._server.decrease_transfers_counter()
        self._close_socket()

        if ex is not None:
            logger.error(XFROUT_HANDLE_QUERY_ERROR, ex)

    def _close_socket(self):
        '''Simply close the socket via the given FD.
192

193
194
195
196
        This is a dedicated subroutine of handle() and is sepsarated from it
        for the convenience of tests.

        '''
197
        os.close(self._sock_fd)
198

199
200
201
202
203
204
205
206
207
208
209
210
    def _check_request_tsig(self, msg, request_data):
        ''' If request has a tsig record, perform tsig related checks '''
        tsig_record = msg.get_tsig_record()
        if tsig_record is not None:
            self._tsig_len = tsig_record.get_length()
            self._tsig_ctx = self.create_tsig_ctx(tsig_record, self._tsig_key_ring)
            tsig_error = self._tsig_ctx.verify(tsig_record, request_data)
            if tsig_error != TSIGError.NOERROR:
                return Rcode.NOTAUTH()

        return Rcode.NOERROR()

211
212
    def _parse_query_message(self, mdata):
        ''' parse query message to [socket,message]'''
213
        #TODO, need to add parseHeader() in case the message header is invalid
214
        try:
215
            msg = Message(Message.PARSE)
216
            Message.from_wire(msg, mdata)
217
        except Exception as err: # Exception is too broad
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
218
            logger.error(XFROUT_PARSE_QUERY_ERROR, err)
219
            return Rcode.FORMERR(), None
220

221
222
        # TSIG related checks
        rcode = self._check_request_tsig(msg, mdata)
223
224
225
226
227
228
229
230
        if rcode != Rcode.NOERROR():
            return rcode, msg

        # Make sure the question is valid.  This should be ensured by
        # the auth server, but since it's far from our xfrout itself,
        # we check it by ourselves.
        if msg.get_rr_count(Message.SECTION_QUESTION) != 1:
            return Rcode.FORMERR(), msg
231
232
233
234
235
236
237
238
239
        question = msg.get_question()[0]

        # Identify the request type
        self._request_type = question.get_type()
        if self._request_type == RRType.AXFR():
            self._request_typestr = 'AXFR'
        elif self._request_type == RRType.IXFR():
            self._request_typestr = 'IXFR'
        else:
240
241
            # Likewise, this should be impossible.  (TBD: to be tested)
            raise RuntimeError('Unexpected XFR type: ' + \
242
                                   str(self._request_type))
243
244

        # ACL checks
245
246
        zone_name = question.get_name()
        zone_class = question.get_class()
247
248
        acl = self._get_transfer_acl(zone_name, zone_class)
        acl_result = acl.execute(
249
            isc.acl.dns.RequestContext(self._remote[2], msg.get_tsig_record()))
250
        if acl_result == DROP:
251
            logger.info(XFROUT_QUERY_DROPPED, self._request_typestr,
252
253
                        format_addrinfo(self._remote),
                        format_zone_str(zone_name, zone_class))
254
255
            return None, None
        elif acl_result == REJECT:
256
            logger.info(XFROUT_QUERY_REJECTED, self._request_typestr,
257
258
                        format_addrinfo(self._remote),
                        format_zone_str(zone_name, zone_class))
259
            return Rcode.REFUSED(), msg
260

261
        return rcode, msg
262

263
264
265
266
267
268
269
270
271
272
273
274
    def _get_transfer_acl(self, zone_name, zone_class):
        '''Return the ACL that should be applied for a given zone.

        The zone is identified by a tuple of name and RR class.
        If a per zone configuration for the zone exists and contains
        transfer_acl, that ACL will be used; otherwise, the default
        ACL will be used.

        '''
        # Internally zone names are managed in lower cased label characters,
        # so we first need to convert the name.
        zone_name_lower = Name(zone_name.to_text(), True)
275
        config_key = (zone_class.to_text(), zone_name_lower.to_text())
276
277
278
279
280
        if config_key in self._zone_config and \
                'transfer_acl' in self._zone_config[config_key]:
            return self._zone_config[config_key]['transfer_acl']
        return self._acl

281
    def _send_data(self, sock_fd, data):
282
283
284
        size = len(data)
        total_count = 0
        while total_count < size:
285
            count = os.write(sock_fd, data[total_count:])
286
287
288
            total_count += count


289
    def _send_message(self, sock_fd, msg, tsig_ctx=None):
290
        render = MessageRenderer()
291
292
        # As defined in RFC5936 section3.4, perform case-preserving name
        # compression for AXFR message.
293
        render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
294
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
295
296
297
298
299
300
301
302

        # XXX Currently, python wrapper doesn't accept 'None' parameter in this case,
        # we should remove the if statement and use a universal interface later.
        if tsig_ctx is not None:
            msg.to_wire(render, tsig_ctx)
        else:
            msg.to_wire(render)

303
        header_len = struct.pack('H', socket.htons(render.get_length()))
304
305
        self._send_data(sock_fd, header_len)
        self._send_data(sock_fd, render.get_data())
306
307


308
    def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
309
        if not msg:
310
            return # query message is invalid. send nothing back.
311
312

        msg.make_response()
313
        msg.set_rcode(rcode_)
314
        self._send_message(sock_fd, msg, self._tsig_ctx)
315

316
    def _get_zone_soa(self, zone_name):
317
318
319
320
321
322
323
324
        '''Retrieve the SOA RR of the given zone.

        It returns a pair of RCODE and the SOA (in the form of RRset).
        On success RCODE is NOERROR and returned SOA is not None;
        on failure RCODE indicates the appropriate code in the context of
        xfr processing, and the returned SOA is None.

        '''
325
326
        result, finder = self._datasrc_client.find_zone(zone_name)
        if result != DataSourceClient.SUCCESS:
327
328
329
            return (Rcode.NOTAUTH(), None)
        result, soa_rrset = finder.find(zone_name, RRType.SOA(), None,
                                        ZoneFinder.FIND_DEFAULT)
330
        if result != ZoneFinder.SUCCESS:
331
            return (Rcode.SERVFAIL(), None)
332
333
334
335
        # Especially for database-based zones, a working zone may be in
        # a broken state where it has more than one SOA RR.  We proactively
        # check the condition and abort the xfr attempt if we identify it.
        if soa_rrset.get_rdata_count() != 1:
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
            return (Rcode.SERVFAIL(), None)
        return (Rcode.NOERROR(), soa_rrset)

    def __setup_axfr(self, zone_name):
        '''Setup a zone iterator for AXFR or AXFR-style IXFR.

        '''
        try:
            # Note that we disable 'adjust_ttl'.  In xfr-out we need to
            # preserve as many things as possible (even if it's half
            # broken) stored in the zone.
            self._iterator = self._datasrc_client.get_iterator(zone_name,
                                                               False)
        except isc.datasrc.Error:
            # If the current name server does not have authority for the
            # zone, xfrout can't serve for it, return rcode NOTAUTH.
            # Note: this exception can happen for other reasons.  We should
            # update get_iterator() API so that we can distinguish "no such
            # zone" and other cases (#1373).  For now we consider all these
            # cases as NOTAUTH.
            return Rcode.NOTAUTH()

        # If we are an authoritative name server for the zone, but fail
        # to find the zone's SOA record in datasource, xfrout can't
        # provide zone transfer for it.
        self._soa = self._iterator.get_soa()
        if self._soa is None or self._soa.get_rdata_count() != 1:
            return Rcode.SERVFAIL()

        return Rcode.NOERROR()

    def __setup_ixfr(self, request_msg, zone_name):
        '''Setup a zone journal reader for IXFR.

        If the underlying data source does not know the requested range
        of zone differences it automatically falls back to AXFR-style
        IXFR by setting up a zone iterator instead of a journal reader.

        '''
        # TODO: more error case handling
        remote_soa = None
        for auth_rrset in request_msg.get_section(Message.SECTION_AUTHORITY):
            if auth_rrset.get_type() != RRType.SOA():
                continue
            remote_soa = auth_rrset
        rcode, self._soa = self._get_zone_soa(zone_name)
        if rcode != Rcode.NOERROR():
            return rcode
        code, self._jnl_reader = self._datasrc_client.get_journal_reader(
            remote_soa.get_name(), get_soa_serial(remote_soa.get_rdata()[0]),
            get_soa_serial(self._soa.get_rdata()[0]))
        if code == ZoneJournalReader.NO_SUCH_VERSION:
            # fallback to AXFR-style IXFR
            self._jnl_reader = None # clear it just in case
            return self.__setup_axfr(zone_name)

        return Rcode.NOERROR()
393

394
395
396
397
398
399
400
    def _xfrout_setup(self, request_msg, zone_name):
        '''Setup a context for xfr responses according to the request type.

        This method identifies the most appropriate data source for the
        request and set up a zone iterator or journal reader depending on
        whether the request is AXFR or IXFR.  If it identifies any protocol
        level error it returns an RCODE other than NOERROR.
401

402
        '''
403

404
405
        # Identify the data source for the requested zone and see if it has
        # SOA while initializing objects used for request processing later.
JINMEI Tatuya's avatar
JINMEI Tatuya committed
406
407
408
        # We should eventually generalize this so that we can choose the
        # appropriate data source from (possible) multiple candidates.
        # We should eventually take into account the RR class here.
409
        # For now, we hardcode a particular type (SQLite3-based), and only
410
        # consider that one.
411
412
        datasrc_config = '{ "database_file": "' + \
            self._server.get_db_file() + '"}'
413
        self._datasrc_client = self.ClientClass('sqlite3', datasrc_config)
414
415

        if self._request_type == RRType.AXFR():
416
            return self.__setup_axfr(zone_name)
417
        else:
418
            return self.__setup_ixfr(request_msg, zone_name)
419

420
    def dns_xfrout_start(self, sock_fd, msg_query, quota_ok=True):
421
422
        rcode_, msg = self._parse_query_message(msg_query)
        #TODO. create query message and parse header
423
424
425
        if rcode_ is None: # Dropped by ACL
            return
        elif rcode_ == Rcode.NOTAUTH() or rcode_ == Rcode.REFUSED():
426
            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
427
        elif rcode_ != Rcode.NOERROR():
428
429
            return self._reply_query_with_error_rcode(msg, sock_fd,
                                                      Rcode.FORMERR())
430
        elif not quota_ok:
431
            logger.warn(XFROUT_QUERY_QUOTA_EXCCEEDED, self._request_typestr,
432
433
                        format_addrinfo(self._remote),
                        self._server._max_transfers_out)
434
435
            return self._reply_query_with_error_rcode(msg, sock_fd,
                                                      Rcode.REFUSED())
436

437
438
439
440
        question = msg.get_question()[0]
        zone_name = question.get_name()
        zone_class = question.get_class()
        zone_str = format_zone_str(zone_name, zone_class) # for logging
441

442
        # TODO: we should also include class in the check
443
        try:
444
            rcode_ = self._xfrout_setup(msg, zone_name)
445
        except Exception as ex:
446
            logger.error(XFROUT_XFR_TRANSFER_CHECK_ERROR, self._request_typestr,
447
448
                         format_addrinfo(self._remote), zone_str, ex)
            rcode_ = Rcode.SERVFAIL()
449
        if rcode_ != Rcode.NOERROR():
450
            logger.info(XFROUT_AXFR_TRANSFER_FAILED, self._request_typestr,
451
                        format_addrinfo(self._remote), zone_str, rcode_)
452
            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
453
454

        try:
455
            logger.info(XFROUT_AXFR_TRANSFER_STARTED, self._request_typestr,
456
                        format_addrinfo(self._remote), zone_str)
457
            self._reply_xfrout_query(msg, sock_fd)
458
        except Exception as err:
459
            logger.error(XFROUT_AXFR_TRANSFER_ERROR, self._request_typestr,
460
                    format_addrinfo(self._remote), zone_str, err)
461
            pass
462
        logger.info(XFROUT_AXFR_TRANSFER_DONE, self._request_typestr,
463
                    format_addrinfo(self._remote), zone_str)
464
465
466
467
468

    def _clear_message(self, msg):
        qid = msg.get_qid()
        opcode = msg.get_opcode()
        rcode = msg.get_rcode()
469

470
        msg.clear(Message.RENDER)
471
472
473
        msg.set_qid(qid)
        msg.set_opcode(opcode)
        msg.set_rcode(rcode)
474
475
        msg.set_header_flag(Message.HEADERFLAG_AA)
        msg.set_header_flag(Message.HEADERFLAG_QR)
476
477
        return msg

478
479
    def _send_message_with_last_soa(self, msg, sock_fd, rrset_soa,
                                    message_upper_len):
480
481
482
        '''Add the SOA record to the end of message. If it can't be
        added, a new message should be created to send out the last soa .
        '''
483
484
        if (message_upper_len + self._tsig_len + get_rrset_len(rrset_soa) >=
            XFROUT_MAX_MESSAGE_SIZE):
485
486
            self._send_message(sock_fd, msg, self._tsig_ctx)
            msg = self._clear_message(msg)
487

488
489
        # If tsig context exist, sign the last packet
        msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
490
        self._send_message(sock_fd, msg, self._tsig_ctx)
491

492
    def _reply_xfrout_query(self, msg, sock_fd):
493
494
        #TODO, there should be a better way to insert rrset.
        msg.make_response()
495
        msg.set_header_flag(Message.HEADERFLAG_AA)
496
        msg.add_rrset(Message.SECTION_ANSWER, self._soa)
497

498
        message_upper_len = get_rrset_len(self._soa) + self._tsig_len
499

500
501
502
        for rrset in self._iterator:
            # Check if xfrout is shutdown
            if  self._server._shutdown_event.is_set():
503
                logger.info(XFROUT_STOPPING)
504
                return
Jelte Jansen's avatar
Jelte Jansen committed
505

506
507
            if rrset.get_type() == RRType.SOA():
                continue
508
509
510
511

            # We calculate the maximum size of the RRset (i.e. the
            # size without compression) and use that to see if we
            # may have reached the limit
512
            rrset_len = get_rrset_len(rrset)
513
            if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
514
                msg.add_rrset(Message.SECTION_ANSWER, rrset)
515
                message_upper_len += rrset_len
516
517
                continue

518
            self._send_message(sock_fd, msg, self._tsig_ctx)
519

520
            msg = self._clear_message(msg)
521
            # Add the RRset to the new message
522
            msg.add_rrset(Message.SECTION_ANSWER, rrset)
523
524

            # Reserve tsig space for signed packet
525
            message_upper_len = rrset_len + self._tsig_len
526

527
        self._send_message_with_last_soa(msg, sock_fd, self._soa,
528
                                         message_upper_len)
529

530
531
class UnixSockServer(socketserver_mixin.NoPollMixIn,
                     ThreadingUnixStreamServer):
532
533
    '''The unix domain socket server which accept xfr query sent from auth server.'''

534
535
    def __init__(self, sock_file, handle_class, shutdown_event, config_data,
                 cc):
536
        self._remove_unused_sock_file(sock_file)
537
        self._sock_file = sock_file
538
        socketserver_mixin.NoPollMixIn.__init__(self)
539
540
        ThreadingUnixStreamServer.__init__(self, sock_file, handle_class)
        self._shutdown_event = shutdown_event
541
        self._write_sock, self._read_sock = socket.socketpair()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
542
        self._common_init()
543
        self._cc = cc
544
        self.update_config_data(config_data)
545

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
546
    def _common_init(self):
547
        '''Initialization shared with the mock server class used for tests'''
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
548
549
        self._lock = threading.Lock()
        self._transfers_counter = 0
550
551
        self._zone_config = {}
        self._acl = None # this will be initialized in update_config_data()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
552

553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
    def _receive_query_message(self, sock):
        ''' receive request message from sock'''
        # receive data length
        data_len = sock.recv(2)
        if not data_len:
            return None
        msg_len = struct.unpack('!H', data_len)[0]
        # receive data
        recv_size = 0
        msgdata = b''
        while recv_size < msg_len:
            data = sock.recv(msg_len - recv_size)
            if not data:
                return None
            recv_size += len(data)
            msgdata += data

        return msgdata

572
573
574
575
576
    def handle_request(self):
        ''' Enable server handle a request until shutdown or auth is closed.'''
        try:
            request, client_address = self.get_request()
        except socket.error:
577
            logger.error(XFROUT_FETCH_REQUEST_ERROR)
578
579
580
581
582
583
584
585
586
587
588
589
590
            return

        # Check self._shutdown_event to ensure the real shutdown comes.
        # Linux could trigger a spurious readable event on the _read_sock
        # due to a bug, so we need perform a double check.
        while not self._shutdown_event.is_set(): # Check if xfrout is shutdown
            try:
                (rlist, wlist, xlist) = select.select([self._read_sock, request], [], [])
            except select.error as e:
                if e.args[0] == errno.EINTR:
                    (rlist, wlist, xlist) = ([], [], [])
                    continue
                else:
591
                    logger.error(XFROUT_SOCKET_SELECT_ERROR, str(e))
592
593
594
595
596
597
598
599
600
                    break

            # self.server._shutdown_event will be set by now, if it is not a false
            # alarm
            if self._read_sock in rlist:
                continue

            try:
                self.process_request(request)
601
            except Exception as pre:
602
                log.error(XFROUT_PROCESS_REQUEST_ERROR, str(pre))
603
604
                break

605
    def _handle_request_noblock(self):
606
607
        """Override the function _handle_request_noblock(), it creates a new
        thread to handle requests for each auth"""
608
609
610
611
        td = threading.Thread(target=self.handle_request)
        td.setDaemon(True)
        td.start()

612
    def process_request(self, request):
613
614
615
616
617
618
619
        """Receive socket fd and query message from auth, then
        start a new thread to process the request."""
        sock_fd = recv_fd(request.fileno())
        if sock_fd < 0:
            # This may happen when one xfrout process try to connect to
            # xfrout unix socket server, to check whether there is another
            # xfrout running.
620
            if sock_fd == FD_COMM_ERROR:
621
                logger.error(XFROUT_RECEIVE_FILE_DESCRIPTOR_ERROR)
622
623
624
625
626
627
628
            return

        # receive request msg
        request_data = self._receive_query_message(request)
        if not request_data:
            return

629
        t = threading.Thread(target=self.finish_request,
630
                             args = (sock_fd, request_data))
631
632
633
634
        if self.daemon_threads:
            t.daemon = True
        t.start()

635
    def _guess_remote(self, sock_fd):
636
637
638
639
640
641
        """Guess remote address and port of the socket.

        The sock_fd must be a file descriptor of a socket.
        This method retuns a 3-tuple consisting of address family,
        socket type, and a 2-tuple with the address (string) and port (int).

642
643
        """
        # This uses a trick. If the socket is IPv4 in reality and we pretend
644
        # it to be IPv6, it returns IPv4 address anyway. This doesn't seem
645
646
647
648
649
650
651
652
        # to care about the SOCK_STREAM parameter at all (which it really is,
        # except for testing)
        if socket.has_ipv6:
            sock = socket.fromfd(sock_fd, socket.AF_INET6, socket.SOCK_STREAM)
        else:
            # To make it work even on hosts without IPv6 support
            # (Any idea how to simulate this in test?)
            sock = socket.fromfd(sock_fd, socket.AF_INET, socket.SOCK_STREAM)
653
654
655
656
657
658
659
660
661
662
        peer = sock.getpeername()

        # Identify the correct socket family.  Due to the above "trick",
        # we cannot simply use sock.family.
        family = socket.AF_INET6
        try:
            socket.inet_pton(socket.AF_INET6, peer[0])
        except socket.error:
            family = socket.AF_INET
        return (family, socket.SOCK_STREAM, peer)
663

664
    def finish_request(self, sock_fd, request_data):
665
666
        '''Finish one request by instantiating RequestHandlerClass.

667
668
669
        This is an entry point of a separate thread spawned in
        UnixSockServer.process_request().

670
671
        This method creates a XfroutSession object.
        '''
672
673
674
675
        self._lock.acquire()
        acl = self._acl
        zone_config = self._zone_config
        self._lock.release()
676
677
        self.RequestHandlerClass(sock_fd, request_data, self,
                                 self.tsig_key_ring,
678
                                 self._guess_remote(sock_fd), acl, zone_config)
679
680

    def _remove_unused_sock_file(self, sock_file):
681
682
        '''Try to remove the socket file. If the file is being used
        by one running xfrout process, exit from python.
683
684
685
        If it's not a socket file or nobody is listening
        , it will be removed. If it can't be removed, exit from python. '''
        if self._sock_file_in_use(sock_file):
686
            logger.error(XFROUT_UNIX_SOCKET_FILE_IN_USE, sock_file)
687
688
689
690
691
692
693
694
            sys.exit(0)
        else:
            if not os.path.exists(sock_file):
                return

            try:
                os.unlink(sock_file)
            except OSError as err:
695
                logger.error(XFROUT_REMOVE_OLD_UNIX_SOCKET_FILE_ERROR, sock_file, str(err))
696
                sys.exit(0)
697

698
    def _sock_file_in_use(self, sock_file):
699
700
        '''Check whether the socket file 'sock_file' exists and
        is being used by one running xfrout process. If it is,
701
702
703
704
705
706
707
        return True, or else return False. '''
        try:
            sock = socket.socket(socket.AF_UNIX)
            sock.connect(sock_file)
        except socket.error as err:
            return False
        else:
708
            return True
709

710
    def shutdown(self):
711
        self._write_sock.send(b"shutdown") #terminate the xfrout session thread
712
        super().shutdown() # call the shutdown() of class socketserver_mixin.NoPollMixIn
713
714
        try:
            os.unlink(self._sock_file)
Jerry's avatar
Jerry committed
715
        except Exception as e:
Jelte Jansen's avatar
Jelte Jansen committed
716
            logger.error(XFROUT_REMOVE_UNIX_SOCKET_FILE_ERROR, self._sock_file, str(e))
717
            pass
718
719

    def update_config_data(self, new_config):
720
721
722
        '''Apply the new config setting of xfrout module.

        '''
723
        self._lock.acquire()
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
        try:
            logger.info(XFROUT_NEW_CONFIG)
            new_acl = self._acl
            if 'transfer_acl' in new_config:
                try:
                    new_acl = REQUEST_LOADER.load(new_config['transfer_acl'])
                except LoaderError as e:
                    raise XfroutConfigError('Failed to parse transfer_acl: ' +
                                            str(e))

            new_zone_config = self._zone_config
            zconfig_data = new_config.get('zone_config')
            if zconfig_data is not None:
                new_zone_config = self.__create_zone_config(zconfig_data)

            self._acl = new_acl
            self._zone_config = new_zone_config
            self._max_transfers_out = new_config.get('transfers_out')
            self.set_tsig_key_ring(new_config.get('tsig_key_ring'))
        except Exception as e:
            self._lock.release()
            raise e
746
        self._lock.release()
747
        logger.info(XFROUT_NEW_CONFIG_DONE)
748

749
750
751
752
753
    def __create_zone_config(self, zone_config_list):
        new_config = {}
        for zconf in zone_config_list:
            # convert the class, origin (name) pair.  First build pydnspp
            # object to reject invalid input.
754
755
756
757
758
            zclass_str = zconf.get('class')
            if zclass_str is None:
                #zclass_str = 'IN' # temporary
                zclass_str = self._cc.get_default_value('zone_config/class')
            zclass = RRClass(zclass_str)
759
760
761
762
763
            zorigin = Name(zconf['origin'], True)
            config_key = (zclass.to_text(), zorigin.to_text())

            # reject duplicate config
            if config_key in new_config:
764
                raise XfroutConfigError('Duplicate zone_config for ' +
765
                                        str(zorigin) + '/' + str(zclass))
766
767
768
769

            # create a new config entry, build any given (and known) config
            new_config[config_key] = {}
            if 'transfer_acl' in zconf:
770
771
772
773
774
775
776
                try:
                    new_config[config_key]['transfer_acl'] = \
                        REQUEST_LOADER.load(zconf['transfer_acl'])
                except LoaderError as e:
                    raise XfroutConfigError('Failed to parse transfer_acl ' +
                                            'for ' + zorigin.to_text() + '/' +
                                            zclass_str + ': ' + str(e))
777
778
        return new_config

779
    def set_tsig_key_ring(self, key_list):
780
781
782
        """Set the tsig_key_ring , given a TSIG key string list representation. """

        # XXX add values to configure zones/tsig options
783
        self.tsig_key_ring = TSIGKeyRing()
784
        # If key string list is empty, create a empty tsig_key_ring
785
786
787
788
789
790
791
        if not key_list:
            return

        for key_item in key_list:
            try:
                self.tsig_key_ring.add(TSIGKey(key_item))
            except InvalidParameter as ipe:
792
                logger.error(XFROUT_BAD_TSIG_KEY_STRING, str(key_item))
793

794
    def get_db_file(self):
795
796
797
798
799
800
        file, is_default = self._cc.get_remote_config_value("Auth", "database_file")
        # this too should be unnecessary, but currently the
        # 'from build' override isn't stored in the config
        # (and we don't have indirect python access to datasources yet)
        if is_default and "B10_FROM_BUILD" in os.environ:
            file = os.environ["B10_FROM_BUILD"] + os.sep + "bind10_zones.sqlite3"
801
802
        return file

803

804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
    def increase_transfers_counter(self):
        '''Return False, if counter + 1 > max_transfers_out, or else
        return True
        '''
        ret = False
        self._lock.acquire()
        if self._transfers_counter < self._max_transfers_out:
            self._transfers_counter += 1
            ret = True
        self._lock.release()
        return ret

    def decrease_transfers_counter(self):
        self._lock.acquire()
        self._transfers_counter -= 1
        self._lock.release()

class XfroutServer:
    def __init__(self):
823
        self._unix_socket_server = None
824
        self._listen_sock_file = UNIX_SOCKET_FILE
825
        self._shutdown_event = threading.Event()
826
        self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler)
827
828
        self._config_data = self._cc.get_full_config()
        self._cc.start()
829
        self._cc.add_remote_config(AUTH_SPECFILE_LOCATION);
830
        self._start_xfr_query_listener()
831
        self._start_notifier()
832

833
834
    def _start_xfr_query_listener(self):
        '''Start a new thread to accept xfr query. '''
835
836
837
838
        self._unix_socket_server = UnixSockServer(self._listen_sock_file,
                                                  XfroutSession,
                                                  self._shutdown_event,
                                                  self._config_data,
839
                                                  self._cc)
840
        listener = threading.Thread(target=self._unix_socket_server.serve_forever)
841
        listener.start()
842

843
844
    def _start_notifier(self):
        datasrc = self._unix_socket_server.get_db_file()
845
        self._notifier = notify_out.NotifyOut(datasrc)
Michal Vaner's avatar
Michal Vaner committed
846
        self._notifier.dispatcher()
847

848
849
    def send_notify(self, zone_name, zone_class):
        self._notifier.send_notify(zone_name, zone_class)
850
851
852
853
854
855
856
857
858

    def config_handler(self, new_config):
        '''Update config data. TODO. Do error check'''
        answer = create_answer(0)
        for key in new_config:
            if key not in self._config_data:
                answer = create_answer(1, "Unknown config data: " + str(key))
                continue
            self._config_data[key] = new_config[key]
Michal Vaner's avatar
Michal Vaner committed
859

860
        if self._unix_socket_server:
861
862
863
            try:
                self._unix_socket_server.update_config_data(self._config_data)
            except Exception as e:
864
865
866
                answer = create_answer(1,
                                       "Failed to handle new configuration: " +
                                       str(e))
867

868
869
870
871
        return answer


    def shutdown(self):
872
        ''' shutdown the xfrout process. The thread which is doing zone transfer-out should be
873
874
        terminated.
        '''
875
876
877

        global xfrout_server
        xfrout_server = None #Avoid shutdown is called twice
878
        self._shutdown_event.set()
879
        self._notifier.shutdown()
880
881
        if self._unix_socket_server:
            self._unix_socket_server.shutdown()
882

883
        # Wait for all threads to terminate
884
885
886
887
888
889
890
891
        main_thread = threading.currentThread()
        for th in threading.enumerate():
            if th is main_thread:
                continue
            th.join()

    def command_handler(self, cmd, args):
        if cmd == "shutdown":
892
            logger.info(XFROUT_RECEIVED_SHUTDOWN_COMMAND)
893
894
            self.shutdown()
            answer = create_answer(0)
Michal Vaner's avatar
Michal Vaner committed
895

896
        elif cmd == notify_out.ZONE_NEW_DATA_READY_CMD:
897
            zone_name = args.get('zone_name')
898
899
            zone_class = args.get('zone_class')
            if zone_name and zone_class:
900
                logger.info(XFROUT_NOTIFY_COMMAND, zone_name, zone_class)
901
                self.send_notify(zone_name, zone_class)
902
903
904
905
                answer = create_answer(0)
            else:
                answer = create_answer(1, "Bad command parameter:" + str(args))

906
        else:
907
908
            answer = create_answer(1, "Unknown command:" + str(cmd))

Michal Vaner's avatar
Michal Vaner committed
909
        return answer
910
911
912
913

    def run(self):
        '''Get and process all commands sent from cfgmgr or other modules. '''
        while not self._shutdown_event.is_set():
914
            self._cc.check_command(False)
915
916
917
918
919


xfrout_server = None

def signal_handler(signal, frame):
920
    if xfrout_server:
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
        xfrout_server.shutdown()
        sys.exit(0)

def set_signal_handler():
    signal.signal(signal.SIGTERM, signal_handler)
    signal.signal(signal.SIGINT, signal_handler)

def set_cmd_options(parser):
    parser.add_option("-v", "--verbose", dest="verbose", action="store_true",
            help="display more about what is going on")

if '__main__' == __name__:
    try:
        parser = OptionParser()
        set_cmd_options(parser)
        (options, args) = parser.parse_args()
Jerry's avatar
Jerry committed
937
        VERBOSE_MODE = options.verbose
938
939
940
941
942

        set_signal_handler()
        xfrout_server = XfroutServer()
        xfrout_server.run()
    except KeyboardInterrupt:
943
        logger.INFO(XFROUT_STOPPED_BY_KEYBOARD)
944
    except SessionError as e:
945
        logger.error(XFROUT_CC_SESSION_ERROR, str(e))
946
947
948
949
    except ModuleCCSessionError as e:
        logger.error(XFROUT_MODULECC_SESSION_ERROR, str(e))
    except XfroutConfigError as e:
        logger.error(XFROUT_CONFIG_ERROR, str(e))
950
    except SessionTimeout as e:
951
        logger.error(XFROUT_CC_SESSION_TIMEOUT_ERROR)
952

953
954
955
    if xfrout_server:
        xfrout_server.shutdown()