Commit 0785c84f authored by JINMEI Tatuya's avatar JINMEI Tatuya
Browse files

[2380] handle signals

parent 9d1e869b
......@@ -18,6 +18,7 @@
import sys
sys.path.append('@@PYTHONPATH@@')
import time
import signal
from optparse import OptionParser
from isc.dns import *
from isc.datasrc import *
......@@ -85,6 +86,7 @@ class LoadZoneRunner:
def __init__(self, command_args):
self.__command_args = command_args
self.__loaded_rrs = 0
self.__interrupted = False # will be set to True on receiving signal
# system-wide log configuration. We need to configure logging this
# way so that the logging policy applies to underlying libraries, too.
......@@ -199,6 +201,11 @@ class LoadZoneRunner:
[self._zone_name.to_text()])
def _report_progress(self, loaded_rrs):
'''Dump the current progress report to stdout.
This is essentially private, but defined as "protected" for tests.
'''
elapsed = time.time() - self.__start_time
sys.stdout.write("\r" + (80 * " "))
sys.stdout.write("\r%d RRs loaded in %.2f seconds" %
......@@ -225,10 +232,13 @@ class LoadZoneRunner:
limit = self._load_iteration_limit
else:
limit = LOAD_INTERVAL_DEFAULT
while not loader.load_incremental(limit):
while (not self.__interrupted and
not loader.load_incremental(limit)):
self.__loaded_rrs += self._load_iteration_limit
if self._load_iteration_limit > 0:
self._report_progress(self.__loaded_rrs)
if self.__interrupted:
raise LoadFailure('loading interrupted by signal')
except Exception as ex:
# release any remaining lock held in the client/loader
loader, datasrc_client = None, None
......@@ -260,10 +270,18 @@ class LoadZoneRunner:
logger.warn(LOADZONE_POSTLOAD_ISSUE, self._zone_name,
self._zone_class, msg)
def _set_signal_handlers(self):
signal.signal(signal.SIGINT, self._interrupt_handler)
signal.signal(signal.SIGTERM, self._interrupt_handler)
def _interrupt_handler(self, signal, frame):
self.__interrupted = True
def run(self):
'''Top-level method, simply calling other helpers'''
try:
self._set_signal_handlers()
self._parse_args()
self._do_load()
logger.info(LOADZONE_DONE, self._zone_name, self._zone_class)
......
......@@ -243,7 +243,7 @@ class TestLoadZoneRunner(unittest.TestCase):
self.__runner._do_load()
self.__runner._post_load_checks()
def test_load_fail_create_cancel(self):
def test_load_post_check_fail_soa(self):
'''Load succeeds but warns about missing SOA, should cause warn'''
self.__common_load_setup()
self.__common_post_load_setup(LOCAL_TESTDATA_PATH +
......@@ -252,7 +252,7 @@ class TestLoadZoneRunner(unittest.TestCase):
self.assertEqual(1, len(self.__warnings))
self.assertEqual('zone has no SOA', self.__warnings[0])
def test_load_fail_create_cancel(self):
def test_load_post_check_fail_ns(self):
'''Load succeeds but warns about missing NS, should cause warn'''
self.__common_load_setup()
self.__common_post_load_setup(LOCAL_TESTDATA_PATH +
......@@ -261,6 +261,43 @@ class TestLoadZoneRunner(unittest.TestCase):
self.assertEqual(1, len(self.__warnings))
self.assertEqual('zone has no NS', self.__warnings[0])
def __interrupt_progress(self, loaded_rrs):
'''A helper emulating a signal in the middle of loading.
On the second progress report, it internally invokes the signal
handler to see if it stops the loading.
'''
self.__reports.append(loaded_rrs)
if len(self.__reports) == 2:
self.__runner._interrupt_handler()
def test_load_interrupted(self):
'''Load attempt fails due to signal interruption'''
self.__common_load_setup()
self.__runner._report_progress = lambda x: self.__interrupt_progress(x)
# The interrupting _report_progress() will terminate the loading
# in the middle. the number of reports is smaller, and the zone
# won't be changed.
self.assertRaises(LoadFailure, self.__runner._do_load)
self.assertEqual([1, 2], self.__reports)
self.__check_zone_soa(ORIG_SOA_TXT)
def test_load_interrupted_create_cancel(self):
'''Load attempt for a new zone fails due to signal interruption
It cancels the zone creation.
'''
self.__common_load_setup()
self.__runner._report_progress = lambda x: self.__interrupt_progress(x)
self.__runner._zone_name = Name('example.com')
self.__runner._zone_file = ALT_NEW_ZONE_TXT_FILE
self.__check_zone_soa(None, zone_name=Name('example.com'))
self.assertRaises(LoadFailure, self.__runner._do_load)
self.assertEqual([1, 2], self.__reports)
self.__check_zone_soa(None, zone_name=Name('example.com'))
def test_run_success(self):
'''Check for the top-level method.
......@@ -291,4 +328,7 @@ if __name__== "__main__":
# Disable the internal logging setup so the test output won't be too
# verbose by default.
LoadZoneRunner._config_log = lambda x: None
# Cancel signal handlers so we can stop tests when they hang
LoadZoneRunner._set_signal_handlers = lambda x: None
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment