Rev 1419: Save outstanding changes in http://people.canonical.com/~robertc/baz2.0/twisted
Robert Collins
robertc at robertcollins.net
Sat Aug 15 08:27:39 BST 2009
At http://people.canonical.com/~robertc/baz2.0/twisted
------------------------------------------------------------
revno: 1419
revision-id: robertc at robertcollins.net-20090815072738-p5p926ype3ezh4ry
parent: robertc at robertcollins.net-20051026121752-66f676b7a725d06b
committer: Robert Collins <robertc at robertcollins.net>
branch nick: twisted
timestamp: Sat 2009-08-15 17:27:38 +1000
message:
Save outstanding changes
=== modified file 'bzrlib/selftest/testtransport.py'
--- a/bzrlib/selftest/testtransport.py 2005-10-26 08:06:44 +0000
+++ b/bzrlib/selftest/testtransport.py 2009-08-15 07:27:38 +0000
@@ -106,8 +106,7 @@
open('a', 'wb').write('some text for a\n')
else:
t.put('a', 'some text for a\n')
- self.assert_(os.path.exists('a'))
- self.check_file_contents('a', 'some text for a\n')
+ self.assertFileEqual('some text for a\n', 'a')
self.assertEqual(t.get('a').read(), 'some text for a\n')
# Make sure 'has' is updated
self.assertEqual(list(t.has_multi(['a', 'b', 'c', 'd', 'e'])),
@@ -325,8 +324,9 @@
else:
t.append('a', 'add\nsome\nmore\ncontents\n')
- self.check_file_contents('a',
- 'diff\ncontents for\na\nadd\nsome\nmore\ncontents\n')
+ self.assertFileEqual(
+ 'diff\ncontents for\na\nadd\nsome\nmore\ncontents\n',
+ 'a')
if self.readonly:
self.assertRaises(TransportNotPossible,
@@ -470,14 +470,36 @@
class SFTPTransportTest(TestCaseWithSFTPServer, TestTransportMixIn):
- #readonly = True
-
def get_transport(self):
from bzrlib.transport.sftp import SFTPTransport
url = self.get_remote_url('')
return SFTPTransport(url)
+class SFTP2BaseTransportTest(TestCaseWithSFTPServer):
+
+ def get_transport(self, url=None):
+ from bzrlib.transport.sftp import SFTP2Transport
+ if url is None:
+ url = self.get_remote_url('')
+ result = SFTP2Transport(url)
+ result._accept_all_host_keys()
+ return result
+
+
+class SFTP2StockTransportTest(SFTP2BaseTransportTest, TestTransportMixIn):
+ """Test SFTP2 for transport protocol compliance."""
+
+
+class SFTP2TransportTest(SFTP2BaseTransportTest):
+
+ def test_init(self):
+ instance = self.get_transport(
+ 'sftp://someuser:somepassword@localhost:1234/foo/bar')
+ self.assertEqual('/foo/bar/', instance._path)
+ self.assertEqual('someuser:somepassword at localhost:1234', instance._netloc)
+
+
class TestMemoryTransport(TestCase):
def test_get_transport(self):
=== modified file 'bzrlib/transport/sftp.py'
--- a/bzrlib/transport/sftp.py 2005-10-26 08:38:05 +0000
+++ b/bzrlib/transport/sftp.py 2009-08-15 07:27:38 +0000
@@ -23,6 +23,7 @@
import stat
import sys
import urllib
+import urlparse
from bzrlib.errors import (FileExists,
NoSuchFile,
@@ -32,6 +33,7 @@
from bzrlib.config import config_dir
from bzrlib.trace import mutter, warning, error
from bzrlib.transport import Transport, register_transport
+from bzrlib.twisted_support import *
try:
import paramiko
@@ -39,6 +41,11 @@
error('The SFTP transport requires paramiko.')
raise
+if not 'sftp' in urlparse.uses_netloc:
+ urlparse.uses_netloc.append('sftp')
+
+if not 'sftp' in urlparse.uses_relative:
+ urlparse.uses_relative.append('sftp')
SYSTEM_HOSTKEYS = {}
BZR_HOSTKEYS = {}
@@ -84,7 +91,7 @@
pass
-class SFTPTransport (Transport):
+class SFTPTransport(Transport):
"""
Transport implementation for SFTP access.
"""
@@ -460,3 +467,809 @@
except IOError:
pass
return False
+
+# -*- test-case-name: twisted.conch.test.test_cftp -*-
+#
+# Copyright (c) 2001-2004 Twisted Matrix Laboratories.
+# See LICENSE for details.
+
+#
+# $Id: cftp.py,v 1.65 2004/03/11 00:29:14 z3p Exp $
+
+#""" Implementation module for the `cftp` command.
+#"""
+from twisted.conch.client import agent, connect, default, options
+from twisted.conch.error import ConchError
+from twisted.conch.ssh import connection, common
+from twisted.conch.ssh import channel, filetransfer
+from twisted.protocols import basic
+from twisted.internet import reactor, defer, utils
+from twisted.python import log, usage, failure
+
+import os, sys, getpass, struct, tty, fcntl, base64, signal, stat, errno
+import fnmatch, pwd, time, glob
+
+class SFTP2Transport(Transport):
+ """A twisted conch based SFTP transport."""
+
+ def fileTransferError(self, error):
+ # what about after its connected ?
+ print "error:", error
+ while self._connected_actions:
+ d = self._connected_actions.pop()
+ d.errback(error)
+
+ def fileTransferClientConnected(self, client):
+ print "client:", client
+ if self._ftc is not None:
+ self.fileTransferError(ValueError("self._ftc already set."))
+ return
+ self._ftc = client
+ while self._connected_actions:
+ d = self._connected_actions.pop()
+ d.callback(None)
+
+ def _accept_all_host_keys(self):
+ self._verify_host_key = False
+
+ def append(self, relpath, f):
+ """See Transport.append."""
+ path = self._remote_path(relpath)
+ def open_append():
+ script = Script(self.client.openFile,
+ path,
+ filetransfer.FXF_WRITE|filetransfer.FXF_CREAT|
+ filetransfer.FXF_APPEND,
+ {})
+ return script
+ script = Script(self._ensureClient)
+ script.nextAction(open_append)
+ script.nextAction(self._cbPutOpenFile, f)
+ script.handleException(self._ebCloseLf, f)
+ return script
+
+ def _ebCloseLf(self, error, f):
+ print 'woo', f
+
+ def _ensureClient(self):
+ pass
+
+ def _cbPutOpenFile(self, rf, lf):
+ numRequests = self.client.transport.conn.options['requests']
+ dList = []
+ chunks = []
+ startTime = time.time()
+ for i in range(numRequests):
+ d = self._cbPutWrite(None, rf, lf, chunks, startTime)
+ if d:
+ dList.append(d)
+ dl = defer.DeferredList(dList, fireOnOneErrback=1)
+ dl.addCallback(self._cbPutDone, rf, lf)
+ return dl
+
+ def _cbPutWrite(self, ignored, rf, lf, chunks, startTime):
+ chunk = self._getNextChunk(chunks)
+ start, size = chunk
+ lf.seek(start)
+ data = lf.read(size)
+ if data:
+ d = rf.writeChunk(start, data)
+ d.addCallback(self._cbPutWrite, rf, lf, chunks, startTime)
+ return d
+ else:
+ return
+
+ def _cbPutDone(self, ignored, rf, lf):
+ lf.close()
+ rf.close()
+
+ def _getNextChunk(self, chunks):
+ end = 0
+ for chunk in chunks:
+ if end == 'eof':
+ return # nothing more to get
+ if end != chunk[0]:
+ i = chunks.index(chunk)
+ chunks.insert(i, (end, chunk[0]))
+ return (end, chunk[0] - end)
+ end = chunk[1]
+ bufSize = int(self.client.transport.conn.options['buffersize'])
+ chunks.append((end, end + bufSize))
+ return (end, bufSize)
+
+ _url_matcher = re.compile(r'^([^@]*@)?(.*?)(:\d+)?$')
+ def __init__(self, base):
+ """Set the base path where files will be stored."""
+ self._verify_host_key = True
+ assert base.startswith('sftp://')
+ super(SFTP2Transport, self).__init__(base)
+ (self._proto, self._netloc,
+ self._path, self._parameters,
+ self._query, self._fragment) = urlparse.urlparse(self.base)
+ match = self._url_matcher.match(self._netloc)
+ if match is None:
+ raise SFTPTransportError('Unable to parse SFTP host %r' % (self._netloc,))
+ self._username, self._host, self._port = match.groups()
+ if self._port is None:
+ self._port = 22
+ else:
+ self._port = int(self._port[1:])
+ self._password = None
+ if self._username is None:
+ self._username = getpass.getuser()
+ else:
+ self._username = self._username[:-1]
+ if ':' in self._username:
+ self._username, self._password = self._username.split(':', 1)
+ self._path = urllib.unquote(self._path)
+ if self._path[-1] != '/':
+ self._path += '/'
+ self._ftc = None
+ self._connection = None
+ self._connected_actions = []
+
+ def put_async(self, relpath, f):
+ """See Transport.put."""
+ # FIXME: get a client async call, connects if needed
+ remote_path = self._remote_path(relpath)
+ def write(result):
+ print "writing"
+ def failed_to_open(error):
+ print "failed to open"
+ def do_put_async(result):
+ print "putting"
+ print self._ftc
+ script = Script(self._ftc.openFile, remote_path, filetransfer.FXF_WRITE|filetransfer.FXF_CREAT, {})
+ script.nextAction(write)
+ script.handleException(failed_to_open)
+ return script
+ if self._ftc is not None:
+ # do stuff
+ return
+ if self._connection is None:
+ # so we can cancel ?
+ self._connection = doConnect(self._host, self._port, self._username,
+ not self._verify_host_key, self)
+ # pending connection ensured
+ self._connected_actions.append(defer.Deferred())
+ self._connected_actions[-1].addCallback(do_put_async)
+ return self._connected_actions[-1]
+ print remote_path, "XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX", self._connection
+ def connected(result):
+ print "YYYYYYYYYYY", result
+ # FIXME: should do something atomic or locking here, this is unsafe
+ #try:
+ #fout = self._sftp.file(path, 'wb')
+ #except IOError, e:
+ # self._translate_io_exception(e, relpath)
+ #except (IOError, paramiko.SSHException), x:
+ # raise SFTPTransportError('Unable to write file %r' % (path,), x)
+ #try:
+ # self._pump(f, fout)
+ #finally:
+ # fout.close()
+ # return script
+ return defer.Deferred()
+ put = make_synchronous(put_async)
+
+ def _remote_path(self, relpath):
+ """Get the remote path for relpath.
+
+ Note that we dont believe .. exists. But it may. Whatever.
+ """
+ return self._path + relpath
+
+
+def doConnect(host, port, user, accept_all_hosts, transport):
+ conn = SSHConnection(transport)
+ conn.options = options.ConchOptions()
+ if not accept_all_hosts:
+ vhk = default.verifyHostKey
+ else:
+ def _say_yes(transport, host, pubKey, fingerprint):
+ return defer.succeed(None)
+ vhk = _say_yes
+ uao = default.SSHUserAuthClient(user, conn.options, conn)
+ connect.connect(host, port, conn.options, vhk, uao)
+ conn._client = transport
+ return conn
+
+def _ignore(*args): pass
+
+class FileWrapper:
+
+ def __init__(self, f):
+ self.f = f
+ self.total = 0.0
+ f.seek(0, 2) # seek to the end
+ self.size = f.tell()
+
+ def __getattr__(self, attr):
+ return getattr(self.f, attr)
+
+
+class StdioClient(basic.LineReceiver):
+
+ def connectionMade(self):
+ self.client.realPath('').addCallback(self._cbSetCurDir)
+
+ def _cbSetCurDir(self, path):
+ self.currentDirectory = path
+
+ def lineReceived(self, line):
+ if self.client.transport.localClosed:
+ return
+ log.msg('got line %s' % repr(line))
+ line = line.lstrip()
+ if not line:
+ return
+ if self.file and line.startswith('-'):
+ self.ignoreErrors = 1
+ line = line[1:]
+ else:
+ self.ignoreErrors = 0
+ if ' ' in line:
+ command, rest = line.split(' ', 1)
+ rest = rest.lstrip()
+ else:
+ command, rest = line, ''
+ if command.startswith('!'): # command
+ f = self.cmd_EXEC
+ rest = (command[1:] + ' ' + rest).strip()
+ else:
+ command = command.upper()
+ log.msg('looking up cmd %s' % command)
+ f = getattr(self, 'cmd_%s' % command, None)
+ if f is not None:
+ d = defer.maybeDeferred(f, rest)
+ d.addCallback(self._cbCommand)
+ d.addErrback(self._ebCommand)
+ else:
+ self._ebCommand(failure.Failure(NotImplementedError(
+ "No command called `%s'" % command)))
+
+ def _printFailure(self, f):
+ log.msg(f)
+ e = f.trap(NotImplementedError, filetransfer.SFTPError, OSError, IOError)
+ if e == filetransfer.SFTPError:
+ self.transport.write("remote error %i: %s\n" %
+ (f.value.code, f.value.message))
+ elif e in (OSError, IOError):
+ self.transport.write("local error %i: %s\n" %
+ (f.value.errno, f.value.strerror))
+
+ def _cbCommand(self, result):
+ if result is not None:
+ self.transport.write(result)
+ if not result.endswith('\n'):
+ self.transport.write('\n')
+
+ def _ebCommand(self, f):
+ self._printFailure(f)
+ if self.file and not self.ignoreErrors:
+ self.client.transport.loseConnection()
+
+ def cmd_CD(self, path):
+ path, rest = self._getFilename(path)
+ if not path.endswith('/'):
+ path += '/'
+ newPath = path and os.path.join(self.currentDirectory, path) or ''
+ d = self.client.openDirectory(newPath)
+ d.addCallback(self._cbCd)
+ d.addErrback(self._ebCommand)
+ return d
+
+ def _cbCd(self, directory):
+ directory.close()
+ d = self.client.realPath(directory.name)
+ d.addCallback(self._cbCurDir)
+ return d
+
+ def _cbCurDir(self, path):
+ self.currentDirectory = path
+
+ def cmd_CHGRP(self, rest):
+ grp, rest = rest.split(None, 1)
+ path, rest = self._getFilename(rest)
+ grp = int(grp)
+ d = self.client.getAttrs(path)
+ d.addCallback(self._cbSetUsrGrp, path, grp=grp)
+ return d
+
+ def cmd_CHMOD(self, rest):
+ mod, rest = rest.split(None, 1)
+ path, rest = self._getFilename(rest)
+ mod = int(mod, 8)
+ d = self.client.setAttrs(path, {'permissions':mod})
+ d.addCallback(_ignore)
+ return d
+
+ def cmd_CHOWN(self, rest):
+ usr, rest = rest.split(None, 1)
+ path, rest = self._getFilename(rest)
+ usr = int(usr)
+ d = self.client.getAttrs(path)
+ d.addCallback(self._cbSetUsrGrp, path, usr=usr)
+ return d
+
+ def _cbSetUsrGrp(self, attrs, path, usr=None, grp=None):
+ new = {}
+ new['uid'] = (usr is not None) and usr or attrs['uid']
+ new['gid'] = (grp is not None) and grp or attrs['gid']
+ d = self.client.setAttrs(path, new)
+ d.addCallback(_ignore)
+ return d
+
+ def cmd_GET(self, rest):
+ remote, rest = self._getFilename(rest)
+ if '*' in remote or '?' in remote: # wildcard
+ if rest:
+ local, rest = self._getFilename(rest)
+ if not os.path.isdir(local):
+ return "Wildcard get with non-directory target."
+ else:
+ local = ''
+ d = self._remoteGlob(remote)
+ d.addCallback(self._cbGetMultiple, local)
+ return d
+ if rest:
+ local, rest = self._getFilename(rest)
+ else:
+ local = os.path.split(remote)[1]
+ log.msg((remote, local))
+ lf = file(local, 'w', 0)
+ path = os.path.join(self.currentDirectory, remote)
+ d = self.client.openFile(path, filetransfer.FXF_READ, {})
+ d.addCallback(self._cbGetOpenFile, lf)
+ d.addErrback(self._ebCloseLf, lf)
+ return d
+
+ def _cbGetMultiple(self, files, local):
+ #if self._useProgressBar: # one at a time
+ # XXX this can be optimized for times w/o progress bar
+ return self._cbGetMultipleNext(None, files, local)
+
+ def _cbGetMultipleNext(self, res, files, local):
+ if isinstance(res, failure.Failure):
+ self._printFailure(res)
+ elif res:
+ self.transport.write(res)
+ if not res.endswith('\n'):
+ self.transport.write('\n')
+ if not files:
+ return
+ f = files.pop(0)[0]
+ lf = file(os.path.join(local, os.path.split(f)[1]), 'w', 0)
+ path = os.path.join(self.currentDirectory, f)
+ d = self.client.openFile(path, filetransfer.FXF_READ, {})
+ d.addCallback(self._cbGetOpenFile, lf)
+ d.addErrback(self._ebCloseLf, lf)
+ d.addBoth(self._cbGetMultipleNext, files, local)
+ return d
+
+ def _ebCloseLf(self, f, lf):
+ lf.close()
+ return f
+
+ def _cbGetOpenFile(self, rf, lf):
+ return rf.getAttrs().addCallback(self._cbGetFileSize, rf, lf)
+
+ def _cbGetFileSize(self, attrs, rf, lf):
+ if not stat.S_ISREG(attrs['permissions']):
+ rf.close()
+ lf.close()
+ return "Can't get non-regular file: %s" % rf.name
+ rf.size = attrs['size']
+ bufferSize = self.client.transport.conn.options['buffersize']
+ numRequests = self.client.transport.conn.options['requests']
+ rf.total = 0.0
+ dList = []
+ chunks = []
+ startTime = time.time()
+ for i in range(numRequests):
+ d = self._cbGetRead('', rf, lf, chunks, 0, bufferSize, startTime)
+ dList.append(d)
+ dl = defer.DeferredList(dList, fireOnOneErrback=1)
+ dl.addCallback(self._cbGetDone, rf, lf)
+ return dl
+
+ def _getNextChunk(self, chunks):
+ end = 0
+ for chunk in chunks:
+ if end == 'eof':
+ return # nothing more to get
+ if end != chunk[0]:
+ i = chunks.index(chunk)
+ chunks.insert(i, (end, chunk[0]))
+ return (end, chunk[0] - end)
+ end = chunk[1]
+ bufSize = int(self.client.transport.conn.options['buffersize'])
+ chunks.append((end, end + bufSize))
+ return (end, bufSize)
+
+ def _cbGetRead(self, data, rf, lf, chunks, start, size, startTime):
+ if data and isinstance(data, failure.Failure):
+ log.msg('get read err: %s' % data)
+ reason = data
+ reason.trap(EOFError)
+ i = chunks.index((start, start + size))
+ del chunks[i]
+ chunks.insert(i, (start, 'eof'))
+ elif data:
+ log.msg('get read data: %i' % len(data))
+ lf.seek(start)
+ lf.write(data)
+ if len(data) != size:
+ log.msg('got less than we asked for: %i < %i' %
+ (len(data), size))
+ i = chunks.index((start, start + size))
+ del chunks[i]
+ chunks.insert(i, (start, start + len(data)))
+ rf.total += len(data)
+ if self.useProgressBar:
+ self._printProgessBar(rf, startTime)
+ chunk = self._getNextChunk(chunks)
+ if not chunk:
+ return
+ else:
+ start, length = chunk
+ log.msg('asking for %i -> %i' % (start, start+length))
+ d = rf.readChunk(start, length)
+ d.addBoth(self._cbGetRead, rf, lf, chunks, start, length, startTime)
+ return d
+
+ def _cbGetDone(self, ignored, rf, lf):
+ log.msg('get done')
+ rf.close()
+ lf.close()
+ if self.useProgressBar:
+ self.transport.write('\n')
+ return "Transferred %s to %s" % (rf.name, lf.name)
+
+ def cmd_PUT(self, rest):
+ local, rest = self._getFilename(rest)
+ if '*' in local or '?' in local: # wildcard
+ if rest:
+ remote, rest = self._getFilename(rest)
+ path = os.path.join(self.currentDirectory, remote)
+ d = self.client.getAttrs(path)
+ d.addCallback(self._cbPutTargetAttrs, remote, local)
+ return d
+ else:
+ remote = ''
+ files = glob.glob(local)
+ return self._cbPutMultipleNext(None, files, remote)
+ if rest:
+ remote, rest = self._getFilename(rest)
+ else:
+ remote = os.path.split(local)[1]
+ lf = file(local, 'r')
+ path = os.path.join(self.currentDirectory, remote)
+ d = self.client.openFile(path, filetransfer.FXF_WRITE|filetransfer.FXF_CREAT, {})
+ d.addCallback(self._cbPutOpenFile, lf)
+ d.addErrback(self._ebCloseLf, lf)
+ return d
+
+ def _cbPutTargetAttrs(self, attrs, path, local):
+ if not stat.S_ISDIR(attrs['permissions']):
+ return "Wildcard put with non-directory target."
+ return self._cbPutMultipleNext(None, files, path)
+
+ def _cbPutMultipleNext(self, res, files, path):
+ if isinstance(res, failure.Failure):
+ self._printFailure(res)
+ elif res:
+ self.transport.write(res)
+ if not res.endswith('\n'):
+ self.transport.write('\n')
+ f = None
+ while files and not f:
+ try:
+ f = files.pop(0)
+ lf = file(f, 'r')
+ except:
+ self._printFailure(failure.Failure())
+ f = None
+ if not f:
+ return
+ name = os.path.split(f)[1]
+ remote = os.path.join(self.currentDirectory, path, name)
+ log.msg((name, remote, path))
+ d = self.client.openFile(remote, filetransfer.FXF_WRITE|filetransfer.FXF_CREAT, {})
+ d.addCallback(self._cbPutOpenFile, lf)
+ d.addErrback(self._ebCloseLf, lf)
+ d.addBoth(self._cbPutMultipleNext, files, path)
+ return d
+
+ def _cbPutOpenFile(self, rf, lf):
+ numRequests = self.client.transport.conn.options['requests']
+ if self.useProgressBar:
+ lf = FileWrapper(lf)
+ dList = []
+ chunks = []
+ startTime = time.time()
+ for i in range(numRequests):
+ d = self._cbPutWrite(None, rf, lf, chunks, startTime)
+ if d:
+ dList.append(d)
+ dl = defer.DeferredList(dList, fireOnOneErrback=1)
+ dl.addCallback(self._cbPutDone, rf, lf)
+ return dl
+
+ def _cbPutWrite(self, ignored, rf, lf, chunks, startTime):
+ chunk = self._getNextChunk(chunks)
+ start, size = chunk
+ lf.seek(start)
+ data = lf.read(size)
+ if self.useProgressBar:
+ lf.total += len(data)
+ self._printProgessBar(lf, startTime)
+ if data:
+ d = rf.writeChunk(start, data)
+ d.addCallback(self._cbPutWrite, rf, lf, chunks, startTime)
+ return d
+ else:
+ return
+
+ def _cbPutDone(self, ignored, rf, lf):
+ lf.close()
+ rf.close()
+ if self.useProgressBar:
+ self.transport.write('\n')
+ return 'Transferred %s to %s' % (lf.name, rf.name)
+
+ def cmd_LS(self, rest):
+ # possible lines:
+ # ls current directory
+ # ls name_of_file that file
+ # ls name_of_directory that directory
+ # ls some_glob_string current directory, globbed for that string
+ options = []
+ rest = rest.split()
+ while rest and rest[0] and rest[0][0] == '-':
+ opts = rest.pop(0)[1:]
+ for o in opts:
+ if o == 'l':
+ options.append('verbose')
+ elif o == 'a':
+ options.append('all')
+ rest = ' '.join(rest)
+ path, rest = self._getFilename(rest)
+ if not path:
+ fullPath = self.currentDirectory + '/'
+ else:
+ fullPath = os.path.join(self.currentDirectory, path)
+ d = self._remoteGlob(fullPath)
+ d.addCallback(self._cbDisplayFiles, options)
+ return d
+
+ def _cbDisplayFiles(self, files, options):
+ files.sort()
+ if 'all' not in options:
+ files = [f for f in files if not f[0].startswith('.')]
+ if 'verbose' in options:
+ lines = [f[1] for f in files]
+ else:
+ lines = [f[0] for f in files]
+ if not lines:
+ return None
+ else:
+ return '\n'.join(lines)
+
+ def cmd_MKDIR(self, path):
+ path, rest = self._getFilename(path)
+ path = os.path.join(self.currentDirectory, path)
+ return self.client.makeDirectory(path, {}).addCallback(_ignore)
+
+ def cmd_RMDIR(self, path):
+ path, rest = self._getFilename(path)
+ path = os.path.join(self.currentDirectory, path)
+ return self.client.removeDirectory(path).addCallback(_ignore)
+
+ def cmd_RM(self, path):
+ path, rest = self._getFilename(path)
+ path = os.path.join(self.currentDirectory, path)
+ return self.client.removeFile(path).addCallback(_ignore)
+
+ def cmd_RENAME(self, rest):
+ oldpath, rest = self._getFilename(rest)
+ newpath, rest = self._getFilename(rest)
+ oldpath, newpath = map (
+ lambda x: os.path.join(self.currentDirectory, x),
+ (oldpath, newpath))
+ return self.client.renameFile(oldpath, newpath).addCallback(_ignore)
+
+ def cmd_EXIT(self, ignored):
+ self.client.transport.loseConnection()
+
+ def cmd_VERSION(self, ignored):
+ return "SFTP version %i" % self.client.version
+
+ def cmd_PWD(self, ignored):
+ return self.currentDirectory
+
+ def cmd_PROGRESS(self, ignored):
+ self.useProgressBar = not self.useProgressBar
+ return "%ssing progess bar." % (self.useProgressBar and "U" or "Not u")
+
+ def cmd_EXEC(self, rest):
+ shell = pwd.getpwnam(getpass.getuser())[6]
+ print repr(rest)
+ if rest:
+ cmds = ['-c', rest]
+ return utils.getProcessOutput(shell, cmds, errortoo=1)
+ else:
+ os.system(shell)
+
+ # accessory functions
+
+ def _remoteGlob(self, fullPath):
+ log.msg('looking up %s' % fullPath)
+ head, tail = os.path.split(fullPath)
+ if '*' in tail or '?' in tail:
+ glob = 1
+ else:
+ glob = 0
+ if tail and not glob: # could be file or directory
+ # try directory first
+ d = self.client.openDirectory(fullPath)
+ d.addCallback(self._cbOpenList, '')
+ d.addErrback(self._ebNotADirectory, head, tail)
+ else:
+ d = self.client.openDirectory(head)
+ d.addCallback(self._cbOpenList, tail)
+ return d
+
+ def _cbOpenList(self, directory, glob):
+ files = []
+ d = directory.read()
+ d.addBoth(self._cbReadFile, files, directory, glob)
+ return d
+
+ def _ebNotADirectory(self, reason, path, glob):
+ d = self.client.openDirectory(path)
+ d.addCallback(self._cbOpenList, glob)
+ return d
+
+ def _cbReadFile(self, files, l, directory, glob):
+ if not isinstance(files, failure.Failure):
+ if glob:
+ l.extend([f for f in files if fnmatch.fnmatch(f[0], glob)])
+ else:
+ l.extend(files)
+ d = directory.read()
+ d.addBoth(self._cbReadFile, l, directory, glob)
+ return d
+ else:
+ reason = files
+ reason.trap(EOFError)
+ directory.close()
+ return l
+
+ def _abbrevSize(self, size):
+ # from http://mail.python.org/pipermail/python-list/1999-December/018395.html
+ _abbrevs = [
+ (1<<50L, 'PB'),
+ (1<<40L, 'TB'),
+ (1<<30L, 'GB'),
+ (1<<20L, 'MB'),
+ (1<<10L, 'kb'),
+ (1, '')
+ ]
+
+ for factor, suffix in _abbrevs:
+ if size > factor:
+ break
+ return '%.1f' % (size/factor) + suffix
+
+ def _abbrevTime(self, t):
+ if t > 3600: # 1 hour
+ hours = int(t / 3600)
+ t -= (3600 * hours)
+ mins = int(t / 60)
+ t -= (60 * mins)
+ return "%i:%02i:%02i" % (hours, mins, t)
+ else:
+ mins = int(t/60)
+ t -= (60 * mins)
+ return "%02i:%02i" % (mins, t)
+
+ def _printProgessBar(self, f, startTime):
+ diff = time.time() - startTime
+ total = f.total
+ try:
+ winSize = struct.unpack('4H',
+ fcntl.ioctl(0, tty.TIOCGWINSZ, '12345679'))
+ except IOError:
+ winSize = [None, 80]
+ speed = total/diff
+ if speed:
+ timeLeft = (f.size - total) / speed
+ else:
+ timeLeft = 0
+ front = f.name
+ back = '%3i%% %s %sps %s ' % ((total/f.size)*100, self._abbrevSize(total),
+ self._abbrevSize(total/diff), self._abbrevTime(timeLeft))
+ spaces = (winSize[1] - (len(front) + len(back) + 1)) * ' '
+ self.transport.write('\r%s%s%s' % (front, spaces, back))
+
+ def _getFilename(self, line):
+ line.lstrip()
+ if not line:
+ return None, ''
+ if line[0] in '\'"':
+ ret = []
+ line = list(line)
+ try:
+ for i in range(1,len(line)):
+ c = line[i]
+ if c == line[0]:
+ return ''.join(ret), ''.join(line[i+1:]).lstrip()
+ elif c == '\\': # quoted character
+ del line[i]
+ if line[i] not in '\'"\\':
+ raise IndexError, "bad quote: \\%s" % line[i]
+ ret.append(line[i])
+ else:
+ ret.append(line[i])
+ except IndexError:
+ raise IndexError, "unterminated quote"
+ ret = line.split(None, 1)
+ if len(ret) == 1:
+ return ret[0], ''
+ else:
+ return ret
+
+
+class SSHConnection(connection.SSHConnection):
+
+ def __init__(self, client):
+ self._client = client
+ connection.SSHConnection.__init__(self)
+
+ def serviceStarted(self):
+ session = SSHSession()
+ session._client = self._client
+ self.openChannel(session)
+
+
+class SSHSession(channel.SSHChannel):
+
+ name = 'session'
+
+ def channelOpen(self, foo):
+ log.msg('session %s open' % self.id)
+ request = 'subsystem'
+ d = self.conn.sendRequest(self, request, \
+ common.NS('sftp'), wantReply=1)
+ d.addCallback(self._cbSubsystem)
+ d.addCallback(self._cbTellClient)
+ d.addErrback(self.cbTellClientError)
+
+ def closed(self):
+ print "closed session, should inform _client"
+ channel.SSHChannel.closed(self)
+
+ def _cbSubsystem(self, result):
+ self.client = filetransfer.FileTransferClient()
+ self.client.makeConnection(self)
+ self.dataReceived = self.client.dataReceived
+
+ def _cbTellClientError(self, result):
+ self._client.fileTransferError(result)
+
+ def _cbTellClient(self, result):
+ self._client.fileTransferClientConnected(self.client)
+
+ def extReceived(self, t, data):
+ if t==connection.EXTENDED_DATA_STDERR:
+ log.msg('got %s stderr data' % len(data))
+ sys.stderr.write(data)
+ sys.stderr.flush()
+
+ def eofReceived(self):
+ log.msg('got eof')
+
+ def closeReceived(self):
+ log.msg('remote side closed %s' % self)
+ self.conn.sendClose(self)
=== modified file 'twisted/conch/client/default.py'
--- a/twisted/conch/client/default.py 2005-10-26 09:06:33 +0000
+++ b/twisted/conch/client/default.py 2009-08-15 07:27:38 +0000
@@ -44,9 +44,10 @@
return defer.fail(ConchError('bad host key'))
print "Warning: Permanently added '%s' (%s) to the list of known hosts." % (khHost, {'ssh-dss':'DSA', 'ssh-rsa':'RSA'}[keyType])
known_hosts = open(os.path.expanduser('~/.ssh/known_hosts'), 'a+')
- known_hosts.seek(-1, 2)
- if known_hosts.read(1) != '\n':
- known_hosts.write('\n')
+ if known_hosts.tell() > 1:
+ known_hosts.seek(-1, 2)
+ if known_hosts.read(1) != '\n':
+ known_hosts.write('\n')
encodedKey = base64.encodestring(pubKey).replace('\n', '')
known_hosts.write('%s %s %s\n' % (khHost, keyType, encodedKey))
known_hosts.close()
More information about the bazaar-commits
mailing list