Source code for noob.input
import re
import warnings
from collections import ChainMap, defaultdict
from enum import StrEnum
from functools import cached_property
from typing import Any, ClassVar, overload
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from noob.edge import Edge
from noob.exceptions import ExtraInputWarning, InputMissingError
from noob.types import AbsoluteIdentifier, PythonIdentifier
from noob.yaml import id_optional_json_schema
[docs]
class InputScope(StrEnum):
"""The scope that input must be provided in"""
tube = "tube"
process = "process"
[docs]
class InputSpecification(BaseModel):
"""
Specification of inputs to a noob tube.
Inputs can be supplied at different times and frequencies,
as specified by `scope`:
- `tube`: When instantiating the tube
- `process`: Per call to :meth:`.TubeRunner.process`
`tube`-scoped inputs may be used in a node's `param` specification,
and `process`-scoped inputs may be used as one of a node's `depends`.
Inputs can be supplied at a "higher" scope and be accessed by lower scopes:
e.g. input requested with a process scope can use input provided when instantiating the tube,
if not provided to process but provided to the tube.
"""
id: PythonIdentifier
type_: AbsoluteIdentifier = Field(..., alias="type")
scope: InputScope = InputScope.tube
default: Any | None = None
description: str | None = None
"""An optional description of the input value"""
model_config = ConfigDict(extra="forbid")
__get_pydantic_json_schema__ = classmethod(id_optional_json_schema) # type: ignore[var-annotated]
[docs]
class InputCollection(BaseModel):
"""
A collection of input specifications used during runtime, split by scope,
to validate presence of and to combine inputs.
"""
INPUT_PATTERN: ClassVar[re.Pattern] = re.compile(r"input\.(?P<key>.*)")
specs: dict[InputScope, dict[PythonIdentifier, InputSpecification]] = Field(
default_factory=lambda: defaultdict(dict) # type: ignore[arg-type]
)
# store long-lived scope inputs
_input: dict[InputScope, dict] = PrivateAttr(default_factory=lambda: defaultdict(dict))
_chain: ChainMap | None = None
@property
def chain(self) -> ChainMap:
"""
Make a chainmap of inputs at different scopes
(for possible expansion of number of scopes, to e.g. a runner scope)
"""
if self._chain is None:
self._chain = ChainMap(self._input[InputScope.tube], self.defaults)
return self._chain
[docs]
@cached_property
def defaults(self) -> dict:
"""A dict of any of the defaults set in tube specs"""
defaults: dict[str, str] = {}
for specs in self.specs.values():
# dump to distinguish between explicit nones and unset defaults
dumped = {key: spec.model_dump(exclude_unset=True) for key, spec in specs.items()}
defaults |= {key: spec["default"] for key, spec in dumped.items() if "default" in spec}
return defaults
[docs]
def get(self, key: str, input: dict | None = None) -> Any:
"""Get a value from the inputs at any scope, if present"""
if input is None:
input = {}
return self.chain.new_child(input)[key]
@overload
def get_node_params(self, params: dict) -> dict: ...
@overload
def get_node_params(self, params: list) -> list: ...
[docs]
def get_node_params(self, params: dict | list) -> dict | list:
"""Get tube-scoped params specified as inputs needed when instantiating a node"""
if isinstance(params, list):
list_params = []
for param in params:
if isinstance(param, list):
list_params.append(self.get_node_params(param))
elif isinstance(param, str) and (
match := InputCollection.INPUT_PATTERN.fullmatch(param)
):
input_key = match.groupdict()["key"]
try:
params.append(self.get(input_key))
except KeyError as e:
raise InputMissingError(
f"Node params requested {input_key}, but not present in input"
) from e
else:
list_params.append(param)
return list_params
else:
for k, v in params.items():
if not isinstance(v, str):
continue
if match := InputCollection.INPUT_PATTERN.fullmatch(v):
input_key = match.groupdict()["key"]
try:
params[k] = self.get(input_key)
except KeyError as e:
raise InputMissingError(
f"Node params requested {input_key}, but not present in input"
) from e
return params
[docs]
def collect(self, edges: list[Edge], input: dict) -> dict:
args = {}
for edge in edges:
if edge.source_node != "input":
continue
try:
assert (
edge.source_signal is not None
), "Must specify an input signal (there is no generic 'value' signal for inputs)"
args[edge.target_slot] = self.get(edge.source_signal, input)
except KeyError as e:
raise InputMissingError(
f"Node depends on input {edge.source_signal}, "
"but not provided in any input scope"
) from e
return args
[docs]
def add_input(self, scope: InputScope, input: dict) -> None:
"""Add some scope's input to the input collection."""
if scope == InputScope.process:
raise ValueError("Can't store process-scoped input, since it is ephemeral")
if isinstance(scope, str) and scope in InputScope.__members__:
scope = getattr(InputScope, scope)
if not isinstance(scope, InputScope):
raise ValueError(f"Unknown scope: {scope}")
new = {**self._input[scope], **input}
new = self.validate_input(scope, new)
self._input[scope] = new
self._chain = None
[docs]
def validate_input(self, scope: InputScope, input: dict) -> dict:
"""
Check that the required inputs are present in one of several input dicts,
and then filter to only specified input
"""
if scope not in self.specs:
# no input specs for this scope
if input:
warnings.warn(
f"Ignoring extra input for a scope `{scope.value}` "
"without any input specifications.",
ExtraInputWarning,
stacklevel=3,
)
return {}
input = self.filter_input(scope, input)
chain = self.chain.new_child(input)
for spec in self.specs[scope].values():
if spec.id not in chain:
raise InputMissingError(
f"Requested input {spec.id} not present in inputs at scope {scope.value}"
)
return input
[docs]
def filter_input(self, scope: InputScope, input: dict) -> dict:
"""filter input to only specified keys, emitting an ExtraInput warning if found."""
if scope not in self.specs:
warnings.warn(
f"Ignoring extra input for a scope `{scope.value}` "
"without any input specifications.",
ExtraInputWarning,
stacklevel=3,
)
return {}
filtered = {k: v for k, v in input.items() if k in self.specs[scope]}
if len(input) > len(filtered):
extra = set(input.keys()) - set(filtered.keys())
warnings.warn(
f"Ignoring extra input without a specification provided to scope "
f"`{scope.value}`: {extra}",
ExtraInputWarning,
stacklevel=3,
)
return filtered