Skip to content

Commit 65beeea

Browse files
committed
Fix session_debug_test.py (patched from #754)
1 parent 0f372bc commit 65beeea

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

tensorboard/plugins/debugger/session_debug_test.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,15 +205,13 @@ def testRunSimpleNetworkoWithInfAndNaNWorks(self):
205205

206206
report = self._debug_data_server.numerics_alert_report()
207207
self.assertEqual(2, len(report))
208-
self.assertEqual("/job:localhost/replica:0/task:0/cpu:0",
209-
report[0].device_name)
208+
self.assertTrue(report[0].device_name.lower().endswith("cpu:0"))
210209
self.assertEqual("u:0", report[0].tensor_name)
211210
self.assertGreater(report[0].first_timestamp, 0)
212211
self.assertEqual(0, report[0].nan_event_count)
213212
self.assertEqual(0, report[0].neg_inf_event_count)
214213
self.assertEqual(1, report[0].pos_inf_event_count)
215-
self.assertEqual("/job:localhost/replica:0/task:0/cpu:0",
216-
report[1].device_name)
214+
self.assertTrue(report[1].device_name.lower().endswith("cpu:0"))
217215
self.assertEqual("u:0", report[0].tensor_name)
218216
self.assertGreaterEqual(report[1].first_timestamp,
219217
report[0].first_timestamp)
@@ -299,7 +297,7 @@ def testConcurrentNumericsAlertsAreRegisteredCorrectly(self):
299297

300298
def run_v(thread_id):
301299
for _ in range(num_runs_per_thread):
302-
sess.run(v, options=run_options_list[thread_id]) # DEBUG
300+
sess.run(v, options=run_options_list[thread_id])
303301

304302
run_threads = []
305303
for thread_id in range(num_threads):
@@ -312,15 +310,13 @@ def run_v(thread_id):
312310

313311
report = self._debug_data_server.numerics_alert_report()
314312
self.assertEqual(2, len(report))
315-
self.assertEqual("/job:localhost/replica:0/task:0/cpu:0",
316-
report[0].device_name)
313+
self.assertTrue(report[0].device_name.lower().endswith("cpu:0"))
317314
self.assertEqual("u:0", report[0].tensor_name)
318315
self.assertGreater(report[0].first_timestamp, 0)
319316
self.assertEqual(0, report[0].nan_event_count)
320317
self.assertEqual(0, report[0].neg_inf_event_count)
321318
self.assertEqual(total_num_runs, report[0].pos_inf_event_count)
322-
self.assertEqual("/job:localhost/replica:0/task:0/cpu:0",
323-
report[1].device_name)
319+
self.assertTrue(report[1].device_name.lower().endswith("cpu:0"))
324320
self.assertEqual("u:0", report[0].tensor_name)
325321
self.assertGreaterEqual(report[1].first_timestamp,
326322
report[0].first_timestamp)

0 commit comments

Comments
 (0)