@@ -35,6 +35,7 @@ import (
3535 "google.golang.org/grpc/codes"
3636 "google.golang.org/grpc/mem"
3737 "google.golang.org/grpc/metadata"
38+ "google.golang.org/grpc/stats"
3839 "google.golang.org/grpc/status"
3940 "google.golang.org/protobuf/proto"
4041 "google.golang.org/protobuf/protoadapt"
@@ -246,7 +247,26 @@ type handleStreamTest struct {
246247 ht * serverHandlerTransport
247248}
248249
249- func newHandleStreamTest (t * testing.T ) * handleStreamTest {
250+ type mockStatsHandler struct {
251+ rpcStatsCh chan stats.RPCStats
252+ }
253+
254+ func (h * mockStatsHandler ) TagRPC (ctx context.Context , _ * stats.RPCTagInfo ) context.Context {
255+ return ctx
256+ }
257+
258+ func (h * mockStatsHandler ) HandleRPC (_ context.Context , s stats.RPCStats ) {
259+ h .rpcStatsCh <- s
260+ }
261+
262+ func (h * mockStatsHandler ) TagConn (ctx context.Context , _ * stats.ConnTagInfo ) context.Context {
263+ return ctx
264+ }
265+
266+ func (h * mockStatsHandler ) HandleConn (context.Context , stats.ConnStats ) {
267+ }
268+
269+ func newHandleStreamTest (t * testing.T , statsHandlers []stats.Handler ) * handleStreamTest {
250270 bodyr , bodyw := io .Pipe ()
251271 req := & http.Request {
252272 ProtoMajor : 2 ,
@@ -260,7 +280,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
260280 Body : bodyr ,
261281 }
262282 rw := newTestHandlerResponseWriter ().(testHandlerResponseWriter )
263- ht , err := NewServerHandlerTransport (rw , req , nil , mem .DefaultBufferPool ())
283+ ht , err := NewServerHandlerTransport (rw , req , statsHandlers , mem .DefaultBufferPool ())
264284 if err != nil {
265285 t .Fatal (err )
266286 }
@@ -273,7 +293,7 @@ func newHandleStreamTest(t *testing.T) *handleStreamTest {
273293}
274294
275295func (s ) TestHandlerTransport_HandleStreams (t * testing.T ) {
276- st := newHandleStreamTest (t )
296+ st := newHandleStreamTest (t , nil )
277297 handleStream := func (s * ServerStream ) {
278298 if want := "/service/foo.bar" ; s .method != want {
279299 t .Errorf ("stream method = %q; want %q" , s .method , want )
@@ -342,7 +362,7 @@ func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
342362}
343363
344364func handleStreamCloseBodyTest (t * testing.T , statusCode codes.Code , msg string ) {
345- st := newHandleStreamTest (t )
365+ st := newHandleStreamTest (t , nil )
346366
347367 handleStream := func (s * ServerStream ) {
348368 s .WriteStatus (status .New (statusCode , msg ))
@@ -451,7 +471,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
451471}
452472
453473func testHandlerTransportHandleStreams (t * testing.T , handleStream func (st * handleStreamTest , s * ServerStream )) {
454- st := newHandleStreamTest (t )
474+ st := newHandleStreamTest (t , nil )
455475 ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
456476 t .Cleanup (cancel )
457477 st .ht .HandleStreams (
@@ -483,7 +503,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
483503 t .Fatal (err )
484504 }
485505
486- hst := newHandleStreamTest (t )
506+ hst := newHandleStreamTest (t , nil )
487507 handleStream := func (s * ServerStream ) {
488508 s .WriteStatus (st )
489509 }
@@ -506,11 +526,81 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
506526 checkHeaderAndTrailer (t , hst .rw , wantHeader , wantTrailer )
507527}
508528
529+ // Tests the use of stats handlers and ensures there are no data races while
530+ // accessing trailers.
531+ func (s ) TestHandlerTransport_HandleStreams_StatsHandlers (t * testing.T ) {
532+ errDetails := []protoadapt.MessageV1 {
533+ & epb.RetryInfo {
534+ RetryDelay : & durationpb.Duration {Seconds : 60 },
535+ },
536+ & epb.ResourceInfo {
537+ ResourceType : "foo bar" ,
538+ ResourceName : "service.foo.bar" ,
539+ Owner : "User" ,
540+ },
541+ }
542+
543+ statusCode := codes .ResourceExhausted
544+ msg := "you are being throttled"
545+ st , err := status .New (statusCode , msg ).WithDetails (errDetails ... )
546+ if err != nil {
547+ t .Fatal (err )
548+ }
549+
550+ stBytes , err := proto .Marshal (st .Proto ())
551+ if err != nil {
552+ t .Fatal (err )
553+ }
554+ // Add mock stats handlers to exercise the stats handler code path.
555+ statsHandler := & mockStatsHandler {
556+ rpcStatsCh : make (chan stats.RPCStats , 2 ),
557+ }
558+ hst := newHandleStreamTest (t , []stats.Handler {statsHandler })
559+ handleStream := func (s * ServerStream ) {
560+ if err := s .SendHeader (metadata .New (map [string ]string {})); err != nil {
561+ t .Error (err )
562+ }
563+ if err := s .SetTrailer (metadata .Pairs ("custom-trailer" , "Custom trailer value" )); err != nil {
564+ t .Error (err )
565+ }
566+ s .WriteStatus (st )
567+ }
568+ ctx , cancel := context .WithTimeout (context .Background (), defaultTestTimeout )
569+ defer cancel ()
570+ hst .ht .HandleStreams (
571+ ctx , func (s * ServerStream ) { go handleStream (s ) },
572+ )
573+ wantHeader := http.Header {
574+ "Date" : nil ,
575+ "Content-Type" : {"application/grpc" },
576+ "Trailer" : {"Grpc-Status" , "Grpc-Message" , "Grpc-Status-Details-Bin" },
577+ }
578+ wantTrailer := http.Header {
579+ "Grpc-Status" : {fmt .Sprint (uint32 (statusCode ))},
580+ "Grpc-Message" : {encodeGrpcMessage (msg )},
581+ "Grpc-Status-Details-Bin" : {encodeBinHeader (stBytes )},
582+ "Custom-Trailer" : []string {"Custom trailer value" },
583+ }
584+
585+ checkHeaderAndTrailer (t , hst .rw , wantHeader , wantTrailer )
586+ wantStatTypes := []stats.RPCStats {& stats.OutHeader {}, & stats.OutTrailer {}}
587+ for _ , wantType := range wantStatTypes {
588+ select {
589+ case <- ctx .Done ():
590+ t .Fatal ("Context timed out waiting for statsHandler.HandleRPC() to be called." )
591+ case s := <- statsHandler .rpcStatsCh :
592+ if reflect .TypeOf (s ) != reflect .TypeOf (wantType ) {
593+ t .Fatalf ("Received RPCStats of type %T, want %T" , s , wantType )
594+ }
595+ }
596+ }
597+ }
598+
509599// TestHandlerTransport_Drain verifies that Drain() is not implemented
510600// by `serverHandlerTransport`.
511601func (s ) TestHandlerTransport_Drain (t * testing.T ) {
512602 defer func () { recover () }()
513- st := newHandleStreamTest (t )
603+ st := newHandleStreamTest (t , nil )
514604 st .ht .Drain ("whatever" )
515605 t .Errorf ("serverHandlerTransport.Drain() should have panicked" )
516606}
0 commit comments