Rev 4903: Merge the refactored SmartClientRequest code, and update to match it. in http://bazaar.launchpad.net/~jameinel/bzr/2.1-client-stream-started-819604

John Arbash Meinel john at arbash-meinel.com
Sat Oct 8 10:36:46 UTC 2011


At http://bazaar.launchpad.net/~jameinel/bzr/2.1-client-stream-started-819604

------------------------------------------------------------
revno: 4903 [merge]
revision-id: john at arbash-meinel.com-20111008103614-zx2lroae5c1re300
parent: john at arbash-meinel.com-20111007161140-yew7yg5n6bajyejm
parent: john at arbash-meinel.com-20111008102315-5y0yk3bnxhpa86el
committer: John Arbash Meinel <john at arbash-meinel.com>
branch nick: 2.1-client-stream-started-819604
timestamp: Sat 2011-10-08 12:36:14 +0200
message:
  Merge the refactored SmartClientRequest code, and update to match it.
modified:
  bzrlib/smart/client.py         client.py-20061116014825-2k6ada6xgulslami-1
  bzrlib/tests/test_smart_transport.py test_ssh_transport.py-20060608202016-c25gvf1ob7ypbus6-2
-------------- next part --------------
=== modified file 'bzrlib/smart/client.py'
--- a/bzrlib/smart/client.py	2011-10-07 16:11:40 +0000
+++ b/bzrlib/smart/client.py	2011-10-08 10:36:14 +0000
@@ -16,7 +16,6 @@
 
 import bzrlib
 from bzrlib.smart import message, protocol
