Rev 2330: Update _KnitData parser to raise more helpful errors when it detects corruption. in http://bzr.arbash-meinel.com/branches/bzr/0.16-dev/knit_corrupt

John Arbash Meinel john at arbash-meinel.com
Fri Mar 9 22:15:05 GMT 2007


At http://bzr.arbash-meinel.com/branches/bzr/0.16-dev/knit_corrupt

------------------------------------------------------------
revno: 2330
revision-id: john at arbash-meinel.com-20070309221455-xh1sb3i6v07a7mak
parent: pqm at pqm.ubuntu.com-20070309212715-07aacbc78b3e2dc0
committer: John Arbash Meinel <john at arbash-meinel.com>
branch nick: knit_corrupt
timestamp: Fri 2007-03-09 16:14:55 -0600
message:
  Update _KnitData parser to raise more helpful errors when it detects corruption.
modified:
  bzrlib/knit.py                 knit.py-20051212171256-f056ac8f0fbe1bd9
  bzrlib/tests/test_knit.py      test_knit.py-20051212171302-95d4c00dd5f11f2b
-------------- next part --------------
=== modified file 'bzrlib/knit.py'
--- a/bzrlib/knit.py	2007-02-13 20:33:57 +0000
+++ b/bzrlib/knit.py	2007-03-09 22:14:55 +0000
@@ -1427,7 +1427,7 @@
 
     def add_raw_record(self, raw_data):
         """Append a prepared record to the data file.
-        
+
         :return: the offset in the data file raw_data was written.
         """
         assert isinstance(raw_data, str), 'data must be plain bytes'
@@ -1440,7 +1440,7 @@
                                    dir_mode=self._dir_mode)
             self._need_to_create = False
             return 0
-        
+
     def add_record(self, version_id, digest, lines):
         """Write new text record to disk.  Returns the position in the
         file where it was written."""
