Commit 6383f1fd authored by Michael Graff's avatar Michael Graff
Browse files

New wire format, which makes things more sane for processing envelope apart...

New wire format, which makes things more sane for processing envelope apart from messages.  No API changes.  The current msgq does not support this, but the pymsgq I'm hoping to finish up tomorrow will.

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/parkinglot@332 e5f2f494-b856-4b98-b285-d166d9295462
parent 83553f96
......@@ -10,7 +10,7 @@ import signal
import os
import socket
import sys
import re
import struct
import errno
import time
import select
......@@ -19,6 +19,8 @@ from optparse import OptionParser, OptionValueError
import ISC.CC
class MsgQReceiveError(Exception): pass
# This is the version that gets displayed to the user.
__version__ = "v20091030 (Paving the DNS Parking Lot)"
......@@ -63,12 +65,14 @@ class MsgQ:
self.runnable = True
def process_accept(self):
"""Process an accept on the listening socket."""
newsocket, ipaddr = self.listen_socket.accept()
sys.stderr.write("Connection\n")
self.sockets[newsocket.fileno()] = newsocket
self.poller.register(newsocket, select.POLLIN)
def process_socket(self, fd):
"""Process a read on a socket."""
sock = self.sockets[fd]
if sock == None:
sys.stderr.write("Got read on Strange Socket fd %d\n" % fd)
......@@ -76,19 +80,98 @@ class MsgQ:
sys.stderr.write("Got read on fd %d\n" %fd)
self.process_packet(fd, sock)
def kill_socket(self, fd, sock):
"""Fully close down the socket."""
self.poller.unregister(sock)
sock.close()
self.sockets[fd] = None
sys.stderr.write("Closing socket fd %d\n" % fd)
def getbytes(self, fd, sock, length):
"""Get exactly the requested bytes, or raise an exception if
EOF."""
received = b''
while len(received) < length:
data = sock.recv(length - len(received))
if len(data) == 0:
raise MsgQReceiveError("EOF")
received += data
return received
def read_packet(self, fd, sock):
"""Read a correctly formatted packet. Will raise exceptions if
something fails."""
lengths = self.getbytes(fd, sock, 6)
overall_length, routing_length = struct.unpack(">IH", lengths)
if overall_length < 2:
raise MsgQReceiveError("overall_length < 2")
overall_length -= 2
sys.stderr.write("overall length: %d, routing_length %d\n"
% (overall_length, routing_length))
if routing_length > overall_length:
raise MsgQReceiveError("routing_length > overall_length")
if routing_length == 0:
raise MsgQReceiveError("routing_length == 0")
data_length = overall_length - routing_length
# probably need to sanity check lengths here...
routing = self.getbytes(fd, sock, routing_length)
if data_length > 0:
data = self.getbytes(fd, sock, data_length)
else:
data = None
return (routing, data)
def process_packet(self, fd, sock):
data = sock.recv(4)
if len(data) == 0:
self.poller.unregister(sock)
sock.close()
self.sockets[fd] = None
sys.stderr.write("Closing socket fd %d\n" % fd)
"""Process one packet."""
try:
routing, data = self.read_packet(fd, sock)
except MsgQReceiveError as err:
self.kill_socket(fd, sock)
sys.stderr.write("Receive error: %s\n" % err)
return
try:
routingmsg = ISC.CC.Message.from_wire(routing)
except DecodeError as err:
self.kill_socket(fd, sock)
sys.stderr.write("Routing decode error: %s\n" % err)
return
sys.stderr.write("Got data: %s\n" % data)
sys.stdout.write("\t" + pprint.pformat(routingmsg) + "\n")
sys.stdout.write("\t" + pprint.pformat(data) + "\n")
self.process_command(fd, sock, routingmsg, data)
def process_command(self, fd, sock, routing, data):
"""Process a single command. This will split out into one of the
other functions, above."""
cmd = routing["type"]
if cmd == 'getlname':
self.process_command_getlname(sock, routing, data)
elif cmd == 'send':
self.process_command_send(sock, routing, data)
else:
sys.stderr.write("Invalid command: %s\n" % cmd)
def sendmsg(self, sock, env, msg = None):
if type(env) == dict:
env = ISC.CC.Message.to_wire(env)
if type(msg) == dict:
msg = ISC.CC.Message.to_wire(msg)
sock.setblocking(1)
length = 2 + len(env);
if msg:
length += len(msg)
sock.send(struct.pack("!IH", length, len(env)))
sock.send(env)
if msg:
sock.send(msg)
def process_command_getlname(self, sock, routing, data):
self.sendmsg(sock, { "type" : "getlname" }, { "lname" : "staticlname" })
def run(self):
"""Process messages. Forever. Mostly."""
while True:
try:
events = self.poller.poll()
......
......@@ -73,28 +73,70 @@ Session::sendmsg(ElementPtr& msg)
std::string wire = msg->to_wire();
unsigned int length = wire.length();
unsigned int length_net = htonl(length);
unsigned short header_length_net = htons(length);
unsigned int ret;
ret = write(sock, &length_net, 4);
if (ret != 4)
throw SessionError("Short write");
ret = write(sock, &header_length_net, 2);
if (ret != 2)
throw SessionError("Short write");
ret = write(sock, wire.c_str(), length);
if (ret != length)
throw SessionError("Short write");
}
void
Session::sendmsg(ElementPtr& env, ElementPtr& msg)
{
std::string header_wire = env->to_wire();
std::string body_wire = msg->to_wire();
unsigned int length = 2 + header_wire.length() + body_wire.length();
unsigned int length_net = htonl(length);
unsigned short header_length = header_wire.length();
unsigned short header_length_net = htons(header_length);
unsigned int ret;
ret = write(sock, &length_net, 4);
if (ret != 4)
throw SessionError("Short write");
ret = write(sock, &header_length_net, 2);
if (ret != 2)
throw SessionError("Short write");
std::cout << "[XX] Header length sending: " << header_length << std::endl;
ret = write(sock, header_wire.c_str(), header_length);
ret = write(sock, body_wire.c_str(), body_wire.length());
if (ret != length)
throw SessionError("Short write");
}
bool
Session::recvmsg(ElementPtr& msg, bool nonblock)
{
unsigned int length_net;
unsigned short header_length_net;
unsigned int ret;
ret = read(sock, &length_net, 4);
if (ret != 4)
throw SessionError("Short read");
unsigned int length = ntohl(length_net);
ret = read(sock, &header_length_net, 2);
if (ret != 2)
throw SessionError("Short read");
unsigned int length = ntohl(length_net) - 2;
unsigned short header_length = ntohs(header_length_net);
if (header_length != length) {
throw SessionError("Received non-empty body where only a header expected");
}
char *buffer = new char[length];
ret = read(sock, buffer, length);
if (ret != length)
......@@ -112,6 +154,48 @@ Session::recvmsg(ElementPtr& msg, bool nonblock)
// XXXMLG handle non-block here, and return false for short reads
}
bool
Session::recvmsg(ElementPtr& env, ElementPtr& msg, bool nonblock)
{
unsigned int length_net;
unsigned short header_length_net;
unsigned int ret;
ret = read(sock, &length_net, 4);
if (ret != 4)
throw SessionError("Short read");
ret = read(sock, &header_length_net, 2);
if (ret != 2)
throw SessionError("Short read");
unsigned int length = ntohl(length_net);
unsigned short header_length = ntohs(header_length_net);
if (header_length > length)
throw SessionError("Bad header length");
char *buffer = new char[length];
ret = read(sock, buffer, length);
if (ret != length)
throw SessionError("Short read");
std::string header_wire = std::string(buffer, header_length);
std::string body_wire = std::string(buffer, length - header_length);
delete [] buffer;
std::stringstream header_wire_stream;
header_wire_stream << header_wire;
env = Element::from_wire(header_wire_stream, length);
std::stringstream body_wire_stream;
body_wire_stream << body_wire;
msg = Element::from_wire(body_wire_stream, length - header_length);
return (true);
// XXXMLG handle non-block here, and return false for short reads
}
void
Session::subscribe(std::string group, std::string instance, std::string subtype)
{
......@@ -148,9 +232,9 @@ Session::group_sendmsg(ElementPtr& msg, std::string group, std::string instance,
env->set("group", Element::create(group));
env->set("instance", Element::create(instance));
env->set("seq", Element::create(sequence));
env->set("msg", Element::create(msg->to_wire()));
//env->set("msg", Element::create(msg->to_wire()));
sendmsg(env);
sendmsg(env, msg);
return (sequence++);
}
......@@ -158,14 +242,11 @@ Session::group_sendmsg(ElementPtr& msg, std::string group, std::string instance,
bool
Session::group_recvmsg(ElementPtr& envelope, ElementPtr& msg, bool nonblock)
{
bool got_message = recvmsg(envelope, nonblock);
bool got_message = recvmsg(envelope, msg, nonblock);
if (!got_message) {
return false;
}
msg = Element::from_wire(envelope->get("msg")->string_value());
envelope->remove("msg");
return (true);
}
......@@ -180,10 +261,9 @@ Session::reply(ElementPtr& envelope, ElementPtr& newmsg)
env->set("group", Element::create(envelope->get("group")->string_value()));
env->set("instance", Element::create(envelope->get("instance")->string_value()));
env->set("seq", Element::create(sequence));
env->set("msg", Element::create(newmsg->to_wire()));
env->set("reply", Element::create(envelope->get("seq")->string_value()));
sendmsg(env);
sendmsg(env, newmsg);
return (sequence++);
}
......@@ -36,8 +36,12 @@ namespace ISC {
void establish();
void disconnect();
void sendmsg(ISC::Data::ElementPtr& msg);
void sendmsg(ISC::Data::ElementPtr& env, ISC::Data::ElementPtr& msg);
bool recvmsg(ISC::Data::ElementPtr& msg,
bool nonblock = true);
bool recvmsg(ISC::Data::ElementPtr& env,
ISC::Data::ElementPtr& msg,
bool nonblock = true);
void subscribe(std::string group,
std::string instance = "*",
std::string subtype = "normal");
......
......@@ -37,7 +37,7 @@ class Session:
self._socket.connect(tuple(['127.0.0.1', port]))
self.sendmsg({ "type": "getlname" })
msg = self.recvmsg(False)
env, msg = self.recvmsg(False)
self._lname = msg["lname"]
if not self._lname:
raise ProtocolError("Could not get local name")
......@@ -48,18 +48,31 @@ class Session:
def lname(self):
return self._lname
def sendmsg(self, msg):
def sendmsg(self, env, msg = None):
if type(env) == dict:
env = Message.to_wire(env)
if type(msg) == dict:
msg = Message.to_wire(msg)
self._socket.setblocking(1)
self._socket.send(struct.pack("!I", len(msg)))
self._socket.send(msg)
length = 2 + len(env);
if msg:
length += len(msg)
self._socket.send(struct.pack("!I", length))
self._socket.send(struct.pack("!H", len(env)))
self._socket.send(env)
if msg:
self._socket.send(msg)
def recvmsg(self, nonblock = True):
data = self._receive_full_buffer(nonblock)
if data:
return Message.from_wire(data)
return None
if data and len(data) > 2:
header_length = struct.unpack('>H', data[0:2])[0]
data_length = len(data) - 2 - header_length
if data_length > 0:
return Message.from_wire(data[2:header_length+2]), Message.from_wire(data[header_length + 2:])
else:
return Message.from_wire(data[2:header_length+2]), None
return None, None
def _receive_full_buffer(self, nonblock):
if nonblock:
......@@ -127,20 +140,15 @@ class Session:
"group": group,
"instance": instance,
"seq": seq,
"msg": Message.to_wire(msg),
})
}, Message.to_wire(msg))
return seq
def group_recvmsg(self, nonblock = True):
env = self.recvmsg(nonblock)
env, msg = self.recvmsg(nonblock)
if env == None:
# return none twice to match normal return value
# (so caller won't get a type error on no data)
return (None, None)
if type(env["msg"]) != bytearray:
msg = Message.from_wire(env["msg"].encode('ascii'))
else:
msg = Message.from_wire(env["msg"])
return (msg, env)
def group_reply(self, routing, msg):
......@@ -153,8 +161,7 @@ class Session:
"instance": routing["instance"],
"seq": seq,
"reply": routing["seq"],
"msg": Message.to_wire(msg),
})
}, Message.to_wire(msg))
return seq
if __name__ == "__main__":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment