diff --git a/test.sh b/test.sh index 3bf4a4d8..6f62db36 100755 --- a/test.sh +++ b/test.sh @@ -4,8 +4,8 @@ OUTPUT_FILE=$1 touch $OUTPUT_FILE -echo "| Name | Data Type | Input Shapes | torch2trt kwargs | Max Error | Throughput (PyTorch) | Throughput (TensorRT) | Latency (PyTorch) | Latency (TensorRT) |" >> $OUTPUT_FILE -echo "|------|-----------|--------------|------------------|-----------|----------------------|-----------------------|-------------------|--------------------|" >> $OUTPUT_FILE +echo "| Name | Data Type | Input Shapes | torch2trt kwargs | Max Error | Peak Signal to Noise Ratio | Mean Squared Error | Throughput (PyTorch) | Throughput (TensorRT) | Latency (PyTorch) | Latency (TensorRT) |" >> $OUTPUT_FILE +echo "|------|-----------|--------------|------------------|-----------|----------------------------|--------------------|----------------------|-----------------------|-------------------|--------------------|" >> $OUTPUT_FILE python3 -m torch2trt.test -o $OUTPUT_FILE --name alexnet --include=torch2trt.tests.torchvision.classification python3 -m torch2trt.test -o $OUTPUT_FILE --name squeezenet1_0 --include=torch2trt.tests.torchvision.classification diff --git a/torch2trt/test.py b/torch2trt/test.py index ce348c99..f298ebb1 100644 --- a/torch2trt/test.py +++ b/torch2trt/test.py @@ -159,7 +159,7 @@ def run(self): max_error,psnr_db,mse, fps, fps_trt, ms, ms_trt = run(test) # write entry - line = '| %70s | %s | %25s | %s | %.2E | %.2f | %.2E | %.3g | %.3g | %.3g | %.3g |' % (name, test.dtype.__repr__().split('.')[-1], str(test.input_shapes), str(test.torch2trt_kwargs), max_error,psnr_db,mse, fps, fps_trt, ms, ms_trt) + line = '| %70s | %s | %25s | %s | %.2E | %.2f | %.2E | %.5g | %.5g | %.5g | %.5g |' % (name, test.dtype.__repr__().split('.')[-1], str(test.input_shapes), str(test.torch2trt_kwargs), max_error,psnr_db,mse, fps, fps_trt, ms, ms_trt) if args.tolerance >= 0 and max_error > args.tolerance: print(colored(line, 'yellow')) @@ -171,7 +171,7 @@ def run(self): print(line) num_success += 1 except: - line = '| %s | %s | %s | %s | N/A | N/A | N/A | N/A | N/A |' % (name, test.dtype.__repr__().split('.')[-1], str(test.input_shapes), str(test.torch2trt_kwargs)) + line = '| %s | %s | %s | %s | N/A | N/A | N/A | N/A | N/A | N/A | N/A |' % (name, test.dtype.__repr__().split('.')[-1], str(test.input_shapes), str(test.torch2trt_kwargs)) print(colored(line, 'red')) num_error += 1 tb = traceback.format_exc()