xfrout.py.in 47.9 KB
Newer Older
1
2
#!@PYTHON@

3
# Copyright (C) 2010-2012  Internet Systems Consortium.
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.


import sys; sys.path.append ('@@PYTHONPATH@@')
import isc
import isc.cc
import threading
import struct
import signal
25
from isc.datasrc import DataSourceClient, ZoneFinder, ZoneJournalReader
26
27
28
from socketserver import *
import os
from isc.config.ccsession import *
29
from isc.cc import SessionError, SessionTimeout
30
from isc.statistics import Counters
31
from isc.notify import notify_out
32
import isc.util.process
33
import fcntl
34
import socket
35
import select
36
import errno
37
from optparse import OptionParser, OptionValueError
Likun Zhang's avatar
Likun Zhang committed
38
from isc.util import socketserver_mixin
39
import isc.server_common.tsig_keyring
40

41
from isc.log_messages.xfrout_messages import *
42

43
isc.log.init("b10-xfrout", buffer=True)
44
logger = isc.log.Logger("xfrout")
45
46
47
48
49
50

# Pending system-wide debug level definitions, the ones we
# use here are hardcoded for now
DBG_PROCESS = logger.DBGLVL_TRACE_BASIC
DBG_COMMANDS = logger.DBGLVL_TRACE_DETAIL

Jelte Jansen's avatar
Jelte Jansen committed
51
DBG_XFROUT_TRACE = logger.DBGLVL_TRACE_BASIC
52

53
try:
54
    from libutil_io_python import *
Jelte Jansen's avatar
Jelte Jansen committed
55
    from pydnspp import *
56
57
58
except ImportError as e:
    # C++ loadable module may not be installed; even so the xfrout process
    # must keep running, so we warn about it and move forward.
59
    logger.error(XFROUT_IMPORT, str(e))
Michal Vaner's avatar
Michal Vaner committed
60

61
from isc.acl.acl import ACCEPT, REJECT, DROP, LoaderError
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
62
63
from isc.acl.dns import REQUEST_LOADER

64
isc.util.process.rename()
65

66
67
68
69
70
71
72
73
74
75
76
77
class XfroutConfigError(Exception):
    """An exception indicating an error in updating xfrout configuration.

    This exception is raised when the xfrout process encouters an error in
    handling configuration updates.  Not all syntax error can be caught
    at the module-CC layer, so xfrout needs to (explicitly or implicitly)
    validate the given configuration data itself.  When it finds an error
    it raises this exception (either directly or by converting an exception
    from other modules) as a unified error in configuration.
    """
    pass

78
79
80
81
82
class XfroutSessionError(Exception):
    '''An exception raised for some unexpected events during an xfrout session.
    '''
    pass

83
84
85
86
87
88
89
90
91
92
93
94
def init_paths():
    global SPECFILE_PATH
    global AUTH_SPECFILE_PATH
    global UNIX_SOCKET_FILE
    if "B10_FROM_BUILD" in os.environ:
        SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/xfrout"
        AUTH_SPECFILE_PATH = os.environ["B10_FROM_BUILD"] + "/src/bin/auth"
        if "B10_FROM_SOURCE_LOCALSTATEDIR" in os.environ:
            UNIX_SOCKET_FILE = os.environ["B10_FROM_SOURCE_LOCALSTATEDIR"] + \
                "/auth_xfrout_conn"
        else:
            UNIX_SOCKET_FILE = os.environ["B10_FROM_BUILD"] + "/auth_xfrout_conn"
95
    else:
96
97
98
99
100
101
102
        PREFIX = "@prefix@"
        DATAROOTDIR = "@datarootdir@"
        SPECFILE_PATH = "@datadir@/@PACKAGE@".replace("${datarootdir}", DATAROOTDIR).replace("${prefix}", PREFIX)
        AUTH_SPECFILE_PATH = SPECFILE_PATH
        if "BIND10_XFROUT_SOCKET_FILE" in os.environ:
            UNIX_SOCKET_FILE = os.environ["BIND10_XFROUT_SOCKET_FILE"]
        else:
103
            UNIX_SOCKET_FILE = "@@LOCALSTATEDIR@@/@PACKAGE_NAME@/auth_xfrout_conn"
104
105

init_paths()
106

107
SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec"
108
AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
Jerry's avatar
Jerry committed
109
VERBOSE_MODE = False
110
111
XFROUT_DNS_HEADER_SIZE = 12     # protocol constant
XFROUT_MAX_MESSAGE_SIZE = 65535 # ditto
112

113
114
115
116
117
118
119
120
# borrowed from xfrin.py @ #1298.  We should eventually unify it.
def format_zone_str(zone_name, zone_class):
    """Helper function to format a zone name and class as a string of
       the form '<name>/<class>'.
       Parameters:
       zone_name (isc.dns.Name) name to format
       zone_class (isc.dns.RRClass) class to format
    """
121
    return zone_name.to_text(True) + '/' + str(zone_class)
122

