Rev 2207: Initial support for obtaining a graph directly from the repository using a in file:///home/robertc/source/baz/hpss-remotegraph/

Robert Collins robertc at robertcollins.net
Fri Apr 13 08:40:37 BST 2007


At file:///home/robertc/source/baz/hpss-remotegraph/

------------------------------------------------------------
revno: 2207
revision-id: robertc at robertcollins.net-20070413073942-qvb37falcu2b5dr9
parent: robertc at robertcollins.net-20070409221942-d0t2lemx16nyn2ro
committer: Robert Collins <robertc at robertcollins.net>
branch nick: hpss-remotegraph
timestamp: Fri 2007-04-13 17:39:42 +1000
message:
  Initial support for obtaining a graph directly from the repository using a
  repository supplied description of the graph. This is the output format
  that smart-negotiation routines will use.
modified:
  bzrlib/graph.py                graph.py-20050905070950-b47dce53236c5e48
  bzrlib/remote.py               remote.py-20060720103555-yeeg2x51vn0rbtdp-1
  bzrlib/smart/repository.py     repository.py-20061128022038-vr5wy5bubyb8xttk-1
  bzrlib/smart/request.py        request.py-20061108095550-gunadhxmzkdjfeek-1
  bzrlib/tests/test_remote.py    test_remote.py-20060720103555-yeeg2x51vn0rbtdp-2
  bzrlib/tests/test_smart.py     test_smart.py-20061122024551-ol0l0o0oofsu9b3t-2
=== modified file 'bzrlib/graph.py'
--- a/bzrlib/graph.py	2007-04-03 07:43:02 +0000
+++ b/bzrlib/graph.py	2007-04-13 07:39:42 +0000
@@ -120,11 +120,16 @@
         self._graph_descendants = None
 
     def __eq__(self, other):
+        """Two Graphs are equal when have the same contents."""
         if not isinstance(other, Graph):
             return False
-        if self._graph_descendants and not other._graph_descendants:
-            self._generate_descendants()
-        return self.__dict__ == other.__dict__
+        if self.get_ancestors() != other.get_ancestors():
+            return False
+        if self.ghosts != other.ghosts:
+            return False
+        if self.get_descendants() != other.get_descendants():
+            return False
+        return True
 
     def add_ghost(self, node_id):
         """Add a ghost to the graph."""

=== modified file 'bzrlib/remote.py'
--- a/bzrlib/remote.py	2007-04-05 09:35:26 +0000
+++ b/bzrlib/remote.py	2007-04-13 07:39:42 +0000
@@ -26,6 +26,7 @@
 from bzrlib.config import BranchConfig, TreeConfig
 from bzrlib.decorators import needs_read_lock, needs_write_lock
 from bzrlib.errors import NoSuchRevision
+from bzrlib.graph import Graph
 from bzrlib.lockable_files import LockableFiles
 from bzrlib.revision import NULL_REVISION
 from bzrlib.smart import client, vfs
@@ -257,7 +258,7 @@
         elif revision_id == NULL_REVISION:
             return {}
 
-        path = self.bzrdir._path_for_remote_call(self._client)
+        path = self._path_for_remote_call()
         assert type(revision_id) is str
         response = self._client.call2(
             'Repository.get_revision_graph', path, revision_id)
@@ -285,14 +286,14 @@
         if revision_id is None:
             # The null revision is always present.
             return True
-        path = self.bzrdir._path_for_remote_call(self._client)
+        path = self._path_for_remote_call()
         response = self._client.call('Repository.has_revision', path, revision_id)
         assert response[0] in ('ok', 'no'), 'unexpected response code %s' % (response,)
         return response[0] == 'ok'
 
     def gather_stats(self, revid=None, committers=None):
         """See Repository.gather_stats()."""
-        path = self.bzrdir._path_for_remote_call(self._client)
+        path = self._path_for_remote_call()
         if revid in (None, NULL_REVISION):
             fmt_revid = ''
         else:
@@ -326,7 +327,7 @@
 
     def is_shared(self):
         """See Repository.is_shared()."""
