Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ def _instrument(self, **kwargs):
wrapper=_BaseCallbackManagerInitWrapper(traceloopCallbackHandler),
)

# Wrap CallbackManager.configure to ensure our handler is included
wrap_function_wrapper(
module="langchain_core.callbacks.manager",
name="CallbackManager.configure",
wrapper=_CallbackManagerConfigureWrapper(traceloopCallbackHandler),
)

if not self.disable_trace_context_propagation:
self._wrap_openai_functions_for_tracing(traceloopCallbackHandler)

Expand Down Expand Up @@ -168,6 +175,7 @@ def _wrap_openai_functions_for_tracing(self, traceloopCallbackHandler):

def _uninstrument(self, **kwargs):
unwrap("langchain_core.callbacks", "BaseCallbackManager.__init__")
unwrap("langchain_core.callbacks.manager", "CallbackManager.configure")
if not self.disable_trace_context_propagation:
if is_package_available("langchain_community"):
unwrap("langchain_community.llms.openai", "BaseOpenAI._generate")
Expand Down Expand Up @@ -208,6 +216,30 @@ def __call__(
instance.add_handler(self._callback_handler, True)


class _CallbackManagerConfigureWrapper:
def __init__(self, callback_handler: "TraceloopCallbackHandler"):
self._callback_handler = callback_handler

def __call__(
self,
wrapped,
instance,
args,
kwargs,
):
result = wrapped(*args, **kwargs)

if result and hasattr(result, 'add_handler'):
for handler in result.inheritable_handlers:
if isinstance(handler, type(self._callback_handler)):
break
else:
self._callback_handler._callback_manager = result
result.add_handler(self._callback_handler, True)

return result


# This class wraps a function call to inject tracing information (trace headers) into
# OpenAI client requests. It assumes the following:
# 1. The wrapped function includes a `run_manager` keyword argument that contains a `run_id`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ def _create_llm_span(
_set_span_attribute(span, SpanAttributes.LLM_SYSTEM, vendor)
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_TYPE, request_type.value)

span_kind = self._determine_llm_span_kind(serialized)
_set_span_attribute(span, SpanAttributes.TRACELOOP_SPAN_KIND, span_kind.value)

# we already have an LLM span by this point,
# so skip any downstream instrumentation from here
try:
Expand All @@ -375,6 +378,72 @@ def _create_llm_span(

return span

def _determine_llm_span_kind(self, serialized: Optional[dict[str, Any]]) -> TraceloopSpanKindValues:
"""Determine the appropriate span kind for LLM operations based on model type."""
if not serialized:
return TraceloopSpanKindValues.GENERATION

class_name = _extract_class_name_from_serialized(serialized)
class_name_lower = class_name.lower()

if any(keyword in class_name_lower for keyword in ['embedding', 'embed']):
return TraceloopSpanKindValues.EMBEDDING

# Default to generation for other LLM operations
return TraceloopSpanKindValues.GENERATION

def _determine_chain_span_kind(
self,
serialized: dict[str, Any],
name: str,
tags: Optional[list[str]] = None
) -> TraceloopSpanKindValues:
if serialized and "id" in serialized:
class_path = serialized["id"]
if any("agent" in part.lower() for part in class_path):
return TraceloopSpanKindValues.AGENT

if "agent" in name.lower():
return TraceloopSpanKindValues.AGENT

class_name = _extract_class_name_from_serialized(serialized)
name_lower = name.lower()

# Tool detection for RunnableLambda and custom tool chains
if any(keyword in class_name.lower() for keyword in ['tool']):
return TraceloopSpanKindValues.TOOL

# More precise tool detection: exclude operation like `parsers`
if any(keyword in name_lower for keyword in ['tool']) or (
'function' in name_lower and 'parser' not in name_lower
):
return TraceloopSpanKindValues.TOOL

if tags and any('tool' in tag.lower() for tag in tags):
return TraceloopSpanKindValues.TOOL

# Retriever detection for RunnableLambda and custom tool chains
if any(keyword in class_name.lower() for keyword in ['retriever', 'retrieve', 'vectorstore']):
return TraceloopSpanKindValues.RETRIEVER

if any(keyword in name_lower for keyword in ['retriever', 'retrieve', 'search']):
return TraceloopSpanKindValues.RETRIEVER

# Embedding detection for RunnableLambda and custom chains
if any(keyword in class_name.lower() for keyword in ['embedding', 'embed']):
return TraceloopSpanKindValues.EMBEDDING

if any(keyword in name_lower for keyword in ['embedding', 'embed']):
return TraceloopSpanKindValues.EMBEDDING

if any(keyword in class_name.lower() for keyword in ['rerank', 'reorder']):
return TraceloopSpanKindValues.RERANKER

if any(keyword in name_lower for keyword in ['rerank', 'reorder']):
return TraceloopSpanKindValues.RERANKER

return TraceloopSpanKindValues.TASK

@dont_throw
def on_chain_start(
self,
Expand All @@ -395,12 +464,18 @@ def on_chain_start(
entity_path = ""

name = self._get_name_from_callback(serialized, **kwargs)
kind = (

base_kind = (
TraceloopSpanKindValues.WORKFLOW
if parent_run_id is None or parent_run_id not in self.spans
else TraceloopSpanKindValues.TASK
)

if base_kind == TraceloopSpanKindValues.TASK:
kind = self._determine_chain_span_kind(serialized, name, tags)
else:
kind = base_kind

if kind == TraceloopSpanKindValues.WORKFLOW:
workflow_name = name
else:
Expand Down Expand Up @@ -710,6 +785,73 @@ def on_tool_end(
)
self._end_span(span, run_id)

@dont_throw
def on_retriever_start(
self,
serialized: dict[str, Any],
query: str,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[list[str]] = None,
metadata: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> None:
"""Run when retriever starts running."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return

name = self._get_name_from_callback(serialized, kwargs=kwargs)
workflow_name = self.get_workflow_name(parent_run_id)
entity_path = self.get_entity_path(parent_run_id)

span = self._create_task_span(
run_id,
parent_run_id,
name,
TraceloopSpanKindValues.RETRIEVER,
workflow_name,
name,
entity_path,
)
if not should_emit_events() and should_send_prompts():
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_INPUT,
json.dumps(
{
"query": query,
"tags": tags,
"metadata": metadata,
"kwargs": kwargs,
},
cls=CallbackFilteredJSONEncoder,
),
)

@dont_throw
def on_retriever_end(
self,
documents: Any,
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
**kwargs: Any,
) -> None:
"""Run when retriever ends running."""
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
return

span = self._get_span(run_id)
if not should_emit_events() and should_send_prompts():
span.set_attribute(
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
json.dumps(
{"documents": str(documents)[:1000], "kwargs": kwargs}, # Limit output size
cls=CallbackFilteredJSONEncoder,
),
)
self._end_span(span, run_id)

def get_parent_span(self, parent_run_id: Optional[str] = None):
if parent_run_id is None:
return None
Expand Down
Loading