123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# borrowed from xfrin.py @ #1298.
def format_addrinfo(addrinfo):
    """Helper function to format the addrinfo as a string of the form
       <addr>:<port> (for IPv4) or [<addr>]:port (for IPv6). For unix domain
       sockets, and unknown address families, it returns a basic string
       conversion of the third element of the passed tuple.
       Parameters:
       addrinfo: a 3-tuple consisting of address family, socket type, and,
                 depending on the family, either a 2-tuple with the address
                 and port, or a filename
    """
    try:
        if addrinfo[0] == socket.AF_INET:
            return str(addrinfo[2][0]) + ":" + str(addrinfo[2][1])
        elif addrinfo[0] == socket.AF_INET6:
            return "[" + str(addrinfo[2][0]) + "]:" + str(addrinfo[2][1])
        else:
            return str(addrinfo[2])
    except IndexError:
        raise TypeError("addrinfo argument to format_addrinfo() does not "
                        "appear to be consisting of (family, socktype, (addr, port))")

145
146
147
148
149
150
def get_rrset_len(rrset):
    """Returns the wire length of the given RRset"""
    bytes = bytearray()
    rrset.to_wire(bytes)
    return len(bytes)

151
def get_soa_serial(soa_rdata):
152
    '''Extract the serial field of an SOA RDATA and returns it as an Serial object.
153
    '''
154
    return Serial(int(soa_rdata.to_text().split()[2]))
155

156
def make_blocking(filenum, on):
157
    """A helper function to change blocking mode of the given socket.
158

159
160
161
    It sets the mode of blocking I/O for the socket associated with filenum
    (descriptor of the socket) according to parameter 'on': if it's True the
    file will be made blocking; otherwise it will be made non-blocking.
162

163
164
165
166
    The given filenum must be a descriptor of a socket (not an ordinary file
    etc), but this function doesn't check that condition.

    filenum(int): file number (descriptor) of the socket to update.
167
168
169
170
171
172
173
174
175
176
    on(bool): whether enable (True) or disable (False) blocking I/O.

    """
    flags = fcntl.fcntl(filenum, fcntl.F_GETFL)
    if on:                      # make it blocking
        flags &= ~os.O_NONBLOCK
    else:                       # make it non blocking
        flags |= os.O_NONBLOCK
    fcntl.fcntl(filenum, fcntl.F_SETFL, flags)

177
class XfroutSession():
178
    def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
179
                 default_acl, zone_config, client_class=DataSourceClient):
180
181
182
        self._sock_fd = sock_fd
        self._request_data = request_data
        self._server = server
183
184
185
        self._tsig_key_ring = tsig_key_ring
        self._tsig_ctx = None
        self._tsig_len = 0
186
        self._remote = remote
187
188
        self._request_type = None
        self._request_typestr = None
189
190
        self._acl = default_acl
        self._zone_config = zone_config
191
        self.ClientClass = client_class # parameterize this for testing
192
        self._soa = None # will be set in _xfrout_setup or in tests
193
        self._jnl_reader = None # will be set to a reader for IXFR
194
        # Creation of self.counters should be done before of
195
        # invoking self._handle()
196
        self._counters = Counters(SPECFILE_LOCATION)
197
        self._handle()
Jerry's avatar
Jerry committed
198

199
200
201
202
    def create_tsig_ctx(self, tsig_record, tsig_key_ring):
        return TSIGContext(tsig_record.get_name(), tsig_record.get_rdata().get_algorithm(),
                           tsig_key_ring)

203
204
205
206
207
208
209
210
    def _handle(self):
        ''' Handle a xfrout query, send xfrout response(s).

        This is separated from the constructor so that we can override
        it from tests.

        '''
        # Check the xfrout quota.  We do both increase/decrease in this
211
        # method so it's clear we always release it once acquired.
212
213
        quota_ok = self._server.increase_transfers_counter()
        ex = None
214
        try:
215
216
217
218
            # Before start, make sure the socket uses blocking I/O because
            # responses will be sent in the blocking mode; otherwise it could
            # result in EWOULDBLOCK and disrupt the session.
            make_blocking(self._sock_fd, True)
219
            self.dns_xfrout_start(self._sock_fd, self._request_data, quota_ok)
220
        except Exception as e:
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
            # To avoid resource leak we need catch all possible exceptions
            # We log it later to exclude the case where even logger raises
            # an exception.
            ex = e

        # Release any critical resources
        if quota_ok:
            self._server.decrease_transfers_counter()
        self._close_socket()

        if ex is not None:
            logger.error(XFROUT_HANDLE_QUERY_ERROR, ex)

    def _close_socket(self):
        '''Simply close the socket via the given FD.
236

237
238
239
240
        This is a dedicated subroutine of handle() and is sepsarated from it
        for the convenience of tests.

        '''
241
        os.close(self._sock_fd)
242

243
244
245
246
247
    def _check_request_tsig(self, msg, request_data):
        ''' If request has a tsig record, perform tsig related checks '''
        tsig_record = msg.get_tsig_record()
        if tsig_record is not None:
            self._tsig_len = tsig_record.get_length()
248
249
            self._tsig_ctx = self.create_tsig_ctx(tsig_record,
                                                  self._tsig_key_ring)
250
251
            tsig_error = self._tsig_ctx.verify(tsig_record, request_data)
            if tsig_error != TSIGError.NOERROR:
252
                return Rcode.NOTAUTH
253

254
        return Rcode.NOERROR
255

256
257
    def _parse_query_message(self, mdata):
        ''' parse query message to [socket,message]'''
258
        #TODO, need to add parseHeader() in case the message header is invalid
259
        try:
260
            msg = Message(Message.PARSE)
261
            Message.from_wire(msg, mdata)
262
        except Exception as err: # Exception is too broad
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
263
            logger.error(XFROUT_PARSE_QUERY_ERROR, err)
264
            return Rcode.FORMERR, None
265

266
267
        # TSIG related checks
        rcode = self._check_request_tsig(msg, mdata)
268
        if rcode != Rcode.NOERROR:
269
270
271
            return rcode, msg

        # Make sure the question is valid.  This should be ensured by
272
273
274
        # the auth server, but since it's far from xfrout itself, we check
        # it by ourselves.  A viloation would be an internal bug, so we
        # raise and stop here rather than returning a FORMERR or SERVFAIL.
275
        if msg.get_rr_count(Message.SECTION_QUESTION) != 1:
276
277
            raise RuntimeError('Invalid number of question for XFR: ' +
                               str(msg.get_rr_count(Message.SECTION_QUESTION)))
278
279
280
281
        question = msg.get_question()[0]

        # Identify the request type
        self._request_type = question.get_type()
282
        if self._request_type == RRType.AXFR:
283
            self._request_typestr = 'AXFR'
284
        elif self._request_type == RRType.IXFR:
285
            self._request_typestr = 'IXFR'
286
        else:
287
288
289
            # Likewise, this should be impossible.
            raise RuntimeError('Unexpected XFR type: ' +
                               str(self._request_type))
290
291

        # ACL checks
292
293
        zone_name = question.get_name()
        zone_class = question.get_class()
294
295
        acl = self._get_transfer_acl(zone_name, zone_class)
        acl_result = acl.execute(
296
            isc.acl.dns.RequestContext(self._remote[2], msg.get_tsig_record()))
297
        if acl_result == DROP:
Jelte Jansen's avatar
Jelte Jansen committed
298
299
300
            logger.debug(DBG_XFROUT_TRACE, XFROUT_QUERY_DROPPED,
                         self._request_type, format_addrinfo(self._remote),
                         format_zone_str(zone_name, zone_class))
301
302
            return None, None
        elif acl_result == REJECT:
303
            # count rejected Xfr request by each zone name
Naoki Kambe's avatar
Naoki Kambe committed
304
305
            self._counters.inc('zones', zone_class.to_text(),
                               zone_name.to_text(), 'xfrrej')
Jelte Jansen's avatar
Jelte Jansen committed
306
307
308
            logger.debug(DBG_XFROUT_TRACE, XFROUT_QUERY_REJECTED,
                         self._request_type, format_addrinfo(self._remote),
                         format_zone_str(zone_name, zone_class))
309
            return Rcode.REFUSED, msg
310

311
        return rcode, msg
312

313
314
315
316
317
318
319
320
321
322
323
324
    def _get_transfer_acl(self, zone_name, zone_class):
        '''Return the ACL that should be applied for a given zone.

        The zone is identified by a tuple of name and RR class.
        If a per zone configuration for the zone exists and contains
        transfer_acl, that ACL will be used; otherwise, the default
        ACL will be used.

        '''
        # Internally zone names are managed in lower cased label characters,
        # so we first need to convert the name.
        zone_name_lower = Name(zone_name.to_text(), True)
325
        config_key = (zone_class.to_text(), zone_name_lower.to_text())
326
327
328
329
330
        if config_key in self._zone_config and \
                'transfer_acl' in self._zone_config[config_key]:
            return self._zone_config[config_key]['transfer_acl']
        return self._acl

331
    def _send_data(self, sock_fd, data):
332
333
334
        size = len(data)
        total_count = 0
        while total_count < size:
335
            count = os.write(sock_fd, data[total_count:])
336
337
338
            total_count += count


339
    def _send_message(self, sock_fd, msg, tsig_ctx=None):
340
        render = MessageRenderer()
341
342
        # As defined in RFC5936 section3.4, perform case-preserving name
        # compression for AXFR message.
343
        render.set_compress_mode(MessageRenderer.CASE_SENSITIVE)
344
        render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
345
346
347
348
349
350
351
352

        # XXX Currently, python wrapper doesn't accept 'None' parameter in this case,
        # we should remove the if statement and use a universal interface later.
        if tsig_ctx is not None:
            msg.to_wire(render, tsig_ctx)
        else:
            msg.to_wire(render)

353
        header_len = struct.pack('H', socket.htons(render.get_length()))
354
355
        self._send_data(sock_fd, header_len)
        self._send_data(sock_fd, render.get_data())
356
357


358
    def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
359
        if not msg:
360
            return # query message is invalid. send nothing back.
361
362

        msg.make_response()
363
        msg.set_rcode(rcode_)
364
        self._send_message(sock_fd, msg, self._tsig_ctx)
365

366
    def _get_zone_soa(self, zone_name):
367
368
369
370
371
372
        '''Retrieve the SOA RR of the given zone.

        It returns a pair of RCODE and the SOA (in the form of RRset).
        On success RCODE is NOERROR and returned SOA is not None;
        on failure RCODE indicates the appropriate code in the context of
        xfr processing, and the returned SOA is None.
373

374
        '''
375
376
        result, finder = self._datasrc_client.find_zone(zone_name)
        if result != DataSourceClient.SUCCESS:
377
            return (Rcode.NOTAUTH, None)
378
        result, soa_rrset, _ = finder.find(zone_name, RRType.SOA)
379
        if result != ZoneFinder.SUCCESS:
