Significantly speedup PlatformIO Home loading time by migrating to native Python 3 Asynchronous I/O

This commit is contained in:
Ivan Kravets
2021-01-18 18:20:26 +02:00
parent dd7d282d17
commit 6ff67aeadf
13 changed files with 253 additions and 241 deletions

View File

@@ -11,6 +11,7 @@ PlatformIO Core 5
5.0.5 (2021-??-??) 5.0.5 (2021-??-??)
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
* Significantly speedup PlatformIO Home loading time by migrating to native Python 3 Asynchronous I/O
* Improved listing of `multicast DNS services <https://docs.platformio.org/page/core/userguide/device/cmd_list.html>`_ * Improved listing of `multicast DNS services <https://docs.platformio.org/page/core/userguide/device/cmd_list.html>`_
* Check for debug server's "ready_pattern" in "stderr" * Check for debug server's "ready_pattern" in "stderr"

View File

@@ -15,17 +15,17 @@
# pylint: disable=too-many-locals,too-many-statements # pylint: disable=too-many-locals,too-many-statements
import mimetypes import mimetypes
import os
import socket import socket
from os.path import isdir
import click import click
from platformio import exception from platformio import exception
from platformio.compat import WINDOWS from platformio.compat import WINDOWS, ensure_python3
from platformio.package.manager.core import get_core_package_dir, inject_contrib_pysite from platformio.package.manager.core import get_core_package_dir
@click.command("home", short_help="UI to manage PlatformIO") @click.command("home", short_help="GUI to manage PlatformIO")
@click.option("--port", type=int, default=8008, help="HTTP port, default=8008") @click.option("--port", type=int, default=8008, help="HTTP port, default=8008")
@click.option( @click.option(
"--host", "--host",
@@ -46,60 +46,16 @@ from platformio.package.manager.core import get_core_package_dir, inject_contrib
), ),
) )
def cli(port, host, no_open, shutdown_timeout): def cli(port, host, no_open, shutdown_timeout):
# pylint: disable=import-error, import-outside-toplevel
# import contrib modules
inject_contrib_pysite()
from autobahn.twisted.resource import WebSocketResource
from twisted.internet import reactor
from twisted.web import server
from twisted.internet.error import CannotListenError
from platformio.commands.home.rpc.handlers.app import AppRPC
from platformio.commands.home.rpc.handlers.ide import IDERPC
from platformio.commands.home.rpc.handlers.misc import MiscRPC
from platformio.commands.home.rpc.handlers.os import OSRPC
from platformio.commands.home.rpc.handlers.piocore import PIOCoreRPC
from platformio.commands.home.rpc.handlers.project import ProjectRPC
from platformio.commands.home.rpc.handlers.account import AccountRPC
from platformio.commands.home.rpc.server import JSONRPCServerFactory
from platformio.commands.home.web import WebRoot
factory = JSONRPCServerFactory(shutdown_timeout)
factory.addHandler(AppRPC(), namespace="app")
factory.addHandler(IDERPC(), namespace="ide")
factory.addHandler(MiscRPC(), namespace="misc")
factory.addHandler(OSRPC(), namespace="os")
factory.addHandler(PIOCoreRPC(), namespace="core")
factory.addHandler(ProjectRPC(), namespace="project")
factory.addHandler(AccountRPC(), namespace="account")
contrib_dir = get_core_package_dir("contrib-piohome")
if not isdir(contrib_dir):
raise exception.PlatformioException("Invalid path to PIO Home Contrib")
# Ensure PIO Home mimetypes are known # Ensure PIO Home mimetypes are known
mimetypes.add_type("text/html", ".html") mimetypes.add_type("text/html", ".html")
mimetypes.add_type("text/css", ".css") mimetypes.add_type("text/css", ".css")
mimetypes.add_type("application/javascript", ".js") mimetypes.add_type("application/javascript", ".js")
root = WebRoot(contrib_dir)
root.putChild(b"wsrpc", WebSocketResource(factory))
site = server.Site(root)
# hook for `platformio-node-helpers` # hook for `platformio-node-helpers`
if host == "__do_not_start__": if host == "__do_not_start__":
return return
already_started = is_port_used(host, port)
home_url = "http://%s:%d" % (host, port) home_url = "http://%s:%d" % (host, port)
if not no_open:
if already_started:
click.launch(home_url)
else:
reactor.callLater(1, lambda: click.launch(home_url))
click.echo( click.echo(
"\n".join( "\n".join(
[ [
@@ -115,21 +71,21 @@ def cli(port, host, no_open, shutdown_timeout):
click.echo("") click.echo("")
click.echo("Open PlatformIO Home in your browser by this URL => %s" % home_url) click.echo("Open PlatformIO Home in your browser by this URL => %s" % home_url)
try: if is_port_used(host, port):
reactor.listenTCP(port, site, interface=host)
except CannotListenError as e:
click.secho(str(e), fg="red", err=True)
already_started = True
if already_started:
click.secho( click.secho(
"PlatformIO Home server is already started in another process.", fg="yellow" "PlatformIO Home server is already started in another process.", fg="yellow"
) )
if not no_open:
click.launch(home_url)
return return
click.echo("PIO Home has been started. Press Ctrl+C to shutdown.") run_server(
host=host,
reactor.run() port=port,
no_open=no_open,
shutdown_timeout=shutdown_timeout,
home_url=home_url,
)
def is_port_used(host, port): def is_port_used(host, port):
@@ -150,3 +106,54 @@ def is_port_used(host, port):
return False return False
return True return True
def run_server(host, port, no_open, shutdown_timeout, home_url):
# pylint: disable=import-error, import-outside-toplevel
ensure_python3()
import uvicorn
from starlette.applications import Starlette
from starlette.routing import Mount, WebSocketRoute
from starlette.staticfiles import StaticFiles
from platformio.commands.home.rpc.handlers.account import AccountRPC
from platformio.commands.home.rpc.handlers.app import AppRPC
from platformio.commands.home.rpc.handlers.ide import IDERPC
from platformio.commands.home.rpc.handlers.misc import MiscRPC
from platformio.commands.home.rpc.handlers.os import OSRPC
from platformio.commands.home.rpc.handlers.piocore import PIOCoreRPC
from platformio.commands.home.rpc.handlers.project import ProjectRPC
from platformio.commands.home.rpc.server import WebSocketJSONRPCServerFactory
contrib_dir = get_core_package_dir("contrib-piohome")
if not os.path.isdir(contrib_dir):
raise exception.PlatformioException("Invalid path to PIO Home Contrib")
ws_rpc_factory = WebSocketJSONRPCServerFactory(shutdown_timeout)
ws_rpc_factory.addHandler(AccountRPC(), namespace="account")
ws_rpc_factory.addHandler(AppRPC(), namespace="app")
ws_rpc_factory.addHandler(IDERPC(), namespace="ide")
ws_rpc_factory.addHandler(MiscRPC(), namespace="misc")
ws_rpc_factory.addHandler(OSRPC(), namespace="os")
ws_rpc_factory.addHandler(PIOCoreRPC(), namespace="core")
ws_rpc_factory.addHandler(ProjectRPC(), namespace="project")
uvicorn.run(
Starlette(
routes=[
WebSocketRoute("/wsrpc", ws_rpc_factory, name="wsrpc"),
Mount("/", StaticFiles(directory=contrib_dir, html=True)),
],
on_startup=[
lambda: click.echo(
"PIO Home has been started. Press Ctrl+C to shutdown."
),
lambda: None if no_open else click.launch(home_url),
],
),
host=host,
port=port,
log_level="warning",
)

View File

@@ -12,36 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=keyword-arg-before-vararg,arguments-differ,signature-differs
import requests import requests
from twisted.internet import defer # pylint: disable=import-error from starlette.concurrency import run_in_threadpool
from twisted.internet import reactor # pylint: disable=import-error
from twisted.internet import threads # pylint: disable=import-error
from platformio import util from platformio import util
from platformio.proc import where_is_program from platformio.proc import where_is_program
class AsyncSession(requests.Session): class AsyncSession(requests.Session):
def __init__(self, n=None, *args, **kwargs): async def request( # pylint: disable=signature-differs,invalid-overridden-method
if n: self, *args, **kwargs
pool = reactor.getThreadPool() ):
pool.adjustPoolsize(0, n)
super(AsyncSession, self).__init__(*args, **kwargs)
def request(self, *args, **kwargs):
func = super(AsyncSession, self).request func = super(AsyncSession, self).request
return threads.deferToThread(func, *args, **kwargs) return await run_in_threadpool(func, *args, **kwargs)
def wrap(self, *args, **kwargs): # pylint: disable=no-self-use
return defer.ensureDeferred(*args, **kwargs)
@util.memoized(expire="60s") @util.memoized(expire="60s")
def requests_session(): def requests_session():
return AsyncSession(n=5) return AsyncSession()
@util.memoized(expire="60s") @util.memoized(expire="60s")

View File

@@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import jsonrpc # pylint: disable=import-error import jsonrpc
from platformio.clients.account import AccountClient from platformio.clients.account import AccountClient
class AccountRPC(object): class AccountRPC:
@staticmethod @staticmethod
def call_client(method, *args, **kwargs): def call_client(method, *args, **kwargs):
try: try:

View File

@@ -20,7 +20,7 @@ from platformio import __version__, app, fs, util
from platformio.project.helpers import get_project_core_dir, is_platformio_project from platformio.project.helpers import get_project_core_dir, is_platformio_project
class AppRPC(object): class AppRPC:
APPSTATE_PATH = join(get_project_core_dir(), "homestate.json") APPSTATE_PATH = join(get_project_core_dir(), "homestate.json")

View File

@@ -14,11 +14,12 @@
import time import time
import jsonrpc # pylint: disable=import-error import jsonrpc
from twisted.internet import defer # pylint: disable=import-error
from platformio.compat import get_running_loop
class IDERPC(object): class IDERPC:
def __init__(self): def __init__(self):
self._queue = {} self._queue = {}
@@ -28,14 +29,14 @@ class IDERPC(object):
code=4005, message="PIO Home IDE agent is not started" code=4005, message="PIO Home IDE agent is not started"
) )
while self._queue[sid]: while self._queue[sid]:
self._queue[sid].pop().callback( self._queue[sid].pop().set_result(
{"id": time.time(), "method": command, "params": params} {"id": time.time(), "method": command, "params": params}
) )
def listen_commands(self, sid=0): def listen_commands(self, sid=0):
if sid not in self._queue: if sid not in self._queue:
self._queue[sid] = [] self._queue[sid] = []
self._queue[sid].append(defer.Deferred()) self._queue[sid].append(get_running_loop().create_future())
return self._queue[sid][-1] return self._queue[sid][-1]
def open_project(self, sid, project_dir): def open_project(self, sid, project_dir):

View File

@@ -15,14 +15,13 @@
import json import json
import time import time
from twisted.internet import defer, reactor # pylint: disable=import-error
from platformio.cache import ContentCache from platformio.cache import ContentCache
from platformio.commands.home.rpc.handlers.os import OSRPC from platformio.commands.home.rpc.handlers.os import OSRPC
from platformio.compat import create_task
class MiscRPC(object): class MiscRPC:
def load_latest_tweets(self, data_url): async def load_latest_tweets(self, data_url):
cache_key = ContentCache.key_from_args(data_url, "tweets") cache_key = ContentCache.key_from_args(data_url, "tweets")
cache_valid = "180d" cache_valid = "180d"
with ContentCache() as cc: with ContentCache() as cc:
@@ -31,22 +30,20 @@ class MiscRPC(object):
cache_data = json.loads(cache_data) cache_data = json.loads(cache_data)
# automatically update cache in background every 12 hours # automatically update cache in background every 12 hours
if cache_data["time"] < (time.time() - (3600 * 12)): if cache_data["time"] < (time.time() - (3600 * 12)):
reactor.callLater( create_task(
5, self._preload_latest_tweets, data_url, cache_key, cache_valid self._preload_latest_tweets(data_url, cache_key, cache_valid)
) )
return cache_data["result"] return cache_data["result"]
result = self._preload_latest_tweets(data_url, cache_key, cache_valid) return await self._preload_latest_tweets(data_url, cache_key, cache_valid)
return result
@staticmethod @staticmethod
@defer.inlineCallbacks async def _preload_latest_tweets(data_url, cache_key, cache_valid):
def _preload_latest_tweets(data_url, cache_key, cache_valid): result = json.loads((await OSRPC.fetch_content(data_url)))
result = json.loads((yield OSRPC.fetch_content(data_url)))
with ContentCache() as cc: with ContentCache() as cc:
cc.set( cc.set(
cache_key, cache_key,
json.dumps({"time": int(time.time()), "result": result}), json.dumps({"time": int(time.time()), "result": result}),
cache_valid, cache_valid,
) )
defer.returnValue(result) return result

View File

@@ -14,25 +14,23 @@
from __future__ import absolute_import from __future__ import absolute_import
import glob
import io import io
import os import os
import shutil import shutil
from functools import cmp_to_key from functools import cmp_to_key
import click import click
from twisted.internet import defer # pylint: disable=import-error
from platformio import __default_requests_timeout__, fs, util from platformio import __default_requests_timeout__, fs, util
from platformio.cache import ContentCache from platformio.cache import ContentCache
from platformio.clients.http import ensure_internet_on from platformio.clients.http import ensure_internet_on
from platformio.commands.home import helpers from platformio.commands.home import helpers
from platformio.compat import PY2, get_filesystem_encoding, glob_recursive
class OSRPC(object): class OSRPC:
@staticmethod @staticmethod
@defer.inlineCallbacks async def fetch_content(uri, data=None, headers=None, cache_valid=None):
def fetch_content(uri, data=None, headers=None, cache_valid=None):
if not headers: if not headers:
headers = { headers = {
"User-Agent": ( "User-Agent": (
@@ -46,18 +44,18 @@ class OSRPC(object):
if cache_key: if cache_key:
result = cc.get(cache_key) result = cc.get(cache_key)
if result is not None: if result is not None:
defer.returnValue(result) return result
# check internet before and resolve issue with 60 seconds timeout # check internet before and resolve issue with 60 seconds timeout
ensure_internet_on(raise_exception=True) ensure_internet_on(raise_exception=True)
session = helpers.requests_session() session = helpers.requests_session()
if data: if data:
r = yield session.post( r = await session.post(
uri, data=data, headers=headers, timeout=__default_requests_timeout__ uri, data=data, headers=headers, timeout=__default_requests_timeout__
) )
else: else:
r = yield session.get( r = await session.get(
uri, headers=headers, timeout=__default_requests_timeout__ uri, headers=headers, timeout=__default_requests_timeout__
) )
@@ -66,11 +64,11 @@ class OSRPC(object):
if cache_valid: if cache_valid:
with ContentCache() as cc: with ContentCache() as cc:
cc.set(cache_key, result, cache_valid) cc.set(cache_key, result, cache_valid)
defer.returnValue(result) return result
def request_content(self, uri, data=None, headers=None, cache_valid=None): async def request_content(self, uri, data=None, headers=None, cache_valid=None):
if uri.startswith("http"): if uri.startswith("http"):
return self.fetch_content(uri, data, headers, cache_valid) return await self.fetch_content(uri, data, headers, cache_valid)
if os.path.isfile(uri): if os.path.isfile(uri):
with io.open(uri, encoding="utf-8") as fp: with io.open(uri, encoding="utf-8") as fp:
return fp.read() return fp.read()
@@ -82,13 +80,11 @@ class OSRPC(object):
@staticmethod @staticmethod
def reveal_file(path): def reveal_file(path):
return click.launch( return click.launch(path, locate=True)
path.encode(get_filesystem_encoding()) if PY2 else path, locate=True
)
@staticmethod @staticmethod
def open_file(path): def open_file(path):
return click.launch(path.encode(get_filesystem_encoding()) if PY2 else path) return click.launch(path)
@staticmethod @staticmethod
def is_file(path): def is_file(path):
@@ -121,7 +117,9 @@ class OSRPC(object):
result = set() result = set()
for pathname in pathnames: for pathname in pathnames:
result |= set( result |= set(
glob_recursive(os.path.join(root, pathname) if root else pathname) glob.glob(
os.path.join(root, pathname) if root else pathname, recursive=True
)
) )
return list(result) return list(result)

View File

@@ -17,23 +17,15 @@ from __future__ import absolute_import
import json import json
import os import os
import sys import sys
from io import BytesIO, StringIO from io import StringIO
import click import click
import jsonrpc # pylint: disable=import-error import jsonrpc
from twisted.internet import defer # pylint: disable=import-error from starlette.concurrency import run_in_threadpool
from twisted.internet import threads # pylint: disable=import-error
from twisted.internet import utils # pylint: disable=import-error
from platformio import __main__, __version__, fs from platformio import __main__, __version__, fs, proc
from platformio.commands.home import helpers from platformio.commands.home import helpers
from platformio.compat import ( from platformio.compat import get_locale_encoding, is_bytes
PY2,
get_filesystem_encoding,
get_locale_encoding,
is_bytes,
string_types,
)
try: try:
from thread import get_ident as thread_get_ident from thread import get_ident as thread_get_ident
@@ -52,13 +44,11 @@ class MultiThreadingStdStream(object):
def _ensure_thread_buffer(self, thread_id): def _ensure_thread_buffer(self, thread_id):
if thread_id not in self._buffers: if thread_id not in self._buffers:
self._buffers[thread_id] = BytesIO() if PY2 else StringIO() self._buffers[thread_id] = StringIO()
def write(self, value): def write(self, value):
thread_id = thread_get_ident() thread_id = thread_get_ident()
self._ensure_thread_buffer(thread_id) self._ensure_thread_buffer(thread_id)
if PY2 and isinstance(value, unicode): # pylint: disable=undefined-variable
value = value.encode()
return self._buffers[thread_id].write( return self._buffers[thread_id].write(
value.decode() if is_bytes(value) else value value.decode() if is_bytes(value) else value
) )
@@ -74,7 +64,7 @@ class MultiThreadingStdStream(object):
return result return result
class PIOCoreRPC(object): class PIOCoreRPC:
@staticmethod @staticmethod
def version(): def version():
return __version__ return __version__
@@ -89,16 +79,9 @@ class PIOCoreRPC(object):
sys.stderr = PIOCoreRPC.thread_stderr sys.stderr = PIOCoreRPC.thread_stderr
@staticmethod @staticmethod
def call(args, options=None): async def call(args, options=None):
return defer.maybeDeferred(PIOCoreRPC._call_generator, args, options)
@staticmethod
@defer.inlineCallbacks
def _call_generator(args, options=None):
for i, arg in enumerate(args): for i, arg in enumerate(args):
if isinstance(arg, string_types): if not isinstance(arg, str):
args[i] = arg.encode(get_filesystem_encoding()) if PY2 else arg
else:
args[i] = str(arg) args[i] = str(arg)
options = options or {} options = options or {}
@@ -106,27 +89,34 @@ class PIOCoreRPC(object):
try: try:
if options.get("force_subprocess"): if options.get("force_subprocess"):
result = yield PIOCoreRPC._call_subprocess(args, options) result = await PIOCoreRPC._call_subprocess(args, options)
defer.returnValue(PIOCoreRPC._process_result(result, to_json)) return PIOCoreRPC._process_result(result, to_json)
else: result = await PIOCoreRPC._call_inline(args, options)
result = yield PIOCoreRPC._call_inline(args, options)
try: try:
defer.returnValue(PIOCoreRPC._process_result(result, to_json)) return PIOCoreRPC._process_result(result, to_json)
except ValueError: except ValueError:
# fall-back to subprocess method # fall-back to subprocess method
result = yield PIOCoreRPC._call_subprocess(args, options) result = await PIOCoreRPC._call_subprocess(args, options)
defer.returnValue(PIOCoreRPC._process_result(result, to_json)) return PIOCoreRPC._process_result(result, to_json)
except Exception as e: # pylint: disable=bare-except except Exception as e: # pylint: disable=bare-except
raise jsonrpc.exceptions.JSONRPCDispatchException( raise jsonrpc.exceptions.JSONRPCDispatchException(
code=4003, message="PIO Core Call Error", data=str(e) code=4003, message="PIO Core Call Error", data=str(e)
) )
@staticmethod @staticmethod
def _call_inline(args, options): async def _call_subprocess(args, options):
PIOCoreRPC.setup_multithreading_std_streams() result = await run_in_threadpool(
cwd = options.get("cwd") or os.getcwd() proc.exec_command,
[helpers.get_core_fullpath()] + args,
cwd=options.get("cwd") or os.getcwd(),
)
return (result["out"], result["err"], result["returncode"])
def _thread_task(): @staticmethod
async def _call_inline(args, options):
PIOCoreRPC.setup_multithreading_std_streams()
def _thread_safe_call(args, cwd):
with fs.cd(cwd): with fs.cd(cwd):
exit_code = __main__.main(["-c"] + args) exit_code = __main__.main(["-c"] + args)
return ( return (
@@ -135,16 +125,8 @@ class PIOCoreRPC(object):
exit_code, exit_code,
) )
return threads.deferToThread(_thread_task) return await run_in_threadpool(
_thread_safe_call, args=args, cwd=options.get("cwd") or os.getcwd()
@staticmethod
def _call_subprocess(args, options):
cwd = (options or {}).get("cwd") or os.getcwd()
return utils.getProcessOutputAndValue(
helpers.get_core_fullpath(),
args,
path=cwd,
env={k: v for k, v in os.environ.items() if "%" not in k},
) )
@staticmethod @staticmethod

View File

@@ -18,12 +18,11 @@ import os
import shutil import shutil
import time import time
import jsonrpc # pylint: disable=import-error import jsonrpc
from platformio import exception, fs from platformio import exception, fs
from platformio.commands.home.rpc.handlers.app import AppRPC from platformio.commands.home.rpc.handlers.app import AppRPC
from platformio.commands.home.rpc.handlers.piocore import PIOCoreRPC from platformio.commands.home.rpc.handlers.piocore import PIOCoreRPC
from platformio.compat import PY2, get_filesystem_encoding
from platformio.ide.projectgenerator import ProjectGenerator from platformio.ide.projectgenerator import ProjectGenerator
from platformio.package.manager.platform import PlatformPackageManager from platformio.package.manager.platform import PlatformPackageManager
from platformio.project.config import ProjectConfig from platformio.project.config import ProjectConfig
@@ -32,7 +31,7 @@ from platformio.project.helpers import get_project_dir, is_platformio_project
from platformio.project.options import get_config_options_schema from platformio.project.options import get_config_options_schema
class ProjectRPC(object): class ProjectRPC:
@staticmethod @staticmethod
def config_call(init_kwargs, method, *args): def config_call(init_kwargs, method, *args):
assert isinstance(init_kwargs, dict) assert isinstance(init_kwargs, dict)
@@ -254,8 +253,6 @@ class ProjectRPC(object):
def import_arduino(self, board, use_arduino_libs, arduino_project_dir): def import_arduino(self, board, use_arduino_libs, arduino_project_dir):
board = str(board) board = str(board)
if arduino_project_dir and PY2:
arduino_project_dir = arduino_project_dir.encode(get_filesystem_encoding())
# don't import PIO Project # don't import PIO Project
if is_platformio_project(arduino_project_dir): if is_platformio_project(arduino_project_dir):
return arduino_project_dir return arduino_project_dir

View File

@@ -12,90 +12,112 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=import-error import inspect
import json
import sys
import click import click
import jsonrpc import jsonrpc
from autobahn.twisted.websocket import WebSocketServerFactory, WebSocketServerProtocol from starlette.endpoints import WebSocketEndpoint
from jsonrpc.exceptions import JSONRPCDispatchException
from twisted.internet import defer, reactor
from platformio.compat import PY2, dump_json_to_unicode, is_bytes from platformio.compat import create_task, get_running_loop, is_bytes
class JSONRPCServerProtocol(WebSocketServerProtocol): class JSONRPCServerFactoryBase:
def onOpen(self):
self.factory.connection_nums += 1
if self.factory.shutdown_timer:
self.factory.shutdown_timer.cancel()
self.factory.shutdown_timer = None
def onClose(self, wasClean, code, reason): # pylint: disable=unused-argument
self.factory.connection_nums -= 1
if self.factory.connection_nums == 0:
self.factory.shutdownByTimeout()
def onMessage(self, payload, isBinary): # pylint: disable=unused-argument
# click.echo("> %s" % payload)
response = jsonrpc.JSONRPCResponseManager.handle(
payload, self.factory.dispatcher
).data
# if error
if "result" not in response:
self.sendJSONResponse(response)
return None
d = defer.maybeDeferred(lambda: response["result"])
d.addCallback(self._callback, response)
d.addErrback(self._errback, response)
return None
def _callback(self, result, response):
response["result"] = result
self.sendJSONResponse(response)
def _errback(self, failure, response):
if isinstance(failure.value, JSONRPCDispatchException):
e = failure.value
else:
e = JSONRPCDispatchException(code=4999, message=failure.getErrorMessage())
del response["result"]
response["error"] = e.error._data # pylint: disable=protected-access
self.sendJSONResponse(response)
def sendJSONResponse(self, response):
# click.echo("< %s" % response)
if "error" in response:
click.secho("Error: %s" % response["error"], fg="red", err=True)
response = dump_json_to_unicode(response)
if not PY2 and not is_bytes(response):
response = response.encode("utf-8")
self.sendMessage(response)
class JSONRPCServerFactory(WebSocketServerFactory):
protocol = JSONRPCServerProtocol
connection_nums = 0 connection_nums = 0
shutdown_timer = 0 shutdown_timer = None
def __init__(self, shutdown_timeout=0): def __init__(self, shutdown_timeout=0):
super(JSONRPCServerFactory, self).__init__()
self.shutdown_timeout = shutdown_timeout self.shutdown_timeout = shutdown_timeout
self.dispatcher = jsonrpc.Dispatcher() self.dispatcher = jsonrpc.Dispatcher()
def shutdownByTimeout(self): def __call__(self, *args, **kwargs):
raise NotImplementedError
def addHandler(self, handler, namespace):
self.dispatcher.build_method_map(handler, prefix="%s." % namespace)
def on_client_connect(self):
self.connection_nums += 1
if self.shutdown_timer:
self.shutdown_timer.cancel()
self.shutdown_timer = None
def on_client_disconnect(self):
self.connection_nums -= 1
if self.connection_nums < 1:
self.connection_nums = 0
if self.connection_nums == 0:
self.shutdown_by_timeout()
async def on_shutdown(self):
pass
def shutdown_by_timeout(self):
if self.shutdown_timeout < 1: if self.shutdown_timeout < 1:
return return
def _auto_shutdown_server(): def _auto_shutdown_server():
click.echo("Automatically shutdown server on timeout") click.echo("Automatically shutdown server on timeout")
reactor.stop() try:
get_running_loop().stop()
except: # pylint: disable=bare-except
pass
finally:
sys.exit(0)
self.shutdown_timer = reactor.callLater( self.shutdown_timer = get_running_loop().call_later(
self.shutdown_timeout, _auto_shutdown_server self.shutdown_timeout, _auto_shutdown_server
) )
def addHandler(self, handler, namespace):
self.dispatcher.build_method_map(handler, prefix="%s." % namespace) class WebSocketJSONRPCServerFactory(JSONRPCServerFactoryBase):
def __call__(self, *args, **kwargs):
ws = WebSocketJSONRPCServer(*args, **kwargs)
ws.factory = self
return ws
class WebSocketJSONRPCServer(WebSocketEndpoint):
encoding = "text"
factory: WebSocketJSONRPCServerFactory = None
async def on_connect(self, websocket):
await websocket.accept()
self.factory.on_client_connect() # pylint: disable=no-member
async def on_receive(self, websocket, data):
create_task(self._handle_rpc(websocket, data))
async def on_disconnect(self, websocket, close_code):
self.factory.on_client_disconnect() # pylint: disable=no-member
async def _handle_rpc(self, websocket, data):
response = jsonrpc.JSONRPCResponseManager.handle(
data, self.factory.dispatcher # pylint: disable=no-member
)
if response.result and inspect.isawaitable(response.result):
try:
response.result = await response.result
response.data["result"] = response.result
response.error = None
except Exception as exc: # pylint: disable=broad-except
if not isinstance(exc, jsonrpc.exceptions.JSONRPCDispatchException):
exc = jsonrpc.exceptions.JSONRPCDispatchException(
code=4999, message=str(exc)
)
response.result = None
response.error = exc.error._data # pylint: disable=protected-access
new_data = response.data.copy()
new_data["error"] = response.error
del new_data["result"]
response.data = new_data
if response.error:
click.secho("Error: %s" % response.error, fg="red", err=True)
if "result" in response.data and is_bytes(response.data["result"]):
response.data["result"] = response.data["result"].decode("utf-8")
await websocket.send_text(json.dumps(response.data))

View File

@@ -78,6 +78,12 @@ if PY2:
string_types = (str, unicode) string_types = (str, unicode)
def create_task(coro, name=None):
raise NotImplementedError
def get_running_loop():
raise NotImplementedError
def is_bytes(x): def is_bytes(x):
return isinstance(x, (buffer, bytearray)) return isinstance(x, (buffer, bytearray))
@@ -129,6 +135,12 @@ else:
import importlib.util import importlib.util
from glob import escape as glob_escape from glob import escape as glob_escape
if sys.version_info >= (3, 7):
from asyncio import create_task, get_running_loop
else:
from asyncio import ensure_future as create_task
from asyncio import get_event_loop as get_running_loop
string_types = (str,) string_types = (str,)
def is_bytes(x): def is_bytes(x):

View File

@@ -26,19 +26,26 @@ from platformio import (
from platformio.compat import PY2, WINDOWS from platformio.compat import PY2, WINDOWS
install_requires = [ minimal_requirements = [
"bottle<0.13", "bottle==0.12.*",
"click>=5,<8%s" % (",!=7.1,!=7.1.1" if WINDOWS else ""), "click>=5,<8%s" % (",!=7.1,!=7.1.1" if WINDOWS else ""),
"colorama", "colorama",
"pyserial>=3,<4,!=3.3", "marshmallow%s" % (">=2,<3" if PY2 else ">=2,<4"),
"requests>=2.4.0,<3", "pyelftools>=0.27,<1",
"semantic_version>=2.8.1,<3", "pyserial==3.*",
"tabulate>=0.8.3,<1", "requests==2.*",
"pyelftools>=0.25,<1", "semantic_version==2.8.*",
"marshmallow%s" % (">=2,<3" if PY2 else ">=2"), "tabulate==0.8.*",
"zeroconf==%s" % ("0.19.*" if PY2 else "0.28.*"), "zeroconf==%s" % ("0.19.*" if PY2 else "0.28.*"),
] ]
home_requirements = [
"aiofiles==0.6.*",
"json-rpc==1.13.*",
"starlette==0.14.*",
"uvicorn==0.13.*",
"wsproto==1.0.*",
]
setup( setup(
name=__title__, name=__title__,
@@ -52,7 +59,7 @@ setup(
python_requires=", ".join( python_requires=", ".join(
[">=2.7", "!=3.0.*", "!=3.1.*", "!=3.2.*", "!=3.3.*", "!=3.4.*"] [">=2.7", "!=3.0.*", "!=3.1.*", "!=3.2.*", "!=3.3.*", "!=3.4.*"]
), ),
install_requires=install_requires, install_requires=minimal_requirements + ([] if PY2 else home_requirements),
packages=find_packages(exclude=["tests.*", "tests"]) + ["scripts"], packages=find_packages(exclude=["tests.*", "tests"]) + ["scripts"],
package_data={ package_data={
"platformio": [ "platformio": [