-        path = self.bzrdir._path_for_remote_call(self._client)
+        path = self._path_for_remote_call()
         response = self._client.call('Repository.is_shared', path)
         assert response[0] in ('yes', 'no'), 'unexpected response code %s' % (response,)
         return response[0] == 'yes'
@@ -342,7 +343,7 @@
             self._lock_count += 1
 
     def _remote_lock_write(self, token):
-        path = self.bzrdir._path_for_remote_call(self._client)
+        path = self._path_for_remote_call()
         if token is None:
             token = ''
         response = self._client.call('Repository.lock_write', path, token)
@@ -377,6 +378,9 @@
     def leave_lock_in_place(self):
         self._leave_lock = True
 
+    def _path_for_remote_call(self):
+        return self.bzrdir._path_for_remote_call(self._client)
+
     def dont_leave_lock_in_place(self):
         self._leave_lock = False
 
@@ -396,7 +400,7 @@
             self._real_repository.lock_read()
 
     def _unlock(self, token):
-        path = self.bzrdir._path_for_remote_call(self._client)
+        path = self._path_for_remote_call()
         response = self._client.call('Repository.unlock', path, token)
         if response == ('ok',):
             return
@@ -962,3 +966,66 @@
             self._branch_data_config = TreeConfig(self.branch._real_branch)
         return self._branch_data_config
 
+
+class RemoteGraph(Graph):
+    """A graph whose data is held by the remote server.
+    
+    :var name: The name of the graph.
+    """
+
+    def __init__(self, repository, name, remote_memo):
+        """Construct a remote graph on repository.
+        
+        :param repository: A RemoteRepository.
+        :param name: The graphs name in the repository.
+        :param remote_memo: The memoised graph in the repository.
+        """
+        Graph.__init__(self)
+        self.repository = repository
+        self.name = name
+        self.remote_memo = remote_memo
+        self._loaded = False
+
+    def get_ancestors(self):
+        if not self._loaded:
+            self._load()
+        return dict(self._graph_ancestors.items())
+
+    def _load(self):
+        """Load the graph from the repository."""
+        path = self.repository._path_for_remote_call()
+        response, protocol = self.repository._client.call2(
+            'Repository.get_graph_content', path, self.name, self.remote_memo)
+        # should this really read unconditionally ?
+        coded_content = protocol.read_body_bytes()
+        if response[0] == 'NoSuchFile':
+            raise errors.NoSuchFile(self, self.name)
+        else:
+            if coded_content == '':
+                self._graph_ancestors = {}
+                self.ghosts = set()
+            else:
+                # wire protocol for this verb:
+                # node_id := UTF8TEXT
+                # parent_id := node_id
+                # ghost_id := node_id
+                # ghostlist := ghost_id*
+                # nodeline := node_id parent_id*\n 
+                #
+                # content_content = ghostlist '\0' nodeline*
+                ghosts, nodes = coded_content.split('\0')
+                if ghosts:
+                    self.ghosts = set(ghosts.split(' '))
+                if nodes:
+                    lines = nodes.split('\n')
+                    revision_graph = {}
+                    for line in lines:
+                        node_ids = list(line.split())
+                        self.add_node(node_ids[0], node_ids[1:])
+        self._loaded = True
+
+    def __repr__(self):
+        return "RemoteGraph repository=%s name=%s ancestors=%s, "\
+            "descendants=%r, ghosts=%r" % (
+            self.repository, self.name, self._graph_ancestors,
+            self._graph_descendants, self.ghosts)

=== modified file 'bzrlib/smart/repository.py'
--- a/bzrlib/smart/repository.py	2007-03-13 05:52:01 +0000
+++ b/bzrlib/smart/repository.py	2007-04-13 07:39:42 +0000
@@ -17,7 +17,7 @@
 """Server-side repository related request implmentations."""
 
 
-from bzrlib import errors
+from bzrlib import errors, graph
 from bzrlib.bzrdir import BzrDir
 from bzrlib.smart.request import SmartServerRequest, SmartServerResponse
 
@@ -41,6 +41,51 @@
         return self.do_repository_request(repository, *args)
 
 
