Rev 2434: Smart server mediums now detect which protocol version a request is and dispatch accordingly. in

Andrew Bennetts andrew.bennetts at
Tue Apr 24 08:22:08 BST 2007


revno: 2434
revision-id: andrew.bennetts at
parent: andrew.bennetts at
committer: Andrew Bennetts <andrew.bennetts at>
branch nick: hpss-protocol2
timestamp: Tue 2007-04-24 17:20:48 +1000
  Smart server mediums now detect which protocol version a request is and dispatch accordingly.
=== modified file 'bzrlib/smart/'
--- a/bzrlib/smart/	2007-04-10 02:31:42 +0000
+++ b/bzrlib/smart/	2007-04-24 07:20:48 +0000
@@ -28,7 +28,10 @@
 import socket
 import sys
 from bzrlib import errors
-from import SmartServerRequestProtocolOne
+from import (
+    SmartServerRequestProtocolOne,
+    SmartServerRequestProtocolTwo,
+    )
     from bzrlib.transport import ssh
@@ -66,13 +69,24 @@
         from sys import stderr
             while not self.finished:
-                protocol = SmartServerRequestProtocolOne(self.backing_transport,
-                                                         self._write_out)
+                protocol = self._build_protocol()
         except Exception, e:
             stderr.write("%s terminating on exception %s\n" % (self, e))
