Skip to content

Commit 4389067

Browse files
authored
Merge pull request #10113 from ellemouton/graphPerf2
[1] graph/db: add some SQL performance improvements
2 parents f7efc15 + 5a1184c commit 4389067

File tree

6 files changed

+277
-80
lines changed

6 files changed

+277
-80
lines changed

graph/db/sql_store.go

Lines changed: 82 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ type SQLQueries interface {
7373
DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
7474

7575
InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
76-
GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
76+
GetNodeAddresses(ctx context.Context, nodeID int64) ([]sqlc.GetNodeAddressesRow, error)
7777
DeleteNodeAddresses(ctx context.Context, nodeID int64) error
7878

7979
InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
@@ -103,6 +103,7 @@ type SQLQueries interface {
103103
HighestSCID(ctx context.Context, version int16) ([]byte, error)
104104
ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
105105
ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
106+
ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
106107
ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
107108
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
108109
GetChannelByOutpointWithPolicies(ctx context.Context, arg sqlc.GetChannelByOutpointWithPoliciesParams) (sqlc.GetChannelByOutpointWithPoliciesRow, error)
@@ -320,10 +321,21 @@ func (s *SQLStore) AddrsForNode(ctx context.Context,
320321
known bool
321322
)
322323
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
323-
var err error
324-
known, addresses, err = getNodeAddresses(
325-
ctx, db, nodePub.SerializeCompressed(),
324+
// First, check if the node exists and get its DB ID if it
325+
// does.
326+
dbID, err := db.GetNodeIDByPubKey(
327+
ctx, sqlc.GetNodeIDByPubKeyParams{
328+
Version: int16(ProtocolV1),
329+
PubKey: nodePub.SerializeCompressed(),
330+
},
326331
)
332+
if errors.Is(err, sql.ErrNoRows) {
333+
return nil
334+
}
335+
336+
known = true
337+
338+
addresses, err = getNodeAddresses(ctx, db, dbID)
327339
if err != nil {
328340
return fmt.Errorf("unable to fetch node addresses: %w",
329341
err)
@@ -1247,8 +1259,8 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
12471259

12481260
ctx := context.TODO()
12491261

1250-
handleChannel := func(db SQLQueries,
1251-
row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
1262+
handleChannel := func(
1263+
row sqlc.ListChannelsWithPoliciesForCachePaginatedRow) error {
12521264

12531265
node1, node2, err := buildNodeVertices(
12541266
row.Node1Pubkey, row.Node2Pubkey,
@@ -1258,7 +1270,7 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
12581270
}
12591271

12601272
edge := buildCacheableChannelInfo(
1261-
row.GraphChannel, node1, node2,
1273+
row.Scid, row.Capacity.Int64, node1, node2,
12621274
)
12631275

12641276
dbPol1, dbPol2, err := extractChannelPolicies(row)
@@ -1299,8 +1311,8 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
12991311
lastID := int64(-1)
13001312
for {
13011313
//nolint:ll
1302-
rows, err := db.ListChannelsWithPoliciesPaginated(
1303-
ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
1314+
rows, err := db.ListChannelsWithPoliciesForCachePaginated(
1315+
ctx, sqlc.ListChannelsWithPoliciesForCachePaginatedParams{
13041316
Version: int16(ProtocolV1),
13051317
ID: lastID,
13061318
Limit: pageSize,
@@ -1315,12 +1327,12 @@ func (s *SQLStore) ForEachChannelCacheable(cb func(*models.CachedEdgeInfo,
13151327
}
13161328

13171329
for _, row := range rows {
1318-
err := handleChannel(db, row)
1330+
err := handleChannel(row)
13191331
if err != nil {
13201332
return err
13211333
}
13221334

1323-
lastID = row.GraphChannel.ID
1335+
lastID = row.ID
13241336
}
13251337
}
13261338

@@ -3018,7 +3030,8 @@ func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
30183030
}
30193031

30203032
edge := buildCacheableChannelInfo(
3021-
row.GraphChannel, node1, node2,
3033+
row.GraphChannel.Scid, row.GraphChannel.Capacity.Int64,
3034+
node1, node2,
30223035
)
30233036

30243037
dbPol1, dbPol2, err := extractChannelPolicies(row)
@@ -3321,16 +3334,15 @@ func getNodeByPubKey(ctx context.Context, db SQLQueries,
33213334
}
33223335

33233336
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
3324-
// provided database channel row and the public keys of the two nodes
3325-
// involved in the channel.
3326-
func buildCacheableChannelInfo(dbChan sqlc.GraphChannel, node1Pub,
3337+
// provided parameters.
3338+
func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
33273339
node2Pub route.Vertex) *models.CachedEdgeInfo {
33283340

33293341
return &models.CachedEdgeInfo{
3330-
ChannelID: byteOrder.Uint64(dbChan.Scid),
3342+
ChannelID: byteOrder.Uint64(scid),
33313343
NodeKey1Bytes: node1Pub,
33323344
NodeKey2Bytes: node2Pub,
3333-
Capacity: btcutil.Amount(dbChan.Capacity.Int64),
3345+
Capacity: btcutil.Amount(capacity),
33343346
}
33353347
}
33363348

@@ -3380,7 +3392,7 @@ func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
33803392
}
33813393

33823394
// Fetch the node's addresses.
3383-
_, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
3395+
node.Addresses, err = getNodeAddresses(ctx, db, dbNode.ID)
33843396
if err != nil {
33853397
return nil, fmt.Errorf("unable to fetch node(%d) "+
33863398
"addresses: %w", dbNode.ID, err)
@@ -3684,42 +3696,26 @@ func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
36843696
return nil
36853697
}
36863698

3687-
// getNodeAddresses fetches the addresses for a node with the given public key.
3688-
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
3689-
[]net.Addr, error) {
3699+
// getNodeAddresses fetches the addresses for a node with the given DB ID.
3700+
func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
3701+
error) {
36903702

3691-
// GetNodeAddressesByPubKey ensures that the addresses for a given type
3692-
// are returned in the same order as they were inserted.
3693-
rows, err := db.GetNodeAddressesByPubKey(
3694-
ctx, sqlc.GetNodeAddressesByPubKeyParams{
3695-
Version: int16(ProtocolV1),
3696-
PubKey: nodePub,
3697-
},
3698-
)
3703+
// GetNodeAddresses ensures that the addresses for a given type are
3704+
// returned in the same order as they were inserted.
3705+
rows, err := db.GetNodeAddresses(ctx, id)
36993706
if err != nil {
3700-
return false, nil, err
3701-
}
3702-
3703-
// GetNodeAddressesByPubKey uses a left join so there should always be
3704-
// at least one row returned if the node exists even if it has no
3705-
// addresses.
3706-
if len(rows) == 0 {
3707-
return false, nil, nil
3707+
return nil, err
37083708
}
37093709

37103710
addresses := make([]net.Addr, 0, len(rows))
3711-
for _, addr := range rows {
3712-
if !(addr.Type.Valid && addr.Address.Valid) {
3713-
continue
3714-
}
3715-
3716-
address := addr.Address.String
3711+
for _, row := range rows {
3712+
address := row.Address
37173713

3718-
switch dbAddressType(addr.Type.Int16) {
3714+
switch dbAddressType(row.Type) {
37193715
case addressTypeIPv4:
37203716
tcp, err := net.ResolveTCPAddr("tcp4", address)
37213717
if err != nil {
3722-
return false, nil, nil
3718+
return nil, err
37233719
}
37243720
tcp.IP = tcp.IP.To4()
37253721

@@ -3728,21 +3724,20 @@ func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
37283724
case addressTypeIPv6:
37293725
tcp, err := net.ResolveTCPAddr("tcp6", address)
37303726
if err != nil {
3731-
return false, nil, nil
3727+
return nil, err
37323728
}
37333729
addresses = append(addresses, tcp)
37343730

37353731
case addressTypeTorV3, addressTypeTorV2:
37363732
service, portStr, err := net.SplitHostPort(address)
37373733
if err != nil {
3738-
return false, nil, fmt.Errorf("unable to "+
3739-
"split tor v3 address: %v",
3740-
addr.Address)
3734+
return nil, fmt.Errorf("unable to "+
3735+
"split tor v3 address: %v", address)
37413736
}
37423737

37433738
port, err := strconv.Atoi(portStr)
37443739
if err != nil {
3745-
return false, nil, err
3740+
return nil, err
37463741
}
37473742

37483743
addresses = append(addresses, &tor.OnionAddr{
@@ -3753,17 +3748,17 @@ func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
37533748
case addressTypeOpaque:
37543749
opaque, err := hex.DecodeString(address)
37553750
if err != nil {
3756-
return false, nil, fmt.Errorf("unable to "+
3757-
"decode opaque address: %v", addr)
3751+
return nil, fmt.Errorf("unable to "+
3752+
"decode opaque address: %v", address)
37583753
}
37593754

37603755
addresses = append(addresses, &lnwire.OpaqueAddrs{
37613756
Payload: opaque,
37623757
})
37633758

37643759
default:
3765-
return false, nil, fmt.Errorf("unknown address "+
3766-
"type: %v", addr.Type)
3760+
return nil, fmt.Errorf("unknown address type: %v",
3761+
row.Type)
37673762
}
37683763
}
37693764

@@ -3773,7 +3768,7 @@ func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
37733768
addresses = nil
37743769
}
37753770

3776-
return true, addresses, nil
3771+
return addresses, nil
37773772
}
37783773

37793774
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
@@ -4368,6 +4363,38 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
43684363

43694364
var policy1, policy2 *sqlc.GraphChannelPolicy
43704365
switch r := row.(type) {
4366+
case sqlc.ListChannelsWithPoliciesForCachePaginatedRow:
4367+
if r.Policy1Timelock.Valid {
4368+
policy1 = &sqlc.GraphChannelPolicy{
4369+
Timelock: r.Policy1Timelock.Int32,
4370+
FeePpm: r.Policy1FeePpm.Int64,
4371+
BaseFeeMsat: r.Policy1BaseFeeMsat.Int64,
4372+
MinHtlcMsat: r.Policy1MinHtlcMsat.Int64,
4373+
MaxHtlcMsat: r.Policy1MaxHtlcMsat,
4374+
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
4375+
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
4376+
Disabled: r.Policy1Disabled,
4377+
MessageFlags: r.Policy1MessageFlags,
4378+
ChannelFlags: r.Policy1ChannelFlags,
4379+
}
4380+
}
4381+
if r.Policy2Timelock.Valid {
4382+
policy2 = &sqlc.GraphChannelPolicy{
4383+
Timelock: r.Policy2Timelock.Int32,
4384+
FeePpm: r.Policy2FeePpm.Int64,
4385+
BaseFeeMsat: r.Policy2BaseFeeMsat.Int64,
4386+
MinHtlcMsat: r.Policy2MinHtlcMsat.Int64,
4387+
MaxHtlcMsat: r.Policy2MaxHtlcMsat,
4388+
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
4389+
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
4390+
Disabled: r.Policy2Disabled,
4391+
MessageFlags: r.Policy2MessageFlags,
4392+
ChannelFlags: r.Policy2ChannelFlags,
4393+
}
4394+
}
4395+
4396+
return policy1, policy2, nil
4397+
43714398
case sqlc.GetChannelsBySCIDWithPoliciesRow:
43724399
if r.Policy1ID.Valid {
43734400
policy1 = &sqlc.GraphChannelPolicy{

0 commit comments

Comments
 (0)