/*
 * Decompiled with CFR 0.152.
 */
package com.bmw.hmm;

import com.bmw.hmm.SequenceState;
import com.bmw.hmm.Transition;
import com.bmw.hmm.Utils;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class ViterbiAlgorithm<S, O, D> {
    private Map<S, ExtendedState<S, O, D>> lastExtendedStates;
    private Collection<S> prevCandidates;
    private Map<S, Double> message;
    private boolean isBroken = false;
    private List<Map<S, Double>> messageHistory;

    public ViterbiAlgorithm() {
        this(false);
    }

    public ViterbiAlgorithm(boolean keepMessageHistory) {
        if (keepMessageHistory) {
            this.messageHistory = new ArrayList<Map<S, Double>>();
        }
    }

    public void startWithInitialStateProbabilities(Collection<S> initialStates, Map<S, Double> initialLogProbabilities) {
        this.initializeStateProbabilities(null, initialStates, initialLogProbabilities);
    }

    public void startWithInitialObservation(O observation, Collection<S> candidates, Map<S, Double> emissionLogProbabilities) {
        this.initializeStateProbabilities(observation, candidates, emissionLogProbabilities);
    }

    public void nextStep(O observation, Collection<S> candidates, Map<S, Double> emissionLogProbabilities, Map<Transition<S>, Double> transitionLogProbabilities, Map<Transition<S>, D> transitionDescriptors) {
        if (this.message == null) {
            throw new IllegalStateException("startWithInitialStateProbabilities() or startWithInitialObservation() must be called first.");
        }
        if (this.isBroken) {
            throw new IllegalStateException("Method must not be called after an HMM break.");
        }
        ForwardStepResult<S, O, D> forwardStepResult = this.forwardStep(observation, this.prevCandidates, candidates, this.message, emissionLogProbabilities, transitionLogProbabilities, transitionDescriptors);
        this.isBroken = this.hmmBreak(forwardStepResult.newMessage);
        if (this.isBroken) {
            return;
        }
        if (this.messageHistory != null) {
            this.messageHistory.add(forwardStepResult.newMessage);
        }
        this.message = forwardStepResult.newMessage;
        this.lastExtendedStates = forwardStepResult.newExtendedStates;
        this.prevCandidates = new ArrayList<S>(candidates);
    }

    public void nextStep(O observation, Collection<S> candidates, Map<S, Double> emissionLogProbabilities, Map<Transition<S>, Double> transitionLogProbabilities) {
        this.nextStep(observation, candidates, emissionLogProbabilities, transitionLogProbabilities, new LinkedHashMap());
    }

    public List<SequenceState<S, O, D>> computeMostLikelySequence() {
        if (this.message == null) {
            return new ArrayList<SequenceState<S, O, D>>();
        }
        return this.retrieveMostLikelySequence();
    }

    public boolean isBroken() {
        return this.isBroken;
    }

    public List<Map<S, Double>> messageHistory() {
        return this.messageHistory;
    }

    public String messageHistoryString() {
        if (this.messageHistory == null) {
            throw new IllegalStateException("Message history was not recorded.");
        }
        StringBuilder sb = new StringBuilder();
        sb.append("Message history with log probabilies\n\n");
        int i = 0;
        for (Map<S, Double> message : this.messageHistory) {
            sb.append("Time step " + i + "\n");
            ++i;
            for (S state : message.keySet()) {
                sb.append(state + ": " + message.get(state) + "\n");
            }
            sb.append("\n");
        }
        return sb.toString();
    }

    private boolean hmmBreak(Map<S, Double> message) {
        for (double logProbability : message.values()) {
            if (logProbability == Double.NEGATIVE_INFINITY) continue;
            return false;
        }
        return true;
    }

    private void initializeStateProbabilities(O observation, Collection<S> candidates, Map<S, Double> initialLogProbabilities) {
        if (this.message != null) {
            throw new IllegalStateException("Initial probabilities have already been set.");
        }
        LinkedHashMap<S, Double> initialMessage = new LinkedHashMap<S, Double>();
        for (S candidate : candidates) {
            Double logProbability = initialLogProbabilities.get(candidate);
            if (logProbability == null) {
                throw new NullPointerException("No initial probability for " + candidate);
            }
            initialMessage.put(candidate, logProbability);
        }
        this.isBroken = this.hmmBreak(initialMessage);
        if (this.isBroken) {
            return;
        }
        this.message = initialMessage;
        if (this.messageHistory != null) {
            this.messageHistory.add(this.message);
        }
        this.lastExtendedStates = new LinkedHashMap<S, ExtendedState<S, O, D>>();
        for (S candidate : candidates) {
            this.lastExtendedStates.put(candidate, new ExtendedState<S, O, Object>(candidate, null, observation, null));
        }
        this.prevCandidates = new ArrayList<S>(candidates);
    }

    private ForwardStepResult<S, O, D> forwardStep(O observation, Collection<S> prevCandidates, Collection<S> curCandidates, Map<S, Double> message, Map<S, Double> emissionLogProbabilities, Map<Transition<S>, Double> transitionLogProbabilities, Map<Transition<S>, D> transitionDescriptors) {
        ForwardStepResult result = new ForwardStepResult(curCandidates.size());
        assert (!prevCandidates.isEmpty());
        for (S curState : curCandidates) {
            double maxLogProbability = Double.NEGATIVE_INFINITY;
            Object maxPrevState = null;
            for (S prevState : prevCandidates) {
                double logProbability = message.get(prevState) + this.transitionLogProbability(prevState, curState, transitionLogProbabilities);
                if (!(logProbability > maxLogProbability)) continue;
                maxLogProbability = logProbability;
                maxPrevState = prevState;
            }
            result.newMessage.put(curState, maxLogProbability + emissionLogProbabilities.get(curState));
            if (maxPrevState == null) continue;
            Transition<Object> transition = new Transition<Object>(maxPrevState, curState);
            ExtendedState<S, O, D> extendedState = new ExtendedState<S, O, D>(curState, this.lastExtendedStates.get(maxPrevState), observation, transitionDescriptors.get(transition));
            result.newExtendedStates.put(curState, extendedState);
        }
        return result;
    }

    private double transitionLogProbability(S prevState, S curState, Map<Transition<S>, Double> transitionLogProbabilities) {
        Double transitionLogProbability = transitionLogProbabilities.get(new Transition<S>(prevState, curState));
        if (transitionLogProbability == null) {
            return Double.NEGATIVE_INFINITY;
        }
        return transitionLogProbability;
    }

    private S mostLikelyState() {
        assert (!this.message.isEmpty());
        S result = null;
        double maxLogProbability = Double.NEGATIVE_INFINITY;
        for (Map.Entry<S, Double> entry : this.message.entrySet()) {
            if (!(entry.getValue() > maxLogProbability)) continue;
            result = entry.getKey();
            maxLogProbability = entry.getValue();
        }
        assert (result != null);
        return result;
    }

    private List<SequenceState<S, O, D>> retrieveMostLikelySequence() {
        assert (!this.message.isEmpty());
        S lastState = this.mostLikelyState();
        ArrayList<SequenceState<S, O, D>> result = new ArrayList<SequenceState<S, O, D>>();
        ExtendedState<S, O, D> es = this.lastExtendedStates.get(lastState);
        while (es != null) {
            SequenceState ss = new SequenceState(es.state, es.observation, es.transitionDescriptor);
            result.add(ss);
            es = es.backPointer;
        }
        Collections.reverse(result);
        return result;
    }

    private static class ForwardStepResult<S, O, D> {
        final Map<S, Double> newMessage;
        final Map<S, ExtendedState<S, O, D>> newExtendedStates;

        ForwardStepResult(int numberStates) {
            this.newMessage = new LinkedHashMap<S, Double>(Utils.initialHashMapCapacity(numberStates));
            this.newExtendedStates = new LinkedHashMap<S, ExtendedState<S, O, D>>(Utils.initialHashMapCapacity(numberStates));
        }
    }

    private static class ExtendedState<S, O, D> {
        S state;
        ExtendedState<S, O, D> backPointer;
        O observation;
        D transitionDescriptor;

        ExtendedState(S state, ExtendedState<S, O, D> backPointer, O observation, D transitionDescriptor) {
            this.state = state;
            this.backPointer = backPointer;
            this.observation = observation;
            this.transitionDescriptor = transitionDescriptor;
        }
    }
}

