Skip to content

Guards

Register and evaluate guards used in transitions.

GuardRegistry

GuardRegistry()

Registry for guard functions (sync and async).

Example

registry = GuardRegistry() registry.register("is_valid", lambda ctx: ctx.get("valid", False)) registry.evaluate("is_valid", {"valid": True}) True

Source code in src/pystator/guards.py
def __init__(self) -> None:
    self._guards: dict[str, AnyGuardFunc] = {}
    self._async_guards: set[str] = set()
    self._lock = threading.Lock()

register

register(name: str, func: AnyGuardFunc) -> None

Register a guard function. Thread-safe.

Source code in src/pystator/guards.py
def register(self, name: str, func: AnyGuardFunc) -> None:
    """Register a guard function. Thread-safe."""
    if not name:
        raise ValueError("Guard name cannot be empty")
    with self._lock:
        if name in self._guards:
            raise ValueError(f"Guard '{name}' is already registered")
        self._guards[name] = func
        if asyncio.iscoroutinefunction(func) or (
            callable(func)
            and asyncio.iscoroutinefunction(getattr(func, "__call__", None))  # noqa: B004
        ):
            self._async_guards.add(name)

unregister

unregister(name: str) -> None

Unregister a guard function. Thread-safe.

Source code in src/pystator/guards.py
def unregister(self, name: str) -> None:
    """Unregister a guard function. Thread-safe."""
    with self._lock:
        if name not in self._guards:
            raise GuardNotFoundError(f"Guard '{name}' not found", guard_name=name)
        del self._guards[name]
        self._async_guards.discard(name)

evaluate

evaluate(name: str, context: dict[str, Any]) -> bool

Evaluate a single guard synchronously.

Source code in src/pystator/guards.py
def evaluate(self, name: str, context: dict[str, Any]) -> bool:
    """Evaluate a single guard synchronously."""
    return self.get(name)(context)  # type: ignore[return-value]

evaluate_all

evaluate_all(
    guards: tuple[str, ...] | list[str],
    context: dict[str, Any],
    fail_fast: bool = True,
) -> GuardResult

Evaluate multiple guards. Returns GuardResult.

Source code in src/pystator/guards.py
def evaluate_all(
    self,
    guards: tuple[str, ...] | list[str],
    context: dict[str, Any],
    fail_fast: bool = True,
) -> GuardResult:
    """Evaluate multiple guards. Returns GuardResult."""
    if not guards:
        return GuardResult.success()
    evaluated: list[tuple[str, bool]] = []
    for name in guards:
        result = self.evaluate(name, context)
        evaluated.append((name, result))
        if not result and fail_fast:
            return GuardResult.failure(name, evaluated=evaluated)
    for name, passed in evaluated:
        if not passed:
            return GuardResult.failure(name, evaluated=evaluated)
    return GuardResult.success(evaluated)

decorator

decorator(
    name: str | None = None,
) -> Callable[[AnyGuardFunc], AnyGuardFunc]

Decorator to register a guard function.

Example

@registry.decorator() ... def is_valid(ctx: dict) -> bool: ... return ctx.get("valid", False)

Source code in src/pystator/guards.py
def decorator(
    self, name: str | None = None
) -> Callable[[AnyGuardFunc], AnyGuardFunc]:
    """Decorator to register a guard function.

    Example:
        >>> @registry.decorator()
        ... def is_valid(ctx: dict) -> bool:
        ...     return ctx.get("valid", False)
    """

    def inner(func: AnyGuardFunc) -> AnyGuardFunc:
        self.register(name or func.__name__, func)
        return func

    return inner

GuardEvaluator

GuardEvaluator(
    registry: GuardRegistry | None = None,
    strict: bool = True,
)

Evaluates guard conditions for transitions.

Supports named guards (from registry) and inline expressions.

Source code in src/pystator/guards.py
def __init__(
    self,
    registry: GuardRegistry | None = None,
    strict: bool = True,
) -> None:
    self.registry = registry
    self.strict = strict

