Compare commits

...

1 Commits

Author SHA1 Message Date
Robert Resch
de020143f1 Use duration based splitting if possible 2026-03-03 16:09:59 +01:00
3 changed files with 355 additions and 20 deletions

View File

@@ -40,6 +40,8 @@ env:
CACHE_VERSION: 3
UV_CACHE_VERSION: 1
MYPY_CACHE_VERSION: 1
PYTEST_DURATIONS_CACHE_VERSION: 1
PYTEST_DURATIONS_FILE: .ci/pytest_durations.json
HA_SHORT_VERSION: "2026.4"
DEFAULT_PYTHON: "3.14.2"
ALL_PYTHON_VERSIONS: "['3.14.2']"
@@ -894,12 +896,27 @@ jobs:
key: >-
${{ runner.os }}-${{ runner.arch }}-${{ steps.python.outputs.python-version }}-${{
needs.info.outputs.python_cache_key }}
- name: Restore pytest durations cache
uses: actions/cache/restore@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
with:
path: ${{ env.PYTEST_DURATIONS_FILE }}
key: >-
${{ runner.os }}-${{ runner.arch }}-pytest-durations-${{ env.PYTEST_DURATIONS_CACHE_VERSION }}-${{
github.base_ref || github.ref_name }}
restore-keys: |
${{ runner.os }}-${{ runner.arch }}-pytest-durations-${{ env.PYTEST_DURATIONS_CACHE_VERSION }}-${{ github.base_ref || github.ref_name }}-
${{ runner.os }}-${{ runner.arch }}-pytest-durations-${{ env.PYTEST_DURATIONS_CACHE_VERSION }}-dev
- name: Run split_tests.py
env:
TEST_GROUP_COUNT: ${{ needs.info.outputs.test_group_count }}
PYTEST_DURATIONS_FILE: ${{ env.PYTEST_DURATIONS_FILE }}
run: |
. venv/bin/activate
python -m script.split_tests ${TEST_GROUP_COUNT} tests
split_args=("${TEST_GROUP_COUNT}" tests)
if [[ -f "${PYTEST_DURATIONS_FILE}" ]]; then
split_args+=(--durations-file "${PYTEST_DURATIONS_FILE}")
fi
python -m script.split_tests "${split_args[@]}"
- name: Upload pytest_buckets
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
@@ -997,11 +1014,10 @@ jobs:
. venv/bin/activate
python --version
set -o pipefail
cov_params=()
params=(--junitxml=junit.xml -o junit_family=legacy)
if [[ "${SKIP_COVERAGE}" != "true" ]]; then
cov_params+=(--cov="homeassistant")
cov_params+=(--cov-report=xml)
cov_params+=(--junitxml=junit.xml -o junit_family=legacy)
params+=(--cov="homeassistant")
params+=(--cov-report=xml)
fi
echo "Test group ${TEST_GROUP}: $(sed -n "${TEST_GROUP},1p" pytest_buckets.txt)"
@@ -1012,7 +1028,7 @@ jobs:
--numprocesses auto \
--snapshot-details \
--dist=loadfile \
${cov_params[@]} \
${params[@]} \
-o console_output_style=count \
-p no:sugar \
--exclude-warning-annotations \
@@ -1032,6 +1048,25 @@ jobs:
name: coverage-${{ matrix.python-version }}-${{ matrix.group }}
path: coverage.xml
overwrite: true
- name: Collect pytest durations
env:
TEST_GROUP: ${{ matrix.group }}
PYTHON_VERSION: ${{ matrix.python-version }}
run: |
. venv/bin/activate
output="pytest-durations-${PYTHON_VERSION}-${TEST_GROUP}.json"
if [[ -f junit.xml ]]; then
python -m script.collect_test_durations --output "${output}" junit.xml
else
echo "::error::Missing junit.xml, cannot collect pytest durations"
exit 1
fi
- name: Upload pytest durations
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6.0.0
with:
name: pytest-durations-${{ matrix.python-version }}-${{ matrix.group }}
path: pytest-durations-*.json
overwrite: true
- name: Beautify test results
# For easier identification of parsing errors
if: needs.info.outputs.skip_coverage != 'true'
@@ -1050,6 +1085,68 @@ jobs:
run: |
./script/check_dirty
update-pytest-duration-cache:
name: Update pytest durations cache
runs-on: ubuntu-24.04
permissions:
contents: read
needs:
- info
- prepare-pytest-full
- pytest-full
if: |
needs.info.outputs.lint_only != 'true'
&& needs.info.outputs.test_full_suite == 'true'
steps:
- name: Set up Python ${{ env.DEFAULT_PYTHON }}
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
with:
python-version: ${{ env.DEFAULT_PYTHON }}
check-latest: true
- name: Restore pytest durations cache
uses: actions/cache/restore@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
with:
path: ${{ env.PYTEST_DURATIONS_FILE }}
key: >-
${{ runner.os }}-${{ runner.arch }}-pytest-durations-${{ env.PYTEST_DURATIONS_CACHE_VERSION }}-${{
github.base_ref || github.ref_name }}
restore-keys: |
${{ runner.os }}-${{ runner.arch }}-pytest-durations-${{ env.PYTEST_DURATIONS_CACHE_VERSION }}-${{ github.base_ref || github.ref_name }}-
${{ runner.os }}-${{ runner.arch }}-pytest-durations-${{ env.PYTEST_DURATIONS_CACHE_VERSION }}-dev
- name: Download pytest durations artifacts
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
with:
pattern: pytest-durations-*
merge-multiple: true
path: .ci/pytest-durations
- name: Merge pytest durations
run: |
input_files=()
for file in .ci/pytest-durations/*.json; do
[[ -f "${file}" ]] || continue
input_files+=("${file}")
done
if [[ ${#input_files[@]} -eq 0 ]]; then
mkdir -p "$(dirname "${PYTEST_DURATIONS_FILE}")"
if [[ ! -f "${PYTEST_DURATIONS_FILE}" ]]; then
echo "{}" > "${PYTEST_DURATIONS_FILE}"
fi
exit 0
fi
python -m script.collect_test_durations \
--existing "${PYTEST_DURATIONS_FILE}" \
--output "${PYTEST_DURATIONS_FILE}" \
"${input_files[@]}"
- name: Save pytest durations cache
uses: actions/cache/save@cdf6c1fa76f9f475f3d7449005a359c84ca0f306 # v5.0.3
with:
path: ${{ env.PYTEST_DURATIONS_FILE }}
key: >-
${{ runner.os }}-${{ runner.arch }}-pytest-durations-${{ env.PYTEST_DURATIONS_CACHE_VERSION }}-${{
github.base_ref || github.ref_name }}-${{ github.run_id }}-${{ github.run_attempt }}
pytest-mariadb:
name: Run ${{ matrix.mariadb-group }} tests Python ${{ matrix.python-version }}
runs-on: ubuntu-24.04

View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python3
"""Collect and merge pytest durations per test file."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from defusedxml import ElementTree as ET
def _load_json(path: Path) -> dict[str, float]:
"""Load durations from a JSON file."""
with path.open("r", encoding="utf-8") as file:
payload = json.load(file)
if not isinstance(payload, dict):
raise TypeError(f"Expected JSON object in {path}")
result: dict[str, float] = {}
for file_path, duration in payload.items():
if not isinstance(file_path, str) or not isinstance(duration, int | float):
continue
if duration <= 0:
continue
result[file_path] = float(duration)
return result
def _load_junit(path: Path) -> dict[str, float]:
"""Load durations from a JUnit XML file."""
tree = ET.parse(path)
root = tree.getroot()
result: dict[str, float] = {}
for testcase in root.iter("testcase"):
file_path = testcase.attrib.get("file")
if not file_path:
continue
raw_duration = testcase.attrib.get("time", "0")
try:
duration = float(raw_duration)
except ValueError:
continue
if duration <= 0:
continue
normalized = Path(file_path).as_posix()
result[normalized] = result.get(normalized, 0.0) + duration
return result
def _load_input(path: Path) -> dict[str, float]:
"""Load durations from either JSON or XML input."""
suffix = path.suffix.lower()
if suffix == ".json":
return _load_json(path)
if suffix == ".xml":
return _load_junit(path)
raise ValueError(f"Unsupported file type for {path}")
def merge_durations(
existing: dict[str, float],
incoming: dict[str, float],
smoothing: float,
) -> dict[str, float]:
"""Merge durations by smoothing with historical values.
Formula: merged = old * (1 - smoothing) + new * smoothing
"""
merged = dict(existing)
for file_path, duration in incoming.items():
previous = merged.get(file_path)
if previous is None:
merged[file_path] = duration
continue
merged[file_path] = (previous * (1 - smoothing)) + (duration * smoothing)
return merged
def main() -> None:
"""Run the duration collector."""
parser = argparse.ArgumentParser(
description="Collect and merge test durations from JUnit XML or JSON files"
)
parser.add_argument(
"inputs",
nargs="*",
type=Path,
help="Input files (.xml or .json)",
)
parser.add_argument(
"--existing",
type=Path,
help="Existing durations JSON file",
)
parser.add_argument(
"--output",
required=True,
type=Path,
help="Output JSON file",
)
parser.add_argument(
"--smoothing",
type=float,
default=0.35,
help="Weight for newly measured durations (0.0 to 1.0)",
)
args = parser.parse_args()
if not 0 <= args.smoothing <= 1:
raise ValueError("--smoothing must be between 0.0 and 1.0")
merged: dict[str, float] = {}
if args.existing and args.existing.exists():
merged = _load_json(args.existing)
incoming: dict[str, float] = {}
for input_file in args.inputs:
if not input_file.exists():
continue
for file_path, duration in _load_input(input_file).items():
incoming[file_path] = incoming.get(file_path, 0.0) + duration
merged = merge_durations(merged, incoming, args.smoothing)
args.output.parent.mkdir(parents=True, exist_ok=True)
with args.output.open("w", encoding="utf-8") as file:
json.dump(dict(sorted(merged.items())), file, indent=2)
file.write("\n")
print(
f"Wrote {len(merged)} file durations "
f"(updated {len(incoming)} from current run) to {args.output}"
)
if __name__ == "__main__":
main()

View File

@@ -5,8 +5,9 @@ from __future__ import annotations
import argparse
from dataclasses import dataclass, field
from math import ceil
import json
from pathlib import Path
from statistics import fmean
import subprocess
import sys
from typing import Final
@@ -20,12 +21,14 @@ class Bucket:
):
"""Initialize bucket."""
self.total_tests = 0
self.total_duration = 0.0
self._paths: list[str] = []
def add(self, part: TestFolder | TestFile) -> None:
"""Add tests to bucket."""
part.add_to_bucket()
self.total_tests += part.total_tests
self.total_duration += part.total_duration
self._paths.append(str(part.path))
def get_paths_line(self) -> str:
@@ -36,9 +39,9 @@ class Bucket:
class BucketHolder:
"""Class to hold buckets."""
def __init__(self, tests_per_bucket: int, bucket_count: int) -> None:
def __init__(self, duration_per_bucket: float, bucket_count: int) -> None:
"""Initialize bucket holder."""
self._tests_per_bucket = tests_per_bucket
self._duration_per_bucket = duration_per_bucket
self._bucket_count = bucket_count
self._buckets: list[Bucket] = [Bucket() for _ in range(bucket_count)]
@@ -46,18 +49,26 @@ class BucketHolder:
"""Split tests into buckets."""
digits = len(str(test_folder.total_tests))
sorted_tests = sorted(
test_folder.get_all_flatten(), reverse=True, key=lambda x: x.total_tests
test_folder.get_all_flatten(),
reverse=True,
key=lambda test: (test.total_duration, test.total_tests),
)
for tests in sorted_tests:
if tests.added_to_bucket:
# Already added to bucket
continue
print(f"{tests.total_tests:>{digits}} tests in {tests.path}")
smallest_bucket = min(self._buckets, key=lambda x: x.total_tests)
print(
f"{tests.total_tests:>{digits}} tests in {tests.path} "
f"(~{tests.total_duration:.2f}s)"
)
smallest_bucket = min(
self._buckets, key=lambda bucket: bucket.total_duration
)
is_file = isinstance(tests, TestFile)
if (
smallest_bucket.total_tests + tests.total_tests < self._tests_per_bucket
smallest_bucket.total_duration + tests.total_duration
< self._duration_per_bucket
) or is_file:
smallest_bucket.add(tests)
# Ensure all files from the same folder are in the same bucket
@@ -67,7 +78,9 @@ class BucketHolder:
if other_test is tests or isinstance(other_test, TestFolder):
continue
print(
f"{other_test.total_tests:>{digits}} tests in {other_test.path} (same bucket)"
f"{other_test.total_tests:>{digits}} tests in "
f"{other_test.path} (same bucket, "
f"~{other_test.total_duration:.2f}s)"
)
smallest_bucket.add(other_test)
@@ -79,7 +92,10 @@ class BucketHolder:
"""Create output file."""
with Path("pytest_buckets.txt").open("w") as file:
for idx, bucket in enumerate(self._buckets):
print(f"Bucket {idx + 1} has {bucket.total_tests} tests")
print(
f"Bucket {idx + 1} has {bucket.total_tests} tests "
f"(~{bucket.total_duration:.2f}s)"
)
file.write(bucket.get_paths_line())
@@ -88,6 +104,7 @@ class TestFile:
"""Class represents a single test file and the number of tests it has."""
total_tests: int
total_duration: float
path: Path
added_to_bucket: bool = field(default=False, init=False)
parent: TestFolder | None = field(default=None, init=False)
@@ -100,7 +117,7 @@ class TestFile:
def __gt__(self, other: TestFile) -> bool:
"""Return if greater than."""
return self.total_tests > other.total_tests
return self.total_duration > other.total_duration
class TestFolder:
@@ -116,6 +133,11 @@ class TestFolder:
"""Return total tests."""
return sum([test.total_tests for test in self.children.values()])
@property
def total_duration(self) -> float:
"""Return total estimated duration in seconds."""
return sum(test.total_duration for test in self.children.values())
@property
def added_to_bucket(self) -> bool:
"""Return if added to bucket."""
@@ -189,12 +211,66 @@ def collect_tests(path: Path) -> TestFolder:
print(f"Unexpected line: {line}")
sys.exit(1)
file = TestFile(int(total_tests), Path(file_path))
file = TestFile(int(total_tests), 0.0, Path(file_path))
folder.add_test_file(file)
return folder
def load_test_durations(path: Path | None) -> dict[str, float]:
"""Load known test durations keyed by file path."""
if path is None or not path.exists():
return {}
with path.open("r", encoding="utf-8") as file:
raw_data = json.load(file)
if not isinstance(raw_data, dict):
raise TypeError("Durations file should contain a JSON object")
durations: dict[str, float] = {}
for file_path, duration in raw_data.items():
if not isinstance(file_path, str) or not isinstance(duration, int | float):
continue
if duration <= 0:
continue
durations[file_path] = float(duration)
return durations
def assign_estimated_durations(
tests: TestFolder, known_durations: dict[str, float]
) -> tuple[float, int, int]:
"""Assign estimated durations to all test files.
Files with known timings use those values. New files (without timings)
receive an estimate based on average seconds per collected test.
"""
all_files = [file for file in tests.get_all_flatten() if isinstance(file, TestFile)]
known_seconds_per_test: list[float] = []
files_without_durations = []
for test_file in all_files:
if test_file.total_tests <= 0:
continue
duration = known_durations.get(str(test_file.path))
if duration is None:
files_without_durations.append(test_file)
continue
known_seconds_per_test.append(duration / test_file.total_tests)
test_file.total_duration = duration
default_seconds_per_test = (
fmean(known_seconds_per_test) if known_seconds_per_test else 0.1
)
for test_file in files_without_durations:
test_file.total_duration = test_file.total_tests * default_seconds_per_test
return default_seconds_per_test, len(files_without_durations), len(all_files)
def main() -> None:
"""Execute script."""
parser = argparse.ArgumentParser(description="Split tests into n buckets.")
@@ -217,19 +293,33 @@ def main() -> None:
help="Path to the test files to split into buckets",
type=Path,
)
parser.add_argument(
"--durations-file",
help="JSON file with per-test-file durations in seconds",
type=Path,
)
arguments = parser.parse_args()
print("Collecting tests...")
tests = collect_tests(arguments.path)
tests_per_bucket = ceil(tests.total_tests / arguments.bucket_count)
known_durations = load_test_durations(arguments.durations_file)
default_seconds_per_test, files_missing_durations, total_files = (
assign_estimated_durations(tests, known_durations)
)
bucket_holder = BucketHolder(tests_per_bucket, arguments.bucket_count)
duration_per_bucket = tests.total_duration / arguments.bucket_count
bucket_holder = BucketHolder(duration_per_bucket, arguments.bucket_count)
print("Splitting tests...")
bucket_holder.split_tests(tests)
print(f"Total tests: {tests.total_tests}")
print(f"Estimated tests per bucket: {tests_per_bucket}")
print(f"Files missing durations: {files_missing_durations}")
print(f"Total files: {total_files}")
print(f"Fallback seconds per test: {default_seconds_per_test:.4f}")
print(f"Estimated total duration: {tests.total_duration:.2f}s")
print(f"Estimated duration per bucket: {duration_per_bucket:.2f}s")
bucket_holder.create_ouput_file()