Skip to content

Commit 0cf4776

Browse files
authored
Merge pull request #127 from dreadnode/feat/agents
feat: Agents
2 parents f6645d0 + c09e891 commit 0cf4776

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+5697
-1554
lines changed

agent.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from pathlib import Path
2+
3+
from dreadnode.agent.agent import TaskAgent
4+
from dreadnode.agent.hooks import summarize_when_long
5+
from dreadnode.agent.tools import tool
6+
7+
8+
@tool(truncate=1000, catch=True)
9+
async def read_file(path: str) -> str:
10+
"Read the contents of a file."
11+
return (Path("../") / path).read_text()
12+
13+
14+
agent = TaskAgent(
15+
name="basic",
16+
description="A basic agent that can handle simple tasks.",
17+
model="gpt-4o-mini",
18+
hooks=[summarize_when_long(max_tokens=1000)],
19+
tools=[read_file],
20+
)

docs/sdk/api.mdx

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,10 @@ def __init__(
8282
}
8383

8484
if api_key:
85-
headers["X-Api-Key"] = api_key
85+
headers["X-API-Key"] = api_key
8686

8787
self._client = httpx.Client(
8888
headers=headers,
89-
cookies=_cookies,
9089
base_url=self._base_url,
9190
timeout=30,
9291
)
@@ -156,7 +155,7 @@ def export_metrics(
156155
status: StatusFilter = "completed",
157156
metrics: list[str] | None = None,
158157
aggregations: list[MetricAggregationType] | None = None,
159-
) -> pd.DataFrame:
158+
) -> "pd.DataFrame":
160159
"""
161160
Exports metric data for a specific project.
162161
@@ -170,6 +169,8 @@ def export_metrics(
170169
Returns:
171170
A DataFrame containing the exported metric data.
172171
"""
172+
import pandas as pd
173+
173174
response = self.request(
174175
"GET",
175176
f"/strikes/projects/{project!s}/export/metrics",
@@ -248,7 +249,7 @@ def export_parameters(
248249
parameters: list[str] | None = None,
249250
metrics: list[str] | None = None,
250251
aggregations: list[MetricAggregationType] | None = None,
251-
) -> pd.DataFrame:
252+
) -> "pd.DataFrame":
252253
"""
253254
Exports parameter data for a specific project.
254255
@@ -263,6 +264,8 @@ def export_parameters(
263264
Returns:
264265
A DataFrame containing the exported parameter data.
265266
"""
267+
import pandas as pd
268+
266269
response = self.request(
267270
"GET",
268271
f"/strikes/projects/{project!s}/export/parameters",
@@ -331,7 +334,7 @@ def export_runs(
331334
# format: ExportFormat = "parquet",
332335
status: StatusFilter = "completed",
333336
aggregations: list[MetricAggregationType] | None = None,
334-
) -> pd.DataFrame:
337+
) -> "pd.DataFrame":
335338
"""
336339
Exports run data for a specific project.
337340
@@ -344,6 +347,8 @@ def export_runs(
344347
Returns:
345348
A DataFrame containing the exported run data.
346349
"""
350+
import pandas as pd
351+
347352
response = self.request(
348353
"GET",
349354
f"/strikes/projects/{project!s}/export",
@@ -424,7 +429,7 @@ def export_timeseries(
424429
metrics: list[str] | None = None,
425430
time_axis: TimeAxisType = "relative",
426431
aggregations: list[TimeAggregationType] | None = None,
427-
) -> pd.DataFrame:
432+
) -> "pd.DataFrame":
428433
"""
429434
Exports timeseries data for a specific project.
430435
@@ -439,6 +444,8 @@ def export_timeseries(
439444
Returns:
440445
A DataFrame containing the exported timeseries data.
441446
"""
447+
import pandas as pd
448+
442449
response = self.request(
443450
"GET",
444451
f"/strikes/projects/{project!s}/export/timeseries",

docs/sdk/data_types.mdx

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,7 @@ def __init__(
7272
caption: Optional caption for the audio
7373
format: Optional format to use (default is wav for numpy arrays)
7474
"""
75-
if sf is None:
76-
raise ImportError(
77-
"Audio processing requires optional dependencies. "
78-
"Install with: pip install dreadnode[multimodal]"
79-
)
75+
check_imports()
8076
self._data = data
8177
self._sample_rate = sample_rate
8278
self._caption = caption
@@ -211,10 +207,7 @@ def __init__(
211207
caption: Optional caption for the image
212208
format: Optional format to use when saving (png, jpg, etc.)
213209
"""
214-
if PILImage is None:
215-
raise ImportError(
216-
"Image processing requires PIL (Pillow). Install with: pip install dreadnode[multimodal]"
217-
)
210+
check_imports()
218211
self._data = data
219212
self._mode = mode
220213
self._caption = caption
@@ -650,6 +643,13 @@ def to_serializable(self) -> tuple[bytes, dict[str, t.Any]]:
650643
Returns:
651644
A tuple of (video_bytes, metadata_dict)
652645
"""
646+
import numpy as np # type: ignore[import,unused-ignore]
647+
648+
try:
649+
from moviepy.video.VideoClip import VideoClip # type: ignore[import,unused-ignore]
650+
except ImportError:
651+
VideoClip = None # noqa: N806
652+
653653
if isinstance(self._data, (str, Path)) and Path(self._data).exists():
654654
return self._process_file_path()
655655
if isinstance(self._data, bytes):

docs/sdk/main.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,8 @@ def initialize(self) -> None:
460460
461461
This method is called automatically when you call `configure()`.
462462
"""
463+
from s3fs import S3FileSystem # type: ignore [import-untyped]
464+
463465
if self._initialized:
464466
return
465467

docs/sdk/metric.mdx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ def apply_mode(self, mode: MetricAggMode, others: "list[Metric]") -> "Metric":
136136
self.value = len(others) + 1
137137
elif mode == "avg" and prior_values:
138138
current_avg = prior_values[-1]
139-
self.value = current_avg + (self.value - current_avg) / (len(prior_values) + 1)
139+
self.value = current_avg + (self.value - current_avg) / (
140+
len(prior_values) + 1
141+
)
140142

141143
return self
142144
```

docs/sdk/scorers.mdx

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,20 @@ def zero_shot_classification(
122122
model_name: The name of the zero-shot model from Hugging Face Hub.
123123
name: Name of the scorer.
124124
"""
125-
if not _TRANSFORMERS_AVAILABLE:
126-
warn_at_user_stacklevel(_TRANSFORMERS_ERROR_MSG, UserWarning)
125+
transformers_error_msg = (
126+
"Hugging Face transformers dependency is not installed. "
127+
"Please install with: pip install transformers torch"
128+
)
129+
130+
try:
131+
from transformers import ( # type: ignore [attr-defined,import-not-found,unused-ignore]
132+
pipeline,
133+
)
134+
except ImportError:
135+
warn_at_user_stacklevel(transformers_error_msg, UserWarning)
127136

128137
def disabled_evaluate(_: t.Any) -> Metric:
129-
return Metric(value=0.0, attributes={"error": _TRANSFORMERS_ERROR_MSG})
138+
return Metric(value=0.0, attributes={"error": transformers_error_msg})
130139

131140
return Scorer.from_callable(disabled_evaluate, name=name)
132141

@@ -816,7 +825,7 @@ def detect_harm_with_openai(
816825
*,
817826
api_key: str | None = None,
818827
model: t.Literal["text-moderation-stable", "text-moderation-latest"] = "text-moderation-stable",
819-
client: openai.AsyncOpenAI | None = None,
828+
client: "openai.AsyncOpenAI | None" = None,
820829
name: str = "openai_harm",
821830
) -> "Scorer[t.Any]":
822831
"""
@@ -837,6 +846,7 @@ def detect_harm_with_openai(
837846
model: The moderation model to use.
838847
name: Name of the scorer.
839848
"""
849+
import openai
840850

841851
async def evaluate(data: t.Any) -> Metric:
842852
text = str(data)
@@ -1800,12 +1810,18 @@ def detect_pii_with_presidio(
18001810
invert: Invert the score (1.0 for no PII, 0.0 for PII detected).
18011811
name: Name of the scorer.
18021812
"""
1813+
presidio_import_error_msg = (
1814+
"Presidio dependencies are not installed. "
1815+
"Please install them with: pip install presidio-analyzer presidio-anonymizer 'spacy[en_core_web_lg]'"
1816+
)
18031817

1804-
if not _PRESIDIO_AVAILABLE:
1805-
warn_at_user_stacklevel(_PRESIDIO_ERROR_MSG, UserWarning)
1818+
try:
1819+
import presidio_analyzer # type: ignore[import-not-found,unused-ignore] # noqa: F401
1820+
except ImportError:
1821+
warn_at_user_stacklevel(presidio_import_error_msg, UserWarning)
18061822

18071823
def disabled_evaluate(_: t.Any) -> Metric:
1808-
return Metric(value=0.0, attributes={"error": _PRESIDIO_ERROR_MSG})
1824+
return Metric(value=0.0, attributes={"error": presidio_import_error_msg})
18091825

18101826
return Scorer.from_callable(disabled_evaluate, name=name)
18111827

dreadnode/__init__.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from dreadnode import convert, data_types, scorers
1+
import importlib
2+
import typing as t
3+
4+
from dreadnode import convert, data_types
25
from dreadnode.data_types import Audio, Code, Image, Markdown, Object3D, Table, Text, Video
36
from dreadnode.lookup import Lookup, lookup_input, lookup_output, lookup_param, resolve_lookup
47
from dreadnode.main import DEFAULT_INSTANCE, Dreadnode
@@ -8,6 +11,9 @@
811
from dreadnode.tracing.span import RunSpan, Span, TaskSpan
912
from dreadnode.version import VERSION
1013

14+
if t.TYPE_CHECKING:
15+
from dreadnode import scorers # noqa: F401
16+
1117
configure = DEFAULT_INSTANCE.configure
1218
shutdown = DEFAULT_INSTANCE.shutdown
1319

@@ -77,11 +83,24 @@
7783
"resolve_lookup",
7884
"run",
7985
"scorer",
80-
"scorers",
8186
"shutdown",
8287
"span",
8388
"tag",
8489
"task",
8590
"task_span",
8691
"task_span",
8792
]
93+
94+
__lazy_submodules__ = ["scorers"]
95+
96+
97+
def __getattr__(name: str) -> t.Any:
98+
if name in __lazy_submodules__:
99+
module = importlib.import_module(f".{name}", __name__)
100+
globals()[name] = module
101+
return module
102+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
103+
104+
105+
def __dir__() -> list[str]:
106+
return sorted(list(globals().keys()) + __lazy_submodules__)

dreadnode/agent/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from pydantic.dataclasses import rebuild_dataclass
2+
3+
from dreadnode.agent.agent import Agent
4+
from dreadnode.agent.events import rebuild_event_models
5+
from dreadnode.agent.result import AgentResult
6+
from dreadnode.agent.thread import Thread
7+
8+
Agent.model_rebuild()
9+
Thread.model_rebuild()
10+
11+
rebuild_event_models()
12+
13+
rebuild_dataclass(AgentResult) # type: ignore[arg-type]

0 commit comments

Comments
 (0)