-from bzrlib.trace import warning
 from bzrlib import (
     errors,
     hooks,
@@ -40,117 +39,12 @@
     def __repr__(self):
         return '%s(%r)' % (self.__class__.__name__, self._medium)
 
-    def _send_request_no_retry(self, encoder, method, args, body=None,
-                               readv_body=None, body_stream=None):
-        encoder.set_headers(self._headers)
-        if body is not None:
-            if readv_body is not None:
-                raise AssertionError(
-                    "body and readv_body are mutually exclusive.")
-            if body_stream is not None:
-                raise AssertionError(
-                    "body and body_stream are mutually exclusive.")
-            encoder.call_with_body_bytes((method, ) + args, body)
-        elif readv_body is not None:
-            if body_stream is not None:
-                raise AssertionError(
-                    "readv_body and body_stream are mutually exclusive.")
-            encoder.call_with_body_readv_array((method, ) + args, readv_body)
-        elif body_stream is not None:
-            encoder.call_with_body_stream((method, ) + args, body_stream)
-        else:
-            encoder.call(method, *args)
-
-    def _send_request(self, protocol_version, method, args, body=None,
-                      readv_body=None, body_stream=None):
-        encoder, response_handler = self._construct_protocol(
-            protocol_version)
-        try:
-            self._send_request_no_retry(encoder, method, args, body=body,
-                readv_body=readv_body, body_stream=body_stream)
-        except errors.ConnectionReset, e:
-            # If we fail during the _send_request_no_retry phase, then we can
-            # be confident that the server did not get our request, because we
-            # haven't started waiting for the reply yet. So try the request
-            # again. We only issue a single retry, because if the connection
-            # really is down, there is no reason to loop endlessly.
-
-            # Connection is dead, so close our end of it.
-            self._medium.reset()
-            if body_stream is not None and encoder.body_stream_started:
-                # We consumed some of body_stream, so it isn't safe to retry
-                raise
-            trace.warning('ConnectionReset calling %s, retrying' % (method,))
-            trace.log_exception_quietly()
-            encoder, response_handler = self._construct_protocol(
-                protocol_version)
-            self._send_request_no_retry(encoder, method, args, body=body,
-                readv_body=readv_body, body_stream=body_stream)
-        return response_handler
-
-    def _run_call_hooks(self, method, args, body, readv_body):
-        if not _SmartClient.hooks['call']:
-            return
-        params = CallHookParams(method, args, body, readv_body, self._medium)
-        for hook in _SmartClient.hooks['call']:
-            hook(params)
-
     def _call_and_read_response(self, method, args, body=None, readv_body=None,
             body_stream=None, expect_response_body=True):
-        self._run_call_hooks(method, args, body, readv_body)
-        if self._medium._protocol_version is not None:
-            response_handler = self._send_request(
-                self._medium._protocol_version, method, args, body=body,
-                readv_body=readv_body, body_stream=body_stream)
-            return (response_handler.read_response_tuple(
-                        expect_body=expect_response_body),
-                    response_handler)
-        else:
-            for protocol_version in [3, 2]:
-                if protocol_version == 2:
-                    # If v3 doesn't work, the remote side is older than 1.6.
-                    self._medium._remember_remote_is_before((1, 6))
-                response_handler = self._send_request(
-                    protocol_version, method, args, body=body,
-                    readv_body=readv_body, body_stream=body_stream)
-                try:
-                    response_tuple = response_handler.read_response_tuple(
-                        expect_body=expect_response_body)
-                except errors.UnexpectedProtocolVersionMarker, err:
-                    # TODO: We could recover from this without disconnecting if
-                    # we recognise the protocol version.
-                    warning(
-                        'Server does not understand Bazaar network protocol %d,'
-                        ' reconnecting.  (Upgrade the server to avoid this.)'
-                        % (protocol_version,))
-                    self._medium.disconnect()
-                    continue
-                except errors.ErrorFromSmartServer:
-                    # If we received an error reply from the server, then it
-                    # must be ok with this protocol version.
-                    self._medium._protocol_version = protocol_version
-                    raise
-                else:
-                    self._medium._protocol_version = protocol_version
-                    return response_tuple, response_handler
-            raise errors.SmartProtocolError(
-                'Server is not a Bazaar server: ' + str(err))
-
-    def _construct_protocol(self, version):
-        request = self._medium.get_request()
-        if version == 3:
-            request_encoder = protocol.ProtocolThreeRequester(request)
-            response_handler = message.ConventionalResponseHandler()
-            response_proto = protocol.ProtocolThreeDecoder(
-                response_handler, expect_version_marker=True)
-            response_handler.setProtoAndMediumRequest(response_proto, request)
-        elif version == 2:
-            request_encoder = protocol.SmartClientRequestProtocolTwo(request)
-            response_handler = request_encoder
-        else:
-            request_encoder = protocol.SmartClientRequestProtocolOne(request)
-            response_handler = request_encoder
-        return request_encoder, response_handler
+        request = _SmartClientRequest(self, method, args, body=body,
+            readv_body=readv_body, body_stream=body_stream,
+            expect_response_body=expect_response_body)
+        return request.call_and_read_response()
 
     def call(self, method, *args):
         """Call a method on the remote server."""
@@ -216,6 +110,172 @@
         return self._medium.remote_path_from_transport(transport)
 
 
+class _SmartClientRequest(object):
+    """Encapsulate the logic for a single request.
+
+    This class handles things like reconnecting and sending the request a
+    second time when the connection is reset in the middle. It also handles the
+    multiple requests that get made if we don't know what protocol the server
+    supports yet.
+
+    Generally, you build up one of these objects, passing in the arguments that
+    you want to send to the server, and then use 'call_and_read_response' to
+    get the response from the server.
+    """
+
+    def __init__(self, client, method, args, body=None, readv_body=None,
+                 body_stream=None, expect_response_body=True):
+        self.client = client
+        self.method = method
+        self.args = args
+        self.body = body
+        self.readv_body = readv_body
+        self.body_stream = body_stream
+        self.expect_response_body = expect_response_body
+
+    def call_and_read_response(self):
+        """Send the request to the server, and read the initial response.
+
+        This doesn't read all of the body content of the response, instead it
+        returns (response_tuple, response_handler). response_tuple is the 'ok',
+        or 'error' information, and 'response_handler' can be used to get the
+        content stream out.
+        """
+        self._run_call_hooks()
+        protocol_version = self.client._medium._protocol_version
+        if protocol_version is None:
+            return self._call_determining_protocol_version()
+        else:
+            return self._call(protocol_version)
+
+    def _run_call_hooks(self):
+        if not _SmartClient.hooks['call']:
+            return
+        params = CallHookParams(self.method, self.args, self.body,
+                                self.readv_body, self.client._medium)
+        for hook in _SmartClient.hooks['call']:
+            hook(params)
+
+    def _call(self, protocol_version):
+        """We know the protocol version.
+
+        So this just sends the request, and then reads the response. This is
+        where the code will be to retry requests if the connection is closed.
+        """
+        response_handler = self._send(protocol_version)
+        response_tuple = response_handler.read_response_tuple(
+            expect_body=self.expect_response_body)
+        return (response_tuple, response_handler)
+
+    def _call_determining_protocol_version(self):
+        """Determine what protocol the remote server supports.
+
+        We do this by placing a request in the most recent protocol, and
+        handling the UnexpectedProtocolVersionMarker from the server.
+        """
+        for protocol_version in [3, 2]:
+            if protocol_version == 2:
+                # If v3 doesn't work, the remote side is older than 1.6.
+                self.client._medium._remember_remote_is_before((1, 6))
+            try:
+                response_tuple, response_handler = self._call(protocol_version)
+            except errors.UnexpectedProtocolVersionMarker, err:
+                # TODO: We could recover from this without disconnecting if
+                # we recognise the protocol version.
+                trace.warning(
+                    'Server does not understand Bazaar network protocol %d,'
+                    ' reconnecting.  (Upgrade the server to avoid this.)'
+                    % (protocol_version,))
+                self.client._medium.disconnect()
+                continue
+            except errors.ErrorFromSmartServer:
+                # If we received an error reply from the server, then it
+                # must be ok with this protocol version.
+                self.client._medium._protocol_version = protocol_version
+                raise
+            else:
+                self.client._medium._protocol_version = protocol_version
+                return response_tuple, response_handler
+        raise errors.SmartProtocolError(
+            'Server is not a Bazaar server: ' + str(err))
+
+    def _construct_protocol(self, version):
+        """Build the encoding stack for a given protocol version."""
+        request = self.client._medium.get_request()
+        if version == 3:
+            request_encoder = protocol.ProtocolThreeRequester(request)
+            response_handler = message.ConventionalResponseHandler()
+            response_proto = protocol.ProtocolThreeDecoder(
+                response_handler, expect_version_marker=True)
+            response_handler.setProtoAndMediumRequest(response_proto, request)
+        elif version == 2:
+            request_encoder = protocol.SmartClientRequestProtocolTwo(request)
+            response_handler = request_encoder
+        else:
+            request_encoder = protocol.SmartClientRequestProtocolOne(request)
+            response_handler = request_encoder
+        return request_encoder, response_handler
+
+    def _send(self, protocol_version):
+        """Encode the request, and send it to the server.
+
+        This will retry a request if we get a ConnectionReset while sending the
+        request to the server. (Unless we have a body_stream that we have
+        already started consuming, since we can't restart body_streams)
+
+        :return: response_handler as defined by _construct_protocol
+        """
+        encoder, response_handler = self._construct_protocol(protocol_version)
+        try:
+            self._send_no_retry(encoder)
+        except errors.ConnectionReset, e:
+            # If we fail during the _send_no_retry phase, then we can
+            # be confident that the server did not get our request, because we
+            # haven't started waiting for the reply yet. So try the request
+            # again. We only issue a single retry, because if the connection
+            # really is down, there is no reason to loop endlessly.
+
+            # Connection is dead, so close our end of it.
+            self.client._medium.reset()
+            if self.body_stream is not None and encoder.body_stream_started:
+                # We can't restart a body_stream that has been partially
+                # consumed, so we don't retry.
+                # Note: We don't have to worry about
+                #   SmartClientRequestProtocolOne or Two, because they don't
+                #   support client-side body streams.
+                raise
+            trace.log_exception_quietly()
+            trace.warning('ConnectionReset calling %s, retrying'
+                          % (self.method,))
+            encoder, response_handler = self._construct_protocol(
+                protocol_version)
+            self._send_no_retry(encoder)
+        return response_handler
+
+    def _send_no_retry(self, encoder):
+        """Just encode the request and try to send it."""
+        encoder.set_headers(self.client._headers)
+        if self.body is not None:
+            if self.readv_body is not None:
+                raise AssertionError(
+                    "body and readv_body are mutually exclusive.")
+            if self.body_stream is not None:
+                raise AssertionError(
+                    "body and body_stream are mutually exclusive.")
+            encoder.call_with_body_bytes((self.method, ) + self.args, self.body)
+        elif self.readv_body is not None:
+            if self.body_stream is not None:
+                raise AssertionError(
+                    "readv_body and body_stream are mutually exclusive.")
+            encoder.call_with_body_readv_array((self.method, ) + self.args,
+                                               self.readv_body)
+        elif self.body_stream is not None:
+            encoder.call_with_body_stream((self.method, ) + self.args,
+                                          self.body_stream)
+        else:
+            encoder.call(self.method, *self.args)
+
+
 class SmartClientHooks(hooks.Hooks):
 
     def __init__(self):

=== modified file 'bzrlib/tests/test_smart_transport.py'
--- a/bzrlib/tests/test_smart_transport.py	2011-10-07 16:11:40 +0000
+++ b/bzrlib/tests/test_smart_transport.py	2011-10-08 10:36:14 +0000
@@ -3405,19 +3405,24 @@
         # XXX: need a test that smart_client._headers is passed to the request
         # encoder.
 
-    def test__send_request_no_retry_pipes(self):
+
+class Test_SmartClientRequest(tests.TestCase):
+
+    def test__send_no_retry_pipes(self):
         client_read, server_write = create_file_pipes()
         server_read, client_write = create_file_pipes()
         client_medium = medium.SmartSimplePipesClientMedium(client_read,
             client_write, base='/')
         smart_client = client._SmartClient(client_medium)
+        smart_request = client._SmartClientRequest(smart_client,
+            'hello', ())
         # Close the server side
         server_read.close()
-        encoder, response_handler = smart_client._construct_protocol(3)
+        encoder, response_handler = smart_request._construct_protocol(3)
         self.assertRaises(errors.ConnectionReset,
-            smart_client._send_request_no_retry, encoder, 'hello', ())
+            smart_request._send_no_retry, encoder)
 
