diff --git a/docs/release-notes/release-notes-0.20.0.md b/docs/release-notes/release-notes-0.20.0.md index 83dd1e20b21..1c50edab409 100644 --- a/docs/release-notes/release-notes-0.20.0.md +++ b/docs/release-notes/release-notes-0.20.0.md @@ -33,11 +33,15 @@ https://github.com/lightningnetwork/lnd/pull/9993) are returned when receiving empty route hints or a non-UTF-8-encoded description. -- [Fixed](https://github.com/lightningnetwork/lnd/pull/10027) an issue where +- [Fixed](https://github.com/lightningnetwork/lnd/pull/10140) an issue where known TLV fields were incorrectly encoded into the `ExtraData` field of messages in the dynamic commitment set. +- [Fixed](https://github.com/lightningnetwork/lnd/pull/10072) an issue where + known TLV fields were incorrectly encoded into the `ExtraData` field of + messages in the gossip set. + # New Features - Added [NoOp HTLCs](https://github.com/lightningnetwork/lnd/pull/9871). This diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index afb2f141221..789c48e98d6 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -89,6 +89,8 @@ type AcceptChannel struct { // within the commitment transaction of the sender. FirstCommitmentPoint *btcec.PublicKey + // NOTE: The following fields are TLV records. + // // UpfrontShutdownScript is the script to which the channel funds should // be paid when mutually closing the channel. This field is optional, and // and has a length prefix, so a zero will be written if it is not set @@ -138,17 +140,27 @@ var _ SizeableMessage = (*AcceptChannel)(nil) // // This is part of the lnwire.Message interface. func (a *AcceptChannel) Encode(w *bytes.Buffer, pver uint32) error { - recordProducers := []tlv.RecordProducer{&a.UpfrontShutdownScript} + // Get producers from extra data. + producers, err := a.ExtraData.RecordProducers() + if err != nil { + return err + } + + // Append known producers. + producers = append(producers, &a.UpfrontShutdownScript) if a.ChannelType != nil { - recordProducers = append(recordProducers, a.ChannelType) + producers = append(producers, a.ChannelType) } if a.LeaseExpiry != nil { - recordProducers = append(recordProducers, a.LeaseExpiry) + producers = append(producers, a.LeaseExpiry) } a.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) { - recordProducers = append(recordProducers, &localNonce) + producers = append(producers, &localNonce) }) - err := EncodeMessageExtraData(&a.ExtraData, recordProducers...) + + // Pack all records into a new TLV stream. + var tlvData ExtraOpaqueData + err = tlvData.PackRecords(producers...) if err != nil { return err } @@ -209,7 +221,7 @@ func (a *AcceptChannel) Encode(w *bytes.Buffer, pver uint32) error { return err } - return WriteBytes(w, a.ExtraData) + return WriteBytes(w, tlvData) } // Decode deserializes the serialized AcceptChannel stored in the passed @@ -254,8 +266,8 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { leaseExpiry LeaseExpiry localNonce = a.LocalNonce.Zero() ) - typeMap, err := tlvRecords.ExtractRecords( - &a.UpfrontShutdownScript, &chanType, &leaseExpiry, + knownRecords, extraData, err := ParseAndExtractExtraData( + tlvRecords, &a.UpfrontShutdownScript, &chanType, &leaseExpiry, &localNonce, ) if err != nil { @@ -263,17 +275,17 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { } // Set the corresponding TLV types if they were included in the stream. - if val, ok := typeMap[ChannelTypeRecordType]; ok && val == nil { + if _, ok := knownRecords[ChannelTypeRecordType]; ok { a.ChannelType = &chanType } - if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil { + if _, ok := knownRecords[LeaseExpiryRecordType]; ok { a.LeaseExpiry = &leaseExpiry } - if val, ok := typeMap[a.LocalNonce.TlvType()]; ok && val == nil { + if _, ok := knownRecords[a.LocalNonce.TlvType()]; ok { a.LocalNonce = tlv.SomeRecordT(localNonce) } - a.ExtraData = tlvRecords + a.ExtraData = extraData return nil } diff --git a/lnwire/accept_channel_test.go b/lnwire/accept_channel_test.go index 87d9dc029cc..e0eded16136 100644 --- a/lnwire/accept_channel_test.go +++ b/lnwire/accept_channel_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/btcsuite/btcd/btcec/v2" + "github.com/stretchr/testify/require" ) // TestDecodeAcceptChannel tests decoding of an accept channel wire message with @@ -70,3 +71,119 @@ func TestDecodeAcceptChannel(t *testing.T) { }) } } + +// TestAcceptChannelEncodeDecode tests that a raw byte stream can be +// decoded, then re-encoded to the same exact byte stream. +func TestAcceptChannelEncodeDecode(t *testing.T) { + t.Parallel() + + // Create a new private key and its corresponding public key. + priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + pk := priv.PubKey() + + // Create a sample AcceptChannel message with all fields populated. + // The exact values are not important, only that they are of the + // correct size. + var rawBytes []byte + + // PendingChannelID + rawBytes = append(rawBytes, make([]byte, 32)...) + + // DustLimit + rawBytes = append(rawBytes, []byte{0, 0, 0, 0, 0, 0, 0, 1}...) + + // MaxValueInFlight + rawBytes = append(rawBytes, []byte{0, 0, 0, 0, 0, 0, 0, 2}...) + + // ChannelReserve + rawBytes = append(rawBytes, []byte{0, 0, 0, 0, 0, 0, 0, 3}...) + + // HtlcMinimum + rawBytes = append(rawBytes, []byte{0, 0, 0, 0, 0, 0, 0, 4}...) + + // MinAcceptDepth + rawBytes = append(rawBytes, []byte{0, 0, 0, 5}...) + + // CsvDelay + rawBytes = append(rawBytes, []byte{0, 6}...) + + // MaxAcceptedHTLCs + rawBytes = append(rawBytes, []byte{0, 7}...) + + // FundingKey + rawBytes = append(rawBytes, pk.SerializeCompressed()...) + + // RevocationPoint + rawBytes = append(rawBytes, pk.SerializeCompressed()...) + + // PaymentPoint + rawBytes = append(rawBytes, pk.SerializeCompressed()...) + + // DelayedPaymentPoint + rawBytes = append(rawBytes, pk.SerializeCompressed()...) + + // HtlcPoint + rawBytes = append(rawBytes, pk.SerializeCompressed()...) + + // FirstCommitmentPoint + rawBytes = append(rawBytes, pk.SerializeCompressed()...) + + // Add TLV data, including known and unknown records. + tlvData := []byte{ + // UpfrontShutdownScript (known, type 0) + 0, // type + 2, // length + 0xaa, 0xbb, // value + + // ChannelType (known, type 1) + 1, // type + 1, // length + 0x02, // value (feature bit 1 set) + + // Unknown odd-type TLV record. + 0x3, // type + 0x2, // length + 0xab, 0xcd, // value + + // LocalNonce (known, type 4) + 4, // type + 66, // length + // 66 bytes of dummy data + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, + + // Another unknown odd-type TLV record. + 0x6f, // type + 0x2, // length + 0x79, 0x79, // value + + // LeaseExpiry (known, type 65536) + 0xfe, 0x00, 0x01, 0x00, 0x00, // type + 4, // length + 0x12, 0x34, 0x56, 0x78, // value + } + rawBytes = append(rawBytes, tlvData...) + + // Now, create a new empty message and decode the raw bytes into it. + msg := &AcceptChannel{} + r := bytes.NewReader(rawBytes) + err = msg.Decode(r, 0) + require.NoError(t, err) + + // Next, encode the message back into a new byte buffer. + var b bytes.Buffer + err = msg.Encode(&b, 0) + require.NoError(t, err) + + // The re-encoded bytes should be exactly the same as the original raw + // bytes. + require.Equal(t, rawBytes, b.Bytes()) +} diff --git a/lnwire/channel_ready.go b/lnwire/channel_ready.go index f388db1d133..f9a62fe431e 100644 --- a/lnwire/channel_ready.go +++ b/lnwire/channel_ready.go @@ -22,6 +22,8 @@ type ChannelReady struct { // next commitment transaction for the channel. NextPerCommitmentPoint *btcec.PublicKey + // NOTE: The following fields are TLV records. + // // AliasScid is an alias ShortChannelID used to refer to the underlying // channel. It can be used instead of the confirmed on-chain // ShortChannelID for forwarding. @@ -95,8 +97,8 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { nodeNonce = tlv.ZeroRecordT[tlv.TlvType0, Musig2Nonce]() btcNonce = tlv.ZeroRecordT[tlv.TlvType2, Musig2Nonce]() ) - typeMap, err := tlvRecords.ExtractRecords( - &btcNonce, &aliasScid, &nodeNonce, &localNonce, + knownRecords, extraData, err := ParseAndExtractExtraData( + tlvRecords, &btcNonce, &aliasScid, &nodeNonce, &localNonce, ) if err != nil { return err @@ -104,24 +106,20 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { // We'll only set AliasScid if the corresponding TLV type was included // in the stream. - if val, ok := typeMap[AliasScidRecordType]; ok && val == nil { + if _, ok := knownRecords[AliasScidRecordType]; ok { c.AliasScid = &aliasScid } - if val, ok := typeMap[c.NextLocalNonce.TlvType()]; ok && val == nil { + if _, ok := knownRecords[c.NextLocalNonce.TlvType()]; ok { c.NextLocalNonce = tlv.SomeRecordT(localNonce) } - val, ok := typeMap[c.AnnouncementBitcoinNonce.TlvType()] - if ok && val == nil { + if _, ok := knownRecords[c.AnnouncementBitcoinNonce.TlvType()]; ok { c.AnnouncementBitcoinNonce = tlv.SomeRecordT(btcNonce) } - val, ok = typeMap[c.AnnouncementNodeNonce.TlvType()] - if ok && val == nil { + if _, ok := knownRecords[c.AnnouncementNodeNonce.TlvType()]; ok { c.AnnouncementNodeNonce = tlv.SomeRecordT(nodeNonce) } - if len(tlvRecords) != 0 { - c.ExtraData = tlvRecords - } + c.ExtraData = extraData return nil } @@ -140,31 +138,38 @@ func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error { return err } + // Get producers from extra data. + producers, err := c.ExtraData.RecordProducers() + if err != nil { + return err + } + // We'll only encode the AliasScid in a TLV segment if it exists. - recordProducers := make([]tlv.RecordProducer, 0, 4) if c.AliasScid != nil { - recordProducers = append(recordProducers, c.AliasScid) + producers = append(producers, c.AliasScid) } c.NextLocalNonce.WhenSome(func(localNonce Musig2NonceTLV) { - recordProducers = append(recordProducers, &localNonce) + producers = append(producers, &localNonce) }) c.AnnouncementBitcoinNonce.WhenSome( func(nonce tlv.RecordT[tlv.TlvType2, Musig2Nonce]) { - recordProducers = append(recordProducers, &nonce) + producers = append(producers, &nonce) }, ) c.AnnouncementNodeNonce.WhenSome( func(nonce tlv.RecordT[tlv.TlvType0, Musig2Nonce]) { - recordProducers = append(recordProducers, &nonce) + producers = append(producers, &nonce) }, ) - err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) + // Pack all records into a new TLV stream. + var tlvData ExtraOpaqueData + err = tlvData.PackRecords(producers...) if err != nil { return err } - return WriteBytes(w, c.ExtraData) + return WriteBytes(w, tlvData) } // MsgType returns the uint32 code which uniquely identifies this message as a diff --git a/lnwire/channel_ready_test.go b/lnwire/channel_ready_test.go new file mode 100644 index 00000000000..85985dc02dd --- /dev/null +++ b/lnwire/channel_ready_test.go @@ -0,0 +1,105 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/stretchr/testify/require" +) + +// TestChannelReadyEncodeDecode tests that a raw byte stream can be +// decoded, then re-encoded to the same exact byte stream. +func TestChannelReadyEncodeDecode(t *testing.T) { + t.Parallel() + + // Create a new private key and its corresponding public key. + priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + pk := priv.PubKey() + + // Create a sample ChannelReady message with all fields populated. + var rawBytes []byte + + // ChanID + rawBytes = append(rawBytes, make([]byte, 32)...) + + // NextPerCommitmentPoint + rawBytes = append(rawBytes, pk.SerializeCompressed()...) + + // Add TLV data, including known and unknown records. + tlvData := []byte{ + // AnnouncementNodeNonce (known, type 0) + 0, // type + 66, // length + // 66 bytes of dummy data + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, + + // AliasScid (known, type 1). + 1, // type + 8, // length + 0, 0, 0, 0, 0, 0, 0, 1, // value + + // AnnouncementBitcoinNonce (known, type 2) + 2, // type + 66, // length + // 66 bytes of dummy data + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, + + // Unknown odd-type TLV record. + 0x3, // type + 0x2, // length + 0xab, 0xcd, // value + + // NextLocalNonce (known, type 4) + 4, // type + 66, // length + // 66 bytes of dummy data + 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, + 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, + 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, + 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, + 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, + 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, + 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, + 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, + 0x44, 0x44, + + // Another unknown odd-type TLV record at the end. + 0x6f, // type + 0x2, // length + 0x79, 0x79, // value + } + rawBytes = append(rawBytes, tlvData...) + + // Now, create a new empty message and decode the raw bytes into it. + msg := &ChannelReady{} + r := bytes.NewReader(rawBytes) + err = msg.Decode(r, 0) + require.NoError(t, err) + + // Next, encode the message back into a new byte buffer. + var b bytes.Buffer + err = msg.Encode(&b, 0) + require.NoError(t, err) + + // The re-encoded bytes should be exactly the same as the original raw + // bytes. + require.Equal(t, rawBytes, b.Bytes()) +} diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index f26a2fc5d22..4f4417fb846 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -78,6 +78,8 @@ type ChannelReestablish struct { // current un-revoked commitment transaction of the sending party. LocalUnrevokedCommitPoint *btcec.PublicKey + // NOTE: The following fields are TLV records. + // // LocalNonce is an optional field that stores a local musig2 nonce. // This will only be populated if the simple taproot channels type was // negotiated. @@ -140,20 +142,27 @@ func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error { return err } - recordProducers := make([]tlv.RecordProducer, 0, 1) + // Get producers from extra data. + producers, err := a.ExtraData.RecordProducers() + if err != nil { + return err + } + a.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) { - recordProducers = append(recordProducers, &localNonce) + producers = append(producers, &localNonce) }) a.DynHeight.WhenSome(func(h DynHeight) { - recordProducers = append(recordProducers, &h) + producers = append(producers, &h) }) - err := EncodeMessageExtraData(&a.ExtraData, recordProducers...) + // Pack all records into a new TLV stream. + var tlvData ExtraOpaqueData + err = tlvData.PackRecords(producers...) if err != nil { return err } - return WriteBytes(w, a.ExtraData) + return WriteBytes(w, tlvData) } // Decode deserializes a serialized ChannelReestablish stored in the passed @@ -210,23 +219,21 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { dynHeight DynHeight localNonce = a.LocalNonce.Zero() ) - typeMap, err := tlvRecords.ExtractRecords( - &localNonce, &dynHeight, + knownRecords, extraData, err := ParseAndExtractExtraData( + tlvRecords, &localNonce, &dynHeight, ) if err != nil { return err } - if val, ok := typeMap[a.LocalNonce.TlvType()]; ok && val == nil { + if _, ok := knownRecords[a.LocalNonce.TlvType()]; ok { a.LocalNonce = tlv.SomeRecordT(localNonce) } - if val, ok := typeMap[CRDynHeight]; ok && val == nil { + if _, ok := knownRecords[CRDynHeight]; ok { a.DynHeight = fn.Some(dynHeight) } - if len(tlvRecords) != 0 { - a.ExtraData = tlvRecords - } + a.ExtraData = extraData return nil } diff --git a/lnwire/channel_reestablish_test.go b/lnwire/channel_reestablish_test.go new file mode 100644 index 00000000000..4b67e8a55f5 --- /dev/null +++ b/lnwire/channel_reestablish_test.go @@ -0,0 +1,86 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/btcsuite/btcd/btcec/v2" + "github.com/stretchr/testify/require" +) + +// TestChannelReestablishEncodeDecode tests that a raw byte stream can be +// decoded, then re-encoded to the same exact byte stream. +func TestChannelReestablishEncodeDecode(t *testing.T) { + t.Parallel() + + // Create a new private key and its corresponding public key. + priv, err := btcec.NewPrivateKey() + require.NoError(t, err) + pk := priv.PubKey() + + // Create a sample ChannelReestablish message. + var rawBytes []byte + + // ChanID + rawBytes = append(rawBytes, make([]byte, 32)...) + + // NextLocalCommitHeight + rawBytes = append(rawBytes, []byte{0, 0, 0, 0, 0, 0, 0, 1}...) + + // RemoteCommitTailHeight + rawBytes = append(rawBytes, []byte{0, 0, 0, 0, 0, 0, 0, 2}...) + + // LastRemoteCommitSecret + rawBytes = append(rawBytes, make([]byte, 32)...) + + // LocalUnrevokedCommitPoint + rawBytes = append(rawBytes, pk.SerializeCompressed()...) + + // Add TLV data, including known and unknown records. + tlvData := []byte{ + // Unknown odd-type TLV record. + 0x3, // type + 0x2, // length + 0xab, 0xcd, // value + + // LocalNonce (type 4). + 0x04, // type + 0x42, // length (66) + // value (66 bytes) + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + + // DynHeight (type 20). + 0x14, // type + 0x08, // length + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, // value + + // Another unknown odd-type TLV record at the end. + 0x6f, // type + 0x2, // length + 0x79, 0x79, // value + } + rawBytes = append(rawBytes, tlvData...) + + // Now, create a new empty message and decode the raw bytes into it. + msg := &ChannelReestablish{} + r := bytes.NewReader(rawBytes) + err = msg.Decode(r, 0) + require.NoError(t, err) + + // Next, encode the message back into a new byte buffer. + var b bytes.Buffer + err = msg.Encode(&b, 0) + require.NoError(t, err) + + // The re-encoded bytes should be exactly the same as the original raw + // bytes. + require.Equal(t, rawBytes, b.Bytes()) +} diff --git a/lnwire/closing_complete.go b/lnwire/closing_complete.go index 7980ef1ee18..216b53d56be 100644 --- a/lnwire/closing_complete.go +++ b/lnwire/closing_complete.go @@ -57,30 +57,31 @@ type ClosingComplete struct { // decodeClosingSigs decodes the closing sig TLV records in the passed // ExtraOpaqueData. -func decodeClosingSigs(c *ClosingSigs, tlvRecords ExtraOpaqueData) error { +func decodeClosingSigs(c *ClosingSigs, tlvRecords ExtraOpaqueData) ( + ExtraOpaqueData, error) { + sig1 := c.CloserNoClosee.Zero() sig2 := c.NoCloserClosee.Zero() sig3 := c.CloserAndClosee.Zero() - typeMap, err := tlvRecords.ExtractRecords(&sig1, &sig2, &sig3) + knownRecords, extraData, err := ParseAndExtractExtraData( + tlvRecords, &sig1, &sig2, &sig3, + ) if err != nil { - return err + return nil, err } - // TODO(roasbeef): helper func to made decode of the optional vals - // easier? - - if val, ok := typeMap[c.CloserNoClosee.TlvType()]; ok && val == nil { + if _, ok := knownRecords[c.CloserNoClosee.TlvType()]; ok { c.CloserNoClosee = tlv.SomeRecordT(sig1) } - if val, ok := typeMap[c.NoCloserClosee.TlvType()]; ok && val == nil { + if _, ok := knownRecords[c.NoCloserClosee.TlvType()]; ok { c.NoCloserClosee = tlv.SomeRecordT(sig2) } - if val, ok := typeMap[c.CloserAndClosee.TlvType()]; ok && val == nil { + if _, ok := knownRecords[c.CloserAndClosee.TlvType()]; ok { c.CloserAndClosee = tlv.SomeRecordT(sig3) } - return nil + return extraData, nil } // Decode deserializes a serialized ClosingComplete message stored in the @@ -102,13 +103,12 @@ func (c *ClosingComplete) Decode(r io.Reader, _ uint32) error { return err } - if err := decodeClosingSigs(&c.ClosingSigs, tlvRecords); err != nil { + extraData, err := decodeClosingSigs(&c.ClosingSigs, tlvRecords) + if err != nil { return err } - if len(tlvRecords) != 0 { - c.ExtraData = tlvRecords - } + c.ExtraData = extraData return nil } @@ -151,14 +151,22 @@ func (c *ClosingComplete) Encode(w *bytes.Buffer, _ uint32) error { return err } - recordProducers := closingSigRecords(&c.ClosingSigs) + // Get producers from extra data. + producers, err := c.ExtraData.RecordProducers() + if err != nil { + return err + } + + producers = append(producers, closingSigRecords(&c.ClosingSigs)...) - err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) + // Pack all records into a new TLV stream. + var tlvData ExtraOpaqueData + err = tlvData.PackRecords(producers...) if err != nil { return err } - return WriteBytes(w, c.ExtraData) + return WriteBytes(w, tlvData) } // MsgType returns the uint32 code which uniquely identifies this message as a diff --git a/lnwire/closing_complete_test.go b/lnwire/closing_complete_test.go new file mode 100644 index 00000000000..cde69c09aa0 --- /dev/null +++ b/lnwire/closing_complete_test.go @@ -0,0 +1,100 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestClosingCompleteEncodeDecode tests that a raw byte stream can be +// decoded, then re-encoded to the same exact byte stream. +func TestClosingCompleteEncodeDecode(t *testing.T) { + t.Parallel() + + // Create a sample ClosingComplete message. + var rawBytes []byte + + // ChannelID + rawBytes = append(rawBytes, make([]byte, 32)...) + + // CloserScript + rawBytes = append(rawBytes, []byte{0, 1, 0xaa}...) + + // CloseeScript + rawBytes = append(rawBytes, []byte{0, 1, 0xbb}...) + + // FeeSatoshis + rawBytes = append(rawBytes, []byte{0, 0, 0, 0, 0, 0, 0, 1}...) + + // LockTime + rawBytes = append(rawBytes, []byte{0, 0, 0, 2}...) + + // Add TLV data, including known and unknown records. + tlvData := []byte{ + // CloserNoClosee (known, type 1) + 1, // type + 64, // length + // 64 bytes of dummy signature data + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + + // NoCloserClosee (known, type 2) + 2, // type + 64, // length + // 64 bytes of dummy signature data + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + + // CloserAndClosee (known, type 3) + 3, // type + 64, // length + // 64 bytes of dummy signature data + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + + // Unknown odd-type TLV record. + 0x5, // type + 0x2, // length + 0xab, 0xcd, // value + + // Another unknown odd-type TLV record at the end. + 0x6f, // type + 0x2, // length + 0x79, 0x79, // value + } + rawBytes = append(rawBytes, tlvData...) + + // Now, create a new empty message and decode the raw bytes into it. + msg := &ClosingComplete{} + r := bytes.NewReader(rawBytes) + err := msg.Decode(r, 0) + require.NoError(t, err) + + // Next, encode the message back into a new byte buffer. + var b bytes.Buffer + err = msg.Encode(&b, 0) + require.NoError(t, err) + + // The re-encoded bytes should be exactly the same as the original raw + // bytes. + require.Equal(t, rawBytes, b.Bytes()) +} diff --git a/lnwire/closing_sig.go b/lnwire/closing_sig.go index 94a35606638..67954739430 100644 --- a/lnwire/closing_sig.go +++ b/lnwire/closing_sig.go @@ -57,13 +57,12 @@ func (c *ClosingSig) Decode(r io.Reader, _ uint32) error { return err } - if err := decodeClosingSigs(&c.ClosingSigs, tlvRecords); err != nil { + extraData, err := decodeClosingSigs(&c.ClosingSigs, tlvRecords) + if err != nil { return err } - if len(tlvRecords) != 0 { - c.ExtraData = tlvRecords - } + c.ExtraData = extraData return nil } @@ -89,14 +88,22 @@ func (c *ClosingSig) Encode(w *bytes.Buffer, _ uint32) error { return err } - recordProducers := closingSigRecords(&c.ClosingSigs) + // Get producers from extra data. + producers, err := c.ExtraData.RecordProducers() + if err != nil { + return err + } + + producers = append(producers, closingSigRecords(&c.ClosingSigs)...) - err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) + // Pack all records into a new TLV stream. + var tlvData ExtraOpaqueData + err = tlvData.PackRecords(producers...) if err != nil { return err } - return WriteBytes(w, c.ExtraData) + return WriteBytes(w, tlvData) } // MsgType returns the uint32 code which uniquely identifies this message as a diff --git a/lnwire/closing_sig_test.go b/lnwire/closing_sig_test.go new file mode 100644 index 00000000000..b8e26d7d7c3 --- /dev/null +++ b/lnwire/closing_sig_test.go @@ -0,0 +1,100 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestClosingSigEncodeDecode tests that a raw byte stream can be +// decoded, then re-encoded to the same exact byte stream. +func TestClosingSigEncodeDecode(t *testing.T) { + t.Parallel() + + // Create a sample ClosingSig message. + var rawBytes []byte + + // ChannelID + rawBytes = append(rawBytes, make([]byte, 32)...) + + // CloserScript + rawBytes = append(rawBytes, []byte{0, 1, 0xaa}...) + + // CloseeScript + rawBytes = append(rawBytes, []byte{0, 1, 0xbb}...) + + // FeeSatoshis + rawBytes = append(rawBytes, []byte{0, 0, 0, 0, 0, 0, 0, 1}...) + + // LockTime + rawBytes = append(rawBytes, []byte{0, 0, 0, 2}...) + + // Add TLV data, including known and unknown records. + tlvData := []byte{ + // CloserNoClosee (known, type 1) + 1, // type + 64, // length + // 64 bytes of dummy signature data + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + + // NoCloserClosee (known, type 2) + 2, // type + 64, // length + // 64 bytes of dummy signature data + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, 0x22, + + // CloserAndClosee (known, type 3) + 3, // type + 64, // length + // 64 bytes of dummy signature data + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, 0x33, + + // Unknown odd-type TLV record. + 0x5, // type + 0x2, // length + 0xab, 0xcd, // value + + // Another unknown odd-type TLV record at the end. + 0x6f, // type + 0x2, // length + 0x79, 0x79, // value + } + rawBytes = append(rawBytes, tlvData...) + + // Now, create a new empty message and decode the raw bytes into it. + msg := &ClosingSig{} + r := bytes.NewReader(rawBytes) + err := msg.Decode(r, 0) + require.NoError(t, err) + + // Next, encode the message back into a new byte buffer. + var b bytes.Buffer + err = msg.Encode(&b, 0) + require.NoError(t, err) + + // The re-encoded bytes should be exactly the same as the original raw + // bytes. + require.Equal(t, rawBytes, b.Bytes()) +} diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index c247cfe0a55..fe1377e4446 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -81,19 +81,19 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { } partialSig := c.PartialSig.Zero() - typeMap, err := tlvRecords.ExtractRecords(&partialSig) + knownRecords, extraData, err := ParseAndExtractExtraData( + tlvRecords, &partialSig, + ) if err != nil { return err } // Set the corresponding TLV types if they were included in the stream. - if val, ok := typeMap[c.PartialSig.TlvType()]; ok && val == nil { + if _, ok := knownRecords[c.PartialSig.TlvType()]; ok { c.PartialSig = tlv.SomeRecordT(partialSig) } - if len(tlvRecords) != 0 { - c.ExtraData = tlvRecords - } + c.ExtraData = extraData return nil } @@ -103,11 +103,19 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { // // This is part of the lnwire.Message interface. func (c *ClosingSigned) Encode(w *bytes.Buffer, pver uint32) error { - recordProducers := make([]tlv.RecordProducer, 0, 1) + // Get producers from extra data. + producers, err := c.ExtraData.RecordProducers() + if err != nil { + return err + } + c.PartialSig.WhenSome(func(sig PartialSigTLV) { - recordProducers = append(recordProducers, &sig) + producers = append(producers, &sig) }) - err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) + + // Pack all records into a new TLV stream. + var tlvData ExtraOpaqueData + err = tlvData.PackRecords(producers...) if err != nil { return err } @@ -124,7 +132,7 @@ func (c *ClosingSigned) Encode(w *bytes.Buffer, pver uint32) error { return err } - return WriteBytes(w, c.ExtraData) + return WriteBytes(w, tlvData) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/closing_signed_test.go b/lnwire/closing_signed_test.go new file mode 100644 index 00000000000..0255d6e9ea8 --- /dev/null +++ b/lnwire/closing_signed_test.go @@ -0,0 +1,64 @@ +package lnwire + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestClosingSignedEncodeDecode tests that a raw byte stream can be +// decoded, then re-encoded to the same exact byte stream. +func TestClosingSignedEncodeDecode(t *testing.T) { + t.Parallel() + + // Create a sample ClosingSigned message. + var rawBytes []byte + + // ChannelID + rawBytes = append(rawBytes, make([]byte, 32)...) + + // FeeSatoshis + rawBytes = append(rawBytes, []byte{0, 0, 0, 0, 0, 0, 0, 1}...) + + // Signature + rawBytes = append(rawBytes, make([]byte, 64)...) + + // Add TLV data, including known and unknown records. + tlvData := []byte{ + // Unknown odd-type TLV record. + 0x3, // type + 0x2, // length + 0xab, 0xcd, // value + + // PartialSig (known, type 6) + 6, // type + 32, // length + // 32 bytes of dummy signature data + 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, + 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, + 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, + 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, 0x44, + + // Another unknown odd-type TLV record at the end. + 0x6f, // type + 0x2, // length + 0x79, 0x79, // value + } + rawBytes = append(rawBytes, tlvData...) + + // Now, create a new empty message and decode the raw bytes into it. + msg := &ClosingSigned{} + r := bytes.NewReader(rawBytes) + err := msg.Decode(r, 0) + require.NoError(t, err) + + // Next, encode the message back into a new byte buffer. + var b bytes.Buffer + err = msg.Encode(&b, 0) + require.NoError(t, err) + + // The re-encoded bytes should be exactly the same as the original raw + // bytes. + require.Equal(t, rawBytes, b.Bytes()) +} diff --git a/lnwire/commit_sig_test.go b/lnwire/commit_sig_test.go index 0772a2fb831..87d54a6372e 100644 --- a/lnwire/commit_sig_test.go +++ b/lnwire/commit_sig_test.go @@ -131,7 +131,7 @@ func generateCommitSigTestCases(t *testing.T) []commitSigTestCase { // TestCommitSigEncodeDecode tests CommitSig message encoding and decoding for // all supported field values. -func TestCommitSigEncodeDecode(t *testing.T) { +func TestCommitSigEncodeDecodeFields(t *testing.T) { t.Parallel() // Generate test cases. @@ -166,3 +166,56 @@ func TestCommitSigEncodeDecode(t *testing.T) { }) } } + +// TestCommitSigEncodeDecode tests that a raw byte stream can be decoded, then +// re-encoded to the same exact byte stream. +func TestCommitSigEncodeDecode(t *testing.T) { + t.Parallel() + + // We'll create a raw byte stream that represents a valid CommitSig + // message. This includes the fixed-size fields and a TLV stream with + // both known and unknown records. + var rawBytes []byte + + // ChanID + rawBytes = append(rawBytes, make([]byte, 32)...) + + // CommitSig + rawBytes = append(rawBytes, make([]byte, 64)...) + + // HtlcSigs + rawBytes = append(rawBytes, []byte{0, 0}...) + + // Add TLV data, including known and unknown records. + tlvData := []byte{ + // PartialSig record. + 2, 98, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, + + // Unknown odd record: Type=3, Length=1, Value=0. + 3, 1, 0, + + // CustomRecords: Type=65536, Length=1, Value=0. + 0xfe, 0x00, 0x01, 0x00, 0x00, 1, 0, + } + rawBytes = append(rawBytes, tlvData...) + + // Now, create a new empty message and decode the raw bytes into it. + msg := &CommitSig{} + r := bytes.NewReader(rawBytes) + err := msg.Decode(r, 0) + require.NoError(t, err) + + // Next, encode the message back into a new byte buffer. + var b bytes.Buffer + err = msg.Encode(&b, 0) + require.NoError(t, err) + + // The re-encoded bytes should be exactly the same as the original raw + // bytes. + require.Equal(t, rawBytes, b.Bytes()) +} diff --git a/lnwire/test_message.go b/lnwire/test_message.go index 7eb712b743a..2ccb0f1789a 100644 --- a/lnwire/test_message.go +++ b/lnwire/test_message.go @@ -59,7 +59,7 @@ func (a *AcceptChannel) RandTestMessage(t *rapid.T) Message { if includeLocalNonce { nonce := RandMusig2Nonce(t) localNonce = tlv.SomeRecordT( - tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce), + tlv.NewRecordT[NonceRecordTypeT](nonce), ) } @@ -98,7 +98,12 @@ func (a *AcceptChannel) RandTestMessage(t *rapid.T) Message { ChannelType: channelType, LeaseExpiry: leaseExpiry, LocalNonce: localNonce, - ExtraData: RandExtraOpaqueData(t, nil), + ExtraData: RandExtraRecords( + t, uint64(DeliveryAddrType), + uint64(ChannelTypeRecordType), + uint64(LeaseExpiryRecordType), + uint64(nonceRecordType), + ), } } @@ -214,22 +219,22 @@ func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message { msg := &ChannelAnnouncement2{ Signature: RandSignature(t), - ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash]( + ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0]( chainHashObj, ), - Features: tlv.NewRecordT[tlv.TlvType2, RawFeatureVector]( + Features: tlv.NewRecordT[tlv.TlvType2]( *features, ), - ShortChannelID: tlv.NewRecordT[tlv.TlvType4, ShortChannelID]( + ShortChannelID: tlv.NewRecordT[tlv.TlvType4]( shortChanID, ), - Capacity: tlv.NewPrimitiveRecord[tlv.TlvType6, uint64]( + Capacity: tlv.NewPrimitiveRecord[tlv.TlvType6]( capacity, ), - NodeID1: tlv.NewPrimitiveRecord[tlv.TlvType8, [33]byte]( + NodeID1: tlv.NewPrimitiveRecord[tlv.TlvType8]( nodeID1, ), - NodeID2: tlv.NewPrimitiveRecord[tlv.TlvType10, [33]byte]( + NodeID2: tlv.NewPrimitiveRecord[tlv.TlvType10]( nodeID2, ), ExtraOpaqueData: RandExtraOpaqueData(t, nil), @@ -242,7 +247,7 @@ func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message { var bitcoinKey1 [33]byte copy(bitcoinKey1[:], RandPubKey(t).SerializeCompressed()) msg.BitcoinKey1 = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType12, [33]byte]( + tlv.NewPrimitiveRecord[tlv.TlvType12]( bitcoinKey1, ), ) @@ -252,7 +257,7 @@ func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message { var bitcoinKey2 [33]byte copy(bitcoinKey2[:], RandPubKey(t).SerializeCompressed()) msg.BitcoinKey2 = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType14, [33]byte]( + tlv.NewPrimitiveRecord[tlv.TlvType14]( bitcoinKey2, ), ) @@ -263,7 +268,7 @@ func (c *ChannelAnnouncement2) RandTestMessage(t *rapid.T) Message { var merkleRootHash [32]byte copy(merkleRootHash[:], hash[:]) msg.MerkleRootHash = tlv.SomeRecordT( - tlv.NewPrimitiveRecord[tlv.TlvType16, [32]byte]( + tlv.NewPrimitiveRecord[tlv.TlvType16]( merkleRootHash, ), ) @@ -284,7 +289,10 @@ func (c *ChannelReady) RandTestMessage(t *rapid.T) Message { msg := &ChannelReady{ ChanID: RandChannelID(t), NextPerCommitmentPoint: RandPubKey(t), - ExtraData: RandExtraOpaqueData(t, nil), + ExtraData: RandExtraRecords( + t, uint64(AliasScidRecordType), + uint64(nonceRecordType), 0, 2, + ), } includeAliasScid := rapid.Bool().Draw(t, "includeAliasScid") @@ -309,14 +317,14 @@ func (c *ChannelReady) RandTestMessage(t *rapid.T) Message { if includeAnnouncementNodeNonce { nonce := RandMusig2Nonce(t) msg.AnnouncementNodeNonce = tlv.SomeRecordT( - tlv.NewRecordT[tlv.TlvType0, Musig2Nonce](nonce), + tlv.NewRecordT[tlv.TlvType0](nonce), ) } if includeAnnouncementBitcoinNonce { nonce := RandMusig2Nonce(t) msg.AnnouncementBitcoinNonce = tlv.SomeRecordT( - tlv.NewRecordT[tlv.TlvType2, Musig2Nonce](nonce), + tlv.NewRecordT[tlv.TlvType2](nonce), ) } @@ -342,7 +350,9 @@ func (a *ChannelReestablish) RandTestMessage(t *rapid.T) Message { ), LastRemoteCommitSecret: RandPaymentPreimage(t), LocalUnrevokedCommitPoint: RandPubKey(t), - ExtraData: RandExtraOpaqueData(t, nil), + ExtraData: RandExtraRecords( + t, uint64(nonceRecordType), uint64(CRDynHeight), + ), } // Randomly decide whether to include optional fields @@ -431,7 +441,7 @@ func (a *ChannelUpdate1) RandTestMessage(t *rapid.T) Message { FeeRate: inFeeProp, } inboundFee = tlv.SomeRecordT( - tlv.NewRecordT[tlv.TlvType55555, Fee](fee), + tlv.NewRecordT[tlv.TlvType55555](fee), ) var b bytes.Buffer @@ -514,31 +524,31 @@ func (c *ChannelUpdate2) RandTestMessage(t *rapid.T) Message { //nolint:ll msg := &ChannelUpdate2{ Signature: RandSignature(t), - ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0, chainhash.Hash]( + ChainHash: tlv.NewPrimitiveRecord[tlv.TlvType0]( chainHashObj, ), - ShortChannelID: tlv.NewRecordT[tlv.TlvType2, ShortChannelID]( + ShortChannelID: tlv.NewRecordT[tlv.TlvType2]( shortChanID, ), - BlockHeight: tlv.NewPrimitiveRecord[tlv.TlvType4, uint32]( + BlockHeight: tlv.NewPrimitiveRecord[tlv.TlvType4]( blockHeight, ), - DisabledFlags: tlv.NewPrimitiveRecord[tlv.TlvType6, ChanUpdateDisableFlags]( //nolint:ll + DisabledFlags: tlv.NewPrimitiveRecord[tlv.TlvType6]( disabledFlags, ), - CLTVExpiryDelta: tlv.NewPrimitiveRecord[tlv.TlvType10, uint16]( + CLTVExpiryDelta: tlv.NewPrimitiveRecord[tlv.TlvType10]( cltvExpiryDelta, ), - HTLCMinimumMsat: tlv.NewPrimitiveRecord[tlv.TlvType12, MilliSatoshi]( + HTLCMinimumMsat: tlv.NewPrimitiveRecord[tlv.TlvType12]( htlcMinMsat, ), - HTLCMaximumMsat: tlv.NewPrimitiveRecord[tlv.TlvType14, MilliSatoshi]( + HTLCMaximumMsat: tlv.NewPrimitiveRecord[tlv.TlvType14]( htlcMaxMsat, ), - FeeBaseMsat: tlv.NewPrimitiveRecord[tlv.TlvType16, uint32]( + FeeBaseMsat: tlv.NewPrimitiveRecord[tlv.TlvType16]( feeBaseMsat, ), - FeeProportionalMillionths: tlv.NewPrimitiveRecord[tlv.TlvType18, uint32]( + FeeProportionalMillionths: tlv.NewPrimitiveRecord[tlv.TlvType18]( feeProportionalMillionths, ), ExtraOpaqueData: RandExtraOpaqueData(t, nil), @@ -574,7 +584,7 @@ func (c *ClosingComplete) RandTestMessage(t *rapid.T) Message { ), CloseeScript: RandDeliveryAddress(t), CloserScript: RandDeliveryAddress(t), - ExtraData: RandExtraOpaqueData(t, nil), + ExtraData: RandExtraRecords(t, 1, 2, 3), } includeCloserNoClosee := rapid.Bool().Draw(t, "includeCloserNoClosee") @@ -600,21 +610,21 @@ func (c *ClosingComplete) RandTestMessage(t *rapid.T) Message { if includeCloserNoClosee { sig := RandSignature(t) msg.CloserNoClosee = tlv.SomeRecordT( - tlv.NewRecordT[tlv.TlvType1, Sig](sig), + tlv.NewRecordT[tlv.TlvType1](sig), ) } if includeNoCloserClosee { sig := RandSignature(t) msg.NoCloserClosee = tlv.SomeRecordT( - tlv.NewRecordT[tlv.TlvType2, Sig](sig), + tlv.NewRecordT[tlv.TlvType2](sig), ) } if includeCloserAndClosee { sig := RandSignature(t) msg.CloserAndClosee = tlv.SomeRecordT( - tlv.NewRecordT[tlv.TlvType3, Sig](sig), + tlv.NewRecordT[tlv.TlvType3](sig), ) } @@ -634,7 +644,7 @@ func (c *ClosingSig) RandTestMessage(t *rapid.T) Message { ChannelID: RandChannelID(t), CloseeScript: RandDeliveryAddress(t), CloserScript: RandDeliveryAddress(t), - ExtraData: RandExtraOpaqueData(t, nil), + ExtraData: RandExtraRecords(t, 1, 2, 3), } includeCloserNoClosee := rapid.Bool().Draw(t, "includeCloserNoClosee") @@ -660,21 +670,21 @@ func (c *ClosingSig) RandTestMessage(t *rapid.T) Message { if includeCloserNoClosee { sig := RandSignature(t) msg.CloserNoClosee = tlv.SomeRecordT( - tlv.NewRecordT[tlv.TlvType1, Sig](sig), + tlv.NewRecordT[tlv.TlvType1](sig), ) } if includeNoCloserClosee { sig := RandSignature(t) msg.NoCloserClosee = tlv.SomeRecordT( - tlv.NewRecordT[tlv.TlvType2, Sig](sig), + tlv.NewRecordT[tlv.TlvType2](sig), ) } if includeCloserAndClosee { sig := RandSignature(t) msg.CloserAndClosee = tlv.SomeRecordT( - tlv.NewRecordT[tlv.TlvType3, Sig](sig), + tlv.NewRecordT[tlv.TlvType3](sig), ) } @@ -700,7 +710,9 @@ func (c *ClosingSigned) RandTestMessage(t *rapid.T) Message { FeeSatoshis: btcutil.Amount( rapid.Int64Range(0, 1000000).Draw(t, "feeSatoshis"), ), - ExtraData: RandExtraOpaqueData(t, nil), + ExtraData: RandExtraRecords( + t, uint64((PartialSigType)(nil).TypeVal()), + ), } if usePartialSig { @@ -737,7 +749,7 @@ func (c *CommitSig) RandTestMessage(t *rapid.T) Message { numHtlcSigs := rapid.IntRange(0, 20).Draw(t, "numHtlcSigs") htlcSigs := make([]Sig, numHtlcSigs) - for i := 0; i < numHtlcSigs; i++ { + for i := range numHtlcSigs { htlcSigs[i] = RandSignature(t) } @@ -802,17 +814,7 @@ func (da *DynAck) RandTestMessage(t *rapid.T) Message { msg.LocalNonce = tlv.SomeRecordT(rec) } - // Create a tlv type lists to hold all known records which will be - // ignored when creating ExtraData records. - ignoreRecords := fn.NewSet[uint64]() - for i := range uint64(15) { - // Ignore known records. - if i%2 == 0 { - ignoreRecords.Add(i) - } - } - - msg.ExtraData = RandExtraOpaqueData(t, ignoreRecords) + msg.ExtraData = RandUnknownRecords(t, 14) return msg } @@ -883,17 +885,7 @@ func (dp *DynPropose) RandTestMessage(t *rapid.T) Message { msg.ChannelType = tlv.SomeRecordT(chanType) } - // Create a tlv type lists to hold all known records which will be - // ignored when creating ExtraData records. - ignoreRecords := fn.NewSet[uint64]() - for i := range uint64(13) { - // Ignore known records. - if i%2 == 0 { - ignoreRecords.Add(i) - } - } - - msg.ExtraData = RandExtraOpaqueData(t, ignoreRecords) + msg.ExtraData = RandUnknownRecords(t, 12) return msg } @@ -910,7 +902,7 @@ func (dr *DynReject) RandTestMessage(t *rapid.T) Message { featureVec := NewRawFeatureVector() numFeatures := rapid.IntRange(0, 8).Draw(t, "numRejections") - for i := 0; i < numFeatures; i++ { + for i := range numFeatures { bit := FeatureBit( rapid.IntRange(0, 31).Draw( t, fmt.Sprintf("rejectionBit-%d", i), @@ -1011,21 +1003,12 @@ func (dc *DynCommit) RandTestMessage(t *rapid.T) Message { da.LocalNonce = tlv.SomeRecordT(rec) } - // Create a tlv type lists to hold all known records which will be - // ignored when creating ExtraData records. - ignoreRecords := fn.NewSet[uint64]() - for i := range uint64(15) { - // Ignore known records. - if i%2 == 0 { - ignoreRecords.Add(i) - } - } msg := &DynCommit{ DynPropose: *dp, DynAck: *da, } - msg.ExtraData = RandExtraOpaqueData(t, ignoreRecords) + msg.ExtraData = RandUnknownRecords(t, 14) return msg } @@ -1147,7 +1130,7 @@ func (msg *Init) RandTestMessage(t *rapid.T) Message { local := NewRawFeatureVector() numGlobalFeatures := rapid.IntRange(0, 20).Draw(t, "numGlobalFeatures") - for i := 0; i < numGlobalFeatures; i++ { + for i := range numGlobalFeatures { bit := FeatureBit( rapid.IntRange(0, 100).Draw( t, fmt.Sprintf("globalFeatureBit%d", i), @@ -1157,7 +1140,7 @@ func (msg *Init) RandTestMessage(t *rapid.T) Message { } numLocalFeatures := rapid.IntRange(0, 20).Draw(t, "numLocalFeatures") - for i := 0; i < numLocalFeatures; i++ { + for i := range numLocalFeatures { bit := FeatureBit( rapid.IntRange(0, 100).Draw( t, fmt.Sprintf("localFeatureBit%d", i), @@ -1252,7 +1235,7 @@ func (o *OpenChannel) RandTestMessage(t *rapid.T) Message { if includeLocalNonce { nonce := RandMusig2Nonce(t) localNonce = tlv.SomeRecordT( - tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce), + tlv.NewRecordT[NonceRecordTypeT](nonce), ) } @@ -1335,7 +1318,7 @@ func (p *Ping) RandTestMessage(t *rapid.T) Message { padding := make(PingPayload, paddingLen) // Fill padding with random bytes - for i := 0; i < paddingLen; i++ { + for i := range paddingLen { padding[i] = byte(rapid.IntRange(0, 255).Draw( t, fmt.Sprintf("paddingByte%d", i)), ) @@ -1430,7 +1413,7 @@ func (q *QueryShortChanIDs) RandTestMessage(t *rapid.T) Message { // Generate sorted short channel IDs. shortChanIDs := make([]ShortChannelID, numIDs) - for i := 0; i < numIDs; i++ { + for i := range numIDs { shortChanIDs[i] = RandShortChannelID(t) // Ensure they're properly sorted. @@ -1481,7 +1464,7 @@ func (c *ReplyChannelRange) RandTestMessage(t *rapid.T) Message { scidSet := fn.NewSet[ShortChannelID]() scids := make([]ShortChannelID, numShortChanIDs) - for i := 0; i < numShortChanIDs; i++ { + for i := range numShortChanIDs { scid := RandShortChannelID(t) for scidSet.Contains(scid) { scid = RandShortChannelID(t) @@ -1497,7 +1480,7 @@ func (c *ReplyChannelRange) RandTestMessage(t *rapid.T) Message { if rapid.Bool().Draw(t, "includeTimestamps") && numShortChanIDs > 0 { msg.Timestamps = make(Timestamps, numShortChanIDs) - for i := 0; i < numShortChanIDs; i++ { + for i := range numShortChanIDs { msg.Timestamps[i] = ChanUpdateTimestamps{ Timestamp1: uint32(rapid.IntRange(0, math.MaxInt32).Draw(t, fmt.Sprintf("timestamp-1-%d", i))), //nolint:ll Timestamp2: uint32(rapid.IntRange(0, math.MaxInt32).Draw(t, fmt.Sprintf("timestamp-2-%d", i))), //nolint:ll @@ -1554,7 +1537,7 @@ func (c *RevokeAndAck) RandTestMessage(t *rapid.T) Message { copy(nonce[:], nonceBytes) msg.LocalNonce = tlv.SomeRecordT( - tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce), + tlv.NewRecordT[NonceRecordTypeT](nonce), ) } @@ -1772,7 +1755,7 @@ func (c *Warning) RandTestMessage(t *rapid.T) Message { if useASCII { length := rapid.IntRange(1, 100).Draw(t, "warningDataLength") data := make([]byte, length) - for i := 0; i < length; i++ { + for i := range length { data[i] = byte( rapid.IntRange(32, 126).Draw( t, fmt.Sprintf("warningDataByte-%d", i), @@ -1807,7 +1790,7 @@ func (c *Error) RandTestMessage(t *rapid.T) Message { if useASCII { length := rapid.IntRange(1, 100).Draw(t, "errorDataLength") data := make([]byte, length) - for i := 0; i < length; i++ { + for i := range length { data[i] = byte( rapid.IntRange(32, 126).Draw( t, fmt.Sprintf("errorDataByte-%d", i), diff --git a/lnwire/test_utils.go b/lnwire/test_utils.go index 07c9d795bce..731819b2119 100644 --- a/lnwire/test_utils.go +++ b/lnwire/test_utils.go @@ -270,6 +270,55 @@ func RandExtraOpaqueData(t *rapid.T, return ExtraOpaqueData(recordBytes) } +// RandUnknownRecords creates random unknown records. It will skip creating +// records that have known types. +func RandUnknownRecords(t *rapid.T, maxKnownType int) ExtraOpaqueData { + // Create a tlv type lists to hold all known records which will be + // ignored when creating ExtraData records. + ignoreRecords := fn.NewSet[uint64]() + for i := range uint64(maxKnownType + 1) { + // Ignore known records. + if i%2 == 0 { + ignoreRecords.Add(i) + } + } + + // Make some random records. + cRecords, _ := RandCustomRecords(t, ignoreRecords, false) + if cRecords == nil { + return ExtraOpaqueData{} + } + + // Encode those records as opaque data. + recordBytes, err := cRecords.Serialize() + require.NoError(t, err) + + return ExtraOpaqueData(recordBytes) +} + +// RandExtraRecords creates random records. It will skip creating records +// specified by the knownRecords. +func RandExtraRecords(t *rapid.T, knownRecords ...uint64) ExtraOpaqueData { + // Create a tlv type lists to hold all known records which will be + // ignored when creating ExtraData records. + ignoreRecords := fn.NewSet[uint64]() + for _, recordType := range knownRecords { + ignoreRecords.Add(recordType) + } + + // Make some random records. + cRecords, _ := RandCustomRecords(t, ignoreRecords, false) + if cRecords == nil { + return ExtraOpaqueData{} + } + + // Encode those records as opaque data. + recordBytes, err := cRecords.Serialize() + require.NoError(t, err) + + return ExtraOpaqueData(recordBytes) +} + // RandOpaqueReason generates a random opaque reason for HTLC failures. func RandOpaqueReason(t *rapid.T) OpaqueReason { reasonLen := rapid.IntRange(32, 300).Draw(t, "reasonLen")