From 6ff67aeadfc63b421c0f652c70ac0df5b8bb4098 Mon Sep 17 00:00:00 2001 From: Ivan Kravets Date: Mon, 18 Jan 2021 18:20:26 +0200 Subject: [PATCH] Significantly speedup PlatformIO Home loading time by migrating to native Python 3 Asynchronous I/O --- HISTORY.rst | 1 + platformio/commands/home/command.py | 123 +++++++------- platformio/commands/home/helpers.py | 24 +-- .../commands/home/rpc/handlers/account.py | 4 +- platformio/commands/home/rpc/handlers/app.py | 2 +- platformio/commands/home/rpc/handlers/ide.py | 11 +- platformio/commands/home/rpc/handlers/misc.py | 21 ++- platformio/commands/home/rpc/handlers/os.py | 30 ++-- .../commands/home/rpc/handlers/piocore.py | 82 ++++------ .../commands/home/rpc/handlers/project.py | 7 +- platformio/commands/home/rpc/server.py | 152 ++++++++++-------- platformio/compat.py | 12 ++ setup.py | 25 +-- 13 files changed, 253 insertions(+), 241 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index 0b16b2e8..0bca88c8 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -11,6 +11,7 @@ PlatformIO Core 5 5.0.5 (2021-??-??) ~~~~~~~~~~~~~~~~~~ +* Significantly speedup PlatformIO Home loading time by migrating to native Python 3 Asynchronous I/O * Improved listing of `multicast DNS services `_ * Check for debug server's "ready_pattern" in "stderr" diff --git a/platformio/commands/home/command.py b/platformio/commands/home/command.py index 6cb26ed9..abb13074 100644 --- a/platformio/commands/home/command.py +++ b/platformio/commands/home/command.py @@ -15,17 +15,17 @@ # pylint: disable=too-many-locals,too-many-statements import mimetypes +import os import socket -from os.path import isdir import click from platformio import exception -from platformio.compat import WINDOWS -from platformio.package.manager.core import get_core_package_dir, inject_contrib_pysite +from platformio.compat import WINDOWS, ensure_python3 +from platformio.package.manager.core import get_core_package_dir -@click.command("home", short_help="UI to manage PlatformIO") +@click.command("home", short_help="GUI to manage PlatformIO") @click.option("--port", type=int, default=8008, help="HTTP port, default=8008") @click.option( "--host", @@ -46,60 +46,16 @@ from platformio.package.manager.core import get_core_package_dir, inject_contrib ), ) def cli(port, host, no_open, shutdown_timeout): - # pylint: disable=import-error, import-outside-toplevel - - # import contrib modules - inject_contrib_pysite() - - from autobahn.twisted.resource import WebSocketResource - from twisted.internet import reactor - from twisted.web import server - from twisted.internet.error import CannotListenError - - from platformio.commands.home.rpc.handlers.app import AppRPC - from platformio.commands.home.rpc.handlers.ide import IDERPC - from platformio.commands.home.rpc.handlers.misc import MiscRPC - from platformio.commands.home.rpc.handlers.os import OSRPC - from platformio.commands.home.rpc.handlers.piocore import PIOCoreRPC - from platformio.commands.home.rpc.handlers.project import ProjectRPC - from platformio.commands.home.rpc.handlers.account import AccountRPC - from platformio.commands.home.rpc.server import JSONRPCServerFactory - from platformio.commands.home.web import WebRoot - - factory = JSONRPCServerFactory(shutdown_timeout) - factory.addHandler(AppRPC(), namespace="app") - factory.addHandler(IDERPC(), namespace="ide") - factory.addHandler(MiscRPC(), namespace="misc") - factory.addHandler(OSRPC(), namespace="os") - factory.addHandler(PIOCoreRPC(), namespace="core") - factory.addHandler(ProjectRPC(), namespace="project") - factory.addHandler(AccountRPC(), namespace="account") - - contrib_dir = get_core_package_dir("contrib-piohome") - if not isdir(contrib_dir): - raise exception.PlatformioException("Invalid path to PIO Home Contrib") - # Ensure PIO Home mimetypes are known mimetypes.add_type("text/html", ".html") mimetypes.add_type("text/css", ".css") mimetypes.add_type("application/javascript", ".js") - root = WebRoot(contrib_dir) - root.putChild(b"wsrpc", WebSocketResource(factory)) - site = server.Site(root) - # hook for `platformio-node-helpers` if host == "__do_not_start__": return - already_started = is_port_used(host, port) home_url = "http://%s:%d" % (host, port) - if not no_open: - if already_started: - click.launch(home_url) - else: - reactor.callLater(1, lambda: click.launch(home_url)) - click.echo( "\n".join( [ @@ -115,21 +71,21 @@ def cli(port, host, no_open, shutdown_timeout): click.echo("") click.echo("Open PlatformIO Home in your browser by this URL => %s" % home_url) - try: - reactor.listenTCP(port, site, interface=host) - except CannotListenError as e: - click.secho(str(e), fg="red", err=True) - already_started = True - - if already_started: + if is_port_used(host, port): click.secho( "PlatformIO Home server is already started in another process.", fg="yellow" ) + if not no_open: + click.launch(home_url) return - click.echo("PIO Home has been started. Press Ctrl+C to shutdown.") - - reactor.run() + run_server( + host=host, + port=port, + no_open=no_open, + shutdown_timeout=shutdown_timeout, + home_url=home_url, + ) def is_port_used(host, port): @@ -150,3 +106,54 @@ def is_port_used(host, port): return False return True + + +def run_server(host, port, no_open, shutdown_timeout, home_url): + # pylint: disable=import-error, import-outside-toplevel + + ensure_python3() + + import uvicorn + from starlette.applications import Starlette + from starlette.routing import Mount, WebSocketRoute + from starlette.staticfiles import StaticFiles + + from platformio.commands.home.rpc.handlers.account import AccountRPC + from platformio.commands.home.rpc.handlers.app import AppRPC + from platformio.commands.home.rpc.handlers.ide import IDERPC + from platformio.commands.home.rpc.handlers.misc import MiscRPC + from platformio.commands.home.rpc.handlers.os import OSRPC + from platformio.commands.home.rpc.handlers.piocore import PIOCoreRPC + from platformio.commands.home.rpc.handlers.project import ProjectRPC + from platformio.commands.home.rpc.server import WebSocketJSONRPCServerFactory + + contrib_dir = get_core_package_dir("contrib-piohome") + if not os.path.isdir(contrib_dir): + raise exception.PlatformioException("Invalid path to PIO Home Contrib") + + ws_rpc_factory = WebSocketJSONRPCServerFactory(shutdown_timeout) + ws_rpc_factory.addHandler(AccountRPC(), namespace="account") + ws_rpc_factory.addHandler(AppRPC(), namespace="app") + ws_rpc_factory.addHandler(IDERPC(), namespace="ide") + ws_rpc_factory.addHandler(MiscRPC(), namespace="misc") + ws_rpc_factory.addHandler(OSRPC(), namespace="os") + ws_rpc_factory.addHandler(PIOCoreRPC(), namespace="core") + ws_rpc_factory.addHandler(ProjectRPC(), namespace="project") + + uvicorn.run( + Starlette( + routes=[ + WebSocketRoute("/wsrpc", ws_rpc_factory, name="wsrpc"), + Mount("/", StaticFiles(directory=contrib_dir, html=True)), + ], + on_startup=[ + lambda: click.echo( + "PIO Home has been started. Press Ctrl+C to shutdown." + ), + lambda: None if no_open else click.launch(home_url), + ], + ), + host=host, + port=port, + log_level="warning", + ) diff --git a/platformio/commands/home/helpers.py b/platformio/commands/home/helpers.py index aff92281..dfd57c25 100644 --- a/platformio/commands/home/helpers.py +++ b/platformio/commands/home/helpers.py @@ -12,36 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=keyword-arg-before-vararg,arguments-differ,signature-differs - import requests -from twisted.internet import defer # pylint: disable=import-error -from twisted.internet import reactor # pylint: disable=import-error -from twisted.internet import threads # pylint: disable=import-error +from starlette.concurrency import run_in_threadpool from platformio import util from platformio.proc import where_is_program class AsyncSession(requests.Session): - def __init__(self, n=None, *args, **kwargs): - if n: - pool = reactor.getThreadPool() - pool.adjustPoolsize(0, n) - - super(AsyncSession, self).__init__(*args, **kwargs) - - def request(self, *args, **kwargs): + async def request( # pylint: disable=signature-differs,invalid-overridden-method + self, *args, **kwargs + ): func = super(AsyncSession, self).request - return threads.deferToThread(func, *args, **kwargs) - - def wrap(self, *args, **kwargs): # pylint: disable=no-self-use - return defer.ensureDeferred(*args, **kwargs) + return await run_in_threadpool(func, *args, **kwargs) @util.memoized(expire="60s") def requests_session(): - return AsyncSession(n=5) + return AsyncSession() @util.memoized(expire="60s") diff --git a/platformio/commands/home/rpc/handlers/account.py b/platformio/commands/home/rpc/handlers/account.py index d28379f8..337d780a 100644 --- a/platformio/commands/home/rpc/handlers/account.py +++ b/platformio/commands/home/rpc/handlers/account.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import jsonrpc # pylint: disable=import-error +import jsonrpc from platformio.clients.account import AccountClient -class AccountRPC(object): +class AccountRPC: @staticmethod def call_client(method, *args, **kwargs): try: diff --git a/platformio/commands/home/rpc/handlers/app.py b/platformio/commands/home/rpc/handlers/app.py index 1fd49e22..3c0ce465 100644 --- a/platformio/commands/home/rpc/handlers/app.py +++ b/platformio/commands/home/rpc/handlers/app.py @@ -20,7 +20,7 @@ from platformio import __version__, app, fs, util from platformio.project.helpers import get_project_core_dir, is_platformio_project -class AppRPC(object): +class AppRPC: APPSTATE_PATH = join(get_project_core_dir(), "homestate.json") diff --git a/platformio/commands/home/rpc/handlers/ide.py b/platformio/commands/home/rpc/handlers/ide.py index e3ad75f3..ed95b738 100644 --- a/platformio/commands/home/rpc/handlers/ide.py +++ b/platformio/commands/home/rpc/handlers/ide.py @@ -14,11 +14,12 @@ import time -import jsonrpc # pylint: disable=import-error -from twisted.internet import defer # pylint: disable=import-error +import jsonrpc + +from platformio.compat import get_running_loop -class IDERPC(object): +class IDERPC: def __init__(self): self._queue = {} @@ -28,14 +29,14 @@ class IDERPC(object): code=4005, message="PIO Home IDE agent is not started" ) while self._queue[sid]: - self._queue[sid].pop().callback( + self._queue[sid].pop().set_result( {"id": time.time(), "method": command, "params": params} ) def listen_commands(self, sid=0): if sid not in self._queue: self._queue[sid] = [] - self._queue[sid].append(defer.Deferred()) + self._queue[sid].append(get_running_loop().create_future()) return self._queue[sid][-1] def open_project(self, sid, project_dir): diff --git a/platformio/commands/home/rpc/handlers/misc.py b/platformio/commands/home/rpc/handlers/misc.py index a4bdc652..c16a6cc9 100644 --- a/platformio/commands/home/rpc/handlers/misc.py +++ b/platformio/commands/home/rpc/handlers/misc.py @@ -15,14 +15,13 @@ import json import time -from twisted.internet import defer, reactor # pylint: disable=import-error - from platformio.cache import ContentCache from platformio.commands.home.rpc.handlers.os import OSRPC +from platformio.compat import create_task -class MiscRPC(object): - def load_latest_tweets(self, data_url): +class MiscRPC: + async def load_latest_tweets(self, data_url): cache_key = ContentCache.key_from_args(data_url, "tweets") cache_valid = "180d" with ContentCache() as cc: @@ -31,22 +30,20 @@ class MiscRPC(object): cache_data = json.loads(cache_data) # automatically update cache in background every 12 hours if cache_data["time"] < (time.time() - (3600 * 12)): - reactor.callLater( - 5, self._preload_latest_tweets, data_url, cache_key, cache_valid + create_task( + self._preload_latest_tweets(data_url, cache_key, cache_valid) ) return cache_data["result"] - result = self._preload_latest_tweets(data_url, cache_key, cache_valid) - return result + return await self._preload_latest_tweets(data_url, cache_key, cache_valid) @staticmethod - @defer.inlineCallbacks - def _preload_latest_tweets(data_url, cache_key, cache_valid): - result = json.loads((yield OSRPC.fetch_content(data_url))) + async def _preload_latest_tweets(data_url, cache_key, cache_valid): + result = json.loads((await OSRPC.fetch_content(data_url))) with ContentCache() as cc: cc.set( cache_key, json.dumps({"time": int(time.time()), "result": result}), cache_valid, ) - defer.returnValue(result) + return result diff --git a/platformio/commands/home/rpc/handlers/os.py b/platformio/commands/home/rpc/handlers/os.py index 448c633a..f1042978 100644 --- a/platformio/commands/home/rpc/handlers/os.py +++ b/platformio/commands/home/rpc/handlers/os.py @@ -14,25 +14,23 @@ from __future__ import absolute_import +import glob import io import os import shutil from functools import cmp_to_key import click -from twisted.internet import defer # pylint: disable=import-error from platformio import __default_requests_timeout__, fs, util from platformio.cache import ContentCache from platformio.clients.http import ensure_internet_on from platformio.commands.home import helpers -from platformio.compat import PY2, get_filesystem_encoding, glob_recursive -class OSRPC(object): +class OSRPC: @staticmethod - @defer.inlineCallbacks - def fetch_content(uri, data=None, headers=None, cache_valid=None): + async def fetch_content(uri, data=None, headers=None, cache_valid=None): if not headers: headers = { "User-Agent": ( @@ -46,18 +44,18 @@ class OSRPC(object): if cache_key: result = cc.get(cache_key) if result is not None: - defer.returnValue(result) + return result # check internet before and resolve issue with 60 seconds timeout ensure_internet_on(raise_exception=True) session = helpers.requests_session() if data: - r = yield session.post( + r = await session.post( uri, data=data, headers=headers, timeout=__default_requests_timeout__ ) else: - r = yield session.get( + r = await session.get( uri, headers=headers, timeout=__default_requests_timeout__ ) @@ -66,11 +64,11 @@ class OSRPC(object): if cache_valid: with ContentCache() as cc: cc.set(cache_key, result, cache_valid) - defer.returnValue(result) + return result - def request_content(self, uri, data=None, headers=None, cache_valid=None): + async def request_content(self, uri, data=None, headers=None, cache_valid=None): if uri.startswith("http"): - return self.fetch_content(uri, data, headers, cache_valid) + return await self.fetch_content(uri, data, headers, cache_valid) if os.path.isfile(uri): with io.open(uri, encoding="utf-8") as fp: return fp.read() @@ -82,13 +80,11 @@ class OSRPC(object): @staticmethod def reveal_file(path): - return click.launch( - path.encode(get_filesystem_encoding()) if PY2 else path, locate=True - ) + return click.launch(path, locate=True) @staticmethod def open_file(path): - return click.launch(path.encode(get_filesystem_encoding()) if PY2 else path) + return click.launch(path) @staticmethod def is_file(path): @@ -121,7 +117,9 @@ class OSRPC(object): result = set() for pathname in pathnames: result |= set( - glob_recursive(os.path.join(root, pathname) if root else pathname) + glob.glob( + os.path.join(root, pathname) if root else pathname, recursive=True + ) ) return list(result) diff --git a/platformio/commands/home/rpc/handlers/piocore.py b/platformio/commands/home/rpc/handlers/piocore.py index 7a16f9c6..d74095ab 100644 --- a/platformio/commands/home/rpc/handlers/piocore.py +++ b/platformio/commands/home/rpc/handlers/piocore.py @@ -17,23 +17,15 @@ from __future__ import absolute_import import json import os import sys -from io import BytesIO, StringIO +from io import StringIO import click -import jsonrpc # pylint: disable=import-error -from twisted.internet import defer # pylint: disable=import-error -from twisted.internet import threads # pylint: disable=import-error -from twisted.internet import utils # pylint: disable=import-error +import jsonrpc +from starlette.concurrency import run_in_threadpool -from platformio import __main__, __version__, fs +from platformio import __main__, __version__, fs, proc from platformio.commands.home import helpers -from platformio.compat import ( - PY2, - get_filesystem_encoding, - get_locale_encoding, - is_bytes, - string_types, -) +from platformio.compat import get_locale_encoding, is_bytes try: from thread import get_ident as thread_get_ident @@ -52,13 +44,11 @@ class MultiThreadingStdStream(object): def _ensure_thread_buffer(self, thread_id): if thread_id not in self._buffers: - self._buffers[thread_id] = BytesIO() if PY2 else StringIO() + self._buffers[thread_id] = StringIO() def write(self, value): thread_id = thread_get_ident() self._ensure_thread_buffer(thread_id) - if PY2 and isinstance(value, unicode): # pylint: disable=undefined-variable - value = value.encode() return self._buffers[thread_id].write( value.decode() if is_bytes(value) else value ) @@ -74,7 +64,7 @@ class MultiThreadingStdStream(object): return result -class PIOCoreRPC(object): +class PIOCoreRPC: @staticmethod def version(): return __version__ @@ -89,16 +79,9 @@ class PIOCoreRPC(object): sys.stderr = PIOCoreRPC.thread_stderr @staticmethod - def call(args, options=None): - return defer.maybeDeferred(PIOCoreRPC._call_generator, args, options) - - @staticmethod - @defer.inlineCallbacks - def _call_generator(args, options=None): + async def call(args, options=None): for i, arg in enumerate(args): - if isinstance(arg, string_types): - args[i] = arg.encode(get_filesystem_encoding()) if PY2 else arg - else: + if not isinstance(arg, str): args[i] = str(arg) options = options or {} @@ -106,27 +89,34 @@ class PIOCoreRPC(object): try: if options.get("force_subprocess"): - result = yield PIOCoreRPC._call_subprocess(args, options) - defer.returnValue(PIOCoreRPC._process_result(result, to_json)) - else: - result = yield PIOCoreRPC._call_inline(args, options) - try: - defer.returnValue(PIOCoreRPC._process_result(result, to_json)) - except ValueError: - # fall-back to subprocess method - result = yield PIOCoreRPC._call_subprocess(args, options) - defer.returnValue(PIOCoreRPC._process_result(result, to_json)) + result = await PIOCoreRPC._call_subprocess(args, options) + return PIOCoreRPC._process_result(result, to_json) + result = await PIOCoreRPC._call_inline(args, options) + try: + return PIOCoreRPC._process_result(result, to_json) + except ValueError: + # fall-back to subprocess method + result = await PIOCoreRPC._call_subprocess(args, options) + return PIOCoreRPC._process_result(result, to_json) except Exception as e: # pylint: disable=bare-except raise jsonrpc.exceptions.JSONRPCDispatchException( code=4003, message="PIO Core Call Error", data=str(e) ) @staticmethod - def _call_inline(args, options): - PIOCoreRPC.setup_multithreading_std_streams() - cwd = options.get("cwd") or os.getcwd() + async def _call_subprocess(args, options): + result = await run_in_threadpool( + proc.exec_command, + [helpers.get_core_fullpath()] + args, + cwd=options.get("cwd") or os.getcwd(), + ) + return (result["out"], result["err"], result["returncode"]) - def _thread_task(): + @staticmethod + async def _call_inline(args, options): + PIOCoreRPC.setup_multithreading_std_streams() + + def _thread_safe_call(args, cwd): with fs.cd(cwd): exit_code = __main__.main(["-c"] + args) return ( @@ -135,16 +125,8 @@ class PIOCoreRPC(object): exit_code, ) - return threads.deferToThread(_thread_task) - - @staticmethod - def _call_subprocess(args, options): - cwd = (options or {}).get("cwd") or os.getcwd() - return utils.getProcessOutputAndValue( - helpers.get_core_fullpath(), - args, - path=cwd, - env={k: v for k, v in os.environ.items() if "%" not in k}, + return await run_in_threadpool( + _thread_safe_call, args=args, cwd=options.get("cwd") or os.getcwd() ) @staticmethod diff --git a/platformio/commands/home/rpc/handlers/project.py b/platformio/commands/home/rpc/handlers/project.py index eb9cd237..20b1dcc8 100644 --- a/platformio/commands/home/rpc/handlers/project.py +++ b/platformio/commands/home/rpc/handlers/project.py @@ -18,12 +18,11 @@ import os import shutil import time -import jsonrpc # pylint: disable=import-error +import jsonrpc from platformio import exception, fs from platformio.commands.home.rpc.handlers.app import AppRPC from platformio.commands.home.rpc.handlers.piocore import PIOCoreRPC -from platformio.compat import PY2, get_filesystem_encoding from platformio.ide.projectgenerator import ProjectGenerator from platformio.package.manager.platform import PlatformPackageManager from platformio.project.config import ProjectConfig @@ -32,7 +31,7 @@ from platformio.project.helpers import get_project_dir, is_platformio_project from platformio.project.options import get_config_options_schema -class ProjectRPC(object): +class ProjectRPC: @staticmethod def config_call(init_kwargs, method, *args): assert isinstance(init_kwargs, dict) @@ -254,8 +253,6 @@ class ProjectRPC(object): def import_arduino(self, board, use_arduino_libs, arduino_project_dir): board = str(board) - if arduino_project_dir and PY2: - arduino_project_dir = arduino_project_dir.encode(get_filesystem_encoding()) # don't import PIO Project if is_platformio_project(arduino_project_dir): return arduino_project_dir diff --git a/platformio/commands/home/rpc/server.py b/platformio/commands/home/rpc/server.py index 1924754f..a1448566 100644 --- a/platformio/commands/home/rpc/server.py +++ b/platformio/commands/home/rpc/server.py @@ -12,90 +12,112 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=import-error +import inspect +import json +import sys import click import jsonrpc -from autobahn.twisted.websocket import WebSocketServerFactory, WebSocketServerProtocol -from jsonrpc.exceptions import JSONRPCDispatchException -from twisted.internet import defer, reactor +from starlette.endpoints import WebSocketEndpoint -from platformio.compat import PY2, dump_json_to_unicode, is_bytes +from platformio.compat import create_task, get_running_loop, is_bytes -class JSONRPCServerProtocol(WebSocketServerProtocol): - def onOpen(self): - self.factory.connection_nums += 1 - if self.factory.shutdown_timer: - self.factory.shutdown_timer.cancel() - self.factory.shutdown_timer = None +class JSONRPCServerFactoryBase: - def onClose(self, wasClean, code, reason): # pylint: disable=unused-argument - self.factory.connection_nums -= 1 - if self.factory.connection_nums == 0: - self.factory.shutdownByTimeout() - - def onMessage(self, payload, isBinary): # pylint: disable=unused-argument - # click.echo("> %s" % payload) - response = jsonrpc.JSONRPCResponseManager.handle( - payload, self.factory.dispatcher - ).data - # if error - if "result" not in response: - self.sendJSONResponse(response) - return None - - d = defer.maybeDeferred(lambda: response["result"]) - d.addCallback(self._callback, response) - d.addErrback(self._errback, response) - - return None - - def _callback(self, result, response): - response["result"] = result - self.sendJSONResponse(response) - - def _errback(self, failure, response): - if isinstance(failure.value, JSONRPCDispatchException): - e = failure.value - else: - e = JSONRPCDispatchException(code=4999, message=failure.getErrorMessage()) - del response["result"] - response["error"] = e.error._data # pylint: disable=protected-access - self.sendJSONResponse(response) - - def sendJSONResponse(self, response): - # click.echo("< %s" % response) - if "error" in response: - click.secho("Error: %s" % response["error"], fg="red", err=True) - response = dump_json_to_unicode(response) - if not PY2 and not is_bytes(response): - response = response.encode("utf-8") - self.sendMessage(response) - - -class JSONRPCServerFactory(WebSocketServerFactory): - - protocol = JSONRPCServerProtocol connection_nums = 0 - shutdown_timer = 0 + shutdown_timer = None def __init__(self, shutdown_timeout=0): - super(JSONRPCServerFactory, self).__init__() self.shutdown_timeout = shutdown_timeout self.dispatcher = jsonrpc.Dispatcher() - def shutdownByTimeout(self): + def __call__(self, *args, **kwargs): + raise NotImplementedError + + def addHandler(self, handler, namespace): + self.dispatcher.build_method_map(handler, prefix="%s." % namespace) + + def on_client_connect(self): + self.connection_nums += 1 + if self.shutdown_timer: + self.shutdown_timer.cancel() + self.shutdown_timer = None + + def on_client_disconnect(self): + self.connection_nums -= 1 + if self.connection_nums < 1: + self.connection_nums = 0 + + if self.connection_nums == 0: + self.shutdown_by_timeout() + + async def on_shutdown(self): + pass + + def shutdown_by_timeout(self): if self.shutdown_timeout < 1: return def _auto_shutdown_server(): click.echo("Automatically shutdown server on timeout") - reactor.stop() + try: + get_running_loop().stop() + except: # pylint: disable=bare-except + pass + finally: + sys.exit(0) - self.shutdown_timer = reactor.callLater( + self.shutdown_timer = get_running_loop().call_later( self.shutdown_timeout, _auto_shutdown_server ) - def addHandler(self, handler, namespace): - self.dispatcher.build_method_map(handler, prefix="%s." % namespace) + +class WebSocketJSONRPCServerFactory(JSONRPCServerFactoryBase): + def __call__(self, *args, **kwargs): + ws = WebSocketJSONRPCServer(*args, **kwargs) + ws.factory = self + return ws + + +class WebSocketJSONRPCServer(WebSocketEndpoint): + encoding = "text" + factory: WebSocketJSONRPCServerFactory = None + + async def on_connect(self, websocket): + await websocket.accept() + self.factory.on_client_connect() # pylint: disable=no-member + + async def on_receive(self, websocket, data): + create_task(self._handle_rpc(websocket, data)) + + async def on_disconnect(self, websocket, close_code): + self.factory.on_client_disconnect() # pylint: disable=no-member + + async def _handle_rpc(self, websocket, data): + response = jsonrpc.JSONRPCResponseManager.handle( + data, self.factory.dispatcher # pylint: disable=no-member + ) + if response.result and inspect.isawaitable(response.result): + try: + response.result = await response.result + response.data["result"] = response.result + response.error = None + except Exception as exc: # pylint: disable=broad-except + if not isinstance(exc, jsonrpc.exceptions.JSONRPCDispatchException): + exc = jsonrpc.exceptions.JSONRPCDispatchException( + code=4999, message=str(exc) + ) + response.result = None + response.error = exc.error._data # pylint: disable=protected-access + new_data = response.data.copy() + new_data["error"] = response.error + del new_data["result"] + response.data = new_data + + if response.error: + click.secho("Error: %s" % response.error, fg="red", err=True) + if "result" in response.data and is_bytes(response.data["result"]): + response.data["result"] = response.data["result"].decode("utf-8") + + await websocket.send_text(json.dumps(response.data)) diff --git a/platformio/compat.py b/platformio/compat.py index e25fbb6a..53c1507c 100644 --- a/platformio/compat.py +++ b/platformio/compat.py @@ -78,6 +78,12 @@ if PY2: string_types = (str, unicode) + def create_task(coro, name=None): + raise NotImplementedError + + def get_running_loop(): + raise NotImplementedError + def is_bytes(x): return isinstance(x, (buffer, bytearray)) @@ -129,6 +135,12 @@ else: import importlib.util from glob import escape as glob_escape + if sys.version_info >= (3, 7): + from asyncio import create_task, get_running_loop + else: + from asyncio import ensure_future as create_task + from asyncio import get_event_loop as get_running_loop + string_types = (str,) def is_bytes(x): diff --git a/setup.py b/setup.py index 21d3465c..577c2951 100644 --- a/setup.py +++ b/setup.py @@ -26,19 +26,26 @@ from platformio import ( from platformio.compat import PY2, WINDOWS -install_requires = [ - "bottle<0.13", +minimal_requirements = [ + "bottle==0.12.*", "click>=5,<8%s" % (",!=7.1,!=7.1.1" if WINDOWS else ""), "colorama", - "pyserial>=3,<4,!=3.3", - "requests>=2.4.0,<3", - "semantic_version>=2.8.1,<3", - "tabulate>=0.8.3,<1", - "pyelftools>=0.25,<1", - "marshmallow%s" % (">=2,<3" if PY2 else ">=2"), + "marshmallow%s" % (">=2,<3" if PY2 else ">=2,<4"), + "pyelftools>=0.27,<1", + "pyserial==3.*", + "requests==2.*", + "semantic_version==2.8.*", + "tabulate==0.8.*", "zeroconf==%s" % ("0.19.*" if PY2 else "0.28.*"), ] +home_requirements = [ + "aiofiles==0.6.*", + "json-rpc==1.13.*", + "starlette==0.14.*", + "uvicorn==0.13.*", + "wsproto==1.0.*", +] setup( name=__title__, @@ -52,7 +59,7 @@ setup( python_requires=", ".join( [">=2.7", "!=3.0.*", "!=3.1.*", "!=3.2.*", "!=3.3.*", "!=3.4.*"] ), - install_requires=install_requires, + install_requires=minimal_requirements + ([] if PY2 else home_requirements), packages=find_packages(exclude=["tests.*", "tests"]) + ["scripts"], package_data={ "platformio": [