Rev 4903: Move all of the send logic into the _SmartClientRequest class. in http://bazaar.launchpad.net/~jameinel/bzr/2.1-client-read-reconnect-819604

John Arbash Meinel john at arbash-meinel.com
Sat Oct 8 09:43:45 UTC 2011


At http://bazaar.launchpad.net/~jameinel/bzr/2.1-client-read-reconnect-819604

------------------------------------------------------------
revno: 4903
revision-id: john at arbash-meinel.com-20111008094314-yh16binq9pf9fdu3
parent: john at arbash-meinel.com-20111007184018-klup501y7ncps2k8
committer: John Arbash Meinel <john at arbash-meinel.com>
branch nick: 2.1-client-read-reconnect-819604
timestamp: Sat 2011-10-08 11:43:14 +0200
message:
  Move all of the send logic into the _SmartClientRequest class.
  
  This shortens a lot of the function names, and avoids passing 6 arguments
  between lots of different functions so that we can retry requests, etc.
-------------- next part --------------
=== modified file 'bzrlib/smart/client.py'
--- a/bzrlib/smart/client.py	2011-10-07 18:40:18 +0000
+++ b/bzrlib/smart/client.py	2011-10-08 09:43:14 +0000
@@ -40,56 +40,6 @@
     def __repr__(self):
         return '%s(%r)' % (self.__class__.__name__, self._medium)
 
-    def _send_request_no_retry(self, encoder, method, args, body=None,
-                               readv_body=None, body_stream=None):
-        encoder.set_headers(self._headers)
-        if body is not None:
-            if readv_body is not None:
-                raise AssertionError(
-                    "body and readv_body are mutually exclusive.")
-            if body_stream is not None:
-                raise AssertionError(
-                    "body and body_stream are mutually exclusive.")
-            encoder.call_with_body_bytes((method, ) + args, body)
-        elif readv_body is not None:
-            if body_stream is not None:
-                raise AssertionError(
-                    "readv_body and body_stream are mutually exclusive.")
-            encoder.call_with_body_readv_array((method, ) + args, readv_body)
-        elif body_stream is not None:
-            encoder.call_with_body_stream((method, ) + args, body_stream)
-        else:
-            encoder.call(method, *args)
-
-    def _send_request(self, protocol_version, method, args, body=None,
-                      readv_body=None, body_stream=None):
-        encoder, response_handler = self._construct_protocol(
-            protocol_version)
-        try:
-            self._send_request_no_retry(encoder, method, args, body=body,
-                readv_body=readv_body, body_stream=body_stream)
-        except errors.ConnectionReset, e:
-            # If we fail during the _send_request_no_retry phase, then we can
-            # be confident that the server did not get our request, because we
-            # haven't started waiting for the reply yet. So try the request
-            # again. We only issue a single retry, because if the connection
-            # really is down, there is no reason to loop endlessly.
-
-            # Connection is dead, so close our end of it.
-            self._medium.reset()
-            if body_stream is not None:
-                # We can't determine how much of body_stream got consumed
-                # before we noticed the connection is down, so we don't retry
-                # here.
-                raise
-            trace.log_exception_quietly()
-            trace.warning('ConnectionReset calling %s, retrying' % (method,))
-            encoder, response_handler = self._construct_protocol(
-                protocol_version)
-            self._send_request_no_retry(encoder, method, args, body=body,
-                readv_body=readv_body, body_stream=body_stream)
-        return response_handler
-
     def _run_call_hooks(self, method, args, body, readv_body):
         if not _SmartClient.hooks['call']:
             return
@@ -97,38 +47,6 @@
         for hook in _SmartClient.hooks['call']:
             hook(params)
 
-    def _determine_protocol_version(self, method, args, body=None,
-        readv_body=None, body_stream=None, expect_response_body=True):
-        for protocol_version in [3, 2]:
-            if protocol_version == 2:
-                # If v3 doesn't work, the remote side is older than 1.6.
-                self._medium._remember_remote_is_before((1, 6))
-            response_handler = self._send_request(
-                protocol_version, method, args, body=body,
-                readv_body=readv_body, body_stream=body_stream)
-            try:
-                response_tuple = response_handler.read_response_tuple(
-                    expect_body=expect_response_body)
-            except errors.UnexpectedProtocolVersionMarker, err:
-                # TODO: We could recover from this without disconnecting if
-                # we recognise the protocol version.
-                warning(
-                    'Server does not understand Bazaar network protocol %d,'
-                    ' reconnecting.  (Upgrade the server to avoid this.)'
-                    % (protocol_version,))
-                self._medium.disconnect()
-                continue
-            except errors.ErrorFromSmartServer:
-                # If we received an error reply from the server, then it
-                # must be ok with this protocol version.
-                self._medium._protocol_version = protocol_version
-                raise
-            else:
-                self._medium._protocol_version = protocol_version
-                return response_tuple, response_handler
-        raise errors.SmartProtocolError(
-            'Server is not a Bazaar server: ' + str(err))
-
     def _call_and_read_response(self, method, args, body=None, readv_body=None,
             body_stream=None, expect_response_body=True):
         request = _SmartClientRequest(self, method, args, body=body,
@@ -136,22 +54,6 @@
             expect_response_body=expect_response_body)
         return request.call_and_read_response()
 
