Skip to content

Commit 18a2f5e

Browse files
authored
fix: be able to invoke Close in SSE callback (#1048)
1 parent a9ef53e commit 18a2f5e

File tree

2 files changed

+29
-11
lines changed

2 files changed

+29
-11
lines changed

sse.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -585,14 +585,16 @@ func (es *EventSource) processEvent(scanner *bufio.Scanner) error {
585585
}
586586

587587
func (es *EventSource) handleCallback(e *Event) {
588-
es.lock.RLock()
589-
defer es.lock.RUnlock()
590-
591588
eventName := e.Name
592589
if len(eventName) == 0 {
593590
eventName = defaultEventName
594591
}
595-
if cb, found := es.onEvent[eventName]; found {
592+
593+
es.lock.RLock()
594+
cb, found := es.onEvent[eventName]
595+
es.lock.RUnlock()
596+
597+
if found {
596598
if cb.Result == nil {
597599
cb.Func(e)
598600
return

sse_test.go

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,26 @@ import (
2020
)
2121

2222
func TestEventSourceSimpleFlow(t *testing.T) {
23+
es := createEventSource(t, "", nil, nil)
24+
2325
messageCounter := 0
2426
messageFunc := func(e any) {
2527
event := e.(*Event)
2628
assertEqual(t, strconv.Itoa(messageCounter), event.ID)
2729
assertEqual(t, true, strings.HasPrefix(event.Data, "The time is"))
2830
messageCounter++
31+
if messageCounter == 100 {
32+
es.Close()
33+
}
2934
}
35+
es.OnMessage(messageFunc, nil)
3036

3137
counter := 0
32-
es := createEventSource(t, "", messageFunc, nil)
3338
ts := createSSETestServer(
3439
t,
3540
10*time.Millisecond,
3641
func(w io.Writer) error {
3742
if counter == 100 {
38-
es.Close()
3943
return fmt.Errorf("stop sending events")
4044
}
4145
_, err := fmt.Fprintf(w, "id: %v\ndata: The time is %s\n\n", counter, time.Now().Format(time.UnixDate))
@@ -129,22 +133,25 @@ func TestEventSourceOverwriteFuncs(t *testing.T) {
129133
messageFunc1 := func(e any) {
130134
assertNotNil(t, e)
131135
}
136+
es := createEventSource(t, "", messageFunc1, nil)
137+
132138
message2Counter := 0
133139
messageFunc2 := func(e any) {
134140
event := e.(*Event)
135141
assertEqual(t, strconv.Itoa(message2Counter), event.ID)
136142
assertEqual(t, true, strings.HasPrefix(event.Data, "The time is"))
137143
message2Counter++
144+
if message2Counter == 50 {
145+
es.Close()
146+
}
138147
}
139148

140149
counter := 0
141-
es := createEventSource(t, "", messageFunc1, nil)
142150
ts := createSSETestServer(
143151
t,
144152
10*time.Millisecond,
145153
func(w io.Writer) error {
146154
if counter == 50 {
147-
es.Close()
148155
return fmt.Errorf("stop sending events")
149156
}
150157
_, err := fmt.Fprintf(w, "id: %v\ndata: The time is %s\n\n", counter, time.Now().Format(time.UnixDate))
@@ -177,16 +184,21 @@ func TestEventSourceOverwriteFuncs(t *testing.T) {
177184
}
178185

179186
func TestEventSourceRetry(t *testing.T) {
187+
es := createEventSource(t, "", nil, nil)
188+
180189
messageCounter := 2 // 0 & 1 connection failure
181190
messageFunc := func(e any) {
182191
event := e.(*Event)
183192
assertEqual(t, strconv.Itoa(messageCounter), event.ID)
184193
assertEqual(t, true, strings.HasPrefix(event.Data, "The time is"))
185194
messageCounter++
195+
if messageCounter == 15 {
196+
es.Close()
197+
}
186198
}
199+
es.OnMessage(messageFunc, nil)
187200

188201
counter := 0
189-
es := createEventSource(t, "", messageFunc, nil)
190202
ts := createTestServer(func(w http.ResponseWriter, r *http.Request) {
191203
if counter == 1 && r.URL.Query().Get("reconnect") == "1" {
192204
w.WriteHeader(http.StatusTooManyRequests)
@@ -445,19 +457,24 @@ func TestEventSourceWithDifferentMethods(t *testing.T) {
445457

446458
for _, tc := range testCases {
447459
t.Run(tc.name, func(t *testing.T) {
460+
es := createEventSource(t, "", nil, nil)
461+
448462
messageCounter := 0
449463
messageFunc := func(e any) {
450464
event := e.(*Event)
451465
assertEqual(t, strconv.Itoa(messageCounter), event.ID)
452466
assertEqual(t, true, strings.HasPrefix(event.Data, fmt.Sprintf("%s method test:", tc.method)))
453467
messageCounter++
468+
if messageCounter == 20 {
469+
es.Close()
470+
}
454471
}
472+
es.OnMessage(messageFunc, nil)
455473

456474
counter := 0
457475
methodVerified := false
458476
bodyVerified := false
459477

460-
es := createEventSource(t, "", messageFunc, nil)
461478
ts := createMethodVerifyingSSETestServer(
462479
t,
463480
10*time.Millisecond,
@@ -467,7 +484,6 @@ func TestEventSourceWithDifferentMethods(t *testing.T) {
467484
&bodyVerified,
468485
func(w io.Writer) error {
469486
if counter == 20 {
470-
es.Close()
471487
return fmt.Errorf("stop sending events")
472488
}
473489
_, err := fmt.Fprintf(w, "id: %v\ndata: %s method test: %s\n\n", counter, tc.method, time.Now().Format(time.RFC3339))

0 commit comments

Comments
 (0)