diff --git a/.github/workflows/docs-update.yaml b/.github/workflows/docs-update.yaml index a25107c..a1edc9a 100644 --- a/.github/workflows/docs-update.yaml +++ b/.github/workflows/docs-update.yaml @@ -1,5 +1,5 @@ --- -name: Notify Documentation Update +name: Trigger Docs Update on: push: @@ -21,15 +21,14 @@ jobs: private-key: ${{ secrets.UPDATE_DOCS_PRIVATE_KEY }} owner: "${{ github.repository_owner }}" repositories: | - sdk - prod-docs + docs - name: Trigger docs repository workflow uses: peter-evans/repository-dispatch@ff45666b9427631e3450c54a1bcbee4d9ff4d7c0 # v3.0.0 with: token: ${{ steps.app-token.outputs.token }} - repository: dreadnode/prod-docs - event-type: code-update + repository: dreadnode/docs + event-type: docs-update client-payload: | { "repository": "${{ github.repository }}", diff --git a/docs/api/chat.mdx b/docs/api/chat.mdx index 41523aa..bfe6da3 100644 --- a/docs/api/chat.mdx +++ b/docs/api/chat.mdx @@ -1263,7 +1263,7 @@ def __init__( """How to handle failures in the pipeline unless overridden in calls.""" self.caching: CacheMode | None = None """How to handle cache_control entries on messages.""" - self.task_name: str = generator.to_identifier(short=True) + self.task_name: str = f"Chat with {generator.to_identifier(short=True)}" """The name of the pipeline task, used for logging and debugging.""" self.scorers: list[Scorer[Chat]] = [] """List of dreadnode scorers to evaluate the generated chat upon completion.""" @@ -1360,7 +1360,7 @@ List of dreadnode scorers to evaluate the generated chat upon completion. ### task\_name ```python -task_name: str = to_identifier(short=True) +task_name: str = f'Chat with {to_identifier(short=True)}' ``` The name of the pipeline task, used for logging and debugging. @@ -1935,7 +1935,7 @@ def map( if callback in [c[0] for c in self.map_callbacks]: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.map_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks]) @@ -2146,8 +2146,9 @@ async def run( last: PipelineStep | None = None with dn.task_span( - name or f"pipeline - {self.task_name}", + name or self.task_name, label=name or f"pipeline_{self.task_name}", + tags=["rigging/pipeline"], attributes={"rigging.type": "chat_pipeline.run"}, ) as task: dn.log_inputs( @@ -2279,8 +2280,9 @@ async def run_batch( last: PipelineStep | None = None with dn.task_span( - name or f"pipeline - {self.task_name} (batch x{count})", + name or f"{self.task_name} (batch x{count})", label=name or f"pipeline_batch_{self.task_name}", + tags=["rigging/pipeline"], attributes={"rigging.type": "chat_pipeline.run_batch"}, ) as task: dn.log_inputs( @@ -2426,8 +2428,9 @@ async def run_many( last: PipelineStep | None = None with dn.task_span( - name or f"pipeline - {self.task_name} (x{count})", + name or f"{self.task_name} (x{count})", label=name or f"pipeline_many_{self.task_name}", + tags=["rigging/pipeline"], attributes={"rigging.type": "chat_pipeline.run_many"}, ) as task: dn.log_inputs( @@ -2968,7 +2971,7 @@ def then( if callback in [c[0] for c in self.then_callbacks]: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.then_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks]) @@ -3068,7 +3071,7 @@ def transform( for callback in callbacks: if not allow_duplicates and callback in self.transforms: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.transforms.extend(callbacks) @@ -3442,7 +3445,7 @@ def watch( for callback in callbacks: if not allow_duplicates and callback in self.watch_callbacks: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.watch_callbacks.extend(callbacks) diff --git a/docs/api/message.mdx b/docs/api/message.mdx index 87fe06b..20f9d3d 100644 --- a/docs/api/message.mdx +++ b/docs/api/message.mdx @@ -1924,6 +1924,51 @@ def replace_with_slice( ``` + + +### shorten + +```python +shorten(max_length: int, sep: str = '...') -> Message +``` + +Shortens the message content to at most max\_length characters long by removing the middle of the string + +**Parameters:** + +* **`max_length`** + (`int`) + –The maximum length of the message content. +* **`sep`** + (`str`, default: + `'...'` + ) + –The separator to use when shortening the content. + +**Returns:** + +* `Message` + –The shortened message. + + +```python +def shorten(self, max_length: int, sep: str = "...") -> "Message": + """ + Shortens the message content to at most max_length characters long by removing the middle of the string + + Args: + max_length: The maximum length of the message content. + sep: The separator to use when shortening the content. + + Returns: + The shortened message. + """ + new = self.clone() + new.content = shorten_string(new.content, max_length, sep=sep) + return new +``` + + ### strip @@ -2442,8 +2487,9 @@ Returns a string representation of the slice. ```python def __str__(self) -> str: """Returns a string representation of the slice.""" - content_preview = self.content if self._message else "[detached]" - return f"" + content = shorten_string(self.content if self._message else "[detached]", 50) + obj = self.obj.__class__.__name__ if self.obj else None + return f"MessageSlice(type='{self.type}', start={self.start}, stop={self.stop} obj={obj} content='{content}')" ``` diff --git a/docs/api/model.mdx b/docs/api/model.mdx index 83284d0..b110948 100644 --- a/docs/api/model.mdx +++ b/docs/api/model.mdx @@ -207,7 +207,13 @@ def from_text( try: model = ( - cls(**{next(iter(cls.model_fields)): unescape_xml(inner)}) + cls( + **{ + next(iter(cls.model_fields)): unescape_xml( + textwrap.dedent(inner).strip() + ) + } + ) if cls.is_simple() else cls.from_xml( cls.preprocess_with_cdata(full_text), @@ -217,7 +223,7 @@ def from_text( # If our model is relatively simple (only attributes and a single non-element field) # we should go back and update our non-element field with the extracted content. - if cls.is_simple_with_attrs(): + if not cls.is_simple() and cls.is_simple_with_attrs(): name, field = next( (name, field) for name, field in cls.model_fields.items() @@ -228,6 +234,14 @@ def from_text( unescape_xml(inner).strip(), ) + # Walk through any fields which are strings, and dedent them + + for field_name, field_info in cls.model_fields.items(): + if isinstance(field_info, XmlEntityInfo) and field_info.annotation == str: # noqa: E721 + model.__dict__[field_name] = textwrap.dedent( + model.__dict__[field_name] + ).strip() + extracted.append((model, slice_)) except Exception as e: # noqa: BLE001 extracted.append((e, slice_)) @@ -485,7 +499,7 @@ def preprocess_with_cdata(cls, content: str) -> str: needs_escaping = escape_xml(unescape_xml(content)) != content if is_basic_field and not is_already_cdata and needs_escaping: - content = f"" + content = f"" return f"<{field_name}{tag_attrs}>{content}" @@ -514,7 +528,7 @@ to_pretty_xml( skip_empty: bool = False, exclude_none: bool = False, exclude_unset: bool = False, - **kwargs: Any, + **_: Any, ) -> str ``` @@ -533,7 +547,7 @@ def to_pretty_xml( skip_empty: bool = False, exclude_none: bool = False, exclude_unset: bool = False, - **kwargs: t.Any, + **_: t.Any, ) -> str: """ Converts the model to a pretty XML string with indents and newlines. @@ -546,22 +560,7 @@ def to_pretty_xml( exclude_none=exclude_none, exclude_unset=exclude_unset, ) - tree = self._postprocess_with_cdata(tree) - - ET.indent(tree, " ") - pretty_encoded_xml = str( - ET.tostring( - tree, - short_empty_elements=False, - encoding="utf-8", - **kwargs, - ).decode(), - ) - - # Now we can go back and safely unescape the XML - # that we observe between any CDATA tags - - return unescape_cdata_tags(pretty_encoded_xml) + return self._serialize_tree_prettily(tree) ``` @@ -676,14 +675,19 @@ xml_example() -> str Returns an example XML representation of the given class. -Models should typically override this method to provide a more complex example. +This method generates a pretty-printed XML string that includes: +- Example values for each field, taken from the `example` argument +in a field constructor. +- Field descriptions as XML comments, derived from the field's +docstring or the `description` argument. -By default, this method returns a hollow XML scaffold one layer deep. +Note: This implementation is designed for models with flat structures +and does not recursively generate examples for nested models. **Returns:** * `str` - –A string containing the XML representation of the class. + –A string containing the pretty-printed XML example. ```python @@ -692,27 +696,55 @@ def xml_example(cls) -> str: """ Returns an example XML representation of the given class. - Models should typically override this method to provide a more complex example. + This method generates a pretty-printed XML string that includes: + - Example values for each field, taken from the `example` argument + in a field constructor. + - Field descriptions as XML comments, derived from the field's + docstring or the `description` argument. - By default, this method returns a hollow XML scaffold one layer deep. + Note: This implementation is designed for models with flat structures + and does not recursively generate examples for nested models. Returns: - A string containing the XML representation of the class. + A string containing the pretty-printed XML example. """ if cls.is_simple(): - return cls.xml_tags() - - schema = cls.model_json_schema() - properties = schema["properties"] - structure = {cls.__xml_tag__: dict.fromkeys(properties)} - xml_string = xmltodict.unparse( - structure, - pretty=True, - full_document=False, - indent=" ", - short_empty_elements=True, - ) - return t.cast("str", xml_string) # Bad type hints in xmltodict + field_info = next(iter(cls.model_fields.values())) + example = str(next(iter(field_info.examples or []), "")) + return f"<{cls.__xml_tag__}>{escape_xml(example)}" + + lines = [] + attribute_parts = [] + element_fields = {} + + for field_name, field_info in cls.model_fields.items(): + if ( + isinstance(field_info, XmlEntityInfo) + and field_info.location == EntityLocation.ATTRIBUTE + ): + path = field_info.path or field_name + example = str(next(iter(field_info.examples or []), "")).replace('"', """) + attribute_parts.append(f'{path}="{example}"') + else: + element_fields[field_name] = field_info + + attr_string = (" " + " ".join(attribute_parts)) if attribute_parts else "" + lines.append(f"<{cls.__xml_tag__}{attr_string}>") + + for field_name, field_info in element_fields.items(): + path = (isinstance(field_info, XmlEntityInfo) and field_info.path) or field_name + description = field_info.description + example = str(next(iter(field_info.examples or []), "")) + + if description: + lines.append(f" ") + if example: + lines.append(f" <{path}>{escape_xml(example)}") + else: + lines.append(f" <{path}/>") + + lines.append(f"") + return "\n".join(lines) ``` diff --git a/docs/api/prompt.mdx b/docs/api/prompt.mdx index bbf3ef9..68c010b 100644 --- a/docs/api/prompt.mdx +++ b/docs/api/prompt.mdx @@ -277,9 +277,10 @@ def bind( ) async def run(*args: P.args, **kwargs: P.kwargs) -> R: - name = get_qualified_name(self.func) if self.func else "" + name = get_callable_name(self.func, short=True) if self.func else "" with dn.task_span( - f"prompt - {name}", + name, + tags=["rigging/prompt"], attributes={"prompt_name": name, "rigging.type": "prompt.run"}, ): dn.log_inputs(**self._bind_args(*args, **kwargs)) @@ -380,10 +381,11 @@ def bind_many( ) async def run_many(count: int, /, *args: P.args, **kwargs: P.kwargs) -> list[R]: - name = get_qualified_name(self.func) if self.func else "" + name = get_callable_name(self.func, short=True) if self.func else "" with dn.task_span( - f"prompt - {name} (x{count})", + f"{name} (x{count})", label=f"prompt_{name}", + tags=["rigging/prompt"], attributes={"prompt_name": name, "rigging.type": "prompt.run_many"}, ) as span: dn.log_inputs(**self._bind_args(*args, **kwargs)) @@ -659,7 +661,7 @@ def map( for callback in callbacks: if not asyncio.iscoroutinefunction(callback): raise TypeError( - f"Callback '{get_qualified_name(callback)}' must be an async function", + f"Callback '{get_callable_name(callback)}' must be an async function", ) if allow_duplicates: @@ -667,7 +669,7 @@ def map( if callback in self.map_callbacks: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.map_callbacks.extend(callbacks) @@ -1038,7 +1040,7 @@ def then( for callback in callbacks: if not asyncio.iscoroutinefunction(callback): raise TypeError( - f"Callback '{get_qualified_name(callback)}' must be an async function", + f"Callback '{get_callable_name(callback)}' must be an async function", ) if allow_duplicates: @@ -1046,7 +1048,7 @@ def then( if callback in self.then_callbacks: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.then_callbacks.extend(callbacks) @@ -1135,7 +1137,7 @@ def watch( for callback in callbacks: if not allow_duplicates and callback in self.watch_callbacks: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.watch_callbacks.extend(callbacks) diff --git a/docs/api/tools.mdx b/docs/api/tools.mdx index 3627efa..d307c08 100644 --- a/docs/api/tools.mdx +++ b/docs/api/tools.mdx @@ -147,7 +147,8 @@ async def handle_tool_call( # noqa: PLR0912 from rigging.message import ContentText, ContentTypes, Message with dn.task_span( - f"tool - {self.name}", + self.name, + tags=["rigging/tool"], attributes={"tool_name": self.name, "rigging.type": "tool"}, ) as task: dn.log_input("tool_call", tool_call) @@ -221,7 +222,9 @@ async def handle_tool_call( # noqa: PLR0912 message.content_parts = [ContentText(text=str(result))] if self.truncate: - message = message.truncate(self.truncate) + # Use shorten instead of truncate to try and preserve + # the most context possible. + message = message.shorten(self.truncate) return message, stop ``` diff --git a/docs/api/util.mdx b/docs/api/util.mdx index 9f6bd12..bf24e42 100644 --- a/docs/api/util.mdx +++ b/docs/api/util.mdx @@ -131,9 +131,7 @@ def escape_xml(xml_string: str) -> str: """ escaped = xml_string.replace(r"&", "&") escaped = escaped.replace(r"<", "<") - escaped = escaped.replace(r">", ">") - escaped = escaped.replace(r"'", "'") - return escaped.replace(r'"', """) + return escaped.replace(r">", ">") ``` @@ -234,58 +232,90 @@ def flatten_list(nested_list: t.Iterable[t.Iterable[t.Any] | t.Any]) -> list[t.A -get\_qualified\_name --------------------- +get\_callable\_name +------------------- ```python -get_qualified_name(obj: Callable[..., Any]) -> str +get_callable_name( + obj: Callable[..., Any], *, short: bool = False +) -> str ``` -Return a best guess at the qualified name of a callable object. -This includes functions, methods, and callable classes. +Return a best-effort, comprehensive name for a callable object. + +This function handles a wide variety of callables, including regular +functions, methods, lambdas, partials, wrapped functions, and callable +class instances. + +**Parameters:** + +* **`obj`** + (`Callable[..., Any]`) + –The callable object to name. +* **`short`** + (`bool`, default: + `False` + ) + –If True, returns a shorter name suitable for logs or UI, + typically omitting the module path. The class name is + retained for methods. + +**Returns:** + +* `str` + –A string representing the callable's name. ```python -def get_qualified_name(obj: t.Callable[..., t.Any]) -> str: - """ - Return a best guess at the qualified name of a callable object. - This includes functions, methods, and callable classes. +def get_callable_name(obj: t.Callable[..., t.Any], *, short: bool = False) -> str: """ - if obj is None or not callable(obj): - return "unknown" + Return a best-effort, comprehensive name for a callable object. - module = inspect.getmodule(obj) - module_name = module.__name__ if module else "" + This function handles a wide variety of callables, including regular + functions, methods, lambdas, partials, wrapped functions, and callable + class instances. + + Args: + obj: The callable object to name. + short: If True, returns a shorter name suitable for logs or UI, + typically omitting the module path. The class name is + retained for methods. + + Returns: + A string representing the callable's name. + """ + if not callable(obj): + return repr(obj) - # Partial functions if isinstance(obj, functools.partial): - base_name = get_qualified_name(obj.func) - return f"partial({base_name})" - - # Methods - if isinstance(obj, types.MethodType): - class_name = obj.__self__.__class__.__name__ - method_name = obj.__func__.__name__ - return f"{class_name}.{method_name}" - - # Functions - if isinstance(obj, types.FunctionType): - # Check if it's a wrapped function - if hasattr(obj, "__wrapped__"): - original_name = get_qualified_name(obj.__wrapped__) - return f"wrapped({original_name})" - - name = obj.__qualname__ or obj.__name__ - return f"{module_name}.{name}" if module_name != "__main__" else name - - # Callable classes - if callable(obj): - if isinstance(obj, type): - return obj.__qualname__ - return f"{obj.__class__.__qualname__}.__call__" - - # Fallback - return obj.__class__.__qualname__ + inner_name = get_callable_name(obj.func, short=short) + return f"partial({inner_name})" + + unwrapped = obj + with contextlib.suppress(Exception): + unwrapped = inspect.unwrap(obj) + + name = getattr(unwrapped, "__qualname__", None) + + if name is None: + name = getattr(unwrapped, "__name__", None) + + if name is None: + if hasattr(obj, "__class__"): + name = getattr(obj.__class__, "__qualname__", obj.__class__.__name__) + else: + return repr(obj) + + if short: + return str(name).split(".")[-1] # Return only the last part of the name + + with contextlib.suppress(Exception): + if module := inspect.getmodule(unwrapped): + module_name = module.__name__ + if module_name and module_name not in ("builtins", "__main__"): + return f"{module_name}.{name}" + + return str(name) ``` diff --git a/rigging/chat.py b/rigging/chat.py index 8f5f964..de72021 100644 --- a/rigging/chat.py +++ b/rigging/chat.py @@ -55,7 +55,7 @@ tools_to_json_transform, tools_to_json_with_tag_transform, ) -from rigging.util import flatten_list, get_qualified_name +from rigging.util import flatten_list, get_callable_name if t.TYPE_CHECKING: from dreadnode.metric import Scorer, ScorerCallable @@ -217,6 +217,20 @@ def conversation(self) -> str: conversation += f"\n\n[error]: {self.error}" return conversation + def __str__(self) -> str: + formatted = f"--- Chat {self.uuid}" + formatted += f"\n |- timestamp: {self.timestamp.isoformat()}" + if self.usage: + formatted += f"\n |- usage: {self.usage}" + if self.generator: + formatted += f"\n |- generator: {self.generator.to_identifier(short=True)}" + if self.stop_reason: + formatted += f"\n |- stop_reason: {self.stop_reason}" + if self.metadata: + formatted += f"\n |- metadata: {self.metadata}" + formatted += f"\n\n{self.conversation}\n" + return formatted + @property def message_dicts(self) -> list[MessageDict]: """Returns the chat as a minimal message dictionaries.""" @@ -732,7 +746,7 @@ def with_parent(self, parent: "PipelineStep") -> "PipelineStep": raise RuntimeError("Unable to set parent step") def __str__(self) -> str: - callback_name = get_qualified_name(self.callback) if self.callback else "None" + callback_name = get_callable_name(self.callback) if self.callback else "None" self_str = f"PipelineStep(pipeline={id(self.pipeline)}, state={self.state}, chats={len(self.chats)}, callback={callback_name})" if self.parent is not None: self_str += f" <- {self.parent!s}" @@ -762,13 +776,16 @@ def depth(self) -> int: def _wrap_watch_callback(callback: WatchChatCallback) -> WatchChatCallback: import dreadnode as dn - callback_name = get_qualified_name(callback) - return dn.task( - name=f"watch - {callback_name}", - attributes={"rigging.type": "chat_pipeline.watch_callback"}, - log_inputs=True, - log_output=False, - )(callback) + callback_name = get_callable_name(callback) + + async def wrapped_callback(chats: list[Chat]) -> None: + with dn.span( + name=callback_name, + attributes={"rigging.type": "chat_pipeline.watch_callback"}, + ): + await callback(chats) + + return wrapped_callback # Pipeline @@ -803,7 +820,7 @@ def __init__( """How to handle failures in the pipeline unless overridden in calls.""" self.caching: CacheMode | None = None """How to handle cache_control entries on messages.""" - self.task_name: str = generator.to_identifier(short=True) + self.task_name: str = f"Chat with {generator.to_identifier(short=True)}" """The name of the pipeline task, used for logging and debugging.""" self.scorers: list[Scorer[Chat]] = [] """List of dreadnode scorers to evaluate the generated chat upon completion.""" @@ -897,7 +914,7 @@ async def log(chats: list[Chat]) -> None: for callback in callbacks: if not allow_duplicates and callback in self.watch_callbacks: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.watch_callbacks.extend(callbacks) @@ -1120,7 +1137,7 @@ async def process(chat: Chat) -> Chat | None: if callback in [c[0] for c in self.then_callbacks]: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.then_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks]) @@ -1164,7 +1181,7 @@ async def process(chats: list[Chat]) -> list[Chat]: if callback in [c[0] for c in self.map_callbacks]: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.map_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks]) @@ -1208,7 +1225,7 @@ async def post_transform(chat: Chat) -> Chat | None: for callback in callbacks: if not allow_duplicates and callback in self.transforms: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.transforms.extend(callbacks) @@ -1576,8 +1593,6 @@ async def _process_then_callback( callback: ThenChatCallback, state: CallbackState, ) -> None: - callback_name = get_qualified_name(callback) - async def complete() -> None: state.completed = True state.ready_event.set() @@ -1598,7 +1613,7 @@ async def complete() -> None: if not inspect.isasyncgen(result): raise TypeError( - f"Callback '{callback_name}' must return a Chat, PipelineStepGenerator, or None", + f"Callback '{get_callable_name(callback)}' must return a Chat, PipelineStepGenerator, or None", ) generator = t.cast( @@ -1702,14 +1717,14 @@ async def _step( # noqa: PLR0915, PLR0912 try: messages = apply_cache_mode_to_messages(self.caching, messages) - with dn.task_span( + with dn.span( f"generate - {self.generator.to_identifier(short=True)}", attributes={"rigging.type": "chat_pipeline.generate"}, ): - dn.log_input("messages", messages) - dn.log_input("params", params) + # dn.log_input("messages", messages) + # dn.log_input("params", params) generated = await self.generator.generate_messages(messages, params) - dn.log_output("generated", generated) + # dn.log_output("generated", generated) # If we got a total failure here for generation as a whole, # we can't distinguish between incoming messages in terms @@ -1812,7 +1827,7 @@ async def _step( # noqa: PLR0915, PLR0912 # Then callbacks for then_callback, max_depth, as_task in self.then_callbacks: - callback_name = get_qualified_name(then_callback) + callback_name = get_callable_name(then_callback, short=True) states = [ self.CallbackState( @@ -1825,7 +1840,7 @@ async def _step( # noqa: PLR0915, PLR0912 callback_task = ( dn.task( - name=f"then - {callback_name}", + name=callback_name, attributes={"rigging.type": "chat_pipeline.then_callback"}, log_inputs=True, log_output=True, @@ -1904,11 +1919,11 @@ async def _step( # noqa: PLR0915, PLR0912 # Map callbacks for map_callback, max_depth, as_task in self.map_callbacks: - callback_name = get_qualified_name(map_callback) + callback_name = get_callable_name(map_callback, short=True) map_task = ( dn.task( - name=f"map - {callback_name}", + name=callback_name, attributes={"rigging.type": "chat_pipeline.map_callback"}, log_inputs=True, log_output=True, @@ -2056,8 +2071,9 @@ async def run( last: PipelineStep | None = None with dn.task_span( - name or f"pipeline - {self.task_name}", + name or self.task_name, label=name or f"pipeline_{self.task_name}", + tags=["rigging/pipeline"], attributes={"rigging.type": "chat_pipeline.run"}, ) as task: dn.log_inputs( @@ -2156,8 +2172,9 @@ async def run_many( last: PipelineStep | None = None with dn.task_span( - name or f"pipeline - {self.task_name} (x{count})", + name or f"{self.task_name} (x{count})", label=name or f"pipeline_many_{self.task_name}", + tags=["rigging/pipeline"], attributes={"rigging.type": "chat_pipeline.run_many"}, ) as task: dn.log_inputs( @@ -2316,8 +2333,9 @@ async def run_batch( last: PipelineStep | None = None with dn.task_span( - name or f"pipeline - {self.task_name} (batch x{count})", + name or f"{self.task_name} (batch x{count})", label=name or f"pipeline_batch_{self.task_name}", + tags=["rigging/pipeline"], attributes={"rigging.type": "chat_pipeline.run_batch"}, ) as task: dn.log_inputs( diff --git a/rigging/completion.py b/rigging/completion.py index 2f0826b..ee8c96e 100644 --- a/rigging/completion.py +++ b/rigging/completion.py @@ -18,7 +18,7 @@ from rigging.generator import GenerateParams, Generator, get_generator from rigging.generator.base import GeneratedText, StopReason, Usage from rigging.parsing import parse_many -from rigging.util import get_qualified_name +from rigging.util import get_callable_name if t.TYPE_CHECKING: from dreadnode import Span @@ -577,13 +577,16 @@ async def _watch_callback(self, completions: list[Completion]) -> None: def wrap_watch_callback( callback: WatchCompletionCallback, ) -> t.Callable[[list[Completion]], t.Awaitable[None]]: - callback_name = get_qualified_name(callback) - return dn.task( - name=f"watch - {callback_name}", - attributes={"rigging.type": "completion_pipeline.watch_callback"}, - log_inputs=True, - log_output=False, - )(callback) + callback_name = get_callable_name(callback, short=True) + + async def wrapped_callback(completions: list[Completion]) -> None: + with dn.span( + name=callback_name, + attributes={"rigging.type": "chat_pipeline.watch_callback"}, + ): + await callback(completions) + + return wrapped_callback traced_callbacks = [wrap_watch_callback(callback) for callback in self.watch_callbacks] coros = [callback(completions) for callback in traced_callbacks] @@ -636,9 +639,9 @@ async def _post_run( # previous calls being ran. for map_callback in self.map_callbacks: - callback_name = get_qualified_name(map_callback) + callback_name = get_callable_name(map_callback) traced_map_callback = dn.task( - name=f"map - {callback_name}", + name=callback_name, attributes={"rigging.type": "completion_pipeline.map_callback"}, log_inputs=True, log_output=True, @@ -650,9 +653,9 @@ async def _post_run( ) for then_callback in self.then_callbacks: - callback_name = get_qualified_name(then_callback) + callback_name = get_callable_name(then_callback, short=True) traced_then_callback = dn.task( - name=f"then - {callback_name}", + name=callback_name, attributes={"rigging.type": "completion_pipeline.then_callback"}, log_inputs=True, log_output=True, diff --git a/rigging/generator/base.py b/rigging/generator/base.py index 336b15d..4233b3f 100644 --- a/rigging/generator/base.py +++ b/rigging/generator/base.py @@ -310,6 +310,9 @@ def __add__(self, other: "Usage") -> "Usage": total_tokens=self.total_tokens + other.total_tokens, ) + def __str__(self) -> str: + return f"in: {self.input_tokens} | out: {self.output_tokens} | total: {self.total_tokens}" + GeneratedT = t.TypeVar("GeneratedT", Message, str) diff --git a/rigging/message.py b/rigging/message.py index 6741817..03d31d6 100644 --- a/rigging/message.py +++ b/rigging/message.py @@ -28,7 +28,12 @@ from rigging.model import Model, ModelT from rigging.parsing import try_parse_many from rigging.tools.base import ToolCall -from rigging.util import AudioFormat, identify_audio_format, shorten_string, truncate_string +from rigging.util import ( + AudioFormat, + identify_audio_format, + shorten_string, + truncate_string, +) Role = t.Literal["system", "user", "assistant", "tool"] """The role of a message. Can be 'system', 'user', 'assistant', or 'tool'.""" @@ -111,8 +116,9 @@ def __len__(self) -> int: def __str__(self) -> str: """Returns a string representation of the slice.""" - content_preview = self.content if self._message else "[detached]" - return f"" + content = shorten_string(self.content if self._message else "[detached]", 50) + obj = self.obj.__class__.__name__ if self.obj else None + return f"MessageSlice(type='{self.type}', start={self.start}, stop={self.stop} obj={obj} content='{content}')" def clone(self) -> "MessageSlice": """ @@ -164,7 +170,7 @@ class ImageUrl(BaseModel): """Cache control entry for prompt caching.""" def __str__(self) -> str: - return f"" + return f"ContentImageUrl(url='{shorten_string(self.image_url.url, 50)}')" @classmethod def from_file( @@ -1124,6 +1130,21 @@ def truncate(self, max_length: int, suffix: str = "\n[truncated]") -> "Message": new.content = truncate_string(new.content, max_length, suf=suffix) return new + def shorten(self, max_length: int, sep: str = "...") -> "Message": + """ + Shortens the message content to at most max_length characters long by removing the middle of the string + + Args: + max_length: The maximum length of the message content. + sep: The separator to use when shortening the content. + + Returns: + The shortened message. + """ + new = self.clone() + new.content = shorten_string(new.content, max_length, sep=sep) + return new + @property def models(self) -> list[Model]: """ diff --git a/rigging/model.py b/rigging/model.py index cd07488..bb145cf 100644 --- a/rigging/model.py +++ b/rigging/model.py @@ -5,14 +5,15 @@ import dataclasses import inspect import re +import textwrap import typing as t from xml.etree import ElementTree as ET # nosec import typing_extensions as te -import xmltodict # type: ignore [import-untyped] from pydantic import ( BaseModel, BeforeValidator, + ConfigDict, Field, SerializationInfo, ValidationError, @@ -68,6 +69,8 @@ def __get__(self, _: t.Any, owner: t.Any) -> str: class Model(BaseXmlModel): + model_config = ConfigDict(use_attribute_docstrings=True) + def __init_subclass__( cls, tag: str | None = None, @@ -135,16 +138,66 @@ def _postprocess_with_cdata(self, tree: ET.Element) -> ET.Element: return tree - # to_xml() doesn't prettify normally, and extended - # requirements like lxml seemed like poor form for - # just this feature + def _serialize_tree_prettily( + self, element: ET.Element, level: int = 0, indent_str: str = " " + ) -> str: + # Essentially a custom work of ET.indent to better + # handle multi-line text so we get: + # + # + # Some text + # More text + # + # + # instead of: + # + # Some text + # More text + + indent = indent_str * level + lines = [] + + attrs = "".join(f' {k}="{escape_xml(v)}"' for k, v in element.attrib.items()) + tag_with_attrs = f"{element.tag}{attrs}" + + text = element.text and element.text.strip() + multiline_text = text and "\n" in text + has_children = len(element) > 0 + + if not has_children and not multiline_text: + if level == 0 and not text: + lines.append(f"<{tag_with_attrs}>") + elif not text: + lines.append(f"{indent}<{tag_with_attrs} />") + else: + lines.append(f"{indent}<{tag_with_attrs}>{text}") + return "\n".join(lines) + + lines.append(f"{indent}<{tag_with_attrs}>") + + if text and multiline_text: + dedented_text = textwrap.dedent(text).strip() + content_indent = indent_str * (level + 1) + for line in dedented_text.split("\n"): + lines.append(f"{content_indent}{line}" if line.strip() else "") # noqa: PERF401 + + for child in element: + lines.append(self._serialize_tree_prettily(child, level + 1, indent_str)) # noqa: PERF401 + + lines.append(f"{indent}") + + return "\n".join(lines) + + # to_xml() doesn't prettify normally, and extended requirements + # like lxml seemed like poor form for just this feature + def to_pretty_xml( self, *, skip_empty: bool = False, exclude_none: bool = False, exclude_unset: bool = False, - **kwargs: t.Any, + **_: t.Any, ) -> str: """ Converts the model to a pretty XML string with indents and newlines. @@ -157,22 +210,7 @@ def to_pretty_xml( exclude_none=exclude_none, exclude_unset=exclude_unset, ) - tree = self._postprocess_with_cdata(tree) - - ET.indent(tree, " ") - pretty_encoded_xml = str( - ET.tostring( - tree, - short_empty_elements=False, - encoding="utf-8", - **kwargs, - ).decode(), - ) - - # Now we can go back and safely unescape the XML - # that we observe between any CDATA tags - - return unescape_cdata_tags(pretty_encoded_xml) + return self._serialize_tree_prettily(tree) def to_xml( self, @@ -293,27 +331,55 @@ def xml_example(cls) -> str: """ Returns an example XML representation of the given class. - Models should typically override this method to provide a more complex example. + This method generates a pretty-printed XML string that includes: + - Example values for each field, taken from the `example` argument + in a field constructor. + - Field descriptions as XML comments, derived from the field's + docstring or the `description` argument. - By default, this method returns a hollow XML scaffold one layer deep. + Note: This implementation is designed for models with flat structures + and does not recursively generate examples for nested models. Returns: - A string containing the XML representation of the class. + A string containing the pretty-printed XML example. """ if cls.is_simple(): - return cls.xml_tags() - - schema = cls.model_json_schema() - properties = schema["properties"] - structure = {cls.__xml_tag__: dict.fromkeys(properties)} - xml_string = xmltodict.unparse( - structure, - pretty=True, - full_document=False, - indent=" ", - short_empty_elements=True, - ) - return t.cast("str", xml_string) # Bad type hints in xmltodict + field_info = next(iter(cls.model_fields.values())) + example = str(next(iter(field_info.examples or []), "")) + return f"<{cls.__xml_tag__}>{escape_xml(example)}" + + lines = [] + attribute_parts = [] + element_fields = {} + + for field_name, field_info in cls.model_fields.items(): + if ( + isinstance(field_info, XmlEntityInfo) + and field_info.location == EntityLocation.ATTRIBUTE + ): + path = field_info.path or field_name + example = str(next(iter(field_info.examples or []), "")).replace('"', """) + attribute_parts.append(f'{path}="{example}"') + else: + element_fields[field_name] = field_info + + attr_string = (" " + " ".join(attribute_parts)) if attribute_parts else "" + lines.append(f"<{cls.__xml_tag__}{attr_string}>") + + for field_name, field_info in element_fields.items(): + path = (isinstance(field_info, XmlEntityInfo) and field_info.path) or field_name + description = field_info.description + example = str(next(iter(field_info.examples or []), "")) + + if description: + lines.append(f" ") + if example: + lines.append(f" <{path}>{escape_xml(example)}") + else: + lines.append(f" <{path}/>") + + lines.append(f"") + return "\n".join(lines) @classmethod def ensure_valid(cls) -> None: @@ -383,7 +449,7 @@ def wrap_with_cdata(match: re.Match[str]) -> str: needs_escaping = escape_xml(unescape_xml(content)) != content if is_basic_field and not is_already_cdata and needs_escaping: - content = f"" + content = f"" return f"<{field_name}{tag_attrs}>{content}" @@ -504,7 +570,13 @@ def from_text( try: model = ( - cls(**{next(iter(cls.model_fields)): unescape_xml(inner)}) + cls( + **{ + next(iter(cls.model_fields)): unescape_xml( + textwrap.dedent(inner).strip() + ) + } + ) if cls.is_simple() else cls.from_xml( cls.preprocess_with_cdata(full_text), @@ -514,7 +586,7 @@ def from_text( # If our model is relatively simple (only attributes and a single non-element field) # we should go back and update our non-element field with the extracted content. - if cls.is_simple_with_attrs(): + if not cls.is_simple() and cls.is_simple_with_attrs(): name, field = next( (name, field) for name, field in cls.model_fields.items() @@ -525,6 +597,14 @@ def from_text( unescape_xml(inner).strip(), ) + # Walk through any fields which are strings, and dedent them + + for field_name, field_info in cls.model_fields.items(): + if isinstance(field_info, XmlEntityInfo) and field_info.annotation == str: # noqa: E721 + model.__dict__[field_name] = textwrap.dedent( + model.__dict__[field_name] + ).strip() + extracted.append((model, slice_)) except Exception as e: # noqa: BLE001 extracted.append((e, slice_)) diff --git a/rigging/prompt.py b/rigging/prompt.py index 601be2f..5942435 100644 --- a/rigging/prompt.py +++ b/rigging/prompt.py @@ -25,7 +25,7 @@ from rigging.message import Message from rigging.model import Model, SystemErrorModel, ValidationErrorModel, make_primitive from rigging.tools import Tool -from rigging.util import escape_xml, get_qualified_name, to_snake, to_xml_tag +from rigging.util import escape_xml, get_callable_name, to_snake, to_xml_tag DEFAULT_DOC = "Convert the following inputs to outputs ({func_name})." """Default docstring if none is provided to a prompt function.""" @@ -563,10 +563,10 @@ async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None: # A bit weird, but we need from_chat to properly handle # wrapping Chat output types inside lists/dataclasses with dn.task_span( - f"prompt parse - {self.output.tag}", + f"parse - {self.output.tag}", attributes={"rigging.type": "prompt.parse"}, ): - dn.log_input("message", chat.last) + dn.log_input("message", str(chat.last)) self.output.from_chat(chat) except ValidationError as e: next_pipeline.add( @@ -642,7 +642,7 @@ async def summarize(text: str) -> str: for callback in callbacks: if not allow_duplicates and callback in self.watch_callbacks: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.watch_callbacks.extend(callbacks) @@ -680,7 +680,7 @@ async def summarize(text: str) -> str: for callback in callbacks: if not asyncio.iscoroutinefunction(callback): raise TypeError( - f"Callback '{get_qualified_name(callback)}' must be an async function", + f"Callback '{get_callable_name(callback)}' must be an async function", ) if allow_duplicates: @@ -688,7 +688,7 @@ async def summarize(text: str) -> str: if callback in self.then_callbacks: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.then_callbacks.extend(callbacks) @@ -726,7 +726,7 @@ async def summarize(text: str) -> str: for callback in callbacks: if not asyncio.iscoroutinefunction(callback): raise TypeError( - f"Callback '{get_qualified_name(callback)}' must be an async function", + f"Callback '{get_callable_name(callback)}' must be an async function", ) if allow_duplicates: @@ -734,7 +734,7 @@ async def summarize(text: str) -> str: if callback in self.map_callbacks: raise ValueError( - f"Callback '{get_qualified_name(callback)}' is already registered.", + f"Callback '{get_callable_name(callback)}' is already registered.", ) self.map_callbacks.extend(callbacks) @@ -850,9 +850,10 @@ def say_hello(name: str) -> str: ) async def run(*args: P.args, **kwargs: P.kwargs) -> R: - name = get_qualified_name(self.func) if self.func else "" + name = get_callable_name(self.func, short=True) if self.func else "" with dn.task_span( - f"prompt - {name}", + name, + tags=["rigging/prompt"], attributes={"prompt_name": name, "rigging.type": "prompt.run"}, ): dn.log_inputs(**self._bind_args(*args, **kwargs)) @@ -913,10 +914,11 @@ def say_hello(name: str) -> str: ) async def run_many(count: int, /, *args: P.args, **kwargs: P.kwargs) -> list[R]: - name = get_qualified_name(self.func) if self.func else "" + name = get_callable_name(self.func, short=True) if self.func else "" with dn.task_span( - f"prompt - {name} (x{count})", + f"{name} (x{count})", label=f"prompt_{name}", + tags=["rigging/prompt"], attributes={"prompt_name": name, "rigging.type": "prompt.run_many"}, ) as span: dn.log_inputs(**self._bind_args(*args, **kwargs)) diff --git a/rigging/tools/base.py b/rigging/tools/base.py index e6d0ac6..c428df7 100644 --- a/rigging/tools/base.py +++ b/rigging/tools/base.py @@ -30,7 +30,7 @@ make_from_schema, make_from_signature, ) -from rigging.util import deref_json +from rigging.util import deref_json, shorten_string if t.TYPE_CHECKING: from rigging.message import Message @@ -116,7 +116,8 @@ class ToolCall(BaseModel): function: FunctionCall def __str__(self) -> str: - return f"" + arguments = shorten_string(self.function.arguments, max_length=50) + return f"ToolCall({self.function.name}({arguments}), id='{self.id}')" @property def name(self) -> str: @@ -359,7 +360,8 @@ async def handle_tool_call( # noqa: PLR0912 from rigging.message import ContentText, ContentTypes, Message with dn.task_span( - f"tool - {self.name}", + self.name, + tags=["rigging/tool"], attributes={"tool_name": self.name, "rigging.type": "tool"}, ) as task: dn.log_input("tool_call", tool_call) @@ -433,7 +435,9 @@ async def handle_tool_call( # noqa: PLR0912 message.content_parts = [ContentText(text=str(result))] if self.truncate: - message = message.truncate(self.truncate) + # Use shorten instead of truncate to try and preserve + # the most context possible. + message = message.shorten(self.truncate) return message, stop diff --git a/rigging/transform/json_tools.py b/rigging/transform/json_tools.py index 459978f..a95dabc 100644 --- a/rigging/transform/json_tools.py +++ b/rigging/transform/json_tools.py @@ -16,7 +16,7 @@ from rigging.model import Model from rigging.tools.base import FunctionCall, ToolCall, ToolDefinition, ToolResponse from rigging.transform.base import PostTransform, Transform -from rigging.util import extract_json_objects +from rigging.util import extract_json_objects, shorten_string if t.TYPE_CHECKING: from rigging.chat import Chat @@ -54,7 +54,8 @@ class JsonInXmlToolCall(Model): parameters: str def __str__(self) -> str: - return f"" + parameters = shorten_string(self.parameters, max_length=50) + return f"JsonInXmlToolCall(name='{self.name}', parameters='{parameters}', id='{self.id}')" class JsonToolCall(Model): @@ -62,7 +63,8 @@ class JsonToolCall(Model): content: str def __str__(self) -> str: - return f"" + content = shorten_string(self.content, max_length=50) + return f"JsonToolCall(content='{content}', id='{self.id}')" # Prompts diff --git a/rigging/transform/xml_tools.py b/rigging/transform/xml_tools.py index 1baf998..0f43d28 100644 --- a/rigging/transform/xml_tools.py +++ b/rigging/transform/xml_tools.py @@ -18,6 +18,7 @@ from rigging.model import Model from rigging.tools.base import FunctionCall, Tool, ToolCall, ToolResponse from rigging.transform.base import PostTransform, Transform +from rigging.util import shorten_string if t.TYPE_CHECKING: from rigging.chat import Chat @@ -114,7 +115,8 @@ class XmlToolCall(Model, tag=TOOL_CALL_TAG): parameters: str def __str__(self) -> str: - return f"" + parameters = shorten_string(self.parameters, max_length=50) + return f"XmlToolCall(name='{self.name}' parameters='{parameters}')" XML_TOOLS_PREFIX = f"""\ diff --git a/rigging/util.py b/rigging/util.py index cfb6630..7f0cce4 100644 --- a/rigging/util.py +++ b/rigging/util.py @@ -3,10 +3,10 @@ """ import asyncio +import contextlib import functools import inspect import re -import types import typing as t from json import JSONDecoder from threading import Thread @@ -145,9 +145,7 @@ def escape_xml(xml_string: str) -> str: """ escaped = xml_string.replace(r"&", "&") escaped = escaped.replace(r"<", "<") - escaped = escaped.replace(r">", ">") - escaped = escaped.replace(r"'", "'") - return escaped.replace(r'"', """) + return escaped.replace(r">", ">") def unescape_cdata_tags(xml_string: str) -> str: @@ -183,46 +181,55 @@ def to_xml_tag(text: str) -> str: # Name resolution -def get_qualified_name(obj: t.Callable[..., t.Any]) -> str: +def get_callable_name(obj: t.Callable[..., t.Any], *, short: bool = False) -> str: """ - Return a best guess at the qualified name of a callable object. - This includes functions, methods, and callable classes. - """ - if obj is None or not callable(obj): - return "unknown" + Return a best-effort, comprehensive name for a callable object. + + This function handles a wide variety of callables, including regular + functions, methods, lambdas, partials, wrapped functions, and callable + class instances. - module = inspect.getmodule(obj) - module_name = module.__name__ if module else "" + Args: + obj: The callable object to name. + short: If True, returns a shorter name suitable for logs or UI, + typically omitting the module path. The class name is + retained for methods. + + Returns: + A string representing the callable's name. + """ + if not callable(obj): + return repr(obj) - # Partial functions if isinstance(obj, functools.partial): - base_name = get_qualified_name(obj.func) - return f"partial({base_name})" - - # Methods - if isinstance(obj, types.MethodType): - class_name = obj.__self__.__class__.__name__ - method_name = obj.__func__.__name__ - return f"{class_name}.{method_name}" - - # Functions - if isinstance(obj, types.FunctionType): - # Check if it's a wrapped function - if hasattr(obj, "__wrapped__"): - original_name = get_qualified_name(obj.__wrapped__) - return f"wrapped({original_name})" - - name = obj.__qualname__ or obj.__name__ - return f"{module_name}.{name}" if module_name != "__main__" else name - - # Callable classes - if callable(obj): - if isinstance(obj, type): - return obj.__qualname__ - return f"{obj.__class__.__qualname__}.__call__" - - # Fallback - return obj.__class__.__qualname__ + inner_name = get_callable_name(obj.func, short=short) + return f"partial({inner_name})" + + unwrapped = obj + with contextlib.suppress(Exception): + unwrapped = inspect.unwrap(obj) + + name = getattr(unwrapped, "__qualname__", None) + + if name is None: + name = getattr(unwrapped, "__name__", None) + + if name is None: + if hasattr(obj, "__class__"): + name = getattr(obj.__class__, "__qualname__", obj.__class__.__name__) + else: + return repr(obj) + + if short: + return str(name).split(".")[-1] # Return only the last part of the name + + with contextlib.suppress(Exception): + if module := inspect.getmodule(unwrapped): + module_name = module.__name__ + if module_name and module_name not in ("builtins", "__main__"): + return f"{module_name}.{name}" + + return str(name) # Formatting diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..a3a8385 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,161 @@ +import typing as t +from textwrap import dedent +from xml.sax.saxutils import escape + +import pytest + +from rigging.model import Model, attr, element + +# mypy: disable-error-code=empty-body +# ruff: noqa: S101, PLR2004, ARG001, PT011, SLF001 + + +class SimpleModel(Model): + """A simple model to test basic element generation.""" + + content: str = element(examples=["Hello, World!"]) + """The main content of the model.""" + + +class NoExampleModel(Model): + """A model to test fallback to an empty element when no example is given.""" + + name: str + """The name of the entity.""" + + +class AttrAndElementModel(Model): + """Tests a model with both an attribute and a child element.""" + + id: int = attr(examples=[123]) + """The unique identifier (attribute).""" + value: str = element(examples=["Some value"]) + """The primary value (element).""" + + +class DocstringDescriptionModel(Model): + """Tests that field docstrings are correctly used as descriptions.""" + + field1: str = element(examples=["val1"]) + """This is the description for field1.""" + field2: bool = element(examples=[True]) + """This is the description for field2.""" + + +class ParameterDescriptionModel(Model): + """Tests that the `description` parameter overrides a field's docstring.""" + + param: str = element( + examples=["override"], description="This description is from the `description` parameter." + ) + """This docstring should be ignored in the XML example.""" + + +class SpecialCharsModel(Model): + """Tests proper escaping of special XML characters in examples and comments.""" + + comment: str = element(examples=["ok"]) + """This comment contains < and > & special characters.""" + data: str = element(examples=["&'"]) + """This element's example contains special XML characters.""" + + +# This class definition is based on the one you provided in the prompt. +class Analysis(Model, tag="analysis"): + """A model to validate the exact output requested in the prompt.""" + + priority: t.Literal["low", "medium", "high", "critical"] = element(examples=["medium"]) + """Triage priority for human follow-up.""" + tags: str = element("tags", examples=["admin panel, error message, legacy"]) + """A list of specific areas within the screenshot that are noteworthy or require further examination.""" + summary: str = element() + """A markdown summary explaining *why* the screenshot is interesting and what a human should investigate next.""" + + +@pytest.mark.parametrize( + ("model_cls", "expected_xml"), + [ + pytest.param( + SimpleModel, + """ + + + Hello, World! + + """, + id="simple_model", + ), + pytest.param( + NoExampleModel, + """ + + """, + id="model_with_no_example", + ), + pytest.param( + AttrAndElementModel, + """ + + + Some value + + """, + id="model_with_attribute_and_element", + ), + pytest.param( + DocstringDescriptionModel, + """ + + + val1 + + True + + """, + id="descriptions_from_docstrings", + ), + pytest.param( + ParameterDescriptionModel, + """ + + + override + + """, + id="description_from_parameter_overrides_docstring", + ), + pytest.param( + SpecialCharsModel, + f""" + + + ok + + {escape("&'")} + + """, + id="escaping_of_special_characters", + ), + pytest.param( + Analysis, + """ + + + medium + + admin panel, error message, legacy + + + + """, + id="user_provided_analysis_model", + ), + ], +) +def test_xml_example_generation(model_cls: type[Model], expected_xml: str) -> None: + """ + Validates that the `xml_example()` class method produces the correct + pretty-printed XML with examples and descriptions as comments. + """ + actual_xml = model_cls.xml_example() + assert dedent(actual_xml).strip() == dedent(expected_xml).strip() diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 5c446c9..d530786 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -240,22 +240,19 @@ async def register_user(username: str, email: str, age: int) -> User: def test_prompt_parse_fail_nested_input() -> None: - async def foo(arg: list[list[str]]) -> Chat: - ... + async def foo(arg: list[list[str]]) -> Chat: ... with pytest.raises(TypeError): rg.prompt(foo) - async def bar(arg: tuple[int, str, tuple[str]]) -> Chat: - ... + async def bar(arg: tuple[int, str, tuple[str]]) -> Chat: ... with pytest.raises(TypeError): rg.prompt(bar) def test_prompt_parse_fail_unique_ouput() -> None: - async def foo(arg: int) -> tuple[str, str]: - ... + async def foo(arg: int) -> tuple[str, str]: ... with pytest.raises(TypeError): rg.prompt(foo) diff --git a/tests/test_xml_parsing.py b/tests/test_xml_parsing.py index a958118..dba74a8 100644 --- a/tests/test_xml_parsing.py +++ b/tests/test_xml_parsing.py @@ -1,5 +1,6 @@ import typing as t from contextlib import nullcontext as does_not_raise +from textwrap import dedent import pytest @@ -70,6 +71,43 @@ class MixedModel(Model): nested: ContentAsElement = element() +MULTI_LINE_TEXT = """\ +Multiline content with indentation +some extra spaces + +Some more text\ +""" +MULTI_LINE_CONTENT_TAG = """\ + + Multiline content with indentation + some extra spaces + + Some more text + +""" +MULTI_LINE_CONTENT_AS_ELEMENT = """\ + + + Multiline content with indentation + some extra spaces + + Some more text + + +""" +MULTI_LINE_CONTENT_MULTI_FIELD = """\ + + Process terminated + + Multiline content with indentation + some extra spaces + + Some more text + + +""" + + @pytest.mark.parametrize( ("content", "models"), [ @@ -115,13 +153,23 @@ class MixedModel(Model): id="question_with_answer_tag_2", ), pytest.param( - "helloworld", + dedent("""\ + + hello + world + \ + """), [QuestionAnswer(question=Question(content="hello"), answer=Answer(content="world"))], id="question_answer", ), pytest.param( - "- hello\n - world", - [DelimitedAnswer(content="- hello\n - world", _items=["hello", "world"])], + dedent("""\ + + - hello + - world + \ + """), + [DelimitedAnswer(content="- hello\n- world", _items=["hello", "world"])], id="newline_delimited_answer", ), pytest.param( @@ -145,12 +193,22 @@ class MixedModel(Model): id="slash_delimited_answer", ), pytest.param( - 'ab', + dedent("""\ + + a + b + \ + """), [NameWithThings(name="test", things=["a", "b"])], id="name_with_things", ), pytest.param( - 'meowbark', + dedent("""\ + + meow + bark + \ + """), [ Wrapped( inners=[ @@ -195,6 +253,27 @@ class MixedModel(Model): [ContentTag(content="actual content")], id="tag_in_text_and_xml", ), + pytest.param( + MULTI_LINE_CONTENT_TAG, + [ContentTag(content=MULTI_LINE_TEXT)], + id="indented_multiline_content_tag", + ), + pytest.param( + MULTI_LINE_CONTENT_AS_ELEMENT, + [ContentAsElement(content=MULTI_LINE_TEXT)], + id="indented_multiline_content_as_element", + ), + pytest.param( + MULTI_LINE_CONTENT_MULTI_FIELD, + [ + MultiFieldModel( + type_="cmd", + foo_field="Process terminated", + bar_field=MULTI_LINE_TEXT, + ), + ], + id="indented_multiline_content_multi_field", + ), pytest.param( "first text second more text third", [Answer(content="first"), Answer(content="second"), Answer(content="third")], @@ -211,7 +290,12 @@ class MixedModel(Model): id="xml_breaking_chars_with_ampersand", ), pytest.param( - 'Process terminatedExit code <1>', + dedent("""\ + + Process terminated + Exit code <1> + \ + """), [ MultiFieldModel( type_="cmd", @@ -222,7 +306,13 @@ class MixedModel(Model): id="multi_field_with_xml_breaking_chars", ), pytest.param( - "Volume in drive C:\n Directory of C:\\\n 01/02/2024 Program Files", + dedent("""\ + + Volume in drive C: + Directory of C:\\ + 01/02/2024 Program Files + \ + """), [ Content( content="Volume in drive C:\n Directory of C:\\\n 01/02/2024 Program Files", @@ -231,7 +321,14 @@ class MixedModel(Model): id="shell_output_simulation", ), pytest.param( - "Error in at line <42>normal nested content", + dedent("""\ + + Error in at line <42> + + normal nested content + + \ + """), [ MixedModel( content_field="Error in at line <42>", @@ -249,7 +346,7 @@ def test_xml_parsing(content: str, models: list[Model]) -> None: assert obj.model_dump() == expected.model_dump(), ( f"Failed to parse model {expected.__class__.__name__} <- {obj!s} ({parsed})" ) - xml = obj.to_xml() + xml = obj.to_pretty_xml() assert xml == content[slice_], ( f"Failed to serialize model {expected.__class__.__name__} back to XML: {xml!r} != {content!r}" ) @@ -258,7 +355,7 @@ def test_xml_parsing(content: str, models: list[Model]) -> None: @pytest.mark.parametrize( ("content", "models"), [ - # These cases parse correctly, but their XML representation doesn't yeild + # These cases parse correctly, but their XML representation doesn't yield # the original content as some escape sequences are not preserved. pytest.param( "Text with <inner> as escaped HTML entities",