Skip to content

Commit 7e7fc67

Browse files
feat(semconv): expand genai span kind
1 parent fd9626b commit 7e7fc67

File tree

6 files changed

+519
-35
lines changed

6 files changed

+519
-35
lines changed

packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/__init__.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,13 @@ def _instrument(self, **kwargs):
8383
wrapper=_BaseCallbackManagerInitWrapper(traceloopCallbackHandler),
8484
)
8585

86+
# Wrap CallbackManager.configure to ensure our handler is included
87+
wrap_function_wrapper(
88+
module="langchain_core.callbacks.manager",
89+
name="CallbackManager.configure",
90+
wrapper=_CallbackManagerConfigureWrapper(traceloopCallbackHandler),
91+
)
92+
8693
if not self.disable_trace_context_propagation:
8794
self._wrap_openai_functions_for_tracing(traceloopCallbackHandler)
8895

@@ -168,6 +175,7 @@ def _wrap_openai_functions_for_tracing(self, traceloopCallbackHandler):
168175

169176
def _uninstrument(self, **kwargs):
170177
unwrap("langchain_core.callbacks", "BaseCallbackManager.__init__")
178+
unwrap("langchain_core.callbacks.manager", "CallbackManager.configure")
171179
if not self.disable_trace_context_propagation:
172180
if is_package_available("langchain_community"):
173181
unwrap("langchain_community.llms.openai", "BaseOpenAI._generate")
@@ -208,6 +216,30 @@ def __call__(
208216
instance.add_handler(self._callback_handler, True)
209217

210218

219+
class _CallbackManagerConfigureWrapper:
220+
def __init__(self, callback_handler: "TraceloopCallbackHandler"):
221+
self._callback_handler = callback_handler
222+
223+
def __call__(
224+
self,
225+
wrapped,
226+
instance,
227+
args,
228+
kwargs,
229+
):
230+
result = wrapped(*args, **kwargs)
231+
232+
if result and hasattr(result, 'add_handler'):
233+
for handler in result.inheritable_handlers:
234+
if isinstance(handler, type(self._callback_handler)):
235+
break
236+
else:
237+
self._callback_handler._callback_manager = result
238+
result.add_handler(self._callback_handler, True)
239+
240+
return result
241+
242+
211243
# This class wraps a function call to inject tracing information (trace headers) into
212244
# OpenAI client requests. It assumes the following:
213245
# 1. The wrapped function includes a `run_manager` keyword argument that contains a `run_id`.

packages/opentelemetry-instrumentation-langchain/opentelemetry/instrumentation/langchain/callback_handler.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,9 @@ def _create_llm_span(
359359
_set_span_attribute(span, SpanAttributes.LLM_SYSTEM, vendor)
360360
_set_span_attribute(span, SpanAttributes.LLM_REQUEST_TYPE, request_type.value)
361361

362+
span_kind = self._determine_llm_span_kind(serialized)
363+
_set_span_attribute(span, SpanAttributes.TRACELOOP_SPAN_KIND, span_kind.value)
364+
362365
# we already have an LLM span by this point,
363366
# so skip any downstream instrumentation from here
364367
try:
@@ -375,6 +378,69 @@ def _create_llm_span(
375378

376379
return span
377380

381+
def _determine_llm_span_kind(self, serialized: Optional[dict[str, Any]]) -> TraceloopSpanKindValues:
382+
"""Determine the appropriate span kind for LLM operations based on model type."""
383+
if not serialized:
384+
return TraceloopSpanKindValues.GENERATION
385+
386+
class_name = _extract_class_name_from_serialized(serialized)
387+
class_name_lower = class_name.lower()
388+
389+
if any(keyword in class_name_lower for keyword in ['embedding', 'embed']):
390+
return TraceloopSpanKindValues.EMBEDDING
391+
392+
# Default to generation for other LLM operations
393+
return TraceloopSpanKindValues.GENERATION
394+
395+
def _determine_chain_span_kind(
396+
self,
397+
serialized: dict[str, Any],
398+
name: str,
399+
tags: Optional[list[str]] = None
400+
) -> TraceloopSpanKindValues:
401+
if serialized and "id" in serialized:
402+
class_path = serialized["id"]
403+
if any("agent" in part.lower() for part in class_path):
404+
return TraceloopSpanKindValues.AGENT
405+
406+
if "agent" in name.lower():
407+
return TraceloopSpanKindValues.AGENT
408+
409+
class_name = _extract_class_name_from_serialized(serialized)
410+
name_lower = name.lower()
411+
412+
# Tool detection for RunnableLambda and custom tool chains
413+
if any(keyword in class_name.lower() for keyword in ['tool']):
414+
return TraceloopSpanKindValues.TOOL
415+
416+
if any(keyword in name_lower for keyword in ['tool', 'function']):
417+
return TraceloopSpanKindValues.TOOL
418+
419+
if tags and any('tool' in tag.lower() for tag in tags):
420+
return TraceloopSpanKindValues.TOOL
421+
422+
# Retriever detection for RunnableLambda and custom tool chains
423+
if any(keyword in class_name.lower() for keyword in ['retriever', 'retrieve', 'vectorstore']):
424+
return TraceloopSpanKindValues.RETRIEVER
425+
426+
if any(keyword in name_lower for keyword in ['retriever', 'retrieve', 'search']):
427+
return TraceloopSpanKindValues.RETRIEVER
428+
429+
# Embedding detection for RunnableLambda and custom chains
430+
if any(keyword in class_name.lower() for keyword in ['embedding', 'embed']):
431+
return TraceloopSpanKindValues.EMBEDDING
432+
433+
if any(keyword in name_lower for keyword in ['embedding', 'embed']):
434+
return TraceloopSpanKindValues.EMBEDDING
435+
436+
if any(keyword in class_name.lower() for keyword in ['rerank', 'reorder']):
437+
return TraceloopSpanKindValues.RERANKER
438+
439+
if any(keyword in name_lower for keyword in ['rerank', 'reorder']):
440+
return TraceloopSpanKindValues.RERANKER
441+
442+
return TraceloopSpanKindValues.TASK
443+
378444
@dont_throw
379445
def on_chain_start(
380446
self,
@@ -395,12 +461,18 @@ def on_chain_start(
395461
entity_path = ""
396462

397463
name = self._get_name_from_callback(serialized, **kwargs)
398-
kind = (
464+
465+
base_kind = (
399466
TraceloopSpanKindValues.WORKFLOW
400467
if parent_run_id is None or parent_run_id not in self.spans
401468
else TraceloopSpanKindValues.TASK
402469
)
403470

471+
if base_kind == TraceloopSpanKindValues.TASK:
472+
kind = self._determine_chain_span_kind(serialized, name, tags)
473+
else:
474+
kind = base_kind
475+
404476
if kind == TraceloopSpanKindValues.WORKFLOW:
405477
workflow_name = name
406478
else:
@@ -710,6 +782,73 @@ def on_tool_end(
710782
)
711783
self._end_span(span, run_id)
712784

785+
@dont_throw
786+
def on_retriever_start(
787+
self,
788+
serialized: dict[str, Any],
789+
query: str,
790+
*,
791+
run_id: UUID,
792+
parent_run_id: Optional[UUID] = None,
793+
tags: Optional[list[str]] = None,
794+
metadata: Optional[dict[str, Any]] = None,
795+
**kwargs: Any,
796+
) -> None:
797+
"""Run when retriever starts running."""
798+
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
799+
return
800+
801+
name = self._get_name_from_callback(serialized, kwargs=kwargs)
802+
workflow_name = self.get_workflow_name(parent_run_id)
803+
entity_path = self.get_entity_path(parent_run_id)
804+
805+
span = self._create_task_span(
806+
run_id,
807+
parent_run_id,
808+
name,
809+
TraceloopSpanKindValues.RETRIEVER,
810+
workflow_name,
811+
name,
812+
entity_path,
813+
)
814+
if not should_emit_events() and should_send_prompts():
815+
span.set_attribute(
816+
SpanAttributes.TRACELOOP_ENTITY_INPUT,
817+
json.dumps(
818+
{
819+
"query": query,
820+
"tags": tags,
821+
"metadata": metadata,
822+
"kwargs": kwargs,
823+
},
824+
cls=CallbackFilteredJSONEncoder,
825+
),
826+
)
827+
828+
@dont_throw
829+
def on_retriever_end(
830+
self,
831+
documents: Any,
832+
*,
833+
run_id: UUID,
834+
parent_run_id: Optional[UUID] = None,
835+
**kwargs: Any,
836+
) -> None:
837+
"""Run when retriever ends running."""
838+
if context_api.get_value(_SUPPRESS_INSTRUMENTATION_KEY):
839+
return
840+
841+
span = self._get_span(run_id)
842+
if not should_emit_events() and should_send_prompts():
843+
span.set_attribute(
844+
SpanAttributes.TRACELOOP_ENTITY_OUTPUT,
845+
json.dumps(
846+
{"documents": str(documents)[:1000], "kwargs": kwargs}, # Limit output size
847+
cls=CallbackFilteredJSONEncoder,
848+
),
849+
)
850+
self._end_span(span, run_id)
851+
713852
def get_parent_span(self, parent_run_id: Optional[str] = None):
714853
if parent_run_id is None:
715854
return None

0 commit comments

Comments
 (0)