Skip to content

Guards

Register and evaluate guards used in transitions.

GuardRegistry

GuardRegistry()

Registry for guard functions (sync and async).

Example

registry = GuardRegistry() registry.register("is_valid", lambda ctx: ctx.get("valid", False)) registry.evaluate("is_valid", {"valid": True}) True

Source code in src/pystator/guards.py
def __init__(self) -> None:
    self._guards: dict[str, AnyGuardFunc] = {}
    self._async_guards: set[str] = set()
    self._lock = threading.Lock()

register

register(name: str, func: AnyGuardFunc) -> None

Register a guard function. Thread-safe.

Source code in src/pystator/guards.py
def register(self, name: str, func: AnyGuardFunc) -> None:
    """Register a guard function. Thread-safe."""
    if not name:
        raise ValueError("Guard name cannot be empty")
    with self._lock:
        if name in self._guards:
            raise ValueError(f"Guard '{name}' is already registered")
        self._guards[name] = func
        if asyncio.iscoroutinefunction(func) or (
            callable(func)
            and asyncio.iscoroutinefunction(getattr(func, "__call__", None))  # noqa: B004
        ):
            self._async_guards.add(name)

unregister

unregister(name: str) -> None

Unregister a guard function. Thread-safe.

Source code in src/pystator/guards.py
def unregister(self, name: str) -> None:
    """Unregister a guard function. Thread-safe."""
    with self._lock:
        if name not in self._guards:
            raise GuardNotFoundError(f"Guard '{name}' not found", guard_name=name)
        del self._guards[name]
        self._async_guards.discard(name)

get

get(name: str) -> AnyGuardFunc

Look up a registered guard callable by name.

Parameters:

Name Type Description Default
name str

The registered guard name.

required

Returns:

Type Description
AnyGuardFunc

The guard callable.

Raises:

Type Description
GuardNotFoundError

If no guard is registered under name.

Source code in src/pystator/guards.py
def get(self, name: str) -> AnyGuardFunc:
    """Look up a registered guard callable by name.

    Args:
        name: The registered guard name.

    Returns:
        The guard callable.

    Raises:
        GuardNotFoundError: If no guard is registered under *name*.
    """
    if name not in self._guards:
        raise GuardNotFoundError(f"Guard '{name}' not found", guard_name=name)
    return self._guards[name]

has

has(name: str) -> bool

Check whether a guard is registered.

Source code in src/pystator/guards.py
def has(self, name: str) -> bool:
    """Check whether a guard is registered."""
    return name in self._guards

is_async

is_async(name: str) -> bool

Return True if the named guard is an async function.

Source code in src/pystator/guards.py
def is_async(self, name: str) -> bool:
    """Return True if the named guard is an async function."""
    return name in self._async_guards

evaluate

evaluate(name: str, context: dict[str, Any]) -> bool

Evaluate a single guard synchronously.

Source code in src/pystator/guards.py
def evaluate(self, name: str, context: dict[str, Any]) -> bool:
    """Evaluate a single guard synchronously."""
    return self.get(name)(context)  # type: ignore[return-value]

evaluate_all

evaluate_all(
    guards: tuple[str, ...] | list[str],
    context: dict[str, Any],
    fail_fast: bool = True,
) -> GuardResult

Evaluate multiple guards. Returns GuardResult.

Source code in src/pystator/guards.py
def evaluate_all(
    self,
    guards: tuple[str, ...] | list[str],
    context: dict[str, Any],
    fail_fast: bool = True,
) -> GuardResult:
    """Evaluate multiple guards. Returns GuardResult."""
    if not guards:
        return GuardResult.success()
    evaluated: list[tuple[str, bool]] = []
    for name in guards:
        result = self.evaluate(name, context)
        evaluated.append((name, result))
        if not result and fail_fast:
            return GuardResult.failure(name, evaluated=evaluated)
    for name, passed in evaluated:
        if not passed:
            return GuardResult.failure(name, evaluated=evaluated)
    return GuardResult.success(evaluated)

async_evaluate async

async_evaluate(name: str, context: dict[str, Any]) -> bool

Evaluate a single guard asynchronously.

Falls back to synchronous execution if the guard is not async.

Parameters:

Name Type Description Default
name str

