Rev 103: Override _clear_cached_state to destroy the Querier object. in http://bzr.arbash-meinel.com/branches/bzr/history_db/trunk

John Arbash Meinel john at arbash-meinel.com
Wed Apr 21 21:29:22 BST 2010


At http://bzr.arbash-meinel.com/branches/bzr/history_db/trunk

------------------------------------------------------------
revno: 103
revision-id: john at arbash-meinel.com-20100421202914-7n1bxltjw0e42bdf
parent: john at arbash-meinel.com-20100421201704-j04x52cdkyvargz9
committer: John Arbash Meinel <john at arbash-meinel.com>
branch nick: trunk
timestamp: Wed 2010-04-21 15:29:14 -0500
message:
  Override _clear_cached_state to destroy the Querier object.
  
  This allows the test suite to clean itself up normally.
-------------- next part --------------
=== modified file '__init__.py'
--- a/__init__.py	2010-04-21 20:17:04 +0000
+++ b/__init__.py	2010-04-21 20:29:14 +0000
@@ -343,6 +343,8 @@
     '_do_revision_id_to_dotted_revno', None)
 _orig_iter_merge_sorted = getattr(branch.Branch,
     'iter_merge_sorted_revisions', None)
+_orig_clear_cached_state = getattr(branch.Branch,
+    '_clear_cached_state', None)
 
 
 def _get_history_db_path(a_branch):
@@ -380,6 +382,15 @@
     return query
 
 
