Skip to content

Commit 2e7a25b

Browse files
zoyahavtf-transform-team
authored andcommitted
Avoid erroneous warnings regarding table initialization inside tf.init_scope when calling InitializableGraphAnalyzer from private methods (such as get_dependent_inputs).
get_dependent_inputs is mostly called with a partial graph where assets are of type tf.Tensor even though it's running using the native TF2 environment. This change also updates the expected type of the parameter `InitializableGraphAnalyzer` to `tft.apply_vocabulary`. PiperOrigin-RevId: 387127346
1 parent ad64484 commit 2e7a25b

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

tensorflow_transform/graph_tools.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import collections
2727
import itertools
2828
import uuid
29+
from absl import logging
2930

3031
import tensorflow as tf
3132
from tensorflow_transform import analyzer_nodes
@@ -569,13 +570,12 @@ def __init__(self,
569570
continue
570571

571572
if isinstance(graph, tf_func_graph.FuncGraph):
572-
tf.compat.v1.logging.warning('Tables initialized inside a tf.function '
573-
'will be re-initialized on every '
574-
'invocation of the function. This '
575-
're-initialization can have significant '
576-
'impact on performance. Consider lifting '
577-
'them out of the graph context using '
578-
'`tf.init_scope`.')
573+
self._log_warning('Tables initialized inside a tf.function will be'
574+
' re-initialized on every invocation of the function.'
575+
' This re-initialization can have significant impact'
576+
' on performance. Consider lifting them out of the'
577+
' graph context using `tf.init_scope`.: {}'.format(
578+
table_init_op_or_tensor.name))
579579

580580
table_init_op, table_input_ops = (
581581
self._get_table_init_op_and_inputs(table_init_op_or_tensor))
@@ -593,6 +593,9 @@ def __init__(self,
593593
self._graph_analyzer = _GraphAnalyzer(complete_source_info_dict,
594594
translate_path_fn, graph)
595595

596+
def _log_warning(self, message: str):
597+
logging.warning(message)
598+
596599
def _get_table_init_op_and_inputs(self, table_init_op_or_tensor):
597600
"""Get a tuple of table init op and keys for its input ops."""
598601
# If a TF2 exported SavedModel with a table is loaded inside the
@@ -759,6 +762,13 @@ def get_dependent_inputs(self, tensor_or_op):
759762
return result
760763

761764

765+
class _QuietInitializableGraphAnalyzer(InitializableGraphAnalyzer):
766+
"""A `InitializableGraphAnalyzer` which doesn't log any warnings."""
767+
768+
def _log_warning(self, message: str):
769+
pass
770+
771+
762772
def get_dependent_inputs(graph, input_tensors, output_tensors):
763773
"""Returns tensors in input_tensors that (transitively) produce output_tensors.
764774
@@ -790,8 +800,8 @@ def get_dependent_inputs(graph, input_tensors, output_tensors):
790800
# tensors doesn't affect the correctness of dependencies tracing.
791801
tensor_sinks = graph.get_collection(analyzer_nodes.TENSOR_REPLACEMENTS)
792802
sink_tensors_ready = [(sink.tensor, False) for sink in tensor_sinks]
793-
graph_analyzer = InitializableGraphAnalyzer(graph, input_tensors,
794-
sink_tensors_ready)
803+
graph_analyzer = _QuietInitializableGraphAnalyzer(graph, input_tensors,
804+
sink_tensors_ready)
795805
dependent_inputs = {}
796806
for output_tensor in output_container:
797807
dependent_inputs.update(graph_analyzer.get_dependent_inputs(output_tensor))

tensorflow_transform/mappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1041,7 +1041,7 @@ def compute_and_apply_vocabulary(
10411041
@common.log_api_use(common.MAPPER_COLLECTION)
10421042
def apply_vocabulary(
10431043
x: common_types.ConsistentTensorType,
1044-
deferred_vocab_filename_tensor: tf.Tensor,
1044+
deferred_vocab_filename_tensor: common_types.TemporaryAnalyzerOutputType,
10451045
default_value: Optional[Any] = -1,
10461046
num_oov_buckets: Optional[int] = 0,
10471047
lookup_fn: Optional[Callable[[common_types.TensorType, tf.Tensor],

0 commit comments

Comments
 (0)