Implement native async RPC core.exec

This commit is contained in:
Ivan Kravets
2023-05-03 22:25:44 +03:00
parent f9cbf6cb97
commit 475c5d2a3c
4 changed files with 105 additions and 51 deletions

View File

@ -13,10 +13,11 @@
# limitations under the License.
import mimetypes
import socket
import click
from platformio.home.helpers import is_port_used
from platformio.compat import IS_WINDOWS
from platformio.home.run import run_server
from platformio.package.manager.core import get_core_package_dir
@ -95,3 +96,23 @@ def cli(port, host, no_open, shutdown_timeout, session_id):
shutdown_timeout=shutdown_timeout,
home_url=home_url,
)
def is_port_used(host, port):
socket.setdefaulttimeout(1)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if IS_WINDOWS:
try:
s.bind((host, port))
s.close()
return False
except (OSError, socket.error):
pass
else:
try:
s.connect((host, port))
s.close()
except socket.error:
return False
return True

View File

@ -1,44 +0,0 @@
# Copyright (c) 2014-present PlatformIO <contact@platformio.org>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
from platformio import util
from platformio.compat import IS_WINDOWS
from platformio.proc import where_is_program
@util.memoized(expire="60s")
def get_core_fullpath():
return where_is_program("platformio" + (".exe" if IS_WINDOWS else ""))
def is_port_used(host, port):
socket.setdefaulttimeout(1)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if IS_WINDOWS:
try:
s.bind((host, port))
s.close()
return False
except (OSError, socket.error):
pass
else:
try:
s.connect((host, port))
s.close()
except socket.error:
return False
return True

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import functools
import io
import json
import os
@ -22,10 +24,45 @@ import click
from ajsonrpc.core import JSONRPC20DispatchException
from starlette.concurrency import run_in_threadpool
from platformio import __main__, __version__, fs, proc
from platformio.compat import get_locale_encoding, is_bytes
from platformio import __main__, __version__, app, fs, proc, util
from platformio.compat import (
IS_WINDOWS,
aio_create_task,
aio_get_running_loop,
get_locale_encoding,
is_bytes,
)
from platformio.exception import PlatformioException
from platformio.home import helpers
from platformio.home.rpc.handlers.base import BaseRPCHandler
class PIOCoreProtocol(asyncio.SubprocessProtocol):
def __init__(self, exit_future, on_data_callback=None):
self.exit_future = exit_future
self.on_data_callback = on_data_callback
self.stdout = ""
self.stderr = ""
self._is_exited = False
self._encoding = get_locale_encoding()
def pipe_data_received(self, fd, data):
data = data.decode(self._encoding, "replace")
pipe = ["stdin", "stdout", "stderr"][fd]
if pipe == "stdout":
self.stdout += data
if pipe == "stderr":
self.stderr += data
if self.on_data_callback:
self.on_data_callback(pipe=pipe, data=data)
def connection_lost(self, exc):
self.process_exited()
def process_exited(self):
if self._is_exited:
return
self.exit_future.set_result(True)
self._is_exited = True
class MultiThreadingStdStream:
@ -59,11 +96,51 @@ class MultiThreadingStdStream:
return result
class PIOCoreRPC:
@util.memoized(expire="60s")
def get_core_fullpath():
return proc.where_is_program("platformio" + (".exe" if IS_WINDOWS else ""))
class PIOCoreRPC(BaseRPCHandler):
@staticmethod
def version():
return __version__
async def exec(self, args, options=None):
loop = aio_get_running_loop()
exit_future = loop.create_future()
data_callback = functools.partial(
self._on_exec_data_received, exec_options=options
)
if args[0] != "--caller" and app.get_session_var("caller_id"):
args = ["--caller", app.get_session_var("caller_id")] + args
transport, protocol = await loop.subprocess_exec(
lambda: PIOCoreProtocol(exit_future, data_callback),
get_core_fullpath(),
*args,
stdin=None,
**options.get("spawn", {}),
)
await exit_future
transport.close()
return {
"stdout": protocol.stdout,
"stderr": protocol.stderr,
"returncode": transport.get_returncode(),
}
def _on_exec_data_received(self, exec_options, pipe, data):
notification_method = exec_options.get(f"{pipe}NotificationMethod")
if not notification_method:
return
aio_create_task(
self.factory.notify_clients(
method=notification_method,
params=[data],
actor="frontend",
)
)
@staticmethod
def setup_multithreading_std_streams():
if isinstance(sys.stdout, MultiThreadingStdStream):
@ -102,7 +179,7 @@ class PIOCoreRPC:
async def _call_subprocess(args, options):
result = await run_in_threadpool(
proc.exec_command,
[helpers.get_core_fullpath()] + args,
[get_core_fullpath()] + args,
cwd=options.get("cwd") or os.getcwd(),
)
return (result["out"], result["err"], result["returncode"])

View File

@ -45,7 +45,7 @@ class ShutdownMiddleware:
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] == "http" and b"__shutdown__" in scope.get("query_string", {}):
if scope["type"] == "http" and b"__shutdown__" in scope.get("query_string", ""):
await shutdown_server()
await self.app(scope, receive, send)