|
85 | 85 | "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES",
|
86 | 86 | "AutoModelForDocumentQuestionAnswering",
|
87 | 87 | ),
|
| 88 | + ( |
| 89 | + "visual-question-answering", |
| 90 | + "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES", |
| 91 | + "AutoModelForVisualQuestionAnswering", |
| 92 | + ), |
| 93 | + ("image-to-text", "MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"), |
88 | 94 | ]
|
89 | 95 |
|
90 | 96 |
|
@@ -236,10 +242,35 @@ def update_metadata(token, commit_sha):
|
236 | 242 | repo.push_to_hub(commit_message)
|
237 | 243 |
|
238 | 244 |
|
| 245 | +def check_pipeline_tags(): |
| 246 | + in_table = {tag: cls for tag, _, cls in PIPELINE_TAGS_AND_AUTO_MODELS} |
| 247 | + pipeline_tasks = transformers_module.pipelines.SUPPORTED_TASKS |
| 248 | + missing = [] |
| 249 | + for key in pipeline_tasks: |
| 250 | + if key not in in_table: |
| 251 | + model = pipeline_tasks[key]["pt"] |
| 252 | + if isinstance(model, (list, tuple)): |
| 253 | + model = model[0] |
| 254 | + model = model.__name__ |
| 255 | + if model not in in_table.values(): |
| 256 | + missing.append(key) |
| 257 | + |
| 258 | + if len(missing) > 0: |
| 259 | + msg = ", ".join(missing) |
| 260 | + raise ValueError( |
| 261 | + "The following pipeline tags are not present in the `PIPELINE_TAGS_AND_AUTO_MODELS` constant inside " |
| 262 | + f"`utils/update_metadata.py`: {msg}. Please add them!" |
| 263 | + ) |
| 264 | + |
| 265 | + |
239 | 266 | if __name__ == "__main__":
|
240 | 267 | parser = argparse.ArgumentParser()
|
241 | 268 | parser.add_argument("--token", type=str, help="The token to use to push to the transformers-metadata dataset.")
|
242 | 269 | parser.add_argument("--commit_sha", type=str, help="The sha of the commit going with this update.")
|
| 270 | + parser.add_argument("--check-only", action="store_true", help="Activate to just check all pipelines are present.") |
243 | 271 | args = parser.parse_args()
|
244 | 272 |
|
245 |
| - update_metadata(args.token, args.commit_sha) |
| 273 | + if args.check_only: |
| 274 | + check_pipeline_tags() |
| 275 | + else: |
| 276 | + update_metadata(args.token, args.commit_sha) |
0 commit comments