Skip to content

Commit d17c31e

Browse files
authored
Merge pull request #12 from doccano/feature/GCPLabelDetection
Feature/gcp label detection
2 parents 2b5b41a + 0d93699 commit d17c31e

File tree

14 files changed

+306
-26
lines changed

14 files changed

+306
-26
lines changed

auto_labeling_pipeline/mappings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ class AmazonComprehendSentimentTemplate(MappingTemplate):
4444
template_file = 'amazon_comprehend_sentiment.jinja2'
4545

4646

47+
class GCPImageLabelDetectionTemplate(MappingTemplate):
48+
label_collection = ClassificationLabels
49+
template_file = 'gcp_image_label_detection.jinja2'
50+
51+
4752
class AmazonComprehendEntityTemplate(MappingTemplate):
4853
label_collection = SequenceLabels
4954
template_file = 'amazon_comprehend_entity.jinja2'

auto_labeling_pipeline/menu.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,17 @@
22

33
from pydantic import BaseModel
44

5-
from auto_labeling_pipeline.mappings import (AmazonComprehendEntityTemplate, AmazonComprehendSentimentTemplate,
6-
GCPEntitiesTemplate, MappingTemplate)
7-
from auto_labeling_pipeline.models import (AmazonComprehendEntityRequestModel, AmazonComprehendPIIEntityRequestModel,
8-
AmazonComprehendSentimentRequestModel, CustomRESTRequestModel,
9-
GCPEntitiesRequestModel, RequestModel)
10-
from auto_labeling_pipeline.task import DocumentClassification, GenericTask, SequenceLabeling, Task, TaskFactory
5+
from auto_labeling_pipeline import mappings as mp
6+
from auto_labeling_pipeline import models as mo
7+
from auto_labeling_pipeline import task as t
118

129

1310
class Option(BaseModel):
1411
name: str
1512
description: str
16-
task: Type[Task]
17-
model: Type[RequestModel]
18-
template: Type[MappingTemplate]
13+
task: Type[t.Task]
14+
model: Type[mo.RequestModel]
15+
template: Type[mp.MappingTemplate]
1916

2017
class Config:
2118
arbitrary_types_allowed = True
@@ -34,8 +31,8 @@ class Options:
3431

3532
@classmethod
3633
def filter_by_task(cls, task_name: str) -> List[Option]:
37-
task = TaskFactory.create(task_name)
38-
return [option for option in cls.options if option.task == task or option.task == GenericTask]
34+
task = t.TaskFactory.create(task_name)
35+
return [option for option in cls.options if option.task == task or option.task == t.GenericTask]
3936

4037
@classmethod
4138
def find(cls, option_name: str) -> Option:
@@ -45,7 +42,7 @@ def find(cls, option_name: str) -> Option:
4542
raise ValueError('Option {} is not found.'.format(option_name))
4643

