Rev 4902: Implement retrying a request as long as we haven't started consuming the body stream. in http://bazaar.launchpad.net/~jameinel/bzr/2.1-client-stream-started-819604

John Arbash Meinel john at arbash-meinel.com
Fri Oct 7 16:12:06 UTC 2011


At http://bazaar.launchpad.net/~jameinel/bzr/2.1-client-stream-started-819604

------------------------------------------------------------
revno: 4902
revision-id: john at arbash-meinel.com-20111007161140-yew7yg5n6bajyejm
parent: john at arbash-meinel.com-20111007121231-v1oo7htubsh9qiro
committer: John Arbash Meinel <john at arbash-meinel.com>
branch nick: 2.1-client-stream-started-819604
timestamp: Fri 2011-10-07 18:11:40 +0200
message:
  Implement retrying a request as long as we haven't started consuming the body stream.
  
  Also, add a flush() before we start consuming the stream, so that any disconnect, etc
  can be detected early. I wonder if it would also let us detect unknownsmartverb faster.
  
  But as the important streaming verb was introduced in 1.19 (pre 2.0) I don't know that we care.
-------------- next part --------------
=== modified file 'bzrlib/smart/client.py'
--- a/bzrlib/smart/client.py	2011-10-07 12:12:31 +0000
+++ b/bzrlib/smart/client.py	2011-10-07 16:11:40 +0000
@@ -77,13 +77,11 @@
 
             # Connection is dead, so close our end of it.
             self._medium.reset()
-            if body_stream is not None:
-                # We can't determine how much of body_stream got consumed
-                # before we noticed the connection is down, so we don't retry
-                # here.
+            if body_stream is not None and encoder.body_stream_started:
+                # We consumed some of body_stream, so it isn't safe to retry
                 raise
+            trace.warning('ConnectionReset calling %s, retrying' % (method,))
             trace.log_exception_quietly()
-            trace.warning('ConnectionReset calling %s, retrying' % (method,))
             encoder, response_handler = self._construct_protocol(
                 protocol_version)
             self._send_request_no_retry(encoder, method, args, body=body,

=== modified file 'bzrlib/smart/medium.py'
--- a/bzrlib/smart/medium.py	2011-10-06 14:01:32 +0000
+++ b/bzrlib/smart/medium.py	2011-10-07 16:11:40 +0000
@@ -739,8 +739,7 @@
         except IOError, e:
             if e.errno in (errno.EINVAL, errno.EPIPE):
                 raise errors.ConnectionReset(
-                    "Error trying to write to subprocess:\n%s"
-                    % (e,))
+                    "Error trying to write to subprocess:\n%s" % (e,))
             raise
         self._report_activity(len(bytes), 'write')
 

=== modified file 'bzrlib/smart/protocol.py'
--- a/bzrlib/smart/protocol.py	2011-09-26 15:18:14 +0000
+++ b/bzrlib/smart/protocol.py	2011-10-07 16:11:40 +0000
@@ -1282,6 +1282,7 @@
         _ProtocolThreeEncoder.__init__(self, medium_request.accept_bytes)
         self._medium_request = medium_request
         self._headers = {}
+        self.body_stream_started = None
 
     def set_headers(self, headers):
         self._headers = headers.copy()
@@ -1347,6 +1348,7 @@
             if path is not None:
                 mutter('                  (to %s)', path)
             self._request_start_time = osutils.timer_func()
+        self.body_stream_started = False
         self._write_protocol_version()
         self._write_headers(self._headers)
         self._write_structure(args)
@@ -1354,6 +1356,9 @@
         #       have finished sending the stream.  We would notice at the end
         #       anyway, but if the medium can deliver it early then it's good
         #       to short-circuit the whole request...
+        # Provoke any ConnectionReset failures before we start the body stream.
+        self.flush()
+        self.body_stream_started = True
         for exc_info, part in _iter_with_errors(stream):
             if exc_info is not None:
                 # Iterating the stream failed.  Cleanly abort the request.

=== modified file 'bzrlib/tests/test_smart_transport.py'
--- a/bzrlib/tests/test_smart_transport.py	2011-10-07 12:12:31 +0000
+++ b/bzrlib/tests/test_smart_transport.py	2011-10-07 16:11:40 +0000
@@ -98,7 +98,7 @@
     def __init__(self, read_from, write_to, fail_at_write=True):
         super(FirstRejectedStringIOSSHVendor, self).__init__(read_from,
             write_to)
-        self.fail_at_write= fail_at_write
+        self.fail_at_write = fail_at_write
         self._first = True
 
     def connect_ssh(self, username, password, host, port, command):
@@ -2943,6 +2943,33 @@
             'e', # end
             output.getvalue())
 
