22
33import vcr
44
5- from auto_labeling_pipeline .mappings import AmazonComprehendSentimentTemplate
6- from auto_labeling_pipeline .models import AmazonComprehendSentimentRequestModel
5+ from auto_labeling_pipeline .mappings import AmazonComprehendSentimentTemplate , GCPImageLabelDetectionTemplate
6+ from auto_labeling_pipeline .models import AmazonComprehendSentimentRequestModel , GCPImageLabelDetectionRequestModel
77from auto_labeling_pipeline .pipeline import pipeline
88from auto_labeling_pipeline .postprocessing import PostProcessor
99
10+ from .test_models import load_image_as_b64
11+
1012
1113def test_amazon_pipeline (cassettes_path ):
1214 with vcr .use_cassette (str (cassettes_path / 'amazon_comprehend_sentiment.yaml' ),
@@ -30,3 +32,25 @@ def test_amazon_pipeline(cassettes_path):
3032 assert isinstance (labels , list )
3133 assert len (labels ) == 1
3234 assert 'label' in labels [0 ]
35+
36+
37+ def test_gcp_label_detection_pipeline (data_path , cassettes_path ):
38+ with vcr .use_cassette (
39+ str (cassettes_path / 'pipeline_gcp_label_detection.yaml' ),
40+ mode = 'once' ,
41+ filter_query_parameters = ['key' ]
42+ ):
43+ model = GCPImageLabelDetectionRequestModel (key = os .environ .get ('API_KEY_GCP' , '' ))
44+ template = GCPImageLabelDetectionTemplate ()
45+ b64_image = load_image_as_b64 (data_path / 'images/1500x500.jpeg' )
46+ post_processor = PostProcessor ({})
47+ labels = pipeline (
48+ text = b64_image ,
49+ request_model = model ,
50+ mapping_template = template ,
51+ post_processing = post_processor
52+ )
53+ labels = labels .dict ()
54+ assert isinstance (labels , list )
55+ assert len (labels ) == 1
56+ assert 'label' in labels [0 ]
0 commit comments