Refactor ThreadSafeStdBuffer

This commit is contained in:
Ivan Kravets
2019-07-02 15:47:03 +03:00
parent 83f64cebbd
commit 148b7dccfd

View File

@ -33,42 +33,42 @@ except ImportError:
class ThreadSafeStdBuffer(object):
def __init__(self, parent_stream, parent_thread_id):
self.parent_stream = parent_stream
self.parent_thread_id = parent_thread_id
self._buffer = {}
def __init__(self, parent_stream):
self._buffers = {thread_get_ident(): parent_stream}
def __getattr__(self, name):
thread_id = thread_get_ident()
if thread_id not in self._buffers:
raise AttributeError(name)
return getattr(self._buffers[thread_id], name)
def write(self, value):
thread_id = thread_get_ident()
if thread_id == self.parent_thread_id:
return self.parent_stream.write(
value if isinstance(value, string_types) else value.decode())
if thread_id not in self._buffer:
self._buffer[thread_id] = BytesIO()
return self._buffer[thread_id].write(value)
if thread_id not in self._buffers:
self._buffers[thread_id] = BytesIO()
try:
return self._buffers[thread_id].write(value)
except TypeError:
return self._buffers[thread_id].write(value.encode())
def flush(self):
return (self.parent_stream.flush()
if thread_get_ident() == self.parent_thread_id else None)
def getvalue_and_close(self, thread_id=None):
thread_id = thread_id or thread_get_ident()
if thread_id not in self._buffer:
return ""
result = self._buffer.get(thread_id).getvalue()
self._buffer.get(thread_id).close()
del self._buffer[thread_id]
def get_value_and_close(self):
thread_id = thread_get_ident()
result = ""
try:
result = self.getvalue()
self.close()
if thread_id in self._buffers:
del self._buffers[thread_id]
except AttributeError:
pass
return result
class PIOCoreRPC(object):
def __init__(self):
cur_thread_id = thread_get_ident()
PIOCoreRPC.thread_stdout = ThreadSafeStdBuffer(sys.stdout,
cur_thread_id)
PIOCoreRPC.thread_stderr = ThreadSafeStdBuffer(sys.stderr,
cur_thread_id)
PIOCoreRPC.thread_stdout = ThreadSafeStdBuffer(sys.stdout)
PIOCoreRPC.thread_stderr = ThreadSafeStdBuffer(sys.stderr)
sys.stdout = PIOCoreRPC.thread_stdout
sys.stderr = PIOCoreRPC.thread_stderr
@ -86,8 +86,8 @@ class PIOCoreRPC(object):
def _call_cli():
with util.cd((options or {}).get("cwd") or os.getcwd()):
exit_code = __main__.main(["-c"] + args)
return (PIOCoreRPC.thread_stdout.getvalue_and_close(),
PIOCoreRPC.thread_stderr.getvalue_and_close(), exit_code)
return (PIOCoreRPC.thread_stdout.get_value_and_close(),
PIOCoreRPC.thread_stderr.get_value_and_close(), exit_code)
d = threads.deferToThread(_call_cli)
d.addCallback(PIOCoreRPC._call_callback, "--json-output" in args)