Skip to content

Commit 3a8df49

Browse files
committed
update the observation trace unfolder towards generic support to compute conditional probs and fix bug in belief tracker
1 parent 3f77cdc commit 3a8df49

File tree

3 files changed

+61
-24
lines changed

3 files changed

+61
-24
lines changed

src/storm-pomdp/generator/NondeterministicBeliefTracker.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ NondeterministicBeliefTracker<ValueType, BeliefState>::NondeterministicBeliefTra
383383
template<typename ValueType, typename BeliefState>
384384
bool NondeterministicBeliefTracker<ValueType, BeliefState>::reset(uint32_t observation) {
385385
bool hit = false;
386+
beliefs.clear();
386387
for (auto state : pomdp.getInitialStates()) {
387388
if (observation == pomdp.getObservation(state)) {
388389
hit = true;
@@ -492,9 +493,6 @@ bool NondeterministicBeliefTracker<ValueType, BeliefState>::hasTimedOut() const
492493
template class SparseBeliefState<double>;
493494
template bool operator==(SparseBeliefState<double> const&, SparseBeliefState<double> const&);
494495
template class NondeterministicBeliefTracker<double, SparseBeliefState<double>>;
495-
// template class ObservationDenseBeliefState<double>;
496-
// template bool operator==(ObservationDenseBeliefState<double> const&, ObservationDenseBeliefState<double> const&);
497-
// template class NondeterministicBeliefTracker<double, ObservationDenseBeliefState<double>>;
498496

499497
template class SparseBeliefState<storm::RationalNumber>;
500498
template bool operator==(SparseBeliefState<storm::RationalNumber> const&, SparseBeliefState<storm::RationalNumber> const&);

src/storm-pomdp/transformer/ObservationTraceUnfolder.cpp

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,34 @@
11
#include "storm-pomdp/transformer/ObservationTraceUnfolder.h"
2+
3+
#include "storm/adapters/RationalFunctionAdapter.h"
24
#include "storm/exceptions/InvalidArgumentException.h"
35
#include "storm/storage/expressions/ExpressionManager.h"
46
#include "storm/utility/ConstantsComparator.h"
57

6-
#include "storm/adapters/RationalFunctionAdapter.h"
78

89
#undef _VERBOSE_OBSERVATION_UNFOLDING
910

1011
namespace storm {
1112
namespace pomdp {
1213
template<typename ValueType>
1314
ObservationTraceUnfolder<ValueType>::ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model, std::vector<ValueType> const& risk,
14-
std::shared_ptr<storm::expressions::ExpressionManager>& exprManager)
15-
: model(model), risk(risk), exprManager(exprManager) {
15+
std::shared_ptr<storm::expressions::ExpressionManager>& exprManager,
16+
ObservationTraceUnfolderOptions const& options)
17+
: model(model), risk(risk), exprManager(exprManager), options(options) {
1618
statesPerObservation = std::vector<storm::storage::BitVector>(model.getNrObservations() + 1, storm::storage::BitVector(model.getNumberOfStates()));
1719
for (uint64_t state = 0; state < model.getNumberOfStates(); ++state) {
1820
statesPerObservation[model.getObservation(state)].set(state, true);
1921
}
2022
svvar = exprManager->declareFreshIntegerVariable(false, "_s");
23+
tsvar = exprManager->declareFreshIntegerVariable(false, "_t");
2124
}
2225

2326
template<typename ValueType>
2427
std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<ValueType>::transform(const std::vector<uint32_t>& observations) {
2528
std::vector<uint32_t> modifiedObservations = observations;
2629
// First observation should be special.
2730
// This just makes the algorithm simpler because we do not treat the first step as a special case later.
31+
// We overwrite the observation with a non-existing obs z*
2832
modifiedObservations[0] = model.getNrObservations();
2933

3034
storm::storage::BitVector initialStates = model.getInitialStates();
@@ -36,15 +40,17 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
3640
}
3741
STORM_LOG_THROW(actualInitialStates.getNumberOfSetBits() == 1, storm::exceptions::InvalidArgumentException,
3842
"Must have unique initial state matching the observation");
39-
//
43+
// For this z* that only exists in the initial state, we now also define the states for this observation.
4044
statesPerObservation[model.getNrObservations()] = actualInitialStates;
4145

4246
#ifdef _VERBOSE_OBSERVATION_UNFOLDING
4347
std::cout << "build valution builder..\n";
4448
#endif
4549
storm::storage::sparse::StateValuationsBuilder svbuilder;
4650
svbuilder.addVariable(svvar);
51+
svbuilder.addVariable(tsvar);
4752

53+
// TODO: Do we need this as ordered maps? Is it better to make them unordered?
4854
std::map<uint64_t, uint64_t> unfoldedToOld;
4955
std::map<uint64_t, uint64_t> unfoldedToOldNextStep;
5056
std::map<uint64_t, uint64_t> oldToUnfolded;
@@ -53,15 +59,30 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
5359
std::cout << "start buildiing matrix...\n";
5460
#endif
5561

62+
uint64_t newStateIndex = 0;
63+
// TODO do not add violated state if we do rejection sampling.
64+
uint64_t violatedState = newStateIndex;
65+
++newStateIndex;
5666
// Add this initial state state:
57-
unfoldedToOldNextStep[0] = actualInitialStates.getNextSetIndex(0);
67+
uint64_t initialState = newStateIndex;
68+
++newStateIndex;
5869

70+
unfoldedToOldNextStep[initialState] = actualInitialStates.getNextSetIndex(0);
71+
72+
uint64_t resetDestination = options.rejectionSampling ? initialState : violatedState; // Should be initial state for the standard semantics.
5973
storm::storage::SparseMatrixBuilder<ValueType> transitionMatrixBuilder(0, 0, 0, true, true);
60-
uint64_t newStateIndex = 1;
61-
uint64_t newRowGroupStart = 0;
62-
uint64_t newRowCount = 0;
63-
// Notice that we are going to use a special last step
6474

75+
// TODO only add this state if it is actually reachable / rejection sampling
76+
// the violated state is a sink state
77+
transitionMatrixBuilder.newRowGroup(violatedState);
78+
transitionMatrixBuilder.addNextValue(violatedState, violatedState, storm::utility::one<ValueType>());
79+
svbuilder.addState(violatedState, {}, {-1, -1});
80+
81+
// Now we are starting to build the MDP from the initial state onwards.
82+
uint64_t newRowGroupStart = initialState;
83+
uint64_t newRowCount = initialState;
84+
85+
// Notice that we are going to use a special last step
6586
for (uint64_t step = 0; step < observations.size() - 1; ++step) {
6687
oldToUnfolded.clear();
6788
unfoldedToOld = unfoldedToOldNextStep;
@@ -73,7 +94,7 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
7394
std::cout << "\tconsider new state " << unfoldedToOldEntry.first << '\n';
7495
#endif
7596
assert(step == 0 || newRowCount == transitionMatrixBuilder.getLastRow() + 1);
76-
svbuilder.addState(unfoldedToOldEntry.first, {}, {static_cast<int64_t>(unfoldedToOldEntry.second)});
97+
svbuilder.addState(unfoldedToOldEntry.first, {}, {static_cast<int64_t>(unfoldedToOldEntry.second), static_cast<int64_t>(step)});
7798
uint64_t oldRowIndexStart = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second];
7899
uint64_t oldRowIndexEnd = model.getNondeterministicChoiceIndices()[unfoldedToOldEntry.second + 1];
79100

@@ -96,7 +117,7 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
96117

97118
// Add the resets
98119
if (resetProb != storm::utility::zero<ValueType>()) {
99-
transitionMatrixBuilder.addNextValue(newRowCount, 0, resetProb);
120+
transitionMatrixBuilder.addNextValue(newRowCount, resetDestination, resetProb);
100121
}
101122
#ifdef _VERBOSE_OBSERVATION_UNFOLDING
102123
std::cout << "\t\t\t add other transitions...\n";
@@ -125,22 +146,20 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
125146
}
126147
newRowCount++;
127148
}
128-
129149
newRowGroupStart = transitionMatrixBuilder.getLastRow() + 1;
130150
}
131151
}
132152
// Now, take care of the last step.
133153
uint64_t sinkState = newStateIndex;
134154
uint64_t targetState = newStateIndex + 1;
135-
auto cc = storm::utility::ConstantsComparator<ValueType>();
155+
[[maybe_unused]] auto cc = storm::utility::ConstantsComparator<ValueType>();
136156
for (auto const& unfoldedToOldEntry : unfoldedToOldNextStep) {
137-
svbuilder.addState(unfoldedToOldEntry.first, {}, {static_cast<int64_t>(unfoldedToOldEntry.second)});
157+
svbuilder.addState(unfoldedToOldEntry.first, {}, {static_cast<int64_t>(unfoldedToOldEntry.second), static_cast<int64_t>(observations.size() - 1)});
138158

139159
transitionMatrixBuilder.newRowGroup(newRowGroupStart);
140160
STORM_LOG_ASSERT(risk.size() > unfoldedToOldEntry.second, "Must be a state");
141161
STORM_LOG_ASSERT(!cc.isLess(storm::utility::one<ValueType>(), risk[unfoldedToOldEntry.second]), "Risk must be a probability");
142162
STORM_LOG_ASSERT(!cc.isLess(risk[unfoldedToOldEntry.second], storm::utility::zero<ValueType>()), "Risk must be a probability");
143-
// std::cout << "risk is" << risk[unfoldedToOldEntry.second] << '\n';
144163
if (!storm::utility::isOne(risk[unfoldedToOldEntry.second])) {
145164
transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, storm::utility::one<ValueType>() - risk[unfoldedToOldEntry.second]);
146165
}
@@ -152,13 +171,13 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
152171
// sink state
153172
transitionMatrixBuilder.newRowGroup(newRowGroupStart);
154173
transitionMatrixBuilder.addNextValue(newRowGroupStart, sinkState, storm::utility::one<ValueType>());
155-
svbuilder.addState(sinkState, {}, {-1});
174+
svbuilder.addState(sinkState, {}, {-1, -1});
156175

