11#include " storm-pomdp/transformer/ObservationTraceUnfolder.h"
2+
3+ #include " storm/adapters/RationalFunctionAdapter.h"
24#include " storm/exceptions/InvalidArgumentException.h"
35#include " storm/storage/expressions/ExpressionManager.h"
46#include " storm/utility/ConstantsComparator.h"
57
6- #include " storm/adapters/RationalFunctionAdapter.h"
78
89#undef _VERBOSE_OBSERVATION_UNFOLDING
910
1011namespace storm {
1112namespace pomdp {
1213template <typename ValueType>
1314ObservationTraceUnfolder<ValueType>::ObservationTraceUnfolder(storm::models::sparse::Pomdp<ValueType> const & model, std::vector<ValueType> const & risk,
14- std::shared_ptr<storm::expressions::ExpressionManager>& exprManager)
15- : model(model), risk(risk), exprManager(exprManager) {
15+ std::shared_ptr<storm::expressions::ExpressionManager>& exprManager,
16+ ObservationTraceUnfolderOptions const & options)
17+ : model(model), risk(risk), exprManager(exprManager), options(options) {
1618 statesPerObservation = std::vector<storm::storage::BitVector>(model.getNrObservations () + 1 , storm::storage::BitVector (model.getNumberOfStates ()));
1719 for (uint64_t state = 0 ; state < model.getNumberOfStates (); ++state) {
1820 statesPerObservation[model.getObservation (state)].set (state, true );
1921 }
2022 svvar = exprManager->declareFreshIntegerVariable (false , " _s" );
23+ tsvar = exprManager->declareFreshIntegerVariable (false , " _t" );
2124}
2225
2326template <typename ValueType>
2427std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<ValueType>::transform(const std::vector<uint32_t >& observations) {
2528 std::vector<uint32_t > modifiedObservations = observations;
2629 // First observation should be special.
2730 // This just makes the algorithm simpler because we do not treat the first step as a special case later.
31+ // We overwrite the observation with a non-existing obs z*
2832 modifiedObservations[0 ] = model.getNrObservations ();
2933
3034 storm::storage::BitVector initialStates = model.getInitialStates ();
@@ -36,15 +40,17 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
3640 }
3741 STORM_LOG_THROW (actualInitialStates.getNumberOfSetBits () == 1 , storm::exceptions::InvalidArgumentException,
3842 " Must have unique initial state matching the observation" );
39- //
43+ // For this z* that only exists in the initial state, we now also define the states for this observation.
4044 statesPerObservation[model.getNrObservations ()] = actualInitialStates;
4145
4246#ifdef _VERBOSE_OBSERVATION_UNFOLDING
4347 std::cout << " build valution builder..\n " ;
4448#endif
4549 storm::storage::sparse::StateValuationsBuilder svbuilder;
4650 svbuilder.addVariable (svvar);
51+ svbuilder.addVariable (tsvar);
4752
53+ // TODO: Do we need this as ordered maps? Is it better to make them unordered?
4854 std::map<uint64_t , uint64_t > unfoldedToOld;
4955 std::map<uint64_t , uint64_t > unfoldedToOldNextStep;
5056 std::map<uint64_t , uint64_t > oldToUnfolded;
@@ -53,15 +59,30 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
5359 std::cout << " start buildiing matrix...\n " ;
5460#endif
5561
62+ uint64_t newStateIndex = 0 ;
63+ // TODO do not add violated state if we do rejection sampling.
64+ uint64_t violatedState = newStateIndex;
65+ ++newStateIndex;
5666 // Add this initial state state:
57- unfoldedToOldNextStep[0 ] = actualInitialStates.getNextSetIndex (0 );
67+ uint64_t initialState = newStateIndex;
68+ ++newStateIndex;
5869
70+ unfoldedToOldNextStep[initialState] = actualInitialStates.getNextSetIndex (0 );
71+
72+ uint64_t resetDestination = options.rejectionSampling ? initialState : violatedState; // Should be initial state for the standard semantics.
5973 storm::storage::SparseMatrixBuilder<ValueType> transitionMatrixBuilder (0 , 0 , 0 , true , true );
60- uint64_t newStateIndex = 1 ;
61- uint64_t newRowGroupStart = 0 ;
62- uint64_t newRowCount = 0 ;
63- // Notice that we are going to use a special last step
6474
75+ // TODO only add this state if it is actually reachable / rejection sampling
76+ // the violated state is a sink state
77+ transitionMatrixBuilder.newRowGroup (violatedState);
78+ transitionMatrixBuilder.addNextValue (violatedState, violatedState, storm::utility::one<ValueType>());
79+ svbuilder.addState (violatedState, {}, {-1 , -1 });
80+
81+ // Now we are starting to build the MDP from the initial state onwards.
82+ uint64_t newRowGroupStart = initialState;
83+ uint64_t newRowCount = initialState;
84+
85+ // Notice that we are going to use a special last step
6586 for (uint64_t step = 0 ; step < observations.size () - 1 ; ++step) {
6687 oldToUnfolded.clear ();
6788 unfoldedToOld = unfoldedToOldNextStep;
@@ -73,7 +94,7 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
7394 std::cout << " \t consider new state " << unfoldedToOldEntry.first << ' \n ' ;
7495#endif
7596 assert (step == 0 || newRowCount == transitionMatrixBuilder.getLastRow () + 1 );
76- svbuilder.addState (unfoldedToOldEntry.first , {}, {static_cast <int64_t >(unfoldedToOldEntry.second )});
97+ svbuilder.addState (unfoldedToOldEntry.first , {}, {static_cast <int64_t >(unfoldedToOldEntry.second ), static_cast < int64_t >(step) });
7798 uint64_t oldRowIndexStart = model.getNondeterministicChoiceIndices ()[unfoldedToOldEntry.second ];
7899 uint64_t oldRowIndexEnd = model.getNondeterministicChoiceIndices ()[unfoldedToOldEntry.second + 1 ];
79100
@@ -96,7 +117,7 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
96117
97118 // Add the resets
98119 if (resetProb != storm::utility::zero<ValueType>()) {
99- transitionMatrixBuilder.addNextValue (newRowCount, 0 , resetProb);
120+ transitionMatrixBuilder.addNextValue (newRowCount, resetDestination , resetProb);
100121 }
101122#ifdef _VERBOSE_OBSERVATION_UNFOLDING
102123 std::cout << " \t\t\t add other transitions...\n " ;
@@ -125,22 +146,20 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
125146 }
126147 newRowCount++;
127148 }
128-
129149 newRowGroupStart = transitionMatrixBuilder.getLastRow () + 1 ;
130150 }
131151 }
132152 // Now, take care of the last step.
133153 uint64_t sinkState = newStateIndex;
134154 uint64_t targetState = newStateIndex + 1 ;
135- auto cc = storm::utility::ConstantsComparator<ValueType>();
155+ [[maybe_unused]] auto cc = storm::utility::ConstantsComparator<ValueType>();
136156 for (auto const & unfoldedToOldEntry : unfoldedToOldNextStep) {
137- svbuilder.addState (unfoldedToOldEntry.first , {}, {static_cast <int64_t >(unfoldedToOldEntry.second )});
157+ svbuilder.addState (unfoldedToOldEntry.first , {}, {static_cast <int64_t >(unfoldedToOldEntry.second ), static_cast < int64_t >(observations. size () - 1 ) });
138158
139159 transitionMatrixBuilder.newRowGroup (newRowGroupStart);
140160 STORM_LOG_ASSERT (risk.size () > unfoldedToOldEntry.second , " Must be a state" );
141161 STORM_LOG_ASSERT (!cc.isLess (storm::utility::one<ValueType>(), risk[unfoldedToOldEntry.second ]), " Risk must be a probability" );
142162 STORM_LOG_ASSERT (!cc.isLess (risk[unfoldedToOldEntry.second ], storm::utility::zero<ValueType>()), " Risk must be a probability" );
143- // std::cout << "risk is" << risk[unfoldedToOldEntry.second] << '\n';
144163 if (!storm::utility::isOne (risk[unfoldedToOldEntry.second ])) {
145164 transitionMatrixBuilder.addNextValue (newRowGroupStart, sinkState, storm::utility::one<ValueType>() - risk[unfoldedToOldEntry.second ]);
146165 }
@@ -152,13 +171,13 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
152171 // sink state
153172 transitionMatrixBuilder.newRowGroup (newRowGroupStart);
154173 transitionMatrixBuilder.addNextValue (newRowGroupStart, sinkState, storm::utility::one<ValueType>());
155- svbuilder.addState (sinkState, {}, {-1 });
174+ svbuilder.addState (sinkState, {}, {-1 , - 1 });
156175
157176 newRowGroupStart++;
158177 transitionMatrixBuilder.newRowGroup (newRowGroupStart);
159178 // target state
160179 transitionMatrixBuilder.addNextValue (newRowGroupStart, targetState, storm::utility::one<ValueType>());
161- svbuilder.addState (targetState, {}, {-1 });
180+ svbuilder.addState (targetState, {}, {-1 , - 1 });
162181#ifdef _VERBOSE_OBSERVATION_UNFOLDING
163182 std::cout << " build matrix...\n " ;
164183#endif
@@ -168,14 +187,19 @@ std::shared_ptr<storm::models::sparse::Mdp<ValueType>> ObservationTraceUnfolder<
168187#ifdef _VERBOSE_OBSERVATION_UNFOLDING
169188 std::cout << components.transitionMatrix << ' \n ' ;
170189#endif
171- STORM_LOG_ASSERT (components.transitionMatrix .getRowGroupCount () == targetState + 1 ,
190+ STORM_LOG_ASSERT (components.transitionMatrix .getRowGroupCount () == targetState + 2 ,
172191 " Expect row group count (" << components.transitionMatrix .getRowGroupCount () << " ) one more as target state index " << targetState << " )" );
173192
174193 storm::models::sparse::StateLabeling labeling (components.transitionMatrix .getRowGroupCount ());
175194 labeling.addLabel (" _goal" );
176195 labeling.addLabelToState (" _goal" , targetState);
196+ labeling.addLabel (" _violated" );
197+ labeling.addLabelToState (" _violated" , violatedState);
198+ labeling.addLabel (" _end" );
199+ labeling.addLabelToState (" _end" , sinkState);
200+ labeling.addLabelToState (" _end" , targetState);
177201 labeling.addLabel (" init" );
178- labeling.addLabelToState (" init" , 0 );
202+ labeling.addLabelToState (" init" , initialState );
179203 components.stateLabeling = labeling;
180204 components.stateValuations = svbuilder.build ();
181205 return std::make_shared<storm::models::sparse::Mdp<ValueType>>(std::move (components));
@@ -192,6 +216,11 @@ void ObservationTraceUnfolder<ValueType>::reset(uint32_t observation) {
192216 traceSoFar = {observation};
193217}
194218
219+ template <typename ValueType>
220+ bool ObservationTraceUnfolder<ValueType>::isRejectionSamplingSet() const {
221+ return options.rejectionSampling ;
222+ }
223+
195224template class ObservationTraceUnfolder <double >;
196225template class ObservationTraceUnfolder <storm::RationalNumber>;
197226template class ObservationTraceUnfolder <storm::RationalFunction>;
0 commit comments