+class SmartServerRepositoryGetGraphContent(SmartServerRepositoryRequest):
+    
+    def do_repository_request(self, repository, graph_name, graph_description):
+        """Return the contents of a graph previously promised to the client.
+
+        :param repository: The repository to query in.
+        :param graph_name: The subgraph from the repository to retrieve.
+            Currently only 'revisions' is supported.
+        :param graph_description: A description of the graph previously sent
+            by this SmartRespository to the client : it is opaque to the client.
+        :return: A smart server response where the body contains the graph
+            content encoded:
+                # wire protocol for this verb:
+                # node_id := UTF8TEXT
+                # parent_id := node_id
+                # ghost_id := node_id
+                # ghostlist := ghost_id*
+                # nodeline := node_id parent_id*\n 
+                #
+                # content_content = ghostlist '\0' nodeline*
+        """
+        if graph_name != 'revisions':
+            return SmartServerResponse(('NoSuchFile', graph_name))
+        description = GraphDescription.deserialise(graph_description)
+        # hmm, how to factor this cleanly into the GraphDescriptionType
+        # API ? We dont want to call get_revision_graph_with_ghosts 
+        # unnecessarily - we can use the heads list from the 
+        # delta *when* the description is a delta to short circuit.
+        assert isinstance(description, GraphDeltaDescription), \
+            'Expected a GraphDeltaDescription.'
+        full_graph = repository.get_revision_graph_with_ghosts(
+            description.delta.heads)
+        revision_graph = description.delta.apply_to(full_graph)
+        # now encode the contents of revision graph:
+        # ghosts first
+        ghosts = ' '.join(revision_graph.ghosts)
+        node_lines = []
+        for revision, parents in revision_graph.get_ancestors().items():
+            node_lines.append(' '.join([revision,] + parents))
+        encoded_graph = '\0'.join((ghosts, '\n'.join(node_lines)))
+
+        return SmartServerResponse(('ok', ), encoded_graph)
+
+
+
 class SmartServerRepositoryGetRevisionGraph(SmartServerRepositoryRequest):
     
     def do_repository_request(self, repository, revision_id):
@@ -173,3 +218,77 @@
         repository.unlock()
         return SmartServerResponse(('ok',))
 
+
+# XXX: Should this be in graph.py? It is smart server only, but ...
+GraphDescriptionTypes = {}
+
+
+class GraphDescription(object):
+    """A description of a graph for sending to a client and reconstructing later."""
+
+    @staticmethod
+    def deserialise(encoded_description):
+        """Decode encoded_description into a GraphDescription."""
+        type_string, encoded_details = encoded_description.split('\0', 1)
+        for klass, registered_type_string in GraphDescriptionTypes.items():
+            if type_string == registered_type_string:
+                result = klass()
+                result.decode(encoded_details)
+                return result
+        raise KeyError(type_string)
+
+    def serialise(self):
+        """Serialise this GraphDescription.
+
+        Serialisation consists of writing the type string, a \0, and then the
+        output of self.encode().
+        """
+        return '\0'.join([self.type_string(), self.encode()])
+
+    def type_string(self):
+        """Return a string that represents the type of this graph description."""
+        return GraphDescriptionTypes[self.__class__]
+
+
+class GraphDeltaDescription(GraphDescription):
+    """A GraphDescription of a graph using a GraphDelta to model the graph.
+    
+    This uses the GraphDelta to describe the graph in terms of a super graph
+    held by the server.
+    """
+
+    def __init__(self, full_graph=None, delta=None):
+        """Constructor.
+
+        :param full_graph: The full graph the delta is to be applied to.
+        :param delta: A graph delta used to obtain the desired final graph
+            from full_graph.
+        """
+        self.full_graph = full_graph
+        self.delta = delta
+
+    def decode(self, coded_content):
+        """Decode coded_content into a GraphDelta.
+
+        :param coded_content: The content as encoded by self.encode.
+        """
+        self.delta = graph.GraphDelta()
+        heads, cut_from = coded_content.split('\0', 1)
+        if heads:
+            self.delta.heads = set(heads.split(' '))
+        if cut_from:
+            self.delta.cut_from = set(cut_from.split(' '))
+
+    def encode(self):
+        """Encodes the graph delta for transmission.
+
+        If the delta does not have a heads list, one is obtained from the
+        fullgraph, to allow the fullgraph to grow without confusing things.
+        """
+        if not self.delta.heads:
+            self.delta.heads = self.full_graph.get_heads()
+        return ' '.join(self.delta.heads) + '\0' + ' '.join(self.delta.cut_from)
+
+GraphDescriptionTypes[GraphDeltaDescription] = 'Delta'
+
+