+def _history_db_clear_cached_state(a_branch):
+    query = getattr(a_branch, '_history_db_querier', _singleton)
+    if query is not _singleton:
+        if query is not None:
+            query._db_conn.close()
+        del a_branch._history_db_querier
+    return _orig_clear_cached_state(a_branch)
+
+
 def _history_db_iter_merge_sorted_revisions(self, start_revision_id=None,
     stop_revision_id=None, stop_rule='exclude', direction='reverse'):
     """See Branch.iter_merge_sorted_revisions()
@@ -524,6 +535,8 @@
         _history_db_revision_id_to_dotted_revno
     branch.Branch.iter_merge_sorted_revisions = \
         _history_db_iter_merge_sorted_revisions
+    branch.Branch._clear_cached_state = \
+        _history_db_clear_cached_state
     branch.Branch.hooks.install_named_hook('post_change_branch_tip',
         _history_db_post_change_branch_tip_hook, 'history_db')
 

=== modified file 'history_db.py'
--- a/history_db.py	2010-04-15 17:36:53 +0000
+++ b/history_db.py	2010-04-21 20:29:14 +0000
@@ -1164,23 +1164,32 @@
     """Perform queries on an existing history db."""
 
     def __init__(self, db_path, a_branch):
-        db_conn = dbapi2.connect(db_path)
-        self._db_conn = db_conn
-        self._cursor = self._db_conn.cursor()
+        self._db_path = db_path
+        self._db_conn = None
+        self._cursor = None
         self._branch = a_branch
         self._branch_tip_rev_id = a_branch.last_revision()
         self._stats = defaultdict(lambda: 0)
 
+    def _get_cursor(self):
+        if self._cursor is not None:
+            return self._cursor
+        db_conn = dbapi2.connect(self._db_path)
+        self._db_conn = db_conn
+        self._cursor = self._db_conn.cursor()
+        return self._cursor
+
     def _get_db_id(self, revision_id):
-        db_res = self._cursor.execute('SELECT db_id FROM revision'
-                                      ' WHERE revision_id = ?',
-                                      [revision_id]).fetchone()
+        db_res = self._get_cursor().execute(
+            'SELECT db_id FROM revision'
+            ' WHERE revision_id = ?',
+            [revision_id]).fetchone()
         if db_res is None:
             return None
         return db_res[0]
 
     def _get_lh_parent_rev_id(self, revision_id):
-        parent_res = self._cursor.execute("""
+        parent_res = self._get_cursor().execute("""
             SELECT p.revision_id
               FROM parent, revision as c, revision as p
              WHERE parent.child = c.db_id
@@ -1194,7 +1203,7 @@
         return parent_res[0]
 
     def _get_lh_parent_db_id(self, revision_db_id):
-        parent_res = self._cursor.execute("""
+        parent_res = self._get_cursor().execute("""
             SELECT parent.parent
               FROM parent
              WHERE parent.child = ?
@@ -1207,7 +1216,7 @@
 
     def _get_possible_dotted_revno(self, tip_revision_id, merged_revision_id):
         """Given a possible tip revision, try to determine the dotted revno."""
-        revno = self._cursor.execute("""
+        revno = self._get_cursor().execute("""
             SELECT revno FROM dotted_revno, revision t, revision m
              WHERE t.revision_id = ?
                AND t.db_id = dotted_revno.tip_revision
@@ -1222,7 +1231,7 @@
 
     def _get_possible_dotted_revno_db_id(self, tip_db_id, merged_db_id):
         """Get a dotted revno if we have it."""
-        revno = self._cursor.execute("""
+        revno = self._get_cursor().execute("""
             SELECT revno FROM dotted_revno
              WHERE tip_revision = ?
                AND merged_revision = ?
@@ -1254,7 +1263,7 @@
         t = time.time()
         rev_id_to_db_id = {}
         db_id_to_rev_id = {}
-        schema.ensure_revisions(self._cursor,
+        schema.ensure_revisions(self._get_cursor(),
                                 [revision_id, self._branch_tip_rev_id],
                                 rev_id_to_db_id, db_id_to_rev_id, graph=None)
         tip_db_id = rev_id_to_db_id[self._branch_tip_rev_id]
@@ -1277,21 +1286,21 @@
         revno = None
         while tip_db_id is not None:
             self._stats['num_steps'] += 1
-            range_res = self._cursor.execute(
+            range_res = self._get_cursor().execute(
                 "SELECT pkey, tail"
                 "  FROM mainline_parent_range"
                 " WHERE head = ?"
                 " ORDER BY count DESC LIMIT 1",
                 (tip_db_id,)).fetchone()
             if range_res is None:
-                revno_res = self._cursor.execute(
+                revno_res = self._get_cursor().execute(
                     "SELECT revno FROM dotted_revno"
                     " WHERE tip_revision = ? AND merged_revision = ?",
                     (tip_db_id, rev_db_id)).fetchone()
                 next_db_id = self._get_lh_parent_db_id(tip_db_id)
             else:
                 pkey, next_db_id = range_res
-                revno_res = self._cursor.execute(
+                revno_res = self._get_cursor().execute(
                     "SELECT revno FROM dotted_revno, mainline_parent"
                     " WHERE tip_revision = mainline_parent.revision"
                     "   AND mainline_parent.range = ?"
@@ -1307,6 +1316,7 @@
 
     def get_dotted_revno_range_multi(self, revision_ids):
         """Determine the dotted revno, using the range info, etc."""
+        cursor = self._get_cursor()
         t = time.time()
         tip_db_id = self._get_db_id(self._branch_tip_rev_id)
         db_ids = set()
@@ -1320,14 +1330,14 @@
         revnos = {}
         while tip_db_id is not None and db_ids:
             self._stats['num_steps'] += 1
-            range_res = self._cursor.execute(
+            range_res = cursor.execute(
                 "SELECT pkey, tail"
                 "  FROM mainline_parent_range"
                 " WHERE head = ?"
                 " ORDER BY count DESC LIMIT 1",
                 (tip_db_id,)).fetchone()
             if range_res is None:
-                revno_res = self._cursor.execute(_add_n_params(
+                revno_res = cursor.execute(_add_n_params(
                     "SELECT merged_revision, revno FROM dotted_revno"
                     " WHERE tip_revision = ?"
                     "   AND merged_revision IN (%s)",
@@ -1336,7 +1346,7 @@
                 next_db_id = self._get_lh_parent_db_id(tip_db_id)
             else:
                 pkey, next_db_id = range_res
-                revno_res = self._cursor.execute(_add_n_params(
+                revno_res = cursor.execute(_add_n_params(
                     "SELECT merged_revision, revno"
                     "  FROM dotted_revno, mainline_parent"
                     " WHERE tip_revision = mainline_parent.revision"
@@ -1360,16 +1370,17 @@
         #       To indicate that the branch has not been imported yet
         revno_strs = set(['.'.join(map(str, revno)) for revno in revnos])
         revno_map = {}
+        cursor = self._get_cursor()
         while tip_db_id is not None and revno_strs:
             self._stats['num_steps'] += 1
-            range_res = self._cursor.execute(
+            range_res = cursor.execute(
                 "SELECT pkey, tail"
                 "  FROM mainline_parent_range"
                 " WHERE head = ?"
                 " ORDER BY count DESC LIMIT 1",
                 (tip_db_id,)).fetchone()
             if range_res is None:
-                revision_res = self._cursor.execute(_add_n_params(
+                revision_res = cursor.execute(_add_n_params(
                     "SELECT revision_id, revno"
                     "  FROM dotted_revno, revision"
                     " WHERE merged_revision = revision.db_id"
@@ -1379,7 +1390,7 @@
                 next_db_id = self._get_lh_parent_db_id(tip_db_id)
             else:
                 pkey, next_db_id = range_res
-                revision_res = self._cursor.execute(_add_n_params(
+                revision_res = cursor.execute(_add_n_params(
                     "SELECT revision_id, revno"
                     "  FROM dotted_revno, mainline_parent, revision"
                     " WHERE tip_revision = mainline_parent.revision"
@@ -1423,7 +1434,8 @@
         If a range cannot be found, just find the next parent.
         :return: (range_or_None, next_db_id)
         """
-        range_res = self._cursor.execute(
+        cursor = self._get_cursor()
+        range_res = cursor.execute(
             "SELECT pkey, tail"
             "  FROM mainline_parent_range"
             " WHERE head = ?"
@@ -1435,7 +1447,7 @@
         range_key, tail_db_id = range_res
         # TODO: Is ORDER BY dist ASC expensive? We know a priori that the list
         #       is probably already in sorted order, but does sqlite know that?
-        range_db_ids = self._cursor.execute(
+        range_db_ids = cursor.execute(
             "SELECT revision FROM mainline_parent"
             " WHERE range = ? ORDER BY dist ASC",
             (range_key,)).fetchall()
@@ -1465,7 +1477,7 @@
         all = set(remaining)
         while remaining:
             next = remaining.popleft()
-            parents = self._cursor.execute("""
+            parents = self._get_cursor().execute("""
                 SELECT p.revision_id
                   FROM parent, revision p, revision c
                  WHERE parent.child = c.db_id
@@ -1479,7 +1491,6 @@
         return all
 
     def walk_ancestry_db_ids(self):
-        _exec = self._cursor.execute
         all_ancestors = set()
         db_id = self._get_db_id(self._branch_tip_rev_id)
         all_ancestors.add(db_id)
@@ -1488,7 +1499,7 @@
             self._stats['num_steps'] += 1
             next = remaining[:100]
             remaining = remaining[len(next):]
-            res = _exec(_add_n_params(
+            res = self._get_cursor().execute(_add_n_params(
                 "SELECT parent FROM parent WHERE child in (%s)",
                 len(db_ids)), next)
             next_p = [p[0] for p in res if p[0] not in all_ancestors]
@@ -1501,7 +1512,6 @@
         
         Use the mainline_parent_range/mainline_parent table to speed things up.
         """
-        _exec = self._cursor.execute
         # All we are doing is pre-seeding the search with all the mainline
         # revisions, we could probably do more with interleaving calls to
         # mainline with calls to parents but this is easier to write :)
@@ -1513,7 +1523,7 @@
             self._stats['num_steps'] += 1
             next = remaining[:100]
             remaining = remaining[len(next):]
-            res = _exec(_add_n_params(
+            res = self._get_cursor().execute(_add_n_params(
                 "SELECT parent FROM parent WHERE child in (%s)",
                 len(next)), next)
             next_p = [p[0] for p in res if p[0] not in all_ancestors]
@@ -1539,9 +1549,10 @@
         db_id = self._get_db_id(self._branch_tip_rev_id)
         all_ancestors = set()
         t = time.time()
+        cursor = self._get_cursor()
         while db_id is not None:
             self._stats['num_steps'] += 1
-            range_res = self._cursor.execute(
+            range_res = cursor.execute(
                 "SELECT pkey, tail"
                 "  FROM mainline_parent_range"
                 " WHERE head = ?"
@@ -1549,14 +1560,14 @@
                 (db_id,)).fetchone()
             if range_res is None:
                 next_db_id = self._get_lh_parent_db_id(db_id)
-                merged_revs = self._cursor.execute(
+                merged_revs = cursor.execute(
                     "SELECT merged_revision FROM dotted_revno"
                     " WHERE tip_revision = ?",
                     (db_id,)).fetchall()
                 all_ancestors.update([r[0] for r in merged_revs])
             else:
                 pkey, next_db_id = range_res
-                merged_revs = self._cursor.execute(
+                merged_revs = cursor.execute(
                     "SELECT merged_revision FROM dotted_revno, mainline_parent"
                     " WHERE tip_revision = mainline_parent.revision"
                     "   AND mainline_parent.range = ?",
@@ -1568,6 +1579,7 @@
 
     def _find_tip_containing(self, tip_db_id, merged_db_id):
         """Walk backwards until you find the tip that contains the given id."""
+        cursor = self._get_cursor()
         while tip_db_id is not None:
             if tip_db_id == merged_db_id:
                 # A tip obviously contains itself
@@ -1575,14 +1587,14 @@
                 return tip_db_id
             self._stats['num_steps'] += 1
             self._stats['step_find_tip_containing'] += 1
-            range_res = self._cursor.execute(
+            range_res = cursor.execute(
                 "SELECT pkey, tail"
                 "  FROM mainline_parent_range"
                 " WHERE head = ?"
                 " ORDER BY count DESC LIMIT 1",
                 (tip_db_id,)).fetchone()
             if range_res is None:
-                present_res = self._cursor.execute(
+                present_res = cursor.execute(
                     "SELECT 1 FROM dotted_revno"
                     " WHERE tip_revision = ?"
                     "   AND merged_revision = ?",
@@ -1590,7 +1602,7 @@
                 next_db_id = self._get_lh_parent_db_id(tip_db_id)
             else:
                 pkey, next_db_id = range_res
-                present_res = self._cursor.execute(
+                present_res = cursor.execute(
                     "SELECT 1"
                     "  FROM dotted_revno, mainline_parent"
                     " WHERE tip_revision = mainline_parent.revision"
@@ -1609,17 +1621,18 @@
             found_start = True
         else:
             found_start = False
+        cursor = self._get_cursor()
         while tip_db_id is not None:
             self._stats['num_steps'] += 1
             self._stats['step_get_merge_sorted'] += 1
-            range_res = self._cursor.execute(
+            range_res = cursor.execute(
                 "SELECT pkey, tail"
                 "  FROM mainline_parent_range"
                 " WHERE head = ?"
                 " ORDER BY count DESC LIMIT 1",
                 (tip_db_id,)).fetchone()
             if range_res is None:
-                merged_res = self._cursor.execute(
+                merged_res = cursor.execute(
                     "SELECT db_id, revision_id, merge_depth, revno,"
                     "       end_of_merge"
                     "  FROM dotted_revno, revision"
@@ -1638,7 +1651,7 @@
                 # At the moment, SELECT order == INSERT order, so we don't
                 # strictly need it. I don't know that we can trust that,
                 # though.
-                merged_res = self._cursor.execute(
+                merged_res = cursor.execute(
                     "SELECT db_id, revision_id, merge_depth, revno,"
                     "       end_of_merge"
                     # "       , mainline_parent.dist as mp_dist"

=== modified file 'test_hooks.py'
--- a/test_hooks.py	2010-04-21 20:17:04 +0000
+++ b/test_hooks.py	2010-04-21 20:29:14 +0000
@@ -64,7 +64,6 @@
         # TODO: It should populate the cache before running, so check that the
         #       cache is filled
         self.assertIsNot(None, b._history_db_querier)
-        b._history_db_querier._db_conn.close()
 
     def test_iter_merge_sorted_cached(self):
         history_db_path = osutils.getcwd() + '/history.db'
@@ -77,4 +76,3 @@
         self.assertEqual(merge_sorted,
                 list(history_db._history_db_iter_merge_sorted_revisions(b)))
         self.assertIsNot(None, b._history_db_querier)
-        b._history_db_querier._db_conn.close()



More information about the bazaar-commits mailing list