|
1 | 1 | package model
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "cmp" |
4 | 5 | "errors"
|
5 | 6 | "fmt"
|
| 7 | + "slices" |
6 | 8 | "sort"
|
7 | 9 | "strings"
|
8 | 10 |
|
@@ -261,49 +263,238 @@ func (c *Channel) Validate() error {
|
261 | 263 | return result.orNil()
|
262 | 264 | }
|
263 | 265 |
|
264 |
| -// validateReplacesChain checks the replaces chain of a channel. |
265 |
| -// Specifically the following rules must be followed: |
266 |
| -// 1. There must be exactly 1 channel head. |
267 |
| -// 2. Beginning at the head, the replaces chain must reach all non-skipped entries. |
268 |
| -// Non-skipped entries are defined as entries that are not skipped by any other entry in the channel. |
269 |
| -// 3. There must be no cycles in the replaces chain. |
270 |
| -// 4. The tail entry in the replaces chain is permitted to replace a non-existent entry. |
271 |
| -func (c *Channel) validateReplacesChain() error { |
272 |
| - head, err := c.Head() |
273 |
| - if err != nil { |
274 |
| - return err |
| 266 | +type node struct { |
| 267 | + name string |
| 268 | + version semver.Version |
| 269 | + replacedBy map[string]*node |
| 270 | + replaces *node |
| 271 | + skippedBy map[string]*node |
| 272 | + skips map[string]*node |
| 273 | + skipRange string |
| 274 | + hasEntry bool |
| 275 | +} |
| 276 | + |
| 277 | +type graph struct { |
| 278 | + nodes map[string]*node |
| 279 | + root *node |
| 280 | +} |
| 281 | + |
| 282 | +func newGraph(c *Channel) (*graph, error) { |
| 283 | + nodes := map[string]*node{} |
| 284 | + for _, b := range c.Bundles { |
| 285 | + nodes[b.Name] = &node{ |
| 286 | + name: b.Name, |
| 287 | + version: b.Version, |
| 288 | + skipRange: b.SkipRange, |
| 289 | + replacedBy: make(map[string]*node), |
| 290 | + skippedBy: make(map[string]*node), |
| 291 | + skips: make(map[string]*node), |
| 292 | + } |
275 | 293 | }
|
276 | 294 |
|
277 |
| - allBundles := sets.NewString() |
278 |
| - skippedBundles := sets.NewString() |
279 | 295 | for _, b := range c.Bundles {
|
280 |
| - allBundles = allBundles.Insert(b.Name) |
281 |
| - skippedBundles = skippedBundles.Insert(b.Skips...) |
| 296 | + n := nodes[b.Name] |
| 297 | + |
| 298 | + if b.Replaces != "" { |
| 299 | + replaces, ok := nodes[b.Replaces] |
| 300 | + if !ok { |
| 301 | + replaces = &node{ |
| 302 | + name: b.Replaces, |
| 303 | + replacedBy: make(map[string]*node), |
| 304 | + hasEntry: false, |
| 305 | + } |
| 306 | + nodes[b.Replaces] = replaces |
| 307 | + } |
| 308 | + n.replaces = replaces |
| 309 | + n.replaces.replacedBy[n.name] = n |
| 310 | + } |
| 311 | + |
| 312 | + for _, skipName := range b.Skips { |
| 313 | + skip, ok := nodes[skipName] |
| 314 | + if !ok { |
| 315 | + skip = &node{ |
| 316 | + name: skipName, |
| 317 | + skippedBy: make(map[string]*node), |
| 318 | + skips: make(map[string]*node), |
| 319 | + hasEntry: false, |
| 320 | + } |
| 321 | + } |
| 322 | + skip.skippedBy[b.Name] = n |
| 323 | + n.skips[skipName] = skip |
| 324 | + } |
| 325 | + } |
| 326 | + |
| 327 | + return &graph{ |
| 328 | + nodes: nodes, |
| 329 | + }, nil |
| 330 | +} |
| 331 | + |
| 332 | +func (g *graph) validate() error { |
| 333 | + result := newValidationError("invalid upgrade graph") |
| 334 | + if err := g.validateNoCycles(); err != nil { |
| 335 | + result.subErrors = append(result.subErrors, err) |
| 336 | + } |
| 337 | + if err := g.validateNoStranded(); err != nil { |
| 338 | + result.subErrors = append(result.subErrors, err) |
| 339 | + } |
| 340 | + return result.orNil() |
| 341 | +} |
| 342 | + |
| 343 | +func (g *graph) validateNoCycles() error { |
| 344 | + result := newValidationError("cycles found in graph") |
| 345 | + allCycles := [][]*node{} |
| 346 | + for _, n := range g.nodes { |
| 347 | + ancestors := map[string]*node{} |
| 348 | + maps.Copy(ancestors, n.replacedBy) |
| 349 | + maps.Copy(ancestors, n.skippedBy) |
| 350 | + allCycles = append(allCycles, paths([]*node{n}, ancestors, n)...) |
| 351 | + } |
| 352 | + dedupPaths(&allCycles) |
| 353 | + for _, cycle := range allCycles { |
| 354 | + cycleStr := strings.Join(mapSlice(cycle, nodeName), " -> ") |
| 355 | + result.subErrors = append(result.subErrors, errors.New(cycleStr)) |
282 | 356 | }
|
283 | 357 |
|
284 |
| - chainFrom := map[string][]string{} |
285 |
| - replacesChainFromHead := sets.NewString(head.Name) |
| 358 | + return result.orNil() |
| 359 | +} |
| 360 | + |
| 361 | +func (g *graph) validateNoStranded() error { |
| 362 | + head, err := g.head() |
| 363 | + if err != nil { |
| 364 | + return err |
| 365 | + } |
| 366 | + all := sets.New[*node](maps.Values(g.nodes)...) |
| 367 | + chain := sets.New[*node]() |
| 368 | + skipped := sets.New[*node]() |
| 369 | + |
286 | 370 | cur := head
|
287 |
| - for cur != nil { |
288 |
| - if _, ok := chainFrom[cur.Name]; !ok { |
289 |
| - chainFrom[cur.Name] = []string{cur.Name} |
| 371 | + for cur != nil && !skipped.Has(cur) && !chain.Has(cur) { |
| 372 | + chain.Insert(cur) |
| 373 | + skipped.Insert(maps.Values(cur.skips)...) |
| 374 | + cur = cur.replaces |
| 375 | + } |
| 376 | + |
| 377 | + stranded := all.Difference(chain).Difference(skipped) |
| 378 | + if stranded.Len() > 0 { |
| 379 | + strandedNames := mapSlice(stranded.UnsortedList(), func(n *node) string { |
| 380 | + return n.name |
| 381 | + }) |
| 382 | + slices.Sort(strandedNames) |
| 383 | + return fmt.Errorf("channel contains one or more stranded bundles: %s", strings.Join(strandedNames, ", ")) |
| 384 | + } |
| 385 | + |
| 386 | + return nil |
| 387 | +} |
| 388 | + |
| 389 | +func (g *graph) head() (*node, error) { |
| 390 | + heads := []*node{} |
| 391 | + for _, n := range g.nodes { |
| 392 | + if len(n.replacedBy) == 0 && len(n.skippedBy) == 0 { |
| 393 | + heads = append(heads, n) |
290 | 394 | }
|
291 |
| - for k := range chainFrom { |
292 |
| - chainFrom[k] = append(chainFrom[k], cur.Replaces) |
| 395 | + } |
| 396 | + if len(heads) == 0 { |
| 397 | + return nil, fmt.Errorf("no channel head found in graph") |
| 398 | + } |
| 399 | + if len(heads) > 1 { |
| 400 | + var headNames []string |
| 401 | + for _, head := range heads { |
| 402 | + headNames = append(headNames, head.name) |
293 | 403 | }
|
294 |
| - if replacesChainFromHead.Has(cur.Replaces) { |
295 |
| - return fmt.Errorf("detected cycle in replaces chain of upgrade graph: %s", strings.Join(chainFrom[cur.Replaces], " -> ")) |
| 404 | + sort.Strings(headNames) |
| 405 | + return nil, fmt.Errorf("multiple channel heads found in graph: %s", strings.Join(headNames, ", ")) |
| 406 | + } |
| 407 | + return heads[0], nil |
| 408 | +} |
| 409 | + |
| 410 | +func nodeName(n *node) string { |
| 411 | + return n.name |
| 412 | +} |
| 413 | + |
| 414 | +func mapSlice[I, O any](s []I, fn func(I) O) []O { |
| 415 | + var result []O |
| 416 | + for _, i := range s { |
| 417 | + result = append(result, fn(i)) |
| 418 | + } |
| 419 | + return result |
| 420 | +} |
| 421 | + |
| 422 | +func paths(existingPath []*node, froms map[string]*node, to *node) [][]*node { |
| 423 | + if len(froms) == 0 { |
| 424 | + // we never found a path to "to" |
| 425 | + return nil |
| 426 | + } |
| 427 | + var allPaths [][]*node |
| 428 | + for _, f := range froms { |
| 429 | + path := append(slices.Clone(existingPath), f) |
| 430 | + if f == to { |
| 431 | + // we found "to"! |
| 432 | + allPaths = append(allPaths, path) |
| 433 | + } else { |
| 434 | + for _, subPath := range paths(path, f.replacedBy, to) { |
| 435 | + allPaths = append(allPaths, subPath) |
| 436 | + } |
296 | 437 | }
|
297 |
| - replacesChainFromHead = replacesChainFromHead.Insert(cur.Replaces) |
298 |
| - cur = c.Bundles[cur.Replaces] |
299 | 438 | }
|
| 439 | + return allPaths |
| 440 | +} |
300 | 441 |
|
301 |
| - strandedBundles := allBundles.Difference(replacesChainFromHead).Difference(skippedBundles).List() |
302 |
| - if len(strandedBundles) > 0 { |
303 |
| - return fmt.Errorf("channel contains one or more stranded bundles: %s", strings.Join(strandedBundles, ", ")) |
| 442 | +// dedupPaths removes rotations of the same cycle. |
| 443 | +// For example there are three paths: |
| 444 | +// 1. a -> b -> c -> a |
| 445 | +// 2. b -> c -> a -> b |
| 446 | +// 3. c -> a -> b -> c |
| 447 | +// |
| 448 | +// These are all the same cycle, so we want to choose just one of them. |
| 449 | +// dedupPaths chooses to keep the one whose first node has the highest version. |
| 450 | +func dedupPaths(paths *[][]*node) { |
| 451 | + slices.SortFunc(*paths, func(a, b []*node) int { |
| 452 | + if v := cmp.Compare(len(a), len(b)); v != 0 { |
| 453 | + return v |
| 454 | + } |
| 455 | + return b[0].version.Compare(a[0].version) |
| 456 | + }) |
| 457 | + deleteIndices := sets.New[int]() |
| 458 | + for i, path := range *paths { |
| 459 | + for j, other := range (*paths)[i+1:] { |
| 460 | + if isSameRotation(path, other) { |
| 461 | + deleteIndices.Insert(j + i + 1) |
| 462 | + } |
| 463 | + } |
304 | 464 | }
|
305 | 465 |
|
306 |
| - return nil |
| 466 | + toDelete := sets.List(deleteIndices) |
| 467 | + slices.Reverse(toDelete) |
| 468 | + for _, i := range toDelete { |
| 469 | + (*paths) = slices.Delete(*paths, i, i+1) |
| 470 | + } |
| 471 | +} |
| 472 | + |
| 473 | +func isSameRotation(a, b []*node) bool { |
| 474 | + if len(a) != len(b) { |
| 475 | + return false |
| 476 | + } |
| 477 | + aStr := strings.Join(mapSlice(a[:len(a)-1], nodeName), " -> ") |
| 478 | + bStr := strings.Join(mapSlice(b[:len(b)-1], nodeName), " -> ") |
| 479 | + aPlusA := aStr + " -> " + aStr |
| 480 | + return strings.Contains(aPlusA, bStr) |
| 481 | +} |
| 482 | + |
| 483 | +// validateReplacesChain checks the replaces chain of a channel. |
| 484 | +// Specifically the following rules must be followed: |
| 485 | +// 1. There must be exactly 1 channel head. |
| 486 | +// 2. Beginning at the head, the replaces chain traversal must reach all entries. |
| 487 | +// Unreached entries are considered "stranded" and cause a channel to be invalid. |
| 488 | +// 3. Skipped entries are always leaf nodes. We never follow replaces or skips edges |
| 489 | +// of skipped entries during replaces chain traversal. |
| 490 | +// 4. There must be no cycles in the replaces chain. |
| 491 | +// 5. The tail entry in the replaces chain is permitted to replace a non-existent entry. |
| 492 | +func (c *Channel) validateReplacesChain() error { |
| 493 | + g, err := newGraph(c) |
| 494 | + if err != nil { |
| 495 | + return err |
| 496 | + } |
| 497 | + return g.validate() |
307 | 498 | }
|
308 | 499 |
|
309 | 500 | type Bundle struct {
|
|
0 commit comments