Skip to content

Commit c2e85c6

Browse files
authored
[Costs] Success probability (#970)
* success prob * docs
1 parent 8ea0267 commit c2e85c6

File tree

4 files changed

+87
-1
lines changed

4 files changed

+87
-1
lines changed

qualtran/bloqs/for_testing/costing.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,23 @@ def _convert_callees(callees: Sequence[BloqCountT]) -> Tuple[BloqCountT, ...]:
2424
return tuple(callees)
2525

2626

27+
def _convert_static_costs(
28+
static_costs: Sequence[Tuple[CostKey, Any]]
29+
) -> Tuple[Tuple[CostKey, Any], ...]:
30+
# Convert to tuples in a type-checked way.
31+
return tuple(static_costs)
32+
33+
2734
@frozen
2835
class CostingBloq(Bloq):
2936
"""A bloq that lets you set the costs via attributes."""
3037

3138
name: str
3239
num_qubits: int
3340
callees: Sequence[BloqCountT] = field(converter=_convert_callees, factory=tuple)
34-
static_costs: Sequence[Tuple[CostKey, Any]] = field(converter=tuple, factory=tuple)
41+
static_costs: Sequence[Tuple[CostKey, Any]] = field(
42+
converter=_convert_static_costs, factory=tuple
43+
)
3544

3645
@property
3746
def signature(self) -> 'Signature':

qualtran/resource_counting/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,6 @@
3131

3232
from ._costing import GeneralizerT, get_cost_value, get_cost_cache, query_costs, CostKey, CostValT
3333

34+
from ._success_prob import SuccessProb
35+
3436
from . import generalizers
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import logging
15+
from typing import Callable
16+
17+
from attrs import frozen
18+
19+
from qualtran import Bloq
20+
21+
from ._call_graph import get_bloq_callee_counts
22+
from ._costing import CostKey
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
@frozen
28+
class SuccessProb(CostKey[float]):
29+
"""The success probability of a bloq.
30+
31+
A bloq's success probability is the multiplicative product of its callees'
32+
success probabilities. Bloqs that have a specific success probability should override
33+
`my_static_costs` to provide their actual success probability.
34+
"""
35+
36+
def compute(self, bloq: 'Bloq', get_callee_cost: Callable[['Bloq'], float]) -> float:
37+
tot: float = 1.0
38+
callees = get_bloq_callee_counts(bloq)
39+
logger.info("Computing %s for %s from %d callee(s)", self, bloq, len(callees))
40+
for callee, n in callees:
41+
v = get_callee_cost(callee)
42+
tot *= v**n
43+
return tot
44+
45+
def zero(self) -> float:
46+
return 1.0 # under multiplication, 1 is the identity.
47+
48+
def __str__(self):
49+
return 'success prob'
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from qualtran.bloqs.for_testing.costing import CostingBloq
15+
from qualtran.resource_counting import get_cost_cache, get_cost_value, SuccessProb
16+
17+
18+
def test_coin_flip():
19+
flip = CostingBloq('CoinFlip', num_qubits=1, static_costs=[(SuccessProb(), 0.5)])
20+
algo = CostingBloq('Algo', num_qubits=0, callees=[(flip, 4)])
21+
22+
p = get_cost_value(algo, SuccessProb())
23+
assert p == 0.5**4
24+
25+
costs = get_cost_cache(algo, SuccessProb())
26+
assert costs == {algo: p, flip: 0.5}

0 commit comments

Comments
 (0)