+    def test_records_start_of_body_stream(self):
+        requester, output = self.make_client_encoder_and_output()
+        requester.set_headers({})
+        in_stream = [False]
+        def stream_checker():
+            self.assertTrue(requester.body_stream_started)
+            in_stream[0] = True
+            yield 'content'
+        flush_called = []
+        orig_flush = requester.flush
+        def tracked_flush():
+            flush_called.append(in_stream[0])
+            if in_stream[0]:
+                self.assertTrue(requester.body_stream_started)
+            else:
+                self.assertFalse(requester.body_stream_started)
+            return orig_flush()
+        requester.flush = tracked_flush
+        requester.call_with_body_stream(('one arg',), stream_checker())
+        self.assertEqual(
+            'bzr message 3 (bzr 1.6)\n' # protocol version
+            '\x00\x00\x00\x02de' # headers
+            's\x00\x00\x00\x0bl7:one arge' # args
+            'b\x00\x00\x00\x07content' # body
+            'e', output.getvalue())
+        self.assertEqual([False, True, True], flush_called)
+
 
 class StubMediumRequest(object):
     """A stub medium request that tracks the number of times accept_bytes is
@@ -3414,11 +3441,12 @@
         client_medium = medium.SmartSSHClientMedium(
             'a host', 'a port', 'a user', 'a pass', 'base', vendor,
             'bzr')
-        smart_client = client._SmartClient(client_medium)
+        smart_client = client._SmartClient(client_medium, headers={})
         handler = smart_client._send_request(3, 'hello', ())
-        message_sent = output.getvalue()
-        self.assertStartsWith(message_sent, 'bzr message 3 (bzr 1.6)\n')
-        self.assertEndsWith(message_sent, 's\x00\x00\x00\tl5:helloee')
+        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
+                         '\x00\x00\x00\x02de'   # empty headers
+                         's\x00\x00\x00\tl5:helloee',
+                         output.getvalue())
         self.assertEqual(
             [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
               ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
@@ -3436,11 +3464,12 @@
         client_medium = medium.SmartSSHClientMedium(
             'a host', 'a port', 'a user', 'a pass', 'base', vendor,
             'bzr')
-        smart_client = client._SmartClient(client_medium)
+        smart_client = client._SmartClient(client_medium, headers={})
         handler = smart_client._send_request(3, 'hello', ())
-        message_sent = output.getvalue()
-        self.assertStartsWith(message_sent, 'bzr message 3 (bzr 1.6)\n')
-        self.assertEndsWith(message_sent, 's\x00\x00\x00\tl5:helloee')
+        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
+                         '\x00\x00\x00\x02de'   # empty headers
+                         's\x00\x00\x00\tl5:helloee',
+                         output.getvalue())
         self.assertEqual(
             [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
               ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
@@ -3448,27 +3477,71 @@
             vendor.calls)
         self.assertRaises(errors.ConnectionReset, handler.read_response_tuple)
 
-    def test__send_request_doesnt_retry_body_stream(self):
-        # We don't know how much of body_stream would get iterated as part of
-        # _send_request before it failed to actually send the request, so we
-        # just always fail in this condition.
+    def test__send_request_retries_body_stream_if_not_started(self):
         response = StringIO()
         output = StringIO()
         vendor = FirstRejectedStringIOSSHVendor(response, output)
         client_medium = medium.SmartSSHClientMedium(
             'a host', 'a port', 'a user', 'a pass', 'base', vendor,
             'bzr')
-        smart_client = client._SmartClient(client_medium)
+        smart_client = client._SmartClient(client_medium, headers={})
+        smart_client._send_request(3, 'hello', (), body_stream=['a', 'b'])
+        # We connect, get disconnected, and notice before consuming the stream,
+        # so we try again one time and succeed.
+        self.assertEqual(
+            [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
+              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
+             ('close',),
+             ('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
+              ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
+            ],
+            vendor.calls)
+        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
+                         '\x00\x00\x00\x02de'   # empty headers
+                         's\x00\x00\x00\tl5:helloe'
+                         'b\x00\x00\x00\x01a'
+                         'b\x00\x00\x00\x01b'
+                         'e',
+                         output.getvalue())
+
+    def test__send_request_stops_if_body_started(self):
+        # We intentionally use the python StringIO so that we can subclass it.
+        from StringIO import StringIO
+        response = StringIO()
+
+        class FailAfterFirstWrite(StringIO):
+            """Allow one 'write' call to pass, fail the rest"""
+            def __init__(self):
+                StringIO.__init__(self)
+                self._first = True
+
+            def write(self, s):
+                if self._first:
+                    self._first = False
+                    return StringIO.write(self, s)
+                raise IOError(errno.EINVAL, 'invalid file handle')
+        output = FailAfterFirstWrite()
+
+        vendor = FirstRejectedStringIOSSHVendor(response, output,
+            fail_at_write=False)
+        client_medium = medium.SmartSSHClientMedium(
+            'a host', 'a port', 'a user', 'a pass', 'base', vendor,
+            'bzr')
+        smart_client = client._SmartClient(client_medium, headers={})
         self.assertRaises(errors.ConnectionReset,
             smart_client._send_request, 3, 'hello', (), body_stream=['a', 'b'])
-        # We got one connect, but it fails, so we disconnect, but we don't
-        # retry it
+        # We connect, and manage to get to the point that we start consuming
+        # the body stream. The next write fails, so we just stop.
         self.assertEqual(
             [('connect_ssh', 'a user', 'a pass', 'a host', 'a port',
               ['bzr', 'serve', '--inet', '--directory=/', '--allow-writes']),
              ('close',),
             ],
             vendor.calls)
+        self.assertEqual('bzr message 3 (bzr 1.6)\n' # protocol
+                         '\x00\x00\x00\x02de'   # empty headers
+                         's\x00\x00\x00\tl5:helloe',
+                         output.getvalue())
 
 
 class LengthPrefixedBodyDecoder(tests.TestCase):



More information about the bazaar-commits mailing list