Skip to content

Guards

Register and evaluate guards used in transitions.

GuardRegistry

GuardRegistry()

Registry for guard functions (sync and async).

Example

registry = GuardRegistry() registry.register("is_valid", lambda ctx: ctx.get("valid", False)) registry.evaluate("is_valid", {"valid": True}) True

Source code in src/pystator/guards.py
def __init__(self) -> None:
    self._guards: dict[str, AnyGuardFunc] = {}
    self._async_guards: set[str] = set()
    self._lock = threading.Lock()

register

register(name: str, func: AnyGuardFunc) -> None

Register a guard function. Thread-safe.

Source code in src/pystator/guards.py
def register(self, name: str, func: AnyGuardFunc) -> None:
    """Register a guard function. Thread-safe."""
    if not name:
        raise ValueError("Guard name cannot be empty")
    with self._lock:
        if name in self._guards:
            raise ValueError(f"Guard '{name}' is already registered")
        self._guards[name] = func
        if asyncio.iscoroutinefunction(func) or (
            hasattr(func, "__call__")
            and asyncio.iscoroutinefunction(getattr(func, "__call__", None))
        ):
            self._async_guards.add(name)

unregister

unregister(name: str) -> None

Unregister a guard function. Thread-safe.

Source code in src/pystator/guards.py
def unregister(self, name: str) -> None:
    """Unregister a guard function. Thread-safe."""
    with self._lock:
        if name not in self._guards:
            raise GuardNotFoundError(f"Guard '{name}' not found", guard_name=name)
        del self._guards[name]
        self._async_guards.discard(name)

evaluate

evaluate(name: str, context: dict[str, Any]) -> bool

Evaluate a single guard synchronously.

Source code in src/pystator/guards.py
def evaluate(self, name: str, context: dict[str, Any]) -> bool:
    """Evaluate a single guard synchronously."""
    return self.get(name)(context)  # type: ignore[return-value]

evaluate_all

evaluate_all(
    guards: tuple[str, ...] | list[str],
    context: dict[str, Any],
    fail_fast: bool = True,
) -> GuardResult

Evaluate multiple guards. Returns GuardResult.

Source code in src/pystator/guards.py
def evaluate_all(
    self,
    guards: tuple[str, ...] | list[str],
    context: dict[str, Any],
    fail_fast: bool = True,
) -> GuardResult:
    """Evaluate multiple guards. Returns GuardResult."""
    if not guards:
        return GuardResult.success()
    evaluated: list[tuple[str, bool]] = []
    for name in guards:
        result = self.evaluate(name, context)
        evaluated.append((name, result))
        if not result and fail_fast:
            return GuardResult.failure(name, evaluated=evaluated)
    for name, passed in evaluated:
        if not passed:
            return GuardResult.failure(name, evaluated=evaluated)
    return GuardResult.success(evaluated)

decorator

decorator(
    name: str | None = None,
) -> Callable[[AnyGuardFunc], AnyGuardFunc]

Decorator to register a guard function.

Example

@registry.decorator() ... def is_valid(ctx: dict) -> bool: ... return ctx.get("valid", False)

Source code in src/pystator/guards.py
def decorator(
    self, name: str | None = None
) -> Callable[[AnyGuardFunc], AnyGuardFunc]:
    """Decorator to register a guard function.

    Example:
        >>> @registry.decorator()
        ... def is_valid(ctx: dict) -> bool:
        ...     return ctx.get("valid", False)
    """

    def inner(func: AnyGuardFunc) -> AnyGuardFunc:
        self.register(name or func.__name__, func)
        return func

    return inner

GuardEvaluator

GuardEvaluator(
    registry: GuardRegistry | None = None,
    strict: bool = True,
)

Evaluates guard conditions for transitions.

Supports named guards (from registry) and inline expressions.

Source code in src/pystator/guards.py
def __init__(
    self,
    registry: GuardRegistry | None = None,
    strict: bool = True,
) -> None:
    self.registry = registry
    self.strict = strict

can_transition

can_transition(
    transition: Transition, context: dict[str, Any]
) -> GuardResult

Check if a transition is allowed based on its guards.

Source code in src/pystator/guards.py
def can_transition(
    self,
    transition: Transition,
    context: dict[str, Any],
) -> GuardResult:
    """Check if a transition is allowed based on its guards."""
    if not transition.guards:
        return GuardResult.success()
    return self._evaluate_guards(transition.guards, context)