380
            return (Rcode.SERVFAIL, None)
381
382
383
384
        # Especially for database-based zones, a working zone may be in
        # a broken state where it has more than one SOA RR.  We proactively
        # check the condition and abort the xfr attempt if we identify it.
        if soa_rrset.get_rdata_count() != 1:
385
386
            return (Rcode.SERVFAIL, None)
        return (Rcode.NOERROR, soa_rrset)
387

388
    def __axfr_setup(self, zone_name):
389
        '''Setup a zone iterator for AXFR or AXFR-style IXFR.
390

391
        '''
392
        try:
393
            # Note that we enable 'separate_rrs'.  In xfr-out we need to
394
395
396
            # preserve as many things as possible (even if it's half broken)
            # stored in the zone.
            self._iterator = self._datasrc_client.get_iterator(zone_name,
397
                                                               True)
398
        except isc.datasrc.Error:
399
400
401
402
            # If the current name server does not have authority for the
            # zone, xfrout can't serve for it, return rcode NOTAUTH.
            # Note: this exception can happen for other reasons.  We should
            # update get_iterator() API so that we can distinguish "no such
JINMEI Tatuya's avatar
JINMEI Tatuya committed
403
404
            # zone" and other cases (#1373).  For now we consider all these
            # cases as NOTAUTH.
405
            return Rcode.NOTAUTH
406

JINMEI Tatuya's avatar
JINMEI Tatuya committed
407
408
409
        # If we are an authoritative name server for the zone, but fail
        # to find the zone's SOA record in datasource, xfrout can't
        # provide zone transfer for it.
410
        self._soa = self._iterator.get_soa()
411
        if self._soa is None or self._soa.get_rdata_count() != 1:
412
            return Rcode.SERVFAIL
413

414
        return Rcode.NOERROR
415

416
    def __ixfr_setup(self, request_msg, zone_name, zone_class):
417
418
419
420
421
422
423
        '''Setup a zone journal reader for IXFR.

        If the underlying data source does not know the requested range
        of zone differences it automatically falls back to AXFR-style
        IXFR by setting up a zone iterator instead of a journal reader.

        '''
424
425
        # Check the authority section.  Look for a SOA record with
        # the same name and class as the question.
426
427
        remote_soa = None
        for auth_rrset in request_msg.get_section(Message.SECTION_AUTHORITY):
428
429
430
            # Ignore data whose owner name is not the zone apex, and
            # ignore non-SOA or different class of records.
            if auth_rrset.get_name() != zone_name or \
431
                    auth_rrset.get_type() != RRType.SOA or \
432
                    auth_rrset.get_class() != zone_class:
433
                continue
434
            if auth_rrset.get_rdata_count() != 1:
435
436
                logger.info(XFROUT_IXFR_MULTIPLE_SOA,
                            format_addrinfo(self._remote))
437
                return Rcode.FORMERR
438
            remote_soa = auth_rrset
439
        if remote_soa is None:
440
            logger.info(XFROUT_IXFR_NO_SOA, format_addrinfo(self._remote))
441
            return Rcode.FORMERR
442

JINMEI Tatuya's avatar
JINMEI Tatuya committed
443
        # Retrieve the local SOA
444
        rcode, self._soa = self._get_zone_soa(zone_name)
445
        if rcode != Rcode.NOERROR:
446
            return rcode
447
448
449
450
451
452
453
454

        # RFC1995 says "If an IXFR query with the same or newer version
        # number than that of the server is received, it is replied to with
        # a single SOA record of the server's current version, just as
        # in AXFR".  The claim about AXFR is incorrect, but other than that,
        # we do as the RFC says.
        begin_serial = get_soa_serial(remote_soa.get_rdata()[0])
        end_serial = get_soa_serial(self._soa.get_rdata()[0])
455
        if begin_serial >= end_serial:
456
457
458
459
            # clear both iterator and jnl_reader to signal we won't do
            # iteration in response generation
            self._iterator = None
            self._jnl_reader = None
460
461
462
            logger.info(XFROUT_IXFR_UPTODATE, format_addrinfo(self._remote),
                        format_zone_str(zone_name, zone_class),
                        begin_serial, end_serial)
463
            return Rcode.NOERROR
464

465
        # Set up the journal reader or fall back to AXFR-style IXFR
466
467
        try:
            code, self._jnl_reader = self._datasrc_client.get_journal_reader(
468
                zone_name, begin_serial.get_value(), end_serial.get_value())
469
470
        except isc.datasrc.NotImplemented as ex:
            # The underlying data source doesn't support journaling.
471
472
473
474
            # Fall back to AXFR-style IXFR.
            logger.info(XFROUT_IXFR_NO_JOURNAL_SUPPORT,
                        format_addrinfo(self._remote),
                        format_zone_str(zone_name, zone_class))
475
            return self.__axfr_setup(zone_name)
476
        if code == ZoneJournalReader.NO_SUCH_VERSION:
477
478
479
            logger.info(XFROUT_IXFR_NO_VERSION, format_addrinfo(self._remote),
                        format_zone_str(zone_name, zone_class),
                        begin_serial, end_serial)
480
            return self.__axfr_setup(zone_name)
481
482
483
484
485
        if code == ZoneJournalReader.NO_SUCH_ZONE:
            # this is quite unexpected as we know zone's SOA exists.
            # It might be a bug or the data source is somehow broken,
            # but it can still happen if someone has removed the zone
            # between these two operations.  We treat it as NOTAUTH.
486
487
            logger.warn(XFROUT_IXFR_NO_ZONE, format_addrinfo(self._remote),
                        format_zone_str(zone_name, zone_class))
488
            return Rcode.NOTAUTH
489

490
491
        # Use the reader as the iterator to generate the response.
        self._iterator = self._jnl_reader
492

493
        return Rcode.NOERROR
494

495
    def _xfrout_setup(self, request_msg, zone_name, zone_class):
496
497
498
499
500
501
        '''Setup a context for xfr responses according to the request type.

        This method identifies the most appropriate data source for the
        request and set up a zone iterator or journal reader depending on
        whether the request is AXFR or IXFR.  If it identifies any protocol
        level error it returns an RCODE other than NOERROR.
502

503
        '''
504

505
506
        # Identify the data source for the requested zone and see if it has
        # SOA while initializing objects used for request processing later.
JINMEI Tatuya's avatar
JINMEI Tatuya committed
507
508
509
        # We should eventually generalize this so that we can choose the
        # appropriate data source from (possible) multiple candidates.
        # We should eventually take into account the RR class here.
510
        # For now, we hardcode a particular type (SQLite3-based), and only
511
        # consider that one.
512
513
        datasrc_config = '{ "database_file": "' + \
            self._server.get_db_file() + '"}'
514
        self._datasrc_client = self.ClientClass('sqlite3', datasrc_config)
515

516
        if self._request_type == RRType.AXFR:
517
            return self.__axfr_setup(zone_name)
518
        else:
519
            return self.__ixfr_setup(request_msg, zone_name, zone_class)
520

521
    def dns_xfrout_start(self, sock_fd, msg_query, quota_ok=True):
522
523
        rcode_, msg = self._parse_query_message(msg_query)
        #TODO. create query message and parse header
524
525
        if rcode_ is None: # Dropped by ACL
            return
526
        elif rcode_ == Rcode.NOTAUTH or rcode_ == Rcode.REFUSED:
527
            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
528
        elif rcode_ != Rcode.NOERROR:
529
            return self._reply_query_with_error_rcode(msg, sock_fd,
530
                                                      Rcode.FORMERR)
531
        elif not quota_ok:
532
            logger.warn(XFROUT_QUERY_QUOTA_EXCEEDED, self._request_typestr,
533
534
                        format_addrinfo(self._remote),
                        self._server._max_transfers_out)
535
            return self._reply_query_with_error_rcode(msg, sock_fd,
536
                                                      Rcode.REFUSED)
537

538
539
540
541
        question = msg.get_question()[0]
        zone_name = question.get_name()
        zone_class = question.get_class()
        zone_str = format_zone_str(zone_name, zone_class) # for logging
542

543
        try:
544
            rcode_ = self._xfrout_setup(msg, zone_name, zone_class)
545
        except Exception as ex:
546
            logger.error(XFROUT_XFR_TRANSFER_CHECK_ERROR, self._request_typestr,
547
                         format_addrinfo(self._remote), zone_str, ex)
548
549
            rcode_ = Rcode.SERVFAIL
        if rcode_ != Rcode.NOERROR:
550
            logger.info(XFROUT_XFR_TRANSFER_FAILED, self._request_typestr,
551
                        format_addrinfo(self._remote), zone_str, rcode_)
552
            return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
553
554

        try:
555
            # increment Xfr starts by RRType
556
            if self._request_type == RRType.AXFR:
557
                self._counters.inc('axfr_running')
558
            else:
559
                self._counters.inc('ixfr_running')
560
            logger.info(XFROUT_XFR_TRANSFER_STARTED, self._request_typestr,
561
                        format_addrinfo(self._remote), zone_str)
562
            self._reply_xfrout_query(msg, sock_fd)
563
        except Exception as err:
564
            # count unixsockets send errors
565
            self._counters.inc('socket', 'unixdomain', 'senderr')
566
            logger.error(XFROUT_XFR_TRANSFER_ERROR, self._request_typestr,
567
                    format_addrinfo(self._remote), zone_str, err)
568
569
        finally:
            # decrement Xfr starts by RRType
570
            if self._request_type == RRType.AXFR:
571
                self._counters.dec('axfr_running')
572
            else:
573
                self._counters.dec('ixfr_running')
574
        # count done Xfr requests by each zone name
Naoki Kambe's avatar
Naoki Kambe committed
575
576
        self._counters.inc('zones', zone_class.to_text(),
                           zone_name.to_text(), 'xfrreqdone')
577
        logger.info(XFROUT_XFR_TRANSFER_DONE, self._request_typestr,
578
                    format_addrinfo(self._remote), zone_str)
579
580
581
582
583

    def _clear_message(self, msg):
        qid = msg.get_qid()
        opcode = msg.get_opcode()
        rcode = msg.get_rcode()
584

585
        msg.clear(Message.RENDER)
586
587
588
        msg.set_qid(qid)
        msg.set_opcode(opcode)
        msg.set_rcode(rcode)
589
590
        msg.set_header_flag(Message.HEADERFLAG_AA)
        msg.set_header_flag(Message.HEADERFLAG_QR)
591
592
        return msg

593
594
    def _send_message_with_last_soa(self, msg, sock_fd, rrset_soa,
                                    message_upper_len):
595
596
597
598
599
600
601
602
        '''Add the SOA record to the end of message.

        If it would exceed the maximum allowable size of a message, a new
        message will be created to send out the last SOA.

        We assume a message with a single SOA can always fit the buffer
        with or without TSIG.  In theory this could be wrong if TSIG is
        stupidly large, but in practice this assumption should be reasonable.
603
        '''
604
605
        if message_upper_len + get_rrset_len(rrset_soa) > \
                XFROUT_MAX_MESSAGE_SIZE:
606
607
            self._send_message(sock_fd, msg, self._tsig_ctx)
            msg = self._clear_message(msg)
608

609
        msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
610
        self._send_message(sock_fd, msg, self._tsig_ctx)
611

612
    def _reply_xfrout_query(self, msg, sock_fd):
613
        msg.make_response()
614
        msg.set_header_flag(Message.HEADERFLAG_AA)
615
616
617
618
619
620
621
        # Reserved space for the fixed header size, the size of the question
        # section, and TSIG size (when included).  The size of the question
        # section is the sum of the qname length and the size of the
        # fixed-length fields (type and class, 2 bytes each).
        message_upper_len = XFROUT_DNS_HEADER_SIZE + \
            msg.get_question()[0].get_name().get_length() + 4 + \
            self._tsig_len
622

623
624
625
        # If the iterator is None, we are responding to IXFR with a single
        # SOA RR.
        if self._iterator is None:
626
627
            self._send_message_with_last_soa(msg, sock_fd, self._soa,
                                             message_upper_len)
628
629
630
631
            return

        # Add the beginning SOA
        msg.add_rrset(Message.SECTION_ANSWER, self._soa)
632
        message_upper_len += get_rrset_len(self._soa)
633

634
        # Add the rest of the zone/diff contets
635
636
637
        for rrset in self._iterator:
            # Check if xfrout is shutdown
            if  self._server._shutdown_event.is_set():
638
                logger.info(XFROUT_STOPPING)
639
                return
Jelte Jansen's avatar
Jelte Jansen committed
640

641
642
            # For AXFR (or AXFR-style IXFR), in which case _jnl_reader is None,
            # we should skip SOAs from the iterator.
643
            if self._jnl_reader is None and rrset.get_type() == RRType.SOA:
644
                continue
645
646
647
648

            # We calculate the maximum size of the RRset (i.e. the
            # size without compression) and use that to see if we
            # may have reached the limit
649
            rrset_len = get_rrset_len(rrset)
650
651

            if message_upper_len + rrset_len <= XFROUT_MAX_MESSAGE_SIZE:
652
                msg.add_rrset(Message.SECTION_ANSWER, rrset)
653
                message_upper_len += rrset_len
654
655
                continue

656
657
            # RR would not fit.  If there are other RRs in the buffer, send
            # them now and leave this RR to the next message.
658
            self._send_message(sock_fd, msg, self._tsig_ctx)
659

660
661
            # Create a new message and reserve space for the carried-over
            # RR (and TSIG space in case it's to be TSIG signed)
662
            msg = self._clear_message(msg)
663
664
665
666
667
668
669
670
671
672
673
            message_upper_len = XFROUT_DNS_HEADER_SIZE + rrset_len + \
                self._tsig_len

            # If this RR overflows the buffer all by itself, fail.  In theory
            # some RRs might fit in a TCP message when compressed even if they
            # do not fit when uncompressed, but surely we don't want to send
            # such monstrosities to an unsuspecting slave.
            if message_upper_len > XFROUT_MAX_MESSAGE_SIZE:
                raise XfroutSessionError('RR too large for zone transfer (' +
                                         str(rrset_len) + ' bytes)')

674
            # Add the RRset to the new message
675
            msg.add_rrset(Message.SECTION_ANSWER, rrset)
676

677
        # Add and send the trailing SOA
678
        self._send_message_with_last_soa(msg, sock_fd, self._soa,
679
                                         message_upper_len)
680

681
682
class UnixSockServer(socketserver_mixin.NoPollMixIn,
                     ThreadingUnixStreamServer):
683
684
    '''The unix domain socket server which accept xfr query sent from auth server.'''

685
    def __init__(self, sock_file, handle_class, shutdown_event, config_data,
686
                 cc):
687
        self._remove_unused_sock_file(sock_file)
688
        self._sock_file = sock_file
689
        socketserver_mixin.NoPollMixIn.__init__(self)
690
        self._counters = Counters(SPECFILE_LOCATION)
691
692
693
694
        try:
            ThreadingUnixStreamServer.__init__(self, sock_file, \
                                                   handle_class)
        except:
695
            self._counters.inc('socket', 'unixdomain', 'openfail')
696
            raise
697
        else:
698
            self._counters.inc('socket', 'unixdomain', 'open')
699
        self._shutdown_event = shutdown_event
700
        self._write_sock, self._read_sock = socket.socketpair()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
701
        self._common_init()
702
        self._cc = cc
703
        self.update_config_data(config_data)
704
705

    def server_bind(self):
706
707
        """server_bind() overridden for counting unix domain sockets
        bind() failures
708
709
710
711
712
713
714
        """
        try:
            # call the server_bind() of class
            # ThreadingUnixStreamServer
            return super().server_bind()
        except:
            # count bind failed unixsockets
715
            self._counters.inc('socket', 'unixdomain', 'bindfail')
716
717
718
            raise

    def get_request(self):
719
720
        """get_request() overridden for counting unix domain sockets
        accept() failures and success
721
722
723
724
725
726
        """
        try:
            # call the get_request() of class
            # ThreadingUnixStreamServer
            ret = super().get_request()
            # count successfully accepted unixsockets
727
            self._counters.inc('socket', 'unixdomain', 'accept')
728
729
730
            return ret
        except:
            # count failed accepted unixsockets
731
            self._counters.inc('socket', 'unixdomain', 'acceptfail')
732
            raise
733

Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
734
    def _common_init(self):
735
        '''Initialization shared with the mock server class used for tests'''
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
736
737
        self._lock = threading.Lock()
        self._transfers_counter = 0
738
739
        self._zone_config = {}
        self._acl = None # this will be initialized in update_config_data()
Michal 'vorner' Vaner's avatar
Michal 'vorner' Vaner committed
740

741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
    def _receive_query_message(self, sock):
        ''' receive request message from sock'''
        # receive data length
        data_len = sock.recv(2)
        if not data_len:
            return None
        msg_len = struct.unpack('!H', data_len)[0]
        # receive data
        recv_size = 0
        msgdata = b''
        while recv_size < msg_len:
            data = sock.recv(msg_len - recv_size)
            if not data:
                return None
            recv_size += len(data)
            msgdata += data

        return msgdata

760
761
762
763
764
    def handle_request(self):
        ''' Enable server handle a request until shutdown or auth is closed.'''
        try:
            request, client_address = self.get_request()
        except socket.error:
765
            logger.error(XFROUT_FETCH_REQUEST_ERROR)
766
            return
767
768
769
770
771
772
773
774
        self._select_loop(request)

    def _select_loop(self, request_sock):
        '''Main loop for a single session between xfrout and auth.

        This is a dedicated subroutine of handle_request(), but is defined
        as a separate "protected" method for the convenience of tests.
        '''
775
776
777
778
779
780

        # Check self._shutdown_event to ensure the real shutdown comes.
        # Linux could trigger a spurious readable event on the _read_sock
        # due to a bug, so we need perform a double check.
        while not self._shutdown_event.is_set(): # Check if xfrout is shutdown
            try:
781
782
                (rlist, wlist, xlist) = select.select([self._read_sock,
                                                       request_sock], [], [])
783
784
785
786
787
            except select.error as e:
                if e.args[0] == errno.EINTR:
                    (rlist, wlist, xlist) = ([], [], [])
                    continue
                else:
788
                    logger.error(XFROUT_SOCKET_SELECT_ERROR, e)
789
790
                    break

791
792
            # self.server._shutdown_event will be set by now, if it is not a
            # false alarm
793
794
795
796
            if self._read_sock in rlist:
                continue

            try:
797
798
                if not self.process_request(request_sock):
                    break
799
            except Exception as pre:
800
                # count unixsockets receive errors
801
                self._counters.inc('socket', 'unixdomain', 'recverr')
802
                logger.error(XFROUT_PROCESS_REQUEST_ERROR, pre)
803
804
                break

805
    def _handle_request_noblock(self):
806
807
        """Override the function _handle_request_noblock(), it creates a new
        thread to handle requests for each auth"""
808
809
810
811
        td = threading.Thread(target=self.handle_request)
        td.setDaemon(True)
        td.start()

812
    def process_request(self, request):
813
        """Receive socket fd and query message from auth, then
814
815
816
817
818
819
        start a new thread to process the request.

        Return: True if everything is okay; otherwise False, in which case
        the calling thread will terminate.

        """
820
821
        sock_fd = recv_fd(request.fileno())
        if sock_fd < 0:
822
            logger.warn(XFROUT_RECEIVE_FD_FAILED)
823
            return False
824

825
826
827
828
        # receive request msg.  If it fails we simply terminate the thread;
        # it might be possible to recover from this state, but it's more likely
        # that auth and xfrout are in inconsistent states.  So it will make
        # more sense to restart in a new session.
829
        request_data = self._receive_query_message(request)
830
831
832
833
        if request_data is None:
            # The specific exception type doesn't matter so we use session
            # error.
            raise XfroutSessionError('Failed to get complete xfr request')
834

835
        t = threading.Thread(target=self.finish_request,
836
                             args=(sock_fd, request_data))
837
838
839
        if self.daemon_threads:
            t.daemon = True
        t.start()
840
        return True
841

842
    def _guess_remote(self, sock_fd):
843
844
845
846
847
848
        """Guess remote address and port of the socket.

        The sock_fd must be a file descriptor of a socket.
        This method retuns a 3-tuple consisting of address family,
        socket type, and a 2-tuple with the address (string) and port (int).

849
850
        """
        # This uses a trick. If the socket is IPv4 in reality and we pretend
851
        # it to be IPv6, it returns IPv4 address anyway. This doesn't seem
852
853
854
        # to care about the SOCK_STREAM parameter at all (which it really is,
        # except for testing)
        if socket.has_ipv6:
855
            sock_domain = socket.AF_INET6
856
857
858
        else:
            # To make it work even on hosts without IPv6 support
            # (Any idea how to simulate this in test?)
859
            sock_domain = socket.AF_INET
860

