Commit f7c04626 authored by Jerry's avatar Jerry
Browse files

merge trac299: Xfrout and Auth will communicate by long tcp connection,

Auth needs to make a new connection only on the first time or if an error occurred.


git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@3482 e5f2f494-b856-4b98-b285-d166d9295462
parent 921de30e
116. [bug] jerry
src/bin/xfrout: Xfrout and Auth will communicate by long tcp
connection, Auth needs to make a new connection only on the first
time or if an error occurred.
(Trac #299, svn r3482)
115. [func]* jinmei
src/lib/dns: Changed DNS message flags and section names from
separate classes to simpler enums, considering the balance between
......@@ -10,7 +16,7 @@
(Trac #365, svn r3383)
113. [func]* zhanglikun
Folder name 'utils'(the folder in /src/lib/python/isc/) has been
Folder name 'utils'(the folder in /src/lib/python/isc/) has been
renamed to 'util'. Programs that used 'import isc.utils.process'
now need to use 'import isc.util.process'. The folder
/src/lib/python/isc/Util is removed since it isn't used by any
......
......@@ -77,7 +77,7 @@ public:
MessageRenderer& response_renderer);
bool processAxfrQuery(const IOMessage& io_message, Message& message,
MessageRenderer& response_renderer);
bool processNotify(const IOMessage& io_message, Message& message,
bool processNotify(const IOMessage& io_message, Message& message,
MessageRenderer& response_renderer);
std::string db_file_;
ModuleCCSession* config_session_;
......@@ -307,7 +307,7 @@ AuthSrvImpl::processNormalQuery(const IOMessage& io_message, Message& message,
ConstEDNSPtr remote_edns = message.getEDNS();
const bool dnssec_ok = remote_edns && remote_edns->getDNSSECAwareness();
const uint16_t remote_bufsize = remote_edns ? remote_edns->getUDPSize() :
Message::DEFAULT_MAX_UDPSIZE;
Message::DEFAULT_MAX_UDPSIZE;
message.makeResponse();
message.setHeaderFlag(Message::HEADERFLAG_AA);
......@@ -360,8 +360,10 @@ AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, Message& message,
}
try {
xfrout_client_.connect();
xfrout_connected_ = true;
if (!xfrout_connected_) {
xfrout_client_.connect();
xfrout_connected_ = true;
}
xfrout_client_.sendXfroutRequestInfo(
io_message.getSocket().getNative(),
io_message.getData(),
......@@ -375,7 +377,7 @@ AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, Message& message,
xfrout_client_.disconnect();
xfrout_connected_ = false;
}
if (verbose_mode_) {
cerr << "[b10-auth] Error in handling XFR request: " << err.what()
<< endl;
......@@ -385,15 +387,12 @@ AuthSrvImpl::processAxfrQuery(const IOMessage& io_message, Message& message,
return (true);
}
xfrout_client_.disconnect();
xfrout_connected_ = false;
return (false);
}
bool
AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
MessageRenderer& response_renderer)
AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
MessageRenderer& response_renderer)
{
// The incoming notify must contain exactly one question for SOA of the
// zone name.
......@@ -435,7 +434,7 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
}
return (false);
}
const string remote_ip_address =
io_message.getRemoteEndpoint().getAddress().toText();
static const string command_template_start =
......@@ -446,7 +445,7 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
try {
ConstElementPtr notify_command = Element::fromJSON(
command_template_start + question->getName().toText() +
command_template_start + question->getName().toText() +
command_template_master + remote_ip_address +
command_template_rrclass + question->getClass().toText() +
command_template_end);
......@@ -460,7 +459,7 @@ AuthSrvImpl::processNotify(const IOMessage& io_message, Message& message,
if (rcode != 0) {
if (verbose_mode_) {
cerr << "[b10-auth] failed to notify Zonemgr: "
<< parsed_answer->str() << endl;
<< parsed_answer->str() << endl;
}
return (false);
}
......
......@@ -489,7 +489,7 @@ TEST_F(AuthSrvTest, AXFRSuccess) {
// so we shouldn't have to respond.
EXPECT_FALSE(server.processMessage(*io_message, parse_message,
response_renderer));
EXPECT_FALSE(xfrout.isConnected());
EXPECT_TRUE(xfrout.isConnected());
}
TEST_F(AuthSrvTest, AXFRConnectFail) {
......@@ -501,8 +501,6 @@ TEST_F(AuthSrvTest, AXFRConnectFail) {
response_renderer));
headerCheck(parse_message, default_qid, Rcode::SERVFAIL(),
opcode.getCode(), QR_FLAG, 1, 0, 0, 0);
// For a shot term workaround with xfrout we currently close the connection
// for each AXFR attempt
EXPECT_FALSE(xfrout.isConnected());
}
......@@ -512,7 +510,7 @@ TEST_F(AuthSrvTest, AXFRSendFail) {
createRequestPacket(opcode, Name("example.com"), RRClass::IN(),
RRType::AXFR(), IPPROTO_TCP);
server.processMessage(*io_message, parse_message, response_renderer);
EXPECT_FALSE(xfrout.isConnected()); // see above
EXPECT_TRUE(xfrout.isConnected());
xfrout.disableSend();
parse_message.clear(Message::PARSE);
......
......@@ -47,22 +47,29 @@ class MySocket():
result = self.sendqueue[:size]
self.sendqueue = self.sendqueue[size:]
return result
def read_msg(self):
sent_data = self.readsent()
get_msg = Message(Message.PARSE)
get_msg.from_wire(bytes(sent_data[2:]))
return get_msg
def clear_send(self):
del self.sendqueue[:]
# We subclass the Session class we're testing here, only
# to override the __init__() method, which wants a socket,
# to override the handle() and _send_data() method
class MyXfroutSession(XfroutSession):
def handle(self):
pass
def _send_data(self, sock, data):
size = len(data)
total_count = 0
while total_count < size:
count = sock.send(data[total_count:])
total_count += count
class Dbserver:
def __init__(self):
self._shutdown_event = threading.Event()
......@@ -80,12 +87,21 @@ class TestXfroutSession(unittest.TestCase):
def setUp(self):
request = MySocket(socket.AF_INET,socket.SOCK_STREAM)
self.log = isc.log.NSLogger('xfrout', '', severity = 'critical', log_to_console = False )
self.xfrsess = MyXfroutSession(request, None, None, self.log)
(self.write_sock, self.read_sock) = socket.socketpair()
self.xfrsess = MyXfroutSession(request, None, None, self.log, self.read_sock)
self.xfrsess.server = Dbserver()
self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
def test_receive_query_message(self):
send_msg = b"\xd6=\x00\x00\x00\x01\x00"
msg_len = struct.pack('H', socket.htons(len(send_msg)))
self.write_sock.send(msg_len)
self.write_sock.send(send_msg)
recv_msg = self.xfrsess._receive_query_message(self.read_sock)
self.assertEqual(recv_msg, send_msg)
def test_parse_query_message(self):
[get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
self.assertEqual(get_rcode.to_text(), "NOERROR")
......@@ -93,7 +109,7 @@ class TestXfroutSession(unittest.TestCase):
def test_get_query_zone_name(self):
msg = self.getmsg()
self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
def test_send_data(self):
self.xfrsess._send_data(self.sock, self.mdata)
senddata = self.sock.readsent()
......@@ -103,8 +119,8 @@ class TestXfroutSession(unittest.TestCase):
msg = self.getmsg()
self.xfrsess._reply_query_with_error_rcode(msg, self.sock, Rcode(3))
get_msg = self.sock.read_msg()
self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
self.assertEqual(get_msg.get_rcode().to_text(), "NXDOMAIN")
def test_clear_message(self):
msg = self.getmsg()
qid = msg.get_qid()
......@@ -118,7 +134,7 @@ class TestXfroutSession(unittest.TestCase):
self.assertTrue(msg.get_header_flag(Message.HEADERFLAG_AA))
def test_reply_query_with_format_error(self):
msg = self.getmsg()
self.xfrsess._reply_query_with_format_error(msg, self.sock)
get_msg = self.sock.read_msg()
......@@ -217,7 +233,7 @@ class TestXfroutSession(unittest.TestCase):
sqlite3_ds.get_zone_soa = zone_soa
self.assertEqual(self.xfrsess._zone_exist(True), True)
self.assertEqual(self.xfrsess._zone_exist(False), False)
def test_check_xfrout_available(self):
def zone_exist(zone):
return zone
......@@ -243,7 +259,7 @@ class TestXfroutSession(unittest.TestCase):
self.xfrsess.dns_xfrout_start(self.sock, b"\xd6=\x00\x00\x00\x01\x00")
sent_data = self.sock.readsent()
self.assertEqual(len(sent_data), 0)
def default(self, param):
return "example.com"
......@@ -255,20 +271,20 @@ class TestXfroutSession(unittest.TestCase):
self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
get_msg = self.sock.read_msg()
self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
def test_dns_xfrout_start_noerror(self):
self.xfrsess._get_query_zone_name = self.default
def noerror(form):
return Rcode.NOERROR()
return Rcode.NOERROR()
self.xfrsess._check_xfrout_available = noerror
def myreply(msg, sock, zonename):
self.sock.send(b"success")
self.xfrsess._reply_xfrout_query = myreply
self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
self.assertEqual(self.sock.readsent(), b"success")
def test_reply_xfrout_query_noerror(self):
global sqlite3_ds
def get_zone_soa(zonename, file):
......@@ -292,7 +308,7 @@ class MyCCSession():
return "initdb.file", False
else:
return "unknown", False
class MyUnixSockServer(UnixSockServer):
def __init__(self):
......@@ -306,7 +322,7 @@ class MyUnixSockServer(UnixSockServer):
class TestUnixSockServer(unittest.TestCase):
def setUp(self):
self.unix = MyUnixSockServer()
def test_updata_config_data(self):
self.unix.update_config_data({'transfers_out':10 })
self.assertEqual(self.unix._max_transfers_out, 10)
......@@ -324,7 +340,7 @@ class TestUnixSockServer(unittest.TestCase):
count = self.unix._transfers_counter
self.assertEqual(self.unix.increase_transfers_counter(), False)
self.assertEqual(count, self.unix._transfers_counter)
def test_decrease_transfers_counter(self):
count = self.unix._transfers_counter
self.unix.decrease_transfers_counter()
......@@ -335,7 +351,7 @@ class TestUnixSockServer(unittest.TestCase):
os.remove(sock_file)
except OSError:
pass
def test_sock_file_in_use_file_exist(self):
sock_file = 'temp.sock.file'
self._remove_file(sock_file)
......
......@@ -63,6 +63,7 @@ AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
MAX_TRANSFERS_OUT = 10
VERBOSE_MODE = False
XFROUT_MAX_MESSAGE_SIZE = 65535
def get_rrset_len(rrset):
......@@ -73,46 +74,78 @@ def get_rrset_len(rrset):
class XfroutSession(BaseRequestHandler):
def __init__(self, request, client_address, server, log):
def __init__(self, request, client_address, server, log, sock):
# The initializer for the superclass may call functions
# that need _log to be set, so we set it first
self._log = log
self._shutdown_sock = sock
BaseRequestHandler.__init__(self, request, client_address, server)
def handle(self):
fd = recv_fd(self.request.fileno())
if fd < 0:
# This may happen when one xfrout process try to connect to
# xfrout unix socket server, to check whether there is another
# xfrout running.
self._log.log_message("error", "Failed to receive the file descriptor for XFR connection")
return
data_len = self.request.recv(2)
msg_len = struct.unpack('!H', data_len)[0]
msgdata = self.request.recv(msg_len)
sock = socket.fromfd(fd, socket.AF_INET, socket.SOCK_STREAM)
try:
self.dns_xfrout_start(sock, msgdata)
#TODO, avoid catching all exceptions
except Exception as e:
self._log.log_message("error", str(e))
'''Handle a request until shutdown or xfrout client is closed.'''
# check self.server._shutdown_event to ensure the real shutdown comes.
# Linux could trigger a spurious readable event on the _shutdown_sock
# due to a bug, so we need perform a double check.
while not self.server._shutdown_event.is_set(): # Check if xfrout is shutdown
try:
(rlist, wlist, xlist) = select.select([self._shutdown_sock, self.request], [], [])
except select.error as e:
if e.args[0] == errno.EINTR:
(rlist, wlist, xlist) = ([], [], [])
continue
else:
self._log.log_message("error", "Error with select(): %s" %e)
break
# self.server._shutdown_evnet will be set by now, if it is not a false
# alarm
if self._shutdown_sock in rlist:
continue
try:
sock.shutdown(socket.SHUT_RDWR)
except socket.error:
# Avoid socket error caused by shutting down
# one non-connected socket.
pass
sock_fd = recv_fd(self.request.fileno())
if sock_fd < 0:
# This may happen when one xfrout process try to connect to
# xfrout unix socket server, to check whether there is another
# xfrout running.
if sock_fd == XFR_FD_RECEIVE_FAIL:
self._log.log_message("error", "Failed to receive the file descriptor for XFR connection")
break
sock.close()
os.close(fd)
pass
# receive query msg
msgdata = self._receive_query_message(self.request)
if not msgdata:
break
try:
self.dns_xfrout_start(sock_fd, msgdata)
#TODO, avoid catching all exceptions
except Exception as e:
self._log.log_message("error", str(e))
os.close(sock_fd)
def _receive_query_message(self, sock):
''' receive query message from sock'''
# receive data length
data_len = sock.recv(2)
if not data_len:
return None
msg_len = struct.unpack('!H', data_len)[0]
# receive data
recv_size = 0
msgdata = b''
while recv_size < msg_len:
data = sock.recv(msg_len - recv_size)
if not data:
return None
recv_size += len(data)
msgdata += data
return msgdata
def _parse_query_message(self, mdata):
''' parse query message to [socket,message]'''
#TODO, need to add parseHeader() in case the message header is invalid
#TODO, need to add parseHeader() in case the message header is invalid
try:
msg = Message(Message.PARSE)
Message.from_wire(msg, mdata)
......@@ -127,37 +160,37 @@ class XfroutSession(BaseRequestHandler):
return question.get_name().to_text()
def _send_data(self, sock, data):
def _send_data(self, sock_fd, data):
size = len(data)
total_count = 0
while total_count < size:
count = sock.send(data[total_count:])
count = os.write(sock_fd, data[total_count:])
total_count += count
def _send_message(self, sock, msg):
def _send_message(self, sock_fd, msg):
render = MessageRenderer()
render.set_length_limit(XFROUT_MAX_MESSAGE_SIZE)
msg.to_wire(render)
header_len = struct.pack('H', socket.htons(render.get_length()))
self._send_data(sock, header_len)
self._send_data(sock, render.get_data())
self._send_data(sock_fd, header_len)
self._send_data(sock_fd, render.get_data())
def _reply_query_with_error_rcode(self, msg, sock, rcode_):
def _reply_query_with_error_rcode(self, msg, sock_fd, rcode_):
msg.make_response()
msg.set_rcode(rcode_)
self._send_message(sock, msg)
self._send_message(sock_fd, msg)
def _reply_query_with_format_error(self, msg, sock):
def _reply_query_with_format_error(self, msg, sock_fd):
'''query message format isn't legal.'''
if not msg:
return # query message is invalid. send nothing back.
return # query message is invalid. send nothing back.
msg.make_response()
msg.set_rcode(Rcode.FORMERR())
self._send_message(sock, msg)
self._send_message(sock_fd, msg)
def _zone_is_empty(self, zone):
......@@ -167,24 +200,24 @@ class XfroutSession(BaseRequestHandler):
return True
def _zone_exist(self, zonename):
# Find zone in datasource, should this works? maybe should ask
# Find zone in datasource, should this works? maybe should ask
# config manager.
soa = sqlite3_ds.get_zone_soa(zonename, self.server.get_db_file())
if soa:
return True
return False
def _check_xfrout_available(self, zone_name):
'''Check if xfr request can be responsed.
TODO, Get zone's configuration from cfgmgr or some other place
eg. check allow_transfer setting,
eg. check allow_transfer setting,
'''
if not self._zone_exist(zone_name):
return Rcode.NOTAUTH()
if self._zone_is_empty(zone_name):
return Rcode.SERVFAIL()
return Rcode.SERVFAIL()
#TODO, check allow_transfer
if not self.server.increase_transfers_counter():
......@@ -193,35 +226,35 @@ class XfroutSession(BaseRequestHandler):
return Rcode.NOERROR()
def dns_xfrout_start(self, sock, msg_query):
def dns_xfrout_start(self, sock_fd, msg_query):
rcode_, msg = self._parse_query_message(msg_query)
#TODO. create query message and parse header
if rcode_ != Rcode.NOERROR():
return self._reply_query_with_format_error(msg, sock)
return self._reply_query_with_format_error(msg, sock_fd)
zone_name = self._get_query_zone_name(msg)
rcode_ = self._check_xfrout_available(zone_name)
if rcode_ != Rcode.NOERROR():
self._log.log_message("info", "transfer of '%s/IN' failed: %s",
zone_name, rcode_.to_text())
return self. _reply_query_with_error_rcode(msg, sock, rcode_)
return self. _reply_query_with_error_rcode(msg, sock_fd, rcode_)
try:
self._log.log_message("info", "transfer of '%s/IN': AXFR started" % zone_name)
self._reply_xfrout_query(msg, sock, zone_name)
self._reply_xfrout_query(msg, sock_fd, zone_name)
self._log.log_message("info", "transfer of '%s/IN': AXFR end" % zone_name)
except Exception as err:
self._log.log_message("error", str(err))
self.server.decrease_transfers_counter()
return
return
def _clear_message(self, msg):
qid = msg.get_qid()
opcode = msg.get_opcode()
rcode = msg.get_rcode()
msg.clear(Message.RENDER)
msg.set_qid(qid)
msg.set_opcode(opcode)
......@@ -231,7 +264,7 @@ class XfroutSession(BaseRequestHandler):
return msg
def _create_rrset_from_db_record(self, record):
'''Create one rrset from one record of datasource, if the schema of record is changed,
'''Create one rrset from one record of datasource, if the schema of record is changed,
This function should be updated first.
'''
rrtype_ = RRType(record[5])
......@@ -239,8 +272,8 @@ class XfroutSession(BaseRequestHandler):
rrset_ = RRset(Name(record[2]), RRClass("IN"), rrtype_, RRTTL( int(record[4])))
rrset_.add_rdata(rdata_)
return rrset_
def _send_message_with_last_soa(self, msg, sock, rrset_soa, message_upper_len):
def _send_message_with_last_soa(self, msg, sock_fd, rrset_soa, message_upper_len):
'''Add the SOA record to the end of message. If it can't be
added, a new message should be created to send out the last soa .
'''
......@@ -249,14 +282,14 @@ class XfroutSession(BaseRequestHandler):
if message_upper_len + rrset_len < XFROUT_MAX_MESSAGE_SIZE:
msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
else:
self._send_message(sock, msg)
self._send_message(sock_fd, msg)
msg = self._clear_message(msg)
msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
self._send_message(sock, msg)
self._send_message(sock_fd, msg)
def _reply_xfrout_query(self, msg, sock, zone_name):
def _reply_xfrout_query(self, msg, sock_fd, zone_name):
#TODO, there should be a better way to insert rrset.
msg.make_response()
msg.set_header_flag(Message.HEADERFLAG_AA)
......@@ -286,12 +319,12 @@ class XfroutSession(BaseRequestHandler):
message_upper_len += rrset_len
continue
self._send_message(sock, msg)
self._send_message(sock_fd, msg)
msg = self._clear_message(msg)
msg.add_rrset(Message.SECTION_ANSWER, rrset_) # Add the rrset to the new message
message_upper_len = rrset_len
self._send_message_with_last_soa(msg, sock, rrset_soa, message_upper_len)
self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len)
class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
'''The unix domain socket server which accept xfr query sent from auth server.'''
......@@ -304,22 +337,23 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
self._lock = threading.Lock()
self._transfers_counter = 0
self._shutdown_event = shutdown_event
self._write_sock, self._read_sock = socket.socketpair()
self._log = log
self.update_config_data(config_data)
self._cc = cc
def finish_request(self, request, client_address):
'''Finish one request by instantiating RequestHandlerClass.'''
self.RequestHandlerClass(request, client_address, self, self._log)
self.RequestHandlerClass(request, client_address, self, self._log, self._read_sock)
def _remove_unused_sock_file(self, sock_file):
'''Try to remove the socket file. If the file is being used
by one running xfrout process, exit from python.
'''Try to remove the socket file. If the file is being used
by one running xfrout process, exit from python.
If it's not a socket file or nobody is listening
, it will be removed. If it can't be removed, exit from python. '''
if self._sock_file_in_use(sock_file):
sys.stderr.write("[b10-xfrout] Fail to start xfrout process, unix socket"
" file '%s' is being used by another xfrout process\n" % sock_file)
self._log.log_message("error", "Fail to start xfrout process, unix socket file '%s'"
" is being used by another xfrout process\n" % sock_file)
sys.exit(0)
else:
if not os.path.exists(sock_file):
......@@ -328,12 +362,12 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
try:
os.unlink(sock_file)
except OSError as err:
sys.stderr.write('[b10-xfrout] Fail to remove file %s: %s\n' % (sock_file, err))
self._log.log_message("error", '[b10-xfrout] Fail to remove file %s: %s\n' % (sock_file, err))
sys.exit(0)
def _sock_file_in_use(self, sock_file):
'''Check whether the socket file 'sock_file' exists and
is being used by one running xfrout process. If it is,
'''Check whether the socket file 'sock_file' exists and
is being used by one running xfrout process. If it is,
return True, or else return False. '''
try:
sock = socket.socket(socket.AF_UNIX)
......@@ -341,9 +375,10 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
except socket.error as err:
return False
else:
return True
return True
def shutdown(self):
self._write_sock.send(b"shutdown") #terminate the xfrout session thread
super().shutdown() # call the shutdown() of class socketserver_mixin.NoPollMixIn
try:
os.unlink(self._sock_file)
......@@ -390,7 +425,7 @@ class XfroutServer:
def __init__(self):
self._unix_socket_server = None
self._log = None
self._listen_sock_file = UNIX_SOCKET_FILE
self._listen_sock_file = UNIX_SOCKET_FILE
self._shutdown_event = threading.Event()
self._cc = isc.config.ModuleCCSession(SPECFILE_LOCATION, self.config_handler, self.command_handler)