1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
//! Satisfyability count, active nodes count

use std::collections::hash_map::Entry::{Occupied, Vacant};

use num_bigint::BigUint;
use num_traits::{One, Zero};

use super::{hash_select::HashMap, DDManager};
use crate::bdd_node::{NodeID, VarID};

impl DDManager {
    #[allow(dead_code)]
    fn is_sat(&self, node: u32) -> bool {
        node != 0
    }

    pub fn sat_count(&self, f: NodeID) -> BigUint {
        self.sat_count_rec(f, &mut HashMap::default())
    }

    fn sat_count_rec(&self, f: NodeID, cache: &mut HashMap<NodeID, BigUint>) -> BigUint {
        let mut total: BigUint = Zero::zero();
        let node_id = f;

        if node_id == NodeID(0) {
            return Zero::zero();
        } else if node_id == NodeID(1) {
            return One::one();
        } else {
            let node = &self.nodes.get(&node_id).unwrap();

            let low = &self.nodes.get(&node.low).unwrap();
            let high = &self.nodes.get(&node.high).unwrap();

            let low_jump = if low.var == VarID(0) {
                self.order.len() as u32 - self.order[node.var.0 as usize] - 1
            } else {
                self.order[low.var.0 as usize] - self.order[node.var.0 as usize] - 1
            };

            let high_jump = if high.var == VarID(0) {
                self.order.len() as u32 - self.order[node.var.0 as usize] - 1
            } else {
                self.order[high.var.0 as usize] - self.order[node.var.0 as usize] - 1
            };

            let low_fac = BigUint::parse_bytes(b"2", 10).unwrap().pow(low_jump);
            let high_fac = BigUint::parse_bytes(b"2", 10).unwrap().pow(high_jump);

            total += match cache.get(&node.low) {
                Some(x) => x * low_fac,
                None => self.sat_count_rec(node.low, cache) * low_fac,
            };

            total += match cache.get(&node.high) {
                Some(x) => x * high_fac,
                None => self.sat_count_rec(node.high, cache) * high_fac,
            };
        };

        cache.insert(f, total.clone());

        total
    }

    pub fn count_active(&self, f: NodeID) -> u32 {
        // We use HashMap<NodeID, ()> instead of HashSet<NodeID> to be able to use the .entry()
        // API below. This turns out to be faster, since it avoids the double lookup if the
        // ID is not yet known (!contains -> insert).
        let mut nodes = HashMap::<NodeID, ()>::default();
        nodes.reserve(self.nodes.len());

        let mut stack = vec![f];
        stack.reserve(self.nodes.len());

        while !stack.is_empty() {
            let x = stack.pop().unwrap();
            let entry = nodes.entry(x);

            match entry {
                Occupied(_) => continue, // Node already counted
                Vacant(vacant_entry) => {
                    // Store node, add children to stack
                    let node = self.nodes.get(&x).unwrap();
                    stack.push(node.low);
                    stack.push(node.high);
                    vacant_entry.insert(());
                }
            }
        }

        nodes.len() as u32
    }
}