Rev 2434: Smart server mediums now detect which protocol version a request is and dispatch accordingly. in http://bazaar.launchpad.net/~bzr
Andrew Bennetts
andrew.bennetts at canonical.com
Tue Apr 24 08:22:08 BST 2007
At http://bazaar.launchpad.net/~bzr
------------------------------------------------------------
revno: 2434
revision-id: andrew.bennetts at canonical.com-20070424072048-tgbochqfr1n33bcy
parent: andrew.bennetts at canonical.com-20070424051106-wwlidpflp1rwi3a7
committer: Andrew Bennetts <andrew.bennetts at canonical.com>
branch nick: hpss-protocol2
timestamp: Tue 2007-04-24 17:20:48 +1000
message:
Smart server mediums now detect which protocol version a request is and dispatch accordingly.
modified:
bzrlib/smart/medium.py medium.py-20061103051856-rgu2huy59fkz902q-1
bzrlib/tests/test_smart_transport.py test_ssh_transport.py-20060608202016-c25gvf1ob7ypbus6-2
bzrlib/tests/test_wsgi.py test_wsgi.py-20061005091552-rz8pva0olkxv0sd8-1
bzrlib/transport/http/wsgi.py wsgi.py-20061005091552-rz8pva0olkxv0sd8-2
=== modified file 'bzrlib/smart/medium.py'
--- a/bzrlib/smart/medium.py 2007-04-10 02:31:42 +0000
+++ b/bzrlib/smart/medium.py 2007-04-24 07:20:48 +0000
@@ -28,7 +28,10 @@
import socket
import sys
from bzrlib import errors
-from bzrlib.smart.protocol import SmartServerRequestProtocolOne
+from bzrlib.smart.protocol import (
+ SmartServerRequestProtocolOne,
+ SmartServerRequestProtocolTwo,
+ )
try:
from bzrlib.transport import ssh
@@ -66,13 +69,24 @@
from sys import stderr
try:
while not self.finished:
- protocol = SmartServerRequestProtocolOne(self.backing_transport,
- self._write_out)
+ protocol = self._build_protocol()
self._serve_one_request(protocol)
except Exception, e:
stderr.write("%s terminating on exception %s\n" % (self, e))
raise
+ def _build_protocol(self):
+ # Identify the protocol version.
+ bytes = self._get_bytes(2)
+ if bytes.startswith('2\x01'):
+ protocol_class = SmartServerRequestProtocolTwo
+ bytes = bytes[2:]
+ else:
+ protocol_class = SmartServerRequestProtocolOne
+ protocol = protocol_class(self.backing_transport, self._write_out)
+ protocol.accept_bytes(bytes)
+ return protocol
+
def _serve_one_request(self, protocol):
"""Read one request from input, process, send back a response.
@@ -89,6 +103,13 @@
"""Called when an unhandled exception from the protocol occurs."""
raise NotImplementedError(self.terminate_due_to_error)
+ def _get_bytes(self, desired_count):
+ """Get some bytes from the medium.
+
+ :param desired_count: number of bytes we want to read.
+ """
+ raise NotImplementedError(self._get_bytes)
+
class SmartServerSocketStreamMedium(SmartServerStreamMedium):
@@ -109,13 +130,18 @@
protocol.accept_bytes(self.push_back)
self.push_back = ''
else:
- bytes = self.socket.recv(4096)
+ bytes = self._get_bytes(4096)
if bytes == '':
self.finished = True
return
protocol.accept_bytes(bytes)
self.push_back = protocol.excess_buffer
+
+ def _get_bytes(self, desired_count):
+ # We ignore the desired_count because on sockets it's more efficient to
+ # read 4k at a time.
+ return self.socket.recv(4096)
def terminate_due_to_error(self):
"""Called when an unhandled exception from the protocol occurs."""
@@ -155,7 +181,7 @@
# Finished serving this request.
self._out.flush()
return
- bytes = self._in.read(bytes_to_read)
+ bytes = self._get_bytes(bytes_to_read)
if bytes == '':
# Connection has been closed.
self.finished = True
@@ -163,6 +189,9 @@
return
protocol.accept_bytes(bytes)
+ def _get_bytes(self, desired_count):
+ return self._in.read(desired_count)
+
def terminate_due_to_error(self):
# TODO: This should log to a server log file, but no such thing
# exists yet. Andrew Bennetts 2006-09-29.
=== modified file 'bzrlib/tests/test_smart_transport.py'
--- a/bzrlib/tests/test_smart_transport.py 2007-04-24 05:11:06 +0000
+++ b/bzrlib/tests/test_smart_transport.py 2007-04-24 07:20:48 +0000
@@ -746,9 +746,6 @@
self.assertTrue(server.finished)
def test_socket_stream_error_handling(self):
- # Use plain python StringIO so we can monkey-patch the close method to
- # not discard the contents.
- from StringIO import StringIO
server_sock, client_sock = self.portable_socket_pair()
server = medium.SmartServerSocketStreamMedium(
server_sock, None)
@@ -780,6 +777,62 @@
KeyboardInterrupt, server._serve_one_request, fake_protocol)
server_sock.close()
self.assertEqual('', client_sock.recv(1))
+
+ def build_protocol_pipe_like(self, bytes):
+ to_server = StringIO(bytes)
+ from_server = StringIO()
+ server = medium.SmartServerPipeStreamMedium(
+ to_server, from_server, None)
+ return server._build_protocol()
+
+ def build_protocol_socket(self, bytes):
+ server_sock, client_sock = self.portable_socket_pair()
+ server = medium.SmartServerSocketStreamMedium(
+ server_sock, None)
+ client_sock.sendall(bytes)
+ client_sock.close()
+ return server._build_protocol()
+
+ def assertProtocolOne(self, server_protocol):
+ # Use assertIs because assertIsInstance will wrongly pass
+ # SmartServerRequestProtocolTwo (because it subclasses
+ # SmartServerRequestProtocolOne).
+ self.assertIs(
+ type(server_protocol), protocol.SmartServerRequestProtocolOne)
+
+ def assertProtocolTwo(self, server_protocol):
+ self.assertIsInstance(
+ server_protocol, protocol.SmartServerRequestProtocolTwo)
+
+ def test_pipe_like_build_protocol_empty_bytes(self):
+ # Any empty request (i.e. no bytes) is detected as protocol version one.
+ server_protocol = self.build_protocol_pipe_like('')
+ self.assertProtocolOne(server_protocol)
+
+ def test_socket_like_build_protocol_empty_bytes(self):
+ # Any empty request (i.e. no bytes) is detected as protocol version one.
+ server_protocol = self.build_protocol_socket('')
+ self.assertProtocolOne(server_protocol)
+
+ def test_pipe_like_build_protocol_non_two(self):
+ # A request that doesn't start with "2\x01" is version one.
+ server_protocol = self.build_protocol_pipe_like('2-')
+ self.assertProtocolOne(server_protocol)
+
+ def test_socket_build_protocol_non_two(self):
+ # A request that doesn't start with "2\x01" is version one.
+ server_protocol = self.build_protocol_socket('2-')
+ self.assertProtocolOne(server_protocol)
+
+ def test_pipe_like_build_protocol_two(self):
+ # A request that starts with "2\x01" is version two.
+ server_protocol = self.build_protocol_pipe_like('2\x01')
+ self.assertProtocolTwo(server_protocol)
+
+ def test_socket_build_protocol_two(self):
+ # A request that starts with "2\x01" is version two.
+ server_protocol = self.build_protocol_socket('2\x01')
+ self.assertProtocolTwo(server_protocol)
class TestSmartTCPServer(tests.TestCase):
=== modified file 'bzrlib/tests/test_wsgi.py'
--- a/bzrlib/tests/test_wsgi.py 2007-03-28 03:02:32 +0000
+++ b/bzrlib/tests/test_wsgi.py 2007-04-24 07:20:48 +0000
@@ -82,6 +82,12 @@
self.assertEqual('405 Method not allowed', self.status)
self.assertTrue(('Allow', 'POST') in self.headers)
+ def _fake_make_request(self, transport, write_func, bytes):
+ request = FakeRequest(transport, write_func)
+ request.accept_bytes(bytes)
+ self.request = request
+ return request
+
def test_smart_wsgi_app_uses_given_relpath(self):
# The SmartWSGIApp should use the "bzrlib.relpath" field from the
# WSGI environ to clone from its backing transport to get a specific
@@ -89,11 +95,7 @@
transport = FakeTransport()
wsgi_app = wsgi.SmartWSGIApp(transport)
wsgi_app.backing_transport = transport
- def make_request(transport, write_func):
- request = FakeRequest(transport, write_func)
- self.request = request
- return request
- wsgi_app.make_request = make_request
+ wsgi_app.make_request = self._fake_make_request
fake_input = StringIO('fake request')
environ = self.build_environ({
'REQUEST_METHOD': 'POST',
@@ -112,11 +114,7 @@
transport = memory.MemoryTransport()
transport.put_bytes('foo', 'some bytes')
wsgi_app = wsgi.SmartWSGIApp(transport)
- def make_request(transport, write_func):
- request = FakeRequest(transport, write_func)
- self.request = request
- return request
- wsgi_app.make_request = make_request
+ wsgi_app.make_request = self._fake_make_request
fake_input = StringIO('fake request')
environ = self.build_environ({
'REQUEST_METHOD': 'POST',
@@ -186,8 +184,9 @@
def test_incomplete_request(self):
transport = FakeTransport()
wsgi_app = wsgi.SmartWSGIApp(transport)
- def make_request(transport, write_func):
+ def make_request(transport, write_func, bytes):
request = IncompleteRequest(transport, write_func)
+ request.accept_bytes(bytes)
self.request = request
return request
wsgi_app.make_request = make_request
@@ -204,6 +203,41 @@
self.assertEqual('200 OK', self.status)
self.assertEqual('error\x01incomplete request\n', response)
+ def test_protocol_version_detection_one(self):
+ # SmartWSGIApp detects requests that don't start with '2\x01' as version
+ # one.
+ transport = memory.MemoryTransport()
+ wsgi_app = wsgi.SmartWSGIApp(transport)
+ fake_input = StringIO('hello\n')
+ environ = self.build_environ({
+ 'REQUEST_METHOD': 'POST',
+ 'CONTENT_LENGTH': len(fake_input.getvalue()),
+ 'wsgi.input': fake_input,
+ 'bzrlib.relpath': 'foo',
+ })
+ iterable = wsgi_app(environ, self.start_response)
+ response = self.read_response(iterable)
+ self.assertEqual('200 OK', self.status)
+ # Expect a version 1-encoded response.
+ self.assertEqual('ok\x012\n', response)
+
+ def test_protocol_version_detection_two(self):
+ # SmartWSGIApp detects requests that start with '2\x01' as version two.
+ transport = memory.MemoryTransport()
+ wsgi_app = wsgi.SmartWSGIApp(transport)
+ fake_input = StringIO('2\x01hello\n')
+ environ = self.build_environ({
+ 'REQUEST_METHOD': 'POST',
+ 'CONTENT_LENGTH': len(fake_input.getvalue()),
+ 'wsgi.input': fake_input,
+ 'bzrlib.relpath': 'foo',
+ })
+ iterable = wsgi_app(environ, self.start_response)
+ response = self.read_response(iterable)
+ self.assertEqual('200 OK', self.status)
+ # Expect a version 2-encoded response.
+ self.assertEqual('2\x01ok\x012\n', response)
+
class FakeRequest(object):
=== modified file 'bzrlib/transport/http/wsgi.py'
--- a/bzrlib/transport/http/wsgi.py 2007-04-05 15:20:57 +0000
+++ b/bzrlib/transport/http/wsgi.py 2007-04-24 07:20:48 +0000
@@ -22,6 +22,7 @@
from cStringIO import StringIO
+from bzrlib.smart import protocol
from bzrlib.transport import chroot, get_transport, remote
from bzrlib.urlutils import local_path_to_url
@@ -113,10 +114,10 @@
relpath = environ['bzrlib.relpath']
transport = self.backing_transport.clone(relpath)
out_buffer = StringIO()
- smart_protocol_request = self.make_request(transport, out_buffer.write)
request_data_length = int(environ['CONTENT_LENGTH'])
request_data_bytes = environ['wsgi.input'].read(request_data_length)
- smart_protocol_request.accept_bytes(request_data_bytes)
+ smart_protocol_request = self.make_request(
+ transport, out_buffer.write, request_data_bytes)
if smart_protocol_request.next_read_size() != 0:
# The request appears to be incomplete, or perhaps it's just a
# newer version we don't understand. Regardless, all we can do
@@ -130,5 +131,14 @@
start_response('200 OK', headers)
return [response_data]
- def make_request(self, transport, write_func):
- return protocol.SmartServerRequestProtocolOne(transport, write_func)
+ def make_request(self, transport, write_func, request_bytes):
+ # XXX: This duplicates the logic in
+ # SmartServerStreamMedium._build_protocol.
+ if request_bytes.startswith('2\x01'):
+ protocol_class = protocol.SmartServerRequestProtocolTwo
+ request_bytes = request_bytes[2:]
+ else:
+ protocol_class = protocol.SmartServerRequestProtocolOne
+ server_protocol = protocol_class(transport, write_func)
+ server_protocol.accept_bytes(request_bytes)
+ return server_protocol
More information about the bazaar-commits
mailing list