Refactor ThreadSafeStdBuffer

This commit is contained in:
Ivan Kravets
2019-07-02 15:47:03 +03:00
parent 83f64cebbd
commit 148b7dccfd

View File

@ -33,42 +33,42 @@ except ImportError:
class ThreadSafeStdBuffer(object): class ThreadSafeStdBuffer(object):
def __init__(self, parent_stream, parent_thread_id): def __init__(self, parent_stream):
self.parent_stream = parent_stream self._buffers = {thread_get_ident(): parent_stream}
self.parent_thread_id = parent_thread_id
self._buffer = {} def __getattr__(self, name):
thread_id = thread_get_ident()
if thread_id not in self._buffers:
raise AttributeError(name)
return getattr(self._buffers[thread_id], name)
def write(self, value): def write(self, value):
thread_id = thread_get_ident() thread_id = thread_get_ident()
if thread_id == self.parent_thread_id: if thread_id not in self._buffers:
return self.parent_stream.write( self._buffers[thread_id] = BytesIO()
value if isinstance(value, string_types) else value.decode()) try:
if thread_id not in self._buffer: return self._buffers[thread_id].write(value)
self._buffer[thread_id] = BytesIO() except TypeError:
return self._buffer[thread_id].write(value) return self._buffers[thread_id].write(value.encode())
def flush(self): def get_value_and_close(self):
return (self.parent_stream.flush() thread_id = thread_get_ident()
if thread_get_ident() == self.parent_thread_id else None) result = ""
try:
def getvalue_and_close(self, thread_id=None): result = self.getvalue()
thread_id = thread_id or thread_get_ident() self.close()
if thread_id not in self._buffer: if thread_id in self._buffers:
return "" del self._buffers[thread_id]
result = self._buffer.get(thread_id).getvalue() except AttributeError:
self._buffer.get(thread_id).close() pass
del self._buffer[thread_id]
return result return result
class PIOCoreRPC(object): class PIOCoreRPC(object):
def __init__(self): def __init__(self):
cur_thread_id = thread_get_ident() PIOCoreRPC.thread_stdout = ThreadSafeStdBuffer(sys.stdout)
PIOCoreRPC.thread_stdout = ThreadSafeStdBuffer(sys.stdout, PIOCoreRPC.thread_stderr = ThreadSafeStdBuffer(sys.stderr)
cur_thread_id)
PIOCoreRPC.thread_stderr = ThreadSafeStdBuffer(sys.stderr,
cur_thread_id)
sys.stdout = PIOCoreRPC.thread_stdout sys.stdout = PIOCoreRPC.thread_stdout
sys.stderr = PIOCoreRPC.thread_stderr sys.stderr = PIOCoreRPC.thread_stderr
@ -86,8 +86,8 @@ class PIOCoreRPC(object):
def _call_cli(): def _call_cli():
with util.cd((options or {}).get("cwd") or os.getcwd()): with util.cd((options or {}).get("cwd") or os.getcwd()):
exit_code = __main__.main(["-c"] + args) exit_code = __main__.main(["-c"] + args)
return (PIOCoreRPC.thread_stdout.getvalue_and_close(), return (PIOCoreRPC.thread_stdout.get_value_and_close(),
PIOCoreRPC.thread_stderr.getvalue_and_close(), exit_code) PIOCoreRPC.thread_stderr.get_value_and_close(), exit_code)
d = threads.deferToThread(_call_cli) d = threads.deferToThread(_call_cli)
d.addCallback(PIOCoreRPC._call_callback, "--json-output" in args) d.addCallback(PIOCoreRPC._call_callback, "--json-output" in args)