Implement RPC notification system

This commit is contained in:
Ivan Kravets
2023-05-03 22:27:00 +03:00
parent f840577066
commit c016d6827b

View File

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from urllib.parse import parse_qs
import click
from ajsonrpc.core import JSONRPC20Error
from ajsonrpc.core import JSONRPC20Error, JSONRPC20Request
from ajsonrpc.dispatcher import Dispatcher
from ajsonrpc.manager import AsyncJSONRPCResponseManager, JSONRPC20Response
from starlette.endpoints import WebSocketEndpoint
@ -32,6 +34,7 @@ class JSONRPCServerFactoryBase:
self.manager = AsyncJSONRPCResponseManager(
Dispatcher(), is_server_error_verbose=True
)
self._clients = {}
def __call__(self, *args, **kwargs):
raise NotImplementedError
@ -40,13 +43,16 @@ class JSONRPCServerFactoryBase:
handler.factory = self
self.manager.dispatcher.add_object(handler, prefix="%s." % namespace)
def on_client_connect(self):
def on_client_connect(self, connection, actor=None):
self._clients[connection] = {"actor": actor}
self.connection_nums += 1
if self.shutdown_timer:
self.shutdown_timer.cancel()
self.shutdown_timer = None
def on_client_disconnect(self):
def on_client_disconnect(self, connection):
if connection in self._clients:
del self._clients[connection]
self.connection_nums -= 1
if self.connection_nums < 1:
self.connection_nums = 0
@ -69,6 +75,14 @@ class JSONRPCServerFactoryBase:
self.shutdown_timeout, _auto_shutdown_server
)
async def notify_clients(self, method, params=None, actor=None):
for client, options in self._clients.items():
if actor and options["actor"] != actor:
continue
request = JSONRPC20Request(method, params, is_notification=True)
await client.send_text(self.manager.serialize(request.body))
return True
class WebSocketJSONRPCServerFactory(JSONRPCServerFactoryBase):
def __call__(self, *args, **kwargs):
@ -83,13 +97,17 @@ class WebSocketJSONRPCServer(WebSocketEndpoint):
async def on_connect(self, websocket):
await websocket.accept()
self.factory.on_client_connect() # pylint: disable=no-member
qs = parse_qs(self.scope.get("query_string", b""))
actors = qs.get(b"actor")
self.factory.on_client_connect( # pylint: disable=no-member
websocket, actor=actors[0].decode() if actors else None
)
async def on_receive(self, websocket, data):
aio_create_task(self._handle_rpc(websocket, data))
async def on_disconnect(self, websocket, close_code):
self.factory.on_client_disconnect() # pylint: disable=no-member
self.factory.on_client_disconnect(websocket) # pylint: disable=no-member
async def _handle_rpc(self, websocket, data):
# pylint: disable=no-member