4744
@classmethod
48-
def register(cls, task: Type[Task], model: Type[RequestModel], template: Type[MappingTemplate]):
45+
def register(cls, task: Type[t.Task], model: Type[mo.RequestModel], template: Type[mp.MappingTemplate]):
4946
schema = model.schema()
5047
cls.options.append(
5148
Option(
@@ -58,8 +55,33 @@ def register(cls, task: Type[Task], model: Type[RequestModel], template: Type[Ma
5855
)
5956

6057

61-
Options.register(GenericTask, CustomRESTRequestModel, MappingTemplate)
62-
Options.register(DocumentClassification, AmazonComprehendSentimentRequestModel, AmazonComprehendSentimentTemplate)
63-
Options.register(SequenceLabeling, GCPEntitiesRequestModel, GCPEntitiesTemplate)
64-
Options.register(SequenceLabeling, AmazonComprehendEntityRequestModel, AmazonComprehendEntityTemplate)
65-
Options.register(SequenceLabeling, AmazonComprehendPIIEntityRequestModel, AmazonComprehendEntityTemplate)
58+
Options.register(
59+
t.GenericTask,
60+
mo.CustomRESTRequestModel,
61+
mp.MappingTemplate
62+
)
63+
Options.register(
64+
t.DocumentClassification,
65+
mo.AmazonComprehendSentimentRequestModel,
66+
mp.AmazonComprehendSentimentTemplate
67+
)
68+
Options.register(
69+
t.SequenceLabeling,
70+
mo.GCPEntitiesRequestModel,
71+
mp.GCPEntitiesTemplate
72+
)
73+
Options.register(
74+
t.SequenceLabeling,
75+
mo.AmazonComprehendEntityRequestModel,
76+
mp.AmazonComprehendEntityTemplate
77+
)
78+
Options.register(
79+
t.SequenceLabeling,
80+
mo.AmazonComprehendPIIEntityRequestModel,
81+
mp.AmazonComprehendEntityTemplate
82+
)
83+
Options.register(
84+
t.ImageClassification,
85+
mo.GCPImageLabelDetectionRequestModel,
86+
mp.GCPImageLabelDetectionTemplate
87+
)

auto_labeling_pipeline/models.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,36 @@ def send(self, text: str):
206206
LanguageCode=self.language_code
207207
)
208208
return response
209+
210+
211+
class GCPImageLabelDetectionRequestModel(RequestModel):
212+
"""
213+
This allow you to detect labels for a image by
214+
<a href="https://cloud.google.com/vision/docs/labels">Cloud Vision API</a>.
215+
"""
216+
key: str
217+
218+
class Config:
219+
title = 'GCP Image Label Detection'
220+
221+
def send(self, b64_image: str):
222+
url = 'https://vision.googleapis.com/v1/images:annotate'
223+
headers = {'Content-Type': 'application/json'}
224+
params = {'key': self.key}
225+
body = {
226+
'requests': [
227+
{
228+
'image': {
229+
'content': b64_image
230+
},
231+
'features': [
232+
{
233+
'maxResults': 5,
234+
'type': 'LABEL_DETECTION'
235+
}
236+
]
237+
}
238+
]
239+
}
240+
response = requests.post(url, headers=headers, params=params, json=body).json()
241+
return response

auto_labeling_pipeline/task.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,17 @@ class Seq2seq(Task):
2424
label_collection = Seq2seqLabels
2525

2626

27+
class ImageClassification(Task):
28+
label_collection = ClassificationLabels
29+
30+
2731
class TaskFactory:
2832

2933
@classmethod
3034
def create(cls, task_name: str) -> Type[Task]:
3135
return {
3236
'DocumentClassification': DocumentClassification,
3337
'SequenceLabeling': SequenceLabeling,
34-
'Seq2seq': Seq2seq
38+
'Seq2seq': Seq2seq,
39+
'ImageClassification': ImageClassification
3540
}.get(task_name, GenericTask)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[
2+
{
3+
"label": "{{ input.responses[0].labelAnnotations[0].description }}"
4+
}
5+
]
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{
2+
"responses":[
3+
{
4+
"labelAnnotations":[
5+
{
6+
"description":"Cat",
7+
"mid":"/m/01yrx",
8+
"score":0.945612,
9+
"topicality":0.945612
10+
},
11+
{
12+
"description":"Eye",
13+
"mid":"/m/014sv8",
14+
"score":0.9400194,
15+
"topicality":0.9400194
16+
},
17+
{
18+
"description":"Felidae",
19+
"mid":"/m/0307l",
20+
"score":0.8835683,
21+
"topicality":0.8835683
22+
},
23+
{
24+
"description":"Carnivore",
25+
"mid":"/m/01lrl",
26+
"score":0.8821837,
27+
"topicality":0.8821837
28+
},
29+
{
30+
"description":"Plant",
31+
"mid":"/m/05s2s",
32+
"score":0.8714797,
33+
"topicality":0.8714797
34+
}
35+
]
36+
}
37+
]
38+
}

tests/data/images/1500x500.jpeg

82.8 KB
Loading

tests/fixtures/cassettes/gcp_label_detection.yaml

Lines changed: 57 additions & 0 deletions
Large diffs are not rendered by default.

tests/fixtures/cassettes/pipeline_gcp_label_detection.yaml

Lines changed: 58 additions & 0 deletions
Large diffs are not rendered by default.

tests/test_mappings.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22

33
from auto_labeling_pipeline.mappings import (AmazonComprehendEntityTemplate, AmazonComprehendSentimentTemplate,
4-
GCPEntitiesTemplate)
4+
GCPEntitiesTemplate, GCPImageLabelDetectionTemplate)
55

66

77
def load_json(filepath):
@@ -86,3 +86,12 @@ def test_amazon_comprehend_entities(data_path):
8686
}
8787
]
8888
assert labels == expected
89+
90+
91+
def test_gcp_image_label_detection(data_path):
92+
response = load_json(data_path / 'gcp_image_label_detection.json')
93+
template = GCPImageLabelDetectionTemplate()
94+
labels = template.render(response)
95+
labels = labels.dict()
96+
expected = [{'label': 'Cat'}]
97+
assert labels == expected

0 commit comments

Comments
 (0)