diff --git a/ddtrace/internal/ci_visibility/encoder.py b/ddtrace/internal/ci_visibility/encoder.py index 8b8acec2eef..e8fc8cceec3 100644 --- a/ddtrace/internal/ci_visibility/encoder.py +++ b/ddtrace/internal/ci_visibility/encoder.py @@ -1,7 +1,14 @@ +from __future__ import annotations + import json import os import threading from typing import TYPE_CHECKING # noqa:F401 +from typing import Any # noqa:F401 +from typing import Dict # noqa:F401 +from typing import List # noqa:F401 +from typing import Optional # noqa:F401 +from typing import Tuple # noqa:F401 from uuid import uuid4 from ddtrace.ext import SpanTypes @@ -28,12 +35,6 @@ log = get_logger(__name__) if TYPE_CHECKING: # pragma: no cover - from typing import Any # noqa:F401 - from typing import Dict # noqa:F401 - from typing import List # noqa:F401 - from typing import Optional # noqa:F401 - from typing import Tuple # noqa:F401 - from ddtrace._trace.span import Span # noqa:F401 @@ -43,79 +44,153 @@ class CIVisibilityEncoderV01(BufferedEncoder): TEST_SUITE_EVENT_VERSION = 1 TEST_EVENT_VERSION = 2 ENDPOINT_TYPE = ENDPOINT.TEST_CYCLE + _MAX_PAYLOAD_SIZE = 5 * 1024 * 1024 # 5MB def __init__(self, *args): # DEV: args are not used here, but are used by BufferedEncoder's __cinit__() method, # which is called implicitly by Cython. super(CIVisibilityEncoderV01, self).__init__() + self._metadata: Dict[str, Dict[str, str]] = {} self._lock = threading.RLock() - self._metadata = {} + self._is_xdist_worker = os.getenv("PYTEST_XDIST_WORKER") is not None self._init_buffer() def __len__(self): with self._lock: return len(self.buffer) - def set_metadata(self, event_type, metadata): - # type: (str, Dict[str, str]) -> None + def set_metadata(self, event_type: str, metadata: Dict[str, str]): self._metadata.setdefault(event_type, {}).update(metadata) def _init_buffer(self): with self._lock: self.buffer = [] - def put(self, spans): + def put(self, item): with self._lock: - self.buffer.append(spans) + self.buffer.append(item) def encode_traces(self, traces): - return self._build_payload(traces=traces) + """ + Only used for LogWriter, not called for CI Visibility currently + """ + raise NotImplementedError() - def encode(self): + def encode(self) -> List[Tuple[Optional[bytes], int]]: with self._lock: + if not self.buffer: + return [] + payloads = [] with StopWatch() as sw: - result_payloads = self._build_payload(self.buffer) + payloads = self._build_payload(self.buffer) record_endpoint_payload_events_serialization_time(endpoint=self.ENDPOINT_TYPE, seconds=sw.elapsed()) self._init_buffer() - return result_payloads + return payloads - def _get_parent_session(self, traces): + def _get_parent_session(self, traces: List[List[Span]]) -> int: for trace in traces: for span in trace: if span.get_tag(EVENT_TYPE) == SESSION_TYPE and span.parent_id is not None and span.parent_id != 0: return span.parent_id return 0 - def _build_payload(self, traces): - # type: (List[List[Span]]) -> List[Tuple[Optional[bytes], int]] + def _build_payload(self, traces: List[List[Span]]) -> List[Tuple[Optional[bytes], int]]: + """ + Build multiple payloads from traces, splitting when necessary to stay under size limits. + Uses index-based recursive approach to avoid copying slices. + + Returns a list of (payload_bytes, trace_count) tuples, where each payload contains + as many traces as possible without exceeding _MAX_PAYLOAD_SIZE. + """ + if not traces: + return [] + new_parent_session_span_id = self._get_parent_session(traces) - is_not_xdist_worker = os.getenv("PYTEST_XDIST_WORKER") is None - normalized_spans = [ - self._convert_span(span, trace[0].context.dd_origin, new_parent_session_span_id) - for trace in traces - for span in trace - if (is_not_xdist_worker or span.get_tag(EVENT_TYPE) != SESSION_TYPE) - ] - if not normalized_spans: + return self._build_payloads_recursive(traces, 0, len(traces), new_parent_session_span_id) + + def _build_payloads_recursive( + self, traces: List[List[Span]], start_idx: int, end_idx: int, new_parent_session_span_id: int + ) -> List[Tuple[Optional[bytes], int]]: + """ + Recursively build payloads using start/end indexes to avoid slice copying. + + Args: + traces: Full list of traces + start_idx: Start index (inclusive) + end_idx: End index (exclusive) + new_parent_session_span_id: Parent session span ID + + Returns: + List of (payload_bytes, trace_count) tuples + """ + if start_idx >= end_idx: return [] - record_endpoint_payload_events_count(endpoint=ENDPOINT.TEST_CYCLE, count=len(normalized_spans)) - # TODO: Split the events in several payloads as needed to avoid hitting the intake's maximum payload size. - return [ - ( - CIVisibilityEncoderV01._pack_payload( - {"version": self.PAYLOAD_FORMAT_VERSION, "metadata": self._metadata, "events": normalized_spans} - ), - len(traces), - ) - ] + trace_count = end_idx - start_idx + + # Convert traces to spans with filtering (using indexes) + all_spans_with_trace_info = self._convert_traces_to_spans_indexed( + traces, start_idx, end_idx, new_parent_session_span_id + ) + + # Get all spans (flattened) + all_spans = [span for _, trace_spans in all_spans_with_trace_info for span in trace_spans] + + if not all_spans: + log.debug("No spans to encode after filtering, skipping chunk") + return [] + + # Try to create payload from all spans + payload = self._create_payload_from_spans(all_spans) + + if len(payload) <= self._MAX_PAYLOAD_SIZE or trace_count == 1: + # Payload fits or we can't split further (single trace) + record_endpoint_payload_events_count(endpoint=self.ENDPOINT_TYPE, count=len(all_spans)) + return [(payload, trace_count)] + else: + # Payload is too large, split in half recursively + mid_idx = start_idx + (trace_count + 1) // 2 + + # Process both halves recursively + left_payloads = self._build_payloads_recursive(traces, start_idx, mid_idx, new_parent_session_span_id) + right_payloads = self._build_payloads_recursive(traces, mid_idx, end_idx, new_parent_session_span_id) + + # Combine results + return left_payloads + right_payloads + + def _convert_traces_to_spans_indexed( + self, traces: List[List[Span]], start_idx: int, end_idx: int, new_parent_session_span_id: int + ) -> List[Tuple[int, List[Dict[str, Any]]]]: + """Convert traces to spans with xdist filtering applied, using indexes to avoid slicing.""" + all_spans_with_trace_info = [] + for trace_idx in range(start_idx, end_idx): + trace = traces[trace_idx] + trace_spans = [ + self._convert_span(span, trace[0].context.dd_origin, new_parent_session_span_id) + for span in trace + if (not self._is_xdist_worker) or (span.get_tag(EVENT_TYPE) != SESSION_TYPE) + ] + all_spans_with_trace_info.append((trace_idx, trace_spans)) + + return all_spans_with_trace_info + + def _create_payload_from_spans(self, spans: List[Dict[str, Any]]) -> bytes: + """Create a payload from the given spans.""" + return CIVisibilityEncoderV01._pack_payload( + { + "version": self.PAYLOAD_FORMAT_VERSION, + "metadata": self._metadata, + "events": spans, + } + ) @staticmethod def _pack_payload(payload): return msgpack_packb(payload) - def _convert_span(self, span, dd_origin, new_parent_session_span_id=0): - # type: (Span, Optional[str], Optional[int]) -> Dict[str, Any] + def _convert_span( + self, span: Span, dd_origin: Optional[str] = None, new_parent_session_span_id: int = 0 + ) -> Dict[str, Any]: sp = JSONEncoderV2._span_to_dict(span) sp = JSONEncoderV2._normalize_span(sp) sp["type"] = span.get_tag(EVENT_TYPE) or span.span_type @@ -183,18 +258,17 @@ class CIVisibilityCoverageEncoderV02(CIVisibilityEncoderV01): def _set_itr_suite_skipping_mode(self, new_value): self.itr_suite_skipping_mode = new_value - def put(self, spans): + def put(self, item): spans_with_coverage = [ span - for span in spans + for span in item if COVERAGE_TAG_NAME in span.get_tags() or span.get_struct_tag(COVERAGE_TAG_NAME) is not None ] if not spans_with_coverage: raise NoEncodableSpansError() return super(CIVisibilityCoverageEncoderV02, self).put(spans_with_coverage) - def _build_coverage_attachment(self, data): - # type: (bytes) -> List[bytes] + def _build_coverage_attachment(self, data: bytes) -> List[bytes]: return [ b"--%s" % self.boundary.encode("utf-8"), b'Content-Disposition: form-data; name="coverage1"; filename="coverage1.msgpack"', @@ -203,8 +277,7 @@ def _build_coverage_attachment(self, data): data, ] - def _build_event_json_attachment(self): - # type: () -> List[bytes] + def _build_event_json_attachment(self) -> List[bytes]: return [ b"--%s" % self.boundary.encode("utf-8"), b'Content-Disposition: form-data; name="event"; filename="event.json"', @@ -213,18 +286,16 @@ def _build_event_json_attachment(self): b'{"dummy":true}', ] - def _build_body(self, data): - # type: (bytes) -> List[bytes] + def _build_body(self, data: bytes) -> List[bytes]: return ( self._build_coverage_attachment(data) + self._build_event_json_attachment() + [b"--%s--" % self.boundary.encode("utf-8")] ) - def _build_data(self, traces): - # type: (List[List[Span]]) -> Optional[bytes] + def _build_data(self, traces: List[List[Span]]) -> Optional[bytes]: normalized_covs = [ - self._convert_span(span, "") + self._convert_span(span) for trace in traces for span in trace if (COVERAGE_TAG_NAME in span.get_tags() or span.get_struct_tag(COVERAGE_TAG_NAME) is not None) @@ -235,17 +306,17 @@ def _build_data(self, traces): # TODO: Split the events in several payloads as needed to avoid hitting the intake's maximum payload size. return msgpack_packb({"version": self.PAYLOAD_FORMAT_VERSION, "coverages": normalized_covs}) - def _build_payload(self, traces): - # type: (List[List[Span]]) -> List[Tuple[Optional[bytes], int]] + def _build_payload(self, traces: List[List[Span]]) -> List[Tuple[Optional[bytes], int]]: data = self._build_data(traces) if not data: return [] - return [(b"\r\n".join(self._build_body(data)), len(traces))] + return [(b"\r\n".join(self._build_body(data)), len(data))] - def _convert_span(self, span, dd_origin, new_parent_session_span_id=0): - # type: (Span, Optional[str], Optional[int]) -> Dict[str, Any] + def _convert_span( + self, span: Span, dd_origin: Optional[str] = None, new_parent_session_span_id: int = 0 + ) -> Dict[str, Any]: # DEV: new_parent_session_span_id is unused here, but it is used in super class - files: Dict[str, Any] = {} + files: dict[str, Any] = {} files_struct_tag_value = span.get_struct_tag(COVERAGE_TAG_NAME) if files_struct_tag_value is not None and "files" in files_struct_tag_value: diff --git a/tests/ci_visibility/test_encoder.py b/tests/ci_visibility/test_encoder.py index 4624061ba29..2a90122308d 100644 --- a/tests/ci_visibility/test_encoder.py +++ b/tests/ci_visibility/test_encoder.py @@ -1,6 +1,7 @@ import json import os +import mock import msgpack import pytest @@ -120,7 +121,197 @@ def test_encode_traces_civisibility_v01_empty_traces(): for trace in traces: encoder.put(trace) encoded_traces = encoder.encode() - assert encoded_traces == [], "Expected empty list when no content" + assert encoded_traces == [], "Expected empty list when payload is None" + + +def test_build_payload_empty_traces(): + """Test _build_payload with empty traces list.""" + encoder = CIVisibilityEncoderV01(0, 0) + payloads = encoder._build_payload([]) + assert payloads == [], "Expected empty list when payload is None" + + +def test_build_payload_single_trace(): + """Test _build_payload with a single trace.""" + trace = [Span(name="test", span_id=0x123456, service="test_service")] + encoder = CIVisibilityEncoderV01(0, 0) + payloads = encoder._build_payload([trace]) + + assert len(payloads) == 1 + payload, count = payloads[0] + assert count == 1 + assert payload is not None + assert isinstance(payload, bytes) + + +def test_build_payload_multiple_small_traces(): + """Test _build_payload with multiple traces that fit in one payload.""" + traces = [ + [Span(name="test1", span_id=0x111111, service="test")], + [Span(name="test2", span_id=0x222222, service="test")], + [Span(name="test3", span_id=0x333333, service="test")], + ] + encoder = CIVisibilityEncoderV01(0, 0) + payloads = encoder._build_payload(traces) + + assert len(payloads) == 1 + payload, count = payloads[0] + assert count == 3 + assert payload is not None + + +def test_build_payload_large_trace_splitting(): + """Test _build_payload with traces that exceed max payload size.""" + # Create large traces that will exceed the 5MB limit + large_traces = [] + for i in range(100): # Create many traces + trace = [] + for j in range(50): # Each trace has many spans + span = Span(name=f"large_test_{i}_{j}", span_id=0x100000 + i * 100 + j, service="test") + # Add large metadata to increase payload size + span.set_tag_str("large_data", "x" * 1000) # 1KB per span + trace.append(span) + large_traces.append(trace) + + encoder = CIVisibilityEncoderV01(0, 0) + # Use monkeypatch to temporarily reduce max payload size for testing + with mock.patch.object(encoder, "_MAX_PAYLOAD_SIZE", 50 * 1024): # 50KB to force splitting + payloads = encoder._build_payload(large_traces) + + # Should have multiple payloads + assert len(payloads) > 1 + + # All payloads should be under the size limit (except single traces that can't be split) + total_traces_processed = 0 + for payload, count in payloads: + assert count > 0 + if count > 1 and payload is not None: # Multi-trace payloads should be under limit + assert len(payload) <= 50 * 1024 + total_traces_processed += count + + # All traces should be processed + assert total_traces_processed == len(large_traces) + + +def test_build_payload_recursive_splitting(): + """Test that recursive splitting works correctly and terminates.""" + # Create traces that will require multiple levels of splitting + traces = [] + for i in range(16): # 16 traces + trace = [] + for j in range(10): # Each with 10 spans + span = Span(name=f"test_{i}_{j}", span_id=0x200000 + i * 100 + j, service="test") + span.set_tag_str("data", "x" * 500) # Make each span moderately large + trace.append(span) + traces.append(trace) + + encoder = CIVisibilityEncoderV01(0, 0) + # Set a small payload size to force multiple splits + with mock.patch.object(encoder, "_MAX_PAYLOAD_SIZE", 10 * 1024): # 10KB + payloads = encoder._build_payload(traces) + + # Should have multiple payloads due to splitting + assert len(payloads) > 1 + + # Verify all traces are processed + total_traces = sum(count for _, count in payloads) + assert total_traces == len(traces) + + # Verify no infinite recursion (should complete in reasonable time) + # If we get here, recursion terminated properly + + +def test_build_payload_with_filtered_spans(): + """Test _build_payload with spans that get filtered out.""" + traces = [ + [ + Span(name="session", span_id=0x111111, service="test"), + Span(name="regular", span_id=0x222222, service="test"), + ], + [ + Span(name="test", span_id=0x333333, service="test", span_type="test"), + ], + ] + + # Set up xdist worker environment to trigger filtering + original_env = os.environ.get("PYTEST_XDIST_WORKER") + os.environ["PYTEST_XDIST_WORKER"] = "gw0" + + try: + # Add session type tag to trigger filtering + traces[0][0].set_tag_str(EVENT_TYPE, SESSION_TYPE) + + encoder = CIVisibilityEncoderV01(0, 0) + payloads = encoder._build_payload(traces) + + assert len(payloads) == 1 + payload, count = payloads[0] + assert count == 2 # Both traces processed + assert payload is not None + + # Decode and verify that session spans were filtered out + decoded = msgpack.unpackb(payload, raw=True, strict_map_key=False) + events = decoded[b"events"] + # Should have 2 events (1 regular span + 1 test span), session span filtered + assert len(events) == 2 + + finally: + if original_env is None: + os.environ.pop("PYTEST_XDIST_WORKER", None) + else: + os.environ["PYTEST_XDIST_WORKER"] = original_env + + +def test_build_payload_all_spans_filtered(): + """Test _build_payload when all spans get filtered out.""" + traces = [ + [Span(name="session1", span_id=0x111111, service="test")], + [Span(name="session2", span_id=0x222222, service="test")], + ] + + # Set up xdist worker environment to trigger filtering + original_env = os.environ.get("PYTEST_XDIST_WORKER") + os.environ["PYTEST_XDIST_WORKER"] = "gw0" + + try: + # Make both spans session types to trigger filtering + for trace in traces: + trace[0].set_tag_str(EVENT_TYPE, SESSION_TYPE) + + encoder = CIVisibilityEncoderV01(0, 0) + payloads = encoder._build_payload(traces) + + # Should return empty list when no spans remain after filtering + assert payloads == [] + + finally: + if original_env is None: + os.environ.pop("PYTEST_XDIST_WORKER", None) + else: + os.environ["PYTEST_XDIST_WORKER"] = original_env + + +def test_build_payload_no_infinite_recursion(): + """Test that recursion always terminates, even with edge cases.""" + # Single large trace that can't be split further + large_trace = [] + for i in range(100): + span = Span(name=f"large_span_{i}", span_id=0x400000 + i, service="test") + span.set_tag_str("large_data", "x" * 1000) + large_trace.append(span) + + encoder = CIVisibilityEncoderV01(0, 0) + # Set very small payload size + with mock.patch.object(encoder, "_MAX_PAYLOAD_SIZE", 1024): # 1KB - much smaller than the trace + # This should not hang due to infinite recursion + payloads = encoder._build_payload([large_trace]) + + # Should return exactly one payload (can't split single trace) + assert len(payloads) == 1 + payload, count = payloads[0] + assert count == 1 + assert payload is not None + # Payload can exceed max size when it's a single unsplittable trace def test_encode_traces_civisibility_v2_coverage_per_test():