xpra icon
Bug tracker and wiki

Ticket #2139: cythonize-protocol.patch

File cythonize-protocol.patch, 59.2 KB (added by Antoine Martin, 16 months ago)

try to speedup protocol using cython

  • setup.py

     
    23322332                **v4l2_pkgconfig))
    23332333
    23342334
     2335if False:
     2336    protocol_pkgconfig = pkgconfig(optimize=3)
     2337    cython_add(Extension("xpra.net.protocol",
     2338                ["xpra/net/protocol.pyx"],
     2339                **protocol_pkgconfig))
     2340
    23352341toggle_packages(bencode_ENABLED, "xpra.net.bencode")
    23362342toggle_packages(bencode_ENABLED and cython_bencode_ENABLED, "xpra.net.bencode.cython_bencode")
    23372343if cython_bencode_ENABLED:
  • unittests/unit/net/protocol_test.py

     
    7676    def do_read_parse_thread_loop(self):
    7777        with self.profiling_context("read-parse-thread"):
    7878            Protocol.do_read_parse_thread_loop(self)
    79        
    8079
    8180
    8281class ProtocolTest(unittest.TestCase):
     
    8988            protocol_class = Protocol
    9089        p = protocol_class(glib, conn, process_packet_cb, get_packet_cb=get_packet_cb)
    9190        #p = Protocol(glib, conn, process_packet_cb, get_packet_cb=get_packet_cb)
    92         p.read_buffer_size = read_buffer_size
    93         p.hangup_delay = hangup_delay
     91        p.set_read_buffer_size(read_buffer_size)
     92        p.set_hangup_delay(hangup_delay)
    9493        return p
    9594
    9695    def test_invalid_data(self):
     
    101100        errs = []
    102101        protocol = self.make_memory_protocol(data)
    103102        def check_failed():
    104             if not protocol._closed:
     103            if not protocol.is_closed():
    105104                errs.append("protocol not closed")
    106             if protocol.input_packetcount>0:
     105            if protocol.get_input_packetcount()>0:
    107106                errs.append("processed %i packets" % protocol.input_packetcount)
    108             if protocol.input_raw_packetcount==0:
     107            if protocol.get_input_packetcount()==0:
    109108                errs.append("not read any raw packets")
    110109        loop = glib.MainLoop()
    111110        glib.timeout_add(200, check_failed)
     
    188187
    189188    def test_format_thread(self):
    190189        packets = self.make_test_packets()
    191         N = 500
     190        N = 5000
    192191        many = self.repeat_list(packets, N)
    193192        def get_packet_cb():
    194193            #log.info("get_packet_cb")
     
    202201            if packet[0]==Protocol.CONNECTION_LOST:
    203202                glib.timeout_add(1000, loop.quit)
    204203        protocol = self.make_memory_protocol(None, process_packet_cb=process_packet_cb, get_packet_cb=get_packet_cb)
    205         conn = protocol._conn
     204        conn = protocol.get_connection()
    206205        loop = glib.MainLoop()
    207206        glib.timeout_add(TIMEOUT*1000, loop.quit)
    208207        start = monotonic_time()
     
    211210        protocol.start()
    212211        protocol.source_has_more()
    213212        loop.run()
    214         assert protocol._closed
     213        assert protocol.is_closed()
    215214        end = monotonic_time()
    216215        log("protocol: %s", protocol)
    217216        log("%s write-data=%s", conn, len(conn.write_data))
  • xpra/net/protocol.py

     
    150150            self._process_packet_cb = process_packet_cb
    151151        self.make_chunk_header = self.make_xpra_header
    152152        self.make_frame_header = self.noframe_header
    153         self._write_queue = Queue(1)
     153        self._write_queue = Queue(2)
    154154        self._read_queue = Queue(20)
    155155        self._process_read = self.read_queue_put
    156156        self._read_queue_put = self.read_queue_put
     
    212212        self.enable_encoder(self.encoder)
    213213
    214214
     215    def set_read_buffer_size(self, n):
     216        self.read_buffer_size = n
     217    def set_hangup_delay(self, n):
     218        self.hangup_delay = n
     219    def get_connection(self):
     220        return self._conn
     221
     222    def get_input_packetcount(self):
     223        return self.input_packetcount
     224
     225    def is_closed(self):
     226        return self._closed
     227
     228
    215229    def wait_for_io_threads_exit(self, timeout=None):
    216230        io_threads = [x for x in (self._read_thread, self._write_thread) if x is not None]
    217231        for t in io_threads:
     
    565579                        #so we must tell it how to do that and pass the level flag
    566580                        il = item.level
    567581                    packets.append((0, i, il, item.data))
    568                     packet[i] = ''
     582                    packet[i] = b''
    569583                else:
    570584                    #data is small enough, inline it:
    571585                    packet[i] = item.data
     
    685699    def write_buffers(self, buf_data, _fail_cb, _synchronous):
    686700        con = self._conn
    687701        if not con:
    688             return 0
     702            return
    689703        for buf in buf_data:
    690704            while buf and not self._closed:
    691705                written = con.write(buf)
  • xpra/net/protocol.pyx

     
     1# This file is part of Xpra.
     2# Copyright (C) 2011-2019 Antoine Martin <antoine@xpra.org>
     3# Copyright (C) 2008, 2009, 2010 Nathaniel Smith <njs@pobox.com>
     4# Xpra is released under the terms of the GNU GPL v2, or, at your option, any
     5# later version. See the file COPYING for details.
     6
     7# oh gods it's threads
     8
     9# but it works on win32, for whatever that's worth.
     10
     11import os
     12from socket import error as socket_error
     13from threading import Lock, Event
     14
     15from xpra.os_util import PYTHON3, Queue, memoryview_to_bytes, strtobytes, bytestostr, hexstr
     16from xpra.util import repr_ellipsized, csv, envint, envbool
     17from xpra.make_thread import make_thread, start_thread
     18from xpra.net.common import ConnectionClosedException          #@UndefinedVariable (pydev false positive)
     19from xpra.net.bytestreams import ABORT
     20from xpra.net import compression
     21from xpra.net.compression import (
     22    decompress, sanity_checks as compression_sanity_checks,
     23    InvalidCompressionException, Compressed, LevelCompressed, Compressible, LargeStructure,
     24    )
     25from xpra.net import packet_encoding
     26from xpra.net.packet_encoding import (
     27    decode, sanity_checks as packet_encoding_sanity_checks,
     28    InvalidPacketEncodingException,
     29    )
     30from xpra.net.header import unpack_header, pack_header, FLAGS_CIPHER, FLAGS_NOHEADER, HEADER_SIZE
     31from xpra.net.crypto import get_encryptor, get_decryptor, pad, INITIAL_PADDING
     32from xpra.log import Logger
     33
     34log = Logger("network", "protocol")
     35cryptolog = Logger("network", "crypto")
     36
     37
     38#stupid python version breakage:
     39JOIN_TYPES = (str, bytes)
     40if PYTHON3:
     41    long = int              #@ReservedAssignment
     42    unicode = str           #@ReservedAssignment
     43    JOIN_TYPES = (bytes, )
     44
     45
     46USE_ALIASES = envbool("XPRA_USE_ALIASES", True)
     47READ_BUFFER_SIZE = envint("XPRA_READ_BUFFER_SIZE", 65536)
     48#merge header and packet if packet is smaller than:
     49PACKET_JOIN_SIZE = envint("XPRA_PACKET_JOIN_SIZE", READ_BUFFER_SIZE)
     50LARGE_PACKET_SIZE = envint("XPRA_LARGE_PACKET_SIZE", 4096)
     51LOG_RAW_PACKET_SIZE = envbool("XPRA_LOG_RAW_PACKET_SIZE", False)
     52#inline compressed data in packet if smaller than:
     53INLINE_SIZE = envint("XPRA_INLINE_SIZE", 32768)
     54FAKE_JITTER = envint("XPRA_FAKE_JITTER", 0)
     55MIN_COMPRESS_SIZE = envint("XPRA_MIN_COMPRESS_SIZE", 378)
     56SEND_INVALID_PACKET = envint("XPRA_SEND_INVALID_PACKET", 0)
     57SEND_INVALID_PACKET_DATA = strtobytes(os.environ.get("XPRA_SEND_INVALID_PACKET_DATA", b"ZZinvalid-packetZZ"))
     58
     59
     60def sanity_checks():
     61    """ warns the user if important modules are missing """
     62    compression_sanity_checks()
     63    packet_encoding_sanity_checks()
     64
     65
     66def exit_queue():
     67    queue = Queue()
     68    for _ in range(10):     #just 2 should be enough!
     69        queue.put(None)
     70    return queue
     71
     72def force_flush_queue(q):
     73    try:
     74        #discard all elements in the old queue and push the None marker:
     75        try:
     76            while q.qsize()>0:
     77                q.read(False)
     78        except:
     79            pass
     80        q.put_nowait(None)
     81    except:
     82        pass
     83
     84
     85def verify_packet(packet):
     86    """ look for None values which may have caused the packet to fail encoding """
     87    if type(packet)!=list:
     88        return False
     89    assert len(packet)>0, "invalid packet: %s" % packet
     90    tree = ["'%s' packet" % packet[0]]
     91    return do_verify_packet(tree, packet)
     92
     93def do_verify_packet(tree, packet):
     94    def err(msg):
     95        log.error("%s in %s", msg, "->".join(tree))
     96    def new_tree(append):
     97        nt = tree[:]
     98        nt.append(append)
     99        return nt
     100    if packet is None:
     101        err("None value")
     102        return False
     103    r = True
     104    if type(packet) in (list, tuple):
     105        for i, x in enumerate(packet):
     106            if not do_verify_packet(new_tree("[%s]" % i), x):
     107                r = False
     108    elif type(packet)==dict:
     109        for k,v in packet.items():
     110            if not do_verify_packet(new_tree("key for value='%s'" % str(v)), k):
     111                r = False
     112            if not do_verify_packet(new_tree("value for key='%s'" % str(k)), v):
     113                r = False
     114    elif type(packet) in (int, bool, str, bytes, unicode):
     115        pass
     116    else:
     117        err("unsupported type: %s" % type(packet))
     118        r = False
     119    return r
     120
     121
     122cdef class Protocol:
     123    cdef object timeout_add
     124    cdef object idle_add
     125    cdef object source_remove
     126    cdef unsigned int read_buffer_size
     127    cdef unsigned int hangup_delay
     128    cdef object _conn
     129    cdef object _process_packet_cb
     130    cdef object make_chunk_header
     131    cdef object make_frame_header
     132    cdef object _write_queue
     133    cdef object _read_queue
     134    cdef object _process_read
     135    cdef object _read_queue_put
     136    cdef object _get_packet_cb
     137    cdef object input_stats
     138    cdef unsigned int input_packetcount
     139    cdef unsigned int input_raw_packetcount
     140    cdef object output_stats
     141    cdef unsigned int output_packetcount
     142    cdef unsigned int output_raw_packetcount
     143    #initial value which may get increased by client/server after handshake:
     144    cdef unsigned int max_packet_size
     145    cdef unsigned int abs_max_packet_size
     146    cdef object large_packets
     147    cdef object send_aliases
     148    cdef object receive_aliases
     149    cdef object _log_stats
     150    cdef object _closed
     151    cdef object encoder
     152    cdef object _encoder
     153    cdef object compressor
     154    cdef object _compress
     155    cdef unsigned int compression_level
     156    cdef object cipher_in
     157    cdef object cipher_in_name
     158    cdef unsigned int cipher_in_block_size
     159    cdef object cipher_in_padding
     160    cdef object cipher_out
     161    cdef object cipher_out_name
     162    cdef unsigned int cipher_out_block_size
     163    cdef object cipher_out_padding
     164    cdef object _write_lock
     165    cdef object _write_thread
     166    cdef object _read_thread
     167    cdef object _read_parser_thread
     168    cdef object _write_format_thread
     169    cdef object _source_has_more_event
     170   
     171    """
     172        This class handles sending and receiving packets,
     173        it will encode and compress them before sending,
     174        and decompress and decode when receiving.
     175    """
     176
     177    CONNECTION_LOST = "connection-lost"
     178    GIBBERISH = "gibberish"
     179    INVALID = "invalid"
     180
     181    def __init__(self, scheduler, conn, process_packet_cb, get_packet_cb=None):
     182        """
     183            You must call this constructor and source_has_more() from the main thread.
     184        """
     185        assert scheduler is not None
     186        assert conn is not None
     187        self.timeout_add = scheduler.timeout_add
     188        self.idle_add = scheduler.idle_add
     189        self.source_remove = scheduler.source_remove
     190        self.read_buffer_size = READ_BUFFER_SIZE
     191        self.hangup_delay = 1000
     192        self._conn = conn
     193        if FAKE_JITTER>0:
     194            from xpra.net.fake_jitter import FakeJitter
     195            fj = FakeJitter(self.timeout_add, process_packet_cb)
     196            self._process_packet_cb =  fj.process_packet_cb
     197        else:
     198            self._process_packet_cb = process_packet_cb
     199        self.make_chunk_header = self.make_xpra_header
     200        self.make_frame_header = self.noframe_header
     201        self._write_queue = Queue(1)
     202        self._read_queue = Queue(20)
     203        self._process_read = self.read_queue_put
     204        self._read_queue_put = self.read_queue_put
     205        # Invariant: if .source is None, then _source_has_more_event == False
     206        self._get_packet_cb = get_packet_cb
     207        #counters:
     208        self.input_stats = {}
     209        self.input_packetcount = 0
     210        self.input_raw_packetcount = 0
     211        self.output_stats = {}
     212        self.output_packetcount = 0
     213        self.output_raw_packetcount = 0
     214        #initial value which may get increased by client/server after handshake:
     215        self.max_packet_size = 4*1024*1024
     216        self.abs_max_packet_size = 256*1024*1024
     217        self.large_packets = [b"hello", b"window-metadata", b"sound-data", b"notify_show"]
     218        self.send_aliases = {}
     219        self.receive_aliases = {}
     220        self._log_stats = None          #None here means auto-detect
     221        self._closed = False
     222        self.encoder = "none"
     223        self._encoder = self.noencode
     224        self.compressor = "none"
     225        self._compress = compression.nocompress
     226        self.compression_level = 0
     227        self.cipher_in = None
     228        self.cipher_in_name = None
     229        self.cipher_in_block_size = 0
     230        self.cipher_in_padding = INITIAL_PADDING
     231        self.cipher_out = None
     232        self.cipher_out_name = None
     233        self.cipher_out_block_size = 0
     234        self.cipher_out_padding = INITIAL_PADDING
     235        self._write_lock = Lock()
     236        self._write_thread = None
     237        self._read_thread = make_thread(self._read_thread_loop, "read", daemon=True)
     238        self._read_parser_thread = None         #started when needed
     239        self._write_format_thread = None        #started when needed
     240        self._source_has_more_event = Event()
     241
     242    STATE_FIELDS = ("max_packet_size", "large_packets", "send_aliases", "receive_aliases",
     243                    "cipher_in", "cipher_in_name", "cipher_in_block_size", "cipher_in_padding",
     244                    "cipher_out", "cipher_out_name", "cipher_out_block_size", "cipher_out_padding",
     245                    "compression_level", "encoder", "compressor")
     246
     247    def save_state(self):
     248        state = {}
     249        for x in Protocol.STATE_FIELDS:
     250            state[x] = getattr(self, x)
     251        return state
     252
     253    def restore_state(self, state):
     254        assert state is not None
     255        for x in Protocol.STATE_FIELDS:
     256            assert x in state, "field %s is missing" % x
     257            setattr(self, x, state[x])
     258        #special handling for compressor / encoder which are named objects:
     259        self.enable_compressor(self.compressor)
     260        self.enable_encoder(self.encoder)
     261
     262
     263    def set_read_buffer_size(self, n):
     264        self.read_buffer_size = n
     265    def set_hangup_delay(self, n):
     266        self.hangup_delay = n
     267    def get_connection(self):
     268        return self._conn
     269
     270    def get_input_packetcount(self):
     271        return self.input_packetcount
     272
     273    def is_closed(self):
     274        return self._closed
     275
     276    def wait_for_io_threads_exit(self, timeout=None):
     277        io_threads = [x for x in (self._read_thread, self._write_thread) if x is not None]
     278        for t in io_threads:
     279            if t.isAlive():
     280                t.join(timeout)
     281        exited = True
     282        cinfo = self._conn or "cleared connection"
     283        for t in io_threads:
     284            if t.isAlive():
     285                log.warn("Warning: %s thread of %s is still alive (timeout=%s)", t.name, cinfo, timeout)
     286                exited = False
     287        return exited
     288
     289    def set_packet_source(self, get_packet_cb):
     290        self._get_packet_cb = get_packet_cb
     291
     292
     293    def set_cipher_in(self, ciphername, iv, password, key_salt, iterations, padding):
     294        cryptolog("set_cipher_in%s", (ciphername, iv, password, key_salt, iterations))
     295        self.cipher_in, self.cipher_in_block_size = get_decryptor(ciphername, iv, password, key_salt, iterations)
     296        self.cipher_in_padding = padding
     297        if self.cipher_in_name!=ciphername:
     298            cryptolog.info("receiving data using %s encryption", ciphername)
     299            self.cipher_in_name = ciphername
     300
     301    def set_cipher_out(self, ciphername, iv, password, key_salt, iterations, padding):
     302        cryptolog("set_cipher_out%s", (ciphername, iv, password, key_salt, iterations, padding))
     303        self.cipher_out, self.cipher_out_block_size = get_encryptor(ciphername, iv, password, key_salt, iterations)
     304        self.cipher_out_padding = padding
     305        if self.cipher_out_name!=ciphername:
     306            cryptolog.info("sending data using %s encryption", ciphername)
     307            self.cipher_out_name = ciphername
     308
     309
     310    def __repr__(self):
     311        return "Protocol(%s)" % self._conn
     312
     313    def get_threads(self):
     314        return  [x for x in [self._write_thread, self._read_thread, self._read_parser_thread, self._write_format_thread] if x is not None]
     315
     316    def accept(self):
     317        pass
     318
     319    def parse_remote_caps(self, caps):
     320        self.send_aliases = caps.dictget("aliases")
     321
     322    def get_info(self, alias_info=True):
     323        info = {
     324            "large_packets"         : tuple(bytestostr(x) for x in self.large_packets),
     325            "compression_level"     : self.compression_level,
     326            "max_packet_size"       : self.max_packet_size,
     327            "aliases"               : USE_ALIASES,
     328            "input" : {
     329                       "buffer-size"            : self.read_buffer_size,
     330                       "hangup-delay"           : self.hangup_delay,
     331                       "packetcount"            : self.input_packetcount,
     332                       "raw_packetcount"        : self.input_raw_packetcount,
     333                       "count"                  : self.input_stats,
     334                       "cipher"                 : {"": self.cipher_in_name or "",
     335                                                   "padding"        : self.cipher_in_padding,
     336                                                   },
     337                        },
     338            "output" : {
     339                        "packet-join-size"      : PACKET_JOIN_SIZE,
     340                        "large-packet-size"     : LARGE_PACKET_SIZE,
     341                        "inline-size"           : INLINE_SIZE,
     342                        "min-compress-size"     : MIN_COMPRESS_SIZE,
     343                        "packetcount"           : self.output_packetcount,
     344                        "raw_packetcount"       : self.output_raw_packetcount,
     345                        "count"                 : self.output_stats,
     346                        "cipher"                : {"": self.cipher_out_name or "",
     347                                                   "padding" : self.cipher_out_padding
     348                                                   },
     349                        },
     350            }
     351        c = self._compress
     352        if c:
     353            info["compressor"] = compression.get_compressor_name(self._compress)
     354        e = self._encoder
     355        if e:
     356            if self._encoder==self.noencode:
     357                info["encoder"] = "noencode"
     358            else:
     359                info["encoder"] = packet_encoding.get_encoder_name(self._encoder)
     360        if alias_info:
     361            info["send_alias"] = self.send_aliases
     362            info["receive_alias"] = self.receive_aliases
     363        c = self._conn
     364        if c:
     365            try:
     366                info.update(self._conn.get_info())
     367            except:
     368                log.error("error collecting connection information on %s", self._conn, exc_info=True)
     369        shm = self._source_has_more_event
     370        info["has_more"] = shm and shm.is_set()
     371        for t in (self._write_thread, self._read_thread, self._read_parser_thread, self._write_format_thread):
     372            if t:
     373                info.setdefault("thread", {})[t.name] = t.is_alive()
     374        return info
     375
     376
     377    def start(self):
     378        def start_network_read_thread():
     379            if not self._closed:
     380                self._read_thread.start()
     381        self.idle_add(start_network_read_thread)
     382        if SEND_INVALID_PACKET:
     383            self.timeout_add(SEND_INVALID_PACKET*1000, self.raw_write, SEND_INVALID_PACKET_DATA)
     384
     385
     386    def send_disconnect(self, reasons, done_callback=None):
     387        self.flush_then_close(["disconnect"]+list(reasons), done_callback=done_callback)
     388
     389    def send_now(self, packet):
     390        if self._closed:
     391            log("send_now(%s ...) connection is closed already, not sending", packet[0])
     392            return
     393        log("send_now(%s ...)", packet[0])
     394        assert self._get_packet_cb==None, "cannot use send_now when a packet source exists! (set to %s)" % self._get_packet_cb
     395        tmp_queue = [packet]
     396        def packet_cb():
     397            self._get_packet_cb = None
     398            if not tmp_queue:
     399                raise Exception("packet callback used more than once!")
     400            packet = tmp_queue.pop()
     401            return (packet, )
     402        self._get_packet_cb = packet_cb
     403        self.source_has_more()
     404
     405    def source_has_more(self):
     406        shm = self._source_has_more_event
     407        if not shm or self._closed:
     408            return
     409        shm.set()
     410        #start the format thread:
     411        if not self._write_format_thread and not self._closed:
     412            self._write_format_thread = make_thread(self._write_format_thread_loop, "format", daemon=True)
     413            self._write_format_thread.start()
     414        #from now on, take shortcut:
     415        #not compatible with cython..
     416        #self.source_has_more = self._source_has_more_event.set
     417
     418    def _write_format_thread_loop(self):
     419        log("write_format_thread_loop starting")
     420        try:
     421            while not self._closed:
     422                self._source_has_more_event.wait()
     423                gpc = self._get_packet_cb
     424                if self._closed or not gpc:
     425                    return
     426                self._add_packet_to_queue(*gpc())
     427        except Exception as e:
     428            if self._closed:
     429                return
     430            self._internal_error("error in network packet write/format", e, exc_info=True)
     431
     432    def _add_packet_to_queue(self, packet, start_send_cb=None, end_send_cb=None, fail_cb=None, synchronous=True, has_more=False, wait_for_more=False):
     433        if not has_more:
     434            shm = self._source_has_more_event
     435            if shm:
     436                shm.clear()
     437        if packet is None:
     438            return
     439        #log("add_packet_to_queue(%s ... %s, %s, %s)", packet[0], synchronous, has_more, wait_for_more)
     440        packet_type = packet[0]
     441        chunks = self.encode(packet)
     442        with self._write_lock:
     443            if self._closed:
     444                return
     445            try:
     446                self._add_chunks_to_queue(packet_type, chunks, start_send_cb, end_send_cb, fail_cb, synchronous, has_more or wait_for_more)
     447            except:
     448                log.error("Error: failed to queue '%s' packet", packet[0])
     449                log("add_chunks_to_queue%s", (chunks, start_send_cb, end_send_cb, fail_cb), exc_info=True)
     450                raise
     451
     452    def _add_chunks_to_queue(self, packet_type, chunks, start_send_cb=None, end_send_cb=None, fail_cb=None, synchronous=True, more=False):
     453        """ the write_lock must be held when calling this function """
     454        items = []
     455        for proto_flags,index,level,data in chunks:
     456            payload_size = len(data)
     457            actual_size = payload_size
     458            if self.cipher_out:
     459                proto_flags |= FLAGS_CIPHER
     460                #note: since we are padding: l!=len(data)
     461                padding_size = self.cipher_out_block_size - (payload_size % self.cipher_out_block_size)
     462                if padding_size==0:
     463                    padded = data
     464                else:
     465                    # pad byte value is number of padding bytes added
     466                    padded = memoryview_to_bytes(data) + pad(self.cipher_out_padding, padding_size)
     467                    actual_size += padding_size
     468                assert len(padded)==actual_size, "expected padded size to be %i, but got %i" % (len(padded), actual_size)
     469                data = self.cipher_out.encrypt(padded)
     470                assert len(data)==actual_size, "expected encrypted size to be %i, but got %i" % (len(data), actual_size)
     471                cryptolog("sending %s bytes %s encrypted with %s padding", payload_size, self.cipher_out_name, padding_size)
     472            if proto_flags & FLAGS_NOHEADER:
     473                assert not self.cipher_out
     474                #for plain/text packets (ie: gibberish response)
     475                log("sending %s bytes without header", payload_size)
     476                items.append(data)
     477            else:
     478                #the xpra packet header:
     479                #(WebSocketProtocol may also add a websocket header too)
     480                header = self.make_chunk_header(packet_type, proto_flags, level, index, actual_size)
     481                if actual_size<PACKET_JOIN_SIZE:
     482                    if not isinstance(data, JOIN_TYPES):
     483                        data = memoryview_to_bytes(data)
     484                    items.append(header+data)
     485                else:
     486                    items.append(header)
     487                    items.append(data)
     488        #WebSocket header may be added here:
     489        frame_header = self.make_frame_header(packet_type, items)
     490        if frame_header:
     491            item0 = items[0]
     492            if len(item0)<PACKET_JOIN_SIZE:
     493                if not isinstance(item0, JOIN_TYPES):
     494                    item0 = memoryview_to_bytes(item0)
     495                items[0] = frame_header + item0
     496            else:
     497                items.insert(0, frame_header)
     498        self.raw_write(items, start_send_cb, end_send_cb, fail_cb, synchronous, more)
     499
     500    def make_xpra_header(self, _packet_type, proto_flags, level, index, payload_size):
     501        return pack_header(proto_flags, level, index, payload_size)
     502
     503    def noframe_header(self, _packet_type, _items):
     504        return None
     505
     506
     507    def start_write_thread(self):
     508        self._write_thread = start_thread(self._write_thread_loop, "write", daemon=True)
     509
     510    def raw_write(self, items, start_cb=None, end_cb=None, fail_cb=None, synchronous=True, more=False):
     511        """ Warning: this bypasses the compression and packet encoder! """
     512        if self._write_thread is None:
     513            self.start_write_thread()
     514        self._write_queue.put((items, start_cb, end_cb, fail_cb, synchronous, more))
     515
     516
     517    def enable_default_encoder(self):
     518        opts = packet_encoding.get_enabled_encoders()
     519        assert len(opts)>0, "no packet encoders available!"
     520        self.enable_encoder(opts[0])
     521
     522    def enable_encoder_from_caps(self, caps):
     523        opts = packet_encoding.get_enabled_encoders(order=packet_encoding.PERFORMANCE_ORDER)
     524        log("enable_encoder_from_caps(..) options=%s", opts)
     525        for e in opts:
     526            if caps.boolget(e, e=="bencode"):
     527                self.enable_encoder(e)
     528                return True
     529        log.error("no matching packet encoder found!")
     530        return False
     531
     532    def enable_encoder(self, e):
     533        self._encoder = packet_encoding.get_encoder(e)
     534        self.encoder = e
     535        log("enable_encoder(%s): %s", e, self._encoder)
     536
     537
     538    def enable_default_compressor(self):
     539        opts = compression.get_enabled_compressors()
     540        if len(opts)>0:
     541            self.enable_compressor(opts[0])
     542        else:
     543            self.enable_compressor("none")
     544
     545    def enable_compressor_from_caps(self, caps):
     546        if self.compression_level==0:
     547            self.enable_compressor("none")
     548            return
     549        opts = compression.get_enabled_compressors(order=compression.PERFORMANCE_ORDER)
     550        log("enable_compressor_from_caps(..) options=%s", opts)
     551        for c in opts:      #ie: [zlib, lz4, lzo]
     552            if caps.boolget(c):
     553                self.enable_compressor(c)
     554                return
     555        log.warn("compression disabled: no matching compressor found")
     556        self.enable_compressor("none")
     557
     558    def enable_compressor(self, compressor):
     559        self._compress = compression.get_compressor(compressor)
     560        self.compressor = compressor
     561        log("enable_compressor(%s): %s", compressor, self._compress)
     562
     563
     564    def noencode(self, data):
     565        #just send data as a string for clients that don't understand xpra packet format:
     566        if PYTHON3:
     567            import codecs
     568            def b(x):
     569                if type(x)==bytes:
     570                    return x
     571                return codecs.latin_1_encode(x)[0]
     572        else:
     573            def b(x):               #@DuplicatedSignature
     574                return x
     575        return b(": ".join(str(x) for x in data)+"\n"), FLAGS_NOHEADER
     576
     577
     578    def encode(self, packet_in):
     579        """
     580        Given a packet (tuple or list of items), converts it for the wire.
     581        This method returns all the binary packets to send, as an array of:
     582        (index, compression_level and compression flags, binary_data)
     583        The index, if positive indicates the item to populate in the packet
     584        whose index is zero.
     585        ie: ["blah", [large binary data], "hello", 200]
     586        may get converted to:
     587        [
     588            (1, compression_level, [large binary data now zlib compressed]),
     589            (0,                 0, bencoded/rencoded(["blah", '', "hello", 200]))
     590        ]
     591        """
     592        cdef object packets = []
     593        cdef object packet = list(packet_in)
     594        cdef unsigned int level = self.compression_level
     595        cdef unsigned int size_check = LARGE_PACKET_SIZE
     596        cdef unsigned int min_comp_size = MIN_COMPRESS_SIZE
     597        cdef unsigned int i, l, cl
     598        cdef object item
     599        cdef object ti
     600        cdef object packet_type
     601        for i in range(1, len(packet)):
     602            item = packet[i]
     603            if item is None:
     604                raise TypeError("invalid None value in %s packet at index %s" % (packet[0], i))
     605            ti = type(item)
     606            if ti in (int, long, bool, dict, list, tuple):
     607                continue
     608            try:
     609                l = len(item)
     610            except TypeError as e:
     611                raise TypeError("invalid type %s in %s packet at index %s: %s" % (ti, packet[0], i, e))
     612            if ti==LargeStructure:
     613                item = item.data
     614                packet[i] = item
     615                ti = type(item)
     616                continue
     617            elif ti==Compressible:
     618                #this is a marker used to tell us we should compress it now
     619                #(used by the client for clipboard data)
     620                item = item.compress()
     621                packet[i] = item
     622                ti = type(item)
     623                #(it may now be a "Compressed" item and be processed further)
     624            if ti in (Compressed, LevelCompressed):
     625                #already compressed data (usually pixels, cursors, etc)
     626                if not item.can_inline or l>INLINE_SIZE:
     627                    il = 0
     628                    if ti==LevelCompressed:
     629                        #unlike Compressed (usually pixels, decompressed in the paint thread),
     630                        #LevelCompressed is decompressed by the network layer
     631                        #so we must tell it how to do that and pass the level flag
     632                        il = item.level
     633                    packets.append((0, i, il, item.data))
     634                    packet[i] = ''
     635                else:
     636                    #data is small enough, inline it:
     637                    packet[i] = item.data
     638                    min_comp_size += l
     639                    size_check += l
     640            elif ti in (str, bytes) and level>0 and l>LARGE_PACKET_SIZE:
     641                log.warn("Warning: found a large uncompressed item")
     642                log.warn(" in packet '%s' at position %i: %s bytes", packet[0], i, len(item))
     643                #add new binary packet with large item:
     644                cl, cdata = self._compress(item, level)
     645                packets.append((0, i, cl, cdata))
     646                #replace this item with an empty string placeholder:
     647                packet[i] = b''
     648            elif ti not in (str, bytes):
     649                log.warn("Warning: unexpected data type %s", ti)
     650                log.warn(" in '%s' packet at position %i: %s", packet[0], i, repr_ellipsized(item))
     651        #now the main packet (or what is left of it):
     652        packet_type = packet[0]
     653        self.output_stats[packet_type] = self.output_stats.get(packet_type, 0)+1
     654        if USE_ALIASES and self.send_aliases and packet_type in self.send_aliases:
     655            #replace the packet type with the alias:
     656            packet[0] = self.send_aliases[packet_type]
     657        try:
     658            main_packet, proto_flags = self._encoder(packet)
     659        except Exception:
     660            if self._closed:
     661                return [], 0
     662            log.error("Error: failed to encode packet: %s", packet, exc_info=True)
     663            #make the error a bit nicer to parse: undo aliases:
     664            packet[0] = packet_type
     665            verify_packet(packet)
     666            raise
     667        if len(main_packet)>size_check and strtobytes(packet_in[0]) not in self.large_packets:
     668            log.warn("Warning: found large packet")
     669            log.warn(" '%s' packet is %s bytes: ", packet_type, len(main_packet))
     670            log.warn(" argument types: %s", csv(type(x) for x in packet[1:]))
     671            log.warn(" sizes: %s", csv(len(strtobytes(x)) for x in packet[1:]))
     672            log.warn(" packet head=%s", repr_ellipsized(packet))
     673        #compress, but don't bother for small packets:
     674        if level>0 and len(main_packet)>min_comp_size:
     675            try:
     676                cl, cdata = self._compress(main_packet, level)
     677            except Exception as e:
     678                log.error("Error compressing '%s' packet", packet_type)
     679                log.error(" %s", e)
     680                raise
     681            packets.append((proto_flags, 0, cl, cdata))
     682        else:
     683            packets.append((proto_flags, 0, 0, main_packet))
     684        return packets
     685
     686    def set_compression_level(self, level):
     687        #this may be used next time encode() is called
     688        assert level>=0 and level<=10, "invalid compression level: %s (must be between 0 and 10" % level
     689        self.compression_level = level
     690
     691
     692    def _io_thread_loop(self, name, callback):
     693        try:
     694            log("io_thread_loop(%s, %s) loop starting", name, callback)
     695            while not self._closed and callback():
     696                pass
     697            log("io_thread_loop(%s, %s) loop ended, closed=%s", name, callback, self._closed)
     698        except ConnectionClosedException:
     699            log("%s closed", self._conn, exc_info=True)
     700            if not self._closed:
     701                #ConnectionClosedException means the warning has been logged already
     702                self._connection_lost("%s connection %s closed" % (name, self._conn))
     703        except (OSError, IOError, socket_error) as e:
     704            if not self._closed:
     705                self._internal_error("%s connection %s reset" % (name, self._conn), e, exc_info=e.args[0] not in ABORT)
     706        except Exception as e:
     707            #can happen during close(), in which case we just ignore:
     708            if not self._closed:
     709                log.error("Error: %s on %s failed: %s", name, self._conn, type(e), exc_info=True)
     710                self.close()
     711
     712
     713    def _write_thread_loop(self):
     714        self._io_thread_loop("write", self._write)
     715    def _write(self):
     716        items = self._write_queue.get()
     717        # Used to signal that we should exit:
     718        if items is None:
     719            log("write thread: empty marker, exiting")
     720            self.close()
     721            return False
     722        return self.write_items(*items)
     723
     724    def write_items(self, object buf_data, object start_cb, object end_cb, object fail_cb, char synchronous, char more):
     725        cdef object conn = self._conn
     726        if not conn:
     727            return False
     728        if more or len(buf_data)>1:
     729            conn.set_nodelay(False)
     730        if len(buf_data)>1:
     731            conn.set_cork(True)
     732        if start_cb:
     733            try:
     734                start_cb(conn.output_bytecount)
     735            except:
     736                if not self._closed:
     737                    log.error("Error on write start callback %s", start_cb, exc_info=True)
     738        self.write_buffers(buf_data, fail_cb, synchronous)
     739        if len(buf_data)>1:
     740            conn.set_cork(False)
     741        if not more:
     742            conn.set_nodelay(True)
     743        if end_cb:
     744            try:
     745                end_cb(self._conn.output_bytecount)
     746            except:
     747                if not self._closed:
     748                    log.error("Error on write end callback %s", end_cb, exc_info=True)
     749        return True
     750
     751    cdef write_buffers(self, object buf_data, object _fail_cb, char _synchronous):
     752        cdef object con = self._conn
     753        if not con:
     754            return 0
     755        for buf in buf_data:
     756            while buf and not self._closed:
     757                written = con.write(buf)
     758                #example test code, for sending small chunks very slowly:
     759                #written = con.write(buf[:1024])
     760                #import time
     761                #time.sleep(0.05)
     762                if written:
     763                    buf = buf[written:]
     764                    self.output_raw_packetcount += 1
     765        self.output_packetcount += 1
     766
     767
     768    def _read_thread_loop(self):
     769        self._io_thread_loop("read", self._read)
     770    def _read(self):
     771        buf = self._conn.read(self.read_buffer_size)
     772        #log("read thread: got data of size %s: %s", len(buf), repr_ellipsized(buf))
     773        #add to the read queue (or whatever takes its place - see steal_connection)
     774        self._process_read(buf)
     775        if not buf:
     776            log("read thread: eof")
     777            #give time to the parse thread to call close itself
     778            #so it has time to parse and process the last packet received
     779            self.timeout_add(1000, self.close)
     780            return False
     781        self.input_raw_packetcount += 1
     782        return True
     783
     784    def _internal_error(self, message="", exc=None, exc_info=False):
     785        #log exception info with last log message
     786        if self._closed:
     787            return
     788        ei = exc_info
     789        if exc:
     790            ei = None   #log it separately below
     791        log.error("Error: %s", message, exc_info=ei)
     792        if exc:
     793            log.error(" %s", exc, exc_info=exc_info)
     794            exc = None
     795        self.idle_add(self._connection_lost, message)
     796
     797    def _connection_lost(self, message="", exc_info=False):
     798        log("connection lost: %s", message, exc_info=exc_info)
     799        self.close()
     800        return False
     801
     802
     803    def invalid(self, msg, data):
     804        self.idle_add(self._process_packet_cb, self, [Protocol.INVALID, msg, data])
     805        # Then hang up:
     806        self.timeout_add(1000, self._connection_lost, msg)
     807
     808    def gibberish(self, msg, data):
     809        self.idle_add(self._process_packet_cb, self, [Protocol.GIBBERISH, msg, data])
     810        # Then hang up:
     811        self.timeout_add(self.hangup_delay, self._connection_lost, msg)
     812
     813
     814    #delegates to invalid_header()
     815    #(so this can more easily be intercepted and overriden
     816    # see tcp-proxy)
     817    def _invalid_header(self, data, msg=""):
     818        self.invalid_header(self, data, msg)
     819
     820    def invalid_header(self, _proto, data, msg="invalid packet header"):
     821        err = "%s: '%s'" % (msg, hexstr(data[:HEADER_SIZE]))
     822        if len(data)>1:
     823            err += " read buffer=%s (%i bytes)" % (repr_ellipsized(data), len(data))
     824        self.gibberish(err, data)
     825
     826
     827    def process_read(self, data):
     828        self._read_queue_put(data)
     829
     830    def read_queue_put(self, data):
     831        #start the parse thread if needed:
     832        if not self._read_parser_thread and not self._closed:
     833            if data is None:
     834                log("empty marker in read queue, exiting")
     835                self.idle_add(self.close)
     836                return
     837            self.start_read_parser_thread()
     838        self._read_queue.put(data)
     839        #from now on, take shortcut:
     840        if self._read_queue_put==self.read_queue_put:
     841            self._read_queue_put = self._read_queue.put
     842
     843    def start_read_parser_thread(self):
     844        self._read_parser_thread = start_thread(self._read_parse_thread_loop, "parse", daemon=True)
     845
     846    def _read_parse_thread_loop(self):
     847        log("read_parse_thread_loop starting")
     848        try:
     849            self.do_read_parse_thread_loop()
     850        except Exception as e:
     851            if self._closed:
     852                return
     853            self._internal_error("error in network packet reading/parsing", e, exc_info=True)
     854
     855    def do_read_parse_thread_loop(self):
     856        """
     857            Process the individual network packets placed in _read_queue.
     858            Concatenate the raw packet data, then try to parse it.
     859            Extract the individual packets from the potentially large buffer,
     860            saving the rest of the buffer for later, and optionally decompress this data
     861            and re-construct the one python-object-packet from potentially multiple packets (see packet_index).
     862            The 8 bytes packet header gives us information on the packet index, packet size and compression.
     863            The actual processing of the packet is done via the callback process_packet_cb,
     864            this will be called from this parsing thread so any calls that need to be made
     865            from the UI thread will need to use a callback (usually via 'idle_add')
     866        """
     867        cdef object header = b""
     868        cdef object read_buffers = []
     869        cdef int payload_size = -1
     870        cdef unsigned int padding_size = 0
     871        cdef unsigned int packet_index = 0
     872        cdef unsigned int compression_level = 0
     873        cdef object raw_packets = {}
     874        while not self._closed:
     875            buf = self._read_queue.get()
     876            if not buf:
     877                log("parse thread: empty marker, exiting")
     878                self.idle_add(self.close)
     879                return
     880
     881            read_buffers.append(buf)
     882            while read_buffers:
     883                #have we read the header yet?
     884                if payload_size<0:
     885                    #try to handle the first buffer:
     886                    buf = read_buffers[0]
     887                    if not header and buf[0] not in ("P", ord("P")):
     888                        self._invalid_header(buf, "invalid packet header byte %s" % buf)
     889                        return
     890                    #how much to we need to slice off to complete the header:
     891                    read = min(len(buf), HEADER_SIZE-len(header))
     892                    header += buf[:read]
     893                    if len(header)<HEADER_SIZE:
     894                        #need to process more buffers to get a full header:
     895                        read_buffers.pop(0)
     896                        continue
     897                    elif len(buf)>read:
     898                        #got the full header and more, keep the rest of the packet:
     899                        read_buffers[0] = buf[read:]
     900                    else:
     901                        #we only got the header:
     902                        assert len(buf)==read
     903                        read_buffers.pop(0)
     904                        continue
     905                    #parse the header:
     906                    # format: struct.pack(b'cBBBL', ...) - HEADER_SIZE bytes
     907                    _, protocol_flags, compression_level, packet_index, data_size = unpack_header(header)
     908
     909                    #sanity check size (will often fail if not an xpra client):
     910                    if data_size>self.abs_max_packet_size:
     911                        self._invalid_header(header, "invalid size in packet header: %s" % data_size)
     912                        return
     913
     914                    if protocol_flags & FLAGS_CIPHER:
     915                        if self.cipher_in_block_size==0 or not self.cipher_in_name:
     916                            cryptolog.warn("received cipher block but we don't have a cipher to decrypt it with, not an xpra client?")
     917                            self._invalid_header(header, "invalid encryption packet flag (no cipher configured)")
     918                            return
     919                        padding_size = self.cipher_in_block_size - (data_size % self.cipher_in_block_size)
     920                        payload_size = data_size + padding_size
     921                    else:
     922                        #no cipher, no padding:
     923                        padding_size = 0
     924                        payload_size = data_size
     925                    assert payload_size>0, "invalid payload size: %i" % payload_size
     926
     927                    if payload_size>self.max_packet_size:
     928                        #this packet is seemingly too big, but check again from the main UI thread
     929                        #this gives 'set_max_packet_size' a chance to run from "hello"
     930                        def check_packet_size(size_to_check, packet_header):
     931                            if self._closed:
     932                                return False
     933                            log("check_packet_size(%s, 0x%s) limit is %s", size_to_check, repr_ellipsized(packet_header), self.max_packet_size)
     934                            if size_to_check>self.max_packet_size:
     935                                msg = "packet size requested is %s but maximum allowed is %s" % \
     936                                              (size_to_check, self.max_packet_size)
     937                                self.invalid(msg, packet_header)
     938                            return False
     939                        self.timeout_add(1000, check_packet_size, payload_size, header)
     940
     941                #how much data do we have?
     942                bl = sum(len(v) for v in read_buffers)
     943                if bl<payload_size:
     944                    # incomplete packet, wait for the rest to arrive
     945                    break
     946
     947                buf = read_buffers[0]
     948                if len(buf)==payload_size:
     949                    #exact match, consume it all:
     950                    data = read_buffers.pop(0)
     951                elif len(buf)>payload_size:
     952                    #keep rest of packet for later:
     953                    read_buffers[0] = buf[payload_size:]
     954                    data = buf[:payload_size]
     955                else:
     956                    #we need to aggregate chunks,
     957                    #just concatenate them all:
     958                    data = b"".join(read_buffers)
     959                    if bl==payload_size:
     960                        #nothing left:
     961                        read_buffers = []
     962                    else:
     963                        #keep the left over:
     964                        read_buffers = [data[payload_size:]]
     965                        data = data[:payload_size]
     966
     967                #decrypt if needed:
     968                if self.cipher_in:
     969                    if not (protocol_flags & FLAGS_CIPHER):
     970                        self.invalid("unencrypted packet dropped", data)
     971                        return
     972                    cryptolog("received %i %s encrypted bytes with %s padding", payload_size, self.cipher_in_name, padding_size)
     973                    data = self.cipher_in.decrypt(data)
     974                    if padding_size > 0:
     975                        def debug_str(s):
     976                            try:
     977                                return hexstr(bytearray(s))
     978                            except:
     979                                return csv(tuple(s))
     980                        # pad byte value is number of padding bytes added
     981                        padtext = pad(self.cipher_in_padding, padding_size)
     982                        if data.endswith(padtext):
     983                            cryptolog("found %s %s padding", self.cipher_in_padding, self.cipher_in_name)
     984                        else:
     985                            actual_padding = data[-padding_size:]
     986                            cryptolog.warn("Warning: %s decryption failed: invalid padding", self.cipher_in_name)
     987                            cryptolog(" data does not end with %s padding bytes %s", self.cipher_in_padding, debug_str(padtext))
     988                            cryptolog(" but with %s (%s)", debug_str(actual_padding), type(data))
     989                            cryptolog(" decrypted data: %s", debug_str(data[:128]))
     990                            self._internal_error("%s encryption padding error - wrong key?" % self.cipher_in_name)
     991                            return
     992                        data = data[:-padding_size]
     993                #uncompress if needed:
     994                if compression_level>0:
     995                    try:
     996                        data = decompress(data, compression_level)
     997                    except InvalidCompressionException as e:
     998                        self.invalid("invalid compression: %s" % e, data)
     999                        return
     1000                    except Exception as e:
     1001                        ctype = compression.get_compression_type(compression_level)
     1002                        log("%s packet decompression failed", ctype, exc_info=True)
     1003                        msg = "%s packet decompression failed" % ctype
     1004                        if self.cipher_in:
     1005                            msg += " (invalid encryption key?)"
     1006                        else:
     1007                            #only include the exception text when not using encryption
     1008                            #as this may leak crypto information:
     1009                            msg += " %s" % e
     1010                        del e
     1011                        return self.gibberish(msg, data)
     1012
     1013                if self._closed:
     1014                    return
     1015
     1016                #we're processing this packet,
     1017                #make sure we get a new header next time
     1018                header = b""
     1019                if packet_index>0:
     1020                    #raw packet, store it and continue:
     1021                    raw_packets[packet_index] = data
     1022                    payload_size = -1
     1023                    if len(raw_packets)>=4:
     1024                        self.invalid("too many raw packets: %s" % len(raw_packets), data)
     1025                        return
     1026                    continue
     1027                #final packet (packet_index==0), decode it:
     1028                try:
     1029                    packet = decode(data, protocol_flags)
     1030                except InvalidPacketEncodingException as e:
     1031                    self.invalid("invalid packet encoding: %s" % e, data)
     1032                    return
     1033                except ValueError as e:
     1034                    etype = packet_encoding.get_packet_encoding_type(protocol_flags)
     1035                    log.error("Error parsing %s packet:", etype)
     1036                    log.error(" %s", e)
     1037                    if self._closed:
     1038                        return
     1039                    log("failed to parse %s packet: %s", etype, hexstr(data[:128]))
     1040                    log(" %s", e)
     1041                    log(" data: %s", repr_ellipsized(data))
     1042                    log(" packet index=%i, packet size=%i, buffer size=%s", packet_index, payload_size, bl)
     1043                    self.gibberish("failed to parse %s packet" % etype, data)
     1044                    return
     1045
     1046                if self._closed:
     1047                    return
     1048                payload_size = -1
     1049                #add any raw packets back into it:
     1050                if raw_packets:
     1051                    for index,raw_data in raw_packets.items():
     1052                        #replace placeholder with the raw_data packet data:
     1053                        packet[index] = raw_data
     1054                    raw_packets = {}
     1055
     1056                packet_type = packet[0]
     1057                if self.receive_aliases and type(packet_type)==int and packet_type in self.receive_aliases:
     1058                    packet_type = self.receive_aliases.get(packet_type)
     1059                    packet[0] = packet_type
     1060                self.input_stats[packet_type] = self.output_stats.get(packet_type, 0)+1
     1061                if LOG_RAW_PACKET_SIZE:
     1062                    log("%s: %i bytes", packet_type, HEADER_SIZE + payload_size)
     1063
     1064                self.input_packetcount += 1
     1065                log("processing packet %s", bytestostr(packet_type))
     1066                self._process_packet_cb(self, packet)
     1067                packet = None
     1068
     1069    def flush_then_close(self, last_packet, done_callback=None):
     1070        """ Note: this is best effort only
     1071            the packet may not get sent.
     1072
     1073            We try to get the write lock,
     1074            we try to wait for the write queue to flush
     1075            we queue our last packet,
     1076            we wait again for the queue to flush,
     1077            then no matter what, we close the connection and stop the threads.
     1078        """
     1079        log("flush_then_close(%s, %s) closed=%s", last_packet, done_callback, self._closed)
     1080        def done():
     1081            log("flush_then_close: done, callback=%s", done_callback)
     1082            if done_callback:
     1083                done_callback()
     1084        if self._closed:
     1085            log("flush_then_close: already closed")
     1086            return done()
     1087        def wait_for_queue(timeout=10):
     1088            #IMPORTANT: if we are here, we have the write lock held!
     1089            if not self._write_queue.empty():
     1090                #write queue still has stuff in it..
     1091                if timeout<=0:
     1092                    log("flush_then_close: queue still busy, closing without sending the last packet")
     1093                    try:
     1094                        self._write_lock.release()
     1095                    except:
     1096                        pass
     1097                    self.close()
     1098                    done()
     1099                else:
     1100                    log("flush_then_close: still waiting for queue to flush")
     1101                    self.timeout_add(100, wait_for_queue, timeout-1)
     1102            else:
     1103                log("flush_then_close: queue is now empty, sending the last packet and closing")
     1104                chunks = self.encode(last_packet)
     1105                def close_and_release():
     1106                    log("flush_then_close: wait_for_packet_sent() close_and_release()")
     1107                    self.close()
     1108                    try:
     1109                        self._write_lock.release()
     1110                    except:
     1111                        pass
     1112                    done()
     1113                def wait_for_packet_sent():
     1114                    log("flush_then_close: wait_for_packet_sent() queue.empty()=%s, closed=%s", self._write_queue.empty(), self._closed)
     1115                    if self._write_queue.empty() or self._closed:
     1116                        #it got sent, we're done!
     1117                        close_and_release()
     1118                        return False
     1119                    return not self._closed     #run until we manage to close (here or via the timeout)
     1120                def packet_queued(*_args):
     1121                    #if we're here, we have the lock and the packet is in the write queue
     1122                    log("flush_then_close: packet_queued() closed=%s", self._closed)
     1123                    if wait_for_packet_sent():
     1124                        #check again every 100ms
     1125                        self.timeout_add(100, wait_for_packet_sent)
     1126                self._add_chunks_to_queue(last_packet[0], chunks, start_send_cb=None, end_send_cb=packet_queued, synchronous=False, more=False)
     1127                #just in case wait_for_packet_sent never fires:
     1128                self.timeout_add(5*1000, close_and_release)
     1129
     1130        def wait_for_write_lock(timeout=100):
     1131            wl = self._write_lock
     1132            if not wl:
     1133                #cleaned up already
     1134                return
     1135            if not wl.acquire(False):
     1136                if timeout<=0:
     1137                    log("flush_then_close: timeout waiting for the write lock")
     1138                    self.close()
     1139                    done()
     1140                else:
     1141                    log("flush_then_close: write lock is busy, will retry %s more times", timeout)
     1142                    self.timeout_add(10, wait_for_write_lock, timeout-1)
     1143            else:
     1144                log("flush_then_close: acquired the write lock")
     1145                #we have the write lock - we MUST free it!
     1146                wait_for_queue()
     1147        #normal codepath:
     1148        # -> wait_for_write_lock
     1149        # -> wait_for_queue
     1150        # -> _add_chunks_to_queue
     1151        # -> packet_queued
     1152        # -> wait_for_packet_sent
     1153        # -> close_and_release
     1154        log("flush_then_close: wait_for_write_lock()")
     1155        wait_for_write_lock()
     1156
     1157    def close(self):
     1158        log("Protocol.close() closed=%s, connection=%s", self._closed, self._conn)
     1159        if self._closed:
     1160            return
     1161        self._closed = True
     1162        self.idle_add(self._process_packet_cb, self, [Protocol.CONNECTION_LOST])
     1163        c = self._conn
     1164        if c:
     1165            self._conn = None
     1166            try:
     1167                log("Protocol.close() calling %s", c.close)
     1168                c.close()
     1169                if self._log_stats is None and c.input_bytecount==0 and c.output_bytecount==0:
     1170                    #no data sent or received, skip logging of stats:
     1171                    self._log_stats = False
     1172                if self._log_stats:
     1173                    from xpra.simple_stats import std_unit, std_unit_dec
     1174                    log.info("connection closed after %s packets received (%s bytes) and %s packets sent (%s bytes)",
     1175                         std_unit(self.input_packetcount), std_unit_dec(c.input_bytecount),
     1176                         std_unit(self.output_packetcount), std_unit_dec(c.output_bytecount)
     1177                         )
     1178            except:
     1179                log.error("error closing %s", c, exc_info=True)
     1180        self.terminate_queue_threads()
     1181        self.idle_add(self.clean)
     1182        log("Protocol.close() done")
     1183
     1184    def steal_connection(self, read_callback=None):
     1185        #so we can re-use this connection somewhere else
     1186        #(frees all protocol threads and resources)
     1187        #Note: this method can only be used with non-blocking sockets,
     1188        #and if more than one packet can arrive, the read_callback should be used
     1189        #to ensure that no packets get lost.
     1190        #The caller must call wait_for_io_threads_exit() to ensure that this
     1191        #class is no longer reading from the connection before it can re-use it
     1192        assert not self._closed, "cannot steal a closed connection"
     1193        if read_callback:
     1194            self._read_queue_put = read_callback
     1195        conn = self._conn
     1196        self._closed = True
     1197        self._conn = None
     1198        if conn:
     1199            #this ensures that we exit the untilConcludes() read/write loop
     1200            conn.set_active(False)
     1201        self.terminate_queue_threads()
     1202        return conn
     1203
     1204    def clean(self):
     1205        #clear all references to ensure we can get garbage collected quickly:
     1206        self._get_packet_cb = None
     1207        self._encoder = None
     1208        self._write_thread = None
     1209        self._read_thread = None
     1210        self._read_parser_thread = None
     1211        self._write_format_thread = None
     1212        self._process_packet_cb = None
     1213        self._process_read = None
     1214        self._read_queue_put = None
     1215        self._compress = None
     1216        self._write_lock = None
     1217        self._source_has_more_event = None
     1218        self._conn = None       #should be redundant
     1219        def noop():
     1220            pass
     1221        #not compatible with cython..
     1222        #self.source_has_more = noop
     1223
     1224
     1225    def terminate_queue_threads(self):
     1226        log("terminate_queue_threads()")
     1227        #the format thread will exit:
     1228        self._get_packet_cb = None
     1229        self._source_has_more_event.set()
     1230        #make all the queue based threads exit by adding the empty marker:
     1231        #write queue:
     1232        owq = self._write_queue
     1233        self._write_queue = exit_queue()
     1234        force_flush_queue(owq)
     1235        #read queue:
     1236        orq = self._read_queue
     1237        self._read_queue = exit_queue()
     1238        force_flush_queue(orq)
     1239        #just in case the read thread is waiting again:
     1240        self._source_has_more_event.set()