Run sync RPC methods in thread

This commit is contained in:
Ivan Kravets
2023-07-21 15:27:31 +03:00
parent 9affc023a2
commit 9deb7f4275
12 changed files with 64 additions and 52 deletions

View File

@@ -19,11 +19,12 @@ from platformio.home.rpc.handlers.base import BaseRPCHandler
class AccountRPC(BaseRPCHandler): class AccountRPC(BaseRPCHandler):
NAMESPACE = "account"
@staticmethod @staticmethod
def call_client(method, *args, **kwargs): def call_client(method, *args, **kwargs):
try: try:
client = AccountClient() return getattr(AccountClient(), method)(*args, **kwargs)
return getattr(client, method)(*args, **kwargs)
except Exception as exc: # pylint: disable=bare-except except Exception as exc: # pylint: disable=bare-except
raise JSONRPC20DispatchException( raise JSONRPC20DispatchException(
code=5000, message="PIO Account Call Error", data=str(exc) code=5000, message="PIO Account Call Error", data=str(exc)

View File

@@ -20,6 +20,7 @@ from platformio.project.helpers import is_platformio_project
class AppRPC(BaseRPCHandler): class AppRPC(BaseRPCHandler):
NAMESPACE = "app"
IGNORE_STORAGE_KEYS = [ IGNORE_STORAGE_KEYS = [
"cid", "cid",
"coreVersion", "coreVersion",

View File

@@ -14,4 +14,6 @@
class BaseRPCHandler: class BaseRPCHandler:
NAMESPACE = None
factory = None factory = None

View File

@@ -22,6 +22,7 @@ from platformio.home.rpc.handlers.base import BaseRPCHandler
class IDERPC(BaseRPCHandler): class IDERPC(BaseRPCHandler):
NAMESPACE = "ide"
COMMAND_TIMEOUT = 1.5 # in seconds COMMAND_TIMEOUT = 1.5 # in seconds
def __init__(self): def __init__(self):

View File

@@ -22,6 +22,8 @@ from platformio.home.rpc.handlers.os import OSRPC
class MiscRPC(BaseRPCHandler): class MiscRPC(BaseRPCHandler):
NAMESPACE = "misc"
async def load_latest_tweets(self, data_url): async def load_latest_tweets(self, data_url):
cache_key = ContentCache.key_from_args(data_url, "tweets") cache_key = ContentCache.key_from_args(data_url, "tweets")
cache_valid = "180d" cache_valid = "180d"

View File

@@ -22,25 +22,18 @@ import click
from platformio import fs from platformio import fs
from platformio.cache import ContentCache from platformio.cache import ContentCache
from platformio.compat import aio_to_thread
from platformio.device.list.util import list_logical_devices from platformio.device.list.util import list_logical_devices
from platformio.home.rpc.handlers.base import BaseRPCHandler from platformio.home.rpc.handlers.base import BaseRPCHandler
from platformio.http import HTTPSession, ensure_internet_on from platformio.http import HTTPSession, ensure_internet_on
class HTTPAsyncSession(HTTPSession):
async def request( # pylint: disable=signature-differs,invalid-overridden-method
self, *args, **kwargs
):
func = super().request
return await aio_to_thread(func, *args, **kwargs)
class OSRPC(BaseRPCHandler): class OSRPC(BaseRPCHandler):
NAMESPACE = "os"
_http_session = None _http_session = None
@classmethod @classmethod
async def fetch_content(cls, url, data=None, headers=None, cache_valid=None): def fetch_content(cls, url, data=None, headers=None, cache_valid=None):
if not headers: if not headers:
headers = { headers = {
"User-Agent": ( "User-Agent": (
@@ -60,12 +53,12 @@ class OSRPC(BaseRPCHandler):
ensure_internet_on(raise_exception=True) ensure_internet_on(raise_exception=True)
if not cls._http_session: if not cls._http_session:
cls._http_session = HTTPAsyncSession() cls._http_session = HTTPSession()
if data: if data:
r = await cls._http_session.post(url, data=data, headers=headers) r = cls._http_session.post(url, data=data, headers=headers)
else: else:
r = await cls._http_session.get(url, headers=headers) r = cls._http_session.get(url, headers=headers)
r.raise_for_status() r.raise_for_status()
result = r.text result = r.text
@@ -74,13 +67,12 @@ class OSRPC(BaseRPCHandler):
cc.set(cache_key, result, cache_valid) cc.set(cache_key, result, cache_valid)
return result return result
async def request_content(self, uri, data=None, headers=None, cache_valid=None): def request_content(self, uri, data=None, headers=None, cache_valid=None):
if uri.startswith("http"): if uri.startswith("http"):
return await self.fetch_content(uri, data, headers, cache_valid) return self.fetch_content(uri, data, headers, cache_valid)
local_path = uri[7:] if uri.startswith("file://") else uri local_path = uri[7:] if uri.startswith("file://") else uri
with io.open(local_path, encoding="utf-8") as fp: with io.open(local_path, encoding="utf-8") as fp:
return fp.read() return fp.read()
return None
@staticmethod @staticmethod
def open_url(url): def open_url(url):

View File

@@ -102,6 +102,8 @@ def get_core_fullpath():
class PIOCoreRPC(BaseRPCHandler): class PIOCoreRPC(BaseRPCHandler):
NAMESPACE = "core"
@staticmethod @staticmethod
def version(): def version():
return __version__ return __version__

View File

@@ -14,8 +14,8 @@
import os.path import os.path
from platformio.compat import aio_to_thread
from platformio.home.rpc.handlers.base import BaseRPCHandler from platformio.home.rpc.handlers.base import BaseRPCHandler
from platformio.home.rpc.handlers.registry import RegistryRPC
from platformio.package.manager.platform import PlatformPackageManager from platformio.package.manager.platform import PlatformPackageManager
from platformio.package.manifest.parser import ManifestParserFactory from platformio.package.manifest.parser import ManifestParserFactory
from platformio.package.meta import PackageSpec from platformio.package.meta import PackageSpec
@@ -23,15 +23,13 @@ from platformio.platform.factory import PlatformFactory
class PlatformRPC(BaseRPCHandler): class PlatformRPC(BaseRPCHandler):
async def fetch_platforms(self, search_query=None, page=0, force_installed=False): NAMESPACE = "platform"
if force_installed:
return {
"items": await aio_to_thread(
self._load_installed_platforms, search_query
)
}
search_result = await self.factory.manager.dispatcher["registry.call_client"]( def fetch_platforms(self, search_query=None, page=0, force_installed=False):
if force_installed:
return {"items": self._load_installed_platforms(search_query)}
search_result = RegistryRPC.call_client(
method="list_packages", method="list_packages",
query=search_query, query=search_query,
qualifiers={ qualifiers={
@@ -88,17 +86,17 @@ class PlatformRPC(BaseRPCHandler):
) )
return items return items
async def fetch_boards(self, platform_spec): def fetch_boards(self, platform_spec):
spec = PackageSpec(platform_spec) spec = PackageSpec(platform_spec)
if spec.owner: if spec.owner:
return await self.factory.manager.dispatcher["registry.call_client"]( return RegistryRPC.call_client(
method="get_package", method="get_package",
typex="platform", typex="platform",
owner=spec.owner, owner=spec.owner,
name=spec.name, name=spec.name,
extra_path="/boards", extra_path="/boards",
) )
return await aio_to_thread(self._load_installed_boards, spec) return self._load_installed_boards(spec)
@staticmethod @staticmethod
def _load_installed_boards(platform_spec): def _load_installed_boards(platform_spec):
@@ -108,17 +106,17 @@ class PlatformRPC(BaseRPCHandler):
key=lambda item: item["name"], key=lambda item: item["name"],
) )
async def fetch_examples(self, platform_spec): def fetch_examples(self, platform_spec):
spec = PackageSpec(platform_spec) spec = PackageSpec(platform_spec)
if spec.owner: if spec.owner:
return await self.factory.manager.dispatcher["registry.call_client"]( return RegistryRPC.call_client(
method="get_package", method="get_package",
typex="platform", typex="platform",
owner=spec.owner, owner=spec.owner,
name=spec.name, name=spec.name,
extra_path="/examples", extra_path="/examples",
) )
return await aio_to_thread(self._load_installed_examples, spec) return self._load_installed_examples(spec)
@staticmethod @staticmethod
def _load_installed_examples(platform_spec): def _load_installed_examples(platform_spec):

View File

@@ -34,6 +34,8 @@ from platformio.project.options import get_config_options_schema
class ProjectRPC(BaseRPCHandler): class ProjectRPC(BaseRPCHandler):
NAMESPACE = "project"
@staticmethod @staticmethod
def config_call(init_kwargs, method, *args): def config_call(init_kwargs, method, *args):
assert isinstance(init_kwargs, dict) assert isinstance(init_kwargs, dict)

View File

@@ -14,17 +14,17 @@
from ajsonrpc.core import JSONRPC20DispatchException from ajsonrpc.core import JSONRPC20DispatchException
from platformio.compat import aio_to_thread
from platformio.home.rpc.handlers.base import BaseRPCHandler from platformio.home.rpc.handlers.base import BaseRPCHandler
from platformio.registry.client import RegistryClient from platformio.registry.client import RegistryClient
class RegistryRPC(BaseRPCHandler): class RegistryRPC(BaseRPCHandler):
NAMESPACE = "registry"
@staticmethod @staticmethod
async def call_client(method, *args, **kwargs): def call_client(method, *args, **kwargs):
try: try:
client = RegistryClient() return getattr(RegistryClient(), method)(*args, **kwargs)
return await aio_to_thread(getattr(client, method), *args, **kwargs)
except Exception as exc: # pylint: disable=bare-except except Exception as exc: # pylint: disable=bare-except
raise JSONRPC20DispatchException( raise JSONRPC20DispatchException(
code=5000, message="Registry Call Error", data=str(exc) code=5000, message="Registry Call Error", data=str(exc)

View File

@@ -12,22 +12,24 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools
import inspect
from urllib.parse import parse_qs from urllib.parse import parse_qs
import ajsonrpc.utils import ajsonrpc.manager
import click import click
from ajsonrpc.core import JSONRPC20Error, JSONRPC20Request from ajsonrpc.core import JSONRPC20Error, JSONRPC20Request
from ajsonrpc.dispatcher import Dispatcher from ajsonrpc.dispatcher import Dispatcher
from ajsonrpc.manager import AsyncJSONRPCResponseManager, JSONRPC20Response from ajsonrpc.manager import AsyncJSONRPCResponseManager, JSONRPC20Response
from starlette.endpoints import WebSocketEndpoint from starlette.endpoints import WebSocketEndpoint
from platformio.compat import aio_create_task, aio_get_running_loop from platformio.compat import aio_create_task, aio_get_running_loop, aio_to_thread
from platformio.http import InternetConnectionError from platformio.http import InternetConnectionError
from platformio.proc import force_exit from platformio.proc import force_exit
# Remove this line when PR is merged # Remove this line when PR is merged
# https://github.com/pavlov99/ajsonrpc/pull/22 # https://github.com/pavlov99/ajsonrpc/pull/22
ajsonrpc.utils.is_invalid_params = lambda: False ajsonrpc.manager.is_invalid_params = lambda *args, **kwargs: False
class JSONRPCServerFactoryBase: class JSONRPCServerFactoryBase:
@@ -44,9 +46,18 @@ class JSONRPCServerFactoryBase:
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def add_object_handler(self, handler, namespace): def add_object_handler(self, obj):
handler.factory = self obj.factory = self
self.manager.dispatcher.add_object(handler, prefix="%s." % namespace) namespace = obj.NAMESPACE or obj.__class__.__name__
for name in dir(obj):
method = getattr(obj, name)
if name.startswith("_") or not (
inspect.ismethod(method) or inspect.isfunction(method)
):
continue
if not inspect.iscoroutinefunction(method):
method = functools.partial(aio_to_thread, method)
self.manager.dispatcher.add_function(method, name=f"{namespace}.{name}")
def on_client_connect(self, connection, actor=None): def on_client_connect(self, connection, actor=None):
self._clients[connection] = {"actor": actor} self._clients[connection] = {"actor": actor}

View File

@@ -67,15 +67,15 @@ def run_server(host, port, no_open, shutdown_timeout, home_url):
raise PlatformioException("Invalid path to PIO Home Contrib") raise PlatformioException("Invalid path to PIO Home Contrib")
ws_rpc_factory = WebSocketJSONRPCServerFactory(shutdown_timeout) ws_rpc_factory = WebSocketJSONRPCServerFactory(shutdown_timeout)
ws_rpc_factory.add_object_handler(AccountRPC(), namespace="account") ws_rpc_factory.add_object_handler(AccountRPC())
ws_rpc_factory.add_object_handler(AppRPC(), namespace="app") ws_rpc_factory.add_object_handler(AppRPC())
ws_rpc_factory.add_object_handler(IDERPC(), namespace="ide") ws_rpc_factory.add_object_handler(IDERPC())
ws_rpc_factory.add_object_handler(MiscRPC(), namespace="misc") ws_rpc_factory.add_object_handler(MiscRPC())
ws_rpc_factory.add_object_handler(OSRPC(), namespace="os") ws_rpc_factory.add_object_handler(OSRPC())
ws_rpc_factory.add_object_handler(PIOCoreRPC(), namespace="core") ws_rpc_factory.add_object_handler(PIOCoreRPC())
ws_rpc_factory.add_object_handler(ProjectRPC(), namespace="project") ws_rpc_factory.add_object_handler(ProjectRPC())
ws_rpc_factory.add_object_handler(PlatformRPC(), namespace="platform") ws_rpc_factory.add_object_handler(PlatformRPC())
ws_rpc_factory.add_object_handler(RegistryRPC(), namespace="registry") ws_rpc_factory.add_object_handler(RegistryRPC())
path = urlparse(home_url).path path = urlparse(home_url).path
routes = [ routes = [