can_transition

can_transition(
    transition: Transition, context: dict[str, Any]
) -> GuardResult

Check if a transition is allowed based on its guards.

Source code in src/pystator/guards.py
def can_transition(
    self,
    transition: Transition,
    context: dict[str, Any],
) -> GuardResult:
    """Check if a transition is allowed based on its guards."""
    if not transition.guards:
        return GuardResult.success()
    return self._evaluate_guards(transition.guards, context)

get_required_guards

get_required_guards(
    transitions: list[Transition] | tuple[Transition, ...],
) -> set[str]

Return the set of named guard names required by the given transitions.

Recurses into composite guards (allOf, anyOf, not). Excludes inline expression guards (only returns guard names from registry).

Source code in src/pystator/guards.py
def get_required_guards(
    self,
    transitions: list[Transition] | tuple[Transition, ...],
) -> set[str]:
    """Return the set of named guard names required by the given transitions.

    Recurses into composite guards (allOf, anyOf, not).
    Excludes inline expression guards (only returns guard names from registry).
    """
    names: set[str] = set()
    for trans in transitions:
        for guard in trans.guards:
            _collect_guard_names(guard, names)
    return names

Built-in guard helpers:

equals

equals(
    key: str, value: Any
) -> Callable[[dict[str, Any]], bool]

Guard factory: context[key] == value.

Source code in src/pystator/guards.py
def equals(key: str, value: Any) -> Callable[[dict[str, Any]], bool]:
    """Guard factory: context[key] == value."""

    def guard(ctx: dict[str, Any]) -> bool:
        return ctx.get(key) == value

    return guard

greater_than

greater_than(
    key: str, value: Any
) -> Callable[[dict[str, Any]], bool]

Guard factory: context[key] > value.

Source code in src/pystator/guards.py
def greater_than(key: str, value: Any) -> Callable[[dict[str, Any]], bool]:
    """Guard factory: context[key] > value."""

    def guard(ctx: dict[str, Any]) -> bool:
        v = ctx.get(key)
        if v is None:
            return False
        try:
            return v > value
        except TypeError:
            return False

    return guard

in_list

in_list(
    key: str, values: list[Any]
) -> Callable[[dict[str, Any]], bool]

Guard factory: context[key] in values.

Source code in src/pystator/guards.py
def in_list(key: str, values: list[Any]) -> Callable[[dict[str, Any]], bool]:
    """Guard factory: context[key] in values."""

    def guard(ctx: dict[str, Any]) -> bool:
        return ctx.get(key) in values

    return guard

all_of

all_of(
    *guards: Callable[[dict[str, Any]], bool],
) -> Callable[[dict[str, Any]], bool]

Compound guard: all must pass.

Source code in src/pystator/guards.py
def all_of(
    *guards: Callable[[dict[str, Any]], bool],
) -> Callable[[dict[str, Any]], bool]:
    """Compound guard: all must pass."""

    def guard(ctx: dict[str, Any]) -> bool:
        return all(g(ctx) for g in guards)

    return guard

any_of

any_of(
    *guards: Callable[[dict[str, Any]], bool],
) -> Callable[[dict[str, Any]], bool]

Compound guard: at least one must pass.

Source code in src/pystator/guards.py
def any_of(
    *guards: Callable[[dict[str, Any]], bool],
) -> Callable[[dict[str, Any]], bool]:
    """Compound guard: at least one must pass."""

    def guard(ctx: dict[str, Any]) -> bool:
        return any(g(ctx) for g in guards)

    return guard

negate

negate(
    guard_fn: Callable[[dict[str, Any]], bool],
) -> Callable[[dict[str, Any]], bool]

Guard that negates the result of another guard.

Source code in src/pystator/guards.py
def negate(
    guard_fn: Callable[[dict[str, Any]], bool],
) -> Callable[[dict[str, Any]], bool]:
    """Guard that negates the result of another guard."""

    def guard(ctx: dict[str, Any]) -> bool:
        return not guard_fn(ctx)

    return guard