get_required_guards

get_required_guards(
    transitions: list[Transition] | tuple[Transition, ...],
) -> set[str]

Return the set of named guard names required by the given transitions.

Excludes inline expression guards (only returns guard names from registry).

Source code in src/pystator/guards.py
def get_required_guards(
    self,
    transitions: list[Transition] | tuple[Transition, ...],
) -> set[str]:
    """Return the set of named guard names required by the given transitions.

    Excludes inline expression guards (only returns guard names from registry).
    """
    names: set[str] = set()
    for trans in transitions:
        for guard in trans.guards:
            if guard.name is not None:
                names.add(guard.name)
    return names

evaluate_and_raise

evaluate_and_raise(
    transition: Transition,
    current_state: str,
    context: dict[str, Any],
) -> None

Evaluate guards and raise GuardRejectedError if blocked.

Source code in src/pystator/guards.py
def evaluate_and_raise(
    self,
    transition: Transition,
    current_state: str,
    context: dict[str, Any],
) -> None:
    """Evaluate guards and raise GuardRejectedError if blocked."""
    result = self.can_transition(transition, context)
    if not result.passed:
        raise GuardRejectedError(
            message=result.message,
            current_state=current_state,
            trigger=transition.trigger,
            guard_name=result.guard_name or "unknown",
            guard_result=result.evaluated_guards,
        )

Built-in guard helpers:

equals

equals(
    key: str, value: Any
) -> Callable[[dict[str, Any]], bool]

Guard factory: context[key] == value.

Source code in src/pystator/guards.py
def equals(key: str, value: Any) -> Callable[[dict[str, Any]], bool]:
    """Guard factory: context[key] == value."""

    def guard(ctx: dict[str, Any]) -> bool:
        return ctx.get(key) == value

    return guard

greater_than

greater_than(
    key: str, value: Any
) -> Callable[[dict[str, Any]], bool]

Guard factory: context[key] > value.

Source code in src/pystator/guards.py
def greater_than(key: str, value: Any) -> Callable[[dict[str, Any]], bool]:
    """Guard factory: context[key] > value."""

    def guard(ctx: dict[str, Any]) -> bool:
        v = ctx.get(key)
        if v is None:
            return False
        try:
            return v > value
        except TypeError:
            return False

    return guard

in_list

in_list(
    key: str, values: list[Any]
) -> Callable[[dict[str, Any]], bool]

Guard factory: context[key] in values.

Source code in src/pystator/guards.py
def in_list(key: str, values: list[Any]) -> Callable[[dict[str, Any]], bool]:
    """Guard factory: context[key] in values."""

    def guard(ctx: dict[str, Any]) -> bool:
        return ctx.get(key) in values

    return guard

all_of

all_of(
    *guards: Callable[[dict[str, Any]], bool]
) -> Callable[[dict[str, Any]], bool]

Compound guard: all must pass.

Source code in src/pystator/guards.py
def all_of(
    *guards: Callable[[dict[str, Any]], bool]
) -> Callable[[dict[str, Any]], bool]:
    """Compound guard: all must pass."""

    def guard(ctx: dict[str, Any]) -> bool:
        return all(g(ctx) for g in guards)

    return guard

any_of

any_of(
    *guards: Callable[[dict[str, Any]], bool]
) -> Callable[[dict[str, Any]], bool]

Compound guard: at least one must pass.

Source code in src/pystator/guards.py
def any_of(
    *guards: Callable[[dict[str, Any]], bool]
) -> Callable[[dict[str, Any]], bool]:
    """Compound guard: at least one must pass."""

    def guard(ctx: dict[str, Any]) -> bool:
        return any(g(ctx) for g in guards)

    return guard

negate

negate(
    guard_fn: Callable[[dict[str, Any]], bool],
) -> Callable[[dict[str, Any]], bool]

Guard that negates the result of another guard.

Source code in src/pystator/guards.py
def negate(
    guard_fn: Callable[[dict[str, Any]], bool],
) -> Callable[[dict[str, Any]], bool]:
    """Guard that negates the result of another guard."""

    def guard(ctx: dict[str, Any]) -> bool:
        return not guard_fn(ctx)

    return guard