Skip to content

Commit 0d93699

Browse files
committed
Add pipeline test case for gcp image label detection
1 parent ca138ec commit 0d93699

File tree

2 files changed

+84
-2
lines changed

2 files changed

+84
-2
lines changed

tests/fixtures/cassettes/pipeline_gcp_label_detection.yaml

Lines changed: 58 additions & 0 deletions
Large diffs are not rendered by default.

tests/test_pipeline.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
import vcr
44

5-
from auto_labeling_pipeline.mappings import AmazonComprehendSentimentTemplate
6-
from auto_labeling_pipeline.models import AmazonComprehendSentimentRequestModel
5+
from auto_labeling_pipeline.mappings import AmazonComprehendSentimentTemplate, GCPImageLabelDetectionTemplate
6+
from auto_labeling_pipeline.models import AmazonComprehendSentimentRequestModel, GCPImageLabelDetectionRequestModel
77
from auto_labeling_pipeline.pipeline import pipeline
88
from auto_labeling_pipeline.postprocessing import PostProcessor
99

10+
from .test_models import load_image_as_b64
11+
1012

1113
def test_amazon_pipeline(cassettes_path):
1214
with vcr.use_cassette(str(cassettes_path / 'amazon_comprehend_sentiment.yaml'),
@@ -30,3 +32,25 @@ def test_amazon_pipeline(cassettes_path):
3032
assert isinstance(labels, list)
3133
assert len(labels) == 1
3234
assert 'label' in labels[0]
35+
36+
37+
def test_gcp_label_detection_pipeline(data_path, cassettes_path):
38+
with vcr.use_cassette(
39+
str(cassettes_path / 'pipeline_gcp_label_detection.yaml'),
40+
mode='once',
41+
filter_query_parameters=['key']
42+
):
43+
model = GCPImageLabelDetectionRequestModel(key=os.environ.get('API_KEY_GCP', ''))
44+
template = GCPImageLabelDetectionTemplate()
45+
b64_image = load_image_as_b64(data_path / 'images/1500x500.jpeg')
46+
post_processor = PostProcessor({})
47+
labels = pipeline(
48+
text=b64_image,
49+
request_model=model,
50+
mapping_template=template,
51+
post_processing=post_processor
52+
)
53+
labels = labels.dict()
54+
assert isinstance(labels, list)
55+
assert len(labels) == 1
56+
assert 'label' in labels[0]

0 commit comments

Comments
 (0)