Commit a6823843 authored by Michael Graff's avatar Michael Graff
Browse files

checkpoint work; Python-based msgq mostly works. Bad input will crash it,...

checkpoint work; Python-based msgq mostly works.  Bad input will crash it, which should be fixed, probably by wrapping the entire message processing in a try loop.  Gross, but...

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/parkinglot@333 e5f2f494-b856-4b98-b285-d166d9295462
parent 6383f1fd
......@@ -15,15 +15,60 @@ import errno
import time
import select
import pprint
import random
from optparse import OptionParser, OptionValueError
import ISC.CC
class MsgQReceiveError(Exception): pass
# This is the version that gets displayed to the user.
__version__ = "v20091030 (Paving the DNS Parking Lot)"
class MsgQReceiveError(Exception): pass
class SubscriptionManager:
def __init__(self):
self.subscriptions = {}
def subscribe(self, group, instance, socket):
"""Add a subscription."""
target = ( group, instance )
if target in self.subscriptions:
print("Appending to existing target")
self.subscriptions[target].append(socket)
else:
print("Creating new target")
self.subscriptions[target] = [ socket ]
def unsubscribe(self, group, instance, socket):
"""Remove the socket from the one specific subscription."""
target = ( group, instance )
if target in self.subscriptions:
while socket in self.subscriptions[target]:
self.subscriptions[target].remove(socket)
def unsubscribe_all(self, socket):
"""Remove the socket from all subscriptions."""
for socklist in self.subscriptions.values():
while socket in socklist:
socklist.remove(socket)
def find_sub(self, group, instance):
"""Return an array of sockets which want this specific group,
instance."""
target = (group, instance)
if target in self.subscriptions:
return self.subscriptions[target]
else:
return []
def find(self, group, instance):
"""Return an array of sockets who should get something sent to
this group, instance pair. This includes wildcard subscriptions."""
target = (group, instance)
partone = self.find_sub(group, instance)
parttwo = self.find_sub(group, "*")
return list(set(partone + parttwo))
class MsgQ:
"""Message Queue class."""
def __init__(self, c_channel_port=9912, verbose=False):
......@@ -39,6 +84,9 @@ class MsgQ:
self.runnable = False
self.listen_socket = False
self.sockets = {}
self.connection_counter = random.random()
self.hostname = socket.gethostname()
self.subs = SubscriptionManager()
def setup_poller(self):
"""Set up the poll thing. Internal function."""
......@@ -77,12 +125,13 @@ class MsgQ:
if sock == None:
sys.stderr.write("Got read on Strange Socket fd %d\n" % fd)
return
sys.stderr.write("Got read on fd %d\n" %fd)
# sys.stderr.write("Got read on fd %d\n" %fd)
self.process_packet(fd, sock)
def kill_socket(self, fd, sock):
"""Fully close down the socket."""
self.poller.unregister(sock)
self.subs.unsubscribe_all(sock)
sock.close()
self.sockets[fd] = None
sys.stderr.write("Closing socket fd %d\n" % fd)
......@@ -106,8 +155,6 @@ class MsgQ:
if overall_length < 2:
raise MsgQReceiveError("overall_length < 2")
overall_length -= 2
sys.stderr.write("overall length: %d, routing_length %d\n"
% (overall_length, routing_length))
if routing_length > overall_length:
raise MsgQReceiveError("routing_length > overall_length")
if routing_length == 0:
......@@ -137,8 +184,8 @@ class MsgQ:
sys.stderr.write("Routing decode error: %s\n" % err)
return
sys.stdout.write("\t" + pprint.pformat(routingmsg) + "\n")
sys.stdout.write("\t" + pprint.pformat(data) + "\n")
# sys.stdout.write("\t" + pprint.pformat(routingmsg) + "\n")
# sys.stdout.write("\t" + pprint.pformat(data) + "\n")
self.process_command(fd, sock, routingmsg, data)
......@@ -146,29 +193,77 @@ class MsgQ:
"""Process a single command. This will split out into one of the
other functions, above."""
cmd = routing["type"]
if cmd == 'getlname':
self.process_command_getlname(sock, routing, data)
elif cmd == 'send':
if cmd == 'send':
self.process_command_send(sock, routing, data)
elif cmd == 'subscribe':
self.process_command_subscribe(sock, routing, data)
elif cmd == 'unsubscribe':
self.process_command_unsubscribe(sock, routing, data)
elif cmd == 'getlname':
self.process_command_getlname(sock, routing, data)
else:
sys.stderr.write("Invalid command: %s\n" % cmd)
def sendmsg(self, sock, env, msg = None):
def preparemsg(self, env, msg = None):
if type(env) == dict:
env = ISC.CC.Message.to_wire(env)
if type(msg) == dict:
msg = ISC.CC.Message.to_wire(msg)
sock.setblocking(1)
length = 2 + len(env);
if msg:
length += len(msg)
sock.send(struct.pack("!IH", length, len(env)))
sock.send(env)
ret = struct.pack("!IH", length, len(env))
ret += env
if msg:
sock.send(msg)
ret += msg
return ret
def sendmsg(self, sock, env, msg = None):
sock.send(self.preparemsg(env, msg))
def send_prepared_msg(self, sock, msg):
sock.send(msg)
def newlname(self):
"""Generate a unique conenction identifier for this socket.
This is done by using an increasing counter and the current
time."""
self.connection_counter += 1
return "%x_%x@%s" % (time.time(), self.connection_counter, self.hostname)
def process_command_getlname(self, sock, routing, data):
self.sendmsg(sock, { "type" : "getlname" }, { "lname" : "staticlname" })
env = { "type" : "getlname" }
reply = { "lname" : self.newlname() }
self.sendmsg(sock, env, reply)
def process_command_send(self, sock, routing, data):
group = routing["group"]
instance = routing["instance"]
if group == None or instance == None:
return # ignore invalid packets entirely
sockets = self.subs.find(group, instance)
msg = self.preparemsg(routing, data)
if sock in sockets:
sockets.remove(sock)
for socket in sockets:
self.send_prepared_msg(socket, msg)
def process_command_subscribe(self, sock, routing, data):
group = routing["group"]
instance = routing["instance"]
subtype = routing["subtype"]
if group == None or instance == None or subtype == None:
return # ignore invalid packets entirely
self.subs.subscribe(group, instance, sock)
def process_command_unsubscribe(self, sock, routing, data):
group = routing["group"]
instance = routing["instance"]
if group == None or instance == None:
return # ignore invalid packets entirely
self.subs.unsubscribe(group, instance, sock)
def run(self):
"""Process messages. Forever. Mostly."""
......
from msgq import SubscriptionManager, MsgQ
import unittest
#
# Currently only the subscription part is implemented... I'd have to mock
# out a socket, which, while not impossible, is not trivial.
#
class TestSubscriptionManager(unittest.TestCase):
def setUp(self):
self.sm = SubscriptionManager()
def test_subscription_add_delete_manager(self):
self.sm.subscribe("a", "*", 'sock1')
self.assertEqual(self.sm.find_sub("a", "*"), [ 'sock1' ])
def test_subscription_add_delete_other(self):
self.sm.subscribe("a", "*", 'sock1')
self.sm.unsubscribe("a", "*", 'sock2')
self.assertEqual(self.sm.find_sub("a", "*"), [ 'sock1' ])
def test_subscription_add_several_sockets(self):
socks = [ 's1', 's2', 's3', 's4', 's5' ]
for s in socks:
self.sm.subscribe("a", "*", s)
self.assertEqual(self.sm.find_sub("a", "*"), socks)
def test_unsubscribe(self):
socks = [ 's1', 's2', 's3', 's4', 's5' ]
for s in socks:
self.sm.subscribe("a", "*", s)
self.sm.unsubscribe("a", "*", 's3')
self.assertEqual(self.sm.find_sub("a", "*"), [ 's1', 's2', 's4', 's5' ])
def test_unsubscribe_all(self):
self.sm.subscribe('g1', 'i1', 's1')
self.sm.subscribe('g1', 'i1', 's2')
self.sm.subscribe('g1', 'i2', 's1')
self.sm.subscribe('g1', 'i2', 's2')
self.sm.subscribe('g2', 'i1', 's1')
self.sm.subscribe('g2', 'i1', 's2')
self.sm.subscribe('g2', 'i2', 's1')
self.sm.subscribe('g2', 'i2', 's2')
self.sm.unsubscribe_all('s1')
self.assertEqual(self.sm.find_sub("g1", "i1"), [ 's2' ])
self.assertEqual(self.sm.find_sub("g1", "i2"), [ 's2' ])
self.assertEqual(self.sm.find_sub("g2", "i1"), [ 's2' ])
self.assertEqual(self.sm.find_sub("g2", "i2"), [ 's2' ])
def test_find(self):
self.sm.subscribe('g1', 'i1', 's1')
self.sm.subscribe('g1', '*', 's2')
self.assertEqual(set(self.sm.find("g1", "i1")), set([ 's1', 's2' ]))
def test_find_sub(self):
self.sm.subscribe('g1', 'i1', 's1')
self.sm.subscribe('g1', '*', 's2')
self.assertEqual(self.sm.find_sub("g1", "i1"), [ 's1' ])
if __name__ == '__main__':
unittest.main()
import ISC
import time
import pprint
import unittest
#
# This test requires the MsgQ daemon to be running. We are doing nasty
# tricks here, and so insert sleeps to give things time to migrate from
# this process, to the MsgQ, and back to this process.
#
class TestCCWireEncoding(unittest.TestCase):
def setUp(self):
self.s1 = ISC.CC.Session()
self.s2 = ISC.CC.Session()
def test_lname(self):
self.assertTrue(self.s1.lname)
self.assertTrue(self.s2.lname)
def test_subscribe(self):
self.s1.group_subscribe("g1", "i1")
self.s2.group_subscribe("g1", "i1")
time.sleep(0.5)
outmsg = { "data" : "foo" }
self.s1.group_sendmsg(outmsg, "g1", "i1")
time.sleep(0.5)
msg, env = self.s2.group_recvmsg()
self.assertEqual(env["from"], self.s1.lname)
def test_unsubscribe(self):
self.s1.group_subscribe("g1", "i1")
self.s2.group_subscribe("g1", "i1")
time.sleep(0.5)
self.s2.group_unsubscribe("g1", "i1")
outmsg = { "data" : "foo" }
self.s1.group_sendmsg(outmsg, "g1", "i1")
time.sleep(0.5)
msg, env = self.s2.group_recvmsg()
self.assertFalse(env)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment