=== modified file 'bzrlib/transport/sftp.py' --- bzrlib/transport/sftp.py 2006-06-12 14:39:27 +0000 +++ bzrlib/transport/sftp.py 2006-06-17 07:03:02 +0000 @@ -22,6 +22,7 @@ import os import random import re +import select import stat import subprocess import sys @@ -302,6 +303,7 @@ # What specific errors should we catch here? pass + class SFTPTransport (Transport): """ Transport implementation for SFTP access. @@ -887,9 +889,9 @@ nvuQES5C9BMHjF39LZiGH1iLQy7FgdHyoP+eodI7 -----END RSA PRIVATE KEY----- """ - - -class SingleListener(threading.Thread): + + +class SocketListener(threading.Thread): def __init__(self, callback): threading.Thread.__init__(self) @@ -899,25 +901,35 @@ self._socket.bind(('localhost', 0)) self._socket.listen(1) self.port = self._socket.getsockname()[1] - self.stop_event = threading.Event() - - def run(self): - s, _ = self._socket.accept() - # now close the listen socket - self._socket.close() - try: - self._callback(s, self.stop_event) - except socket.error: - pass #Ignore socket errors - except Exception, x: - # probably a failed test - warning('Exception from within unit test server thread: %r' % x) + self._stop_event = threading.Event() def stop(self): - self.stop_event.set() + # called from outside this thread + self._stop_event.set() # use a timeout here, because if the test fails, the server thread may # never notice the stop_event. self.join(5.0) + self._socket.close() + + def run(self): + while True: + readable, writable_unused, exception_unused = select.select([self._socket], [], [], 0.1) + if self._stop_event.isSet(): + return + if len(readable) == 0: + continue + try: + s, addr_unused = self._socket.accept() + # because the loopback socket is inline, and transports are + # never explicitly closed, best to launch a new thread. + threading.Thread(target=self._callback, args=(s,)).start() + except socket.error, x: + sys.excepthook(*sys.exc_info()) + warning('Socket error during accept() within unit test server thread: %r' % x) + except Exception, x: + # probably a failed test; unit test thread will log the failure/error + sys.excepthook(*sys.exc_info()) + warning('Exception from within unit test server thread: %r' % x) class SFTPServer(Server): @@ -941,10 +953,12 @@ """StubServer uses this to log when a new server is created.""" self.logs.append(message) - def _run_server(self, s, stop_event): + def _run_server(self, s): ssh_server = paramiko.Transport(s) key_file = os.path.join(self._homedir, 'test_rsa.key') - file(key_file, 'w').write(STUB_SERVER_KEY) + f = open(key_file, 'w') + f.write(STUB_SERVER_KEY) + f.close() host_key = paramiko.RSAKey.from_private_key_file(key_file) ssh_server.add_server_key(host_key) server = StubServer(self) @@ -954,7 +968,6 @@ event = threading.Event() ssh_server.start_server(event, server) event.wait(5.0) - stop_event.wait(30.0) def setUp(self): global _ssh_vendor @@ -965,7 +978,7 @@ self._server_homedir = self._homedir self._root = '/' # FIXME WINDOWS: _root should be _server_homedir[0]:/ - self._listener = SingleListener(self._run_server) + self._listener = SocketListener(self._run_server) self._listener.setDaemon(True) self._listener.start() @@ -998,7 +1011,7 @@ super(SFTPServerWithoutSSH, self).__init__() self._vendor = 'loopback' - def _run_server(self, sock, stop_event): + def _run_server(self, sock): class FakeChannel(object): def get_transport(self): return self