@@ -73,7 +73,7 @@ type SQLQueries interface {
73
73
DeleteExtraNodeType (ctx context.Context , arg sqlc.DeleteExtraNodeTypeParams ) error
74
74
75
75
InsertNodeAddress (ctx context.Context , arg sqlc.InsertNodeAddressParams ) error
76
- GetNodeAddressesByPubKey (ctx context.Context , arg sqlc. GetNodeAddressesByPubKeyParams ) ([]sqlc.GetNodeAddressesByPubKeyRow , error )
76
+ GetNodeAddresses (ctx context.Context , nodeID int64 ) ([]sqlc.GetNodeAddressesRow , error )
77
77
DeleteNodeAddresses (ctx context.Context , nodeID int64 ) error
78
78
79
79
InsertNodeFeature (ctx context.Context , arg sqlc.InsertNodeFeatureParams ) error
@@ -103,6 +103,7 @@ type SQLQueries interface {
103
103
HighestSCID (ctx context.Context , version int16 ) ([]byte , error )
104
104
ListChannelsByNodeID (ctx context.Context , arg sqlc.ListChannelsByNodeIDParams ) ([]sqlc.ListChannelsByNodeIDRow , error )
105
105
ListChannelsWithPoliciesPaginated (ctx context.Context , arg sqlc.ListChannelsWithPoliciesPaginatedParams ) ([]sqlc.ListChannelsWithPoliciesPaginatedRow , error )
106
+ ListChannelsWithPoliciesForCachePaginated (ctx context.Context , arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams ) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow , error )
106
107
ListChannelsPaginated (ctx context.Context , arg sqlc.ListChannelsPaginatedParams ) ([]sqlc.ListChannelsPaginatedRow , error )
107
108
GetChannelsByPolicyLastUpdateRange (ctx context.Context , arg sqlc.GetChannelsByPolicyLastUpdateRangeParams ) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow , error )
108
109
GetChannelByOutpointWithPolicies (ctx context.Context , arg sqlc.GetChannelByOutpointWithPoliciesParams ) (sqlc.GetChannelByOutpointWithPoliciesRow , error )
@@ -320,10 +321,21 @@ func (s *SQLStore) AddrsForNode(ctx context.Context,
320
321
known bool
321
322
)
322
323
err := s .db .ExecTx (ctx , sqldb .ReadTxOpt (), func (db SQLQueries ) error {
323
- var err error
324
- known , addresses , err = getNodeAddresses (
325
- ctx , db , nodePub .SerializeCompressed (),
324
+ // First, check if the node exists and get its DB ID if it
325
+ // does.
326
+ dbID , err := db .GetNodeIDByPubKey (
327
+ ctx , sqlc.GetNodeIDByPubKeyParams {
328
+ Version : int16 (ProtocolV1 ),
329
+ PubKey : nodePub .SerializeCompressed (),
330
+ },
326
331
)
332
+ if errors .Is (err , sql .ErrNoRows ) {
333
+ return nil
334
+ }
335
+
336
+ known = true
337
+
338
+ addresses , err = getNodeAddresses (ctx , db , dbID )
327
339
if err != nil {
328
340
return fmt .Errorf ("unable to fetch node addresses: %w" ,
329
341
err )
@@ -1247,8 +1259,8 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1247
1259
1248
1260
ctx := context .TODO ()
1249
1261
1250
- handleChannel := func (db SQLQueries ,
1251
- row sqlc.ListChannelsWithPoliciesPaginatedRow ) error {
1262
+ handleChannel := func (
1263
+ row sqlc.ListChannelsWithPoliciesForCachePaginatedRow ) error {
1252
1264
1253
1265
node1 , node2 , err := buildNodeVertices (
1254
1266
row .Node1Pubkey , row .Node2Pubkey ,
@@ -1258,7 +1270,7 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1258
1270
}
1259
1271
1260
1272
edge := buildCacheableChannelInfo (
1261
- row .GraphChannel , node1 , node2 ,
1273
+ row .Scid , row . Capacity . Int64 , node1 , node2 ,
1262
1274
)
1263
1275
1264
1276
dbPol1 , dbPol2 , err := extractChannelPolicies (row )
@@ -1299,8 +1311,8 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1299
1311
lastID := int64 (- 1 )
1300
1312
for {
1301
1313
//nolint:ll
1302
- rows , err := db .ListChannelsWithPoliciesPaginated (
1303
- ctx , sqlc.ListChannelsWithPoliciesPaginatedParams {
1314
+ rows , err := db .ListChannelsWithPoliciesForCachePaginated (
1315
+ ctx , sqlc.ListChannelsWithPoliciesForCachePaginatedParams {
1304
1316
Version : int16 (ProtocolV1 ),
1305
1317
ID : lastID ,
1306
1318
Limit : pageSize ,
@@ -1315,12 +1327,12 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
1315
1327
}
1316
1328
1317
1329
for _ , row := range rows {
1318
- err := handleChannel (db , row )
1330
+ err := handleChannel (row )
1319
1331
if err != nil {
1320
1332
return err
1321
1333
}
1322
1334
1323
- lastID = row .GraphChannel . ID
1335
+ lastID = row .ID
1324
1336
}
1325
1337
}
1326
1338
@@ -3018,7 +3030,8 @@ func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
3018
3030
}
3019
3031
3020
3032
edge := buildCacheableChannelInfo (
3021
- row .GraphChannel , node1 , node2 ,
3033
+ row .GraphChannel .Scid , row .GraphChannel .Capacity .Int64 ,
3034
+ node1 , node2 ,
3022
3035
)
3023
3036
3024
3037
dbPol1 , dbPol2 , err := extractChannelPolicies (row )
@@ -3321,16 +3334,15 @@ func getNodeByPubKey(ctx context.Context, db SQLQueries,
3321
3334
}
3322
3335
3323
3336
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3324
- // provided database channel row and the public keys of the two nodes
3325
- // involved in the channel.
3326
- func buildCacheableChannelInfo (dbChan sqlc.GraphChannel , node1Pub ,
3337
+ // provided parameters.
3338
+ func buildCacheableChannelInfo (scid []byte , capacity int64 , node1Pub ,
3327
3339
node2Pub route.Vertex ) * models.CachedEdgeInfo {
3328
3340
3329
3341
return & models.CachedEdgeInfo {
3330
- ChannelID : byteOrder .Uint64 (dbChan . Scid ),
3342
+ ChannelID : byteOrder .Uint64 (scid ),
3331
3343
NodeKey1Bytes : node1Pub ,
3332
3344
NodeKey2Bytes : node2Pub ,
3333
- Capacity : btcutil .Amount (dbChan . Capacity . Int64 ),
3345
+ Capacity : btcutil .Amount (capacity ),
3334
3346
}
3335
3347
}
3336
3348
@@ -3380,7 +3392,7 @@ func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
3380
3392
}
3381
3393
3382
3394
// Fetch the node's addresses.
3383
- _ , node .Addresses , err = getNodeAddresses (ctx , db , pub [:] )
3395
+ node .Addresses , err = getNodeAddresses (ctx , db , dbNode . ID )
3384
3396
if err != nil {
3385
3397
return nil , fmt .Errorf ("unable to fetch node(%d) " +
3386
3398
"addresses: %w" , dbNode .ID , err )
@@ -3684,42 +3696,26 @@ func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
3684
3696
return nil
3685
3697
}
3686
3698
3687
- // getNodeAddresses fetches the addresses for a node with the given public key .
3688
- func getNodeAddresses (ctx context.Context , db SQLQueries , nodePub [] byte ) (bool ,
3689
- []net. Addr , error ) {
3699
+ // getNodeAddresses fetches the addresses for a node with the given DB ID .
3700
+ func getNodeAddresses (ctx context.Context , db SQLQueries , id int64 ) ([]net. Addr ,
3701
+ error ) {
3690
3702
3691
- // GetNodeAddressesByPubKey ensures that the addresses for a given type
3692
- // are returned in the same order as they were inserted.
3693
- rows , err := db .GetNodeAddressesByPubKey (
3694
- ctx , sqlc.GetNodeAddressesByPubKeyParams {
3695
- Version : int16 (ProtocolV1 ),
3696
- PubKey : nodePub ,
3697
- },
3698
- )
3703
+ // GetNodeAddresses ensures that the addresses for a given type are
3704
+ // returned in the same order as they were inserted.
3705
+ rows , err := db .GetNodeAddresses (ctx , id )
3699
3706
if err != nil {
3700
- return false , nil , err
3701
- }
3702
-
3703
- // GetNodeAddressesByPubKey uses a left join so there should always be
3704
- // at least one row returned if the node exists even if it has no
3705
- // addresses.
3706
- if len (rows ) == 0 {
3707
- return false , nil , nil
3707
+ return nil , err
3708
3708
}
3709
3709
3710
3710
addresses := make ([]net.Addr , 0 , len (rows ))
3711
- for _ , addr := range rows {
3712
- if ! (addr .Type .Valid && addr .Address .Valid ) {
3713
- continue
3714
- }
3715
-
3716
- address := addr .Address .String
3711
+ for _ , row := range rows {
3712
+ address := row .Address
3717
3713
3718
- switch dbAddressType (addr .Type . Int16 ) {
3714
+ switch dbAddressType (row .Type ) {
3719
3715
case addressTypeIPv4 :
3720
3716
tcp , err := net .ResolveTCPAddr ("tcp4" , address )
3721
3717
if err != nil {
3722
- return false , nil , nil
3718
+ return nil , err
3723
3719
}
3724
3720
tcp .IP = tcp .IP .To4 ()
3725
3721
@@ -3728,21 +3724,20 @@ func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3728
3724
case addressTypeIPv6 :
3729
3725
tcp , err := net .ResolveTCPAddr ("tcp6" , address )
3730
3726
if err != nil {
3731
- return false , nil , nil
3727
+ return nil , err
3732
3728
}
3733
3729
addresses = append (addresses , tcp )
3734
3730
3735
3731
case addressTypeTorV3 , addressTypeTorV2 :
3736
3732
service , portStr , err := net .SplitHostPort (address )
3737
3733
if err != nil {
3738
- return false , nil , fmt .Errorf ("unable to " +
3739
- "split tor v3 address: %v" ,
3740
- addr .Address )
3734
+ return nil , fmt .Errorf ("unable to " +
3735
+ "split tor v3 address: %v" , address )
3741
3736
}
3742
3737
3743
3738
port , err := strconv .Atoi (portStr )
3744
3739
if err != nil {
3745
- return false , nil , err
3740
+ return nil , err
3746
3741
}
3747
3742
3748
3743
addresses = append (addresses , & tor.OnionAddr {
@@ -3753,17 +3748,17 @@ func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3753
3748
case addressTypeOpaque :
3754
3749
opaque , err := hex .DecodeString (address )
3755
3750
if err != nil {
3756
- return false , nil , fmt .Errorf ("unable to " +
3757
- "decode opaque address: %v" , addr )
3751
+ return nil , fmt .Errorf ("unable to " +
3752
+ "decode opaque address: %v" , address )
3758
3753
}
3759
3754
3760
3755
addresses = append (addresses , & lnwire.OpaqueAddrs {
3761
3756
Payload : opaque ,
3762
3757
})
3763
3758
3764
3759
default :
3765
- return false , nil , fmt .Errorf ("unknown address " +
3766
- "type: %v" , addr .Type )
3760
+ return nil , fmt .Errorf ("unknown address type: %v" ,
3761
+ row .Type )
3767
3762
}
3768
3763
}
3769
3764
@@ -3773,7 +3768,7 @@ func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3773
3768
addresses = nil
3774
3769
}
3775
3770
3776
- return true , addresses , nil
3771
+ return addresses , nil
3777
3772
}
3778
3773
3779
3774
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
@@ -4368,6 +4363,38 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
4368
4363
4369
4364
var policy1 , policy2 * sqlc.GraphChannelPolicy
4370
4365
switch r := row .(type ) {
4366
+ case sqlc.ListChannelsWithPoliciesForCachePaginatedRow :
4367
+ if r .Policy1Timelock .Valid {
4368
+ policy1 = & sqlc.GraphChannelPolicy {
4369
+ Timelock : r .Policy1Timelock .Int32 ,
4370
+ FeePpm : r .Policy1FeePpm .Int64 ,
4371
+ BaseFeeMsat : r .Policy1BaseFeeMsat .Int64 ,
4372
+ MinHtlcMsat : r .Policy1MinHtlcMsat .Int64 ,
4373
+ MaxHtlcMsat : r .Policy1MaxHtlcMsat ,
4374
+ InboundBaseFeeMsat : r .Policy1InboundBaseFeeMsat ,
4375
+ InboundFeeRateMilliMsat : r .Policy1InboundFeeRateMilliMsat ,
4376
+ Disabled : r .Policy1Disabled ,
4377
+ MessageFlags : r .Policy1MessageFlags ,
4378
+ ChannelFlags : r .Policy1ChannelFlags ,
4379
+ }
4380
+ }
4381
+ if r .Policy2Timelock .Valid {
4382
+ policy2 = & sqlc.GraphChannelPolicy {
4383
+ Timelock : r .Policy2Timelock .Int32 ,
4384
+ FeePpm : r .Policy2FeePpm .Int64 ,
4385
+ BaseFeeMsat : r .Policy2BaseFeeMsat .Int64 ,
4386
+ MinHtlcMsat : r .Policy2MinHtlcMsat .Int64 ,
4387
+ MaxHtlcMsat : r .Policy2MaxHtlcMsat ,
4388
+ InboundBaseFeeMsat : r .Policy2InboundBaseFeeMsat ,
4389
+ InboundFeeRateMilliMsat : r .Policy2InboundFeeRateMilliMsat ,
4390
+ Disabled : r .Policy2Disabled ,
4391
+ MessageFlags : r .Policy2MessageFlags ,
4392
+ ChannelFlags : r .Policy2ChannelFlags ,
4393
+ }
4394
+ }
4395
+
4396
+ return policy1 , policy2 , nil
4397
+
4371
4398
case sqlc.GetChannelsBySCIDWithPoliciesRow :
4372
4399
if r .Policy1ID .Valid {
4373
4400
policy1 = & sqlc.GraphChannelPolicy {
0 commit comments