=== modified file 'bzrlib/smart/request.py'
--- a/bzrlib/smart/request.py	2007-03-13 05:52:01 +0000
+++ b/bzrlib/smart/request.py	2007-04-13 07:39:42 +0000
@@ -292,6 +292,8 @@
                                'bzrlib.smart.repository',
                                'SmartServerRepositoryGatherStats')
 request_handlers.register_lazy(
+    'Repository.get_graph_content', 'bzrlib.smart.repository', 'SmartServerRepositoryGetGraphContent')
+request_handlers.register_lazy(
     'Repository.get_revision_graph', 'bzrlib.smart.repository', 'SmartServerRepositoryGetRevisionGraph')
 request_handlers.register_lazy(
     'Repository.has_revision', 'bzrlib.smart.repository', 'SmartServerRequestHasRevision')

=== modified file 'bzrlib/tests/test_remote.py'
--- a/bzrlib/tests/test_remote.py	2007-04-05 09:35:26 +0000
+++ b/bzrlib/tests/test_remote.py	2007-04-13 07:39:42 +0000
@@ -31,10 +31,12 @@
     )
 from bzrlib.branch import Branch
 from bzrlib.bzrdir import BzrDir, BzrDirFormat
+from bzrlib.graph import Graph
 from bzrlib.remote import (
     RemoteBranch,
     RemoteBzrDir,
     RemoteBzrDirFormat,
+    RemoteGraph,
     RemoteRepository,
     )
 from bzrlib.revision import NULL_REVISION
@@ -600,3 +602,97 @@
 
         # The remote repo shouldn't be accessed.
         self.assertEqual([], client._calls)
+
+
+class TestRemoteGraph(TestRemoteRepository):
+    """Tests for RemoteGraph.
+
+    RemoteGraphs represent a named graph in a remote repository. As such there
+    are issues with server state changing, and with graphs representing
+    absent remote graphs.
+    """
+
+    def test_get_ancestors_no_graph(self):
+        """When there is no graph 'name', any operation will fail.
+
+        This error is raised by the current Repository stores, so seems a 
+        reasonable fit.
+        """
+        responses = [(('NoSuchFile', 'foo'), '')]
+        transport_path = 'quack'
+        repo, client = self.setup_fake_client_and_repository(
+            responses, transport_path)
+        graph = RemoteGraph(repo, 'name', '')
+        self.assertRaises(errors.NoSuchFile, graph.get_ancestors)
+        self.assertEqual(
+            [('call2', 'Repository.get_graph_content', ('///quack/', 'name', ''))],
+            client._calls)
+
+    def test_empty_graph(self):
+        """When the graph is completely empty, it should still load once."""
+        responses = [(('ok', ), '\0')]
+        transport_path = 'quack'
+        repo, client = self.setup_fake_client_and_repository(
+            responses, transport_path)
+        graph = RemoteGraph(repo, 'name', '')
+        result_graph = Graph()
+        self.assertEqual(result_graph, graph)
+        # and access the graph again to ensure it cached the response
+        graph.get_ancestors()
+        graph.get_descendants()
+        self.assertEqual(
+            [('call2', 'Repository.get_graph_content', ('///quack/', 'name', ''))],
+            client._calls)
+
+    def test_ghosts_only_graph(self):
+        """When there are only ghosts, it should parse ok."""
+        responses = [(('ok', ), 'a_ghost\0')]
+        transport_path = 'quack'
+        repo, client = self.setup_fake_client_and_repository(
+            responses, transport_path)
+        graph = RemoteGraph(repo, 'name', '')
+        result_graph = Graph()
+        result_graph.add_ghost('a_ghost')
+        self.assertEqual(result_graph, graph)
+        # and access the graph again to ensure it cached the response
+        graph.get_ancestors()
+        self.assertEqual(
+            [('call2', 'Repository.get_graph_content', ('///quack/', 'name', ''))],
+            client._calls)
+
+    def test_no_ghosts_trivial_graph(self):
+        """When there are no ghosts, a trivial graph with one node - parses."""
+        responses = [(('ok', ), '\0node_id')]
+        transport_path = 'quack'
+        repo, client = self.setup_fake_client_and_repository(
+            responses, transport_path)
+        graph = RemoteGraph(repo, 'name', '')
+        result_graph = Graph()
+        result_graph.add_node('node_id', [])
+        self.assertEqual(result_graph, graph)
+        # and access the graph again to ensure it cached the response
+        graph.get_ancestors()
+        self.assertEqual(
+            [('call2', 'Repository.get_graph_content', ('///quack/', 'name', ''))],
+            client._calls)
+
+    def test_ghosts_and_nodes(self):
+        """Graphs with ghosts and nodes should parse ok."""
+        responses = [(('ok', ),
+            'ghost1 ghost2\0node_id1 parent1\n'
+            'node_id2 parent1 parent2')]
+        transport_path = 'quack'
+        repo, client = self.setup_fake_client_and_repository(
+            responses, transport_path)
+        graph = RemoteGraph(repo, 'name', '')
+        result_graph = Graph()
+        result_graph.add_ghost('ghost1')
+        result_graph.add_ghost('ghost2')
+        result_graph.add_node('node_id1', ['parent1'])
+        result_graph.add_node('node_id2', ['parent1', 'parent2'])
+        self.assertEqual(result_graph, graph)
+        # and access the graph again to ensure it cached the response
+        graph.get_ancestors()
+        self.assertEqual(
+            [('call2', 'Repository.get_graph_content', ('///quack/', 'name', ''))],
+            client._calls)