+    def _build_protocol(self):
+        # Identify the protocol version.
+        bytes = self._get_bytes(2)
+        if bytes.startswith('2\x01'):
+            protocol_class = SmartServerRequestProtocolTwo
+            bytes = bytes[2:]
+        else:
+            protocol_class = SmartServerRequestProtocolOne
+        protocol = protocol_class(self.backing_transport, self._write_out)
+        protocol.accept_bytes(bytes)
+        return protocol
     def _serve_one_request(self, protocol):
         """Read one request from input, process, send back a response.
@@ -89,6 +103,13 @@
         """Called when an unhandled exception from the protocol occurs."""
         raise NotImplementedError(self.terminate_due_to_error)
+    def _get_bytes(self, desired_count):
+        """Get some bytes from the medium.
+        :param desired_count: number of bytes we want to read.
+        """
+        raise NotImplementedError(self._get_bytes)
 class SmartServerSocketStreamMedium(SmartServerStreamMedium):
@@ -109,13 +130,18 @@
                 self.push_back = ''
-                bytes = self.socket.recv(4096)
+                bytes = self._get_bytes(4096)
                 if bytes == '':
                     self.finished = True
         self.push_back = protocol.excess_buffer
+    def _get_bytes(self, desired_count):
+        # We ignore the desired_count because on sockets it's more efficient to
+        # read 4k at a time.
+        return self.socket.recv(4096)
     def terminate_due_to_error(self):
         """Called when an unhandled exception from the protocol occurs."""
@@ -155,7 +181,7 @@
                 # Finished serving this request.
-            bytes =
+            bytes = self._get_bytes(bytes_to_read)
             if bytes == '':
                 # Connection has been closed.
                 self.finished = True
@@ -163,6 +189,9 @@
+    def _get_bytes(self, desired_count):
+        return
     def terminate_due_to_error(self):
         # TODO: This should log to a server log file, but no such thing
         # exists yet.  Andrew Bennetts 2006-09-29.

=== modified file 'bzrlib/tests/'
--- a/bzrlib/tests/	2007-04-24 05:11:06 +0000
+++ b/bzrlib/tests/	2007-04-24 07:20:48 +0000
@@ -746,9 +746,6 @@
     def test_socket_stream_error_handling(self):
-        # Use plain python StringIO so we can monkey-patch the close method to
-        # not discard the contents.
-        from StringIO import StringIO
         server_sock, client_sock = self.portable_socket_pair()
         server = medium.SmartServerSocketStreamMedium(
             server_sock, None)
@@ -780,6 +777,62 @@
             KeyboardInterrupt, server._serve_one_request, fake_protocol)
         self.assertEqual('', client_sock.recv(1))
+    def build_protocol_pipe_like(self, bytes):
+        to_server = StringIO(bytes)
+        from_server = StringIO()
+        server = medium.SmartServerPipeStreamMedium(
+            to_server, from_server, None)
+        return server._build_protocol()
+    def build_protocol_socket(self, bytes):
+        server_sock, client_sock = self.portable_socket_pair()
+        server = medium.SmartServerSocketStreamMedium(
+            server_sock, None)
+        client_sock.sendall(bytes)
+        client_sock.close()
+        return server._build_protocol()
+    def assertProtocolOne(self, server_protocol):
+        # Use assertIs because assertIsInstance will wrongly pass
+        # SmartServerRequestProtocolTwo (because it subclasses
+        # SmartServerRequestProtocolOne).
+        self.assertIs(
+            type(server_protocol), protocol.SmartServerRequestProtocolOne)
+    def assertProtocolTwo(self, server_protocol):
+        self.assertIsInstance(
+            server_protocol, protocol.SmartServerRequestProtocolTwo)
+    def test_pipe_like_build_protocol_empty_bytes(self):
+        # Any empty request (i.e. no bytes) is detected as protocol version one.
+        server_protocol = self.build_protocol_pipe_like('')
+        self.assertProtocolOne(server_protocol)
+    def test_socket_like_build_protocol_empty_bytes(self):
+        # Any empty request (i.e. no bytes) is detected as protocol version one.
+        server_protocol = self.build_protocol_socket('')
+        self.assertProtocolOne(server_protocol)
+    def test_pipe_like_build_protocol_non_two(self):
+        # A request that doesn't start with "2\x01" is version one.
+        server_protocol = self.build_protocol_pipe_like('2-')
+        self.assertProtocolOne(server_protocol)
+    def test_socket_build_protocol_non_two(self):
+        # A request that doesn't start with "2\x01" is version one.
+        server_protocol = self.build_protocol_socket('2-')
+        self.assertProtocolOne(server_protocol)
+    def test_pipe_like_build_protocol_two(self):
+        # A request that starts with "2\x01" is version two.
+        server_protocol = self.build_protocol_pipe_like('2\x01')
+        self.assertProtocolTwo(server_protocol)
+    def test_socket_build_protocol_two(self):
+        # A request that starts with "2\x01" is version two.
+        server_protocol = self.build_protocol_socket('2\x01')
+        self.assertProtocolTwo(server_protocol)
 class TestSmartTCPServer(tests.TestCase):

=== modified file 'bzrlib/tests/'
--- a/bzrlib/tests/	2007-03-28 03:02:32 +0000
+++ b/bzrlib/tests/	2007-04-24 07:20:48 +0000
@@ -82,6 +82,12 @@
         self.assertEqual('405 Method not allowed', self.status)
         self.assertTrue(('Allow', 'POST') in self.headers)
+    def _fake_make_request(self, transport, write_func, bytes):
+        request = FakeRequest(transport, write_func)
+        request.accept_bytes(bytes)
+        self.request = request
+        return request
     def test_smart_wsgi_app_uses_given_relpath(self):
         # The SmartWSGIApp should use the "bzrlib.relpath" field from the
         # WSGI environ to clone from its backing transport to get a specific
@@ -89,11 +95,7 @@
         transport = FakeTransport()
         wsgi_app = wsgi.SmartWSGIApp(transport)
         wsgi_app.backing_transport = transport
-        def make_request(transport, write_func):
-            request = FakeRequest(transport, write_func)
-            self.request = request
-            return request
-        wsgi_app.make_request = make_request
+        wsgi_app.make_request = self._fake_make_request
         fake_input = StringIO('fake request')
         environ = self.build_environ({
             'REQUEST_METHOD': 'POST',
@@ -112,11 +114,7 @@
         transport = memory.MemoryTransport()
         transport.put_bytes('foo', 'some bytes')
         wsgi_app = wsgi.SmartWSGIApp(transport)
-        def make_request(transport, write_func):
-            request = FakeRequest(transport, write_func)
-            self.request = request
-            return request
-        wsgi_app.make_request = make_request
+        wsgi_app.make_request = self._fake_make_request
         fake_input = StringIO('fake request')
         environ = self.build_environ({
             'REQUEST_METHOD': 'POST',
@@ -186,8 +184,9 @@
     def test_incomplete_request(self):
         transport = FakeTransport()
         wsgi_app = wsgi.SmartWSGIApp(transport)
-        def make_request(transport, write_func):
+        def make_request(transport, write_func, bytes):
             request = IncompleteRequest(transport, write_func)
+            request.accept_bytes(bytes)
             self.request = request
             return request
         wsgi_app.make_request = make_request
@@ -204,6 +203,41 @@
         self.assertEqual('200 OK', self.status)
         self.assertEqual('error\x01incomplete request\n', response)
+    def test_protocol_version_detection_one(self):
+        # SmartWSGIApp detects requests that don't start with '2\x01' as version
+        # one.
+        transport = memory.MemoryTransport()
+        wsgi_app = wsgi.SmartWSGIApp(transport)
+        fake_input = StringIO('hello\n')
+        environ = self.build_environ({
+            'REQUEST_METHOD': 'POST',
+            'CONTENT_LENGTH': len(fake_input.getvalue()),
+            'wsgi.input': fake_input,
+            'bzrlib.relpath': 'foo',
+        })
+        iterable = wsgi_app(environ, self.start_response)
+        response = self.read_response(iterable)
+        self.assertEqual('200 OK', self.status)
+        # Expect a version 1-encoded response.
+        self.assertEqual('ok\x012\n', response)
+    def test_protocol_version_detection_two(self):
+        # SmartWSGIApp detects requests that start with '2\x01' as version two.
+        transport = memory.MemoryTransport()
+        wsgi_app = wsgi.SmartWSGIApp(transport)
+        fake_input = StringIO('2\x01hello\n')
+        environ = self.build_environ({
+            'REQUEST_METHOD': 'POST',
+            'CONTENT_LENGTH': len(fake_input.getvalue()),
+            'wsgi.input': fake_input,
+            'bzrlib.relpath': 'foo',
+        })
+        iterable = wsgi_app(environ, self.start_response)
+        response = self.read_response(iterable)
+        self.assertEqual('200 OK', self.status)
+        # Expect a version 2-encoded response.
+        self.assertEqual('2\x01ok\x012\n', response)
 class FakeRequest(object):

=== modified file 'bzrlib/transport/http/'
--- a/bzrlib/transport/http/	2007-04-05 15:20:57 +0000
+++ b/bzrlib/transport/http/	2007-04-24 07:20:48 +0000
@@ -22,6 +22,7 @@
 from cStringIO import StringIO
+from import protocol
 from bzrlib.transport import chroot, get_transport, remote
 from bzrlib.urlutils import local_path_to_url
@@ -113,10 +114,10 @@
         relpath = environ['bzrlib.relpath']
         transport = self.backing_transport.clone(relpath)
         out_buffer = StringIO()
-        smart_protocol_request = self.make_request(transport, out_buffer.write)
         request_data_length = int(environ['CONTENT_LENGTH'])
         request_data_bytes = environ['wsgi.input'].read(request_data_length)
-        smart_protocol_request.accept_bytes(request_data_bytes)
+        smart_protocol_request = self.make_request(
+            transport, out_buffer.write, request_data_bytes)
         if smart_protocol_request.next_read_size() != 0:
             # The request appears to be incomplete, or perhaps it's just a
             # newer version we don't understand.  Regardless, all we can do
@@ -130,5 +131,14 @@
         start_response('200 OK', headers)
         return [response_data]
-    def make_request(self, transport, write_func):
-        return protocol.SmartServerRequestProtocolOne(transport, write_func)
+    def make_request(self, transport, write_func, request_bytes):
+        # XXX: This duplicates the logic in
+        # SmartServerStreamMedium._build_protocol.
+        if request_bytes.startswith('2\x01'):
+            protocol_class = protocol.SmartServerRequestProtocolTwo
+            request_bytes = request_bytes[2:]
+        else:
+            protocol_class = protocol.SmartServerRequestProtocolOne
+        server_protocol = protocol_class(transport, write_func)
+        server_protocol.accept_bytes(request_bytes)
+        return server_protocol

More information about the bazaar-commits mailing list