157176
newRowGroupStart++;
158177
transitionMatrixBuilder.newRowGroup(newRowGroupStart);
159178
// target state
160179
transitionMatrixBuilder.addNextValue(newRowGroupStart, targetState, storm::utility::one<ValueType>());
161-
svbuilder.addState(targetState, {}, {-1});
180+
svbuilder.addState(targetState, {}, {-1, -1});
162181
#ifdef _VERBOSE_OBSERVATION_UNFOLDING
163182
std::cout << "build matrix...\n";
164183
#endif
@@ -168,14 +187,19 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
168187
#ifdef _VERBOSE_OBSERVATION_UNFOLDING
169188
std::cout << components.transitionMatrix << '\n';
170189
#endif
171-
STORM_LOG_ASSERT(components.transitionMatrix.getRowGroupCount() == targetState + 1,
190+
STORM_LOG_ASSERT(components.transitionMatrix.getRowGroupCount() == targetState + 2,
172191
"Expect row group count (" << components.transitionMatrix.getRowGroupCount() << ") one more as target state index " << targetState << ")");
173192

174193
storm::models::sparse::StateLabeling labeling(components.transitionMatrix.getRowGroupCount());
175194
labeling.addLabel("_goal");
176195
labeling.addLabelToState("_goal", targetState);
196+
labeling.addLabel("_violated");
197+
labeling.addLabelToState("_violated", violatedState);
198+
labeling.addLabel("_end");
199+
labeling.addLabelToState("_end", sinkState);
200+
labeling.addLabelToState("_end", targetState);
177201
labeling.addLabel("init");
178-
labeling.addLabelToState("init", 0);
202+
labeling.addLabelToState("init", initialState);
179203
components.stateLabeling = labeling;
180204
components.stateValuations = svbuilder.build();
181205
return std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move(components));
@@ -192,6 +216,11 @@ void ObservationTraceUnfolder<ValueType>::reset(uint32_t observation) {
192216
traceSoFar = {observation};
193217
}
194218