Registered guard name.

required
context dict[str, Any]

Context dict passed to the guard callable.

required

Returns:

Type Description
bool

True if the guard passed, False otherwise.

Source code in src/pystator/guards.py
async def async_evaluate(self, name: str, context: dict[str, Any]) -> bool:
    """Evaluate a single guard asynchronously.

    Falls back to synchronous execution if the guard is not async.

    Args:
        name: Registered guard name.
        context: Context dict passed to the guard callable.

    Returns:
        True if the guard passed, False otherwise.
    """
    func = self.get(name)
    if self.is_async(name):
        return await func(context)  # type: ignore[misc]
    return func(context)  # type: ignore[return-value]

async_evaluate_all async

async_evaluate_all(
    guards: (
        tuple[GuardSpec, ...] | tuple[str, ...] | list[str]
    ),
    context: dict[str, Any],
    fail_fast: bool = True,
) -> GuardResult

Evaluate multiple guards asynchronously.

Parameters:

Name Type Description Default
guards tuple[GuardSpec, ...] | tuple[str, ...] | list[str]

Guard specs or names to evaluate.

required
context dict[str, Any]

Context dict passed to each guard.

required
fail_fast bool

If True, stop on first failing guard.

True

Returns:

Type Description
GuardResult

GuardResult summarising all evaluations.

Source code in src/pystator/guards.py
async def async_evaluate_all(
    self,
    guards: tuple[GuardSpec, ...] | tuple[str, ...] | list[str],
    context: dict[str, Any],
    fail_fast: bool = True,
) -> GuardResult:
    """Evaluate multiple guards asynchronously.

    Args:
        guards: Guard specs or names to evaluate.
        context: Context dict passed to each guard.
        fail_fast: If True, stop on first failing guard.

    Returns:
        GuardResult summarising all evaluations.
    """
    if not guards:
        return GuardResult.success()
    evaluated: list[tuple[str, bool]] = []
    for guard in guards:
        name = guard.name if isinstance(guard, GuardSpec) else guard
        if name is None:
            continue
        eval_ctx = context
        if isinstance(guard, GuardSpec) and guard.has_params:
            eval_ctx = {**context, GUARD_PARAMS_KEY: guard.params}
        result = await self.async_evaluate(name, eval_ctx)
        evaluated.append((name, result))
        if not result and fail_fast:
            return GuardResult.failure(name, evaluated=evaluated)
    for name, passed in evaluated:
        if not passed:
            return GuardResult.failure(name, evaluated=evaluated)
    return GuardResult.success(evaluated)

list_guards

list_guards() -> list[str]

Return all registered guard names.

Source code in src/pystator/guards.py
def list_guards(self) -> list[str]:
    """Return all registered guard names."""
    return list(self._guards.keys())

clear

clear() -> None

Remove all registered guards.

Source code in src/pystator/guards.py
def clear(self) -> None:
    """Remove all registered guards."""
    self._guards.clear()
    self._async_guards.clear()

decorator

decorator(
    name: str | None = None,
) -> Callable[[AnyGuardFunc], AnyGuardFunc]

Decorator to register a guard function.

Example

@registry.decorator() ... def is_valid(ctx: dict) -> bool: ... return ctx.get("valid", False)

Source code in src/pystator/guards.py
def decorator(
    self, name: str | None = None
) -> Callable[[AnyGuardFunc], AnyGuardFunc]:
    """Decorator to register a guard function.

    Example:
        >>> @registry.decorator()
        ... def is_valid(ctx: dict) -> bool:
        ...     return ctx.get("valid", False)
    """

    def inner(func: AnyGuardFunc) -> AnyGuardFunc:
        self.register(name or func.__name__, func)
        return func

    return inner

GuardEvaluator

GuardEvaluator(
    registry: GuardRegistry | None = None,
    strict: bool = True,
)

Evaluates guard conditions for transitions.

Supports named guards (from registry) and inline expressions.

Parameters:

Name Type Description Default
registry GuardRegistry | None

Optional guard registry for named guard lookup.

None
strict bool

If True, raise on missing guards instead of passing.

