Refactor state to a proxied dictionary

This commit is contained in:
Ivan Kravets
2019-07-02 00:41:47 +03:00
parent 6d9de80f12
commit d2c86ab71c
4 changed files with 50 additions and 25 deletions

View File

@ -14,7 +14,6 @@
import codecs import codecs
import hashlib import hashlib
import json
import os import os
import uuid import uuid
from os import environ, getenv, listdir, remove from os import environ, getenv, listdir, remove
@ -93,29 +92,26 @@ class State(object):
self.lock = lock self.lock = lock
if not self.path: if not self.path:
self.path = join(get_project_core_dir(), "appstate.json") self.path = join(get_project_core_dir(), "appstate.json")
self._state = {} self._storage = {}
self._prev_state_raw = ""
self._lockfile = None self._lockfile = None
self._modified = False
def __enter__(self): def __enter__(self):
try: try:
self._lock_state_file() self._lock_state_file()
if isfile(self.path): if isfile(self.path):
with open(self.path) as fp: self._storage = util.load_json(self.path)
self._prev_state_raw = fp.read().strip() assert isinstance(self._storage, dict)
self._state = json.loads(self._prev_state_raw) except (AssertionError, ValueError, UnicodeDecodeError,
assert isinstance(self._state, dict) exception.InvalidJSONFile):
except (AssertionError, ValueError, UnicodeDecodeError): self._storage = {}
self._state = {} return self
self._prev_state_raw = ""
return self._state
def __exit__(self, type_, value, traceback): def __exit__(self, type_, value, traceback):
new_state_raw = dump_json_to_unicode(self._state) if self._modified:
if self._prev_state_raw != new_state_raw:
try: try:
with open(self.path, "w") as fp: with open(self.path, "w") as fp:
fp.write(new_state_raw) fp.write(dump_json_to_unicode(self._storage))
except IOError: except IOError:
raise exception.HomeDirPermissionsError(get_project_core_dir()) raise exception.HomeDirPermissionsError(get_project_core_dir())
self._unlock_state_file() self._unlock_state_file()
@ -133,8 +129,31 @@ class State(object):
if hasattr(self, "_lockfile") and self._lockfile: if hasattr(self, "_lockfile") and self._lockfile:
self._lockfile.release() self._lockfile.release()
def __del__(self): # Dictionary Proxy
self._unlock_state_file()
def as_dict(self):
return self._storage
def get(self, key, default=True):
return self._storage.get(key, default)
def update(self, *args, **kwargs):
self._modified = True
return self._storage.update(*args, **kwargs)
def __getitem__(self, key):
return self._storage[key]
def __setitem__(self, key, value):
self._modified = True
self._storage[key] = value
def __delitem__(self, key):
self._modified = True
del self._storage[key]
def __contains__(self, item):
return item in self._storage
class ContentCache(object): class ContentCache(object):

View File

@ -57,7 +57,7 @@ class AppRPC(object):
] ]
state['storage'] = storage state['storage'] = storage
return state return state.as_dict()
@staticmethod @staticmethod
def get_state(): def get_state():
@ -66,6 +66,6 @@ class AppRPC(object):
@staticmethod @staticmethod
def save_state(state): def save_state(state):
with app.State(AppRPC.APPSTATE_PATH, lock=True) as s: with app.State(AppRPC.APPSTATE_PATH, lock=True) as s:
s.clear() # s.clear()
s.update(state) s.update(state)
return True return True

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
import json import json
import os import os
import sys import sys
import thread
from io import BytesIO from io import BytesIO
import jsonrpc # pylint: disable=import-error import jsonrpc # pylint: disable=import-error
@ -26,6 +25,11 @@ from twisted.internet import threads # pylint: disable=import-error
from platformio import __main__, __version__, util from platformio import __main__, __version__, util
from platformio.compat import string_types from platformio.compat import string_types
try:
from thread import get_ident as thread_get_ident
except ImportError:
from threading import get_ident as thread_get_ident
class ThreadSafeStdBuffer(object): class ThreadSafeStdBuffer(object):
@ -35,19 +39,20 @@ class ThreadSafeStdBuffer(object):
self._buffer = {} self._buffer = {}
def write(self, value): def write(self, value):
thread_id = thread.get_ident() thread_id = thread_get_ident()
if thread_id == self.parent_thread_id: if thread_id == self.parent_thread_id:
return self.parent_stream.write(value) return self.parent_stream.write(
value if isinstance(value, string_types) else value.decode())
if thread_id not in self._buffer: if thread_id not in self._buffer:
self._buffer[thread_id] = BytesIO() self._buffer[thread_id] = BytesIO()
return self._buffer[thread_id].write(value) return self._buffer[thread_id].write(value)
def flush(self): def flush(self):
return (self.parent_stream.flush() return (self.parent_stream.flush()
if thread.get_ident() == self.parent_thread_id else None) if thread_get_ident() == self.parent_thread_id else None)
def getvalue_and_close(self, thread_id=None): def getvalue_and_close(self, thread_id=None):
thread_id = thread_id or thread.get_ident() thread_id = thread_id or thread_get_ident()
if thread_id not in self._buffer: if thread_id not in self._buffer:
return "" return ""
result = self._buffer.get(thread_id).getvalue() result = self._buffer.get(thread_id).getvalue()
@ -59,7 +64,7 @@ class ThreadSafeStdBuffer(object):
class PIOCoreRPC(object): class PIOCoreRPC(object):
def __init__(self): def __init__(self):
cur_thread_id = thread.get_ident() cur_thread_id = thread_get_ident()
PIOCoreRPC.thread_stdout = ThreadSafeStdBuffer(sys.stdout, PIOCoreRPC.thread_stdout = ThreadSafeStdBuffer(sys.stdout,
cur_thread_id) cur_thread_id)
PIOCoreRPC.thread_stderr = ThreadSafeStdBuffer(sys.stderr, PIOCoreRPC.thread_stderr = ThreadSafeStdBuffer(sys.stderr,

View File

@ -53,11 +53,12 @@ class JSONRPCServerProtocol(WebSocketServerProtocol):
message=failure.getErrorMessage()) message=failure.getErrorMessage())
del response["result"] del response["result"]
response['error'] = e.error._data # pylint: disable=protected-access response['error'] = e.error._data # pylint: disable=protected-access
click.secho(str(response['error']), fg="red", err=True)
self.sendJSONResponse(response) self.sendJSONResponse(response)
def sendJSONResponse(self, response): def sendJSONResponse(self, response):
# click.echo("< %s" % response) # click.echo("< %s" % response)
if "error" in response:
click.secho("Error: %s" % response['error'], fg="red", err=True)
response = dump_json_to_unicode(response) response = dump_json_to_unicode(response)
if not PY2 and not is_bytes(response): if not PY2 and not is_bytes(response):
response = response.encode("utf-8") response = response.encode("utf-8")