1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
use std::pin::Pin;

use smile::{
    autocxx::{c_int, WithinBox},
    ffi::DSL_pattern,
};

use crate::network::Network;

use super::{EdgeType, PatternError, PatternNode, Result};

fn sort_pair(a: usize, b: usize) -> (usize, usize) {
    if a < b {
        (a, b)
    } else {
        (b, a)
    }
}

/// A pattern, otherwise called an essential graph, is a collection of nodes and edges that can be
/// either directed
///
/// Following methods have been implemented:
/// - [x] `enum EdgeType {None,Undirected,Directed}`
/// - [x] `int GetSize() const`
/// - [x] `EdgeType GetEdge(int from, int to) const`
/// - [x] `void SetEdge(int from, int to, EdgeType type)`
/// - [x] `bool HasDirectedPath(int from, int to) const`
/// - [x] `bool HasCycle() const`
/// - [x] `bool IsDAG() const`
/// - [x] `bool ToDAG()`
/// - [ ] `void Set(DSL_network &input)`
/// - [ ] `bool ToNetwork(const DSL_dataset &ds, DSL_network &net)`
/// - [ ] `void Print() const`
///
/// Differences from the original API:
/// - Method `void SetSize(int size)` has not been implemented, as the default way to construct a [Pattern] is to call [Pattern::with_capacity]
/// - Following methods have been implemented on [super::PatternNode]:
///   - `void GetAdjacentNodes(const int node, std::vector<int>& adj) const`
///   - `void GetParents(const int node, std::vector<int>& par) const`
///   - `void GetChildren(const int node, std::vector<int>& child) const`
///   - `bool HasIncomingEdge(int to) const`
///   - `bool HasOutgoingEdge(int from) const`
pub struct Pattern {
    pub(super) dsl_pattern: Pin<Box<DSL_pattern>>,
    node_count: usize,
}

impl Pattern {
    /// Create a new pattern with no nodes or edges
    ///
    /// Example usage:
    /// ```rust
    /// # use nice_smile::pattern::Pattern;
    /// // Create a pattern with 5 nodes
    /// let mut pattern = Pattern::with_capacity(5);
    /// # assert_eq!(pattern.get_node_count(), 5);
    /// ```
    pub fn with_capacity(node_count: usize) -> Self {
        let mut dsl_pattern = DSL_pattern::new().within_box();
        dsl_pattern.as_mut().SetSize(c_int(node_count as i32));
        Self {
            dsl_pattern,
            node_count,
        }
    }

    /// Get the number of nodes in the pattern
    ///
    /// Example usage:
    /// ```rust
    /// # use nice_smile::pattern::Pattern;
    /// let mut pattern = Pattern::with_capacity(5);
    /// assert_eq!(pattern.get_node_count(), 5);
    /// ```
    pub fn get_node_count(&self) -> usize {
        self.node_count
    }

    fn assert_node_index(&self, index: usize) -> Result<()> {
        if index >= self.get_node_count() {
            return Err(PatternError::NodeIndexOutOfBounds(index));
        }

        Ok(())
    }

    /// Get [EdgeType] between two nodes
    pub fn get_edge(&self, from: usize, to: usize) -> Result<EdgeType> {
        self.assert_node_index(from)?;
        self.assert_node_index(to)?;

        let (from, to) = sort_pair(from, to);

        let edge_type = self
            .dsl_pattern
            .as_ref()
            .GetEdge(c_int(from as i32), c_int(to as i32));

        Ok(EdgeType::from_dsl(edge_type, from, to))
    }

    /// Check if there is a directed path from one node to another.
    ///
    /// Example usage:
    /// ```rust
    /// # use nice_smile::pattern::{Pattern, EdgeType};
    /// let mut pattern = Pattern::with_capacity(3);
    /// pattern.remove_edge(0, 2); // Remove edge A --- C
    /// pattern.direct_edge(0, 1); // Direct edge A --> B
    /// pattern.direct_edge(1, 2); // Direct edge B --> C
    ///
    /// // Assert directed path A --> B --> C
    /// assert_eq!(pattern.has_directed_path(0, 2), Ok(true));
    /// ```
    pub fn has_directed_path(&self, from: usize, to: usize) -> Result<bool> {
        self.assert_node_index(from)?;
        self.assert_node_index(to)?;

        Ok(self
            .dsl_pattern
            .as_ref()
            .HasDirectedPath(c_int(from as i32), c_int(to as i32)))
    }

    /// Set an edge between two nodes to a specific [EdgeType].
    pub fn set_edge(&mut self, from: usize, to: usize, edge_type: EdgeType) -> Result<()> {
        // NOTE: We can skip assertions here, as all of following calls do the asserts for us
        match edge_type {
            EdgeType::Absent => self.remove_edge(from, to),
            EdgeType::Undirected => self.undirect_edge(from, to),
            EdgeType::Directed(from, to) => self.direct_edge(from, to),
        }
    }

    fn set_edge_unchecked(&mut self, from: usize, to: usize, edge_type: EdgeType) {
        self.dsl_pattern
            .as_mut()
            .SetEdge(c_int(from as i32), c_int(to as i32), edge_type.into());
    }

    /// Remove an edge between two nodes
    pub fn remove_edge(&mut self, from: usize, to: usize) -> Result<()> {
        self.assert_node_index(from)?;
        self.assert_node_index(to)?;

        let (from, to) = sort_pair(from, to);
        self.set_edge_unchecked(from, to, EdgeType::Absent);

        Ok(())
    }

    /// Direct an edge between two nodes
    pub fn direct_edge(&mut self, from: usize, to: usize) -> Result<()> {
        self.assert_node_index(from)?;
        self.assert_node_index(to)?;

        // NOTE: We do not shadow the variables here, as we need to call the function with the correct order
        let (a, b) = sort_pair(from, to);
        self.set_edge_unchecked(a, b, EdgeType::Directed(from, to));

        Ok(())
    }

    /// Undirect an edge between two nodes
    pub fn undirect_edge(&mut self, from: usize, to: usize) -> Result<()> {
        self.assert_node_index(from)?;
        self.assert_node_index(to)?;

        let (from, to) = sort_pair(from, to);
        self.set_edge_unchecked(from, to, EdgeType::Undirected);

        Ok(())
    }

    /// Try to convert the pattern to a Directed Acyclic Graph
    pub fn to_dag(&mut self) -> Result<()> {
        match self.dsl_pattern.as_mut().ToDAG() {
            true => Ok(()),
            false => Err(PatternError::ContainsCycle),
        }
    }

    /// Check if the pattern is a Directed Acyclic Graph
    pub fn is_dag(&self) -> bool {
        self.dsl_pattern.as_ref().IsDAG()
    }

    /// Check if the pattern has a cycle
    pub fn has_cycle(&self) -> bool {
        self.dsl_pattern.as_ref().HasCycle()
    }

    /// Get a [PatternNode] for a specific node index
    pub fn get_node(&self, index: usize) -> Result<PatternNode> {
        self.assert_node_index(index)?;

        Ok(PatternNode {
            index,
            pattern: self,
        })
    }

    /// Convert the pattern to a network
    pub fn into_network(self) -> Result<Network> {
        // TODO: Implement this
        // let network = Network::load(dataset);
        // self.dsl_pattern.as_ref().ToNetwork(dataset, network);
        Ok(Network::new_uninitialized())
    }
}