diff --git a/.github/workflows/docs-update.yaml b/.github/workflows/docs-update.yaml
index a25107c..a1edc9a 100644
--- a/.github/workflows/docs-update.yaml
+++ b/.github/workflows/docs-update.yaml
@@ -1,5 +1,5 @@
---
-name: Notify Documentation Update
+name: Trigger Docs Update
on:
push:
@@ -21,15 +21,14 @@ jobs:
private-key: ${{ secrets.UPDATE_DOCS_PRIVATE_KEY }}
owner: "${{ github.repository_owner }}"
repositories: |
- sdk
- prod-docs
+ docs
- name: Trigger docs repository workflow
uses: peter-evans/repository-dispatch@ff45666b9427631e3450c54a1bcbee4d9ff4d7c0 # v3.0.0
with:
token: ${{ steps.app-token.outputs.token }}
- repository: dreadnode/prod-docs
- event-type: code-update
+ repository: dreadnode/docs
+ event-type: docs-update
client-payload: |
{
"repository": "${{ github.repository }}",
diff --git a/docs/api/chat.mdx b/docs/api/chat.mdx
index 41523aa..bfe6da3 100644
--- a/docs/api/chat.mdx
+++ b/docs/api/chat.mdx
@@ -1263,7 +1263,7 @@ def __init__(
"""How to handle failures in the pipeline unless overridden in calls."""
self.caching: CacheMode | None = None
"""How to handle cache_control entries on messages."""
- self.task_name: str = generator.to_identifier(short=True)
+ self.task_name: str = f"Chat with {generator.to_identifier(short=True)}"
"""The name of the pipeline task, used for logging and debugging."""
self.scorers: list[Scorer[Chat]] = []
"""List of dreadnode scorers to evaluate the generated chat upon completion."""
@@ -1360,7 +1360,7 @@ List of dreadnode scorers to evaluate the generated chat upon completion.
### task\_name
```python
-task_name: str = to_identifier(short=True)
+task_name: str = f'Chat with {to_identifier(short=True)}'
```
The name of the pipeline task, used for logging and debugging.
@@ -1935,7 +1935,7 @@ def map(
if callback in [c[0] for c in self.map_callbacks]:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.map_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks])
@@ -2146,8 +2146,9 @@ async def run(
last: PipelineStep | None = None
with dn.task_span(
- name or f"pipeline - {self.task_name}",
+ name or self.task_name,
label=name or f"pipeline_{self.task_name}",
+ tags=["rigging/pipeline"],
attributes={"rigging.type": "chat_pipeline.run"},
) as task:
dn.log_inputs(
@@ -2279,8 +2280,9 @@ async def run_batch(
last: PipelineStep | None = None
with dn.task_span(
- name or f"pipeline - {self.task_name} (batch x{count})",
+ name or f"{self.task_name} (batch x{count})",
label=name or f"pipeline_batch_{self.task_name}",
+ tags=["rigging/pipeline"],
attributes={"rigging.type": "chat_pipeline.run_batch"},
) as task:
dn.log_inputs(
@@ -2426,8 +2428,9 @@ async def run_many(
last: PipelineStep | None = None
with dn.task_span(
- name or f"pipeline - {self.task_name} (x{count})",
+ name or f"{self.task_name} (x{count})",
label=name or f"pipeline_many_{self.task_name}",
+ tags=["rigging/pipeline"],
attributes={"rigging.type": "chat_pipeline.run_many"},
) as task:
dn.log_inputs(
@@ -2968,7 +2971,7 @@ def then(
if callback in [c[0] for c in self.then_callbacks]:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.then_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks])
@@ -3068,7 +3071,7 @@ def transform(
for callback in callbacks:
if not allow_duplicates and callback in self.transforms:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.transforms.extend(callbacks)
@@ -3442,7 +3445,7 @@ def watch(
for callback in callbacks:
if not allow_duplicates and callback in self.watch_callbacks:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.watch_callbacks.extend(callbacks)
diff --git a/docs/api/message.mdx b/docs/api/message.mdx
index 87fe06b..20f9d3d 100644
--- a/docs/api/message.mdx
+++ b/docs/api/message.mdx
@@ -1924,6 +1924,51 @@ def replace_with_slice(
```
+
+
+### shorten
+
+```python
+shorten(max_length: int, sep: str = '...') -> Message
+```
+
+Shortens the message content to at most max\_length characters long by removing the middle of the string
+
+**Parameters:**
+
+* **`max_length`**
+ (`int`)
+ –The maximum length of the message content.
+* **`sep`**
+ (`str`, default:
+ `'...'`
+ )
+ –The separator to use when shortening the content.
+
+**Returns:**
+
+* `Message`
+ –The shortened message.
+
+
+```python
+def shorten(self, max_length: int, sep: str = "...") -> "Message":
+ """
+ Shortens the message content to at most max_length characters long by removing the middle of the string
+
+ Args:
+ max_length: The maximum length of the message content.
+ sep: The separator to use when shortening the content.
+
+ Returns:
+ The shortened message.
+ """
+ new = self.clone()
+ new.content = shorten_string(new.content, max_length, sep=sep)
+ return new
+```
+
+
### strip
@@ -2442,8 +2487,9 @@ Returns a string representation of the slice.
```python
def __str__(self) -> str:
"""Returns a string representation of the slice."""
- content_preview = self.content if self._message else "[detached]"
- return f""
+ content = shorten_string(self.content if self._message else "[detached]", 50)
+ obj = self.obj.__class__.__name__ if self.obj else None
+ return f"MessageSlice(type='{self.type}', start={self.start}, stop={self.stop} obj={obj} content='{content}')"
```
diff --git a/docs/api/model.mdx b/docs/api/model.mdx
index 83284d0..b110948 100644
--- a/docs/api/model.mdx
+++ b/docs/api/model.mdx
@@ -207,7 +207,13 @@ def from_text(
try:
model = (
- cls(**{next(iter(cls.model_fields)): unescape_xml(inner)})
+ cls(
+ **{
+ next(iter(cls.model_fields)): unescape_xml(
+ textwrap.dedent(inner).strip()
+ )
+ }
+ )
if cls.is_simple()
else cls.from_xml(
cls.preprocess_with_cdata(full_text),
@@ -217,7 +223,7 @@ def from_text(
# If our model is relatively simple (only attributes and a single non-element field)
# we should go back and update our non-element field with the extracted content.
- if cls.is_simple_with_attrs():
+ if not cls.is_simple() and cls.is_simple_with_attrs():
name, field = next(
(name, field)
for name, field in cls.model_fields.items()
@@ -228,6 +234,14 @@ def from_text(
unescape_xml(inner).strip(),
)
+ # Walk through any fields which are strings, and dedent them
+
+ for field_name, field_info in cls.model_fields.items():
+ if isinstance(field_info, XmlEntityInfo) and field_info.annotation == str: # noqa: E721
+ model.__dict__[field_name] = textwrap.dedent(
+ model.__dict__[field_name]
+ ).strip()
+
extracted.append((model, slice_))
except Exception as e: # noqa: BLE001
extracted.append((e, slice_))
@@ -485,7 +499,7 @@ def preprocess_with_cdata(cls, content: str) -> str:
needs_escaping = escape_xml(unescape_xml(content)) != content
if is_basic_field and not is_already_cdata and needs_escaping:
- content = f""
+ content = f""
return f"<{field_name}{tag_attrs}>{content}{field_name}>"
@@ -514,7 +528,7 @@ to_pretty_xml(
skip_empty: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
- **kwargs: Any,
+ **_: Any,
) -> str
```
@@ -533,7 +547,7 @@ def to_pretty_xml(
skip_empty: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
- **kwargs: t.Any,
+ **_: t.Any,
) -> str:
"""
Converts the model to a pretty XML string with indents and newlines.
@@ -546,22 +560,7 @@ def to_pretty_xml(
exclude_none=exclude_none,
exclude_unset=exclude_unset,
)
- tree = self._postprocess_with_cdata(tree)
-
- ET.indent(tree, " ")
- pretty_encoded_xml = str(
- ET.tostring(
- tree,
- short_empty_elements=False,
- encoding="utf-8",
- **kwargs,
- ).decode(),
- )
-
- # Now we can go back and safely unescape the XML
- # that we observe between any CDATA tags
-
- return unescape_cdata_tags(pretty_encoded_xml)
+ return self._serialize_tree_prettily(tree)
```
@@ -676,14 +675,19 @@ xml_example() -> str
Returns an example XML representation of the given class.
-Models should typically override this method to provide a more complex example.
+This method generates a pretty-printed XML string that includes:
+- Example values for each field, taken from the `example` argument
+in a field constructor.
+- Field descriptions as XML comments, derived from the field's
+docstring or the `description` argument.
-By default, this method returns a hollow XML scaffold one layer deep.
+Note: This implementation is designed for models with flat structures
+and does not recursively generate examples for nested models.
**Returns:**
* `str`
- –A string containing the XML representation of the class.
+ –A string containing the pretty-printed XML example.
```python
@@ -692,27 +696,55 @@ def xml_example(cls) -> str:
"""
Returns an example XML representation of the given class.
- Models should typically override this method to provide a more complex example.
+ This method generates a pretty-printed XML string that includes:
+ - Example values for each field, taken from the `example` argument
+ in a field constructor.
+ - Field descriptions as XML comments, derived from the field's
+ docstring or the `description` argument.
- By default, this method returns a hollow XML scaffold one layer deep.
+ Note: This implementation is designed for models with flat structures
+ and does not recursively generate examples for nested models.
Returns:
- A string containing the XML representation of the class.
+ A string containing the pretty-printed XML example.
"""
if cls.is_simple():
- return cls.xml_tags()
-
- schema = cls.model_json_schema()
- properties = schema["properties"]
- structure = {cls.__xml_tag__: dict.fromkeys(properties)}
- xml_string = xmltodict.unparse(
- structure,
- pretty=True,
- full_document=False,
- indent=" ",
- short_empty_elements=True,
- )
- return t.cast("str", xml_string) # Bad type hints in xmltodict
+ field_info = next(iter(cls.model_fields.values()))
+ example = str(next(iter(field_info.examples or []), ""))
+ return f"<{cls.__xml_tag__}>{escape_xml(example)}{cls.__xml_tag__}>"
+
+ lines = []
+ attribute_parts = []
+ element_fields = {}
+
+ for field_name, field_info in cls.model_fields.items():
+ if (
+ isinstance(field_info, XmlEntityInfo)
+ and field_info.location == EntityLocation.ATTRIBUTE
+ ):
+ path = field_info.path or field_name
+ example = str(next(iter(field_info.examples or []), "")).replace('"', """)
+ attribute_parts.append(f'{path}="{example}"')
+ else:
+ element_fields[field_name] = field_info
+
+ attr_string = (" " + " ".join(attribute_parts)) if attribute_parts else ""
+ lines.append(f"<{cls.__xml_tag__}{attr_string}>")
+
+ for field_name, field_info in element_fields.items():
+ path = (isinstance(field_info, XmlEntityInfo) and field_info.path) or field_name
+ description = field_info.description
+ example = str(next(iter(field_info.examples or []), ""))
+
+ if description:
+ lines.append(f" ")
+ if example:
+ lines.append(f" <{path}>{escape_xml(example)}{path}>")
+ else:
+ lines.append(f" <{path}/>")
+
+ lines.append(f"{cls.__xml_tag__}>")
+ return "\n".join(lines)
```
diff --git a/docs/api/prompt.mdx b/docs/api/prompt.mdx
index bbf3ef9..68c010b 100644
--- a/docs/api/prompt.mdx
+++ b/docs/api/prompt.mdx
@@ -277,9 +277,10 @@ def bind(
)
async def run(*args: P.args, **kwargs: P.kwargs) -> R:
- name = get_qualified_name(self.func) if self.func else ""
+ name = get_callable_name(self.func, short=True) if self.func else ""
with dn.task_span(
- f"prompt - {name}",
+ name,
+ tags=["rigging/prompt"],
attributes={"prompt_name": name, "rigging.type": "prompt.run"},
):
dn.log_inputs(**self._bind_args(*args, **kwargs))
@@ -380,10 +381,11 @@ def bind_many(
)
async def run_many(count: int, /, *args: P.args, **kwargs: P.kwargs) -> list[R]:
- name = get_qualified_name(self.func) if self.func else ""
+ name = get_callable_name(self.func, short=True) if self.func else ""
with dn.task_span(
- f"prompt - {name} (x{count})",
+ f"{name} (x{count})",
label=f"prompt_{name}",
+ tags=["rigging/prompt"],
attributes={"prompt_name": name, "rigging.type": "prompt.run_many"},
) as span:
dn.log_inputs(**self._bind_args(*args, **kwargs))
@@ -659,7 +661,7 @@ def map(
for callback in callbacks:
if not asyncio.iscoroutinefunction(callback):
raise TypeError(
- f"Callback '{get_qualified_name(callback)}' must be an async function",
+ f"Callback '{get_callable_name(callback)}' must be an async function",
)
if allow_duplicates:
@@ -667,7 +669,7 @@ def map(
if callback in self.map_callbacks:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.map_callbacks.extend(callbacks)
@@ -1038,7 +1040,7 @@ def then(
for callback in callbacks:
if not asyncio.iscoroutinefunction(callback):
raise TypeError(
- f"Callback '{get_qualified_name(callback)}' must be an async function",
+ f"Callback '{get_callable_name(callback)}' must be an async function",
)
if allow_duplicates:
@@ -1046,7 +1048,7 @@ def then(
if callback in self.then_callbacks:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.then_callbacks.extend(callbacks)
@@ -1135,7 +1137,7 @@ def watch(
for callback in callbacks:
if not allow_duplicates and callback in self.watch_callbacks:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.watch_callbacks.extend(callbacks)
diff --git a/docs/api/tools.mdx b/docs/api/tools.mdx
index 3627efa..d307c08 100644
--- a/docs/api/tools.mdx
+++ b/docs/api/tools.mdx
@@ -147,7 +147,8 @@ async def handle_tool_call( # noqa: PLR0912
from rigging.message import ContentText, ContentTypes, Message
with dn.task_span(
- f"tool - {self.name}",
+ self.name,
+ tags=["rigging/tool"],
attributes={"tool_name": self.name, "rigging.type": "tool"},
) as task:
dn.log_input("tool_call", tool_call)
@@ -221,7 +222,9 @@ async def handle_tool_call( # noqa: PLR0912
message.content_parts = [ContentText(text=str(result))]
if self.truncate:
- message = message.truncate(self.truncate)
+ # Use shorten instead of truncate to try and preserve
+ # the most context possible.
+ message = message.shorten(self.truncate)
return message, stop
```
diff --git a/docs/api/util.mdx b/docs/api/util.mdx
index 9f6bd12..bf24e42 100644
--- a/docs/api/util.mdx
+++ b/docs/api/util.mdx
@@ -131,9 +131,7 @@ def escape_xml(xml_string: str) -> str:
"""
escaped = xml_string.replace(r"&", "&")
escaped = escaped.replace(r"<", "<")
- escaped = escaped.replace(r">", ">")
- escaped = escaped.replace(r"'", "'")
- return escaped.replace(r'"', """)
+ return escaped.replace(r">", ">")
```
@@ -234,58 +232,90 @@ def flatten_list(nested_list: t.Iterable[t.Iterable[t.Any] | t.Any]) -> list[t.A
-get\_qualified\_name
---------------------
+get\_callable\_name
+-------------------
```python
-get_qualified_name(obj: Callable[..., Any]) -> str
+get_callable_name(
+ obj: Callable[..., Any], *, short: bool = False
+) -> str
```
-Return a best guess at the qualified name of a callable object.
-This includes functions, methods, and callable classes.
+Return a best-effort, comprehensive name for a callable object.
+
+This function handles a wide variety of callables, including regular
+functions, methods, lambdas, partials, wrapped functions, and callable
+class instances.
+
+**Parameters:**
+
+* **`obj`**
+ (`Callable[..., Any]`)
+ –The callable object to name.
+* **`short`**
+ (`bool`, default:
+ `False`
+ )
+ –If True, returns a shorter name suitable for logs or UI,
+ typically omitting the module path. The class name is
+ retained for methods.
+
+**Returns:**
+
+* `str`
+ –A string representing the callable's name.
```python
-def get_qualified_name(obj: t.Callable[..., t.Any]) -> str:
- """
- Return a best guess at the qualified name of a callable object.
- This includes functions, methods, and callable classes.
+def get_callable_name(obj: t.Callable[..., t.Any], *, short: bool = False) -> str:
"""
- if obj is None or not callable(obj):
- return "unknown"
+ Return a best-effort, comprehensive name for a callable object.
- module = inspect.getmodule(obj)
- module_name = module.__name__ if module else ""
+ This function handles a wide variety of callables, including regular
+ functions, methods, lambdas, partials, wrapped functions, and callable
+ class instances.
+
+ Args:
+ obj: The callable object to name.
+ short: If True, returns a shorter name suitable for logs or UI,
+ typically omitting the module path. The class name is
+ retained for methods.
+
+ Returns:
+ A string representing the callable's name.
+ """
+ if not callable(obj):
+ return repr(obj)
- # Partial functions
if isinstance(obj, functools.partial):
- base_name = get_qualified_name(obj.func)
- return f"partial({base_name})"
-
- # Methods
- if isinstance(obj, types.MethodType):
- class_name = obj.__self__.__class__.__name__
- method_name = obj.__func__.__name__
- return f"{class_name}.{method_name}"
-
- # Functions
- if isinstance(obj, types.FunctionType):
- # Check if it's a wrapped function
- if hasattr(obj, "__wrapped__"):
- original_name = get_qualified_name(obj.__wrapped__)
- return f"wrapped({original_name})"
-
- name = obj.__qualname__ or obj.__name__
- return f"{module_name}.{name}" if module_name != "__main__" else name
-
- # Callable classes
- if callable(obj):
- if isinstance(obj, type):
- return obj.__qualname__
- return f"{obj.__class__.__qualname__}.__call__"
-
- # Fallback
- return obj.__class__.__qualname__
+ inner_name = get_callable_name(obj.func, short=short)
+ return f"partial({inner_name})"
+
+ unwrapped = obj
+ with contextlib.suppress(Exception):
+ unwrapped = inspect.unwrap(obj)
+
+ name = getattr(unwrapped, "__qualname__", None)
+
+ if name is None:
+ name = getattr(unwrapped, "__name__", None)
+
+ if name is None:
+ if hasattr(obj, "__class__"):
+ name = getattr(obj.__class__, "__qualname__", obj.__class__.__name__)
+ else:
+ return repr(obj)
+
+ if short:
+ return str(name).split(".")[-1] # Return only the last part of the name
+
+ with contextlib.suppress(Exception):
+ if module := inspect.getmodule(unwrapped):
+ module_name = module.__name__
+ if module_name and module_name not in ("builtins", "__main__"):
+ return f"{module_name}.{name}"
+
+ return str(name)
```
diff --git a/rigging/chat.py b/rigging/chat.py
index 8f5f964..de72021 100644
--- a/rigging/chat.py
+++ b/rigging/chat.py
@@ -55,7 +55,7 @@
tools_to_json_transform,
tools_to_json_with_tag_transform,
)
-from rigging.util import flatten_list, get_qualified_name
+from rigging.util import flatten_list, get_callable_name
if t.TYPE_CHECKING:
from dreadnode.metric import Scorer, ScorerCallable
@@ -217,6 +217,20 @@ def conversation(self) -> str:
conversation += f"\n\n[error]: {self.error}"
return conversation
+ def __str__(self) -> str:
+ formatted = f"--- Chat {self.uuid}"
+ formatted += f"\n |- timestamp: {self.timestamp.isoformat()}"
+ if self.usage:
+ formatted += f"\n |- usage: {self.usage}"
+ if self.generator:
+ formatted += f"\n |- generator: {self.generator.to_identifier(short=True)}"
+ if self.stop_reason:
+ formatted += f"\n |- stop_reason: {self.stop_reason}"
+ if self.metadata:
+ formatted += f"\n |- metadata: {self.metadata}"
+ formatted += f"\n\n{self.conversation}\n"
+ return formatted
+
@property
def message_dicts(self) -> list[MessageDict]:
"""Returns the chat as a minimal message dictionaries."""
@@ -732,7 +746,7 @@ def with_parent(self, parent: "PipelineStep") -> "PipelineStep":
raise RuntimeError("Unable to set parent step")
def __str__(self) -> str:
- callback_name = get_qualified_name(self.callback) if self.callback else "None"
+ callback_name = get_callable_name(self.callback) if self.callback else "None"
self_str = f"PipelineStep(pipeline={id(self.pipeline)}, state={self.state}, chats={len(self.chats)}, callback={callback_name})"
if self.parent is not None:
self_str += f" <- {self.parent!s}"
@@ -762,13 +776,16 @@ def depth(self) -> int:
def _wrap_watch_callback(callback: WatchChatCallback) -> WatchChatCallback:
import dreadnode as dn
- callback_name = get_qualified_name(callback)
- return dn.task(
- name=f"watch - {callback_name}",
- attributes={"rigging.type": "chat_pipeline.watch_callback"},
- log_inputs=True,
- log_output=False,
- )(callback)
+ callback_name = get_callable_name(callback)
+
+ async def wrapped_callback(chats: list[Chat]) -> None:
+ with dn.span(
+ name=callback_name,
+ attributes={"rigging.type": "chat_pipeline.watch_callback"},
+ ):
+ await callback(chats)
+
+ return wrapped_callback
# Pipeline
@@ -803,7 +820,7 @@ def __init__(
"""How to handle failures in the pipeline unless overridden in calls."""
self.caching: CacheMode | None = None
"""How to handle cache_control entries on messages."""
- self.task_name: str = generator.to_identifier(short=True)
+ self.task_name: str = f"Chat with {generator.to_identifier(short=True)}"
"""The name of the pipeline task, used for logging and debugging."""
self.scorers: list[Scorer[Chat]] = []
"""List of dreadnode scorers to evaluate the generated chat upon completion."""
@@ -897,7 +914,7 @@ async def log(chats: list[Chat]) -> None:
for callback in callbacks:
if not allow_duplicates and callback in self.watch_callbacks:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.watch_callbacks.extend(callbacks)
@@ -1120,7 +1137,7 @@ async def process(chat: Chat) -> Chat | None:
if callback in [c[0] for c in self.then_callbacks]:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.then_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks])
@@ -1164,7 +1181,7 @@ async def process(chats: list[Chat]) -> list[Chat]:
if callback in [c[0] for c in self.map_callbacks]:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.map_callbacks.extend([(callback, max_depth, as_task) for callback in callbacks])
@@ -1208,7 +1225,7 @@ async def post_transform(chat: Chat) -> Chat | None:
for callback in callbacks:
if not allow_duplicates and callback in self.transforms:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.transforms.extend(callbacks)
@@ -1576,8 +1593,6 @@ async def _process_then_callback(
callback: ThenChatCallback,
state: CallbackState,
) -> None:
- callback_name = get_qualified_name(callback)
-
async def complete() -> None:
state.completed = True
state.ready_event.set()
@@ -1598,7 +1613,7 @@ async def complete() -> None:
if not inspect.isasyncgen(result):
raise TypeError(
- f"Callback '{callback_name}' must return a Chat, PipelineStepGenerator, or None",
+ f"Callback '{get_callable_name(callback)}' must return a Chat, PipelineStepGenerator, or None",
)
generator = t.cast(
@@ -1702,14 +1717,14 @@ async def _step( # noqa: PLR0915, PLR0912
try:
messages = apply_cache_mode_to_messages(self.caching, messages)
- with dn.task_span(
+ with dn.span(
f"generate - {self.generator.to_identifier(short=True)}",
attributes={"rigging.type": "chat_pipeline.generate"},
):
- dn.log_input("messages", messages)
- dn.log_input("params", params)
+ # dn.log_input("messages", messages)
+ # dn.log_input("params", params)
generated = await self.generator.generate_messages(messages, params)
- dn.log_output("generated", generated)
+ # dn.log_output("generated", generated)
# If we got a total failure here for generation as a whole,
# we can't distinguish between incoming messages in terms
@@ -1812,7 +1827,7 @@ async def _step( # noqa: PLR0915, PLR0912
# Then callbacks
for then_callback, max_depth, as_task in self.then_callbacks:
- callback_name = get_qualified_name(then_callback)
+ callback_name = get_callable_name(then_callback, short=True)
states = [
self.CallbackState(
@@ -1825,7 +1840,7 @@ async def _step( # noqa: PLR0915, PLR0912
callback_task = (
dn.task(
- name=f"then - {callback_name}",
+ name=callback_name,
attributes={"rigging.type": "chat_pipeline.then_callback"},
log_inputs=True,
log_output=True,
@@ -1904,11 +1919,11 @@ async def _step( # noqa: PLR0915, PLR0912
# Map callbacks
for map_callback, max_depth, as_task in self.map_callbacks:
- callback_name = get_qualified_name(map_callback)
+ callback_name = get_callable_name(map_callback, short=True)
map_task = (
dn.task(
- name=f"map - {callback_name}",
+ name=callback_name,
attributes={"rigging.type": "chat_pipeline.map_callback"},
log_inputs=True,
log_output=True,
@@ -2056,8 +2071,9 @@ async def run(
last: PipelineStep | None = None
with dn.task_span(
- name or f"pipeline - {self.task_name}",
+ name or self.task_name,
label=name or f"pipeline_{self.task_name}",
+ tags=["rigging/pipeline"],
attributes={"rigging.type": "chat_pipeline.run"},
) as task:
dn.log_inputs(
@@ -2156,8 +2172,9 @@ async def run_many(
last: PipelineStep | None = None
with dn.task_span(
- name or f"pipeline - {self.task_name} (x{count})",
+ name or f"{self.task_name} (x{count})",
label=name or f"pipeline_many_{self.task_name}",
+ tags=["rigging/pipeline"],
attributes={"rigging.type": "chat_pipeline.run_many"},
) as task:
dn.log_inputs(
@@ -2316,8 +2333,9 @@ async def run_batch(
last: PipelineStep | None = None
with dn.task_span(
- name or f"pipeline - {self.task_name} (batch x{count})",
+ name or f"{self.task_name} (batch x{count})",
label=name or f"pipeline_batch_{self.task_name}",
+ tags=["rigging/pipeline"],
attributes={"rigging.type": "chat_pipeline.run_batch"},
) as task:
dn.log_inputs(
diff --git a/rigging/completion.py b/rigging/completion.py
index 2f0826b..ee8c96e 100644
--- a/rigging/completion.py
+++ b/rigging/completion.py
@@ -18,7 +18,7 @@
from rigging.generator import GenerateParams, Generator, get_generator
from rigging.generator.base import GeneratedText, StopReason, Usage
from rigging.parsing import parse_many
-from rigging.util import get_qualified_name
+from rigging.util import get_callable_name
if t.TYPE_CHECKING:
from dreadnode import Span
@@ -577,13 +577,16 @@ async def _watch_callback(self, completions: list[Completion]) -> None:
def wrap_watch_callback(
callback: WatchCompletionCallback,
) -> t.Callable[[list[Completion]], t.Awaitable[None]]:
- callback_name = get_qualified_name(callback)
- return dn.task(
- name=f"watch - {callback_name}",
- attributes={"rigging.type": "completion_pipeline.watch_callback"},
- log_inputs=True,
- log_output=False,
- )(callback)
+ callback_name = get_callable_name(callback, short=True)
+
+ async def wrapped_callback(completions: list[Completion]) -> None:
+ with dn.span(
+ name=callback_name,
+ attributes={"rigging.type": "chat_pipeline.watch_callback"},
+ ):
+ await callback(completions)
+
+ return wrapped_callback
traced_callbacks = [wrap_watch_callback(callback) for callback in self.watch_callbacks]
coros = [callback(completions) for callback in traced_callbacks]
@@ -636,9 +639,9 @@ async def _post_run(
# previous calls being ran.
for map_callback in self.map_callbacks:
- callback_name = get_qualified_name(map_callback)
+ callback_name = get_callable_name(map_callback)
traced_map_callback = dn.task(
- name=f"map - {callback_name}",
+ name=callback_name,
attributes={"rigging.type": "completion_pipeline.map_callback"},
log_inputs=True,
log_output=True,
@@ -650,9 +653,9 @@ async def _post_run(
)
for then_callback in self.then_callbacks:
- callback_name = get_qualified_name(then_callback)
+ callback_name = get_callable_name(then_callback, short=True)
traced_then_callback = dn.task(
- name=f"then - {callback_name}",
+ name=callback_name,
attributes={"rigging.type": "completion_pipeline.then_callback"},
log_inputs=True,
log_output=True,
diff --git a/rigging/generator/base.py b/rigging/generator/base.py
index 336b15d..4233b3f 100644
--- a/rigging/generator/base.py
+++ b/rigging/generator/base.py
@@ -310,6 +310,9 @@ def __add__(self, other: "Usage") -> "Usage":
total_tokens=self.total_tokens + other.total_tokens,
)
+ def __str__(self) -> str:
+ return f"in: {self.input_tokens} | out: {self.output_tokens} | total: {self.total_tokens}"
+
GeneratedT = t.TypeVar("GeneratedT", Message, str)
diff --git a/rigging/message.py b/rigging/message.py
index 6741817..03d31d6 100644
--- a/rigging/message.py
+++ b/rigging/message.py
@@ -28,7 +28,12 @@
from rigging.model import Model, ModelT
from rigging.parsing import try_parse_many
from rigging.tools.base import ToolCall
-from rigging.util import AudioFormat, identify_audio_format, shorten_string, truncate_string
+from rigging.util import (
+ AudioFormat,
+ identify_audio_format,
+ shorten_string,
+ truncate_string,
+)
Role = t.Literal["system", "user", "assistant", "tool"]
"""The role of a message. Can be 'system', 'user', 'assistant', or 'tool'."""
@@ -111,8 +116,9 @@ def __len__(self) -> int:
def __str__(self) -> str:
"""Returns a string representation of the slice."""
- content_preview = self.content if self._message else "[detached]"
- return f""
+ content = shorten_string(self.content if self._message else "[detached]", 50)
+ obj = self.obj.__class__.__name__ if self.obj else None
+ return f"MessageSlice(type='{self.type}', start={self.start}, stop={self.stop} obj={obj} content='{content}')"
def clone(self) -> "MessageSlice":
"""
@@ -164,7 +170,7 @@ class ImageUrl(BaseModel):
"""Cache control entry for prompt caching."""
def __str__(self) -> str:
- return f""
+ return f"ContentImageUrl(url='{shorten_string(self.image_url.url, 50)}')"
@classmethod
def from_file(
@@ -1124,6 +1130,21 @@ def truncate(self, max_length: int, suffix: str = "\n[truncated]") -> "Message":
new.content = truncate_string(new.content, max_length, suf=suffix)
return new
+ def shorten(self, max_length: int, sep: str = "...") -> "Message":
+ """
+ Shortens the message content to at most max_length characters long by removing the middle of the string
+
+ Args:
+ max_length: The maximum length of the message content.
+ sep: The separator to use when shortening the content.
+
+ Returns:
+ The shortened message.
+ """
+ new = self.clone()
+ new.content = shorten_string(new.content, max_length, sep=sep)
+ return new
+
@property
def models(self) -> list[Model]:
"""
diff --git a/rigging/model.py b/rigging/model.py
index cd07488..bb145cf 100644
--- a/rigging/model.py
+++ b/rigging/model.py
@@ -5,14 +5,15 @@
import dataclasses
import inspect
import re
+import textwrap
import typing as t
from xml.etree import ElementTree as ET # nosec
import typing_extensions as te
-import xmltodict # type: ignore [import-untyped]
from pydantic import (
BaseModel,
BeforeValidator,
+ ConfigDict,
Field,
SerializationInfo,
ValidationError,
@@ -68,6 +69,8 @@ def __get__(self, _: t.Any, owner: t.Any) -> str:
class Model(BaseXmlModel):
+ model_config = ConfigDict(use_attribute_docstrings=True)
+
def __init_subclass__(
cls,
tag: str | None = None,
@@ -135,16 +138,66 @@ def _postprocess_with_cdata(self, tree: ET.Element) -> ET.Element:
return tree
- # to_xml() doesn't prettify normally, and extended
- # requirements like lxml seemed like poor form for
- # just this feature
+ def _serialize_tree_prettily(
+ self, element: ET.Element, level: int = 0, indent_str: str = " "
+ ) -> str:
+ # Essentially a custom work of ET.indent to better
+ # handle multi-line text so we get:
+ #
+ #
+ # Some text
+ # More text
+ #
+ #
+ # instead of:
+ #
+ # Some text
+ # More text
+
+ indent = indent_str * level
+ lines = []
+
+ attrs = "".join(f' {k}="{escape_xml(v)}"' for k, v in element.attrib.items())
+ tag_with_attrs = f"{element.tag}{attrs}"
+
+ text = element.text and element.text.strip()
+ multiline_text = text and "\n" in text
+ has_children = len(element) > 0
+
+ if not has_children and not multiline_text:
+ if level == 0 and not text:
+ lines.append(f"<{tag_with_attrs}>{element.tag}>")
+ elif not text:
+ lines.append(f"{indent}<{tag_with_attrs} />")
+ else:
+ lines.append(f"{indent}<{tag_with_attrs}>{text}{element.tag}>")
+ return "\n".join(lines)
+
+ lines.append(f"{indent}<{tag_with_attrs}>")
+
+ if text and multiline_text:
+ dedented_text = textwrap.dedent(text).strip()
+ content_indent = indent_str * (level + 1)
+ for line in dedented_text.split("\n"):
+ lines.append(f"{content_indent}{line}" if line.strip() else "") # noqa: PERF401
+
+ for child in element:
+ lines.append(self._serialize_tree_prettily(child, level + 1, indent_str)) # noqa: PERF401
+
+ lines.append(f"{indent}{element.tag}>")
+
+ return "\n".join(lines)
+
+ # to_xml() doesn't prettify normally, and extended requirements
+ # like lxml seemed like poor form for just this feature
+
def to_pretty_xml(
self,
*,
skip_empty: bool = False,
exclude_none: bool = False,
exclude_unset: bool = False,
- **kwargs: t.Any,
+ **_: t.Any,
) -> str:
"""
Converts the model to a pretty XML string with indents and newlines.
@@ -157,22 +210,7 @@ def to_pretty_xml(
exclude_none=exclude_none,
exclude_unset=exclude_unset,
)
- tree = self._postprocess_with_cdata(tree)
-
- ET.indent(tree, " ")
- pretty_encoded_xml = str(
- ET.tostring(
- tree,
- short_empty_elements=False,
- encoding="utf-8",
- **kwargs,
- ).decode(),
- )
-
- # Now we can go back and safely unescape the XML
- # that we observe between any CDATA tags
-
- return unescape_cdata_tags(pretty_encoded_xml)
+ return self._serialize_tree_prettily(tree)
def to_xml(
self,
@@ -293,27 +331,55 @@ def xml_example(cls) -> str:
"""
Returns an example XML representation of the given class.
- Models should typically override this method to provide a more complex example.
+ This method generates a pretty-printed XML string that includes:
+ - Example values for each field, taken from the `example` argument
+ in a field constructor.
+ - Field descriptions as XML comments, derived from the field's
+ docstring or the `description` argument.
- By default, this method returns a hollow XML scaffold one layer deep.
+ Note: This implementation is designed for models with flat structures
+ and does not recursively generate examples for nested models.
Returns:
- A string containing the XML representation of the class.
+ A string containing the pretty-printed XML example.
"""
if cls.is_simple():
- return cls.xml_tags()
-
- schema = cls.model_json_schema()
- properties = schema["properties"]
- structure = {cls.__xml_tag__: dict.fromkeys(properties)}
- xml_string = xmltodict.unparse(
- structure,
- pretty=True,
- full_document=False,
- indent=" ",
- short_empty_elements=True,
- )
- return t.cast("str", xml_string) # Bad type hints in xmltodict
+ field_info = next(iter(cls.model_fields.values()))
+ example = str(next(iter(field_info.examples or []), ""))
+ return f"<{cls.__xml_tag__}>{escape_xml(example)}{cls.__xml_tag__}>"
+
+ lines = []
+ attribute_parts = []
+ element_fields = {}
+
+ for field_name, field_info in cls.model_fields.items():
+ if (
+ isinstance(field_info, XmlEntityInfo)
+ and field_info.location == EntityLocation.ATTRIBUTE
+ ):
+ path = field_info.path or field_name
+ example = str(next(iter(field_info.examples or []), "")).replace('"', """)
+ attribute_parts.append(f'{path}="{example}"')
+ else:
+ element_fields[field_name] = field_info
+
+ attr_string = (" " + " ".join(attribute_parts)) if attribute_parts else ""
+ lines.append(f"<{cls.__xml_tag__}{attr_string}>")
+
+ for field_name, field_info in element_fields.items():
+ path = (isinstance(field_info, XmlEntityInfo) and field_info.path) or field_name
+ description = field_info.description
+ example = str(next(iter(field_info.examples or []), ""))
+
+ if description:
+ lines.append(f" ")
+ if example:
+ lines.append(f" <{path}>{escape_xml(example)}{path}>")
+ else:
+ lines.append(f" <{path}/>")
+
+ lines.append(f"{cls.__xml_tag__}>")
+ return "\n".join(lines)
@classmethod
def ensure_valid(cls) -> None:
@@ -383,7 +449,7 @@ def wrap_with_cdata(match: re.Match[str]) -> str:
needs_escaping = escape_xml(unescape_xml(content)) != content
if is_basic_field and not is_already_cdata and needs_escaping:
- content = f""
+ content = f""
return f"<{field_name}{tag_attrs}>{content}{field_name}>"
@@ -504,7 +570,13 @@ def from_text(
try:
model = (
- cls(**{next(iter(cls.model_fields)): unescape_xml(inner)})
+ cls(
+ **{
+ next(iter(cls.model_fields)): unescape_xml(
+ textwrap.dedent(inner).strip()
+ )
+ }
+ )
if cls.is_simple()
else cls.from_xml(
cls.preprocess_with_cdata(full_text),
@@ -514,7 +586,7 @@ def from_text(
# If our model is relatively simple (only attributes and a single non-element field)
# we should go back and update our non-element field with the extracted content.
- if cls.is_simple_with_attrs():
+ if not cls.is_simple() and cls.is_simple_with_attrs():
name, field = next(
(name, field)
for name, field in cls.model_fields.items()
@@ -525,6 +597,14 @@ def from_text(
unescape_xml(inner).strip(),
)
+ # Walk through any fields which are strings, and dedent them
+
+ for field_name, field_info in cls.model_fields.items():
+ if isinstance(field_info, XmlEntityInfo) and field_info.annotation == str: # noqa: E721
+ model.__dict__[field_name] = textwrap.dedent(
+ model.__dict__[field_name]
+ ).strip()
+
extracted.append((model, slice_))
except Exception as e: # noqa: BLE001
extracted.append((e, slice_))
diff --git a/rigging/prompt.py b/rigging/prompt.py
index 601be2f..5942435 100644
--- a/rigging/prompt.py
+++ b/rigging/prompt.py
@@ -25,7 +25,7 @@
from rigging.message import Message
from rigging.model import Model, SystemErrorModel, ValidationErrorModel, make_primitive
from rigging.tools import Tool
-from rigging.util import escape_xml, get_qualified_name, to_snake, to_xml_tag
+from rigging.util import escape_xml, get_callable_name, to_snake, to_xml_tag
DEFAULT_DOC = "Convert the following inputs to outputs ({func_name})."
"""Default docstring if none is provided to a prompt function."""
@@ -563,10 +563,10 @@ async def _then_parse(self, chat: Chat) -> PipelineStepContextManager | None:
# A bit weird, but we need from_chat to properly handle
# wrapping Chat output types inside lists/dataclasses
with dn.task_span(
- f"prompt parse - {self.output.tag}",
+ f"parse - {self.output.tag}",
attributes={"rigging.type": "prompt.parse"},
):
- dn.log_input("message", chat.last)
+ dn.log_input("message", str(chat.last))
self.output.from_chat(chat)
except ValidationError as e:
next_pipeline.add(
@@ -642,7 +642,7 @@ async def summarize(text: str) -> str:
for callback in callbacks:
if not allow_duplicates and callback in self.watch_callbacks:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.watch_callbacks.extend(callbacks)
@@ -680,7 +680,7 @@ async def summarize(text: str) -> str:
for callback in callbacks:
if not asyncio.iscoroutinefunction(callback):
raise TypeError(
- f"Callback '{get_qualified_name(callback)}' must be an async function",
+ f"Callback '{get_callable_name(callback)}' must be an async function",
)
if allow_duplicates:
@@ -688,7 +688,7 @@ async def summarize(text: str) -> str:
if callback in self.then_callbacks:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.then_callbacks.extend(callbacks)
@@ -726,7 +726,7 @@ async def summarize(text: str) -> str:
for callback in callbacks:
if not asyncio.iscoroutinefunction(callback):
raise TypeError(
- f"Callback '{get_qualified_name(callback)}' must be an async function",
+ f"Callback '{get_callable_name(callback)}' must be an async function",
)
if allow_duplicates:
@@ -734,7 +734,7 @@ async def summarize(text: str) -> str:
if callback in self.map_callbacks:
raise ValueError(
- f"Callback '{get_qualified_name(callback)}' is already registered.",
+ f"Callback '{get_callable_name(callback)}' is already registered.",
)
self.map_callbacks.extend(callbacks)
@@ -850,9 +850,10 @@ def say_hello(name: str) -> str:
)
async def run(*args: P.args, **kwargs: P.kwargs) -> R:
- name = get_qualified_name(self.func) if self.func else ""
+ name = get_callable_name(self.func, short=True) if self.func else ""
with dn.task_span(
- f"prompt - {name}",
+ name,
+ tags=["rigging/prompt"],
attributes={"prompt_name": name, "rigging.type": "prompt.run"},
):
dn.log_inputs(**self._bind_args(*args, **kwargs))
@@ -913,10 +914,11 @@ def say_hello(name: str) -> str:
)
async def run_many(count: int, /, *args: P.args, **kwargs: P.kwargs) -> list[R]:
- name = get_qualified_name(self.func) if self.func else ""
+ name = get_callable_name(self.func, short=True) if self.func else ""
with dn.task_span(
- f"prompt - {name} (x{count})",
+ f"{name} (x{count})",
label=f"prompt_{name}",
+ tags=["rigging/prompt"],
attributes={"prompt_name": name, "rigging.type": "prompt.run_many"},
) as span:
dn.log_inputs(**self._bind_args(*args, **kwargs))
diff --git a/rigging/tools/base.py b/rigging/tools/base.py
index e6d0ac6..c428df7 100644
--- a/rigging/tools/base.py
+++ b/rigging/tools/base.py
@@ -30,7 +30,7 @@
make_from_schema,
make_from_signature,
)
-from rigging.util import deref_json
+from rigging.util import deref_json, shorten_string
if t.TYPE_CHECKING:
from rigging.message import Message
@@ -116,7 +116,8 @@ class ToolCall(BaseModel):
function: FunctionCall
def __str__(self) -> str:
- return f""
+ arguments = shorten_string(self.function.arguments, max_length=50)
+ return f"ToolCall({self.function.name}({arguments}), id='{self.id}')"
@property
def name(self) -> str:
@@ -359,7 +360,8 @@ async def handle_tool_call( # noqa: PLR0912
from rigging.message import ContentText, ContentTypes, Message
with dn.task_span(
- f"tool - {self.name}",
+ self.name,
+ tags=["rigging/tool"],
attributes={"tool_name": self.name, "rigging.type": "tool"},
) as task:
dn.log_input("tool_call", tool_call)
@@ -433,7 +435,9 @@ async def handle_tool_call( # noqa: PLR0912
message.content_parts = [ContentText(text=str(result))]
if self.truncate:
- message = message.truncate(self.truncate)
+ # Use shorten instead of truncate to try and preserve
+ # the most context possible.
+ message = message.shorten(self.truncate)
return message, stop
diff --git a/rigging/transform/json_tools.py b/rigging/transform/json_tools.py
index 459978f..a95dabc 100644
--- a/rigging/transform/json_tools.py
+++ b/rigging/transform/json_tools.py
@@ -16,7 +16,7 @@
from rigging.model import Model
from rigging.tools.base import FunctionCall, ToolCall, ToolDefinition, ToolResponse
from rigging.transform.base import PostTransform, Transform
-from rigging.util import extract_json_objects
+from rigging.util import extract_json_objects, shorten_string
if t.TYPE_CHECKING:
from rigging.chat import Chat
@@ -54,7 +54,8 @@ class JsonInXmlToolCall(Model):
parameters: str
def __str__(self) -> str:
- return f""
+ parameters = shorten_string(self.parameters, max_length=50)
+ return f"JsonInXmlToolCall(name='{self.name}', parameters='{parameters}', id='{self.id}')"
class JsonToolCall(Model):
@@ -62,7 +63,8 @@ class JsonToolCall(Model):
content: str
def __str__(self) -> str:
- return f""
+ content = shorten_string(self.content, max_length=50)
+ return f"JsonToolCall(content='{content}', id='{self.id}')"
# Prompts
diff --git a/rigging/transform/xml_tools.py b/rigging/transform/xml_tools.py
index 1baf998..0f43d28 100644
--- a/rigging/transform/xml_tools.py
+++ b/rigging/transform/xml_tools.py
@@ -18,6 +18,7 @@
from rigging.model import Model
from rigging.tools.base import FunctionCall, Tool, ToolCall, ToolResponse
from rigging.transform.base import PostTransform, Transform
+from rigging.util import shorten_string
if t.TYPE_CHECKING:
from rigging.chat import Chat
@@ -114,7 +115,8 @@ class XmlToolCall(Model, tag=TOOL_CALL_TAG):
parameters: str
def __str__(self) -> str:
- return f""
+ parameters = shorten_string(self.parameters, max_length=50)
+ return f"XmlToolCall(name='{self.name}' parameters='{parameters}')"
XML_TOOLS_PREFIX = f"""\
diff --git a/rigging/util.py b/rigging/util.py
index cfb6630..7f0cce4 100644
--- a/rigging/util.py
+++ b/rigging/util.py
@@ -3,10 +3,10 @@
"""
import asyncio
+import contextlib
import functools
import inspect
import re
-import types
import typing as t
from json import JSONDecoder
from threading import Thread
@@ -145,9 +145,7 @@ def escape_xml(xml_string: str) -> str:
"""
escaped = xml_string.replace(r"&", "&")
escaped = escaped.replace(r"<", "<")
- escaped = escaped.replace(r">", ">")
- escaped = escaped.replace(r"'", "'")
- return escaped.replace(r'"', """)
+ return escaped.replace(r">", ">")
def unescape_cdata_tags(xml_string: str) -> str:
@@ -183,46 +181,55 @@ def to_xml_tag(text: str) -> str:
# Name resolution
-def get_qualified_name(obj: t.Callable[..., t.Any]) -> str:
+def get_callable_name(obj: t.Callable[..., t.Any], *, short: bool = False) -> str:
"""
- Return a best guess at the qualified name of a callable object.
- This includes functions, methods, and callable classes.
- """
- if obj is None or not callable(obj):
- return "unknown"
+ Return a best-effort, comprehensive name for a callable object.
+
+ This function handles a wide variety of callables, including regular
+ functions, methods, lambdas, partials, wrapped functions, and callable
+ class instances.
- module = inspect.getmodule(obj)
- module_name = module.__name__ if module else ""
+ Args:
+ obj: The callable object to name.
+ short: If True, returns a shorter name suitable for logs or UI,
+ typically omitting the module path. The class name is
+ retained for methods.
+
+ Returns:
+ A string representing the callable's name.
+ """
+ if not callable(obj):
+ return repr(obj)
- # Partial functions
if isinstance(obj, functools.partial):
- base_name = get_qualified_name(obj.func)
- return f"partial({base_name})"
-
- # Methods
- if isinstance(obj, types.MethodType):
- class_name = obj.__self__.__class__.__name__
- method_name = obj.__func__.__name__
- return f"{class_name}.{method_name}"
-
- # Functions
- if isinstance(obj, types.FunctionType):
- # Check if it's a wrapped function
- if hasattr(obj, "__wrapped__"):
- original_name = get_qualified_name(obj.__wrapped__)
- return f"wrapped({original_name})"
-
- name = obj.__qualname__ or obj.__name__
- return f"{module_name}.{name}" if module_name != "__main__" else name
-
- # Callable classes
- if callable(obj):
- if isinstance(obj, type):
- return obj.__qualname__
- return f"{obj.__class__.__qualname__}.__call__"
-
- # Fallback
- return obj.__class__.__qualname__
+ inner_name = get_callable_name(obj.func, short=short)
+ return f"partial({inner_name})"
+
+ unwrapped = obj
+ with contextlib.suppress(Exception):
+ unwrapped = inspect.unwrap(obj)
+
+ name = getattr(unwrapped, "__qualname__", None)
+
+ if name is None:
+ name = getattr(unwrapped, "__name__", None)
+
+ if name is None:
+ if hasattr(obj, "__class__"):
+ name = getattr(obj.__class__, "__qualname__", obj.__class__.__name__)
+ else:
+ return repr(obj)
+
+ if short:
+ return str(name).split(".")[-1] # Return only the last part of the name
+
+ with contextlib.suppress(Exception):
+ if module := inspect.getmodule(unwrapped):
+ module_name = module.__name__
+ if module_name and module_name not in ("builtins", "__main__"):
+ return f"{module_name}.{name}"
+
+ return str(name)
# Formatting
diff --git a/tests/test_model.py b/tests/test_model.py
new file mode 100644
index 0000000..a3a8385
--- /dev/null
+++ b/tests/test_model.py
@@ -0,0 +1,161 @@
+import typing as t
+from textwrap import dedent
+from xml.sax.saxutils import escape
+
+import pytest
+
+from rigging.model import Model, attr, element
+
+# mypy: disable-error-code=empty-body
+# ruff: noqa: S101, PLR2004, ARG001, PT011, SLF001
+
+
+class SimpleModel(Model):
+ """A simple model to test basic element generation."""
+
+ content: str = element(examples=["Hello, World!"])
+ """The main content of the model."""
+
+
+class NoExampleModel(Model):
+ """A model to test fallback to an empty element when no example is given."""
+
+ name: str
+ """The name of the entity."""
+
+
+class AttrAndElementModel(Model):
+ """Tests a model with both an attribute and a child element."""
+
+ id: int = attr(examples=[123])
+ """The unique identifier (attribute)."""
+ value: str = element(examples=["Some value"])
+ """The primary value (element)."""
+
+
+class DocstringDescriptionModel(Model):
+ """Tests that field docstrings are correctly used as descriptions."""
+
+ field1: str = element(examples=["val1"])
+ """This is the description for field1."""
+ field2: bool = element(examples=[True])
+ """This is the description for field2."""
+
+
+class ParameterDescriptionModel(Model):
+ """Tests that the `description` parameter overrides a field's docstring."""
+
+ param: str = element(
+ examples=["override"], description="This description is from the `description` parameter."
+ )
+ """This docstring should be ignored in the XML example."""
+
+
+class SpecialCharsModel(Model):
+ """Tests proper escaping of special XML characters in examples and comments."""
+
+ comment: str = element(examples=["ok"])
+ """This comment contains < and > & special characters."""
+ data: str = element(examples=["&'"])
+ """This element's example contains special XML characters."""
+
+
+# This class definition is based on the one you provided in the prompt.
+class Analysis(Model, tag="analysis"):
+ """A model to validate the exact output requested in the prompt."""
+
+ priority: t.Literal["low", "medium", "high", "critical"] = element(examples=["medium"])
+ """Triage priority for human follow-up."""
+ tags: str = element("tags", examples=["admin panel, error message, legacy"])
+ """A list of specific areas within the screenshot that are noteworthy or require further examination."""
+ summary: str = element()
+ """A markdown summary explaining *why* the screenshot is interesting and what a human should investigate next."""
+
+
+@pytest.mark.parametrize(
+ ("model_cls", "expected_xml"),
+ [
+ pytest.param(
+ SimpleModel,
+ """
+
+
+ Hello, World!
+
+ """,
+ id="simple_model",
+ ),
+ pytest.param(
+ NoExampleModel,
+ """
+
+ """,
+ id="model_with_no_example",
+ ),
+ pytest.param(
+ AttrAndElementModel,
+ """
+
+
+ Some value
+
+ """,
+ id="model_with_attribute_and_element",
+ ),
+ pytest.param(
+ DocstringDescriptionModel,
+ """
+
+
+ val1
+
+ True
+
+ """,
+ id="descriptions_from_docstrings",
+ ),
+ pytest.param(
+ ParameterDescriptionModel,
+ """
+
+
+ override
+
+ """,
+ id="description_from_parameter_overrides_docstring",
+ ),
+ pytest.param(
+ SpecialCharsModel,
+ f"""
+
+
+ ok
+
+ {escape("&'")}
+
+ """,
+ id="escaping_of_special_characters",
+ ),
+ pytest.param(
+ Analysis,
+ """
+
+
+ medium
+
+ admin panel, error message, legacy
+
+
+
+ """,
+ id="user_provided_analysis_model",
+ ),
+ ],
+)
+def test_xml_example_generation(model_cls: type[Model], expected_xml: str) -> None:
+ """
+ Validates that the `xml_example()` class method produces the correct
+ pretty-printed XML with examples and descriptions as comments.
+ """
+ actual_xml = model_cls.xml_example()
+ assert dedent(actual_xml).strip() == dedent(expected_xml).strip()
diff --git a/tests/test_prompt.py b/tests/test_prompt.py
index 5c446c9..d530786 100644
--- a/tests/test_prompt.py
+++ b/tests/test_prompt.py
@@ -240,22 +240,19 @@ async def register_user(username: str, email: str, age: int) -> User:
def test_prompt_parse_fail_nested_input() -> None:
- async def foo(arg: list[list[str]]) -> Chat:
- ...
+ async def foo(arg: list[list[str]]) -> Chat: ...
with pytest.raises(TypeError):
rg.prompt(foo)
- async def bar(arg: tuple[int, str, tuple[str]]) -> Chat:
- ...
+ async def bar(arg: tuple[int, str, tuple[str]]) -> Chat: ...
with pytest.raises(TypeError):
rg.prompt(bar)
def test_prompt_parse_fail_unique_ouput() -> None:
- async def foo(arg: int) -> tuple[str, str]:
- ...
+ async def foo(arg: int) -> tuple[str, str]: ...
with pytest.raises(TypeError):
rg.prompt(foo)
diff --git a/tests/test_xml_parsing.py b/tests/test_xml_parsing.py
index a958118..dba74a8 100644
--- a/tests/test_xml_parsing.py
+++ b/tests/test_xml_parsing.py
@@ -1,5 +1,6 @@
import typing as t
from contextlib import nullcontext as does_not_raise
+from textwrap import dedent
import pytest
@@ -70,6 +71,43 @@ class MixedModel(Model):
nested: ContentAsElement = element()
+MULTI_LINE_TEXT = """\
+Multiline content with indentation
+some extra spaces
+
+Some more text\
+"""
+MULTI_LINE_CONTENT_TAG = """\
+
+ Multiline content with indentation
+ some extra spaces
+
+ Some more text
+
+"""
+MULTI_LINE_CONTENT_AS_ELEMENT = """\
+
+
+ Multiline content with indentation
+ some extra spaces
+
+ Some more text
+
+
+"""
+MULTI_LINE_CONTENT_MULTI_FIELD = """\
+
+ Process terminated
+
+ Multiline content with indentation
+ some extra spaces
+
+ Some more text
+
+
+"""
+
+
@pytest.mark.parametrize(
("content", "models"),
[
@@ -115,13 +153,23 @@ class MixedModel(Model):
id="question_with_answer_tag_2",
),
pytest.param(
- "helloworld",
+ dedent("""\
+
+ hello
+ world
+ \
+ """),
[QuestionAnswer(question=Question(content="hello"), answer=Answer(content="world"))],
id="question_answer",
),
pytest.param(
- "- hello\n - world",
- [DelimitedAnswer(content="- hello\n - world", _items=["hello", "world"])],
+ dedent("""\
+
+ - hello
+ - world
+ \
+ """),
+ [DelimitedAnswer(content="- hello\n- world", _items=["hello", "world"])],
id="newline_delimited_answer",
),
pytest.param(
@@ -145,12 +193,22 @@ class MixedModel(Model):
id="slash_delimited_answer",
),
pytest.param(
- 'ab',
+ dedent("""\
+
+ a
+ b
+ \
+ """),
[NameWithThings(name="test", things=["a", "b"])],
id="name_with_things",
),
pytest.param(
- 'meowbark',
+ dedent("""\
+
+ meow
+ bark
+ \
+ """),
[
Wrapped(
inners=[
@@ -195,6 +253,27 @@ class MixedModel(Model):
[ContentTag(content="actual content")],
id="tag_in_text_and_xml",
),
+ pytest.param(
+ MULTI_LINE_CONTENT_TAG,
+ [ContentTag(content=MULTI_LINE_TEXT)],
+ id="indented_multiline_content_tag",
+ ),
+ pytest.param(
+ MULTI_LINE_CONTENT_AS_ELEMENT,
+ [ContentAsElement(content=MULTI_LINE_TEXT)],
+ id="indented_multiline_content_as_element",
+ ),
+ pytest.param(
+ MULTI_LINE_CONTENT_MULTI_FIELD,
+ [
+ MultiFieldModel(
+ type_="cmd",
+ foo_field="Process terminated",
+ bar_field=MULTI_LINE_TEXT,
+ ),
+ ],
+ id="indented_multiline_content_multi_field",
+ ),
pytest.param(
"first text second more text third",
[Answer(content="first"), Answer(content="second"), Answer(content="third")],
@@ -211,7 +290,12 @@ class MixedModel(Model):
id="xml_breaking_chars_with_ampersand",
),
pytest.param(
- 'Process terminatedExit code <1>',
+ dedent("""\
+
+ Process terminated
+ Exit code <1>
+ \
+ """),
[
MultiFieldModel(
type_="cmd",
@@ -222,7 +306,13 @@ class MixedModel(Model):
id="multi_field_with_xml_breaking_chars",
),
pytest.param(
- "Volume in drive C:\n Directory of C:\\\n 01/02/2024 Program Files",
+ dedent("""\
+
+ Volume in drive C:
+ Directory of C:\\
+ 01/02/2024 Program Files
+ \
+ """),
[
Content(
content="Volume in drive C:\n Directory of C:\\\n 01/02/2024 Program Files",
@@ -231,7 +321,14 @@ class MixedModel(Model):
id="shell_output_simulation",
),
pytest.param(
- "Error in at line <42>normal nested content",
+ dedent("""\
+
+ Error in at line <42>
+
+ normal nested content
+
+ \
+ """),
[
MixedModel(
content_field="Error in at line <42>",
@@ -249,7 +346,7 @@ def test_xml_parsing(content: str, models: list[Model]) -> None:
assert obj.model_dump() == expected.model_dump(), (
f"Failed to parse model {expected.__class__.__name__} <- {obj!s} ({parsed})"
)
- xml = obj.to_xml()
+ xml = obj.to_pretty_xml()
assert xml == content[slice_], (
f"Failed to serialize model {expected.__class__.__name__} back to XML: {xml!r} != {content!r}"
)
@@ -258,7 +355,7 @@ def test_xml_parsing(content: str, models: list[Model]) -> None:
@pytest.mark.parametrize(
("content", "models"),
[
- # These cases parse correctly, but their XML representation doesn't yeild
+ # These cases parse correctly, but their XML representation doesn't yield
# the original content as some escape sequences are not preserved.
pytest.param(
"Text with <inner> as escaped HTML entities",