diff --git a/sqlalchemy_opentracing/__init__.py b/sqlalchemy_opentracing/__init__.py index cb77c1a..ad42b6f 100644 --- a/sqlalchemy_opentracing/__init__.py +++ b/sqlalchemy_opentracing/__init__.py @@ -5,14 +5,19 @@ g_tracer = None g_trace_all_queries = False g_trace_all_engines = False +g_flask_tracer = False +g_origin_tracer = None -def init_tracing(tracer, trace_all_engines=True, trace_all_queries=True): + +def init_tracing(tracer, trace_all_engines=True, trace_all_queries=True, flask_tracer=False): ''' Set our global tracer. Tracer objects from our pyramid/flask/django libraries can be passed as well. ''' - global g_tracer, g_trace_all_engines, g_trace_all_queries + global g_tracer, g_trace_all_engines, g_trace_all_queries, g_flask_tracer, g_origin_tracer + + g_origin_tracer = tracer if hasattr(tracer, '_tracer'): tracer = tracer._tracer @@ -20,10 +25,12 @@ def init_tracing(tracer, trace_all_engines=True, trace_all_queries=True): g_tracer = tracer g_trace_all_queries = trace_all_queries g_trace_all_engines = trace_all_engines + g_flask_tracer = flask_tracer if trace_all_engines: register_engine(Engine) + def get_traced(obj): ''' Gets a bool indicating whether or not this @@ -31,6 +38,7 @@ def get_traced(obj): ''' return getattr(obj, '_traced', False) + def set_traced(obj): ''' Mark a statement/session to be traced. @@ -46,6 +54,7 @@ def set_traced(obj): # after commit/rollback. _register_connection_events(obj) + def clear_traced(obj): ''' Clear an object's decorated tracing fields, @@ -56,11 +65,17 @@ def clear_traced(obj): if hasattr(obj, '_traced'): del obj._traced + def get_parent_span(obj): ''' Gets a parent span for this object, if any. ''' - return getattr(obj, '_parent_span', None) + parent_span = getattr(obj, '_parent_span', None) + # use flask_tracer current span as default parent span + if parent_span is None and g_flask_tracer: + parent_span = g_origin_tracer.get_span() + return parent_span + def set_parent_span(obj, parent_span): ''' @@ -70,6 +85,7 @@ def set_parent_span(obj, parent_span): obj._parent_span = parent_span set_traced(obj) + def has_parent_span(obj): ''' Get whether or not the statement has @@ -77,6 +93,7 @@ def has_parent_span(obj): ''' return hasattr(obj, '_parent_span') + def register_engine(obj): ''' Register an engine to have its events be traced. @@ -90,6 +107,7 @@ def register_engine(obj): listen(obj, 'after_cursor_execute', _engine_after_cursor_handler) listen(obj, 'handle_error', _engine_error_handler) + def unregister_engine(obj): ''' Remove an engine from having its events being traced. @@ -98,6 +116,7 @@ def unregister_engine(obj): remove(obj, 'after_cursor_execute', _engine_after_cursor_handler) remove(obj, 'handle_error', _engine_error_handler) + def _clear_tracer(): ''' Set the tracer to None. For test cases usage. @@ -105,6 +124,7 @@ def _clear_tracer(): global g_tracer g_tracer = None + def _can_operation_be_traced(conn, stmt_obj): ''' Get whether an operation can be traced, depending on its @@ -173,6 +193,7 @@ def _engine_before_cursor_handler(conn, cursor, def _engine_after_cursor_handler(conn, cursor, statement, parameters, context, executemany): + span = getattr(context, '_span', None) if span is None: return @@ -243,3 +264,4 @@ def _session_after_begin_handler(session, transaction, conn): def _session_cleanup_handler(session): clear_traced(session) +