6060#include  "aes_icm_ext.h" 
6161#endif 
6262
63+ #include  <stddef.h> 
64+ #include  <string.h> 
6365#include  <limits.h> 
6466#ifdef  HAVE_NETINET_IN_H 
6567#include  <netinet/in.h> 
6668#elif  defined(HAVE_WINSOCK2_H )
6769#include  <winsock2.h> 
6870#endif 
6971
72+ #if  defined(__SSE2__ )
73+ #include  <emmintrin.h> 
74+ #if  defined(_MSC_VER )
75+ #include  <intrin.h> 
76+ #endif 
77+ #endif 
78+ 
7079/* the debug module for srtp */ 
7180srtp_debug_module_t  mod_srtp  =  {
7281    0 ,     /* debugging is off by default */ 
@@ -79,6 +88,16 @@ srtp_debug_module_t mod_srtp = {
7988#define  uint32s_in_rtcp_header  2
8089#define  octets_in_rtp_extn_hdr  4
8190
91+ #ifndef  SRTP_NO_STREAM_LIST 
92+ static  inline  uint32_t  srtp_stream_list_size (srtp_stream_list_t  list );
93+ static  srtp_err_status_t  srtp_stream_list_reserve (srtp_stream_list_t  list ,
94+                                                   uint32_t  new_capacity );
95+ static  uint32_t  srtp_stream_list_find (srtp_stream_list_t  list , uint32_t  ssrc );
96+ static  inline  srtp_stream_t  srtp_stream_list_get_at (srtp_stream_list_t  list ,
97+                                                     uint32_t  pos );
98+ static  void  srtp_stream_list_remove_at (srtp_stream_list_t  list , uint32_t  pos );
99+ #endif  // SRTP_NO_STREAM_LIST 
100+ 
82101static  srtp_err_status_t  srtp_validate_rtp_header (void  * rtp_hdr ,
83102                                                  int  * pkt_octet_len )
84103{
@@ -3030,18 +3049,31 @@ srtp_err_status_t srtp_remove_stream(srtp_t session, uint32_t ssrc)
30303049{
30313050    srtp_stream_ctx_t  * stream ;
30323051    srtp_err_status_t  status ;
3052+ #if  !defined(SRTP_NO_STREAM_LIST )
3053+     uint32_t  pos ;
3054+ #endif 
30333055
30343056    /* sanity check arguments */ 
3035-     if  (session  ==  NULL )
3057+     if  (session  ==  NULL ) { 
30363058        return  srtp_err_status_bad_param ;
3059+     }
30373060
30383061    /* find and remove stream from the list */ 
3062+ #if  !defined(SRTP_NO_STREAM_LIST )
3063+     pos  =  srtp_stream_list_find (session -> stream_list , ssrc );
3064+     if  (pos  >= srtp_stream_list_size (session -> stream_list ))
3065+         return  srtp_err_status_no_ctx ;
3066+ 
3067+     stream  =  srtp_stream_list_get_at (session -> stream_list , pos );
3068+     srtp_stream_list_remove_at (session -> stream_list , pos );
3069+ #else 
30393070    stream  =  srtp_stream_list_get (session -> stream_list , ssrc );
30403071    if  (stream  ==  NULL ) {
30413072        return  srtp_err_status_no_ctx ;
30423073    }
30433074
30443075    srtp_stream_list_remove (session -> stream_list , stream );
3076+ #endif 
30453077
30463078    /* deallocate the stream */ 
30473079    status  =  srtp_stream_dealloc (stream , session -> stream_template );
@@ -4840,11 +4872,11 @@ srtp_err_status_t srtp_get_stream_roc(srtp_t session,
48404872
48414873#ifndef  SRTP_NO_STREAM_LIST 
48424874
4843- /* in the default implementation, we have an intrusive doubly-linked list */ 
48444875typedef  struct  srtp_stream_list_ctx_t_  {
4845-     /* a stub stream that just holds pointers to the beginning and end of the 
4846-      * list */ 
4847-     srtp_stream_ctx_t  data ;
4876+     uint32_t  * ssrcs ;
4877+     srtp_stream_ctx_t  * * streams ;
4878+     uint32_t  size ;
4879+     uint32_t  capacity ;
48484880} srtp_stream_list_ctx_t_ ;
48494881
48504882srtp_err_status_t  srtp_stream_list_alloc (srtp_stream_list_t  * list_ptr )
@@ -4855,73 +4887,204 @@ srtp_err_status_t srtp_stream_list_alloc(srtp_stream_list_t *list_ptr)
48554887        return  srtp_err_status_alloc_fail ;
48564888    }
48574889
4858-     list -> data .next  =  NULL ;
4859-     list -> data .prev  =  NULL ;
4860- 
48614890    * list_ptr  =  list ;
48624891    return  srtp_err_status_ok ;
48634892}
48644893
48654894srtp_err_status_t  srtp_stream_list_dealloc (srtp_stream_list_t  list )
48664895{
48674896    /* list must be empty */ 
4868-     if  (list -> data . next ) {
4897+     if  (list -> size   !=   0u ) {
48694898        return  srtp_err_status_fail ;
48704899    }
4900+     srtp_crypto_free (list -> streams );
4901+     srtp_crypto_free (list -> ssrcs );
48714902    srtp_crypto_free (list );
48724903    return  srtp_err_status_ok ;
48734904}
48744905
4906+ static  inline  uint32_t  srtp_stream_list_size (srtp_stream_list_t  list )
4907+ {
4908+     return  list -> size ;
4909+ }
4910+ 
4911+ static  srtp_err_status_t  srtp_stream_list_reserve (srtp_stream_list_t  list ,
4912+                                                   uint32_t  new_capacity )
4913+ {
4914+     if  (new_capacity  >  list -> capacity ) {
4915+         uint32_t  * ssrcs ;
4916+         srtp_stream_ctx_t  * * stream_ptrs ;
4917+ 
4918+         if  (new_capacity  >  (UINT32_MAX  -  15u ))
4919+             return  srtp_err_status_alloc_fail ;
4920+ 
4921+         new_capacity  =  (new_capacity  +  15u ) &  ~((uint32_t )15u );
4922+ 
4923+         ssrcs  =  (uint32_t  * )srtp_crypto_alloc ((size_t )new_capacity  * 
4924+                                               sizeof (uint32_t ));
4925+         if  (!ssrcs )
4926+             return  srtp_err_status_alloc_fail ;
4927+         stream_ptrs  =  (srtp_stream_ctx_t  * * )srtp_crypto_alloc (
4928+             (size_t )new_capacity  *  sizeof (srtp_stream_ctx_t  * ));
4929+         if  (!stream_ptrs ) {
4930+             srtp_crypto_free (ssrcs );
4931+             return  srtp_err_status_alloc_fail ;
4932+         }
4933+ 
4934+         if  (list -> size  >  0u ) {
4935+             memcpy (ssrcs , list -> ssrcs , (size_t )list -> size  *  sizeof (uint32_t ));
4936+             memcpy (stream_ptrs , list -> streams ,
4937+                    (size_t )list -> size  *  sizeof (srtp_stream_ctx_t  * ));
4938+         }
4939+ 
4940+         srtp_crypto_free (list -> ssrcs );
4941+         srtp_crypto_free (list -> streams );
4942+         list -> streams  =  stream_ptrs ;
4943+         list -> ssrcs  =  ssrcs ;
4944+ 
4945+         list -> capacity  =  new_capacity ;
4946+     }
4947+ 
4948+     return  srtp_err_status_ok ;
4949+ }
4950+ 
48754951srtp_err_status_t  srtp_stream_list_insert (srtp_stream_list_t  list ,
48764952                                          srtp_stream_t  stream )
48774953{
4878-     /* insert at the head of the list */ 
4879-     stream -> next   =   list -> data . next ;
4880-     if  (stream -> next   !=   NULL ) { 
4881-         stream -> next -> prev   =   stream ;
4882-     } 
4883-     list -> data . next  =  stream ;
4884-     stream -> prev  =  & ( list -> data ) ;
4954+     uint32_t   pos ; 
4955+     srtp_err_status_t   status   =   srtp_stream_list_reserve ( list ,  list -> size   +   1u ) ;
4956+     if  (status ) 
4957+         return   status ;
4958+     pos   =   list -> size ++ ; 
4959+     list -> ssrcs [ pos ]  =  stream -> ssrc ;
4960+     list -> streams [ pos ]  =  stream ;
48854961
48864962    return  srtp_err_status_ok ;
48874963}
48884964
4889- srtp_stream_t   srtp_stream_list_get (srtp_stream_list_t  list , uint32_t  ssrc )
4965+ static   uint32_t   srtp_stream_list_find (srtp_stream_list_t  list , uint32_t  ssrc )
48904966{
4891-     /* walk down list until ssrc is found */ 
4892-     srtp_stream_t  stream  =  list -> data .next ;
4893-     while  (stream  !=  NULL ) {
4894-         if  (stream -> ssrc  ==  ssrc ) {
4895-             return  stream ;
4967+ #if  defined(__SSE2__ )
4968+     const  uint32_t  * const  ssrcs  =  list -> ssrcs ;
4969+     const  __m128i  mm_ssrc  =  _mm_set1_epi32 (ssrc );
4970+     uint32_t  pos  =  0u , n  =  (list -> size  +  7u ) &  ~(uint32_t )(7u );
4971+     for  (uint32_t  m  =  n  &  ~(uint32_t )(15u ); pos  <  m ; pos  +=  16u ) {
4972+         __m128i  mm1  =  _mm_loadu_si128 ((const  __m128i  * )(ssrcs  +  pos ));
4973+         __m128i  mm2  =  _mm_loadu_si128 ((const  __m128i  * )(ssrcs  +  pos  +  4u ));
4974+         __m128i  mm3  =  _mm_loadu_si128 ((const  __m128i  * )(ssrcs  +  pos  +  8u ));
4975+         __m128i  mm4  =  _mm_loadu_si128 ((const  __m128i  * )(ssrcs  +  pos  +  12u ));
4976+         mm1  =  _mm_cmpeq_epi32 (mm1 , mm_ssrc );
4977+         mm2  =  _mm_cmpeq_epi32 (mm2 , mm_ssrc );
4978+         mm3  =  _mm_cmpeq_epi32 (mm3 , mm_ssrc );
4979+         mm4  =  _mm_cmpeq_epi32 (mm4 , mm_ssrc );
4980+         mm1  =  _mm_packs_epi32 (mm1 , mm2 );
4981+         mm3  =  _mm_packs_epi32 (mm3 , mm4 );
4982+         mm1  =  _mm_packs_epi16 (mm1 , mm3 );
4983+         uint32_t  mask  =  _mm_movemask_epi8 (mm1 );
4984+         if  (mask ) {
4985+ #if  defined(_MSC_VER )
4986+             unsigned long  bit_pos ;
4987+             _BitScanForward (& bit_pos , mask );
4988+             pos  +=  bit_pos ;
4989+ #else 
4990+             pos  +=  __builtin_ctz (mask );
4991+ #endif 
4992+ 
4993+             goto done ;
4994+         }
4995+     }
4996+ 
4997+     if  (pos  <  n ) {
4998+         __m128i  mm1  =  _mm_loadu_si128 ((const  __m128i  * )(ssrcs  +  pos ));
4999+         __m128i  mm2  =  _mm_loadu_si128 ((const  __m128i  * )(ssrcs  +  pos  +  4u ));
5000+         mm1  =  _mm_cmpeq_epi32 (mm1 , mm_ssrc );
5001+         mm2  =  _mm_cmpeq_epi32 (mm2 , mm_ssrc );
5002+         mm1  =  _mm_packs_epi32 (mm1 , mm2 );
5003+ 
5004+         uint32_t  mask  =  _mm_movemask_epi8 (mm1 );
5005+         if  (mask ) {
5006+ #if  defined(_MSC_VER )
5007+             unsigned long  bit_pos ;
5008+             _BitScanForward (& bit_pos , mask );
5009+             pos  +=  bit_pos  / 2u ;
5010+ #else 
5011+             pos  +=  __builtin_ctz (mask ) / 2u ;
5012+ #endif 
5013+             goto done ;
48965014        }
4897-         stream  =  stream -> next ;
5015+ 
5016+         pos  +=  8u ;
5017+     }
5018+ 
5019+ done :
5020+     return  pos ;
5021+ #else 
5022+     /* walk down list until ssrc is found */ 
5023+     uint32_t  pos  =  0u , n  =  list -> size ;
5024+     for  (; pos  <  n ; ++ pos ) {
5025+         if  (list -> ssrcs [pos ] ==  ssrc )
5026+             break ;
48985027    }
48995028
5029+     return  pos ;
5030+ #endif 
5031+ }
5032+ 
5033+ static  inline  srtp_stream_t  srtp_stream_list_get_at (srtp_stream_list_t  list ,
5034+                                                     uint32_t  pos )
5035+ {
5036+     return  list -> streams [pos ];
5037+ }
5038+ 
5039+ srtp_stream_t  srtp_stream_list_get (srtp_stream_list_t  list , uint32_t  ssrc )
5040+ {
5041+     uint32_t  pos  =  srtp_stream_list_find (list , ssrc );
5042+     if  (pos  <  list -> size )
5043+         return  list -> streams [pos ];
5044+ 
49005045    /* we haven't found our ssrc, so return a null */ 
49015046    return  NULL ;
49025047}
49035048
4904- void  srtp_stream_list_remove (srtp_stream_list_t  list ,
4905-                              srtp_stream_t  stream_to_remove )
5049+ static  void  srtp_stream_list_remove_at (srtp_stream_list_t  list , uint32_t  pos )
49065050{
4907-     ( void ) list ;
5051+     uint32_t   tail_size ,  last_pos ;
49085052
4909-     stream_to_remove -> prev -> next  =  stream_to_remove -> next ;
4910-     if  (stream_to_remove -> next  !=  NULL ) {
4911-         stream_to_remove -> next -> prev  =  stream_to_remove -> prev ;
5053+     last_pos  =  -- list -> size ;
5054+     tail_size  =  last_pos  -  pos ;
5055+     if  (tail_size  >  0u ) {
5056+         memmove (list -> streams  +  pos , list -> streams  +  pos  +  1 ,
5057+                 (size_t )tail_size  *  sizeof (* list -> streams ));
5058+         memmove (list -> ssrcs  +  pos , list -> ssrcs  +  pos  +  1 ,
5059+                 (size_t )tail_size  *  sizeof (* list -> ssrcs ));
49125060    }
5061+ 
5062+     list -> streams [last_pos ] =  NULL ;
5063+     list -> ssrcs [last_pos ] =  0u ;
5064+ }
5065+ 
5066+ void  srtp_stream_list_remove (srtp_stream_list_t  list ,
5067+                              srtp_stream_t  stream_to_remove )
5068+ {
5069+     uint32_t  pos  =  srtp_stream_list_find (list , stream_to_remove -> ssrc );
5070+     if  (pos  <  list -> size )
5071+         srtp_stream_list_remove_at (list , pos );
49135072}
49145073
49155074void  srtp_stream_list_for_each (srtp_stream_list_t  list ,
49165075                               int  (* callback )(srtp_stream_t , void  * ),
49175076                               void  * data )
49185077{
4919-     srtp_stream_t  stream  =  list -> data .next ;
4920-     while  (stream  !=  NULL ) {
4921-         srtp_stream_t  tmp  =  stream ;
4922-         stream  =  stream -> next ;
4923-         if  (callback (tmp , data ))
5078+     uint32_t  size  =  list -> size ;
5079+     for  (uint32_t  i  =  0u ; i  <  size ;) {
5080+         if  (callback (list -> streams [i ], data ))
49245081            break ;
5082+ 
5083+         /* check if the callback removed the current element */ 
5084+         if  (size  ==  list -> size )
5085+             ++ i ;
5086+         else 
5087+             size  =  list -> size ;
49255088    }
49265089}
49275090
0 commit comments