Rev 2433: Add Smart{Client, Server}RequestProtocolTwo, that prefix args tuples with a version marker. in http://bazaar.launchpad.net/~bzr
Andrew Bennetts
andrew.bennetts at canonical.com
Tue Apr 24 06:12:27 BST 2007
At http://bazaar.launchpad.net/~bzr
------------------------------------------------------------
revno: 2433
revision-id: andrew.bennetts at canonical.com-20070424051106-wwlidpflp1rwi3a7
parent: pqm at pqm.ubuntu.com-20070419224637-jvlshh6kibtj43a5
committer: Andrew Bennetts <andrew.bennetts at canonical.com>
branch nick: hpss-protocol2
timestamp: Tue 2007-04-24 15:11:06 +1000
message:
Add Smart{Client,Server}RequestProtocolTwo, that prefix args tuples with a version marker.
modified:
bzrlib/smart/protocol.py protocol.py-20061108035435-ot0lstk2590yqhzr-1
bzrlib/smart/request.py request.py-20061108095550-gunadhxmzkdjfeek-1
bzrlib/tests/test_smart_transport.py test_ssh_transport.py-20060608202016-c25gvf1ob7ypbus6-2
=== modified file 'bzrlib/smart/protocol.py'
--- a/bzrlib/smart/protocol.py 2007-04-10 14:12:35 +0000
+++ b/bzrlib/smart/protocol.py 2007-04-24 05:11:06 +0000
@@ -135,12 +135,19 @@
"""Send a smart server response down the output stream."""
assert not self._finished, 'response already sent'
self._finished = True
+ self._write_protocol_version()
self._write_func(_encode_tuple(args))
if body is not None:
assert isinstance(body, str), 'body must be a str'
bytes = self._encode_bulk_data(body)
self._write_func(bytes)
+ def _write_protocol_version(self):
+ """Write any prefixes this protocol requires.
+
+ Version one doesn't send protocol versions.
+ """
+
def next_read_size(self):
if self._finished:
return 0
@@ -150,6 +157,20 @@
return self._body_decoder.next_read_size()
+class SmartServerRequestProtocolTwo(SmartServerRequestProtocolOne):
+ r"""Version two of the server side of the smart protocol.
+
+ This prefixes responses with the protocol version: "2\x01".
+ """
+
+ def _write_protocol_version(self):
+ r"""Write any prefixes this protocol requires.
+
+ Version two sends "2\x01".
+ """
+ self._write_func('2\x01')
+
+
class LengthPrefixedBodyDecoder(object):
"""Decodes the length-prefixed bulk data."""
@@ -254,8 +275,7 @@
self._body_buffer = None
def call(self, *args):
- bytes = _encode_tuple(args)
- self._request.accept_bytes(bytes)
+ self._write_args(args)
self._request.finished_writing()
def call_with_body_bytes(self, args, body):
@@ -263,8 +283,7 @@
After calling this, call read_response_tuple to find the result out.
"""
- bytes = _encode_tuple(args)
- self._request.accept_bytes(bytes)
+ self._write_args(args)
bytes = self._encode_bulk_data(body)
self._request.accept_bytes(bytes)
self._request.finished_writing()
@@ -275,8 +294,7 @@
The body is encoded with one line per readv offset pair. The numbers in
each pair are separated by a comma, and no trailing \n is emitted.
"""
- bytes = _encode_tuple(args)
- self._request.accept_bytes(bytes)
+ self._write_args(args)
readv_bytes = self._serialise_offsets(body)
bytes = self._encode_bulk_data(readv_bytes)
self._request.accept_bytes(bytes)
@@ -336,8 +354,45 @@
resp = self.read_response_tuple()
if resp == ('ok', '1'):
return 1
+ elif resp == ('ok', '2'):
+ return 2
else:
raise errors.SmartProtocolError("bad response %r" % (resp,))
-
+ def _write_args(self, args):
+ self._write_protocol_version()
+ bytes = _encode_tuple(args)
+ self._request.accept_bytes(bytes)
+
+ def _write_protocol_version(self):
+ """Write any prefixes this protocol requires.
+
+ Version one doesn't send protocol versions.
+ """
+
+
+class SmartClientRequestProtocolTwo(SmartClientRequestProtocolOne):
+ r"""Version two of the client side of the smart protocol.
+
+ This prefixes the request with the protocol version: "2\x01".
+ """
+
+ _version_string = '2\x01'
+
+ def read_response_tuple(self, expect_body=False):
+ """Read a response tuple from the wire.
+
+ This should only be called once.
+ """
+ version = self._request.read_bytes(2)
+ if version != SmartClientRequestProtocolTwo._version_string:
+ raise errors.SmartProtocolError('bad protocol marker %r' % version)
+ return SmartClientRequestProtocolOne.read_response_tuple(self, expect_body)
+
+ def _write_protocol_version(self):
+ r"""Write any prefixes this protocol requires.
+
+ Version two sends "2\x01".
+ """
+ self._request.accept_bytes(SmartClientRequestProtocolTwo._version_string)
=== modified file 'bzrlib/smart/request.py'
--- a/bzrlib/smart/request.py 2007-04-11 02:01:18 +0000
+++ b/bzrlib/smart/request.py 2007-04-24 05:11:06 +0000
@@ -203,7 +203,7 @@
"""Answer a version request with my version."""
def do(self):
- return SmartServerResponse(('ok', '1'))
+ return SmartServerResponse(('ok', '2'))
class GetBundleRequest(SmartServerRequest):
=== modified file 'bzrlib/tests/test_smart_transport.py'
--- a/bzrlib/tests/test_smart_transport.py 2007-04-16 18:08:53 +0000
+++ b/bzrlib/tests/test_smart_transport.py 2007-04-24 05:11:06 +0000
@@ -605,7 +605,7 @@
smart_protocol = protocol.SmartServerRequestProtocolOne(transport,
from_server.write)
server._serve_one_request(smart_protocol)
- self.assertEqual('ok\0011\n',
+ self.assertEqual('ok\0012\n',
from_server.getvalue())
def test_response_to_canned_get(self):
@@ -992,12 +992,15 @@
class SmartServerCommandTests(tests.TestCaseWithTransport):
"""Tests that call directly into the command objects, bypassing the network
and the request dispatching.
+
+ Note: these tests are rudimentary versions of the command object tests in
+ test_remote.py.
"""
def test_hello(self):
cmd = request.HelloRequest(None)
response = cmd.execute()
- self.assertEqual(('ok', '1'), response.args)
+ self.assertEqual(('ok', '2'), response.args)
self.assertEqual(None, response.body)
def test_get_bundle(self):
@@ -1032,7 +1035,7 @@
def test_hello(self):
handler = self.build_handler(None)
handler.dispatch_command('hello', ())
- self.assertEqual(('ok', '1'), handler.response.args)
+ self.assertEqual(('ok', '2'), handler.response.args)
self.assertEqual(None, handler.response.body)
def test_disable_vfs_handler_classes_via_environment(self):
@@ -1157,7 +1160,7 @@
self._write_output_list = write_output_list
-class TestSmartProtocol(tests.TestCase):
+class TestSmartProtocolOne(tests.TestCase):
"""Tests for the smart protocol.
Each test case gets a smart_server and smart_client created during setUp().
@@ -1171,7 +1174,7 @@
"""
def setUp(self):
- super(TestSmartProtocol, self).setUp()
+ super(TestSmartProtocolOne, self).setUp()
# XXX: self.server_to_client doesn't seem to be used. If so,
# InstrumentedServerProtocol is redundant too.
self.server_to_client = []
@@ -1292,7 +1295,7 @@
smart_protocol = protocol.SmartServerRequestProtocolOne(
None, out_stream.write)
smart_protocol.accept_bytes('hello\nhello\n')
- self.assertEqual("ok\x011\n", out_stream.getvalue())
+ self.assertEqual("ok\x012\n", out_stream.getvalue())
self.assertEqual("hello\n", smart_protocol.excess_buffer)
self.assertEqual("", smart_protocol.in_buffer)
@@ -1311,7 +1314,7 @@
smart_protocol = protocol.SmartServerRequestProtocolOne(
None, out_stream.write)
smart_protocol.accept_bytes('hello\n')
- self.assertEqual("ok\x011\n", out_stream.getvalue())
+ self.assertEqual("ok\x012\n", out_stream.getvalue())
smart_protocol.accept_bytes('hel')
self.assertEqual("hel", smart_protocol.excess_buffer)
smart_protocol.accept_bytes('lo\n')
@@ -1335,12 +1338,12 @@
# accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
# response of tuple-encoded (ok, 1). Also, seperately we should test
# the error if the response is a non-understood version.
- input = StringIO('ok\x011\n')
+ input = StringIO('ok\x012\n')
output = StringIO()
client_medium = medium.SmartSimplePipesClientMedium(input, output)
request = client_medium.get_request()
smart_protocol = protocol.SmartClientRequestProtocolOne(request)
- self.assertEqual(1, smart_protocol.query_version())
+ self.assertEqual(2, smart_protocol.query_version())
def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
input_tuples):
@@ -1452,6 +1455,304 @@
errors.ReadingCompleted, smart_protocol.read_body_bytes)
+class TestSmartProtocolTwo(tests.TestCase):
+ """Tests for the smart protocol version two.
+
+ Each test case gets a smart_server and smart_client created during setUp().
+
+ It is planned that the client can be called with self.call_client() giving
+ it an expected server response, which will be fed into it when it tries to
+ read. Likewise, self.call_server will call a servers method with a canned
+ serialised client request. Output done by the client or server for these
+ calls will be captured to self.to_server and self.to_client. Each element
+ in the list is a write call from the client or server respectively.
+
+ This test case is mostly the same as TestSmartProtocolOne.
+ """
+
+ def setUp(self):
+ super(TestSmartProtocolTwo, self).setUp()
+ # XXX: self.server_to_client doesn't seem to be used. If so,
+ # InstrumentedServerProtocol is redundant too.
+ self.server_to_client = []
+ self.to_server = StringIO()
+ self.to_client = StringIO()
+ self.client_medium = medium.SmartSimplePipesClientMedium(self.to_client,
+ self.to_server)
+ self.client_protocol = protocol.SmartClientRequestProtocolTwo(
+ self.client_medium)
+ self.smart_server = InstrumentedServerProtocol(self.server_to_client)
+ self.smart_server_request = request.SmartServerRequestHandler(
+ None, request.request_handlers)
+
+ def assertOffsetSerialisation(self, expected_offsets, expected_serialised,
+ client):
+ """Check that smart (de)serialises offsets as expected.
+
+ We check both serialisation and deserialisation at the same time
+ to ensure that the round tripping cannot skew: both directions should
+ be as expected.
+
+ :param expected_offsets: a readv offset list.
+ :param expected_seralised: an expected serial form of the offsets.
+ """
+ # XXX: '_deserialise_offsets' should be a method of the
+ # SmartServerRequestProtocol in future.
+ readv_cmd = vfs.ReadvRequest(None)
+ offsets = readv_cmd._deserialise_offsets(expected_serialised)
+ self.assertEqual(expected_offsets, offsets)
+ serialised = client._serialise_offsets(offsets)
+ self.assertEqual(expected_serialised, serialised)
+
+ def build_protocol_waiting_for_body(self):
+ out_stream = StringIO()
+ smart_protocol = protocol.SmartServerRequestProtocolTwo(None,
+ out_stream.write)
+ smart_protocol.has_dispatched = True
+ smart_protocol.request = self.smart_server_request
+ class FakeCommand(object):
+ def do_body(cmd, body_bytes):
+ self.end_received = True
+ self.assertEqual('abcdefg', body_bytes)
+ return request.SmartServerResponse(('ok', ))
+ smart_protocol.request._command = FakeCommand()
+ # Call accept_bytes to make sure that internal state like _body_decoder
+ # is initialised. This test should probably be given a clearer
+ # interface to work with that will not cause this inconsistency.
+ # -- Andrew Bennetts, 2006-09-28
+ smart_protocol.accept_bytes('')
+ return smart_protocol
+
+ def test_construct_version_two_server_protocol(self):
+ smart_protocol = protocol.SmartServerRequestProtocolTwo(None, None)
+ self.assertEqual('', smart_protocol.excess_buffer)
+ self.assertEqual('', smart_protocol.in_buffer)
+ self.assertFalse(smart_protocol.has_dispatched)
+ self.assertEqual(1, smart_protocol.next_read_size())
+
+ def test_construct_version_two_client_protocol(self):
+ # we can construct a client protocol from a client medium request
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(None, output)
+ request = client_medium.get_request()
+ client_protocol = protocol.SmartClientRequestProtocolTwo(request)
+
+ def test_server_offset_serialisation(self):
+ """The Smart protocol serialises offsets as a comma and \n string.
+
+ We check a number of boundary cases are as expected: empty, one offset,
+ one with the order of reads not increasing (an out of order read), and
+ one that should coalesce.
+ """
+ self.assertOffsetSerialisation([], '', self.client_protocol)
+ self.assertOffsetSerialisation([(1,2)], '1,2', self.client_protocol)
+ self.assertOffsetSerialisation([(10,40), (0,5)], '10,40\n0,5',
+ self.client_protocol)
+ self.assertOffsetSerialisation([(1,2), (3,4), (100, 200)],
+ '1,2\n3,4\n100,200', self.client_protocol)
+
+ def test_accept_bytes_of_bad_request_to_protocol(self):
+ out_stream = StringIO()
+ smart_protocol = protocol.SmartServerRequestProtocolTwo(
+ None, out_stream.write)
+ smart_protocol.accept_bytes('abc')
+ self.assertEqual('abc', smart_protocol.in_buffer)
+ smart_protocol.accept_bytes('\n')
+ self.assertEqual(
+ "2\x01error\x01Generic bzr smart protocol error: bad request 'abc'\n",
+ out_stream.getvalue())
+ self.assertTrue(smart_protocol.has_dispatched)
+ self.assertEqual(0, smart_protocol.next_read_size())
+
+ def test_accept_body_bytes_to_protocol(self):
+ protocol = self.build_protocol_waiting_for_body()
+ self.assertEqual(6, protocol.next_read_size())
+ protocol.accept_bytes('7\nabc')
+ self.assertEqual(9, protocol.next_read_size())
+ protocol.accept_bytes('defgd')
+ protocol.accept_bytes('one\n')
+ self.assertEqual(0, protocol.next_read_size())
+ self.assertTrue(self.end_received)
+
+ def test_accept_request_and_body_all_at_once(self):
+ self._captureVar('BZR_NO_SMART_VFS', None)
+ mem_transport = memory.MemoryTransport()
+ mem_transport.put_bytes('foo', 'abcdefghij')
+ out_stream = StringIO()
+ smart_protocol = protocol.SmartServerRequestProtocolTwo(mem_transport,
+ out_stream.write)
+ smart_protocol.accept_bytes('readv\x01foo\n3\n3,3done\n')
+ self.assertEqual(0, smart_protocol.next_read_size())
+ self.assertEqual('2\x01readv\n3\ndefdone\n', out_stream.getvalue())
+ self.assertEqual('', smart_protocol.excess_buffer)
+ self.assertEqual('', smart_protocol.in_buffer)
+
+ def test_accept_excess_bytes_are_preserved(self):
+ out_stream = StringIO()
+ smart_protocol = protocol.SmartServerRequestProtocolTwo(
+ None, out_stream.write)
+ smart_protocol.accept_bytes('hello\nhello\n')
+ self.assertEqual("2\x01ok\x012\n", out_stream.getvalue())
+ self.assertEqual("hello\n", smart_protocol.excess_buffer)
+ self.assertEqual("", smart_protocol.in_buffer)
+
+ def test_accept_excess_bytes_after_body(self):
+ # The excess bytes look like the start of another request.
+ protocol = self.build_protocol_waiting_for_body()
+ protocol.accept_bytes('7\nabcdefgdone\n2\x01')
+ self.assertTrue(self.end_received)
+ self.assertEqual("2\x01", protocol.excess_buffer)
+ self.assertEqual("", protocol.in_buffer)
+ protocol.accept_bytes('Y')
+ self.assertEqual("2\x01Y", protocol.excess_buffer)
+ self.assertEqual("", protocol.in_buffer)
+
+ def test_accept_excess_bytes_after_dispatch(self):
+ out_stream = StringIO()
+ smart_protocol = protocol.SmartServerRequestProtocolTwo(
+ None, out_stream.write)
+ smart_protocol.accept_bytes('hello\n')
+ self.assertEqual("2\x01ok\x012\n", out_stream.getvalue())
+ smart_protocol.accept_bytes('2\x01hel')
+ self.assertEqual("2\x01hel", smart_protocol.excess_buffer)
+ smart_protocol.accept_bytes('lo\n')
+ self.assertEqual("2\x01hello\n", smart_protocol.excess_buffer)
+ self.assertEqual("", smart_protocol.in_buffer)
+
+ def test__send_response_sets_finished_reading(self):
+ smart_protocol = protocol.SmartServerRequestProtocolTwo(
+ None, lambda x: None)
+ self.assertEqual(1, smart_protocol.next_read_size())
+ smart_protocol._send_response(('x',))
+ self.assertEqual(0, smart_protocol.next_read_size())
+
+ def test_query_version(self):
+ """query_version on a SmartClientProtocolTwo should return a number.
+
+ The protocol provides the query_version because the domain level clients
+ may all need to be able to probe for capabilities.
+ """
+ # What we really want to test here is that SmartClientProtocolTwo calls
+ # accept_bytes(tuple_based_encoding_of_hello) and reads and parses the
+ # response of tuple-encoded (ok, 1). Also, seperately we should test
+ # the error if the response is a non-understood version.
+ input = StringIO('2\x01ok\x012\n')
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(input, output)
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
+ self.assertEqual(2, smart_protocol.query_version())
+
+ def assertServerToClientEncoding(self, expected_bytes, expected_tuple,
+ input_tuples):
+ """Assert that each input_tuple serialises as expected_bytes, and the
+ bytes deserialise as expected_tuple.
+ """
+ # check the encoding of the server for all input_tuples matches
+ # expected bytes
+ for input_tuple in input_tuples:
+ server_output = StringIO()
+ server_protocol = protocol.SmartServerRequestProtocolTwo(
+ None, server_output.write)
+ server_protocol._send_response(input_tuple)
+ self.assertEqual(expected_bytes, server_output.getvalue())
+ # check the decoding of the client smart_protocol from expected_bytes:
+ input = StringIO(expected_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(input, output)
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
+ smart_protocol.call('foo')
+ self.assertEqual(expected_tuple, smart_protocol.read_response_tuple())
+
+ def test_client_call_empty_response(self):
+ # protocol.call() can get back an empty tuple as a response. This occurs
+ # when the parsed line is an empty line, and results in a tuple with
+ # one element - an empty string.
+ self.assertServerToClientEncoding('2\x01\n', ('', ), [(), ('', )])
+
+ def untest_client_call_three_element_response(self):
+ # protocol.call() can get back tuples of other lengths. A three element
+ # tuple should be unpacked as three strings.
+ self.assertServerToClientEncoding('2\x01a\x01b\x0134\n', ('a', 'b', '34'),
+ [('a', 'b', '34')])
+
+ def test_client_call_with_body_bytes_uploads(self):
+ # protocol.call_with_body_bytes should length-prefix the bytes onto the
+ # wire.
+ expected_bytes = "2\x01foo\n7\nabcdefgdone\n"
+ input = StringIO("\n")
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(input, output)
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
+ smart_protocol.call_with_body_bytes(('foo', ), "abcdefg")
+ self.assertEqual(expected_bytes, output.getvalue())
+
+ def test_client_call_with_body_readv_array(self):
+ # protocol.call_with_upload should encode the readv array and then
+ # length-prefix the bytes onto the wire.
+ expected_bytes = "2\x01foo\n7\n1,2\n5,6done\n"
+ input = StringIO("\n")
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(input, output)
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
+ smart_protocol.call_with_body_readv_array(('foo', ), [(1,2),(5,6)])
+ self.assertEqual(expected_bytes, output.getvalue())
+
+ def test_client_read_body_bytes_all(self):
+ # read_body_bytes should decode the body bytes from the wire into
+ # a response.
+ expected_bytes = "1234567"
+ server_bytes = "2\x01ok\n7\n1234567done\n"
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(input, output)
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ self.assertEqual(expected_bytes, smart_protocol.read_body_bytes())
+
+ def test_client_read_body_bytes_incremental(self):
+ # test reading a few bytes at a time from the body
+ # XXX: possibly we should test dribbling the bytes into the stringio
+ # to make the state machine work harder: however, as we use the
+ # LengthPrefixedBodyDecoder that is already well tested - we can skip
+ # that.
+ expected_bytes = "1234567"
+ server_bytes = "2\x01ok\n7\n1234567done\n"
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(input, output)
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ self.assertEqual(expected_bytes[0:2], smart_protocol.read_body_bytes(2))
+ self.assertEqual(expected_bytes[2:4], smart_protocol.read_body_bytes(2))
+ self.assertEqual(expected_bytes[4:6], smart_protocol.read_body_bytes(2))
+ self.assertEqual(expected_bytes[6], smart_protocol.read_body_bytes())
+
+ def test_client_cancel_read_body_does_not_eat_body_bytes(self):
+ # cancelling the expected body needs to finish the request, but not
+ # read any more bytes.
+ expected_bytes = "1234567"
+ server_bytes = "2\x01ok\n7\n1234567done\n"
+ input = StringIO(server_bytes)
+ output = StringIO()
+ client_medium = medium.SmartSimplePipesClientMedium(input, output)
+ request = client_medium.get_request()
+ smart_protocol = protocol.SmartClientRequestProtocolTwo(request)
+ smart_protocol.call('foo')
+ smart_protocol.read_response_tuple(True)
+ smart_protocol.cancel_read_body()
+ self.assertEqual(5, input.tell())
+ self.assertRaises(
+ errors.ReadingCompleted, smart_protocol.read_body_bytes)
+
+
class TestSmartClientUnicode(tests.TestCase):
"""_SmartClient tests for unicode arguments.
@@ -1611,7 +1912,7 @@
self.addCleanup(http_server.tearDown)
post_body = 'hello\n'
- expected_reply_body = 'ok\x011\n'
+ expected_reply_body = 'ok\x012\n'
http_transport = get_transport(http_server.get_url())
medium = http_transport.get_smart_medium()
@@ -1632,7 +1933,7 @@
self.transport_readonly_server = HTTPServerWithSmarts
post_body = 'hello\n'
- expected_reply_body = 'ok\x011\n'
+ expected_reply_body = 'ok\x012\n'
smart_server_url = self.get_readonly_url('.bzr/smart')
reply = urllib2.urlopen(smart_server_url, post_body).read()
@@ -1658,7 +1959,7 @@
response = socket.writefile.getvalue()
self.assertStartsWith(response, 'HTTP/1.0 200 ')
# This includes the end of the HTTP headers, and all the body.
- expected_end_of_response = '\r\n\r\nok\x011\n'
+ expected_end_of_response = '\r\n\r\nok\x012\n'
self.assertEndsWith(response, expected_end_of_response)
More information about the bazaar-commits
mailing list