Skip to content

Commit ac79c87

Browse files
HaomiaoLiu730facebook-github-bot
authored andcommittedJun 26, 2020
[PyTorch Operator] [2/n] Adding python test
Summary: Adding python test file with image files wit the input image being p.jpg. Test for the quality difference between the raw image and the decoded image Test Plan: Parsing buck files: finished in 1.5 sec Building: finished in 6.4 sec (100%) 10241/10241 jobs, 2 updated Total time: 8.0 sec More details at https://www.internalfb.com/intern/buck/build/387cb1c1-2902-4f90-ae9f-83fb6d473487 Tpx test run coordinator for Facebook. See https://fburl.com/tpx for details. Running with tpx session id: 93e6ef88-ec68-41cb-9de7-7868a14e6d65 Trace available for this run at /tmp/tpx-20200623-055836.283269/trace.log Started reporting to test run: https://our.intern.facebook.com/intern/testinfra/testrun/4222124679431330 ✓ ListingSuccess: caffe2/test:test_bundled_images - main (18.865) ✓ Pass: caffe2/test:test_bundled_images - test_single_tensors (test_bundled_images.TestBundledInputs) (18.060) ✓ Pass: caffe2/test:test_bundled_images - main (18.060) Summary Pass: 2 ListingSuccess: 1 Finished test run: https://our.intern.facebook.com/intern/testinfra/testrun/4222124679431330 Reviewed By: dreiss Differential Revision: D22046611 fbshipit-source-id: fabc604269a5a4d8a37135ce776200da2794a252
1 parent c790476 commit ac79c87

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed
 

‎test/test_bundled_images.py

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#!/usr/bin/env python3
2+
import torch
3+
import torch.utils.bundled_inputs
4+
import io
5+
import cv2
6+
from torch.testing._internal.common_utils import TestCase
7+
8+
torch.ops.load_library("//caffe2/torch/fb/operators:decode_bundled_image")
9+
10+
def model_size(sm):
11+
buffer = io.BytesIO()
12+
torch.jit.save(sm, buffer)
13+
return len(buffer.getvalue())
14+
15+
def save_and_load(sm):
16+
buffer = io.BytesIO()
17+
torch.jit.save(sm, buffer)
18+
buffer.seek(0)
19+
return torch.jit.load(buffer)
20+
21+
"""Return an InflatableArg that contains a tensor of the compressed image and the way to decode it
22+
23+
keyword arguments:
24+
img_tensor -- the raw image tensor in HWC or NCHW with pixel value of type unsigned int
25+
if in NCHW format, N should be 1
26+
quality -- the quality needed to compress the image
27+
"""
28+
def bundle_jpeg_image(img_tensor, quality):
29+
# turn NCHW to HWC
30+
if img_tensor.dim() == 4:
31+
assert(img_tensor.size(0) == 1)
32+
img_tensor = img_tensor[0].permute(1, 2, 0)
33+
pixels = img_tensor.numpy()
34+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
35+
_, enc_img = cv2.imencode(".JPEG", pixels, encode_param)
36+
enc_img_tensor = torch.from_numpy(enc_img)
37+
enc_img_tensor = torch.flatten(enc_img_tensor).byte()
38+
obj = torch.utils.bundled_inputs.InflatableArg(enc_img_tensor, "torch.ops.fb.decode_bundled_image({})")
39+
return obj
40+
41+
class TestBundledInputs(TestCase):
42+
def test_single_tensors(self):
43+
class SingleTensorModel(torch.nn.Module):
44+
def forward(self, arg):
45+
return arg
46+
im = cv2.imread("caffe2/test/test_img/p1.jpg")
47+
tensor = torch.from_numpy(im)
48+
inflatable_arg = bundle_jpeg_image(tensor, 90)
49+
input = [(inflatable_arg,)]
50+
sm = torch.jit.script(SingleTensorModel())
51+
torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, input)
52+
loaded = save_and_load(sm)
53+
inflated = loaded.get_all_bundled_inputs()
54+
decoded_data = inflated[0][0]
55+
# raw image
56+
raw_data = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
57+
raw_data = torch.from_numpy(raw_data).float()
58+
raw_data = raw_data.permute(2, 0, 1)
59+
raw_data = torch.div(raw_data, 255).unsqueeze(0)
60+
self.assertEqual(len(inflated), 1)
61+
self.assertEqual(len(inflated[0]), 1)
62+
self.assertEqual(raw_data.shape, decoded_data.shape)
63+
self.assertTrue(torch.allclose(raw_data, decoded_data, atol=0.1, rtol= 1e-01))

‎test/test_img/p1.jpg

691 Bytes
Loading

0 commit comments

Comments
 (0)
Please sign in to comment.