-    def test__send_request_read_response_sockets(self):
+    def test__send_read_response_sockets(self):
         listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         listen_sock.bind(('127.0.0.1', 0))
         listen_sock.listen(1)
@@ -3425,16 +3430,17 @@
         client_medium = medium.SmartTCPClientMedium(host, port, '/')
         client_medium._ensure_connection()
         smart_client = client._SmartClient(client_medium)
+        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
         # Accept the connection, but don't actually talk to the client.
         server_sock, _ = listen_sock.accept()
         server_sock.close()
         # Sockets buffer and don't really notice that the server has closed the
         # connection until we try to read again.
-        handler = smart_client._send_request(3, 'hello', ())
+        handler = smart_request._send(3)
         self.assertRaises(errors.ConnectionReset,
             handler.read_response_tuple, expect_body=False)
 
-    def test__send_request_retries_on_write(self):
+    def test__send_retries_on_write(self):
         response = StringIO()
         output = StringIO()
         vendor = FirstRejectedStringIOSSHVendor(response, output)
@@ -3442,7 +3448,8 @@
             'a host', 'a port', 'a user', 'a pass', 'base', vendor,
             'bzr')
         smart_client = client._SmartClient(client_medium, headers={})
-        handler = smart_client._send_request(3, 'hello', ())
+        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
+        handler = smart_request._send(3)
         self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
                          '\x00\x00\x00\x02de'   # empty headers
                          's\x00\x00\x00\tl5:helloee',