861
862
863
        sock = socket.fromfd(sock_fd, sock_domain, socket.SOCK_STREAM)
        peer = sock.getpeername()
        sock.close()
864
865
866
867
868
869
870
871

        # Identify the correct socket family.  Due to the above "trick",
        # we cannot simply use sock.family.
        family = socket.AF_INET6
        try:
            socket.inet_pton(socket.AF_INET6, peer[0])
        except socket.error:
            family = socket.AF_INET
872

873
        return (family, socket.SOCK_STREAM, peer)
874

875
    def finish_request(self, sock_fd, request_data):
876
877
        '''Finish one request by instantiating RequestHandlerClass.

878
879
880
        This is an entry point of a separate thread spawned in
        UnixSockServer.process_request().

881
882
        This method creates a XfroutSession object.
        '''
883
884
885
886
        self._lock.acquire()
        acl = self._acl
        zone_config = self._zone_config
        self._lock.release()
887
        self.RequestHandlerClass(sock_fd, request_data, self,
888
                                 isc.server_common.tsig_keyring.get_keyring(),
889
                                 self._guess_remote(sock_fd), acl, zone_config)
890
891

    def _remove_unused_sock_file(self, sock_file):
892
893
        '''Try to remove the socket file. If the file is being used
        by one running xfrout process, exit from python.
894
895
896
        If it's not a socket file or nobody is listening
        , it will be removed. If it can't be removed, exit from python. '''
        if self._sock_file_in_use(sock_file):
897
            logger.error(XFROUT_UNIX_SOCKET_FILE_IN_USE, sock_file)
898
899
900
901
902
903
904
905
            sys.exit(0)
        else:
            if not os.path.exists(sock_file):
                return

            try:
                os.unlink(sock_file)
            except OSError as err:
906
                logger.error(XFROUT_REMOVE_OLD_UNIX_SOCKET_FILE_ERROR, sock_file, str(err))
907
                sys.exit(0)
908

909
    def _sock_file_in_use(self, sock_file):
910
911
        '''Check whether the socket file 'sock_file' exists and
        is being used by one running xfrout process. If it is,
912
        return True, or else return False. '''
913
914
915
916
917
918
919
920
921
        try:
            sock = socket.socket(socket.AF_UNIX)
            sock.connect(sock_file)
        except socket.error as err:
            sock.close()
            return False
        else:
            sock.close()
            return True
922

923
    def shutdown(self):
924
        self._write_sock.send(b"shutdown") #terminate the xfrout session thread
925
        super().shutdown() # call the shutdown() of class socketserver_mixin.NoPollMixIn
926
927
        # count closed unixsockets
        self._counters.inc('socket', 'unixdomain', 'close')
928
929
        try:
            os.unlink(self._sock_file)
Jerry's avatar
Jerry committed
930
        except Exception as e:
Jelte Jansen's avatar
Jelte Jansen committed
931
            logger.error(XFROUT_REMOVE_UNIX_SOCKET_FILE_ERROR, self._sock_file, str(e))
932
933

    def update_config_data(self, new_config):
934
935
936
        '''Apply the new config setting of xfrout module.

        '''
937
        self._lock.acquire()
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
        try:
            logger.info(XFROUT_NEW_CONFIG)
            new_acl = self._acl
            if 'transfer_acl' in new_config:
                try:
                    new_acl = REQUEST_LOADER.load(new_config['transfer_acl'])
                except LoaderError as e:
                    raise XfroutConfigError('Failed to parse transfer_acl: ' +
                                            str(e))

            new_zone_config = self._zone_config
            zconfig_data = new_config.get('zone_config')
            if zconfig_data is not None:
                new_zone_config = self.__create_zone_config(zconfig_data)

            self._acl = new_acl
            self._zone_config = new_zone_config
            self._max_transfers_out = new_config.get('transfers_out')
        except Exception as e:
            self._lock.release()
            raise e
959
        self._lock.release()
960
        logger.info(XFROUT_NEW_CONFIG_DONE)
961

962
963
964
965
966
    def __create_zone_config(self, zone_config_list):
        new_config = {}
        for zconf in zone_config_list:
            # convert the class, origin (name) pair.  First build pydnspp
            # object to reject invalid input.
967
968
969
970
971
            zclass_str = zconf.get('class')
            if zclass_str is None:
                #zclass_str = 'IN' # temporary
                zclass_str = self._cc.get_default_value('zone_config/class')
            zclass = RRClass(zclass_str)
972
973
974
975
976
            zorigin = Name(zconf['origin'], True)
            config_key = (zclass.to_text(), zorigin.to_text())

            # reject duplicate config
            if config_key in new_config:
977
                raise XfroutConfigError('Duplicate zone_config for ' +
978
                                        str(zorigin) + '/' + str(zclass))
979
980
981
982

            # create a new config entry, build any given (and known) config
            new_config[config_key] = {}
            if 'transfer_acl' in zconf:
983
984
985
986
987
988
989
                try:
                    new_config[config_key]['transfer_acl'] = \
                        REQUEST_LOADER.load(zconf['transfer_acl'])
                except LoaderError as e:
                    raise XfroutConfigError('Failed to parse transfer_acl ' +
                                            'for ' + zorigin.to_text() + '/' +
                                            zclass_str + ': ' + str(e))
990
991
        return new_config

992
    def get_db_file(self):
993
994
995
996
997
998