@@ -205,15 +205,13 @@ def testRunSimpleNetworkoWithInfAndNaNWorks(self):
205
205
206
206
report = self ._debug_data_server .numerics_alert_report ()
207
207
self .assertEqual (2 , len (report ))
208
- self .assertEqual ("/job:localhost/replica:0/task:0/cpu:0" ,
209
- report [0 ].device_name )
208
+ self .assertTrue (report [0 ].device_name .lower ().endswith ("cpu:0" ))
210
209
self .assertEqual ("u:0" , report [0 ].tensor_name )
211
210
self .assertGreater (report [0 ].first_timestamp , 0 )
212
211
self .assertEqual (0 , report [0 ].nan_event_count )
213
212
self .assertEqual (0 , report [0 ].neg_inf_event_count )
214
213
self .assertEqual (1 , report [0 ].pos_inf_event_count )
215
- self .assertEqual ("/job:localhost/replica:0/task:0/cpu:0" ,
216
- report [1 ].device_name )
214
+ self .assertTrue (report [1 ].device_name .lower ().endswith ("cpu:0" ))
217
215
self .assertEqual ("u:0" , report [0 ].tensor_name )
218
216
self .assertGreaterEqual (report [1 ].first_timestamp ,
219
217
report [0 ].first_timestamp )
@@ -299,7 +297,7 @@ def testConcurrentNumericsAlertsAreRegisteredCorrectly(self):
299
297
300
298
def run_v (thread_id ):
301
299
for _ in range (num_runs_per_thread ):
302
- sess .run (v , options = run_options_list [thread_id ]) # DEBUG
300
+ sess .run (v , options = run_options_list [thread_id ])
303
301
304
302
run_threads = []
305
303
for thread_id in range (num_threads ):
@@ -312,15 +310,13 @@ def run_v(thread_id):
312
310
313
311
report = self ._debug_data_server .numerics_alert_report ()
314
312
self .assertEqual (2 , len (report ))
315
- self .assertEqual ("/job:localhost/replica:0/task:0/cpu:0" ,
316
- report [0 ].device_name )
313
+ self .assertTrue (report [0 ].device_name .lower ().endswith ("cpu:0" ))
317
314
self .assertEqual ("u:0" , report [0 ].tensor_name )
318
315
self .assertGreater (report [0 ].first_timestamp , 0 )
319
316
self .assertEqual (0 , report [0 ].nan_event_count )
320
317
self .assertEqual (0 , report [0 ].neg_inf_event_count )
321
318
self .assertEqual (total_num_runs , report [0 ].pos_inf_event_count )
322
- self .assertEqual ("/job:localhost/replica:0/task:0/cpu:0" ,
323
- report [1 ].device_name )
319
+ self .assertTrue (report [1 ].device_name .lower ().endswith ("cpu:0" ))
324
320
self .assertEqual ("u:0" , report [0 ].tensor_name )
325
321
self .assertGreaterEqual (report [1 ].first_timestamp ,
326
322
report [0 ].first_timestamp )
0 commit comments