-    def _construct_protocol(self, version):
-        request = self._medium.get_request()
-        if version == 3:
-            request_encoder = protocol.ProtocolThreeRequester(request)
-            response_handler = message.ConventionalResponseHandler()
-            response_proto = protocol.ProtocolThreeDecoder(
-                response_handler, expect_version_marker=True)
-            response_handler.setProtoAndMediumRequest(response_proto, request)
-        elif version == 2:
-            request_encoder = protocol.SmartClientRequestProtocolTwo(request)
-            response_handler = request_encoder
-        else:
-            request_encoder = protocol.SmartClientRequestProtocolOne(request)
-            response_handler = request_encoder
-        return request_encoder, response_handler
-
     def call(self, method, *args):
         """Call a method on the remote server."""
         result, protocol = self.call_expecting_body(method, *args)
@@ -238,10 +140,8 @@
             return self._call()
 
     def _call(self):
-        response_handler = self.client._send_request(
-            self.client._medium._protocol_version, self.method, self.args,
-            body=self.body, readv_body=self.readv_body,
-            body_stream=self.body_stream)
+        response_handler = self._send(
+            self.client._medium._protocol_version)
         response_tuple = response_handler.read_response_tuple(
             expect_body=self.expect_response_body)
         return (response_tuple, response_handler)
@@ -251,9 +151,7 @@
             if protocol_version == 2:
                 # If v3 doesn't work, the remote side is older than 1.6.
                 self.client._medium._remember_remote_is_before((1, 6))
-            response_handler = self.client._send_request(
-                protocol_version, self.method, self.args, body=self.body,
-                readv_body=self.readv_body, body_stream=self.body_stream)
+            response_handler = self._send(protocol_version)
             try:
                 response_tuple = response_handler.read_response_tuple(
                     expect_body=self.expect_response_body)
@@ -277,6 +175,70 @@
         raise errors.SmartProtocolError(
             'Server is not a Bazaar server: ' + str(err))
 
+    def _construct_protocol(self, version):
+        request = self.client._medium.get_request()
+        if version == 3:
+            request_encoder = protocol.ProtocolThreeRequester(request)
+            response_handler = message.ConventionalResponseHandler()
+            response_proto = protocol.ProtocolThreeDecoder(
+                response_handler, expect_version_marker=True)
+            response_handler.setProtoAndMediumRequest(response_proto, request)
+        elif version == 2:
+            request_encoder = protocol.SmartClientRequestProtocolTwo(request)
+            response_handler = request_encoder
+        else:
+            request_encoder = protocol.SmartClientRequestProtocolOne(request)
+            response_handler = request_encoder
+        return request_encoder, response_handler
+
+    def _send(self, protocol_version):
+        encoder, response_handler = self._construct_protocol(protocol_version)
+        try:
+            self._send_no_retry(encoder)
+        except errors.ConnectionReset, e:
+            # If we fail during the _send_no_retry phase, then we can
+            # be confident that the server did not get our request, because we
+            # haven't started waiting for the reply yet. So try the request
+            # again. We only issue a single retry, because if the connection
+            # really is down, there is no reason to loop endlessly.
+
+            # Connection is dead, so close our end of it.
+            self.client._medium.reset()
+            if self.body_stream is not None:
+                # We can't determine how much of body_stream got consumed
+                # before we noticed the connection is down, so we don't retry
+                # here.
+                raise
+            trace.log_exception_quietly()
+            trace.warning('ConnectionReset calling %s, retrying'
+                          % (self.method,))
+            encoder, response_handler = self._construct_protocol(
+                protocol_version)
+            self._send_no_retry(encoder)
+        return response_handler
+
+    def _send_no_retry(self, encoder):
+        encoder.set_headers(self.client._headers)
+        if self.body is not None:
+            if self.readv_body is not None:
+                raise AssertionError(
+                    "body and readv_body are mutually exclusive.")
+            if self.body_stream is not None:
+                raise AssertionError(
+                    "body and body_stream are mutually exclusive.")
+            encoder.call_with_body_bytes((self.method, ) + self.args, self.body)
+        elif self.readv_body is not None:
+            if self.body_stream is not None:
+                raise AssertionError(
+                    "readv_body and body_stream are mutually exclusive.")
+            encoder.call_with_body_readv_array((self.method, ) + self.args,
+                                               self.readv_body)
+        elif self.body_stream is not None:
+            encoder.call_with_body_stream((self.method, ) + self.args,
+                                          self.body_stream)
+        else:
+            encoder.call(self.method, *self.args)
+
 
 class SmartClientHooks(hooks.Hooks):
 