@@ -3456,7 +3463,7 @@
             ],
             vendor.calls)
 
-    def test__send_request_doesnt_retry_read_failure(self):
+    def test__send_doesnt_retry_read_failure(self):
         response = StringIO()
         output = StringIO()
         vendor = FirstRejectedStringIOSSHVendor(response, output,
@@ -3465,7 +3472,8 @@
             'a host', 'a port', 'a user', 'a pass', 'base', vendor,
             'bzr')
         smart_client = client._SmartClient(client_medium, headers={})
-        handler = smart_client._send_request(3, 'hello', ())
+        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
+        handler = smart_request._send(3)
         self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
                          '\x00\x00\x00\x02de'   # empty headers
                          's\x00\x00\x00\tl5:helloee',
@@ -3485,7 +3493,9 @@
             'a host', 'a port', 'a user', 'a pass', 'base', vendor,
             'bzr')
         smart_client = client._SmartClient(client_medium, headers={})
-        smart_client._send_request(3, 'hello', (), body_stream=['a', 'b'])
+        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
+            body_stream=['a', 'b'])
+        response_handler = smart_request._send(3)
         # We connect, get disconnected, and notice before consuming the stream,
         # so we try again one time and succeed.
         self.assertEqual(
@@ -3528,8 +3538,9 @@
             'a host', 'a port', 'a user', 'a pass', 'base', vendor,
             'bzr')
         smart_client = client._SmartClient(client_medium, headers={})
-        self.assertRaises(errors.ConnectionReset,
-            smart_client._send_request, 3, 'hello', (), body_stream=['a', 'b'])
+        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
+            body_stream=['a', 'b'])
+        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
         # We connect, and manage to get to the point that we start consuming
         # the body stream. The next write fails, so we just stop.
         self.assertEqual(



More information about the bazaar-commits mailing list