=== modified file 'bzrlib/tests/test_smart.py'
--- a/bzrlib/tests/test_smart.py	2007-04-05 09:35:26 +0000
+++ b/bzrlib/tests/test_smart.py	2007-04-13 07:39:42 +0000
@@ -17,11 +17,16 @@
 """Tests for the smart wire/domain protococl."""
 
 from bzrlib import bzrdir, errors, smart, tests
+from bzrlib.graph import Graph, GraphDelta
 from bzrlib.smart.request import SmartServerResponse
 import bzrlib.smart.bzrdir
 import bzrlib.smart.branch
 import bzrlib.smart.repository
-
+from bzrlib.smart.repository import (
+    GraphDeltaDescription,
+    GraphDescription,
+    GraphDescriptionTypes,
+    )
 
 class TestCaseWithSmartMedium(tests.TestCaseWithTransport):
 
@@ -725,6 +730,61 @@
             SmartServerResponse(('TokenMismatch',)), response)
 
 
+class TestSmartServerRepositoryGetGraphContent(tests.TestCaseWithTransport):
+
+    def test_get_graph_content_bad_graph_name(self):
+        """When a bad graph name is asked for, NoSuchFile is returned."""
+        backing = self.get_transport()
+        request = smart.repository.SmartServerRepositoryGetGraphContent(backing)
+        repository = self.make_repository('.')
+        # make a request for the repository with no graph name.
+        response = request.execute(backing.local_abspath(''), '', '')
+        self.assertEqual(
+            SmartServerResponse(('NoSuchFile', '')), response)
+
+    def test_get_graph_content_empty_graph(self):
+        """When a graph with no content is asked for, its encoded correctly."""
+        backing = self.get_transport()
+        request = smart.repository.SmartServerRepositoryGetGraphContent(backing)
+        repository = self.make_repository('.')
+        # make a request for the entire repository revision graph.
+        # get a description to ask for this
+        description = GraphDeltaDescription(
+            repository.get_revision_graph_with_ghosts(),
+            GraphDelta()).serialise()
+        response = request.execute(backing.local_abspath(''), 'revisions', description)
+        self.assertEqual(
+            SmartServerResponse(('ok', ), '\0'), response)
+
+    def test_get_graph_content_normal_graph(self):
+        """When a regular graph is asked for, its encoded correctly."""
+        backing = self.get_transport()
+        request = smart.repository.SmartServerRepositoryGetGraphContent(backing)
+        tree = self.make_branch_and_memory_tree('.')
+        tree.lock_write()
+        tree.add('')
+        r1 = tree.commit('1st commit')
+        r2 = tree.commit('2nd commit', rev_id=u'\xc8'.encode('utf8'))
+        tree.unlock()
+        repository = tree.branch.repository
+        # make a request for the entire repository revision graph.
+        # get a description to ask for this
+        description = GraphDeltaDescription(
+            repository.get_revision_graph_with_ghosts(),
+            GraphDelta()).serialise()
+        response = request.execute(backing.local_abspath(''), 'revisions',
+            description)
+        # encoding order doesn't matter, so we can safely use the dict here:
+        # the ordering we get will match the server logic within this test.
+        # the client is already tested for order-safeness.
+        expected_lines = []
+        for revision, parents in {r2:[r1], r1:[]}.items():
+            expected_lines.append(' '.join([revision,] + parents))
+        expected_body = '\0' + '\n'.join(expected_lines)
+        self.assertEqual(
+            SmartServerResponse(('ok', ), expected_body), response)
+
+
 class TestSmartServerIsReadonly(tests.TestCaseWithTransport):
 
     def test_is_readonly_no(self):
