diff --git a/platformio/commands/home/command.py b/platformio/commands/home/command.py index abb13074..28cbfef4 100644 --- a/platformio/commands/home/command.py +++ b/platformio/commands/home/command.py @@ -12,17 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -# pylint: disable=too-many-locals,too-many-statements - import mimetypes -import os -import socket import click -from platformio import exception -from platformio.compat import WINDOWS, ensure_python3 -from platformio.package.manager.core import get_core_package_dir +from platformio.commands.home.helpers import is_port_used +from platformio.compat import ensure_python3 @click.command("home", short_help="GUI to manage PlatformIO") @@ -46,6 +41,8 @@ from platformio.package.manager.core import get_core_package_dir ), ) def cli(port, host, no_open, shutdown_timeout): + ensure_python3() + # Ensure PIO Home mimetypes are known mimetypes.add_type("text/html", ".html") mimetypes.add_type("text/css", ".css") @@ -79,6 +76,9 @@ def cli(port, host, no_open, shutdown_timeout): click.launch(home_url) return + # pylint: disable=import-outside-toplevel + from platformio.commands.home.run import run_server + run_server( host=host, port=port, @@ -86,74 +86,3 @@ def cli(port, host, no_open, shutdown_timeout): shutdown_timeout=shutdown_timeout, home_url=home_url, ) - - -def is_port_used(host, port): - socket.setdefaulttimeout(1) - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - if WINDOWS: - try: - s.bind((host, port)) - s.close() - return False - except (OSError, socket.error): - pass - else: - try: - s.connect((host, port)) - s.close() - except socket.error: - return False - - return True - - -def run_server(host, port, no_open, shutdown_timeout, home_url): - # pylint: disable=import-error, import-outside-toplevel - - ensure_python3() - - import uvicorn - from starlette.applications import Starlette - from starlette.routing import Mount, WebSocketRoute - from starlette.staticfiles import StaticFiles - - from platformio.commands.home.rpc.handlers.account import AccountRPC - from platformio.commands.home.rpc.handlers.app import AppRPC - from platformio.commands.home.rpc.handlers.ide import IDERPC - from platformio.commands.home.rpc.handlers.misc import MiscRPC - from platformio.commands.home.rpc.handlers.os import OSRPC - from platformio.commands.home.rpc.handlers.piocore import PIOCoreRPC - from platformio.commands.home.rpc.handlers.project import ProjectRPC - from platformio.commands.home.rpc.server import WebSocketJSONRPCServerFactory - - contrib_dir = get_core_package_dir("contrib-piohome") - if not os.path.isdir(contrib_dir): - raise exception.PlatformioException("Invalid path to PIO Home Contrib") - - ws_rpc_factory = WebSocketJSONRPCServerFactory(shutdown_timeout) - ws_rpc_factory.addHandler(AccountRPC(), namespace="account") - ws_rpc_factory.addHandler(AppRPC(), namespace="app") - ws_rpc_factory.addHandler(IDERPC(), namespace="ide") - ws_rpc_factory.addHandler(MiscRPC(), namespace="misc") - ws_rpc_factory.addHandler(OSRPC(), namespace="os") - ws_rpc_factory.addHandler(PIOCoreRPC(), namespace="core") - ws_rpc_factory.addHandler(ProjectRPC(), namespace="project") - - uvicorn.run( - Starlette( - routes=[ - WebSocketRoute("/wsrpc", ws_rpc_factory, name="wsrpc"), - Mount("/", StaticFiles(directory=contrib_dir, html=True)), - ], - on_startup=[ - lambda: click.echo( - "PIO Home has been started. Press Ctrl+C to shutdown." - ), - lambda: None if no_open else click.launch(home_url), - ], - ), - host=host, - port=port, - log_level="warning", - ) diff --git a/platformio/commands/home/helpers.py b/platformio/commands/home/helpers.py index dfd57c25..5c6e0c88 100644 --- a/platformio/commands/home/helpers.py +++ b/platformio/commands/home/helpers.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import socket + import requests from starlette.concurrency import run_in_threadpool from platformio import util +from platformio.compat import WINDOWS from platformio.proc import where_is_program @@ -37,3 +40,23 @@ def get_core_fullpath(): return where_is_program( "platformio" + (".exe" if "windows" in util.get_systype() else "") ) + + +def is_port_used(host, port): + socket.setdefaulttimeout(1) + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if WINDOWS: + try: + s.bind((host, port)) + s.close() + return False + except (OSError, socket.error): + pass + else: + try: + s.connect((host, port)) + s.close() + except socket.error: + return False + + return True diff --git a/platformio/commands/home/rpc/server.py b/platformio/commands/home/rpc/server.py index a1448566..8ba41da9 100644 --- a/platformio/commands/home/rpc/server.py +++ b/platformio/commands/home/rpc/server.py @@ -14,13 +14,13 @@ import inspect import json -import sys import click import jsonrpc from starlette.endpoints import WebSocketEndpoint from platformio.compat import create_task, get_running_loop, is_bytes +from platformio.proc import force_exit class JSONRPCServerFactoryBase: @@ -61,12 +61,7 @@ class JSONRPCServerFactoryBase: def _auto_shutdown_server(): click.echo("Automatically shutdown server on timeout") - try: - get_running_loop().stop() - except: # pylint: disable=bare-except - pass - finally: - sys.exit(0) + force_exit() self.shutdown_timer = get_running_loop().call_later( self.shutdown_timeout, _auto_shutdown_server diff --git a/platformio/commands/home/run.py b/platformio/commands/home/run.py new file mode 100644 index 00000000..33096233 --- /dev/null +++ b/platformio/commands/home/run.py @@ -0,0 +1,79 @@ +# Copyright (c) 2014-present PlatformIO +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import click +import uvicorn +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.routing import Mount, WebSocketRoute +from starlette.staticfiles import StaticFiles + +from platformio.commands.home.rpc.handlers.account import AccountRPC +from platformio.commands.home.rpc.handlers.app import AppRPC +from platformio.commands.home.rpc.handlers.ide import IDERPC +from platformio.commands.home.rpc.handlers.misc import MiscRPC +from platformio.commands.home.rpc.handlers.os import OSRPC +from platformio.commands.home.rpc.handlers.piocore import PIOCoreRPC +from platformio.commands.home.rpc.handlers.project import ProjectRPC +from platformio.commands.home.rpc.server import WebSocketJSONRPCServerFactory +from platformio.compat import get_running_loop +from platformio.exception import PlatformioException +from platformio.package.manager.core import get_core_package_dir +from platformio.proc import force_exit + + +class ShutdownMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] == "http" and b"__shutdown__" in scope.get("query_string", {}): + get_running_loop().call_later(0.5, force_exit) + await self.app(scope, receive, send) + + +def run_server(host, port, no_open, shutdown_timeout, home_url): + contrib_dir = get_core_package_dir("contrib-piohome") + if not os.path.isdir(contrib_dir): + raise PlatformioException("Invalid path to PIO Home Contrib") + + ws_rpc_factory = WebSocketJSONRPCServerFactory(shutdown_timeout) + ws_rpc_factory.addHandler(AccountRPC(), namespace="account") + ws_rpc_factory.addHandler(AppRPC(), namespace="app") + ws_rpc_factory.addHandler(IDERPC(), namespace="ide") + ws_rpc_factory.addHandler(MiscRPC(), namespace="misc") + ws_rpc_factory.addHandler(OSRPC(), namespace="os") + ws_rpc_factory.addHandler(PIOCoreRPC(), namespace="core") + ws_rpc_factory.addHandler(ProjectRPC(), namespace="project") + + uvicorn.run( + Starlette( + middleware=[Middleware(ShutdownMiddleware)], + routes=[ + WebSocketRoute("/wsrpc", ws_rpc_factory, name="wsrpc"), + Mount("/", StaticFiles(directory=contrib_dir, html=True)), + ], + on_startup=[ + lambda: click.echo( + "PIO Home has been started. Press Ctrl+C to shutdown." + ), + lambda: None if no_open else click.launch(home_url), + ], + ), + host=host, + port=port, + log_level="warning", + ) diff --git a/platformio/commands/home/web.py b/platformio/commands/home/web.py deleted file mode 100644 index 32bf0692..00000000 --- a/platformio/commands/home/web.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright (c) 2014-present PlatformIO -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet import reactor # pylint: disable=import-error -from twisted.web import static # pylint: disable=import-error - - -class WebRoot(static.File): - def render_GET(self, request): - if request.args.get(b"__shutdown__", False): - reactor.stop() - return "Server has been stopped" - - request.setHeader("cache-control", "no-cache, no-store, must-revalidate") - request.setHeader("pragma", "no-cache") - request.setHeader("expires", "0") - return static.File.render_GET(self, request) diff --git a/platformio/proc.py b/platformio/proc.py index 8db7153e..d9df0a3b 100644 --- a/platformio/proc.py +++ b/platformio/proc.py @@ -24,6 +24,7 @@ from platformio.compat import ( WINDOWS, get_filesystem_encoding, get_locale_encoding, + get_running_loop, string_types, ) @@ -214,3 +215,12 @@ def append_env_path(name, value): return cur_value os.environ[name] = os.pathsep.join([cur_value, value]) return os.environ[name] + + +def force_exit(code=0): + try: + get_running_loop().stop() + except: # pylint: disable=bare-except + pass + finally: + sys.exit(code)