@@ -358,46 +358,50 @@ impl LazyEvpCipherAead {
358358 aad : Option < Aad < ' _ > > ,
359359 nonce : Option < & [ u8 ] > ,
360360 ) -> CryptographyResult < pyo3:: Bound < ' p , pyo3:: types:: PyBytes > > {
361+ if ciphertext. len ( ) < self . tag_len {
362+ return Err ( CryptographyError :: from ( exceptions:: InvalidTag :: new_err ( ( ) ) ) ) ;
363+ }
364+ Ok ( pyo3:: types:: PyBytes :: new_with (
365+ py,
366+ ciphertext. len ( ) - self . tag_len ,
367+ |b| {
368+ self . decrypt_into ( py, ciphertext, aad, nonce, b) ?;
369+ Ok ( ( ) )
370+ } ,
371+ ) ?)
372+ }
373+
374+ fn decrypt_into (
375+ & self ,
376+ py : pyo3:: Python < ' _ > ,
377+ ciphertext : & [ u8 ] ,
378+ aad : Option < Aad < ' _ > > ,
379+ nonce : Option < & [ u8 ] > ,
380+ buf : & mut [ u8 ] ,
381+ ) -> CryptographyResult < ( ) > {
361382 let key_buf = self . key . bind ( py) . extract :: < CffiBuf < ' _ > > ( ) ?;
362383
363384 let mut decryption_ctx = openssl:: cipher_ctx:: CipherCtx :: new ( ) ?;
364385 if self . is_ccm {
365386 decryption_ctx. decrypt_init ( Some ( self . cipher ) , None , None ) ?;
366387 decryption_ctx. set_iv_length ( nonce. as_ref ( ) . unwrap ( ) . len ( ) ) ?;
367-
368- if ciphertext. len ( ) < self . tag_len {
369- return Err ( CryptographyError :: from ( exceptions:: InvalidTag :: new_err ( ( ) ) ) ) ;
370- }
371-
372388 let ( _, tag) = ciphertext. split_at ( ciphertext. len ( ) - self . tag_len ) ;
373389 decryption_ctx. set_tag ( tag) ?;
374-
375390 decryption_ctx. decrypt_init ( None , Some ( key_buf. as_bytes ( ) ) , nonce) ?;
376391 } else {
377392 decryption_ctx. decrypt_init ( Some ( self . cipher ) , Some ( key_buf. as_bytes ( ) ) , None ) ?;
378393 }
379394
380- if ciphertext. len ( ) < self . tag_len {
381- return Err ( CryptographyError :: from ( exceptions:: InvalidTag :: new_err ( ( ) ) ) ) ;
382- }
383-
384- Ok ( pyo3:: types:: PyBytes :: new_with (
385- py,
386- ciphertext. len ( ) - self . tag_len ,
387- |b| {
388- EvpCipherAead :: decrypt_with_context (
389- decryption_ctx,
390- ciphertext,
391- aad,
392- nonce,
393- self . tag_len ,
394- self . tag_first ,
395- self . is_ccm ,
396- b,
397- ) ?;
398- Ok ( ( ) )
399- } ,
400- ) ?)
395+ EvpCipherAead :: decrypt_with_context (
396+ decryption_ctx,
397+ ciphertext,
398+ aad,
399+ nonce,
400+ self . tag_len ,
401+ self . tag_first ,
402+ self . is_ccm ,
403+ buf,
404+ )
401405 }
402406}
403407
@@ -474,6 +478,29 @@ impl EvpAead {
474478 } ,
475479 ) ?)
476480 }
481+
482+ fn decrypt_into (
483+ & self ,
484+ _py : pyo3:: Python < ' _ > ,
485+ ciphertext : & [ u8 ] ,
486+ aad : Option < Aad < ' _ > > ,
487+ nonce : Option < & [ u8 ] > ,
488+ buf : & mut [ u8 ] ,
489+ ) -> CryptographyResult < ( ) > {
490+ let ad = if let Some ( Aad :: Single ( ad) ) = & aad {
491+ check_length ( ad. as_bytes ( ) ) ?;
492+ ad. as_bytes ( )
493+ } else {
494+ assert ! ( aad. is_none( ) ) ;
495+ b""
496+ } ;
497+
498+ self . ctx
499+ . decrypt ( ciphertext, nonce. unwrap_or ( b"" ) , ad, buf)
500+ . map_err ( |_| exceptions:: InvalidTag :: new_err ( ( ) ) ) ?;
501+
502+ Ok ( ( ) )
503+ }
477504}
478505
479506#[ pyo3:: pyclass( frozen, module = "cryptography.hazmat.bindings._rust.openssl.aead" ) ]
@@ -618,7 +645,36 @@ impl ChaCha20Poly1305 {
618645 data : CffiBuf < ' _ > ,
619646 associated_data : Option < CffiBuf < ' _ > > ,
620647 ) -> CryptographyResult < pyo3:: Bound < ' p , pyo3:: types:: PyBytes > > {
648+ if nonce. as_bytes ( ) . len ( ) != 12 {
649+ return Err ( CryptographyError :: from (
650+ pyo3:: exceptions:: PyValueError :: new_err ( "Nonce must be 12 bytes" ) ,
651+ ) ) ;
652+ }
653+ if data. as_bytes ( ) . len ( ) < self . ctx . tag_len {
654+ return Err ( CryptographyError :: from ( exceptions:: InvalidTag :: new_err ( ( ) ) ) ) ;
655+ }
656+ Ok ( pyo3:: types:: PyBytes :: new_with (
657+ py,
658+ data. as_bytes ( ) . len ( ) - self . ctx . tag_len ,
659+ |b| {
660+ let buf = CffiMutBuf :: from_bytes ( py, b) ;
661+ self . decrypt_into ( py, nonce, data, associated_data, buf) ?;
662+ Ok ( ( ) )
663+ } ,
664+ ) ?)
665+ }
666+
667+ #[ pyo3( signature = ( nonce, data, associated_data, buf) ) ]
668+ fn decrypt_into (
669+ & self ,
670+ py : pyo3:: Python < ' _ > ,
671+ nonce : CffiBuf < ' _ > ,
672+ data : CffiBuf < ' _ > ,
673+ associated_data : Option < CffiBuf < ' _ > > ,
674+ mut buf : CffiMutBuf < ' _ > ,
675+ ) -> CryptographyResult < usize > {
621676 let nonce_bytes = nonce. as_bytes ( ) ;
677+ let data_bytes = data. as_bytes ( ) ;
622678 let aad = associated_data. map ( Aad :: Single ) ;
623679
624680 if nonce_bytes. len ( ) != 12 {
@@ -627,8 +683,24 @@ impl ChaCha20Poly1305 {
627683 ) ) ;
628684 }
629685
686+ if data. as_bytes ( ) . len ( ) < self . ctx . tag_len {
687+ return Err ( CryptographyError :: from ( exceptions:: InvalidTag :: new_err ( ( ) ) ) ) ;
688+ }
689+
690+ let expected_len = data_bytes. len ( ) - self . ctx . tag_len ;
691+ if buf. as_mut_bytes ( ) . len ( ) != expected_len {
692+ return Err ( CryptographyError :: from (
693+ pyo3:: exceptions:: PyValueError :: new_err ( format ! (
694+ "buffer must be {} bytes" ,
695+ expected_len
696+ ) ) ,
697+ ) ) ;
698+ }
699+
630700 self . ctx
631- . decrypt ( py, data. as_bytes ( ) , aad, Some ( nonce_bytes) )
701+ . decrypt_into ( py, data_bytes, aad, Some ( nonce_bytes) , buf. as_mut_bytes ( ) ) ?;
702+
703+ Ok ( expected_len)
632704 }
633705}
634706
0 commit comments