@@ -778,6 +838,9 @@
             smart.request.request_handlers.get('Repository.gather_stats'),
             smart.repository.SmartServerRepositoryGatherStats)
         self.assertEqual(
+            smart.request.request_handlers.get('Repository.get_graph_content'),
+            smart.repository.SmartServerRepositoryGetGraphContent)
+        self.assertEqual(
             smart.request.request_handlers.get('Repository.get_revision_graph'),
             smart.repository.SmartServerRepositoryGetRevisionGraph)
         self.assertEqual(
@@ -795,3 +858,151 @@
         self.assertEqual(
             smart.request.request_handlers.get('Transport.is_readonly'),
             smart.request.SmartServerIsReadonly)
+
+
+class SampleGraphDescription(GraphDescription):
+    """A sample GraphDescription for testing common logic."""
+
+    def __init__(self):
+        self.calls = []
+
+    def decode(self, encoded_content):
+        self.calls.append(('decode',))
+        self.coded_content = encoded_content
+
+    def encode(self):
+        self.calls.append(('encode',))
+        return 'encoded'
+
+
+class TestGraphDescription(tests.TestCase):
+    """Tests for the GraphDescription coding logic."""
+
+    def setUp(self):
+        GraphDescriptionTypes[SampleGraphDescription] = 'sample'
+        def removeSample():
+            del GraphDescriptionTypes[SampleGraphDescription]
+        self.addCleanup(removeSample)
+
+    def test_construct(self):
+        # A GraphDescription has no args to the cosntructor
+        # This is because it is intended to be subclassed.
+        GraphDescription()
+
+    def test_type_string_bogus(self):
+        # A GraphDescription subclass that is not registered in
+        # GraphDescriptionTypes causes a KeyError during type_string()
+        class UnregisteredGraphDescription(GraphDescription):
+            pass
+        self.assertRaises(KeyError,
+            UnregisteredGraphDescription().type_string)
+
+    def test_type_string_registered(self):
+        # A GraphDescription subclass that is registered in
+        # GraphDescriptionTypes gets the value back.
+        class RegisteredGraphDescription(GraphDescription):
+            pass
+        GraphDescriptionTypes[RegisteredGraphDescription] = 'value'
+        try:
+            self.assertEqual('value', RegisteredGraphDescription().type_string())
+        except:
+            del GraphDescriptionTypes[RegisteredGraphDescription]
+
+    def test_serialise(self):
+        # A GraphDescription object has its 'encode' method called when it
+        # needs to be serialised to a byte string. This is appended after
+        # the type string by the base class.
+        sample = SampleGraphDescription()
+        self.assertEqual('sample\0encoded', sample.serialise())
+        self.assertEqual([('encode',)], sample.calls)
+
+    def test_deserialise_bad_type_prefix(self):
+        # deserialisation when a bad type prefix is given raises KeyError.
+        self.assertRaises(KeyError, GraphDescription.deserialise, 'unregistered\0')
+
+    def test_deserialise_calls_decode(self):
+        # deserialisation calls decode on a fresh instance of the type it
+        # found.
+        result = GraphDescription.deserialise('sample\0something here')
+        self.assertEqual([('decode', )], result.calls)
+        self.assertEqual('something here', result.coded_content)
+
+
+class TestGraphDeltaDescription(tests.TestCase):
+    """Tests for the GraphDeltaDescription functionality."""
+
+    def test_registered(self):
+        self.assertSubset([GraphDeltaDescription], GraphDescriptionTypes)
+
+    def test_construct(self):
+        # GraphDeltaDescription can be constructed by default:
+        GraphDeltaDescription()
+        # or with a graph and delta
+        description = GraphDeltaDescription('a', 'b')
+        self.assertEqual('a', description.full_graph)
+        self.assertEqual('b', description.delta)
+
+    def get_vicious_graph(self):
+        """Create a graph with every feature that we can encounter.
+
+        That is multiple heads, multiple tails, ghosts, and a disjoint side
+        graph.
+        """
+        full_graph = Graph()
+        full_graph.add_ghost('a-ghost-1')
+        full_graph.add_ghost('de-second-ghost')
+        full_graph.add_node('head-1', ['common-child', 'unique-child-1'])
+        full_graph.add_node('head-2', ['unique-child-2', 'common-child'])
+        full_graph.add_node('common-child', [])
+        full_graph.add_node('unique-child-1', [])
+        full_graph.add_node('unique-child-2', [])
+        full_graph.add_node('disjoint-head', ['disjoint-child'])
+        full_graph.add_node('disjoint-child', [])
+        return full_graph
+
+    def test_encode_everything(self):
+        # A smoke test - encode with a full graph having multiple heads,
+        # multiple tails, ghosts, and a disjoint side graph.
+        full_graph = self.get_vicious_graph()
+        delta = GraphDelta()
+        description = GraphDeltaDescription(full_graph, delta)
+        encoded = description.encode()
+        # the encoding will always ensure the heads are listed, to 
+        # prevent race conditions with concurrent commits. As the delta
+        # was empty, the full_graph heads are used, and there are no cut_from
+        # elements.
+        expected_result = ' '.join(full_graph.get_heads()) + '\0'
+        self.assertEqual(expected_result, encoded)
+
+    def test_encode_subgraph(self):
+        # Encode two subgraphs from a full graph having multiple heads,
+        # multiple tails, ghosts, and a disjoint side graph.
+        # 
+        # this should encode only the necessary heads and the cut points.
+        full_graph = self.get_vicious_graph()
+        delta = GraphDelta()
+        delta.heads = set(['de-second-ghost', 'disjoint-head', 'head-1', 'unique-child-2'])
+        delta.cut_from = set(['common-child', 'disjoint-child'])
+        description = GraphDeltaDescription(full_graph, delta)
+        # the encoding should not alter the delta's heads list,
+        # and should encode the cut_from list.
+        expected_result = ' '.join(delta.heads) + '\0' + ' '.join(delta.cut_from)
+        encoded = description.encode()
+        self.assertEqual(expected_result, encoded)
+
+    def test_decode_empty(self):
+        # decoding puts the decoded elements into a graph delta.
+        expected_delta = GraphDelta()
+        description = GraphDeltaDescription()
+        description.decode('\0')
+        self.assertEqual(expected_delta, description.delta)
+
+    def test_decode_non_empty(self):
+        # decoding puts the heads and cutfrom elements into a graph delta.
+        expected_delta = GraphDelta()
+        expected_delta.heads = set(['foo', 'bar', 'gam'])
+        expected_delta.cut_from = set(['quux', 'frazzle'])
+        description = GraphDeltaDescription()
+        description.decode('foo bar gam\0quux frazzle')
+        self.assertEqual(expected_delta, description.delta)
+



More information about the bazaar-commits mailing list