Skip to content

Commit 1125513

Browse files
authored
Blip2 fixes (#39080)
* Fixed some devices errors * Fixed other device issues and more expectations * Reverted support flags * style * More granular support * Fixed some rebase stuff * add a not None check before .to
1 parent 28df7f8 commit 1125513

File tree

2 files changed

+50
-18
lines changed

2 files changed

+50
-18
lines changed

src/transformers/models/blip_2/modeling_blip_2.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ class Blip2PreTrainedModel(PreTrainedModel):
415415
_no_split_modules = [
416416
"Blip2Attention",
417417
"Blip2QFormerMultiHeadAttention",
418+
"Blip2EncoderLayer",
418419
"Blip2TextEmbeddings",
419420
"T5Block",
420421
"OPTDecoderLayer",
@@ -1262,6 +1263,7 @@ class Blip2Model(Blip2PreTrainedModel):
12621263
config_class = Blip2Config
12631264
main_input_name = "pixel_values"
12641265
_keep_in_fp32_modules = ["query_tokens", "qformer"]
1266+
_supports_flash_attn_2 = False # because self.qformer does not support FA2
12651267

12661268
def __init__(self, config: Blip2Config):
12671269
super().__init__(config)
@@ -1646,6 +1648,7 @@ def forward(
16461648
class Blip2TextModelWithProjection(Blip2PreTrainedModel):
16471649
supports_gradient_checkpointing = False
16481650
_keep_in_fp32_modules = ["query_tokens", "qformer"]
1651+
_supports_flash_attn_2 = False # because self.qformer does not support FA2
16491652

16501653
def __init__(self, config: Blip2Config):
16511654
super().__init__(config)
@@ -1738,6 +1741,7 @@ def forward(
17381741
class Blip2VisionModelWithProjection(Blip2PreTrainedModel):
17391742
main_input_name = "pixel_values"
17401743
_keep_in_fp32_modules = ["query_tokens", "qformer"]
1744+
_supports_flash_attn_2 = False # because self.qformer does not support FA2
17411745

17421746
def __init__(self, config: Blip2Config):
17431747
super().__init__(config)
@@ -1857,6 +1861,7 @@ class Blip2ForConditionalGeneration(Blip2PreTrainedModel, GenerationMixin):
18571861
_supports_quantized_cache = False # not all LM bacbones support (e.g. T5)
18581862

18591863
_keep_in_fp32_modules = ["query_tokens", "qformer"]
1864+
_supports_flash_attn_2 = False # because self.qformer does not support FA2
18601865

18611866
def __init__(self, config: Blip2Config):
18621867
super().__init__(config)
@@ -2086,9 +2091,13 @@ def forward(
20862091
else:
20872092
special_image_mask = input_ids == self.config.image_token_id
20882093

2089-
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
2090-
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
2091-
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
2094+
special_image_mask = (
2095+
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
2096+
)
2097+
language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
2098+
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
2099+
special_image_mask, language_model_inputs
2100+
)
20922101
else:
20932102
logger.warning_once(
20942103
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
@@ -2234,9 +2243,15 @@ def generate(
22342243
else:
22352244
special_image_mask = input_ids == self.config.image_token_id
22362245

2237-
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
2238-
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
2239-
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
2246+
special_image_mask = (
2247+
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(language_model_inputs.device)
2248+
)
2249+
language_model_inputs = language_model_inputs.to(inputs_embeds.dtype)
2250+
inputs_embeds = inputs_embeds.to(language_model_inputs.device).masked_scatter(
2251+
special_image_mask, language_model_inputs
2252+
)
2253+
2254+
attention_mask = attention_mask.to(language_attention_mask.device)
22402255
else:
22412256
logger.warning_once(
22422257
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
@@ -2259,6 +2274,8 @@ def generate(
22592274

22602275
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
22612276
if not self.language_model.config.is_encoder_decoder:
2277+
if input_ids is not None:
2278+
input_ids = input_ids.to(language_model_inputs.device)
22622279
inputs["input_ids"] = input_ids
22632280

22642281
outputs = self.language_model.generate(**inputs, **generate_kwargs)
@@ -2275,6 +2292,7 @@ def generate(
22752292
class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
22762293
main_input_name = "pixel_values"
22772294
_keep_in_fp32_modules = ["query_tokens", "qformer"]
2295+
_supports_flash_attn_2 = False # because self.qformer does not support FA2
22782296

22792297
def __init__(self, config: Blip2Config):
22802298
super().__init__(config)

tests/models/blip_2/test_modeling_blip_2.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,7 +1786,8 @@ def test_inference_opt_multi_accelerator(self):
17861786
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
17871787

17881788
# Test output
1789-
self.assertEqual(predictions[0].tolist(), [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118])
1789+
expected_ids = [2, 102, 693, 2828, 15, 5, 4105, 19, 10, 2335, 50118]
1790+
self.assertEqual(predictions[0].tolist(), [50265] * 32 + expected_ids) # 50265 is the img token id
17901791
self.assertEqual("a woman sitting on the beach with a dog", generated_text)
17911792

17921793
# image and context
@@ -1797,10 +1798,8 @@ def test_inference_opt_multi_accelerator(self):
17971798
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
17981799

17991800
# Test output
1800-
self.assertEqual(
1801-
predictions[0].tolist(),
1802-
[2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118],
1803-
)
1801+
expected_ids = [2, 45641, 35, 61, 343, 16, 42, 116, 31652, 35, 24, 18, 45, 10, 343, 6, 24, 18, 10, 4105, 50118]
1802+
self.assertEqual(predictions[0].tolist(), [50265] * 32 + expected_ids) # 50265 is the img token id
18041803
self.assertEqual(generated_text, "Question: which city is this? Answer: it's not a city, it's a beach")
18051804

18061805
@require_torch_multi_accelerator
@@ -1826,8 +1825,17 @@ def test_inference_t5_multi_accelerator(self):
18261825
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
18271826

18281827
# Test output
1829-
self.assertEqual(predictions[0].tolist(), [0, 2335, 1556, 28, 1782, 30, 8, 2608, 1])
1830-
self.assertEqual("woman playing with dog on the beach", generated_text)
1828+
expected_ids_and_text = Expectations(
1829+
{
1830+
("cuda", None): ([0, 2335, 1556, 28, 1782, 30, 8, 2608, 1], "woman playing with dog on the beach"),
1831+
("rocm", (9, 5)): (
1832+
[0, 3, 9, 2335, 19, 1556, 28, 160, 1782, 30, 8, 2608, 1],
1833+
"a woman is playing with her dog on the beach",
1834+
),
1835+
}
1836+
).get_expectation()
1837+
self.assertEqual(predictions[0].tolist(), expected_ids_and_text[0])
1838+
self.assertEqual(generated_text, expected_ids_and_text[1])
18311839

18321840
# image and context
18331841
prompt = "Question: which city is this? Answer:"
@@ -1837,11 +1845,17 @@ def test_inference_t5_multi_accelerator(self):
18371845
generated_text = processor.batch_decode(predictions, skip_special_tokens=True)[0].strip()
18381846

18391847
# Test output
1840-
self.assertEqual(
1841-
predictions[0].tolist(),
1842-
[0, 3, 7, 152, 67, 839, 1],
1843-
)
1844-
self.assertEqual(generated_text, "san diego")
1848+
expected_ids_and_text = Expectations(
1849+
{
1850+
("cuda", None): ([0, 3, 7, 152, 67, 839, 1], "san diego"),
1851+
("rocm", (9, 5)): (
1852+
[0, 3, 7, 152, 2515, 11389, 3523, 1],
1853+
"san francisco", # TODO: check if this is ok
1854+
),
1855+
}
1856+
).get_expectation()
1857+
self.assertEqual(predictions[0].tolist(), expected_ids_and_text[0])
1858+
self.assertEqual(generated_text, expected_ids_and_text[1])
18451859

18461860
def test_expansion_in_processing(self):
18471861
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")

0 commit comments

Comments
 (0)