True
Source code in src/pystator/guards.py
def __init__(
    self,
    registry: GuardRegistry | None = None,
    strict: bool = True,
) -> None:
    """Initialize the guard evaluator.

    Args:
        registry: Optional guard registry for named guard lookup.
        strict: If True, raise on missing guards instead of passing.
    """
    self.registry = registry
    self.strict = strict

can_transition

can_transition(
    transition: Transition, context: dict[str, Any]
) -> GuardResult

Check if a transition is allowed based on its guards.

Source code in src/pystator/guards.py
def can_transition(
    self,
    transition: Transition,
    context: dict[str, Any],
) -> GuardResult:
    """Check if a transition is allowed based on its guards."""
    if not transition.guards:
        return GuardResult.success()
    return self._evaluate_guards(transition.guards, context)

get_required_guards

get_required_guards(
    transitions: list[Transition] | tuple[Transition, ...],
) -> set[str]

Return the set of named guard names required by the given transitions.

Recurses into composite guards (allOf, anyOf, not). Excludes inline expression guards (only returns guard names from registry).

Source code in src/pystator/guards.py
def get_required_guards(
    self,
    transitions: list[Transition] | tuple[Transition, ...],
) -> set[str]:
    """Return the set of named guard names required by the given transitions.

    Recurses into composite guards (allOf, anyOf, not).
    Excludes inline expression guards (only returns guard names from registry).
    """
    names: set[str] = set()
    for trans in transitions:
        for guard in trans.guards:
            _collect_guard_names(guard, names)
    return names

Built-in guard helpers:

equals

equals(
    key: str, value: Any
) -> Callable[[dict[str, Any]], bool]

Guard factory: context[key] == value.

Source code in src/pystator/guards.py
def equals(key: str, value: Any) -> Callable[[dict[str, Any]], bool]:
    """Guard factory: context[key] == value."""

    def guard(ctx: dict[str, Any]) -> bool:
        return ctx.get(key) == value

    return guard

greater_than

greater_than(
    key: str, value: Any
) -> Callable[[dict[str, Any]], bool]

Guard factory: context[key] > value.

Source code in src/pystator/guards.py
def greater_than(key: str, value: Any) -> Callable[[dict[str, Any]], bool]:
    """Guard factory: context[key] > value."""

    def guard(ctx: dict[str, Any]) -> bool:
        v = ctx.get(key)
        if v is None:
            return False
        try:
            return v > value
        except TypeError:
            return False

    return guard

in_list

in_list(
    key: str, values: list[Any]
) -> Callable[[dict[str, Any]], bool]

Guard factory: context[key] in values.

Source code in src/pystator/guards.py
def in_list(key: str, values: list[Any]) -> Callable[[dict[str, Any]], bool]:
    """Guard factory: context[key] in values."""

    def guard(ctx: dict[str, Any]) -> bool:
        return ctx.get(key) in values

    return guard

all_of

all_of(
    *guards: Callable[[dict[str, Any]], bool]
) -> Callable[[dict[str, Any]], bool]

Compound guard: all must pass.

Source code in src/pystator/guards.py
def all_of(
    *guards: Callable[[dict[str, Any]], bool],
) -> Callable[[dict[str, Any]], bool]:
    """Compound guard: all must pass."""

    def guard(ctx: dict[str, Any]) -> bool:
        return all(g(ctx) for g in guards)

    return guard

any_of

any_of(
    *guards: Callable[[dict[str, Any]], bool]
) -> Callable[[dict[str, Any]], bool]

Compound guard: at least one must pass.

Source code in src/pystator/guards.py
def any_of(
    *guards: Callable[[dict[str, Any]], bool],
) -> Callable[[dict[str, Any]], bool]:
    """Compound guard: at least one must pass."""

    def guard(ctx: dict[str, Any]) -> bool:
        return any(g(ctx) for g in guards)

    return guard

negate

negate(
    guard_fn: Callable[[dict[str, Any]], bool],
) -> Callable[[dict[str, Any]], bool]

Guard that negates the result of another guard.

Source code in src/pystator/guards.py
def negate(
    guard_fn: Callable[[dict[str, Any]], bool],
) -> Callable[[dict[str, Any]], bool]:
    """Guard that negates the result of another guard."""

    def guard(ctx: dict[str, Any]) -> bool:
        return not guard_fn(ctx)

    return guard