diff --git a/prism/src/explicit/DTMC.java b/prism/src/explicit/DTMC.java index b936a9f842..ed6b593f60 100644 --- a/prism/src/explicit/DTMC.java +++ b/prism/src/explicit/DTMC.java @@ -67,7 +67,7 @@ public interface TransitionConsumer { } /** - * Iterate over the outgoing transitions of state {@code s} and call the accept method + * Iterate over the outgoing transitions of state {@code state} and call the accept method * of the consumer for each of them: *
* Call {@code accept(s,t,d)} where t is the successor state and, @@ -80,15 +80,136 @@ public interface TransitionConsumer { * computation methods (mvMult, etc). In derived classes, it may thus be worthwhile to * provide a specialised implementation for this method that avoids using the Iterator mechanism. * - * @param s the state s + * @param state the state * @param c the consumer */ - public default void forEachTransition(int s, TransitionConsumer c) + public default void forEachTransition(int state, TransitionConsumer c) { - for (Iterator> it = getTransitionsIterator(s); it.hasNext(); ) { + reduceTransitions(state, null, (r, s, t, d) -> {c.accept(s,t,d); return r;}); + } + + /** + * Quaternary function that maps an Object of type {@code T}, + * and a transition from state {@code s} to state {@code t} + * with probabiliy/rate {@code d} + * to an Object of type {@code T}. + * + * @param type of first argument and result + */ + @FunctionalInterface + public static interface ObjTransitionFunction + { + T apply(T result, int s, int t, double d); + } + + /** + * Iterate over the outgoing transitions of state {@code state} + * and apply the reducing function {@code fn} + * to the intermediate result and the transition: + *
+ * Call {@code apply(r,s,t,d)} where + * {@code r} is the intermediate result, + * {@code t} is the successor state and, + * in a DTMC, {@code d} = P(s,t) is the probability from {@code s} to {@code t}, + * while in CTMC, {@code d} = R(s,t) is the rate from {@code s} to {@code t}. + * The return value of apply is the intermediate result for the next transition. + *

+ * Default implementation: The default implementation relies on iterating over the + * iterator returned by {@code getTransitionsIterator()}. + *

Note: This method is the base for the default implementation of the numerical + * computation methods (mvMult, etc). In derived classes, it may thus be worthwhile to + * provide a specialised implementation for this method that avoids using the Iterator mechanism. + * + * @param state the state + * @param init initial result value + * @param fn the reducing function + */ + public default T reduceTransitions(int state, T init, ObjTransitionFunction fn) + { + T result = init; + for (Iterator> it = getTransitionsIterator(state); it.hasNext(); ) { + Entry e = it.next(); + result = fn.apply(result, state, e.getKey(), e.getValue()); + } + return result; + } + + /** + * Primitive specialisation of {@code ObjTransitionFunction} for {@code double} results. + * + * @see ObjTransitionFunction + */ + @FunctionalInterface + public static interface DoubleTransitionFunction + { + double apply(double result, int s, int t, double d); + } + + /** + * Primitive specialisation of {@code reduce} for {@code double} values. + * + * @see #reduceTransitions(int, Object, ObjTransitionFunction) + */ + public default double reduceTransitions(int state, double init, DoubleTransitionFunction fn) + { + double result = init; + for (Iterator> it = getTransitionsIterator(state); it.hasNext(); ) { + Entry e = it.next(); + result = fn.apply(result, state, e.getKey(), e.getValue()); + } + return result; + } + + /** + * Primitive specialisation of {@code ObjTransitionFunction} for {@code int} values. + * + * @see ObjTransitionFunction + */ + @FunctionalInterface + public interface IntTransitionFunction + { + int apply(int result, int s, int t, double d); + } + + /** + * Primitive specialisation of {@code reduce} for {@code int} values. + * + * @see #reduceTransitions(int, Object, ObjTransitionFunction) + */ + public default int reduceTransitions(int state, int init, IntTransitionFunction fn) + { + int result = init; + for (Iterator> it = getTransitionsIterator(state); it.hasNext(); ) { Entry e = it.next(); - c.accept(s, e.getKey(), e.getValue()); + result = fn.apply(result, state, e.getKey(), e.getValue()); } + return result; + } + + /** + * Primitive specialisation of {@code ObjTransitionFunction} for {@code long} values. + * + * @see ObjTransitionFunction + */ + @FunctionalInterface + public interface LongTransitionFunction + { + long apply(long result, int s, int t, double d); + } + + /** + * Primitive specialisation of {@code reduce} for {@code long} values. + * + * @see #reduceTransitions(int, Object, ObjTransitionFunction) + */ + public default long reduceTransitions(int state, long init, LongTransitionFunction fn) + { + long result = init; + for (Iterator> it = getTransitionsIterator(state); it.hasNext(); ) { + Entry e = it.next(); + result = fn.apply(result, state, e.getKey(), e.getValue()); + } + return result; } /** @@ -103,29 +224,17 @@ public interface TransitionToDoubleFunction { } /** - * Iterate over the outgoing transitions of state {@code s}, call the function {@code f} + * Iterate over the outgoing transitions of state {@code state}, call the function {@code f} * and return the sum of the result values: *
- * Return sum_t f(s, t, P(s,t)), where t ranges over the successors of s. + * Return sum_t f(state, t, P(s,t)), where t ranges over the successors of state. * - * @param s the state s - * @param c the consumer + * @param state the state + * @param f the function */ - public default double sumOverTransitions(final int s, final TransitionToDoubleFunction f) + public default double sumOverTransitions(int state, TransitionToDoubleFunction f) { - class Sum { - double sum = 0.0; - - void accept(int s, int t, double d) - { - sum += f.apply(s, t, d); - } - } - - Sum sum = new Sum(); - forEachTransition(s, sum::accept); - - return sum.sum; + return reduceTransitions(state, 0.0, (r, s, t, d) -> r + f.apply(s, t, d)); } /** diff --git a/prism/src/explicit/DTMCEmbeddedSimple.java b/prism/src/explicit/DTMCEmbeddedSimple.java index 717d6e9234..cbc2a0709c 100644 --- a/prism/src/explicit/DTMCEmbeddedSimple.java +++ b/prism/src/explicit/DTMCEmbeddedSimple.java @@ -257,17 +257,63 @@ public Double setValue(Double value) } @Override - public void forEachTransition(int s, TransitionConsumer c) + public T reduceTransitions(int state, T init, ObjTransitionFunction fn) { - final double er = exitRates[s]; + double er = exitRates[state]; + T result = init; if (er == 0) { // exit rate = 0 -> prob 1 self loop - c.accept(s, s, 1.0); + fn.apply(result, state, state, 1.0); } else { - ctmc.forEachTransition(s, (s_,t,rate) -> { - c.accept(s_, t, rate / er); - }); + ctmc.reduceTransitions(state, result, (r, s, t, rate) -> + fn.apply(r, s, t, rate / er)); } + return result; + } + + @Override + public double reduceTransitions(int state, double init, DoubleTransitionFunction fn) + { + double er = exitRates[state]; + double result = init; + if (er == 0) { + // exit rate = 0 -> prob 1 self loop + fn.apply(result, state, state, 1.0); + } else { + ctmc.reduceTransitions(state, result, (r, s, t, rate) -> + fn.apply(r, s, t, rate / er)); + } + return result; + } + + @Override + public int reduceTransitions(int state, int init, IntTransitionFunction fn) + { + double er = exitRates[state]; + int result = init; + if (er == 0) { + // exit rate = 0 -> prob 1 self loop + fn.apply(result, state, state, 1.0); + } else { + ctmc.reduceTransitions(state, result, (int r, int s, int t, double rate) -> + fn.apply(r, s, t, rate / er)); + } + return result; + } + + @Override + public long reduceTransitions(int state, long init, LongTransitionFunction fn) + { + double er = exitRates[state]; + long result = init; + if (er == 0) { + // exit rate = 0 -> prob 1 self loop + fn.apply(result, state, state, 1.0); + } else { + ctmc.reduceTransitions(state, result, (long r, int s, int t, double rate) -> + fn.apply(r, s, t, rate / er)); + } + return result; } public double mvMultSingle(int s, double vect[]) diff --git a/prism/src/explicit/DTMCFromMDPAndMDStrategy.java b/prism/src/explicit/DTMCFromMDPAndMDStrategy.java index fe9ca18ec8..8ee7a29388 100644 --- a/prism/src/explicit/DTMCFromMDPAndMDStrategy.java +++ b/prism/src/explicit/DTMCFromMDPAndMDStrategy.java @@ -182,14 +182,41 @@ public Iterator> getTransitionsIterator(int s) } @Override - public void forEachTransition(int s, TransitionConsumer c) + public T reduceTransitions(int state, T init, ObjTransitionFunction fn) { - if (!strat.isChoiceDefined(s)) { - return; + if (!strat.isChoiceDefined(state)) { + return init; } - mdp.forEachTransition(s, strat.getChoiceIndex(s), c::accept); + return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn); } + @Override + public double reduceTransitions(int state, double init, DoubleTransitionFunction fn) + { + if (!strat.isChoiceDefined(state)) { + return init; + } + return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn); + } + + @Override + public int reduceTransitions(int state, int init, IntTransitionFunction fn) + { + if (!strat.isChoiceDefined(state)) { + return init; + } + return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn); + } + + @Override + public long reduceTransitions(int state, long init, LongTransitionFunction fn) + { + if (!strat.isChoiceDefined(state)) { + return init; + } + return mdp.reduceTransitions(state, strat.getChoiceIndex(state), init, fn); +} + @Override public double mvMultSingle(int s, double vect[]) { diff --git a/prism/src/explicit/DTMCModelChecker.java b/prism/src/explicit/DTMCModelChecker.java index ef20196365..b97f3a802b 100644 --- a/prism/src/explicit/DTMCModelChecker.java +++ b/prism/src/explicit/DTMCModelChecker.java @@ -2491,11 +2491,7 @@ public ModelCheckerResult computeSteadyStateProbsForBSCC(DTMC dtmc, BitSet state // Note: diagsQ[state] = 0.0, as it was freshly created // Compute negative exit rate (ignoring a possible self-loop) - dtmc.forEachTransition(state, (s, t, prob) -> { - if (s != t) { - diagsQ[state] -= prob; - } - }); + diagsQ[state] -= dtmc.reduceTransitions(state, 0.0, (r, s, t, p) -> (s == t) ? r : r + p); // Note: If there are no outgoing transitions, diagsQ[state] = 0, which is fine diff --git a/prism/src/explicit/DTMCSparse.java b/prism/src/explicit/DTMCSparse.java index a0f33456e3..5b9f656f72 100644 --- a/prism/src/explicit/DTMCSparse.java +++ b/prism/src/explicit/DTMCSparse.java @@ -228,11 +228,43 @@ public void buildFromPrismExplicit(String filename) throws PrismException //--- DTMC --- @Override - public void forEachTransition(int state, TransitionConsumer consumer) + public T reduceTransitions(int state, T init, ObjTransitionFunction fn) { - for (int col = rows[state], stop = rows[state+1]; col < stop; col++) { - consumer.accept(state, columns[col], probabilities[col]); + T result = init; + for (int col = rows[state], stop = rows[state+1]; col < stop; col++){ + result = fn.apply(result, state, columns[col], probabilities[col]); } + return result; + } + + @Override + public double reduceTransitions(int state, double init, DoubleTransitionFunction fn) + { + double result = init; + for (int col = rows[state], stop = rows[state+1]; col < stop; col++){ + result = fn.apply(result, state, columns[col], probabilities[col]); + } + return result; + } + + @Override + public int reduceTransitions(int state, int init, IntTransitionFunction fn) + { + int result = init; + for (int col = rows[state], stop = rows[state+1]; col < stop; col++){ + result = fn.apply(result, state, columns[col], probabilities[col]); + } + return result; + } + + @Override + public long reduceTransitions(int state, long init, LongTransitionFunction fn) + { + long result = init; + for (int col = rows[state], stop = rows[state+1]; col < stop; col++){ + result = fn.apply(result, state, columns[col], probabilities[col]); + } + return result; } @Override diff --git a/prism/src/explicit/MDP.java b/prism/src/explicit/MDP.java index 613d0a2040..f456f20519 100644 --- a/prism/src/explicit/MDP.java +++ b/prism/src/explicit/MDP.java @@ -35,6 +35,11 @@ import java.util.PrimitiveIterator.OfInt; import common.IterableStateSet; +import explicit.DTMC.DoubleTransitionFunction; +import explicit.DTMC.IntTransitionFunction; +import explicit.DTMC.LongTransitionFunction; +import explicit.DTMC.ObjTransitionFunction; +import explicit.DTMC.TransitionConsumer; import explicit.rewards.MCRewards; import explicit.rewards.MDPRewards; import prism.PrismUtils; @@ -54,17 +59,7 @@ public interface MDP extends MDPGeneric public Iterator> getTransitionsIterator(int s, int i); /** - * Functional interface for a consumer, - * accepting transitions (s,t,d), i.e., - * from state s to state t with value d. - */ - @FunctionalInterface - public interface TransitionConsumer { - void accept(int s, int t, double d); - } - - /** - * Iterate over the outgoing transitions of state {@code s} and choice {@code i} + * Iterate over the outgoing transitions of state {@code state} and choice {@code i} * and call the accept method of the consumer for each of them: *
* Call {@code accept(s,t,d)} where t is the successor state d = P(s,i,t) @@ -80,12 +75,86 @@ public interface TransitionConsumer { * @param i the choice i * @param c the consumer */ - public default void forEachTransition(int s, int i, TransitionConsumer c) + public default void forEachTransition(int state, int choice, TransitionConsumer c) + { + reduceTransitions(state, choice, null, (r, s, t, d) -> {c.accept(s,t,d); return r;}); + } + + /** + * Iterate over the outgoing transitions of state {@code state} and choice {@code c} + * and apply the reducing function {@code fn} + * to the intermediate result and the transition: + *
+ * Call {@code apply(r,s,t,d)} where + * {@code r} is the intermediate result, + * {@code t} is the successor state and, + * {@code d} = P(s,c,t) is the probability from {@code s} to {@code t} with choice {@code c}, + * The return value of apply is the intermediate result for the next transition. + *

+ * Default implementation: The default implementation relies on iterating over the + * iterator returned by {@code getTransitionsIterator()}. + *

Note: This method is the base for the default implementation of the numerical + * computation methods (mvMult, etc). In derived classes, it may thus be worthwhile to + * provide a specialised implementation for this method that avoids using the Iterator mechanism. + * + * @param state the state + * @param choice the choice + * @param init initial result value + * @param fn the reducing function + */ + public default T reduceTransitions(int state, int choice, T init, ObjTransitionFunction fn) + { + T result = init; + for (Iterator> it = getTransitionsIterator(state, choice); it.hasNext(); ) { + Entry e = it.next(); + result = fn.apply(result, state, e.getKey(), e.getValue()); + } + return result; + } + + /** + * Primitive specialisation of {@code reduce} for {@code double} values. + * + * @see #reduceTransitions(int, Object, ObjTransitionFunction) + */ + public default double reduceTransitions(int state, int choice, double init, DoubleTransitionFunction fn) { - for (Iterator> it = getTransitionsIterator(s, i); it.hasNext(); ) { + double result = init; + for (Iterator> it = getTransitionsIterator(state, choice); it.hasNext(); ) { Entry e = it.next(); - c.accept(s, e.getKey(), e.getValue()); + result = fn.apply(result, state, e.getKey(), e.getValue()); } + return result; + } + + /** + * Primitive specialisation of {@code reduce} for {@code int} values. + * + * @see #reduceTransitions(int, Object, ObjTransitionFunction) + */ + public default int reduceTransitions(int state, int choice, int init, IntTransitionFunction fn) + { + int result = init; + for (Iterator> it = getTransitionsIterator(state, choice); it.hasNext(); ) { + Entry e = it.next(); + result = fn.apply(result, state, e.getKey(), e.getValue()); + } + return result; + } + + /** + * Primitive specialisation of {@code reduce} for {@code long} values. + * + * @see #reduceTransitions(int, Object, ObjTransitionFunction) + */ + public default long reduceTransitions(int state, int choice, long init, LongTransitionFunction fn) + { + long result = init; + for (Iterator> it = getTransitionsIterator(state, choice); it.hasNext(); ) { + Entry e = it.next(); + result = fn.apply(result, state, e.getKey(), e.getValue()); + } + return result; } /** @@ -105,24 +174,12 @@ public interface TransitionToDoubleFunction { *
* Return sum_t f(s, t, P(s,i,t)), where t ranges over the i-successors of s. * - * @param s the state s - * @param c the consumer + * @param state the state s + * @param choice the consumer */ - public default double sumOverTransitions(final int s, final int i, final TransitionToDoubleFunction f) + public default double sumOverTransitions(int state, int choice, TransitionToDoubleFunction f) { - class Sum { - double sum = 0.0; - - void accept(int s, int t, double d) - { - sum += f.apply(s, t, d); - } - } - - Sum sum = new Sum(); - forEachTransition(s, i, sum::accept); - - return sum.sum; + return reduceTransitions(state, choice, 0.0, (r, s, t, d) -> r + f.apply(s, t, d)); } /** diff --git a/prism/src/explicit/MDPSparse.java b/prism/src/explicit/MDPSparse.java index 4180d0dfc8..e6b2aba5a7 100644 --- a/prism/src/explicit/MDPSparse.java +++ b/prism/src/explicit/MDPSparse.java @@ -41,6 +41,10 @@ import java.util.TreeMap; import common.IterableStateSet; +import explicit.DTMC.DoubleTransitionFunction; +import explicit.DTMC.IntTransitionFunction; +import explicit.DTMC.LongTransitionFunction; +import explicit.DTMC.ObjTransitionFunction; import explicit.rewards.MCRewards; import explicit.rewards.MDPRewards; import parser.State; @@ -574,6 +578,54 @@ public SuccessorsIterator getSuccessors(final int s, final int i) // Accessors (for MDP) + @Override + public T reduceTransitions(int state, int choice, T init, ObjTransitionFunction fn) + { + T result = init; + int start = choiceStarts[rowStarts[state] + choice]; + int stop = choiceStarts[rowStarts[state] + choice + 1]; + for (int col = start; col < stop; col++) { + result = fn.apply(result, state, cols[col], nonZeros[col]); + } + return result; + } + + @Override + public double reduceTransitions(int state, int choice, double init, DoubleTransitionFunction fn) + { + double result = init; + int start = choiceStarts[rowStarts[state] + choice]; + int stop = choiceStarts[rowStarts[state] + choice + 1]; + for (int col = start; col < stop; col++) { + result = fn.apply(result, state, cols[col], nonZeros[col]); + } + return result; + } + + @Override + public int reduceTransitions(int state, int choice, int init, IntTransitionFunction fn) + { + int result = init; + int start = choiceStarts[rowStarts[state] + choice]; + int stop = choiceStarts[rowStarts[state] + choice + 1]; + for (int col = start; col < stop; col++) { + result = fn.apply(result, state, cols[col], nonZeros[col]); + } + return result; + } + + @Override + public long reduceTransitions(int state, int choice, long init, LongTransitionFunction fn) + { + long result = init; + int start = choiceStarts[rowStarts[state] + choice]; + int stop = choiceStarts[rowStarts[state] + choice + 1]; + for (int col = start; col < stop; col++) { + result = fn.apply(result, state, cols[col], nonZeros[col]); + } + return result; + } + @Override public int getNumTransitions(int s, int i) {