Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions src/track/voting/best.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,101 @@ where
res
}
}

pub struct BestFitVotingWithFallback<OA>
where
OA: ObservationAttributes,
{
pub max_distance: f32,
pub min_votes: usize,
_phantom: PhantomData<OA>,
}

impl<OA> BestFitVotingWithFallback<OA>
where
OA: ObservationAttributes,
{
pub fn new(max_distance: f32, min_votes: usize) -> Self {
Self {
max_distance,
min_votes,
_phantom: PhantomData,
}
}
}

impl<OA> Voting<OA> for BestFitVotingWithFallback<OA>
where
OA: ObservationAttributes,
{
type WinnerObject = TopNVotingElt;

fn winners<T>(&self, distances: T) -> HashMap<u64, Vec<TopNVotingElt>>
where
T: IntoIterator<Item = ObservationMetricOk<OA>>,
{
let mut max_dist = -1.0_f32;

// Step 1: group all distances by (from, to), filter by max_distance
let grouped: HashMap<(u64, u64), Vec<f32>> = distances
.into_iter()
.filter_map(|d| match d.feature_distance {
Some(f) if f <= self.max_distance => {
max_dist = max_dist.max(f);
Some(((d.from, d.to), f))
}
_ => None,
})
.into_group_map();

// Step 2: filter by min_votes
let filtered: Vec<TopNVotingElt> = grouped
.into_iter()
.filter(|(_, v)| v.len() >= self.min_votes)
.map(|((from, to), dists)| {
let weight = dists.into_iter().map(|d| (max_dist - d) as f64).sum();
TopNVotingElt {
query_track: from,
winner_track: to,
weight,
}
})
.collect();

// Step 3: group by query (from), and sort each list by descending weight
let mut per_query = filtered.into_iter().into_group_map_by(|e| e.query_track);

for candidates in per_query.values_mut() {
candidates.sort_by(|a, b| b.weight.partial_cmp(&a.weight).unwrap());
}

// Step 4: assign each query to its best available winner (fallback to self)
let mut used_winners = HashSet::new();
let mut final_map = HashMap::new();

for (query_id, candidates) in per_query {
let mut assigned = false;
for mut cand in candidates {
if !used_winners.contains(&cand.winner_track) {
used_winners.insert(cand.winner_track);
final_map.insert(query_id, vec![cand]);
assigned = true;
break;
}
}

if !assigned {
final_map.insert(
query_id,
vec![TopNVotingElt {
query_track: query_id,
winner_track: query_id,
weight: 0.0,
}],
);
}
}

final_map
}
}
43 changes: 29 additions & 14 deletions src/trackers/visual_sort/voting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::trackers::sort::voting::SortVoting;
use crate::trackers::sort::VotingType;
use crate::trackers::visual_sort::observation_attributes::VisualObservationAttributes;
use crate::utils::bbox::Universal2DBox;
use crate::voting::best::BestFitVoting;
use crate::voting::best::BestFitVotingWithFallback;
use crate::voting::Voting;
use itertools::Itertools;
use log::debug;
Expand Down Expand Up @@ -49,25 +49,39 @@ impl Voting<VisualObservationAttributes> for VisualVoting {
where
T: IntoIterator<Item = ObservationMetricOk<VisualObservationAttributes>>,
{
let topn_feature_voting: BestFitVoting<VisualObservationAttributes> = BestFitVoting::new(
self.max_allowed_feature_distance,
self.min_winner_feature_votes,
);
let topn_feature_voting: BestFitVotingWithFallback<VisualObservationAttributes> =
BestFitVotingWithFallback::new(
self.max_allowed_feature_distance,
self.min_winner_feature_votes,
);

let (distances, distances_clone) = distances.into_iter().tee();

let feature_winners = topn_feature_voting.winners(distances);
debug!("TopN winners: {:#?}", &feature_winners);
// First round: feature-based voting
let raw_feature_winners: HashMap<u64, Vec<TopNVotingElt>> =
topn_feature_voting.winners(distances);
debug!("TopN raw_feature_winners: {:#?}", &raw_feature_winners);

let mut excluded_tracks = HashSet::new();
let mut feature_winners = feature_winners
let mut feature_winners = HashMap::new();

raw_feature_winners
.into_iter()
.map(|(from, w)| {
let winner_track = w[0].winner_track;
excluded_tracks.insert(winner_track);
(from, vec![(winner_track, VotingType::Visual)])
})
.collect::<HashMap<_, _>>();
.for_each(|(from, winner_list)| {
let TopNVotingElt {
winner_track,
query_track,
..
} = winner_list[0];

if winner_track != query_track {
excluded_tracks.insert(winner_track);
feature_winners.insert(from, vec![(winner_track, VotingType::Visual)]);
}
});

debug!("TopN winners: {:#?}", &feature_winners);
debug!("Excluded tracks: {:#?}", &excluded_tracks);

let mut remaining_candidates = HashSet::new();
let mut remaining_tracks = HashSet::new();
Expand Down Expand Up @@ -95,6 +109,7 @@ impl Voting<VisualObservationAttributes> for VisualVoting {
.into_iter()
.map(|(from, winner)| (from, vec![(winner[0], VotingType::Positional)]));

debug!("positional_winners: {:#?}", &positional_winners);
feature_winners.extend(positional_winners);
feature_winners
}
Expand Down