2121from libp2p .abc import (
2222 IHost ,
2323)
24+ from libp2p .custom_types import TProtocol
2425from libp2p .discovery .random_walk .rt_refresh_manager import RTRefreshManager
2526from libp2p .kad_dht .utils import maybe_consume_signed_record
2627from libp2p .network .stream .net_stream import (
3435 PeerInfo ,
3536)
3637from libp2p .peer .peerstore import env_to_send_in_RPC
38+ from libp2p .records .pubkey import PublicKeyValidator
39+ from libp2p .records .validator import NamespacedValidator , Validator
3740from libp2p .tools .async_service import (
3841 Service ,
3942)
4043
4144from .common import (
4245 ALPHA ,
46+ BUCKET_SIZE ,
4347 PROTOCOL_ID ,
48+ PROTOCOL_PREFIX ,
4449 QUERY_TIMEOUT ,
4550)
4651from .pb .kademlia_pb2 import (
@@ -92,7 +97,17 @@ class KadDHT(Service):
9297
9398 """
9499
95- def __init__ (self , host : IHost , mode : DHTMode , enable_random_walk : bool = False ):
100+ def __init__ (
101+ self ,
102+ host : IHost ,
103+ mode : DHTMode ,
104+ enable_random_walk : bool = False ,
105+ validator : NamespacedValidator | None = None ,
106+ validator_changed : bool = False ,
107+ protocol_prefix : TProtocol = PROTOCOL_PREFIX ,
108+ enable_providers : bool = True ,
109+ enable_values : bool = True ,
110+ ):
96111 """
97112 Initialize a new Kademlia DHT node.
98113
@@ -115,6 +130,18 @@ def __init__(self, host: IHost, mode: DHTMode, enable_random_walk: bool = False)
115130 # Initialize the routing table
116131 self .routing_table = RoutingTable (self .local_peer_id , self .host )
117132
133+ self .protocol_prefix = protocol_prefix
134+ self .enable_providers = enable_providers
135+ self .enable_values = enable_values
136+ self .validator = validator
137+
138+ if validator is None :
139+ self .validator = NamespacedValidator ({"pk" : PublicKeyValidator ()})
140+
141+ # If true implies that the validator has been changed and that
142+ # Defaults should not be used
143+ self .validator_changed = validator_changed
144+
118145 # Initialize peer routing
119146 self .peer_routing = PeerRouting (host , self .routing_table )
120147
@@ -208,6 +235,84 @@ async def stop(self) -> None:
208235 else :
209236 logger .info ("RT Refresh Manager was not running (Random Walk disabled)" )
210237
238+ def apply_fallbacks (self ) -> None :
239+ """
240+ Apply fallback validators if not explicitely changed by the user
241+
242+ This sets default validators like 'pk' and 'ipns' if they are missing and
243+ the default validator set hasn't been overridden.
244+ """
245+ if not self .validator_changed :
246+ if not isinstance (self .validator , NamespacedValidator ):
247+ raise ValueError (
248+ "Default validator was changed without marking it True"
249+ )
250+
251+ if "pk" not in self .validator ._validators :
252+ self .validator ._validators ["pk" ] = PublicKeyValidator ()
253+
254+ # TODO: Do the same thing for ipns, but need to implement first.
255+
256+ def validate_config (self ) -> None :
257+ """
258+ Validate the DHT config.
259+ """
260+ if self .protocol_prefix != PROTOCOL_PREFIX :
261+ return # Skip validation for non-standart prefixes
262+
263+ for bucket in self .routing_table .buckets :
264+ if bucket .bucket_size != BUCKET_SIZE :
265+ raise ValueError (
266+ f"{ PROTOCOL_PREFIX } prefix must use bucket size { BUCKET_SIZE } "
267+ )
268+
269+ if not self .enable_providers :
270+ raise ValueError (f"{ PROTOCOL_PREFIX } prefix must have providers enabled" )
271+
272+ if not self .enable_values :
273+ raise ValueError (f"{ PROTOCOL_PREFIX } prefix must have values enabled" )
274+
275+ if not isinstance (self .validator , NamespacedValidator ):
276+ raise ValueError (
277+ f"{ PROTOCOL_PREFIX } prefix must use a namespace type validator"
278+ )
279+
280+ vmap = self .validator ._validators
281+
282+ # TODO: Need to add ipns also in the check
283+ if set (vmap .keys ()) != {"pk" }:
284+ raise ValueError (f"{ PROTOCOL_PREFIX } must have 'pk' and 'ipns' validators" )
285+
286+ pk_validator = vmap .get ("pk" )
287+ if not isinstance (pk_validator , PublicKeyValidator ):
288+ raise TypeError ("'pk' namesapce must use PublicKeyValidator" )
289+
290+ # TODO: ipns checks
291+
292+ def set_validator (self , val : NamespacedValidator ) -> None :
293+ """
294+ Set a custom validator for the DHT config.
295+
296+ This marks the validator as explicitly changed, so the default
297+ validators (pk and ipns) will not be automatically applied later.
298+ """
299+ self .validator = val
300+ self .validator_changed = True
301+ return
302+
303+ def set_namespace_validator (self , ns : str , val : Validator ) -> None :
304+ """
305+ Adds a validator under a specofic namespace to the current DHT config.
306+
307+ Raises an error if the current validator is not a NamespacedValidator
308+ """
309+ if not isinstance (self .validator , NamespacedValidator ):
310+ raise TypeError (
311+ "Can only add namespaced validators to a NamespacedValidator"
312+ )
313+
314+ self .validator ._validators ["ns" ] = val
315+
211316 async def switch_mode (self , new_mode : DHTMode ) -> DHTMode :
212317 """
213318 Switch the DHT mode.
@@ -511,8 +616,8 @@ async def handle_stream(self, stream: INetStream) -> None:
511616 await stream .close ()
512617 return
513618
514- value = self .value_store .get (key )
515- if value :
619+ value_record = self .value_store .get (key )
620+ if value_record :
516621 logger .debug (f"Found value for key { key .hex ()} " )
517622
518623 # Create response using protobuf
@@ -521,9 +626,7 @@ async def handle_stream(self, stream: INetStream) -> None:
521626
522627 # Create record
523628 response .key = key
524- response .record .key = key
525- response .record .value = value
526- response .record .timeReceived = str (time .time ())
629+ response .record .CopyFrom (value_record )
527630
528631 # Create sender_signed_peer_record
529632 envelope_bytes , _ = env_to_send_in_RPC (self .host )
@@ -666,6 +769,26 @@ async def put_value(self, key: bytes, value: bytes) -> None:
666769 """
667770 logger .debug (f"Storing value for key { key .hex ()} " )
668771
772+ if key .decode ("utf-8" ).startswith ("/" ):
773+ if self .validator is not None :
774+ # Dont allow local users to put bad values
775+ self .validator .validate (key .decode ("utf-8" ), value )
776+
777+ old_value_record = self .value_store .get (key )
778+ if old_value_record is not None and old_value_record .value != value :
779+ # Select which value is better
780+ try :
781+ index = self .validator .select (
782+ key .decode ("utf-8" ), [value , old_value_record .value ]
783+ )
784+ if index != 0 :
785+ raise ValueError (
786+ "Refusing to replace newer value with the older one"
787+ )
788+ except Exception as e :
789+ logger .debug (f"Validation error for key { key .hex ()} : { e } " )
790+ raise
791+
669792 # 1. Store locally first
670793 self .value_store .put (key , value )
671794 try :
@@ -716,10 +839,10 @@ async def get_value(self, key: bytes) -> bytes | None:
716839 logger .debug (f"Getting value for key: { key .hex ()} " )
717840
718841 # 1. Check local store first
719- value = self .value_store .get (key )
720- if value :
842+ value_record = self .value_store .get (key )
843+ if value_record :
721844 logger .debug ("Found value locally" )
722- return value
845+ return value_record . value
723846
724847 # 2. Get closest peers, excluding self
725848 closest_peers = [
0 commit comments