Rev 1419: Save outstanding changes in http://people.canonical.com/~robertc/baz2.0/twisted

Robert Collins robertc at robertcollins.net
Sat Aug 15 08:27:39 BST 2009


At http://people.canonical.com/~robertc/baz2.0/twisted

------------------------------------------------------------
revno: 1419
revision-id: robertc at robertcollins.net-20090815072738-p5p926ype3ezh4ry
parent: robertc at robertcollins.net-20051026121752-66f676b7a725d06b
committer: Robert Collins <robertc at robertcollins.net>
branch nick: twisted
timestamp: Sat 2009-08-15 17:27:38 +1000
message:
  Save outstanding changes
=== modified file 'bzrlib/selftest/testtransport.py'
--- a/bzrlib/selftest/testtransport.py	2005-10-26 08:06:44 +0000
+++ b/bzrlib/selftest/testtransport.py	2009-08-15 07:27:38 +0000
@@ -106,8 +106,7 @@
             open('a', 'wb').write('some text for a\n')
         else:
             t.put('a', 'some text for a\n')
-        self.assert_(os.path.exists('a'))
-        self.check_file_contents('a', 'some text for a\n')
+        self.assertFileEqual('some text for a\n', 'a')
         self.assertEqual(t.get('a').read(), 'some text for a\n')
         # Make sure 'has' is updated
         self.assertEqual(list(t.has_multi(['a', 'b', 'c', 'd', 'e'])),
@@ -325,8 +324,9 @@
         else:
             t.append('a', 'add\nsome\nmore\ncontents\n')
 
-        self.check_file_contents('a', 
-            'diff\ncontents for\na\nadd\nsome\nmore\ncontents\n')
+        self.assertFileEqual(
+            'diff\ncontents for\na\nadd\nsome\nmore\ncontents\n',
+            'a')
 
         if self.readonly:
             self.assertRaises(TransportNotPossible,
@@ -470,14 +470,36 @@
 
 class SFTPTransportTest(TestCaseWithSFTPServer, TestTransportMixIn):
 
-    #readonly = True
-
     def get_transport(self):
         from bzrlib.transport.sftp import SFTPTransport
         url = self.get_remote_url('')
         return SFTPTransport(url)
 
 
+class SFTP2BaseTransportTest(TestCaseWithSFTPServer):
+
+    def get_transport(self, url=None):
+        from bzrlib.transport.sftp import SFTP2Transport
+        if url is None:
+            url = self.get_remote_url('')
+        result = SFTP2Transport(url)
+        result._accept_all_host_keys()
+        return result
+
+
+class SFTP2StockTransportTest(SFTP2BaseTransportTest, TestTransportMixIn):
+    """Test SFTP2 for transport protocol compliance."""
+
+
+class SFTP2TransportTest(SFTP2BaseTransportTest):
+
+    def test_init(self):
+        instance = self.get_transport(
+            'sftp://someuser:somepassword@localhost:1234/foo/bar')
+        self.assertEqual('/foo/bar/', instance._path)
+        self.assertEqual('someuser:somepassword at localhost:1234', instance._netloc)
+
+
 class TestMemoryTransport(TestCase):
 
     def test_get_transport(self):

=== modified file 'bzrlib/transport/sftp.py'
--- a/bzrlib/transport/sftp.py	2005-10-26 08:38:05 +0000
+++ b/bzrlib/transport/sftp.py	2009-08-15 07:27:38 +0000
@@ -23,6 +23,7 @@
 import stat
 import sys
 import urllib
+import urlparse
 
 from bzrlib.errors import (FileExists,
                            NoSuchFile,
@@ -32,6 +33,7 @@
 from bzrlib.config import config_dir
 from bzrlib.trace import mutter, warning, error
 from bzrlib.transport import Transport, register_transport
+from bzrlib.twisted_support import *
 
 try:
     import paramiko
@@ -39,6 +41,11 @@
     error('The SFTP transport requires paramiko.')
     raise
 
+if not 'sftp' in urlparse.uses_netloc:
+    urlparse.uses_netloc.append('sftp')
+
+if not 'sftp' in urlparse.uses_relative:
+    urlparse.uses_relative.append('sftp')
 
 SYSTEM_HOSTKEYS = {}
 BZR_HOSTKEYS = {}
@@ -84,7 +91,7 @@
     pass
 
 
-class SFTPTransport (Transport):
+class SFTPTransport(Transport):
     """
     Transport implementation for SFTP access.
     """
@@ -460,3 +467,809 @@
         except IOError:
             pass
         return False
+
+# -*- test-case-name: twisted.conch.test.test_cftp -*-
+#
+# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
+# See LICENSE for details.
+
+#
+# $Id: cftp.py,v 1.65 2004/03/11 00:29:14 z3p Exp $
+
+#""" Implementation module for the `cftp` command.
+#"""
+from twisted.conch.client import agent, connect, default, options
+from twisted.conch.error import ConchError
+from twisted.conch.ssh import connection, common
+from twisted.conch.ssh import channel, filetransfer
+from twisted.protocols import basic
+from twisted.internet import reactor, defer, utils
+from twisted.python import log, usage, failure
+
+import os, sys, getpass, struct, tty, fcntl, base64, signal, stat, errno
+import fnmatch, pwd, time, glob
+
+class SFTP2Transport(Transport):
+    """A twisted conch based SFTP transport."""
+
+    def fileTransferError(self, error):
+        # what about after its connected ?
+        print "error:", error
+        while self._connected_actions:
+            d = self._connected_actions.pop()
+            d.errback(error)
+        
+    def fileTransferClientConnected(self, client):
+        print "client:", client
+        if self._ftc is not None:
+            self.fileTransferError(ValueError("self._ftc already set."))
+            return
+        self._ftc = client
+        while self._connected_actions:
+            d = self._connected_actions.pop()
+            d.callback(None)
+
+    def _accept_all_host_keys(self):
+        self._verify_host_key = False
+
+    def append(self, relpath, f):
+        """See Transport.append."""
+        path = self._remote_path(relpath)
+        def open_append():
+            script = Script(self.client.openFile,
+                            path,
+                            filetransfer.FXF_WRITE|filetransfer.FXF_CREAT|
+                                filetransfer.FXF_APPEND, 
+                            {})
+            return script
+        script = Script(self._ensureClient)
+        script.nextAction(open_append)
+        script.nextAction(self._cbPutOpenFile, f)
+        script.handleException(self._ebCloseLf, f)
+        return script
+
+    def _ebCloseLf(self, error, f):
+        print 'woo', f
+
+    def _ensureClient(self):
+        pass
+        
+    def _cbPutOpenFile(self, rf, lf):
+        numRequests = self.client.transport.conn.options['requests']
+        dList = []
+        chunks = []
+        startTime = time.time()
+        for i in range(numRequests):
+            d = self._cbPutWrite(None, rf, lf, chunks, startTime)
+            if d:
+                dList.append(d)
+        dl = defer.DeferredList(dList, fireOnOneErrback=1)
+        dl.addCallback(self._cbPutDone, rf, lf)
+        return dl
+        
+    def _cbPutWrite(self, ignored, rf, lf, chunks, startTime):
+        chunk = self._getNextChunk(chunks)
+        start, size = chunk
+        lf.seek(start)
+        data = lf.read(size)
+        if data:
+            d = rf.writeChunk(start, data)
+            d.addCallback(self._cbPutWrite, rf, lf, chunks, startTime)
+            return d
+        else:
+            return
+
+    def _cbPutDone(self, ignored, rf, lf):
+        lf.close()
+        rf.close()
+
+    def _getNextChunk(self, chunks):
+        end = 0
+        for chunk in chunks:
+            if end == 'eof':
+                return # nothing more to get
+            if end != chunk[0]:
+                i = chunks.index(chunk)
+                chunks.insert(i, (end, chunk[0]))
+                return (end, chunk[0] - end)
+            end = chunk[1]
+        bufSize = int(self.client.transport.conn.options['buffersize'])
+        chunks.append((end, end + bufSize))
+        return (end, bufSize)
+
+    _url_matcher = re.compile(r'^([^@]*@)?(.*?)(:\d+)?$')
+    def __init__(self, base):
+        """Set the base path where files will be stored."""
+        self._verify_host_key = True
+        assert base.startswith('sftp://')
+        super(SFTP2Transport, self).__init__(base)
+        (self._proto, self._netloc,
+            self._path, self._parameters,
+            self._query, self._fragment) = urlparse.urlparse(self.base)
+        match = self._url_matcher.match(self._netloc)
+        if match is None:
+            raise SFTPTransportError('Unable to parse SFTP host %r' % (self._netloc,))
+        self._username, self._host, self._port = match.groups()
+        if self._port is None:
+            self._port = 22
+        else:
+            self._port = int(self._port[1:])
+        self._password = None
+        if self._username is None:
+            self._username = getpass.getuser()
+        else:
+            self._username = self._username[:-1]
+            if ':' in self._username:
+                self._username, self._password = self._username.split(':', 1)
+        self._path = urllib.unquote(self._path)
+        if self._path[-1] != '/':
+            self._path += '/'
+        self._ftc = None
+        self._connection = None
+        self._connected_actions = []
+
+    def put_async(self, relpath, f):
+        """See Transport.put."""
+        # FIXME: get a client async call, connects if needed
+        remote_path = self._remote_path(relpath)
+        def write(result):
+            print "writing"
+        def failed_to_open(error):
+            print "failed to open"
+        def do_put_async(result):
+            print "putting"
+            print self._ftc
+            script = Script(self._ftc.openFile, remote_path, filetransfer.FXF_WRITE|filetransfer.FXF_CREAT, {})
+            script.nextAction(write)
+            script.handleException(failed_to_open)
+            return script
+        if self._ftc is not None:
+            # do stuff
+            return
+        if self._connection is None:
+            # so we can cancel ?
+            self._connection = doConnect(self._host, self._port, self._username,
+                                         not self._verify_host_key, self)
+        # pending connection ensured
+        self._connected_actions.append(defer.Deferred())
+        self._connected_actions[-1].addCallback(do_put_async)
+        return self._connected_actions[-1]
+        print remote_path, "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", self._connection
+        def connected(result):
+            print "YYYYYYYYYYY", result
+        # FIXME: should do something atomic or locking here, this is unsafe
+        #try:
+            #fout = self._sftp.file(path, 'wb')
+        #except IOError, e:
+        #    self._translate_io_exception(e, relpath)
+        #except (IOError, paramiko.SSHException), x:
+        #    raise SFTPTransportError('Unable to write file %r' % (path,), x)
+        #try:
+        #    self._pump(f, fout)
+        #finally:
+        #    fout.close()
+        # return script
+        return defer.Deferred()
+    put = make_synchronous(put_async)
+
+    def _remote_path(self, relpath):
+        """Get the remote path for relpath.
+        
+        Note that we dont believe .. exists. But it may. Whatever.
+        """
+        return self._path + relpath
+
+    
+def doConnect(host, port, user, accept_all_hosts, transport):
+    conn = SSHConnection(transport)
+    conn.options = options.ConchOptions()
+    if not accept_all_hosts:
+        vhk = default.verifyHostKey
+    else:
+        def _say_yes(transport, host, pubKey, fingerprint):
+            return defer.succeed(None)
+        vhk = _say_yes
+    uao = default.SSHUserAuthClient(user, conn.options, conn)
+    connect.connect(host, port, conn.options, vhk, uao)
+    conn._client = transport
+    return conn
+
+def _ignore(*args): pass
+
+class FileWrapper:
+
+    def __init__(self, f):
+        self.f = f
+        self.total = 0.0
+        f.seek(0, 2) # seek to the end
+        self.size = f.tell()
+
+    def __getattr__(self, attr):
+        return getattr(self.f, attr)
+
+
+class StdioClient(basic.LineReceiver):
+
+    def connectionMade(self):
+        self.client.realPath('').addCallback(self._cbSetCurDir)
+
+    def _cbSetCurDir(self, path):
+        self.currentDirectory = path
+
+    def lineReceived(self, line):
+        if self.client.transport.localClosed:
+            return
+        log.msg('got line %s' % repr(line))
+        line = line.lstrip()
+        if not line:
+            return
+        if self.file and line.startswith('-'):
+            self.ignoreErrors = 1
+            line = line[1:]
+        else:
+            self.ignoreErrors = 0
+        if ' ' in line:
+            command, rest = line.split(' ', 1)
+            rest = rest.lstrip()
+        else:
+            command, rest = line, ''
+        if command.startswith('!'): # command
+            f = self.cmd_EXEC
+            rest = (command[1:] + ' ' + rest).strip()
+        else:
+            command = command.upper()
+            log.msg('looking up cmd %s' % command)
+            f = getattr(self, 'cmd_%s' % command, None)
+        if f is not None:
+            d = defer.maybeDeferred(f, rest)
+            d.addCallback(self._cbCommand)
+            d.addErrback(self._ebCommand)
+        else:
+            self._ebCommand(failure.Failure(NotImplementedError(
+                "No command called `%s'" % command)))
+
+    def _printFailure(self, f):
+        log.msg(f)
+        e = f.trap(NotImplementedError, filetransfer.SFTPError, OSError, IOError)
+        if e == filetransfer.SFTPError:
+            self.transport.write("remote error %i: %s\n" % 
+                    (f.value.code, f.value.message))
+        elif e in (OSError, IOError):
+            self.transport.write("local error %i: %s\n" %
+                    (f.value.errno, f.value.strerror))
+
+    def _cbCommand(self, result):
+        if result is not None:
+            self.transport.write(result)
+            if not result.endswith('\n'):
+                self.transport.write('\n')
+
+    def _ebCommand(self, f):
+        self._printFailure(f)
+        if self.file and not self.ignoreErrors:
+            self.client.transport.loseConnection()
+
+    def cmd_CD(self, path):
+        path, rest = self._getFilename(path)
+        if not path.endswith('/'):
+            path += '/'
+        newPath = path and os.path.join(self.currentDirectory, path) or ''
+        d = self.client.openDirectory(newPath)
+        d.addCallback(self._cbCd)
+        d.addErrback(self._ebCommand)
+        return d
+
+    def _cbCd(self, directory):
+        directory.close()
+        d = self.client.realPath(directory.name)
+        d.addCallback(self._cbCurDir)
+        return d
+
+    def _cbCurDir(self, path):
+        self.currentDirectory = path
+
+    def cmd_CHGRP(self, rest):
+        grp, rest = rest.split(None, 1)
+        path, rest = self._getFilename(rest)
+        grp = int(grp)
+        d = self.client.getAttrs(path)
+        d.addCallback(self._cbSetUsrGrp, path, grp=grp)
+        return d
+    
+    def cmd_CHMOD(self, rest):
+        mod, rest = rest.split(None, 1)
+        path, rest = self._getFilename(rest)
+        mod = int(mod, 8)
+        d = self.client.setAttrs(path, {'permissions':mod})
+        d.addCallback(_ignore)
+        return d
+    
+    def cmd_CHOWN(self, rest):
+        usr, rest = rest.split(None, 1)
+        path, rest = self._getFilename(rest)
+        usr = int(usr)
+        d = self.client.getAttrs(path)
+        d.addCallback(self._cbSetUsrGrp, path, usr=usr)
+        return d
+    
+    def _cbSetUsrGrp(self, attrs, path, usr=None, grp=None):
+        new = {}
+        new['uid'] = (usr is not None) and usr or attrs['uid']
+        new['gid'] = (grp is not None) and grp or attrs['gid']
+        d = self.client.setAttrs(path, new)
+        d.addCallback(_ignore)
+        return d
+
+    def cmd_GET(self, rest):
+        remote, rest = self._getFilename(rest)
+        if '*' in remote or '?' in remote: # wildcard
+            if rest:
+                local, rest = self._getFilename(rest)
+                if not os.path.isdir(local):
+                    return "Wildcard get with non-directory target."
+            else:
+                local = ''
+            d = self._remoteGlob(remote)
+            d.addCallback(self._cbGetMultiple, local)
+            return d
+        if rest:
+            local, rest = self._getFilename(rest)
+        else:
+            local = os.path.split(remote)[1]
+        log.msg((remote, local))
+        lf = file(local, 'w', 0)
+        path = os.path.join(self.currentDirectory, remote)
+        d = self.client.openFile(path, filetransfer.FXF_READ, {})
+        d.addCallback(self._cbGetOpenFile, lf)
+        d.addErrback(self._ebCloseLf, lf)
+        return d
+
+    def _cbGetMultiple(self, files, local):
+        #if self._useProgressBar: # one at a time
+        # XXX this can be optimized for times w/o progress bar
+        return self._cbGetMultipleNext(None, files, local)
+
+    def _cbGetMultipleNext(self, res, files, local):
+        if isinstance(res, failure.Failure):
+            self._printFailure(res)
+        elif res:
+            self.transport.write(res)
+            if not res.endswith('\n'):
+                self.transport.write('\n')
+        if not files:
+            return
+        f = files.pop(0)[0]
+        lf = file(os.path.join(local, os.path.split(f)[1]), 'w', 0)
+        path = os.path.join(self.currentDirectory, f)
+        d = self.client.openFile(path, filetransfer.FXF_READ, {})
+        d.addCallback(self._cbGetOpenFile, lf)
+        d.addErrback(self._ebCloseLf, lf)
+        d.addBoth(self._cbGetMultipleNext, files, local)
+        return d
+
+    def _ebCloseLf(self, f, lf):
+        lf.close()
+        return f
+
+    def _cbGetOpenFile(self, rf, lf):
+        return rf.getAttrs().addCallback(self._cbGetFileSize, rf, lf)
+
+    def _cbGetFileSize(self, attrs, rf, lf):
+        if not stat.S_ISREG(attrs['permissions']):
+            rf.close()
+            lf.close()
+            return "Can't get non-regular file: %s" % rf.name
+        rf.size = attrs['size']
+        bufferSize = self.client.transport.conn.options['buffersize']
+        numRequests = self.client.transport.conn.options['requests']
+        rf.total = 0.0
+        dList = []
+        chunks = []
+        startTime = time.time()
+        for i in range(numRequests):            
+            d = self._cbGetRead('', rf, lf, chunks, 0, bufferSize, startTime)
+            dList.append(d)
+        dl = defer.DeferredList(dList, fireOnOneErrback=1)
+        dl.addCallback(self._cbGetDone, rf, lf)
+        return dl
+
+    def _getNextChunk(self, chunks):
+        end = 0
+        for chunk in chunks:
+            if end == 'eof':
+                return # nothing more to get
+            if end != chunk[0]:
+                i = chunks.index(chunk)
+                chunks.insert(i, (end, chunk[0]))
+                return (end, chunk[0] - end)
+            end = chunk[1]
+        bufSize = int(self.client.transport.conn.options['buffersize'])
+        chunks.append((end, end + bufSize))
+        return (end, bufSize)
+   
+    def _cbGetRead(self, data, rf, lf, chunks, start, size, startTime):
+        if data and isinstance(data, failure.Failure):
+            log.msg('get read err: %s' % data)
+            reason = data
+            reason.trap(EOFError)
+            i = chunks.index((start, start + size))
+            del chunks[i]
+            chunks.insert(i, (start, 'eof'))
+        elif data:
+            log.msg('get read data: %i' % len(data))
+            lf.seek(start)
+            lf.write(data)
+            if len(data) != size:
+                log.msg('got less than we asked for: %i < %i' % 
+                        (len(data), size))
+                i = chunks.index((start, start + size))
+                del chunks[i]
+                chunks.insert(i, (start, start + len(data)))
+            rf.total += len(data)
+        if self.useProgressBar:
+            self._printProgessBar(rf, startTime)
+        chunk = self._getNextChunk(chunks)
+        if not chunk:
+            return
+        else:
+            start, length = chunk
+        log.msg('asking for %i -> %i' % (start, start+length))
+        d = rf.readChunk(start, length)
+        d.addBoth(self._cbGetRead, rf, lf, chunks, start, length, startTime)
+        return d
+
+    def _cbGetDone(self, ignored, rf, lf):
+        log.msg('get done')
+        rf.close()
+        lf.close()
+        if self.useProgressBar:
+            self.transport.write('\n')
+        return "Transferred %s to %s" % (rf.name, lf.name)
+   
+    def cmd_PUT(self, rest):
+        local, rest = self._getFilename(rest)
+        if '*' in local or '?' in local: # wildcard
+            if rest:
+                remote, rest = self._getFilename(rest)
+                path = os.path.join(self.currentDirectory, remote)
+                d = self.client.getAttrs(path)
+                d.addCallback(self._cbPutTargetAttrs, remote, local)
+                return d
+            else:
+                remote = ''
+                files = glob.glob(local)
+                return self._cbPutMultipleNext(None, files, remote)
+        if rest:
+            remote, rest = self._getFilename(rest)
+        else:
+            remote = os.path.split(local)[1]
+        lf = file(local, 'r')
+        path = os.path.join(self.currentDirectory, remote)
+        d = self.client.openFile(path, filetransfer.FXF_WRITE|filetransfer.FXF_CREAT, {})
+        d.addCallback(self._cbPutOpenFile, lf)
+        d.addErrback(self._ebCloseLf, lf)
+        return d
+
+    def _cbPutTargetAttrs(self, attrs, path, local):
+        if not stat.S_ISDIR(attrs['permissions']):
+            return "Wildcard put with non-directory target."
+        return self._cbPutMultipleNext(None, files, path)
+
+    def _cbPutMultipleNext(self, res, files, path):
+        if isinstance(res, failure.Failure):
+            self._printFailure(res)
+        elif res:
+            self.transport.write(res)
+            if not res.endswith('\n'):
+                self.transport.write('\n')
+        f = None
+        while files and not f:
+            try: 
+                f = files.pop(0)
+                lf = file(f, 'r')
+            except:
+                self._printFailure(failure.Failure())
+                f = None
+        if not f:
+            return
+        name = os.path.split(f)[1]
+        remote = os.path.join(self.currentDirectory, path, name)
+        log.msg((name, remote, path))
+        d = self.client.openFile(remote, filetransfer.FXF_WRITE|filetransfer.FXF_CREAT, {})
+        d.addCallback(self._cbPutOpenFile, lf)
+        d.addErrback(self._ebCloseLf, lf)
+        d.addBoth(self._cbPutMultipleNext, files, path)
+        return d
+
+    def _cbPutOpenFile(self, rf, lf):
+        numRequests = self.client.transport.conn.options['requests']
+        if self.useProgressBar:
+            lf = FileWrapper(lf)
+        dList = []
+        chunks = []
+        startTime = time.time()
+        for i in range(numRequests):
+            d = self._cbPutWrite(None, rf, lf, chunks, startTime)
+            if d:
+                dList.append(d)
+        dl = defer.DeferredList(dList, fireOnOneErrback=1)
+        dl.addCallback(self._cbPutDone, rf, lf)
+        return dl
+
+    def _cbPutWrite(self, ignored, rf, lf, chunks, startTime):
+        chunk = self._getNextChunk(chunks)
+        start, size = chunk
+        lf.seek(start)
+        data = lf.read(size)
+        if self.useProgressBar:
+            lf.total += len(data)
+            self._printProgessBar(lf, startTime)
+        if data:
+            d = rf.writeChunk(start, data)
+            d.addCallback(self._cbPutWrite, rf, lf, chunks, startTime)
+            return d
+        else:
+            return
+
+    def _cbPutDone(self, ignored, rf, lf):
+        lf.close()
+        rf.close()
+        if self.useProgressBar:
+            self.transport.write('\n')
+        return 'Transferred %s to %s' % (lf.name, rf.name)
+
+    def cmd_LS(self, rest):
+        # possible lines:
+        # ls                    current directory
+        # ls name_of_file       that file
+        # ls name_of_directory  that directory
+        # ls some_glob_string   current directory, globbed for that string
+        options = []
+        rest = rest.split()
+        while rest and rest[0] and rest[0][0] == '-':
+            opts = rest.pop(0)[1:]
+            for o in opts:
+                if o == 'l':
+                    options.append('verbose')
+                elif o == 'a':
+                    options.append('all')
+        rest = ' '.join(rest)
+        path, rest = self._getFilename(rest)
+        if not path:
+            fullPath = self.currentDirectory + '/'
+        else:
+            fullPath = os.path.join(self.currentDirectory, path)
+        d = self._remoteGlob(fullPath)
+        d.addCallback(self._cbDisplayFiles, options)
+        return d
+
+    def _cbDisplayFiles(self, files, options):
+        files.sort()
+        if 'all' not in options:
+            files = [f for f in files if not f[0].startswith('.')]
+        if 'verbose' in options:
+            lines = [f[1] for f in files]
+        else:
+            lines = [f[0] for f in files]
+        if not lines:
+            return None
+        else:
+            return '\n'.join(lines)
+
+    def cmd_MKDIR(self, path):
+        path, rest = self._getFilename(path)
+        path = os.path.join(self.currentDirectory, path)
+        return self.client.makeDirectory(path, {}).addCallback(_ignore)
+
+    def cmd_RMDIR(self, path):
+        path, rest = self._getFilename(path)
+        path = os.path.join(self.currentDirectory, path)
+        return self.client.removeDirectory(path).addCallback(_ignore)
+
+    def cmd_RM(self, path):
+        path, rest = self._getFilename(path)
+        path = os.path.join(self.currentDirectory, path)
+        return self.client.removeFile(path).addCallback(_ignore)
+
+    def cmd_RENAME(self, rest):
+        oldpath, rest = self._getFilename(rest)
+        newpath, rest = self._getFilename(rest)
+        oldpath, newpath = map (
+                lambda x: os.path.join(self.currentDirectory, x),
+                (oldpath, newpath))
+        return self.client.renameFile(oldpath, newpath).addCallback(_ignore)
+
+    def cmd_EXIT(self, ignored):
+        self.client.transport.loseConnection()
+
+    def cmd_VERSION(self, ignored):
+        return "SFTP version %i" % self.client.version
+    
+    def cmd_PWD(self, ignored):
+        return self.currentDirectory
+
+    def cmd_PROGRESS(self, ignored):
+        self.useProgressBar = not self.useProgressBar
+        return "%ssing progess bar." % (self.useProgressBar and "U" or "Not u")
+
+    def cmd_EXEC(self, rest):
+        shell = pwd.getpwnam(getpass.getuser())[6]
+        print repr(rest)
+        if rest:
+            cmds = ['-c', rest]
+            return utils.getProcessOutput(shell, cmds, errortoo=1)
+        else:
+            os.system(shell)
+
+    # accessory functions
+
+    def _remoteGlob(self, fullPath):
+        log.msg('looking up %s' % fullPath)
+        head, tail = os.path.split(fullPath)
+        if '*' in tail or '?' in tail:
+            glob = 1
+        else:
+            glob = 0
+        if tail and not glob: # could be file or directory
+           # try directory first
+           d = self.client.openDirectory(fullPath)
+           d.addCallback(self._cbOpenList, '')
+           d.addErrback(self._ebNotADirectory, head, tail)
+        else:
+            d = self.client.openDirectory(head)
+            d.addCallback(self._cbOpenList, tail)
+        return d
+
+    def _cbOpenList(self, directory, glob):
+        files = []
+        d = directory.read()
+        d.addBoth(self._cbReadFile, files, directory, glob)
+        return d
+
+    def _ebNotADirectory(self, reason, path, glob):
+        d = self.client.openDirectory(path)
+        d.addCallback(self._cbOpenList, glob)
+        return d
+
+    def _cbReadFile(self, files, l, directory, glob):
+        if not isinstance(files, failure.Failure):
+            if glob:
+                l.extend([f for f in files if fnmatch.fnmatch(f[0], glob)])
+            else:
+                l.extend(files)
+            d = directory.read()
+            d.addBoth(self._cbReadFile, l, directory, glob)
+            return d
+        else:
+            reason = files
+            reason.trap(EOFError)
+            directory.close()
+            return l
+
+    def _abbrevSize(self, size):
+        # from http://mail.python.org/pipermail/python-list/1999-December/018395.html
+        _abbrevs = [
+            (1<<50L, 'PB'),
+            (1<<40L, 'TB'), 
+            (1<<30L, 'GB'), 
+            (1<<20L, 'MB'), 
+            (1<<10L, 'kb'),
+            (1, '')
+            ]
+
+        for factor, suffix in _abbrevs:
+            if size > factor:
+                break
+        return '%.1f' % (size/factor) + suffix
+
+    def _abbrevTime(self, t):
+        if t > 3600: # 1 hour
+            hours = int(t / 3600)
+            t -= (3600 * hours)
+            mins = int(t / 60)
+            t -= (60 * mins)
+            return "%i:%02i:%02i" % (hours, mins, t)
+        else:
+            mins = int(t/60)
+            t -= (60 * mins)
+            return "%02i:%02i" % (mins, t)
+
+    def _printProgessBar(self, f, startTime):
+        diff = time.time() - startTime
+        total = f.total
+        try:
+            winSize = struct.unpack('4H', 
+                fcntl.ioctl(0, tty.TIOCGWINSZ, '12345679'))
+        except IOError:
+            winSize = [None, 80]
+        speed = total/diff
+        if speed:
+            timeLeft = (f.size - total) / speed
+        else:
+            timeLeft = 0
+        front = f.name
+        back = '%3i%% %s %sps %s ' % ((total/f.size)*100, self._abbrevSize(total),
+                self._abbrevSize(total/diff), self._abbrevTime(timeLeft))
+        spaces = (winSize[1] - (len(front) + len(back) + 1)) * ' '
+        self.transport.write('\r%s%s%s' % (front, spaces, back)) 
+
+    def _getFilename(self, line):
+        line.lstrip()
+        if not line:
+            return None, ''
+        if line[0] in '\'"':
+            ret = []
+            line = list(line)
+            try:
+                for i in range(1,len(line)):
+                    c = line[i]
+                    if c == line[0]:
+                        return ''.join(ret), ''.join(line[i+1:]).lstrip()
+                    elif c == '\\': # quoted character
+                        del line[i]
+                        if line[i] not in '\'"\\':
+                            raise IndexError, "bad quote: \\%s" % line[i]
+                        ret.append(line[i])
+                    else:
+                        ret.append(line[i])
+            except IndexError:
+                raise IndexError, "unterminated quote"
+        ret = line.split(None, 1)
+        if len(ret) == 1:
+            return ret[0], ''
+        else:
+            return ret
+
+
+class SSHConnection(connection.SSHConnection):
+
+    def __init__(self, client):
+        self._client = client
+        connection.SSHConnection.__init__(self)
+        
+    def serviceStarted(self):
+        session = SSHSession()
+        session._client = self._client
+        self.openChannel(session)
+
+
+class SSHSession(channel.SSHChannel):
+
+    name = 'session'
+
+    def channelOpen(self, foo):
+        log.msg('session %s open' % self.id)
+        request = 'subsystem'
+        d = self.conn.sendRequest(self, request, \
+            common.NS('sftp'), wantReply=1)
+        d.addCallback(self._cbSubsystem)
+        d.addCallback(self._cbTellClient)
+        d.addErrback(self.cbTellClientError)
+
+    def closed(self):
+        print "closed session, should inform _client"
+        channel.SSHChannel.closed(self)
+
+    def _cbSubsystem(self, result):
+        self.client = filetransfer.FileTransferClient()
+        self.client.makeConnection(self)
+        self.dataReceived = self.client.dataReceived
+
+    def _cbTellClientError(self, result):
+        self._client.fileTransferError(result)
+        
+    def _cbTellClient(self, result):
+        self._client.fileTransferClientConnected(self.client)
+
+    def extReceived(self, t, data):
+        if t==connection.EXTENDED_DATA_STDERR:
+            log.msg('got %s stderr data' % len(data))
+            sys.stderr.write(data)
+            sys.stderr.flush()
+
+    def eofReceived(self):
+        log.msg('got eof')
+    
+    def closeReceived(self):
+        log.msg('remote side closed %s' % self)
+        self.conn.sendClose(self)

=== modified file 'twisted/conch/client/default.py'
--- a/twisted/conch/client/default.py	2005-10-26 09:06:33 +0000
+++ b/twisted/conch/client/default.py	2009-08-15 07:27:38 +0000
@@ -44,9 +44,10 @@
             return defer.fail(ConchError('bad host key'))
         print "Warning: Permanently added '%s' (%s) to the list of known hosts." % (khHost, {'ssh-dss':'DSA', 'ssh-rsa':'RSA'}[keyType])
         known_hosts = open(os.path.expanduser('~/.ssh/known_hosts'), 'a+')
-        known_hosts.seek(-1, 2)
-        if known_hosts.read(1) != '\n':
-            known_hosts.write('\n')
+        if known_hosts.tell() > 1:
+            known_hosts.seek(-1, 2)
+            if known_hosts.read(1) != '\n':
+                known_hosts.write('\n')
         encodedKey = base64.encodestring(pubKey).replace('\n', '')
         known_hosts.write('%s %s %s\n' % (khHost, keyType, encodedKey))
         known_hosts.close()




More information about the bazaar-commits mailing list