219+
template<typename ValueType>
220+
bool ObservationTraceUnfolder<ValueType>::isRejectionSamplingSet() const {
221+
return options.rejectionSampling;
222+
}
223+
195224
template class ObservationTraceUnfolder<double>;
196225
template class ObservationTraceUnfolder<storm::RationalNumber>;
197226
template class ObservationTraceUnfolder<storm::RationalFunction>;

src/storm-pomdp/transformer/ObservationTraceUnfolder.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
namespace storm {
44
namespace pomdp {
5+
6+
class ObservationTraceUnfolderOptions {
7+
public:
8+
bool rejectionSampling = true;
9+
};
10+
511
/**
612
* Observation-trace unrolling to allow model checking for monitoring.
713
* This approach is outlined in Junges, Hazem, Seshia -- Runtime Monitoring for Markov Decision Processes
@@ -17,7 +23,7 @@ class ObservationTraceUnfolder {
1723
* @param exprManager an Expression Manager
1824
*/
1925
ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const& model, std::vector<ValueType> const& risk,
20-
std::shared_ptr<storm::expressions::ExpressionManager>& exprManager);
26+
std::shared_ptr<storm::expressions::ExpressionManager>& exprManager, ObservationTraceUnfolderOptions const& options);
2127
/**
2228
* Transform in one shot
2329
* @param observations
@@ -36,13 +42,17 @@ class ObservationTraceUnfolder {
3642
*/
3743
void reset(uint32_t observation);
3844

45+
bool isRejectionSamplingSet() const;
46+
3947
private:
4048
storm::models::sparse::Pomdp<ValueType> const& model;
4149
std::vector<ValueType> risk; // TODO reconsider holding this as a reference, but there were some strange bugs
4250
std::shared_ptr<storm::expressions::ExpressionManager>& exprManager;
4351
std::vector<storm::storage::BitVector> statesPerObservation;
4452
std::vector<uint32_t> traceSoFar;
45-
storm::expressions::Variable svvar;
53+
storm::expressions::Variable svvar; // Maps to the old state (explicit encoding)
54+
storm::expressions::Variable tsvar; // Maps to the time step
55+
ObservationTraceUnfolderOptions options;
4656
};
4757

4858
} // namespace pomdp

0 commit comments

Comments
 (0)