@@ -1466,7 +1466,12 @@
                  as (stream, header_record)
         """
         df = GzipFile(mode='rb', fileobj=StringIO(raw_data))
-        rec = self._check_header(version_id, df.readline())
+        try:
+            rec = self._check_header(version_id, df.readline())
+        except Exception, e:
+            raise KnitCorrupt(self._filename,
+                              "While reading {%s} got %s(%s)"
+                              % (version_id, e.__class__.__name__, str(e)))
         return df, rec
 
     def _check_header(self, version_id, line):
@@ -1487,12 +1492,22 @@
         # 4168 calls to readlines in 330
         df = GzipFile(mode='rb', fileobj=StringIO(data))
 
-        record_contents = df.readlines()
+        try:
+            record_contents = df.readlines()
+        except Exception, e:
+            raise KnitCorrupt(self._filename,
+                              "While reading {%s} got %s(%s)"
+                              % (version_id, e.__class__.__name__, str(e)))
         header = record_contents.pop(0)
         rec = self._check_header(version_id, header)
 
         last_line = record_contents.pop()
-        assert len(record_contents) == int(rec[2])
+        if len(record_contents) != int(rec[2]):
+            raise KnitCorrupt(self._filename,
+                              'incorrect number of lines %s != %s'
+                              ' for version {%s}'
+                              % (len(record_contents), int(rec[2]),
+                                 version_id))
         if last_line != 'end %s\n' % rec[1]:
             raise KnitCorrupt(self._filename,
                               'unexpected version end line %r, wanted %r' 

=== modified file 'bzrlib/tests/test_knit.py'
--- a/bzrlib/tests/test_knit.py	2007-02-10 02:48:43 +0000
+++ b/bzrlib/tests/test_knit.py	2007-03-09 22:14:55 +0000
@@ -18,6 +18,8 @@
 
 from cStringIO import StringIO
 import difflib
+import gzip
+import sha
 
 from bzrlib import (
     errors,
@@ -33,6 +35,7 @@
     KnitVersionedFile,
     KnitPlainFactory,
     KnitAnnotateFactory,
+    _KnitData,
     _KnitIndex,
     WeaveToKnit,
     )
@@ -109,12 +112,132 @@
         else:
             return StringIO("\n".join(self.file_lines))
 
+    def readv(self, relpath, offsets):
+        fp = self.get(relpath)
+        for offset, size in offsets:
+            fp.seek(offset)
+            yield offset, fp.read(size)
+
     def __getattr__(self, name):
         def queue_call(*args, **kwargs):
             self.calls.append((name, args, kwargs))
         return queue_call
 
 
+class LowLevelKnitDataTests(TestCase):
+
+    def create_gz_content(self, text):
+        sio = StringIO()
+        gz_file = gzip.GzipFile(mode='wb', fileobj=sio)
+        gz_file.write(text)
+        gz_file.close()
+        return sio.getvalue()
+
+    def test_valid_knit_data(self):
+        sha1sum = sha.new('foo\nbar\n').hexdigest()
+        gz_txt = self.create_gz_content('version rev-id-1 2 %s\n'
+                                        'foo\n'
+                                        'bar\n'
+                                        'end rev-id-1\n'
+                                        % (sha1sum,))
+        transport = MockTransport([gz_txt])
+        data = _KnitData(transport, 'filename', mode='r')
+        records = [('rev-id-1', 0, len(gz_txt))]
+
+        contents = data.read_records(records)
+        self.assertEqual({'rev-id-1':(['foo\n', 'bar\n'], sha1sum)}, contents)
+
+        raw_contents = list(data.read_records_iter_raw(records))
+        self.assertEqual([('rev-id-1', gz_txt)], raw_contents)
+
+    def test_not_enough_lines(self):
+        sha1sum = sha.new('foo\n').hexdigest()
+        # record says 2 lines data says 1
+        gz_txt = self.create_gz_content('version rev-id-1 2 %s\n'
+                                        'foo\n'
+                                        'end rev-id-1\n'
+                                        % (sha1sum,))
+        transport = MockTransport([gz_txt])
+        data = _KnitData(transport, 'filename', mode='r')
+        records = [('rev-id-1', 0, len(gz_txt))]
+        self.assertRaises(errors.KnitCorrupt, data.read_records, records)
+
+        # read_records_iter_raw won't detect that sort of mismatch/corruption
+        raw_contents = list(data.read_records_iter_raw(records))
+        self.assertEqual([('rev-id-1', gz_txt)], raw_contents)
+
+    def test_too_many_lines(self):
+        sha1sum = sha.new('foo\nbar\n').hexdigest()
+        # record says 1 lines data says 2
+        gz_txt = self.create_gz_content('version rev-id-1 1 %s\n'
+                                        'foo\n'
+                                        'bar\n'
+                                        'end rev-id-1\n'
+                                        % (sha1sum,))
+        transport = MockTransport([gz_txt])
+        data = _KnitData(transport, 'filename', mode='r')
+        records = [('rev-id-1', 0, len(gz_txt))]
+        self.assertRaises(errors.KnitCorrupt, data.read_records, records)
+
+        # read_records_iter_raw won't detect that sort of mismatch/corruption
+        raw_contents = list(data.read_records_iter_raw(records))
+        self.assertEqual([('rev-id-1', gz_txt)], raw_contents)
+
+    def test_mismatched_version_id(self):
+        sha1sum = sha.new('foo\nbar\n').hexdigest()
+        gz_txt = self.create_gz_content('version rev-id-1 2 %s\n'
+                                        'foo\n'
+                                        'bar\n'
+                                        'end rev-id-1\n'
+                                        % (sha1sum,))
+        transport = MockTransport([gz_txt])
+        data = _KnitData(transport, 'filename', mode='r')
+        # We are asking for rev-id-2, but the data is rev-id-1
+        records = [('rev-id-2', 0, len(gz_txt))]
+        self.assertRaises(errors.KnitCorrupt, data.read_records, records)
+
+        # read_records_iter_raw will notice if we request the wrong version.
+        self.assertRaises(errors.KnitCorrupt, list,
+                          data.read_records_iter_raw(records))
+
+    def test_uncompressed_data(self):
+        sha1sum = sha.new('foo\nbar\n').hexdigest()
+        txt = ('version rev-id-1 2 %s\n'
+               'foo\n'
+               'bar\n'
+               'end rev-id-1\n'
+               % (sha1sum,))
+        transport = MockTransport([txt])
+        data = _KnitData(transport, 'filename', mode='r')
+        records = [('rev-id-1', 0, len(txt))]
+
+        # We don't have valid gzip data ==> corrupt
+        self.assertRaises(errors.KnitCorrupt, data.read_records, records)
+
+        # read_records_iter_raw will notice the bad data
+        self.assertRaises(errors.KnitCorrupt, list,
+                          data.read_records_iter_raw(records))
+
+    def test_corrupted_data(self):
+        sha1sum = sha.new('foo\nbar\n').hexdigest()
+        gz_txt = self.create_gz_content('version rev-id-1 2 %s\n'
+                                        'foo\n'
+                                        'bar\n'
+                                        'end rev-id-1\n'
+                                        % (sha1sum,))
+        # Change 2 bytes in the middle to \xff
+        gz_txt = gz_txt[:10] + '\xff\xff' + gz_txt[12:]
+        transport = MockTransport([gz_txt])
+        data = _KnitData(transport, 'filename', mode='r')
+        records = [('rev-id-1', 0, len(gz_txt))]
+
+        self.assertRaises(errors.KnitCorrupt, data.read_records, records)
+
+        # read_records_iter_raw will notice if we request the wrong version.
+        self.assertRaises(errors.KnitCorrupt, list,
+                          data.read_records_iter_raw(records))
+
+
 class LowLevelKnitIndexTests(TestCase):
 
     def test_no_such_file(self):



More information about the bazaar-commits mailing list