|
5 | 5 | "testing" |
6 | 6 |
|
7 | 7 | "github.com/lightningnetwork/lnd/lnwire" |
| 8 | + "github.com/stretchr/testify/require" |
8 | 9 | ) |
9 | 10 |
|
10 | 11 | type depTest struct { |
@@ -164,3 +165,170 @@ func testValidateDeps(t *testing.T, test depTest) { |
164 | 165 | test.expErr, err) |
165 | 166 | } |
166 | 167 | } |
| 168 | + |
| 169 | +// TestSettingDepBits sets that the SetBit function correctly sets a bit along |
| 170 | +// with its dependencies in a feature vector. Specifically, we want to check |
| 171 | +// that any existing optional bits are upgraded to required if the main bit |
| 172 | +// being set is required. Similarly, if the main bit is optional, then any |
| 173 | +// existing bits that depend on it should not be downgraded from required to |
| 174 | +// optional. |
| 175 | +func TestSettingDepBits(t *testing.T) { |
| 176 | + t.Parallel() |
| 177 | + |
| 178 | + tests := []struct { |
| 179 | + name string |
| 180 | + existingVector *lnwire.RawFeatureVector |
| 181 | + newBit lnwire.FeatureBit |
| 182 | + expectedVector *lnwire.RawFeatureVector |
| 183 | + }{ |
| 184 | + { |
| 185 | + name: "Optional bit with no dependants", |
| 186 | + existingVector: lnwire.NewRawFeatureVector(), |
| 187 | + newBit: lnwire.ExplicitChannelTypeOptional, |
| 188 | + expectedVector: lnwire.NewRawFeatureVector( |
| 189 | + lnwire.ExplicitChannelTypeOptional, |
| 190 | + ), |
| 191 | + }, |
| 192 | + { |
| 193 | + name: "Required bit with no dependants", |
| 194 | + existingVector: lnwire.NewRawFeatureVector(), |
| 195 | + newBit: lnwire.ExplicitChannelTypeRequired, |
| 196 | + expectedVector: lnwire.NewRawFeatureVector( |
| 197 | + lnwire.ExplicitChannelTypeRequired, |
| 198 | + ), |
| 199 | + }, |
| 200 | + { |
| 201 | + name: "Optional bit with single " + |
| 202 | + "level dependant", |
| 203 | + existingVector: lnwire.NewRawFeatureVector(), |
| 204 | + newBit: lnwire.RouteBlindingOptional, |
| 205 | + expectedVector: lnwire.NewRawFeatureVector( |
| 206 | + lnwire.RouteBlindingOptional, |
| 207 | + lnwire.TLVOnionPayloadOptional, |
| 208 | + ), |
| 209 | + }, |
| 210 | + { |
| 211 | + name: "Required bit with single " + |
| 212 | + "level dependant", |
| 213 | + existingVector: lnwire.NewRawFeatureVector(), |
| 214 | + newBit: lnwire.RouteBlindingRequired, |
| 215 | + expectedVector: lnwire.NewRawFeatureVector( |
| 216 | + lnwire.RouteBlindingRequired, |
| 217 | + lnwire.TLVOnionPayloadRequired, |
| 218 | + ), |
| 219 | + }, |
| 220 | + { |
| 221 | + name: "Optional bit with multi level " + |
| 222 | + "dependants", |
| 223 | + existingVector: lnwire.NewRawFeatureVector(), |
| 224 | + newBit: lnwire.Bolt11BlindedPathsOptional, |
| 225 | + expectedVector: lnwire.NewRawFeatureVector( |
| 226 | + lnwire.Bolt11BlindedPathsOptional, |
| 227 | + lnwire.RouteBlindingOptional, |
| 228 | + lnwire.TLVOnionPayloadOptional, |
| 229 | + ), |
| 230 | + }, |
| 231 | + { |
| 232 | + name: "Required bit with multi level " + |
| 233 | + "dependants", |
| 234 | + existingVector: lnwire.NewRawFeatureVector(), |
| 235 | + newBit: lnwire.Bolt11BlindedPathsRequired, |
| 236 | + expectedVector: lnwire.NewRawFeatureVector( |
| 237 | + lnwire.Bolt11BlindedPathsRequired, |
| 238 | + lnwire.RouteBlindingRequired, |
| 239 | + lnwire.TLVOnionPayloadRequired, |
| 240 | + ), |
| 241 | + }, |
| 242 | + { |
| 243 | + name: "Existing required bit should not be " + |
| 244 | + "overridden if new bit is optional", |
| 245 | + existingVector: lnwire.NewRawFeatureVector( |
| 246 | + lnwire.TLVOnionPayloadRequired, |
| 247 | + ), |
| 248 | + newBit: lnwire.Bolt11BlindedPathsOptional, |
| 249 | + expectedVector: lnwire.NewRawFeatureVector( |
| 250 | + lnwire.Bolt11BlindedPathsOptional, |
| 251 | + lnwire.RouteBlindingOptional, |
| 252 | + lnwire.TLVOnionPayloadRequired, |
| 253 | + ), |
| 254 | + }, |
| 255 | + { |
| 256 | + name: "Existing optional bit should be overridden if " + |
| 257 | + "new bit is required", |
| 258 | + existingVector: lnwire.NewRawFeatureVector( |
| 259 | + lnwire.TLVOnionPayloadOptional, |
| 260 | + ), |
| 261 | + newBit: lnwire.Bolt11BlindedPathsRequired, |
| 262 | + expectedVector: lnwire.NewRawFeatureVector( |
| 263 | + lnwire.Bolt11BlindedPathsRequired, |
| 264 | + lnwire.RouteBlindingRequired, |
| 265 | + lnwire.TLVOnionPayloadRequired, |
| 266 | + ), |
| 267 | + }, |
| 268 | + { |
| 269 | + name: "Unrelated bits should not be affected", |
| 270 | + existingVector: lnwire.NewRawFeatureVector( |
| 271 | + lnwire.AMPOptional, |
| 272 | + lnwire.TLVOnionPayloadOptional, |
| 273 | + ), |
| 274 | + newBit: lnwire.Bolt11BlindedPathsRequired, |
| 275 | + expectedVector: lnwire.NewRawFeatureVector( |
| 276 | + lnwire.AMPOptional, |
| 277 | + lnwire.Bolt11BlindedPathsRequired, |
| 278 | + lnwire.RouteBlindingRequired, |
| 279 | + lnwire.TLVOnionPayloadRequired, |
| 280 | + ), |
| 281 | + }, |
| 282 | + } |
| 283 | + |
| 284 | + for _, test := range tests { |
| 285 | + t.Run(test.name, func(t *testing.T) { |
| 286 | + fv := lnwire.NewFeatureVector( |
| 287 | + test.existingVector, lnwire.Features, |
| 288 | + ) |
| 289 | + |
| 290 | + resultFV := SetBit(fv, test.newBit) |
| 291 | + require.Equal( |
| 292 | + t, test.expectedVector, |
| 293 | + resultFV.RawFeatureVector, |
| 294 | + ) |
| 295 | + }) |
| 296 | + } |
| 297 | +} |
| 298 | + |
| 299 | +// TestSetBitNoCycles tests the SetBit call for each feature bit that we know of |
| 300 | +// in both its optional and required form. This ensures that the SetBit call |
| 301 | +// never gets stuck in a recursion cycle for any feature bit. |
| 302 | +func TestSetBitNoCycles(t *testing.T) { |
| 303 | + t.Parallel() |
| 304 | + |
| 305 | + // For each feature-bit that we are aware of (both optional and |
| 306 | + // required), we will create a feature vector that is empty, and then |
| 307 | + // we will call SetBit with the given feature bit. We then check that |
| 308 | + // all the dependent features are also added in the appropriate form |
| 309 | + // (optional vs required). This test completing demonstrates that the |
| 310 | + // recursion in SetBit is not a problem since no feature bits should |
| 311 | + // create a dependency cycle. |
| 312 | + for bit := range lnwire.Features { |
| 313 | + fv := lnwire.NewFeatureVector( |
| 314 | + lnwire.NewRawFeatureVector(), lnwire.Features, |
| 315 | + ) |
| 316 | + |
| 317 | + resultFV := SetBit(fv, bit) |
| 318 | + |
| 319 | + // Ensure that all the dependent feature bits are in fact set |
| 320 | + // in the resulting set. Here we just check that some form |
| 321 | + // (optional or required) is set. The expected type is asserted |
| 322 | + // later on in the test. |
| 323 | + for expectedBit := range deps[bit] { |
| 324 | + require.True(t, resultFV.IsSet(expectedBit) || |
| 325 | + resultFV.IsSet(mapToRequired(expectedBit))) |
| 326 | + } |
| 327 | + |
| 328 | + // Make sure all the resulting feature bits have the correct |
| 329 | + // form (optional vs required). |
| 330 | + for depBit := range resultFV.Features() { |
| 331 | + require.Equal(t, bit.IsRequired(), depBit.IsRequired()) |
| 332 | + } |
| 333 | + } |
| 334 | +} |
0 commit comments