@@ -1355,15 +1355,15 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
1355
1355
std::vector<int32_t > ids;
1356
1356
std::vector<ggml_bitset_t > used_ids;
1357
1357
1358
- for (int i = 0 ; i < sched->n_splits ; i ++) {
1359
- struct ggml_backend_sched_split * split = &splits[i ];
1358
+ for (int split_id = 0 ; split_id < sched->n_splits ; split_id ++) {
1359
+ struct ggml_backend_sched_split * split = &splits[split_id ];
1360
1360
int split_backend_id = split->backend_id ;
1361
1361
ggml_backend_t split_backend = sched->backends [split_backend_id];
1362
1362
1363
1363
// copy the input tensors to the split backend
1364
- for (int j = 0 ; j < split->n_inputs ; j ++) {
1365
- ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend (sched, split->inputs [j ]);
1366
- struct ggml_tensor * input = split->inputs [j ];
1364
+ for (int input_id = 0 ; input_id < split->n_inputs ; input_id ++) {
1365
+ ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend (sched, split->inputs [input_id ]);
1366
+ struct ggml_tensor * input = split->inputs [input_id ];
1367
1367
struct ggml_tensor * input_cpy = tensor_copy (input, split_backend_id, sched->cur_copy );
1368
1368
1369
1369
if (input->flags & GGML_TENSOR_FLAG_INPUT) {
@@ -1398,17 +1398,30 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
1398
1398
1399
1399
// get the ids
1400
1400
ggml_tensor * ids_tensor = node->src [2 ];
1401
+ ggml_backend_t ids_backend = split_backend;
1402
+
1403
+ // if the ids tensor is also an input of the split, it may not have been copied yet to the split backend
1404
+ // in that case, we use the original ids tensor
1405
+ for (int i = input_id + 1 ; i < split->n_inputs ; i++) {
1406
+ if (ids_tensor == tensor_copy (split->inputs [i], split_backend_id, sched->cur_copy )) {
1407
+ ids_tensor = split->inputs [i];
1408
+ ids_backend = ggml_backend_sched_get_tensor_backend (sched, split->inputs [i]);
1409
+ break ;
1410
+ }
1411
+ }
1412
+
1401
1413
if (ids_tensor != prev_ids_tensor) {
1402
1414
ids.resize (ggml_nbytes (ids_tensor) / sizeof (int32_t ));
1403
- ggml_backend_tensor_get_async (split_backend , ids_tensor, ids.data (), 0 , ggml_nbytes (ids_tensor));
1404
- ggml_backend_synchronize (split_backend );
1415
+ ggml_backend_tensor_get_async (ids_backend , ids_tensor, ids.data (), 0 , ggml_nbytes (ids_tensor));
1416
+ ggml_backend_synchronize (ids_backend );
1405
1417
1406
1418
// find the used experts
1407
1419
used_ids.clear ();
1408
1420
used_ids.resize (ggml_bitset_size (n_expert));
1409
1421
for (int64_t i1 = 0 ; i1 < ids_tensor->ne [1 ]; i1++) {
1410
1422
for (int64_t i0 = 0 ; i0 < ids_tensor->ne [0 ]; i0++) {
1411
1423
int32_t id = ids[i1 * ids_tensor->nb [1 ]/sizeof (int32_t ) + i0 * ids_tensor->nb [0 ]/sizeof (int32_t )];
1424
+ GGML_ASSERT (id >= 0 && id < n_expert);
1412
1425
ggml_bitset_set (used_ids.data (), id);
1413
1426
}
1414
1427
}
0 commit comments