@@ -329,6 +329,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
329
329
def __init__ (self , hs : "HomeServer" ):
330
330
super ().__init__ ()
331
331
self .hs = hs
332
+ self ._auth = hs .get_auth ()
332
333
self .server_name = hs .hostname
333
334
self .registration_handler = hs .get_registration_handler ()
334
335
self .ratelimiter = FederationRateLimiter (
@@ -361,7 +362,7 @@ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
361
362
if self .inhibit_user_in_use_error :
362
363
return 200 , {"available" : True }
363
364
364
- ip = request . getClientAddress (). host
365
+ ip = self . _auth . get_ip_address_from_request ( request )
365
366
with self .ratelimiter .ratelimit (ip ) as wait_deferred :
366
367
await wait_deferred
367
368
@@ -395,6 +396,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
395
396
def __init__ (self , hs : "HomeServer" ):
396
397
super ().__init__ ()
397
398
self .hs = hs
399
+ self ._auth = hs .get_auth ()
398
400
self .store = hs .get_datastores ().main
399
401
self .ratelimiter = Ratelimiter (
400
402
store = self .store ,
@@ -403,7 +405,8 @@ def __init__(self, hs: "HomeServer"):
403
405
)
404
406
405
407
async def on_GET (self , request : Request ) -> Tuple [int , JsonDict ]:
406
- await self .ratelimiter .ratelimit (None , (request .getClientAddress ().host ,))
408
+ ip_address = self ._auth .get_ip_address_from_request (request )
409
+ await self .ratelimiter .ratelimit (None , (ip_address ,))
407
410
408
411
if not self .hs .config .registration .enable_registration :
409
412
raise SynapseError (
@@ -456,7 +459,7 @@ def __init__(self, hs: "HomeServer"):
456
459
async def on_POST (self , request : SynapseRequest ) -> Tuple [int , JsonDict ]:
457
460
body = parse_json_object_from_request (request )
458
461
459
- client_addr = request . getClientAddress (). host
462
+ client_addr = self . auth . get_ip_address_from_request ( request )
460
463
461
464
await self .ratelimiter .ratelimit (None , client_addr , update = False )
462
465
@@ -916,7 +919,7 @@ def __init__(self, hs: "HomeServer"):
916
919
async def on_POST (self , request : SynapseRequest ) -> Tuple [int , JsonDict ]:
917
920
body = parse_json_object_from_request (request )
918
921
919
- client_addr = request . getClientAddress (). host
922
+ client_addr = self . auth . get_ip_address_from_request ( request )
920
923
921
924
await self .ratelimiter .ratelimit (None , client_addr , update = False )
922
925
0 commit comments