Rev 5299: (spiv) Use a socketpair to talk to the SSH client process. (Andrew Bennetts) in file:///home/pqm/archives/thelove/bzr/%2Btrunk/

Canonical.com Patch Queue Manager pqm at pqm.ubuntu.com
Thu Jun 17 03:06:25 BST 2010


At file:///home/pqm/archives/thelove/bzr/%2Btrunk/

------------------------------------------------------------
revno: 5299 [merge]
revision-id: pqm at pqm.ubuntu.com-20100617020620-8if212a9t349vbku
parent: pqm at pqm.ubuntu.com-20100616155835-w1l604cbw3du1jwr
parent: andrew.bennetts at canonical.com-20100616082019-juv3jrcr242oik98
committer: Canonical.com Patch Queue Manager <pqm at pqm.ubuntu.com>
branch nick: +trunk
timestamp: Thu 2010-06-17 03:06:20 +0100
message:
  (spiv) Use a socketpair to talk to the SSH client process. (Andrew Bennetts)
modified:
  NEWS                           NEWS-20050323055033-4e00b5db738777ff
  bzrlib/smart/medium.py         medium.py-20061103051856-rgu2huy59fkz902q-1
  bzrlib/tests/test_smart_transport.py test_ssh_transport.py-20060608202016-c25gvf1ob7ypbus6-2
  bzrlib/transport/remote.py     ssh.py-20060608202016-c25gvf1ob7ypbus6-1
  bzrlib/transport/ssh.py        ssh.py-20060824042150-0s9787kng6zv1nwq-1
=== modified file 'NEWS'
--- a/NEWS	2010-06-16 12:47:51 +0000
+++ b/NEWS	2010-06-17 02:06:20 +0000
@@ -69,6 +69,10 @@
 Improvements
 ************
 
+* Bazaar now reads data from SSH connections more efficiently on platforms
+  that provide the ``socketpair`` function, and when using paramiko.
+  (Andrew Bennetts, #590637)
+
 * ``Branch.copy_content_into`` is now a convenience method dispatching to
   a ``InterBranch`` multi-method. This permits ``bzr-loom`` and other
   plugins to intercept this even when a ``RemoteBranch`` proxy is in use.
@@ -97,6 +101,16 @@
   ``bzrlib.patiencediff`` instead.
   (Andrew Bennetts)
 
+* ``bzrlib.transport.ssh.SSHVendor.connect_ssh`` now returns an object
+  that implements the interface of ``bzrlib.transport.ssh.SSHConnection``.
+  Third-party implementations of ``SSHVendor`` may need to be updated
+  accordingly.  Similarly, any code using ``SSHConnection`` directly will
+  need to be updated.  (Andrew Bennetts)
+
+* The constructor of ``bzrilb.smart.medium.SmartSSHClientMedium`` has
+  changed to take an ``SSHParams`` instance (replacing many individual
+  values).  (Andrew Bennetts)
+
 Internals
 *********
 

=== modified file 'bzrlib/smart/medium.py'
--- a/bzrlib/smart/medium.py	2010-05-03 04:08:50 +0000
+++ b/bzrlib/smart/medium.py	2010-06-16 07:45:53 +0000
@@ -715,10 +715,6 @@
     """A client medium using simple pipes.
 
     This client does not manage the pipes: it assumes they will always be open.
-
-    Note that if readable_pipe.read might raise IOError or OSError with errno
-    of EINTR, it must be safe to retry the read.  Plain CPython fileobjects
-    (such as used for sys.stdin) are safe.
     """
 
     def __init__(self, readable_pipe, writeable_pipe, base):
@@ -737,26 +733,40 @@
 
     def _read_bytes(self, count):
         """See SmartClientStreamMedium._read_bytes."""
-        bytes = osutils.until_no_eintr(self._readable_pipe.read, count)
+        bytes_to_read = min(count, _MAX_READ_SIZE)
+        bytes = self._readable_pipe.read(bytes_to_read)
         self._report_activity(len(bytes), 'read')
         return bytes
 
 
+class SSHParams(object):
+    """A set of parameters for starting a remote bzr via SSH."""
+
+    def __init__(self, host, port=None, username=None, password=None,
+            bzr_remote_path='bzr'):
+        self.host = host
+        self.port = port
+        self.username = username
+        self.password = password
+        self.bzr_remote_path = bzr_remote_path
+
+
 class SmartSSHClientMedium(SmartClientStreamMedium):
-    """A client medium using SSH."""
+    """A client medium using SSH.
+    
+    It delegates IO to a SmartClientSocketMedium or
+    SmartClientAlreadyConnectedSocketMedium (depending on platform).
+    """
 
-    def __init__(self, host, port=None, username=None, password=None,
-            base=None, vendor=None, bzr_remote_path=None):
+    def __init__(self, base, ssh_params, vendor=None):
         """Creates a client that will connect on the first use.
 
+        :param ssh_params: A SSHParams instance.
         :param vendor: An optional override for the ssh vendor to use. See
             bzrlib.transport.ssh for details on ssh vendors.
         """
-        self._connected = False
-        self._host = host
-        self._password = password
-        self._port = port
-        self._username = username
+        self._real_medium = None
+        self._ssh_params = ssh_params
         # for the benefit of progress making a short description of this
         # transport
         self._scheme = 'bzr+ssh'
@@ -764,67 +774,70 @@
         # _DebugCounter so we have to store all the values used in our repr
         # method before calling the super init.
         SmartClientStreamMedium.__init__(self, base)
-        self._read_from = None
+        self._vendor = vendor
         self._ssh_connection = None
-        self._vendor = vendor
-        self._write_to = None
-        self._bzr_remote_path = bzr_remote_path
 
     def __repr__(self):
-        if self._port is None:
+        if self._ssh_params.port is None:
             maybe_port = ''
         else:
-            maybe_port = ':%s' % self._port
+            maybe_port = ':%s' % self._ssh_params.port
         return "%s(%s://%s@%s%s/)" % (
             self.__class__.__name__,
             self._scheme,
-            self._username,
-            self._host,
+            self._ssh_params.username,
+            self._ssh_params.host,
             maybe_port)
 
     def _accept_bytes(self, bytes):
         """See SmartClientStreamMedium.accept_bytes."""
         self._ensure_connection()
-        self._write_to.write(bytes)
-        self._report_activity(len(bytes), 'write')
+        self._real_medium.accept_bytes(bytes)
 
     def disconnect(self):
         """See SmartClientMedium.disconnect()."""
-        if not self._connected:
-            return
-        self._read_from.close()
-        self._write_to.close()
-        self._ssh_connection.close()
-        self._connected = False
+        if self._real_medium is not None:
+            self._real_medium.disconnect()
+            self._real_medium = None
+        if self._ssh_connection is not None:
+            self._ssh_connection.close()
+            self._ssh_connection = None
 
     def _ensure_connection(self):
         """Connect this medium if not already connected."""
-        if self._connected:
+        if self._real_medium is not None:
             return
         if self._vendor is None:
             vendor = ssh._get_ssh_vendor()
         else:
             vendor = self._vendor
-        self._ssh_connection = vendor.connect_ssh(self._username,
-                self._password, self._host, self._port,
-                command=[self._bzr_remote_path, 'serve', '--inet',
+        self._ssh_connection = vendor.connect_ssh(self._ssh_params.username,
+                self._ssh_params.password, self._ssh_params.host,
+                self._ssh_params.port,
+                command=[self._ssh_params.bzr_remote_path, 'serve', '--inet',
                          '--directory=/', '--allow-writes'])
-        self._read_from, self._write_to = \
-            self._ssh_connection.get_filelike_channels()
-        self._connected = True
+        io_kind, io_object = self._ssh_connection.get_sock_or_pipes()
+        if io_kind == 'socket':
+            self._real_medium = SmartClientAlreadyConnectedSocketMedium(
+                self.base, io_object)
+        elif io_kind == 'pipes':
+            read_from, write_to = io_object
+            self._real_medium = SmartSimplePipesClientMedium(
+                read_from, write_to, self.base)
+        else:
+            raise AssertionError(
+                "Unexpected io_kind %r from %r"
+                % (io_kind, self._ssh_connection))
 
     def _flush(self):
         """See SmartClientStreamMedium._flush()."""
-        self._write_to.flush()
+        self._real_medium._flush()
 
     def _read_bytes(self, count):
         """See SmartClientStreamMedium.read_bytes."""
-        if not self._connected:
+        if self._real_medium is None:
             raise errors.MediumNotConnected(self)
-        bytes_to_read = min(count, _MAX_READ_SIZE)
-        bytes = self._read_from.read(bytes_to_read)
-        self._report_activity(len(bytes), 'read')
-        return bytes
+        return self._real_medium.read_bytes(count)
 
 
 # Port 4155 is the default port for bzr://, registered with IANA.
@@ -832,22 +845,41 @@
 BZR_DEFAULT_PORT = 4155
 
 
-class SmartTCPClientMedium(SmartClientStreamMedium):
-    """A client medium using TCP."""
+class SmartClientSocketMedium(SmartClientStreamMedium):
+    """A client medium using a socket.
+    
+    This class isn't usable directly.  Use one of its subclasses instead.
+    """
 
-    def __init__(self, host, port, base):
-        """Creates a client that will connect on the first use."""
+    def __init__(self, base):
         SmartClientStreamMedium.__init__(self, base)
+        self._socket = None
         self._connected = False
-        self._host = host
-        self._port = port
-        self._socket = None
 
     def _accept_bytes(self, bytes):
         """See SmartClientMedium.accept_bytes."""
         self._ensure_connection()
         osutils.send_all(self._socket, bytes, self._report_activity)
 
+    def _ensure_connection(self):
+        """Connect this medium if not already connected."""
+        raise NotImplementedError(self._ensure_connection)
+
+    def _flush(self):
+        """See SmartClientStreamMedium._flush().
+
+        For sockets we do no flushing. For TCP sockets we may want to turn off
+        TCP_NODELAY and add a means to do a flush, but that can be done in the
+        future.
+        """
+
+    def _read_bytes(self, count):
+        """See SmartClientMedium.read_bytes."""
+        if not self._connected:
+            raise errors.MediumNotConnected(self)
+        return osutils.read_bytes_from_socket(
+            self._socket, self._report_activity)
+
     def disconnect(self):
         """See SmartClientMedium.disconnect()."""
         if not self._connected:
@@ -856,6 +888,16 @@
         self._socket = None
         self._connected = False
 
+
+class SmartTCPClientMedium(SmartClientSocketMedium):
+    """A client medium that creates a TCP connection."""
+
+    def __init__(self, host, port, base):
+        """Creates a client that will connect on the first use."""
+        SmartClientSocketMedium.__init__(self, base)
+        self._host = host
+        self._port = port
+
     def _ensure_connection(self):
         """Connect this medium if not already connected."""
         if self._connected:
@@ -895,19 +937,22 @@
                     (self._host, port, err_msg))
         self._connected = True
 
-    def _flush(self):
-        """See SmartClientStreamMedium._flush().
-
-        For TCP we do no flushing. We may want to turn off TCP_NODELAY and
-        add a means to do a flush, but that can be done in the future.
-        """
-
-    def _read_bytes(self, count):
-        """See SmartClientMedium.read_bytes."""
-        if not self._connected:
-            raise errors.MediumNotConnected(self)
-        return osutils.read_bytes_from_socket(
-            self._socket, self._report_activity)
+
+class SmartClientAlreadyConnectedSocketMedium(SmartClientSocketMedium):
+    """A client medium for an already connected socket.
+    
+    Note that this class will assume it "owns" the socket, so it will close it
+    when its disconnect method is called.
+    """
+
+    def __init__(self, base, sock):
+        SmartClientSocketMedium.__init__(self, base)
+        self._socket = sock
+        self._connected = True
+
+    def _ensure_connection(self):
+        # Already connected, by definition!  So nothing to do.
+        pass
 
 
 class SmartClientStreamMediumRequest(SmartClientMediumRequest):

=== modified file 'bzrlib/tests/test_smart_transport.py'
--- a/bzrlib/tests/test_smart_transport.py	2010-03-18 23:11:15 +0000
+++ b/bzrlib/tests/test_smart_transport.py	2010-06-16 05:47:02 +0000
@@ -46,6 +46,7 @@
         local,
         memory,
         remote,
+        ssh,
         )
 
 
@@ -63,7 +64,7 @@
         return StringIOSSHConnection(self)
 
 
-class StringIOSSHConnection(object):
+class StringIOSSHConnection(ssh.SSHConnection):
     """A SSH connection that uses StringIO to buffer writes and answer reads."""
 
     def __init__(self, vendor):
@@ -71,9 +72,11 @@
 
     def close(self):
         self.vendor.calls.append(('close', ))
+        self.vendor.read_from.close()
+        self.vendor.write_to.close()
 
-    def get_filelike_channels(self):
-        return self.vendor.read_from, self.vendor.write_to
+    def get_sock_or_pipes(self):
+        return 'pipes', (self.vendor.read_from, self.vendor.write_to)
 
 
 class _InvalidHostnameFeature(tests.Feature):
@@ -243,9 +246,9 @@
         unopened_port = sock.getsockname()[1]
         # having vendor be invalid means that if it tries to connect via the
         # vendor it will blow up.
-        client_medium = medium.SmartSSHClientMedium('127.0.0.1', unopened_port,
-            username=None, password=None, base='base', vendor="not a vendor",
-            bzr_remote_path='bzr')
+        ssh_params = medium.SSHParams('127.0.0.1', unopened_port, None, None)
+        client_medium = medium.SmartSSHClientMedium(
+            'base', ssh_params, "not a vendor")
         sock.close()
 
     def test_ssh_client_connects_on_first_use(self):
@@ -253,9 +256,9 @@
         # it bytes.
         output = StringIO()
         vendor = StringIOSSHVendor(StringIO(), output)
-        client_medium = medium.SmartSSHClientMedium(
-            'a hostname', 'a port', 'a username', 'a password', 'base', vendor,
-            'bzr')
+        ssh_params = medium.SSHParams(
+            'a hostname', 'a port', 'a username', 'a password', 'bzr')
+        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
         client_medium._accept_bytes('abc')
         self.assertEqual('abc', output.getvalue())
         self.assertEqual([('connect_ssh', 'a username', 'a password',
@@ -268,8 +271,10 @@
         # it bytes.
         output = StringIO()
         vendor = StringIOSSHVendor(StringIO(), output)
-        client_medium = medium.SmartSSHClientMedium('a hostname', 'a port',
-            'a username', 'a password', 'base', vendor, bzr_remote_path='fugly')
+        ssh_params = medium.SSHParams(
+            'a hostname', 'a port', 'a username', 'a password',
+            bzr_remote_path='fugly')
+        client_medium = medium.SmartSSHClientMedium('base', ssh_params, vendor)
         client_medium._accept_bytes('abc')
         self.assertEqual('abc', output.getvalue())
         self.assertEqual([('connect_ssh', 'a username', 'a password',
@@ -284,7 +289,7 @@
         output = StringIO()
         vendor = StringIOSSHVendor(input, output)
         client_medium = medium.SmartSSHClientMedium(
-            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
+            'base', medium.SSHParams('a hostname'), vendor)
         client_medium._accept_bytes('abc')
         client_medium.disconnect()
         self.assertTrue(input.closed)
@@ -305,7 +310,7 @@
         output = StringIO()
         vendor = StringIOSSHVendor(input, output)
         client_medium = medium.SmartSSHClientMedium(
-            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
+            'base', medium.SSHParams('a hostname'), vendor)
         client_medium._accept_bytes('abc')
         client_medium.disconnect()
         # the disconnect has closed output, so we need a new output for the
@@ -334,14 +339,14 @@
         # Doing a disconnect on a new (and thus unconnected) SSH medium
         # does not fail.  It's ok to disconnect an unconnected medium.
         client_medium = medium.SmartSSHClientMedium(
-            None, base='base', bzr_remote_path='bzr')
+            'base', medium.SSHParams(None))
         client_medium.disconnect()
 
     def test_ssh_client_raises_on_read_when_not_connected(self):
         # Doing a read on a new (and thus unconnected) SSH medium raises
         # MediumNotConnected.
         client_medium = medium.SmartSSHClientMedium(
-            None, base='base', bzr_remote_path='bzr')
+            'base', medium.SSHParams(None))
         self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
                           0)
         self.assertRaises(errors.MediumNotConnected, client_medium.read_bytes,
@@ -359,7 +364,7 @@
         output.flush = logging_flush
         vendor = StringIOSSHVendor(input, output)
         client_medium = medium.SmartSSHClientMedium(
-            'a hostname', base='base', vendor=vendor, bzr_remote_path='bzr')
+            'base', medium.SSHParams('a hostname'), vendor=vendor)
         # this call is here to ensure we only flush once, not on every
         # _accept_bytes call.
         client_medium._accept_bytes('abc')

=== modified file 'bzrlib/transport/remote.py'
--- a/bzrlib/transport/remote.py	2010-02-09 20:28:26 +0000
+++ b/bzrlib/transport/remote.py	2010-06-16 05:47:02 +0000
@@ -514,9 +514,9 @@
         if user is None:
             auth = config.AuthenticationConfig()
             user = auth.get_user('ssh', self._host, self._port)
-        client_medium = medium.SmartSSHClientMedium(self._host, self._port,
-            user, self._password, self.base,
-            bzr_remote_path=bzr_remote_path)
+        ssh_params = medium.SSHParams(self._host, self._port, user,
+            self._password, bzr_remote_path)
+        client_medium = medium.SmartSSHClientMedium(self.base, ssh_params)
         return client_medium, (user, self._password)
 
 

=== modified file 'bzrlib/transport/ssh.py'
--- a/bzrlib/transport/ssh.py	2010-03-02 06:41:36 +0000
+++ b/bzrlib/transport/ssh.py	2010-06-16 06:21:34 +0000
@@ -239,8 +239,7 @@
     def connect_ssh(self, username, password, host, port, command):
         """Make an SSH connection.
 
-        :returns: something with a `close` method, and a `get_filelike_channels`
-            method that returns a pair of (read, write) filelike objects.
+        :returns: an SSHConnection.
         """
         raise NotImplementedError(self.connect_ssh)
 
@@ -269,17 +268,6 @@
 register_ssh_vendor('loopback', LoopbackVendor())
 
 
-class _ParamikoSSHConnection(object):
-    def __init__(self, channel):
-        self.channel = channel
-
-    def get_filelike_channels(self):
-        return self.channel.makefile('rb'), self.channel.makefile('wb')
-
-    def close(self):
-        return self.channel.close()
-
-
 class ParamikoVendor(SSHVendor):
     """Vendor that uses paramiko."""
 
@@ -363,11 +351,23 @@
     """Abstract base class for vendors that use pipes to a subprocess."""
 
     def _connect(self, argv):
-        proc = subprocess.Popen(argv,
-                                stdin=subprocess.PIPE,
-                                stdout=subprocess.PIPE,
+        # Attempt to make a socketpair to use as stdin/stdout for the SSH
+        # subprocess.  We prefer sockets to pipes because they support
+        # non-blocking short reads, allowing us to optimistically read 64k (or
+        # whatever) chunks.
+        try:
+            my_sock, subproc_sock = socket.socketpair()
+        except (AttributeError, socket.error):
+            # This platform doesn't support socketpair(), so just use ordinary
+            # pipes instead.
+            stdin = stdout = subprocess.PIPE
+            sock = None
+        else:
+            stdin = stdout = subproc_sock
+            sock = my_sock
+        proc = subprocess.Popen(argv, stdin=stdin, stdout=stdout,
                                 **os_specific_subprocess_params())
-        return SSHSubprocess(proc)
+        return SSHSubprocessConnection(proc, sock=sock)
 
     def connect_sftp(self, username, password, host, port):
         try:
@@ -652,11 +652,40 @@
             pass
 
 
-class SSHSubprocess(object):
-    """A socket-like object that talks to an ssh subprocess via pipes."""
-
-    def __init__(self, proc):
+class SSHConnection(object):
+    """Abstract base class for SSH connections."""
+
+    def get_sock_or_pipes(self):
+        """Returns a (kind, io_object) pair.
+
+        If kind == 'socket', then io_object is a socket.
+
+        If kind == 'pipes', then io_object is a pair of file-like objects
+        (read_from, write_to).
+        """
+        raise NotImplementedError(self.get_sock_or_pipes)
+
+    def close(self):
+        raise NotImplementedError(self.close)
+
+
+class SSHSubprocessConnection(SSHConnection):
+    """A connection to an ssh subprocess via pipes or a socket.
+
+    This class is also socket-like enough to be used with
+    SocketAsChannelAdapter (it has 'send' and 'recv' methods).
+    """
+
+    def __init__(self, proc, sock=None):
+        """Constructor.
+
+        :param proc: a subprocess.Popen
+        :param sock: if proc.stdin/out is a socket from a socketpair, then sock
+            should bzrlib's half of that socketpair.  If not passed, proc's
+            stdin/out is assumed to be ordinary pipes.
+        """
         self.proc = proc
+        self._sock = sock
         # Add a weakref to proc that will attempt to do the same as self.close
         # to avoid leaving processes lingering indefinitely.
         def terminate(ref):
@@ -665,14 +694,37 @@
         _subproc_weakrefs.add(weakref.ref(self, terminate))
 
     def send(self, data):
-        return os.write(self.proc.stdin.fileno(), data)
+        if self._sock is not None:
+            return self._sock.send(data)
+        else:
+            return os.write(self.proc.stdin.fileno(), data)
 
     def recv(self, count):
-        return os.read(self.proc.stdout.fileno(), count)
+        if self._sock is not None:
+            return self._sock.read(count)
+        else:
+            return os.read(self.proc.stdout.fileno(), count)
 
     def close(self):
         _close_ssh_proc(self.proc)
 
-    def get_filelike_channels(self):
-        return (self.proc.stdout, self.proc.stdin)
+    def get_sock_or_pipes(self):
+        if self._sock is not None:
+            return 'socket', self._sock
+        else:
+            return 'pipes', (self.proc.stdout, self.proc.stdin)
+
+
+class _ParamikoSSHConnection(SSHConnection):
+    """An SSH connection via paramiko."""
+
+    def __init__(self, channel):
+        self.channel = channel
+
+    def get_sock_or_pipes(self):
+        return ('socket', self.channel)
+
+    def close(self):
+        return self.channel.close()
+
 




More information about the bazaar-commits mailing list