Skip to content

Commit d7321f5

Browse files
author
Srikumar Sastry
committed
Added function for assigning pseudo labels for high confidence samples
1 parent 2f1e866 commit d7321f5

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

examples/cost_effective_active_learning.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@
3939
U_y = np.delete(U_y, ind, axis=0)
4040

4141

42+
def assign_pseudo_labels(active_learner, X, confidence_idx):
43+
conf_samples = X[confidence_idx]
44+
labels = active_learner.predict(conf_samples)
45+
return labels
46+
47+
4248
def max_entropy(active_learner, X, K=16, N=16):
4349

4450
class_prob = active_learner.predict_proba(X)
@@ -67,7 +73,15 @@ def max_entropy(active_learner, X, K=16, N=16):
6773

6874
query_idx, query_instance = active_learner.query(U_x, K_MAX_ENTROPY, N_MIN_ENTROPY)
6975

70-
active_learner.teach(U_x[query_idx], U_y[query_idx])
76+
uncertain_idx = query_idx[:K_MAX_ENTROPY]
77+
confidence_idx = query_idx[K_MAX_ENTROPY:]
78+
79+
conf_labels = assign_pseudo_labels(active_learner, U_x, confidence_idx)
80+
81+
L_x = U_x[query_idx]
82+
L_y = np.concatenate((U_y[uncertain_idx], conf_labels), axis=0)
83+
84+
active_learner.teach(L_x, L_y)
7185

7286
U_x = np.delete(U_x, query_idx, axis=0)
7387
U_y = np.delete(U_y, query_idx, axis=0)

0 commit comments

Comments
 (0)