=== modified file 'bzrlib/tests/test_smart_transport.py'
--- a/bzrlib/tests/test_smart_transport.py	2011-10-07 12:12:31 +0000
+++ b/bzrlib/tests/test_smart_transport.py	2011-10-08 09:43:14 +0000
@@ -3378,19 +3378,24 @@
         # XXX: need a test that smart_client._headers is passed to the request
         # encoder.
 
-    def test__send_request_no_retry_pipes(self):
+
+class Test_SmartClientRequest(tests.TestCase):
+
+    def test__send_no_retry_pipes(self):
         client_read, server_write = create_file_pipes()
         server_read, client_write = create_file_pipes()
         client_medium = medium.SmartSimplePipesClientMedium(client_read,
             client_write, base='/')
         smart_client = client._SmartClient(client_medium)
+        smart_request = client._SmartClientRequest(smart_client,
+            'hello', ())
         # Close the server side
         server_read.close()
-        encoder, response_handler = smart_client._construct_protocol(3)
+        encoder, response_handler = smart_request._construct_protocol(3)
         self.assertRaises(errors.ConnectionReset,
-            smart_client._send_request_no_retry, encoder, 'hello', ())
+            smart_request._send_no_retry, encoder)
 
-    def test__send_request_read_response_sockets(self):
+    def test__send_read_response_sockets(self):
         listen_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         listen_sock.bind(('127.0.0.1', 0))
         listen_sock.listen(1)
@@ -3398,16 +3403,17 @@
         client_medium = medium.SmartTCPClientMedium(host, port, '/')
         client_medium._ensure_connection()
         smart_client = client._SmartClient(client_medium)
+        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
         # Accept the connection, but don't actually talk to the client.
         server_sock, _ = listen_sock.accept()
         server_sock.close()
         # Sockets buffer and don't really notice that the server has closed the
         # connection until we try to read again.
-        handler = smart_client._send_request(3, 'hello', ())
+        handler = smart_request._send(3)
         self.assertRaises(errors.ConnectionReset,
             handler.read_response_tuple, expect_body=False)
 
-    def test__send_request_retries_on_write(self):
+    def test__send_retries_on_write(self):
         response = StringIO()
         output = StringIO()
         vendor = FirstRejectedStringIOSSHVendor(response, output)
@@ -3415,7 +3421,8 @@
             'a host', 'a port', 'a user', 'a pass', 'base', vendor,
             'bzr')
         smart_client = client._SmartClient(client_medium)
-        handler = smart_client._send_request(3, 'hello', ())
+        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
+        handler = smart_request._send(3)
         message_sent = output.getvalue()
         self.assertStartsWith(message_sent, 'bzr message 3 (bzr 1.6)\n')
         self.assertEndsWith(message_sent, 's\x00\x00\x00\tl5:helloee')
@@ -3428,7 +3435,7 @@
             ],
             vendor.calls)
 
-    def test__send_request_doesnt_retry_read_failure(self):
+    def test__send_doesnt_retry_read_failure(self):
         response = StringIO()
         output = StringIO()
         vendor = FirstRejectedStringIOSSHVendor(response, output,
@@ -3437,7 +3444,8 @@
             'a host', 'a port', 'a user', 'a pass', 'base', vendor,
             'bzr')
         smart_client = client._SmartClient(client_medium)
-        handler = smart_client._send_request(3, 'hello', ())
+        smart_request = client._SmartClientRequest(smart_client, 'hello', ())
+        handler = smart_request._send(3)
         message_sent = output.getvalue()
         self.assertStartsWith(message_sent, 'bzr message 3 (bzr 1.6)\n')
         self.assertEndsWith(message_sent, 's\x00\x00\x00\tl5:helloee')
@@ -3448,9 +3456,9 @@
             vendor.calls)
         self.assertRaises(errors.ConnectionReset, handler.read_response_tuple)
 
-    def test__send_request_doesnt_retry_body_stream(self):
+    def test__send_doesnt_retry_body_stream(self):
         # We don't know how much of body_stream would get iterated as part of
-        # _send_request before it failed to actually send the request, so we
+        # _send before it failed to actually send the request, so we
         # just always fail in this condition.
         response = StringIO()
         output = StringIO()
@@ -3459,8 +3467,9 @@
             'a host', 'a port', 'a user', 'a pass', 'base', vendor,
             'bzr')
         smart_client = client._SmartClient(client_medium)
-        self.assertRaises(errors.ConnectionReset,
-            smart_client._send_request, 3, 'hello', (), body_stream=['a', 'b'])
+        smart_request = client._SmartClientRequest(smart_client, 'hello', (),
+            body_stream=['a', 'b'])
+        self.assertRaises(errors.ConnectionReset, smart_request._send, 3)
         # We got one connect, but it fails, so we disconnect, but we don't
         # retry it
         self.assertEqual(



More information about the bazaar-commits mailing list