mirror of
https://github.com/home-assistant/core.git
synced 2025-08-02 20:25:07 +02:00
Replace pytest-test-groups by custom tests splitter (#114381)
Co-authored-by: Martin Hjelmare <marhje52@gmail.com> Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
225
script/split_tests.py
Executable file
225
script/split_tests.py
Executable file
@@ -0,0 +1,225 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Helper script to split test into n buckets."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass, field
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Final
|
||||
|
||||
|
||||
class Bucket:
|
||||
"""Class to hold bucket."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
):
|
||||
"""Initialize bucket."""
|
||||
self.total_tests = 0
|
||||
self._paths: list[str] = []
|
||||
|
||||
def add(self, part: TestFolder | TestFile) -> None:
|
||||
"""Add tests to bucket."""
|
||||
part.add_to_bucket()
|
||||
self.total_tests += part.total_tests
|
||||
self._paths.append(str(part.path))
|
||||
|
||||
def get_paths_line(self) -> str:
|
||||
"""Return paths."""
|
||||
return " ".join(self._paths) + "\n"
|
||||
|
||||
|
||||
class BucketHolder:
|
||||
"""Class to hold buckets."""
|
||||
|
||||
def __init__(self, tests_per_bucket: int, bucket_count: int) -> None:
|
||||
"""Initialize bucket holder."""
|
||||
self._tests_per_bucket = tests_per_bucket
|
||||
self._bucket_count = bucket_count
|
||||
self._buckets: list[Bucket] = [Bucket() for _ in range(bucket_count)]
|
||||
|
||||
def split_tests(self, test_folder: TestFolder) -> None:
|
||||
"""Split tests into buckets."""
|
||||
digits = len(str(test_folder.total_tests))
|
||||
sorted_tests = sorted(
|
||||
test_folder.get_all_flatten(), reverse=True, key=lambda x: x.total_tests
|
||||
)
|
||||
for tests in sorted_tests:
|
||||
print(f"{tests.total_tests:>{digits}} tests in {tests.path}")
|
||||
if tests.added_to_bucket:
|
||||
# Already added to bucket
|
||||
continue
|
||||
|
||||
smallest_bucket = min(self._buckets, key=lambda x: x.total_tests)
|
||||
if (
|
||||
smallest_bucket.total_tests + tests.total_tests < self._tests_per_bucket
|
||||
) or isinstance(tests, TestFile):
|
||||
smallest_bucket.add(tests)
|
||||
|
||||
# verify that all tests are added to a bucket
|
||||
if not test_folder.added_to_bucket:
|
||||
raise ValueError("Not all tests are added to a bucket")
|
||||
|
||||
def create_ouput_file(self) -> None:
|
||||
"""Create output file."""
|
||||
with open("pytest_buckets.txt", "w") as file:
|
||||
for idx, bucket in enumerate(self._buckets):
|
||||
print(f"Bucket {idx+1} has {bucket.total_tests} tests")
|
||||
file.write(bucket.get_paths_line())
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestFile:
|
||||
"""Class represents a single test file and the number of tests it has."""
|
||||
|
||||
total_tests: int
|
||||
path: Path
|
||||
added_to_bucket: bool = field(default=False, init=False)
|
||||
|
||||
def add_to_bucket(self) -> None:
|
||||
"""Add test file to bucket."""
|
||||
if self.added_to_bucket:
|
||||
raise ValueError("Already added to bucket")
|
||||
self.added_to_bucket = True
|
||||
|
||||
def __gt__(self, other: TestFile) -> bool:
|
||||
"""Return if greater than."""
|
||||
return self.total_tests > other.total_tests
|
||||
|
||||
|
||||
class TestFolder:
|
||||
"""Class to hold a folder with test files and folders."""
|
||||
|
||||
def __init__(self, path: Path) -> None:
|
||||
"""Initialize test folder."""
|
||||
self.path: Final = path
|
||||
self.children: dict[Path, TestFolder | TestFile] = {}
|
||||
|
||||
@property
|
||||
def total_tests(self) -> int:
|
||||
"""Return total tests."""
|
||||
return sum([test.total_tests for test in self.children.values()])
|
||||
|
||||
@property
|
||||
def added_to_bucket(self) -> bool:
|
||||
"""Return if added to bucket."""
|
||||
return all(test.added_to_bucket for test in self.children.values())
|
||||
|
||||
def add_to_bucket(self) -> None:
|
||||
"""Add test file to bucket."""
|
||||
if self.added_to_bucket:
|
||||
raise ValueError("Already added to bucket")
|
||||
for child in self.children.values():
|
||||
child.add_to_bucket()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Return representation."""
|
||||
return (
|
||||
f"TestFolder(total_tests={self.total_tests}, children={len(self.children)})"
|
||||
)
|
||||
|
||||
def add_test_file(self, file: TestFile) -> None:
|
||||
"""Add test file to folder."""
|
||||
path = file.path
|
||||
relative_path = path.relative_to(self.path)
|
||||
if not relative_path.parts:
|
||||
raise ValueError("Path is not a child of this folder")
|
||||
|
||||
if len(relative_path.parts) == 1:
|
||||
self.children[path] = file
|
||||
return
|
||||
|
||||
child_path = self.path / relative_path.parts[0]
|
||||
if (child := self.children.get(child_path)) is None:
|
||||
self.children[child_path] = child = TestFolder(child_path)
|
||||
elif not isinstance(child, TestFolder):
|
||||
raise ValueError("Child is not a folder")
|
||||
child.add_test_file(file)
|
||||
|
||||
def get_all_flatten(self) -> list[TestFolder | TestFile]:
|
||||
"""Return self and all children as flatten list."""
|
||||
result: list[TestFolder | TestFile] = [self]
|
||||
for child in self.children.values():
|
||||
if isinstance(child, TestFolder):
|
||||
result.extend(child.get_all_flatten())
|
||||
else:
|
||||
result.append(child)
|
||||
return result
|
||||
|
||||
|
||||
def collect_tests(path: Path) -> TestFolder:
|
||||
"""Collect all tests."""
|
||||
result = subprocess.run(
|
||||
["pytest", "--collect-only", "-qq", "-p", "no:warnings", path],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
print("Failed to collect tests:")
|
||||
print(result.stderr)
|
||||
print(result.stdout)
|
||||
sys.exit(1)
|
||||
|
||||
folder = TestFolder(path)
|
||||
|
||||
for line in result.stdout.splitlines():
|
||||
if not line.strip():
|
||||
continue
|
||||
file_path, _, total_tests = line.partition(": ")
|
||||
if not path or not total_tests:
|
||||
print(f"Unexpected line: {line}")
|
||||
sys.exit(1)
|
||||
|
||||
file = TestFile(int(total_tests), Path(file_path))
|
||||
folder.add_test_file(file)
|
||||
|
||||
return folder
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Execute script."""
|
||||
parser = argparse.ArgumentParser(description="Split tests into n buckets.")
|
||||
|
||||
def check_greater_0(value: str) -> int:
|
||||
ivalue = int(value)
|
||||
if ivalue <= 0:
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"{value} is an invalid. Must be greater than 0"
|
||||
)
|
||||
return ivalue
|
||||
|
||||
parser.add_argument(
|
||||
"bucket_count",
|
||||
help="Number of buckets to split tests into",
|
||||
type=check_greater_0,
|
||||
)
|
||||
parser.add_argument(
|
||||
"path",
|
||||
help="Path to the test files to split into buckets",
|
||||
type=Path,
|
||||
)
|
||||
|
||||
arguments = parser.parse_args()
|
||||
|
||||
print("Collecting tests...")
|
||||
tests = collect_tests(arguments.path)
|
||||
tests_per_bucket = ceil(tests.total_tests / arguments.bucket_count)
|
||||
|
||||
bucket_holder = BucketHolder(tests_per_bucket, arguments.bucket_count)
|
||||
print("Splitting tests...")
|
||||
bucket_holder.split_tests(tests)
|
||||
|
||||
print(f"Total tests: {tests.total_tests}")
|
||||
print(f"Estimated tests per bucket: {tests_per_bucket}")
|
||||
|
||||
bucket_holder.create_ouput_file()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user