Implement "__shutdown__" endpoint for PIO Home server

This commit is contained in:
Ivan Kravets
2021-01-18 21:19:15 +02:00
parent 429065d2b9
commit bd897d780b

View File

@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import os
import click
import uvicorn
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.routing import Mount, WebSocketRoute
from starlette.responses import PlainTextResponse
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.staticfiles import StaticFiles
from platformio.commands.home.rpc.handlers.account import AccountRPC
@ -29,7 +31,6 @@ from platformio.commands.home.rpc.handlers.os import OSRPC
from platformio.commands.home.rpc.handlers.piocore import PIOCoreRPC
from platformio.commands.home.rpc.handlers.project import ProjectRPC
from platformio.commands.home.rpc.server import WebSocketJSONRPCServerFactory
from platformio.compat import get_running_loop
from platformio.exception import PlatformioException
from platformio.package.manager.core import get_core_package_dir
from platformio.proc import force_exit
@ -41,10 +42,15 @@ class ShutdownMiddleware:
async def __call__(self, scope, receive, send):
if scope["type"] == "http" and b"__shutdown__" in scope.get("query_string", {}):
get_running_loop().call_later(0.5, force_exit)
await shutdown_server()
await self.app(scope, receive, send)
async def shutdown_server(_=None):
asyncio.get_event_loop().call_later(0.5, force_exit)
return PlainTextResponse("Server has been shutdown!")
def run_server(host, port, no_open, shutdown_timeout, home_url):
contrib_dir = get_core_package_dir("contrib-piohome")
if not os.path.isdir(contrib_dir):
@ -64,6 +70,7 @@ def run_server(host, port, no_open, shutdown_timeout, home_url):
middleware=[Middleware(ShutdownMiddleware)],
routes=[
WebSocketRoute("/wsrpc", ws_rpc_factory, name="wsrpc"),
Route("/__shutdown__", shutdown_server, methods=["POST"]),
Mount("/", StaticFiles(directory=contrib_dir, html=True)),
],
on_startup=[