Use device finder for automatic detection of upload port

This commit is contained in:
Ivan Kravets
2022-05-15 13:46:44 +03:00
parent d01d314f47
commit 7a100fb0b0
5 changed files with 63 additions and 82 deletions

View File

@ -12,13 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=unused-argument
from __future__ import absolute_import from __future__ import absolute_import
import os
import re import re
import sys import sys
from fnmatch import fnmatch
from os import environ
from os.path import isfile, join
from shutil import copyfile from shutil import copyfile
from time import sleep from time import sleep
@ -26,12 +26,10 @@ from SCons.Script import ARGUMENTS # pylint: disable=import-error
from serial import Serial, SerialException from serial import Serial, SerialException
from platformio import exception, fs from platformio import exception, fs
from platformio.compat import IS_WINDOWS from platformio.device.finder import find_mbed_disk, find_serial_port, is_pattern_port
from platformio.device.list import list_logical_devices, list_serial_ports from platformio.device.list import list_serial_ports
from platformio.proc import exec_command from platformio.proc import exec_command
# pylint: disable=unused-argument
def FlushSerialBuffer(env, port): def FlushSerialBuffer(env, port):
s = Serial(env.subst(port)) s = Serial(env.subst(port))
@ -98,67 +96,28 @@ def WaitForNewSerialPort(env, before):
def AutodetectUploadPort(*args, **kwargs): def AutodetectUploadPort(*args, **kwargs):
env = args[0] env = args[0]
initial_port = env.subst("$UPLOAD_PORT")
def _get_pattern():
if "UPLOAD_PORT" not in env:
return None
if set(["*", "?", "[", "]"]) & set(env["UPLOAD_PORT"]):
return env["UPLOAD_PORT"]
return None
def _is_match_pattern(port):
pattern = _get_pattern()
if not pattern:
return True
return fnmatch(port, pattern)
def _look_for_mbed_disk():
msdlabels = ("mbed", "nucleo", "frdm", "microbit")
for item in list_logical_devices():
if item["path"].startswith("/net") or not _is_match_pattern(item["path"]):
continue
mbed_pages = [join(item["path"], n) for n in ("mbed.htm", "mbed.html")]
if any(isfile(p) for p in mbed_pages):
return item["path"]
if item["name"] and any(l in item["name"].lower() for l in msdlabels):
return item["path"]
return None
def _look_for_serial_port():
port = None
board_hwids = []
upload_protocol = env.subst("$UPLOAD_PROTOCOL") upload_protocol = env.subst("$UPLOAD_PROTOCOL")
if "BOARD" in env and "build.hwids" in env.BoardConfig(): if initial_port and not is_pattern_port(initial_port):
board_hwids = env.BoardConfig().get("build.hwids") print(env.subst("Using manually specified: $UPLOAD_PORT"))
for item in list_serial_ports(filter_hwid=True):
if not _is_match_pattern(item["port"]):
continue
port = item["port"]
if upload_protocol.startswith("blackmagic"):
if IS_WINDOWS and port.startswith("COM") and len(port) > 4:
port = "\\\\.\\%s" % port
if "GDB" in item["description"]:
return port
for hwid in board_hwids:
hwid_str = ("%s:%s" % (hwid[0], hwid[1])).replace("0x", "")
if hwid_str in item["hwid"]:
return port
return port
if "UPLOAD_PORT" in env and not _get_pattern():
print(env.subst("Use manually specified: $UPLOAD_PORT"))
return return
if env.subst("$UPLOAD_PROTOCOL") == "mbed" or ( if upload_protocol == "mbed" or (
"mbed" in env.subst("$PIOFRAMEWORK") and not env.subst("$UPLOAD_PROTOCOL") "mbed" in env.subst("$PIOFRAMEWORK") and not upload_protocol
): ):
env.Replace(UPLOAD_PORT=_look_for_mbed_disk()) env.Replace(UPLOAD_PORT=find_mbed_disk(initial_port))
else: else:
try: try:
fs.ensure_udev_rules() fs.ensure_udev_rules()
except exception.InvalidUdevRules as e: except exception.InvalidUdevRules as e:
sys.stderr.write("\n%s\n\n" % e) sys.stderr.write("\n%s\n\n" % e)
env.Replace(UPLOAD_PORT=_look_for_serial_port()) env.Replace(
UPLOAD_PORT=find_serial_port(
initial_port=initial_port,
board_config=env.BoardConfig() if "BOARD" in env else None,
upload_protocol=upload_protocol,
)
)
if env.subst("$UPLOAD_PORT"): if env.subst("$UPLOAD_PORT"):
print(env.subst("Auto-detected: $UPLOAD_PORT")) print(env.subst("Auto-detected: $UPLOAD_PORT"))
@ -176,10 +135,12 @@ def UploadToDisk(_, target, source, env):
assert "UPLOAD_PORT" in env assert "UPLOAD_PORT" in env
progname = env.subst("$PROGNAME") progname = env.subst("$PROGNAME")
for ext in ("bin", "hex"): for ext in ("bin", "hex"):
fpath = join(env.subst("$BUILD_DIR"), "%s.%s" % (progname, ext)) fpath = os.path.join(env.subst("$BUILD_DIR"), "%s.%s" % (progname, ext))
if not isfile(fpath): if not os.path.isfile(fpath):
continue continue
copyfile(fpath, join(env.subst("$UPLOAD_PORT"), "%s.%s" % (progname, ext))) copyfile(
fpath, os.path.join(env.subst("$UPLOAD_PORT"), "%s.%s" % (progname, ext))
)
print( print(
"Firmware has been successfully uploaded.\n" "Firmware has been successfully uploaded.\n"
"(Some boards may require manual hard reset)" "(Some boards may require manual hard reset)"
@ -212,7 +173,7 @@ def CheckUploadSize(_, target, source, env):
if not isinstance(cmd, list): if not isinstance(cmd, list):
cmd = cmd.split() cmd = cmd.split()
cmd = [arg.replace("$SOURCES", str(source[0])) for arg in cmd if arg] cmd = [arg.replace("$SOURCES", str(source[0])) for arg in cmd if arg]
sysenv = environ.copy() sysenv = os.environ.copy()
sysenv["PATH"] = str(env["ENV"]["PATH"]) sysenv["PATH"] = str(env["ENV"]["PATH"])
result = exec_command(env.subst(cmd), env=sysenv) result = exec_command(env.subst(cmd), env=sysenv)
if result["returncode"] != 0: if result["returncode"] != 0:

View File

@ -20,7 +20,7 @@ from serial.tools import miniterm
from platformio import exception, fs from platformio import exception, fs
from platformio.device.filters.base import register_filters from platformio.device.filters.base import register_filters
from platformio.device.serial import scan_serial_port from platformio.device.finder import find_serial_port
from platformio.platform.factory import PlatformFactory from platformio.platform.factory import PlatformFactory
from platformio.project.config import ProjectConfig from platformio.project.config import ProjectConfig
from platformio.project.exception import NotPlatformIOProjectError from platformio.project.exception import NotPlatformIOProjectError
@ -108,7 +108,7 @@ def device_monitor_cmd(**kwargs): # pylint: disable=too-many-branches
except NotPlatformIOProjectError: except NotPlatformIOProjectError:
pass pass
register_filters(platform=platform, options=kwargs) register_filters(platform=platform, options=kwargs)
kwargs["port"] = scan_serial_port( kwargs["port"] = find_serial_port(
initial_port=kwargs["port"], initial_port=kwargs["port"],
board_config=platform.board_config(project_options.get("board")) board_config=platform.board_config(project_options.get("board"))
if platform and project_options.get("board") if platform and project_options.get("board")

View File

@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
from fnmatch import fnmatch from fnmatch import fnmatch
import serial import serial
from platformio.compat import IS_WINDOWS from platformio.compat import IS_WINDOWS
from platformio.device.list import list_serial_ports from platformio.device.list import list_logical_devices, list_serial_ports
def is_pattern_port(port): def is_pattern_port(port):
@ -26,6 +27,13 @@ def is_pattern_port(port):
return set(["*", "?", "[", "]"]) & set(port) return set(["*", "?", "[", "]"]) & set(port)
def match_serial_port(pattern):
for item in list_serial_ports():
if fnmatch(item["port"], pattern):
return item["port"]
return None
def is_serial_port_ready(port, timeout=1): def is_serial_port_ready(port, timeout=1):
try: try:
serial.Serial(port, timeout=timeout).close() serial.Serial(port, timeout=timeout).close()
@ -35,7 +43,7 @@ def is_serial_port_ready(port, timeout=1):
return False return False
def scan_serial_port( def find_serial_port(
initial_port, board_config=None, upload_protocol=None, ensure_ready=False initial_port, board_config=None, upload_protocol=None, ensure_ready=False
): ):
if initial_port: if initial_port:
@ -44,9 +52,9 @@ def scan_serial_port(
return match_serial_port(initial_port) return match_serial_port(initial_port)
port = None port = None
if upload_protocol and upload_protocol.startswith("blackmagic"): if upload_protocol and upload_protocol.startswith("blackmagic"):
port = scan_blackmagic_serial_port() port = find_blackmagic_serial_port()
if not port and board_config: if not port and board_config:
port = scan_board_serial_port(board_config) port = find_board_serial_port(board_config)
if port: if port:
return port return port
@ -61,14 +69,7 @@ def scan_serial_port(
return usb_port or port return usb_port or port
def match_serial_port(pattern): def find_blackmagic_serial_port():
for item in list_serial_ports():
if fnmatch(item["port"], pattern):
return item["port"]
return None
def scan_blackmagic_serial_port():
for item in list_serial_ports(): for item in list_serial_ports():
port = item["port"] port = item["port"]
if IS_WINDOWS and port.startswith("COM") and len(port) > 4: if IS_WINDOWS and port.startswith("COM") and len(port) > 4:
@ -78,7 +79,7 @@ def scan_blackmagic_serial_port():
return None return None
def scan_board_serial_port(board_config): def find_board_serial_port(board_config):
board_hwids = board_config.get("build.hwids", []) board_hwids = board_config.get("build.hwids", [])
if not board_hwids: if not board_hwids:
return None return None
@ -89,3 +90,22 @@ def scan_board_serial_port(board_config):
if hwid_str in item["hwid"]: if hwid_str in item["hwid"]:
return port return port
return None return None
def find_mbed_disk(initial_port):
msdlabels = ("mbed", "nucleo", "frdm", "microbit")
for item in list_logical_devices():
if item["path"].startswith("/net"):
continue
if (
initial_port
and is_pattern_port(initial_port)
and not fnmatch(item["path"], initial_port)
):
continue
mbed_pages = [os.path.join(item["path"], n) for n in ("mbed.htm", "mbed.html")]
if any(os.path.isfile(p) for p in mbed_pages):
return item["path"]
if item["name"] and any(l in item["name"].lower() for l in msdlabels):
return item["path"]
return None

View File

@ -17,7 +17,7 @@ from time import sleep
import click import click
import serial import serial
from platformio.device.serial import scan_serial_port from platformio.device.finder import find_serial_port
from platformio.exception import UserSideException from platformio.exception import UserSideException
@ -77,7 +77,7 @@ class SerialTestOutputReader:
elapsed = 0 elapsed = 0
while elapsed < 5: while elapsed < 5:
port = scan_serial_port(**scan_options) port = find_serial_port(**scan_options)
if port: if port:
return port return port
sleep(0.25) sleep(0.25)