diff --git a/lib/pavilion/cmd_utils.py b/lib/pavilion/cmd_utils.py index 7376446f2..2a9cb6a59 100644 --- a/lib/pavilion/cmd_utils.py +++ b/lib/pavilion/cmd_utils.py @@ -8,7 +8,7 @@ import sys import time from pathlib import Path -from typing import List, TextIO, Union, Iterator +from typing import List, TextIO, Union, Iterator, Optional from collections import defaultdict from pavilion import config @@ -19,67 +19,18 @@ from pavilion import series from pavilion import sys_vars from pavilion import utils +from pavilion import test_ids from pavilion.errors import TestRunError, CommandError, TestSeriesError, \ PavilionError, TestGroupError from pavilion.test_run import TestRun, load_tests, TestAttributes from pavilion.types import ID_Pair -from pavilion.micro import flatten +from pavilion.test_ids import TestID, SeriesID +from pavilion.micro import listmap LOGGER = logging.getLogger(__name__) -def expand_range(test_range: str) -> List[str]: - """Expand a given test or series range into a list of the individual - tests or series in that range""" - - tests = [] - - if test_range == "all": - return ["all"] - - elif '-' in test_range: - id_start, id_end = test_range.split('-', 1) - - if id_start.startswith('s'): - series_range_start = int(id_start.replace('s','')) - - if id_end.startswith('s'): - series_range_end = int(id_end.replace('s','')) - else: - series_range_end = int(id_end) - - series_ids = range(series_range_start, series_range_end+1) - - for sid in series_ids: - tests.append('s' + str(sid)) - else: - test_range_start = int(id_start) - test_range_end = int(id_end) - test_ids = range(test_range_start, test_range_end+1) - - for tid in test_ids: - tests.append(str(tid)) - else: - tests.append(test_range) - - return tests - - -def expand_ranges(ranges: Iterator[str]) -> Iterator[str]: - """Given a sequence of test and series ranges, expand them - into a sequence of individual tests and series.""" - - return flatten(map(expand_range, ranges)) - - -#pylint: disable=C0103 -def is_series_id(id: str) -> bool: - """Determine whether the given ID is a series ID.""" - - return len(id) > 0 and id[0].lower() == 's' - - -def load_last_series(pav_cfg, errfile: TextIO) -> Union[series.TestSeries, None]: +def load_last_series(pav_cfg, errfile: TextIO) -> Optional[series.TestSeries]: """Load the series object for the last series run by this user on this system.""" try: @@ -88,6 +39,9 @@ def load_last_series(pav_cfg, errfile: TextIO) -> Union[series.TestSeries, None] output.fprint("Failed to find last series: {}".format(err.args[0]), file=errfile) return None + if series_id is None: + return None + try: return series.TestSeries.load(pav_cfg, series_id) except series.TestSeriesError as err: @@ -133,30 +87,26 @@ def arg_filtered_tests(pav_cfg: "PavConfig", args: argparse.Namespace, sys_name = getattr(args, 'sys_name', sys_vars.get_vars(defer=True).get('sys_name')) sort_by = getattr(args, 'sort_by', 'created') - ids = [] + ids = test_ids.resolve_ids(args.tests) + test_filter = args.filter - for test_range in args.tests: - ids.extend(expand_range(test_range)) - - args.tests = ids - - if 'all' in args.tests: + if SeriesID('all') in ids: for arg, default in filters.TEST_FILTER_DEFAULTS.items(): if hasattr(args, arg) and default != getattr(args, arg): break else: output.fprint(verbose, "Using default search filters: The current system, user, and " "created less than 1 day ago.", color=output.CYAN) - args.filter = make_filter_query() + test_filter = make_filter_query() - if args.filter is None: + if test_filter is None: filter_func = filters.const(True) # Always return True else: - filter_func = filters.parse_query(args.filter) + filter_func = filters.parse_query(test_filter) order_func, order_asc = filters.get_sort_opts(sort_by, "TEST") - if 'all' in args.tests: + if SeriesID('all') in ids: tests = dir_db.SelectItems([], []) working_dirs = set(map(lambda cfg: cfg['working_dir'], pav_cfg.configs.values())) @@ -177,10 +127,10 @@ def arg_filtered_tests(pav_cfg: "PavConfig", args: argparse.Namespace, return tests - if not args.tests: - args.tests.append('last') + if len(ids) == 0: + ids.append(SeriesID('last')) - test_paths = test_list_to_paths(pav_cfg, args.tests, verbose) + test_paths = test_list_to_paths(pav_cfg, ids, verbose) return dir_db.select_from( pav_cfg, @@ -194,7 +144,7 @@ def arg_filtered_tests(pav_cfg: "PavConfig", args: argparse.Namespace, def make_filter_query() -> str: - template = 'user={} and created<{}' + template = 'user={} and created>{}' user = utils.get_login() time = (dt.datetime.now() - dt.timedelta(days=1)).isoformat() @@ -216,13 +166,15 @@ def arg_filtered_series(pav_cfg: config.PavConfig, args: argparse.Namespace, search all series (with a default current user/system/1-day filter) and additonally filtered by args attributes provied via filters.add_series_filter_args().""" + args.series = listmap(SeriesID, args.series) + limit = getattr(args, 'limit', filters.SERIES_FILTER_DEFAULTS['limit']) verbose = verbose or io.StringIO() - if not args.series: - args.series = ['last'] + if len(args.series) == 0: + args.series = [SeriesID('last')] - if 'all' in args.series: + if SeriesID('all') in args.series: for arg, default in filters.SERIES_FILTER_DEFAULTS.items(): if hasattr(args, arg) and default != getattr(args, arg): break @@ -236,14 +188,14 @@ def arg_filtered_series(pav_cfg: config.PavConfig, args: argparse.Namespace, for sid in args.series: # Go through each provided sid (including last and all) and find all # matching series. Then only add them if we haven't seen them yet. - if sid == 'last': + if sid == SeriesID('last'): last_series = load_last_series(pav_cfg, verbose) if last_series is None: return [] found_series.append(last_series.info()) - elif sid == 'all': + elif sid == SeriesID('all'): sort_by = getattr(args, 'sort_by', filters.SERIES_FILTER_DEFAULTS['sort_by']) order_func, order_asc = filters.get_sort_opts(sort_by, 'SERIES') @@ -264,7 +216,7 @@ def arg_filtered_series(pav_cfg: config.PavConfig, args: argparse.Namespace, limit=limit, ).data else: - found_series.append(series.SeriesInfo.load(pav_cfg, sid)) + found_series.append(series.SeriesInfo.load(pav_cfg, sid.id_str)) matching_series = [] for sinfo in found_series: @@ -321,7 +273,8 @@ def get_collection_path(pav_cfg, collection) -> Union[Path, None]: return None -def test_list_to_paths(pav_cfg, req_tests, errfile=None) -> List[Path]: +def test_list_to_paths(pav_cfg: "PavConfig", req_tests: Union["TestID", "SeriesID"], + errfile: "StringIO" = None) -> List[Path]: """Given a list of raw test id's and series id's, return a list of paths to those tests. The keyword 'last' may also be given to get the last series run by @@ -337,20 +290,23 @@ def test_list_to_paths(pav_cfg, req_tests, errfile=None) -> List[Path]: errfile = io.StringIO() test_paths = [] + for raw_id in req_tests: - if raw_id == 'last': + if raw_id == SeriesID('last'): raw_id = series.load_user_series_id(pav_cfg, errfile) + if raw_id is None: output.fprint(errfile, "User has no 'last' series for this machine.", color=output.YELLOW) continue + raw_id = SeriesID(raw_id) + if raw_id is None or not raw_id: continue - if '.' in raw_id or utils.is_int(raw_id): - # This is a test id. + if isinstance(raw_id, TestID): try: test_wd, _id = TestRun.parse_raw_id(pav_cfg, raw_id) except TestRunError as err: @@ -361,31 +317,29 @@ def test_list_to_paths(pav_cfg, req_tests, errfile=None) -> List[Path]: test_paths.append(test_path) if not test_path.exists(): output.fprint(errfile, - "Test run with id '{}' could not be found.".format(raw_id), + "Test run with id '{}' could not be found.".format(raw_id.id_str), color=output.YELLOW) - elif raw_id[0] == 's' and utils.is_int(raw_id[1:]): - # A series. + elif isinstance(raw_id, SeriesID): try: test_paths.extend( - series.list_series_tests(pav_cfg, raw_id)) + series.list_series_tests(pav_cfg, raw_id.id_str)) except TestSeriesError: - output.fprint(errfile, "Invalid series id '{}'".format(raw_id), + output.fprint(errfile, "Invalid series id '{}'".format(raw_id.id_str), color=output.YELLOW) else: - # A group try: - group = groups.TestGroup(pav_cfg, raw_id) + group = groups.TestGroup(pav_cfg, raw_id.id_str) except TestGroupError as err: output.fprint( errfile, "Invalid test group id '{}'.\n{}" - .format(raw_id, err.pformat())) + .format(raw_id.id_str, err.pformat())) continue if not group.exists(): output.fprint( errfile, - "Group '{}' does not exist.".format(raw_id)) + "Group '{}' does not exist.".format(raw_id.id_str)) continue try: @@ -394,7 +348,7 @@ def test_list_to_paths(pav_cfg, req_tests, errfile=None) -> List[Path]: output.fprint( errfile, "Invalid test group id '{}', could not get tests from group." - .format(raw_id)) + .format(raw_id.id_str)) return test_paths @@ -464,7 +418,7 @@ def get_tests_by_paths(pav_cfg, test_paths: List[Path], errfile: TextIO, return load_tests(pav_cfg, test_pairs, errfile) -def get_tests_by_id(pav_cfg, test_ids: List['str'], errfile: TextIO, +def get_tests_by_id(pav_cfg, ids: List['str'], errfile: TextIO, exclude_ids: List[str] = None) -> List[TestRun]: """Convert a list of raw test id's and series id's into a list of test objects. @@ -476,26 +430,25 @@ def get_tests_by_id(pav_cfg, test_ids: List['str'], errfile: TextIO, :return: List of test objects """ - test_ids = [str(test) for test in test_ids.copy()] + tids = test_ids.resolve_ids(ids) - if not test_ids: + if len(tids) == 0: # Get the last series ran by this user series_id = series.load_user_series_id(pav_cfg) if series_id is not None: - test_ids.append(series_id) + tids.append(SeriesID(series_id)) else: raise CommandError("No tests specified and no last series was found.") # Convert series and test ids into test paths. test_id_pairs = [] - for raw_id in test_ids: - # Series start with 's' (like 'snake') and never have labels - if '.' not in raw_id and raw_id.startswith('s'): + for raw_id in tids: + if SeriesID.is_valid_id(raw_id.id_str): try: - series_obj = series.TestSeries.load(pav_cfg, raw_id) + series_obj = series.TestSeries.load(pav_cfg, raw_id.id_str) except TestSeriesError as err: output.fprint(errfile, "Suite {} could not be found.\n{}" - .format(raw_id, err), color=output.RED) + .format(raw_id.id_str, err), color=output.RED) continue test_id_pairs.extend(list(series_obj.tests.keys())) @@ -506,7 +459,7 @@ def get_tests_by_id(pav_cfg, test_ids: List['str'], errfile: TextIO, except TestRunError as err: output.fprint(sys.stdout, "Error loading test '{}': {}" - .format(raw_id, err)) + .format(raw_id.id_str, err)) if exclude_ids: test_id_pairs = _filter_tests_by_raw_id(pav_cfg, test_id_pairs, exclude_ids) diff --git a/lib/pavilion/commands/_run.py b/lib/pavilion/commands/_run.py index 597192886..66673131f 100644 --- a/lib/pavilion/commands/_run.py +++ b/lib/pavilion/commands/_run.py @@ -17,6 +17,7 @@ from pavilion.sys_vars import base_classes from pavilion.test_run import TestRun, mass_status_update from pavilion.variables import VariableSetManager +from pavilion.test_ids import resolve_ids from .base_classes import Command # We need to catch pretty much all exceptions to cleanly report errors. @@ -41,6 +42,8 @@ def run(self, pav_cfg, args): """Load and run an already prepped test.""" tests = [] + args.test_ids = resolve_ids(args.test_ids) + for test_id in args.test_ids: try: tests.append(TestRun.load_from_raw_id(pav_cfg, test_id)) diff --git a/lib/pavilion/commands/cancel.py b/lib/pavilion/commands/cancel.py index 4299b55fc..dd0ffc105 100644 --- a/lib/pavilion/commands/cancel.py +++ b/lib/pavilion/commands/cancel.py @@ -12,7 +12,8 @@ from pavilion.errors import TestSeriesError from pavilion.test_run import TestRun from pavilion.config import PavConfig -from pavilion.micro import partition +from pavilion.test_ids import TestID, SeriesID +from pavilion.micro import partition, listmap from .base_classes import Command from ..errors import TestRunError @@ -53,10 +54,10 @@ def run(self, pav_cfg: PavConfig, args: Namespace) -> int: args.tests.append(series_id) # Separate out into tests and series - series_ids, test_ids = partition(cmd_utils.is_series_id, args.tests) + series_ids, test_ids = partition(SeriesID.is_valid_id, args.tests) - args.tests = test_ids - args.series = series_ids + args.tests = list(test_ids) + args.series = list(series_ids) # Get TestRun and TestSeries objects test_paths = cmd_utils.arg_filtered_tests(pav_cfg, args, verbose=self.errfile).paths diff --git a/lib/pavilion/commands/log.py b/lib/pavilion/commands/log.py index 20670f31a..8b5d61cf9 100644 --- a/lib/pavilion/commands/log.py +++ b/lib/pavilion/commands/log.py @@ -10,6 +10,7 @@ from pavilion import output from pavilion import series, series_config from pavilion.test_run import TestRun +from pavilion.test_ids import TestID from .base_classes import Command @@ -145,7 +146,7 @@ def run(self, pav_cfg, args): cmd_name = args.log_cmd if cmd_name == 'states': - return self._states(pav_cfg, args.id, raw=args.raw, raw_time=args.raw_time) + return self._states(pav_cfg, TestID(args.id), raw=args.raw, raw_time=args.raw_time) if cmd_name in ['global', 'all_results', 'allresults', 'all-results']: if 'results' in cmd_name: @@ -158,7 +159,7 @@ def run(self, pav_cfg, args): if cmd_name == 'series': test = series.TestSeries.load(pav_cfg, args.id) else: - test = TestRun.load_from_raw_id(pav_cfg, args.id) + test = TestRun.load_from_raw_id(pav_cfg, TestID(args.id)) except errors.TestRunError as err: output.fprint(self.errfile, "Error loading test.", err, color=output.RED) return 1 @@ -219,7 +220,7 @@ def run(self, pav_cfg, args): break return 0 - def _states(self, pav_cfg, test_id: str, raw: bool = False, raw_time: bool = False): + def _states(self, pav_cfg, test_id: TestID, raw: bool = False, raw_time: bool = False): """Print the states for a test.""" try: diff --git a/lib/pavilion/commands/status.py b/lib/pavilion/commands/status.py index 035a2d2a1..d3823c6a2 100644 --- a/lib/pavilion/commands/status.py +++ b/lib/pavilion/commands/status.py @@ -48,7 +48,7 @@ def _setup_arguments(self, parser): filters.add_test_filter_args(parser) - def run(self, pav_cfg, args): + def run(self, pav_cfg: "PavConfig", args: "Namespace") -> int: """Gathers and prints the statuses from the specified test runs and/or series.""" try: diff --git a/lib/pavilion/filters/parse_time.py b/lib/pavilion/filters/parse_time.py index 2a38a6acc..fb4e1e33a 100644 --- a/lib/pavilion/filters/parse_time.py +++ b/lib/pavilion/filters/parse_time.py @@ -69,7 +69,7 @@ def parse_duration(rval: str, now: datetime) -> datetime: dyear, dmonth = divmod(mag, MONTHS_PER_YEAR) new_day = now.day - new_month = now.month - dmonth + new_month = (now.month - dmonth) % 12 new_year = now.year - dyear return safe_update(now, year=new_year, month=new_month, day=new_day) diff --git a/lib/pavilion/test_ids.py b/lib/pavilion/test_ids.py new file mode 100644 index 000000000..f0439bfdf --- /dev/null +++ b/lib/pavilion/test_ids.py @@ -0,0 +1,248 @@ +from typing import Union, Tuple, List, NewType, Iterator, TypeVar, Iterable +from abc import abstractmethod + +from pavilion.micro import flatten +from pavilion.utils import is_int + + +class ID: + """Base class for IDs""" + + def __init__(self, id_str: str): + self.id_str = id_str + + # pylint: disable=no-self-argument + @abstractmethod + def is_valid_id(id_str: str) -> bool: + """Determine whether the given string constitutes a valid ID.""" + ... + + def __str__(self) -> str: + return self.id_str + + def __eq__(self, other: "ID") -> bool: + return self.id_str == other.id_str + + def __repr__(self) -> str: + return f"{type(self).__name__}({self.id_str})" + + +class TestID(ID): + """Represents a single test ID.""" + + @staticmethod + def is_valid_id(id_str: str) -> bool: + """Determine whether the given string constitutes a valid test ID.""" + + return '.' in id_str or (is_int(id_str) and int(id_str) > 0) + + def is_int(self): + """Determine whether the test ID is an integer value.""" + + return is_int(self.id_str) + + def as_int(self): + """Convert the test ID into an integer, if possible.""" + + try: + return int(self.id_str) + except: + raise ValueError(f"Test with ID {self.id_str} cannot be converted to an integer.") + + @property + def parts(self) -> List[str]: + """Return a list of components of the test ID, where components are separated by + periods.""" + + return self.id_str.split('.', 1) + + +class SeriesID(ID): + """Represents a single series ID.""" + + @staticmethod + def is_valid_id(id_str: str) -> bool: + """Determine whether the given string constitutes a valid series ID.""" + + return id_str == 'all' or id_str == 'last' or (len(id_str) > 0 and id_str[0] == 's' \ + and is_int(id_str[1:]) and int(id_str[1:]) > 0) + + def is_int(self): + """Determine whether the series ID is an integer value.""" + + return len(self.id_str) > 0 and is_int(self.id_str[1:]) + + def as_int(self): + """Convert the series ID into an integer, if possible.""" + + if self.all() or self.last(): + raise ValueError(f"Series with ID {self.id_str} cannot be converted to an integer.") + + return int(self.id_str[1:]) + + def all(self): + """Determine whether the series is the set of all tests.""" + + return self.id_str == "all" + + def last(self): + """Determine whether the series is the most recently run.""" + + return self.id_str == "last" + + +class GroupID(ID): + """Represents a single group ID.""" + + @staticmethod + def is_valid_id(id_str: str) -> bool: + """Determine whether the given string constitutes a valid group ID.""" + return len(id_str) > 0 and not (TestID.is_valid_id(id_str) or SeriesID.is_valid_id(id_str)) + + +class Range: + """Represents a contiguous sequence of IDs.""" + + def __init__(self, start: int, end: int): + self.start = start + self.end = end + + # pylint: disable=no-self-argument + @abstractmethod + def is_valid_range_str(rng_str: str) -> bool: + """Determine whether the given string constitutes a valid range.""" + ... + + # pylint: disable=no-self-argument + @abstractmethod + def from_str(rng_str: str) -> "Range": + """Produce a new range object from a string. + + NOTE: This method should not perform validation. It assumes that validation has been + performed prior to being called.""" + ... + + @abstractmethod + def expand(self) -> Iterator: + """Get the sequence of all values in the range.""" + ... + + def __eq__(self, other: "Range") -> bool: + if not isinstance(other, type(self)): + return False + + return self.start == other.start and self.end == other.end + + @abstractmethod + def __str__(self) -> str: + ... + + def __repr__(self) -> str: + return f"{type(self).__name__}({self.start}, {self.end})" + + +class TestRange(Range): + """Represents a contiguous sequence of test IDs.""" + + @staticmethod + def is_valid_range_str(rng_str: str) -> bool: + """Determine whether the given string constitutes a valid test range.""" + + rng_str = rng_str.split('-') + + if len(rng_str) != 2: + return False + + start, end = rng_str + + if not (is_int(start) and is_int(end)): + return False + if not (int(start) > 0 and int(end) > 0): + return False + + # Allow degenerate ranges + return int(end) - int(start) >= 0 + + @staticmethod + def from_str(rng_str: str) -> "TestRange": + """Produce a new test range object from a string. Assumes validation has been performed + prior to being called.""" + + start, end = rng_str.split('-') + + return TestRange(int(start), int(end)) + + def expand(self) -> Iterator["TestRange"]: + """Get the sequence of all series IDs in the range.""" + + return map(TestID, map(str, range(self.start, self.end + 1))) + + def __str__(self) -> str: + return f"{self.start}-{self.end}" + + +class SeriesRange(Range): + """Represents a contiguous sequence of series IDs.""" + + @staticmethod + def is_valid_range_str(rng_str: str) -> bool: + rng_str = rng_str.split('-') + + if len(rng_str) != 2: + return False + + start, end = rng_str + + if not (is_int(start[1:]) and is_int(end[1:])): + return False + if not (int(start[1:]) > 0 and int(end[1:]) > 0): + return False + + # Allow degenerate ranges + return int(end[1:]) - int(start[1:]) >= 0 + + @staticmethod + def from_str(rng_str: str) -> "SeriesRange": + """Produce a new series range object from a string. Assumes validation has been performed + prior to being called.""" + + start, end = rng_str.split('-') + + return SeriesRange(int(start[1:]), int(end[1:])) + + def expand(self) -> Iterator["TestRange"]: + """Get the sequence of all series IDs in the range.""" + + return map(SeriesID, map(lambda x: f"s{x}", range(self.start, self.end + 1))) + + def __str__(self) -> str: + return f"s{self.start}-s{self.end}" + + +def multi_convert(id_str: str) -> Union[List[TestID], List[SeriesID], List[GroupID]]: + """Convert a string into a list (possibly a singleton list) of either a TestID, SeriesID, + or GroupID as appropriate.""" + + if TestRange.is_valid_range_str(id_str): + return list(TestRange.from_str(id_str).expand()) + if SeriesRange.is_valid_range_str(id_str): + return list(SeriesRange.from_str(id_str).expand()) + if TestID.is_valid_id(id_str): + return [TestID(id_str)] + if SeriesID.is_valid_id(id_str): + return [SeriesID(id_str)] + + return [GroupID(id_str)] + + +def resolve_ids(ids: Iterable[str]) -> List[Union[TestID, SeriesID, GroupID]]: + """Fully resolve all IDs in the given list into either test IDs, series IDs, or group IDs.""" + + ids = list(ids) + + if "all" in ids: + return [SeriesID("all")] + + ids = (i for i in ids if i != "all") + + return list(flatten(map(multi_convert, ids))) diff --git a/lib/pavilion/test_run/test_run.py b/lib/pavilion/test_run/test_run.py index 320940f98..7a7568f10 100644 --- a/lib/pavilion/test_run/test_run.py +++ b/lib/pavilion/test_run/test_run.py @@ -422,12 +422,13 @@ def _validate_config(self): "being defined in the pavilion config.") @classmethod - def parse_raw_id(cls, pav_cfg, raw_test_id: str) -> ID_Pair: + def parse_raw_id(cls, pav_cfg, raw_test_id: "TestID") -> ID_Pair: """Parse a raw test run id and return the label, working_dir, and id for that test. The test run need not exist, but the label must.""" - parts = raw_test_id.split('.', 1) - if not parts: + parts = raw_test_id.parts + + if len(parts) == 0: raise TestRunNotFoundError("Blank test run id given") elif len(parts) == 1: cfg_label = 'main' @@ -451,7 +452,7 @@ def parse_raw_id(cls, pav_cfg, raw_test_id: str) -> ID_Pair: return ID_Pair((working_dir, test_id)) @classmethod - def load_from_raw_id(cls, pav_cfg, raw_test_id: str) -> 'TestRun': + def load_from_raw_id(cls, pav_cfg, raw_test_id: "TestID") -> 'TestRun': """Load a test given a raw test id string, in the form [label].test_id. The optional label will allow us to look up the config path for the test.""" @@ -655,7 +656,6 @@ def build(self, cancel_event=None, tracker: BuildTracker = None): :returns: True if build successful """ - if tracker is None and self.builder is not None: tracker = MultiBuildTracker().register(self) diff --git a/test/tests/cancel_tests.py b/test/tests/cancel_tests.py index b0aa9e735..c730987c3 100644 --- a/test/tests/cancel_tests.py +++ b/test/tests/cancel_tests.py @@ -4,6 +4,7 @@ from pavilion import schedulers from pavilion import unittest from pavilion.status_file import STATES +from pavilion.timing import wait class CancelTests(unittest.PavTestCase): @@ -26,12 +27,10 @@ def test_cancel_jobs(self): test1.cancel("For fun") - # Wait till we know test2 is running - while not test1.complete: - time.sleep(0.1) + wait(lambda: test1.complete, interval=0.1, timeout=30) - while not test2.status.has_state(STATES.RUNNING): - time.sleep(0.1) + # Wait till we know test2 is running + wait(lambda: test2.status.has_state("RUNNING"), interval=0.1, timeout=30) jobs = cancel_utils.cancel_jobs(self.pav_cfg, [test1, test2]) self.assertEqual(test2.status.current().state, STATES.RUNNING) diff --git a/test/tests/general_tests.py b/test/tests/general_tests.py index 09e5aae37..85abd1b86 100644 --- a/test/tests/general_tests.py +++ b/test/tests/general_tests.py @@ -9,6 +9,7 @@ import yc_yaml as yaml from pavilion.test_run import TestRun from pavilion import utils +from pavilion.test_ids import TestID from pavilion.unittest import PavTestCase @@ -128,7 +129,7 @@ def test_legacy_runs(self): build_dst = dst_path/build_dst (dst_path/'build_dir').rename(build_dst) - test = TestRun.load_from_raw_id(self.pav_cfg, run_id) + test = TestRun.load_from_raw_id(self.pav_cfg, TestID(run_id)) self.assertTrue(test.results) self.assertTrue(test.complete) diff --git a/test/tests/sched_tests.py b/test/tests/sched_tests.py index 4cd0d0217..2a17f1592 100644 --- a/test/tests/sched_tests.py +++ b/test/tests/sched_tests.py @@ -491,7 +491,7 @@ def test_kickoff_flex(self): for test in tests: try: - test.wait(timeout=20) + test.wait(timeout=10) except TimeoutError: run_log_path = test.path / 'run.log' if run_log_path.exists(): @@ -589,8 +589,8 @@ def test_task_based(self): test = self._quick_test(test_cfg, finalize=False) test2 = self._quick_test(test_cfg, finalize=False) dummy.schedule_tests(self.pav_cfg, [test, test2]) - test.wait() - test2.wait() + test.wait(timeout=10) + test2.wait(timeout=10) self.assertIn("tasks: 21", (test.path/'run.log').open().read()) self.assertIn("tasks: 21", (test.path/'run.log').open().read()) @@ -614,7 +614,7 @@ def test_wrapper(self): dummy = pavilion.schedulers.get_plugin('dummy') dummy.schedule_tests(self.pav_cfg, [test]) # Wait few seconds for the test to be scheduled to run. - test.wait() + test.wait(timeout=10) # Check if it actually echoed to log with (test.path/'run.log').open('r') as runlog: diff --git a/test/tests/status_cmd_tests.py b/test/tests/status_cmd_tests.py index ddf93e01b..f3f600c16 100644 --- a/test/tests/status_cmd_tests.py +++ b/test/tests/status_cmd_tests.py @@ -263,6 +263,8 @@ def test_status_summary(self): test.RUN_SILENT_TIMEOUT = 1 # Testing that summary flags return correctly + arg_list = ['-s'] + args = parser.parse_args(arg_list) self.assertEqual(status_cmd.run(self.pav_cfg, args), 0) def test_status_history(self): diff --git a/test/tests/test_id_tests.py b/test/tests/test_id_tests.py new file mode 100644 index 000000000..794e53ccc --- /dev/null +++ b/test/tests/test_id_tests.py @@ -0,0 +1,105 @@ +from pavilion.unittest import PavTestCase +from pavilion.test_ids import * + +class TestIDTests(PavTestCase): + + def test_string_conversion(self): + """Test that ID objects convert to the correct strings.""" + + tid = TestID("foo") + + self.assertEqual(str(tid), "foo") + + sid = SeriesID("s5") + + self.assertEqual(str(sid), "s5") + + def test_validate_id(self): + """Test that the ID objects correctly validate IDs.""" + + self.assertTrue(TestID.is_valid_id("1")) + self.assertTrue(TestID.is_valid_id("foo.bar")) + self.assertFalse(TestID.is_valid_id("foobar")) + self.assertFalse(TestID.is_valid_id("-7")) + self.assertFalse(TestID.is_valid_id("0")) + self.assertFalse(TestID.is_valid_id("s1")) + self.assertFalse(TestID.is_valid_id("")) + + self.assertTrue(SeriesID.is_valid_id("s1")) + self.assertTrue(SeriesID.is_valid_id("all")) + self.assertTrue(SeriesID.is_valid_id("last")) + self.assertFalse(SeriesID.is_valid_id("1")) + self.assertFalse(SeriesID.is_valid_id("sh")) + self.assertFalse(SeriesID.is_valid_id("s0")) + self.assertFalse(SeriesID.is_valid_id("s-1")) + self.assertFalse(SeriesID.is_valid_id("")) + + self.assertTrue(GroupID.is_valid_id("foo")) + self.assertFalse(GroupID.is_valid_id("foo.bar")) + self.assertFalse(GroupID.is_valid_id("1")) + self.assertFalse(GroupID.is_valid_id("s1")) + self.assertFalse(GroupID.is_valid_id("")) + + self.assertTrue(TestRange.is_valid_range_str("1-3")) + self.assertTrue(TestRange.is_valid_range_str("1-1")) + self.assertFalse(TestRange.is_valid_range_str("foo")) + self.assertFalse(TestRange.is_valid_range_str("")) + self.assertFalse(TestRange.is_valid_range_str("1")) + self.assertFalse(TestRange.is_valid_range_str("-")) + self.assertFalse(TestRange.is_valid_range_str("1-")) + self.assertFalse(TestRange.is_valid_range_str("0-2")) + self.assertFalse(TestRange.is_valid_range_str("2-1")) + self.assertFalse(TestRange.is_valid_range_str("-3-1")) + + self.assertTrue(SeriesRange.is_valid_range_str("s1-s3")) + self.assertTrue(SeriesRange.is_valid_range_str("s1-s1")) + self.assertFalse(SeriesRange.is_valid_range_str("s")) + self.assertFalse(SeriesRange.is_valid_range_str("")) + self.assertFalse(SeriesRange.is_valid_range_str("1")) + self.assertFalse(SeriesRange.is_valid_range_str("s1")) + self.assertFalse(SeriesRange.is_valid_range_str("-")) + self.assertFalse(SeriesRange.is_valid_range_str("s1-")) + self.assertFalse(SeriesRange.is_valid_range_str("s0-s2")) + self.assertFalse(SeriesRange.is_valid_range_str("s0-2")) + self.assertFalse(SeriesRange.is_valid_range_str("s2-s1")) + self.assertFalse(SeriesRange.is_valid_range_str("s-3-s1")) + + def test_range_from_string(self): + """Test that ranges are correctly create from strings.""" + + self.assertEqual(TestRange.from_str("1-3"), TestRange(1, 3)) + self.assertEqual(SeriesRange.from_str("s1-s3"), SeriesRange(1, 3)) + + def test_expand_range(self): + """Test that range expansion produces the correct sequence of test or series IDs.""" + + test_range = TestRange(1, 3) + self.assertEqual(list(test_range.expand()), [TestID("1"), TestID("2"), TestID("3")]) + + test_range = TestRange(1, 1) + self.assertEqual(list(test_range.expand()), [TestID("1")]) + + series_range = SeriesRange(1, 3) + self.assertEqual(list(series_range.expand()), [SeriesID("s1"), SeriesID("s2"), SeriesID("s3")]) + + series_range = SeriesRange(1, 1) + self.assertEqual(list(series_range.expand()), [SeriesID("s1")]) + + def test_resolve_ids(self): + """Test that heterogeneous lists of tests, series, and range strings are resolved to + the correct objects.""" + + inputs = [ + ["1"], ["1-3"], ["s1"], ["s1-s3"], ["all"], ["last"], ["1", "2", "all"], + ["1", "s3", "s4-s6", "2-3"], [] + ] + expected = [ + [TestID("1")], [TestID("1"), TestID("2"), TestID("3")], [SeriesID("s1")], + [SeriesID("s1"), SeriesID("s2"), SeriesID("s3")], [SeriesID("all")], + [SeriesID("last")], [SeriesID("all")], + [TestID("1"), SeriesID("s3"), SeriesID("s4"), SeriesID("s5"), SeriesID("s6"), + TestID("2"), TestID("3")], [] + ] + + for inp, exp in zip(inputs, expected): + self.assertEqual(resolve_ids(inp), exp)