+ {
+ T apply(T result, int s, int t, double d);
+ }
+
+ /**
+ * Iterate over the outgoing transitions of state {@code state}
+ * and apply the reducing function {@code fn}
+ * to the intermediate result and the transition:
+ *
+ * Call {@code apply(r,s,t,d)} where
+ * {@code r} is the intermediate result,
+ * {@code t} is the successor state and,
+ * in a DTMC, {@code d} = P(s,t) is the probability from {@code s} to {@code t},
+ * while in CTMC, {@code d} = R(s,t) is the rate from {@code s} to {@code t}.
+ * The return value of apply is the intermediate result for the next transition.
+ *
+ * Default implementation: The default implementation relies on iterating over the
+ * iterator returned by {@code getTransitionsIterator()}.
+ *
Note: This method is the base for the default implementation of the numerical
+ * computation methods (mvMult, etc). In derived classes, it may thus be worthwhile to
+ * provide a specialised implementation for this method that avoids using the Iterator mechanism.
+ *
+ * @param state the state
+ * @param init initial result value
+ * @param fn the reducing function
+ */
+ public default T reduceTransitions(int state, T init, ObjTransitionFunction fn)
+ {
+ T result = init;
+ for (Iterator> it = getTransitionsIterator(state); it.hasNext(); ) {
+ Entry e = it.next();
+ result = fn.apply(result, state, e.getKey(), e.getValue());
+ }
+ return result;
+ }
+
+ /**
+ * Primitive specialisation of {@code ObjTransitionFunction} for {@code double} results.
+ *
+ * @see ObjTransitionFunction
+ */
+ @FunctionalInterface
+ public static interface DoubleTransitionFunction
+ {
+ double apply(double result, int s, int t, double d);
+ }
+
+ /**
+ * Primitive specialisation of {@code reduce} for {@code double} values.
+ *
+ * @see #reduceTransitions(int, Object, ObjTransitionFunction)
+ */
+ public default double reduceTransitions(int state, double init, DoubleTransitionFunction fn)
+ {
+ double result = init;
+ for (Iterator> it = getTransitionsIterator(state); it.hasNext(); ) {
+ Entry e = it.next();
+ result = fn.apply(result, state, e.getKey(), e.getValue());
+ }
+ return result;
+ }
+
+ /**
+ * Primitive specialisation of {@code ObjTransitionFunction} for {@code int} values.
+ *
+ * @see ObjTransitionFunction
+ */
+ @FunctionalInterface
+ public interface IntTransitionFunction
+ {
+ int apply(int result, int s, int t, double d);
+ }
+
+ /**
+ * Primitive specialisation of {@code reduce} for {@code int} values.
+ *
+ * @see #reduceTransitions(int, Object, ObjTransitionFunction)
+ */
+ public default int reduceTransitions(int state, int init, IntTransitionFunction fn)
+ {
+ int result = init;
+ for (Iterator> it = getTransitionsIterator(state); it.hasNext(); ) {
Entry e = it.next();
- c.accept(s, e.getKey(), e.getValue());
+ result = fn.apply(result, state, e.getKey(), e.getValue());
}
+ return result;
+ }
+
+ /**
+ * Primitive specialisation of {@code ObjTransitionFunction} for {@code long} values.
+ *
+ * @see ObjTransitionFunction
+ */
+ @FunctionalInterface
+ public interface LongTransitionFunction
+ {
+ long apply(long result, int s, int t, double d);
+ }
+
+ /**
+ * Primitive specialisation of {@code reduce} for {@code long} values.
+ *
+ * @see #reduceTransitions(int, Object, ObjTransitionFunction)
+ */
+ public default long reduceTransitions(int state, long init, LongTransitionFunction fn)
+ {
+ long result = init;
+ for (Iterator> it = getTransitionsIterator(state); it.hasNext(); ) {
+ Entry e = it.next();
+ result = fn.apply(result, state, e.getKey(), e.getValue());
+ }
+ return result;
}
/**
@@ -103,29 +224,17 @@ public interface TransitionToDoubleFunction {
}
/**
- * Iterate over the outgoing transitions of state {@code s}, call the function {@code f}
+ * Iterate over the outgoing transitions of state {@code state}, call the function {@code f}
* and return the sum of the result values:
*
- * Return sum_t f(s, t, P(s,t)), where t ranges over the successors of s.
+ * Return sum_t f(state, t, P(s,t)), where t ranges over the successors of state.
*
- * @param s the state s
- * @param c the consumer
+ * @param state the state
+ * @param f the function
*/
- public default double sumOverTransitions(final int s, final TransitionToDoubleFunction f)
+ public default double sumOverTransitions(int state, TransitionToDoubleFunction f)
{
- class Sum {
- double sum = 0.0;
-
- void accept(int s, int t, double d)
- {
- sum += f.apply(s, t, d);
- }
- }
-
- Sum sum = new Sum();
- forEachTransition(s, sum::accept);
-
- return sum.sum;
+ return reduceTransitions(state, 0.0, (r, s, t, d) -> r + f.apply(s, t, d));
}
/**
diff --git a/prism/src/explicit/DTMCEmbeddedSimple.java b/prism/src/explicit/DTMCEmbeddedSimple.java
index 717d6e9234..cbc2a0709c 100644
--- a/prism/src/explicit/DTMCEmbeddedSimple.java
+++ b/prism/src/explicit/DTMCEmbeddedSimple.java
@@ -257,17 +257,63 @@ public Double setValue(Double value)
}
@Override
- public void forEachTransition(int s, TransitionConsumer c)
+ public T reduceTransitions(int state, T init, ObjTransitionFunction fn)
{
- final double er = exitRates[s];
+ double er = exitRates[state];
+ T result = init;
if (er == 0) {
// exit rate = 0 -> prob 1 self loop
- c.accept(s, s, 1.0);
+ fn.apply(result, state, state, 1.0);
} else {
- ctmc.forEachTransition(s, (s_,t,rate) -> {
- c.accept(s_, t, rate / er);
- });
+ ctmc.reduceTransitions(state, result, (r, s, t, rate) ->
+ fn.apply(r, s, t, rate / er));
}
+ return result;
+ }
+
+ @Override
+ public double reduceTransitions(int state, double init, DoubleTransitionFunction fn)
+ {
+ double er = exitRates[state];
+ double result = init;
+ if (er == 0) {
+ // exit rate = 0 -> prob 1 self loop
+ fn.apply(result, state, state, 1.0);
+ } else {
+ ctmc.reduceTransitions(state, result, (r, s, t, rate) ->
+ fn.apply(r, s, t, rate / er));
+ }
+ return result;
+ }
+
+ @Override
+ public int reduceTransitions(int state, int init, IntTransitionFunction fn)
+ {
+ double er = exitRates[state];
+ int result = init;
+ if (er == 0) {
+ // exit rate = 0 -> prob 1 self loop
+ fn.apply(result, state, state, 1.0);
+ } else {
+ ctmc.reduceTransitions(state, result, (int r, int s, int t, double rate) ->
+ fn.apply(r, s, t, rate / er));
+ }
+ return result;
+ }
+
+ @Override
+ public long reduceTransitions(int state, long init, LongTransitionFunction fn)
+ {
+ double er = exitRates[state];
+ long result = init;
+ if (er == 0) {
+ // exit rate = 0 -> prob 1 self loop
+ fn.apply(result, state, state, 1.0);
+ } else {
+ ctmc.reduceTransitions(state, result, (long r, int s, int t, double rate) ->
+ fn.apply(r, s, t, rate / er));
+ }
+ return result;
}
public double mvMultSingle(int s, double vect[])
diff --git a/prism/src/explicit/DTMCFromMDPAndMDStrategy.java b/prism/src/explicit/DTMCFromMDPAndMDStrategy.java
index fe9ca18ec8..8ee7a29388 100644
--- a/prism/src/explicit/DTMCFromMDPAndMDStrategy.java
+++ b/prism/src/explicit/DTMCFromMDPAndMDStrategy.java
@@ -182,14 +182,41 @@ public Iterator> getTransitionsIterator(int s)
}
@Override
- public void forEachTransition(int s, TransitionConsumer c)
+ public T reduceTransitions(int state, T init, ObjTransitionFunction fn)
{
- if (!strat.isChoiceDefined(s)) {
- return;
+ if (!strat.isChoiceDefined(state)) {
+ return init;
}
- mdp.forEachTransition(s, strat.getChoiceIndex(s), c::accept);
+ return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn);
}
+ @Override
+ public double reduceTransitions(int state, double init, DoubleTransitionFunction fn)
+ {
+ if (!strat.isChoiceDefined(state)) {
+ return init;
+ }
+ return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn);
+ }
+
+ @Override
+ public int reduceTransitions(int state, int init, IntTransitionFunction fn)
+ {
+ if (!strat.isChoiceDefined(state)) {
+ return init;
+ }
+ return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn);
+ }
+
+ @Override
+ public long reduceTransitions(int state, long init, LongTransitionFunction fn)
+ {
+ if (!strat.isChoiceDefined(state)) {
+ return init;
+ }
+ return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn);
+}
+
@Override
public double mvMultSingle(int s, double vect[])
{
diff --git a/prism/src/explicit/DTMCModelChecker.java b/prism/src/explicit/DTMCModelChecker.java
index ef20196365..b97f3a802b 100644
--- a/prism/src/explicit/DTMCModelChecker.java
+++ b/prism/src/explicit/DTMCModelChecker.java
@@ -2491,11 +2491,7 @@ public ModelCheckerResult computeSteadyStateProbsForBSCC(DTMC dtmc, BitSet state
// Note: diagsQ[state] = 0.0, as it was freshly created
// Compute negative exit rate (ignoring a possible self-loop)
- dtmc.forEachTransition(state, (s, t, prob) -> {
- if (s != t) {
- diagsQ[state] -= prob;
- }
- });
+ diagsQ[state] -= dtmc.reduceTransitions(state, 0.0, (r, s, t, p) -> (s == t) ? r : r + p);
// Note: If there are no outgoing transitions, diagsQ[state] = 0, which is fine
diff --git a/prism/src/explicit/DTMCSparse.java b/prism/src/explicit/DTMCSparse.java
index a0f33456e3..5b9f656f72 100644
--- a/prism/src/explicit/DTMCSparse.java
+++ b/prism/src/explicit/DTMCSparse.java
@@ -228,11 +228,43 @@ public void buildFromPrismExplicit(String filename) throws PrismException
//--- DTMC ---
@Override
- public void forEachTransition(int state, TransitionConsumer consumer)
+ public T reduceTransitions(int state, T init, ObjTransitionFunction fn)
{
- for (int col = rows[state], stop = rows[state+1]; col < stop; col++) {
- consumer.accept(state, columns[col], probabilities[col]);
+ T result = init;
+ for (int col = rows[state], stop = rows[state+1]; col < stop; col++){
+ result = fn.apply(result, state, columns[col], probabilities[col]);
}
+ return result;
+ }
+
+ @Override
+ public double reduceTransitions(int state, double init, DoubleTransitionFunction fn)
+ {
+ double result = init;
+ for (int col = rows[state], stop = rows[state+1]; col < stop; col++){
+ result = fn.apply(result, state, columns[col], probabilities[col]);
+ }
+ return result;
+ }
+
+ @Override
+ public int reduceTransitions(int state, int init, IntTransitionFunction fn)
+ {
+ int result = init;
+ for (int col = rows[state], stop = rows[state+1]; col < stop; col++){
+ result = fn.apply(result, state, columns[col], probabilities[col]);
+ }
+ return result;
+ }
+
+ @Override
+ public long reduceTransitions(int state, long init, LongTransitionFunction fn)
+ {
+ long result = init;
+ for (int col = rows[state], stop = rows[state+1]; col < stop; col++){
+ result = fn.apply(result, state, columns[col], probabilities[col]);
+ }
+ return result;
}
@Override
diff --git a/prism/src/explicit/MDP.java b/prism/src/explicit/MDP.java
index 613d0a2040..f456f20519 100644
--- a/prism/src/explicit/MDP.java
+++ b/prism/src/explicit/MDP.java
@@ -35,6 +35,11 @@
import java.util.PrimitiveIterator.OfInt;
import common.IterableStateSet;
+import explicit.DTMC.DoubleTransitionFunction;
+import explicit.DTMC.IntTransitionFunction;
+import explicit.DTMC.LongTransitionFunction;
+import explicit.DTMC.ObjTransitionFunction;
+import explicit.DTMC.TransitionConsumer;
import explicit.rewards.MCRewards;
import explicit.rewards.MDPRewards;
import prism.PrismUtils;
@@ -54,17 +59,7 @@ public interface MDP extends MDPGeneric
public Iterator> getTransitionsIterator(int s, int i);
/**
- * Functional interface for a consumer,
- * accepting transitions (s,t,d), i.e.,
- * from state s to state t with value d.
- */
- @FunctionalInterface
- public interface TransitionConsumer {
- void accept(int s, int t, double d);
- }
-
- /**
- * Iterate over the outgoing transitions of state {@code s} and choice {@code i}
+ * Iterate over the outgoing transitions of state {@code state} and choice {@code i}
* and call the accept method of the consumer for each of them:
*
* Call {@code accept(s,t,d)} where t is the successor state d = P(s,i,t)
@@ -80,12 +75,86 @@ public interface TransitionConsumer {
* @param i the choice i
* @param c the consumer
*/
- public default void forEachTransition(int s, int i, TransitionConsumer c)
+ public default void forEachTransition(int state, int choice, TransitionConsumer c)
+ {
+ reduceTransitions(state, choice, null, (r, s, t, d) -> {c.accept(s,t,d); return r;});
+ }
+
+ /**
+ * Iterate over the outgoing transitions of state {@code state} and choice {@code c}
+ * and apply the reducing function {@code fn}
+ * to the intermediate result and the transition:
+ *
+ * Call {@code apply(r,s,t,d)} where
+ * {@code r} is the intermediate result,
+ * {@code t} is the successor state and,
+ * {@code d} = P(s,c,t) is the probability from {@code s} to {@code t} with choice {@code c},
+ * The return value of apply is the intermediate result for the next transition.
+ *
+ * Default implementation: The default implementation relies on iterating over the
+ * iterator returned by {@code getTransitionsIterator()}.
+ *
Note: This method is the base for the default implementation of the numerical
+ * computation methods (mvMult, etc). In derived classes, it may thus be worthwhile to
+ * provide a specialised implementation for this method that avoids using the Iterator mechanism.
+ *
+ * @param state the state
+ * @param choice the choice
+ * @param init initial result value
+ * @param fn the reducing function
+ */
+ public default T reduceTransitions(int state, int choice, T init, ObjTransitionFunction fn)
+ {
+ T result = init;
+ for (Iterator> it = getTransitionsIterator(state, choice); it.hasNext(); ) {
+ Entry e = it.next();
+ result = fn.apply(result, state, e.getKey(), e.getValue());
+ }
+ return result;
+ }
+
+ /**
+ * Primitive specialisation of {@code reduce} for {@code double} values.
+ *
+ * @see #reduceTransitions(int, Object, ObjTransitionFunction)
+ */
+ public default double reduceTransitions(int state, int choice, double init, DoubleTransitionFunction fn)
{
- for (Iterator> it = getTransitionsIterator(s, i); it.hasNext(); ) {
+ double result = init;
+ for (Iterator> it = getTransitionsIterator(state, choice); it.hasNext(); ) {
Entry e = it.next();
- c.accept(s, e.getKey(), e.getValue());
+ result = fn.apply(result, state, e.getKey(), e.getValue());
}
+ return result;
+ }
+
+ /**
+ * Primitive specialisation of {@code reduce} for {@code int} values.
+ *
+ * @see #reduceTransitions(int, Object, ObjTransitionFunction)
+ */
+ public default int reduceTransitions(int state, int choice, int init, IntTransitionFunction fn)
+ {
+ int result = init;
+ for (Iterator> it = getTransitionsIterator(state, choice); it.hasNext(); ) {
+ Entry e = it.next();
+ result = fn.apply(result, state, e.getKey(), e.getValue());
+ }
+ return result;
+ }
+
+ /**
+ * Primitive specialisation of {@code reduce} for {@code long} values.
+ *
+ * @see #reduceTransitions(int, Object, ObjTransitionFunction)
+ */
+ public default long reduceTransitions(int state, int choice, long init, LongTransitionFunction fn)
+ {
+ long result = init;
+ for (Iterator> it = getTransitionsIterator(state, choice); it.hasNext(); ) {
+ Entry e = it.next();
+ result = fn.apply(result, state, e.getKey(), e.getValue());
+ }
+ return result;
}
/**
@@ -105,24 +174,12 @@ public interface TransitionToDoubleFunction {
*
* Return sum_t f(s, t, P(s,i,t)), where t ranges over the i-successors of s.
*
- * @param s the state s
- * @param c the consumer
+ * @param state the state s
+ * @param choice the consumer
*/
- public default double sumOverTransitions(final int s, final int i, final TransitionToDoubleFunction f)
+ public default double sumOverTransitions(int state, int choice, TransitionToDoubleFunction f)
{
- class Sum {
- double sum = 0.0;
-
- void accept(int s, int t, double d)
- {
- sum += f.apply(s, t, d);
- }
- }
-
- Sum sum = new Sum();
- forEachTransition(s, i, sum::accept);
-
- return sum.sum;
+ return reduceTransitions(state, choice, 0.0, (r, s, t, d) -> r + f.apply(s, t, d));
}
/**
diff --git a/prism/src/explicit/MDPSparse.java b/prism/src/explicit/MDPSparse.java
index 4180d0dfc8..e6b2aba5a7 100644
--- a/prism/src/explicit/MDPSparse.java
+++ b/prism/src/explicit/MDPSparse.java
@@ -41,6 +41,10 @@
import java.util.TreeMap;
import common.IterableStateSet;
+import explicit.DTMC.DoubleTransitionFunction;
+import explicit.DTMC.IntTransitionFunction;
+import explicit.DTMC.LongTransitionFunction;
+import explicit.DTMC.ObjTransitionFunction;
import explicit.rewards.MCRewards;
import explicit.rewards.MDPRewards;
import parser.State;
@@ -574,6 +578,54 @@ public SuccessorsIterator getSuccessors(final int s, final int i)
// Accessors (for MDP)
+ @Override
+ public T reduceTransitions(int state, int choice, T init, ObjTransitionFunction fn)
+ {
+ T result = init;
+ int start = choiceStarts[rowStarts[state] + choice];
+ int stop = choiceStarts[rowStarts[state] + choice + 1];
+ for (int col = start; col < stop; col++) {
+ result = fn.apply(result, state, cols[col], nonZeros[col]);
+ }
+ return result;
+ }
+
+ @Override
+ public double reduceTransitions(int state, int choice, double init, DoubleTransitionFunction fn)
+ {
+ double result = init;
+ int start = choiceStarts[rowStarts[state] + choice];
+ int stop = choiceStarts[rowStarts[state] + choice + 1];
+ for (int col = start; col < stop; col++) {
+ result = fn.apply(result, state, cols[col], nonZeros[col]);
+ }
+ return result;
+ }
+
+ @Override
+ public int reduceTransitions(int state, int choice, int init, IntTransitionFunction fn)
+ {
+ int result = init;
+ int start = choiceStarts[rowStarts[state] + choice];
+ int stop = choiceStarts[rowStarts[state] + choice + 1];
+ for (int col = start; col < stop; col++) {
+ result = fn.apply(result, state, cols[col], nonZeros[col]);
+ }
+ return result;
+ }
+
+ @Override
+ public long reduceTransitions(int state, int choice, long init, LongTransitionFunction fn)
+ {
+ long result = init;
+ int start = choiceStarts[rowStarts[state] + choice];
+ int stop = choiceStarts[rowStarts[state] + choice + 1];
+ for (int col = start; col < stop; col++) {
+ result = fn.apply(result, state, cols[col], nonZeros[col]);
+ }
+ return result;
+ }
+
@Override
public int getNumTransitions(int s, int i)
{