Commit 613b2880 authored by JINMEI Tatuya's avatar JINMEI Tatuya
Browse files

cleanup: eliminated the need for isc.auth.sqlite3_ds.AXFRInDB by directly...

cleanup: eliminated the need for isc.auth.sqlite3_ds.AXFRInDB by directly calling sqlite3_ds.load() from the xfrin module.

Other cleanups:
- cosmetic: removed redundant blank lines and white spaces after EOL
- grammar fix in comments
- catch Sqlite3DSError explicitly (but I suspect the exception handling
  in the xfrin module is naive overall, which should be fixed)


git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@1684 e5f2f494-b856-4b98-b285-d166d9295462
parent 4b29535a
......@@ -85,7 +85,6 @@ class XfrinConnection(asyncore.dispatcher):
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
self._zone_name = zone_name
self._db_file = db_file
self._axfrin_db = isc.auth.sqlite3_ds.AXFRInDB(self._db_file, self._zone_name)
self._soa_rr_count = 0
self._idle_timeout = idle_timeout
self.setblocking(1)
......@@ -112,7 +111,6 @@ class XfrinConnection(asyncore.dispatcher):
count = self.send(data[total_count:])
total_count += count
def _send_query(self, query_type):
'''Send query message over TCP. '''
msg = self._create_query(query_type)
......@@ -123,7 +121,6 @@ class XfrinConnection(asyncore.dispatcher):
self._send_data(header_len)
self._send_data(obuf.get_data())
def _get_request_response(self, size):
recv_size = 0
......@@ -140,14 +137,12 @@ class XfrinConnection(asyncore.dispatcher):
return data
def handle_read(self):
'''Read query's response from socket. '''
self._recvd_data = self.recv(self._need_recv_size)
self._recvd_size = len(self._recvd_data)
self._recv_time_out = False
def _check_soa_serial(self):
''' Compare the soa serial, if soa serial in master is less than
the soa serial in local, Finish xfrin.
......@@ -169,9 +164,9 @@ class XfrinConnection(asyncore.dispatcher):
self.log_msg('transfer of \'%s\': AXFR started' % self._zone_name)
if ret == XFRIN_OK:
self._axfrin_db.prepare_axfrin()
self._send_query(rr_type.AXFR())
ret = self._handle_xfrin_response()
isc.auth.sqlite3_ds.load(self._db_file, self._zone_name,
self._handle_xfrin_response)
endmsg = 'succeeded' if ret == XFRIN_OK else 'failed'
self.log_msg('transfer of \'%s\' AXFR %s' % (self._zone_name,
......@@ -179,11 +174,11 @@ class XfrinConnection(asyncore.dispatcher):
except XfrinException as e:
self.log_msg(e)
self.log_msg('Error happened during xfrin!')
#TODO, recover data source.
#TODO, recover data source.
except isc.auth.sqlite3_ds.Sqlite3DSError as e:
self.log_msg(e)
finally:
self.close()
if ret == XFRIN_OK:
self._axfrin_db.finish_axfrin()
return ret
......@@ -204,7 +199,6 @@ class XfrinConnection(asyncore.dispatcher):
if msg.get_rr_count(section.QUESTION()) > 1:
raise XfrinException('query section count greater than 1')
def _handle_answer_section(self, rrset_iter):
while not rrset_iter.is_last():
rrset = rrset_iter.get_rrset()
......@@ -231,11 +225,10 @@ class XfrinConnection(asyncore.dispatcher):
break
rdata_text = rdata_iter.get_current().to_text()
rr_data = (rrset_name, rrset_ttl, rrset_class, rrset_type, rdata_text)
self._axfrin_db.insert_axfr_record([rr_data])
yield (rrset_name, rrset_ttl, rrset_class, rrset_type,
rdata_text)
rdata_iter.next()
def _handle_xfrin_response(self):
while True:
data_len = self._get_request_response(2)
......@@ -246,23 +239,21 @@ class XfrinConnection(asyncore.dispatcher):
self._check_response_status(msg)
rrset_iter = section_iter(msg, section.ANSWER())
self._handle_answer_section(rrset_iter)
for rr in self._handle_answer_section(rrset_iter):
yield rr
if self._soa_rr_count == 2:
return XFRIN_OK
break
if self._shutdown_event.is_set():
#Check if xfrin process is shutdown.
#TODO, xfrin may be blocked in one loop.
raise XfrinException('xfrin is forced to stop')
return XFRIN_OK
def writable(self):
'''Ignore the writable socket. '''
return False
def log_info(self, msg, type='info'):
# Overwrite the log function, log nothing
pass
......@@ -276,7 +267,7 @@ class XfrinConnection(asyncore.dispatcher):
def process_xfrin(xfrin_recorder, zone_name, db_file,
shutdown_event, master_addr, port, check_soa, verbose):
xfrin_recorder.increment(zone_name)
xfrin_recorder.increment(name)
try:
conn = XfrinConnection(zone_name, db_file, shutdown_event,
master_addr, int(port), check_soa, verbose)
......@@ -405,12 +396,12 @@ class Xfrin():
if self.recorder.xfrin_in_progress(zone_name):
return (1, 'zone xfrin is in progress')
xfrin_thread = threading.Thread(target = process_xfrin,
args = (self.recorder,
zone_name,
db_file,
xfrin_thread = threading.Thread(target = process_xfrin,
args = (self.recorder,
zone_name,
db_file,
self._shutdown_event,
master_addr,
master_addr,
port, check_soa, self._verbose))
xfrin_thread.start()
......
......@@ -148,17 +148,20 @@ def reverse_name(name):
new.pop(0)
return '.'.join(new)+'.'
#########################################################################
# load:
# load a zone into the SQL database.
# input:
# dbfile: the sqlite3 database fileanme
# zone: the zone origin
# reader: an generator function producing an iterable set of
# reader: a generator function producing an iterable set of
# name/ttl/class/rrtype/rdata-text tuples
#########################################################################
def load(dbfile, zone, reader):
# if the zone name doesn't contain the trailing dot, automatically add it.
if zone[-1] != '.':
zone += '.'
conn, cur = open(dbfile)
old_zone_id = get_zoneid(zone, cur)
......@@ -184,13 +187,13 @@ def load(dbfile, zone, reader):
rdtype, sigtype, rdata)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
[new_zone_id, name, reverse_name(name), ttl,
rdtype, sigtype, rdata])
rdtype, sigtype, rdata])
else:
cur.execute("""INSERT INTO records
(zone_id, name, rname, ttl, rdtype, rdata)
VALUES (?, ?, ?, ?, ?, ?)""",
[new_zone_id, name, reverse_name(name), ttl,
rdtype, rdata])
rdtype, rdata])
except Exception as e:
fail = "Error while loading " + zone + ": " + e.args[0]
raise Sqlite3DSError(fail)
......@@ -208,78 +211,3 @@ def load(dbfile, zone, reader):
cur.close()
conn.close()
#########################################################################
# temp sqlite3 datasource backend for axfr in. The code should be refectored
# later.
#########################################################################
class AXFRInDB:
def __init__(self, dbfile, zone_name):
self._dbfile = dbfile
self._zone_name = zone_name
# if the zone name doesn't contain the trailing dot, automatically
# add it.
if self._zone_name[-1] != '.':
self._zone_name += '.'
self._old_zone_id = None
self._new_zone_id = None
def prepare_axfrin(self):
self._conn, self._cur = open(self._dbfile)
self._old_zone_id = get_zoneid(self._zone_name, self._cur)
temp = str(random.randrange(100000))
self._cur.execute("INSERT INTO zones (name, rdclass) VALUES (?, 'IN')", [temp])
self._new_zone_id = self._cur.lastrowid
def insert_axfr_record(self, rrsets):
'''insert zone records to sqlite3 database'''
try:
for name, ttl, rdclass, rdtype, rdata in rrsets:
sigtype = ''
if rdtype.lower() == 'rrsig':
sigtype = rdata.split()[0]
if rdtype.lower() == 'nsec3' or sigtype.lower() == 'nsec3':
hash = name.split('.')[0]
self._cur.execute("""INSERT INTO nsec3
(zone_id, hash, owner, ttl, rdtype, rdata)
VALUES (?, ?, ?, ?, ?, ?)""",
[self._new_zone_id, hash, name, ttl, rdtype, rdata])
elif rdtype.lower() == 'rrsig':
self._cur.execute("""INSERT INTO records
(zone_id, name, rname, ttl,
rdtype, sigtype, rdata)
VALUES (?, ?, ?, ?, ?, ?, ?)""",
[self._new_zone_id, name, reverse_name(name), ttl,
rdtype, sigtype, rdata])
else:
self._cur.execute("""INSERT INTO records
(zone_id, name, rname, ttl, rdtype, rdata)
VALUES (?, ?, ?, ?, ?, ?)""",
[self._new_zone_id, name, reverse_name(name), ttl,
rdtype, rdata])
except Exception as e:
fail = "Error while loading " + self._zone_name + ": " + e.args[0]
raise Sqlite3DSError(fail)
def finish_axfrin(self):
'''commit changes and close sqlite3 database'''
if self._old_zone_id:
self._cur.execute("DELETE FROM zones WHERE id=?", [self._old_zone_id])
self._cur.execute("UPDATE zones SET name=? WHERE id=?", [self._zone_name, self._new_zone_id])
self._conn.commit()
self._cur.execute("DELETE FROM records WHERE zone_id=?", [self._old_zone_id])
self._cur.execute("DELETE FROM nsec3 WHERE zone_id=?", [self._old_zone_id])
self._conn.commit()
else:
self._cur.execute("UPDATE zones SET name=? WHERE id=?", [self._zone_name, self._new_zone_id])
self._conn.commit()
self._cur.close()
self._conn.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment