1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
use std::pin::Pin;
use smile::{
autocxx::{c_int, WithinBox},
ffi::DSL_pattern,
};
use crate::network::Network;
use super::{EdgeType, PatternError, PatternNode, Result};
fn sort_pair(a: usize, b: usize) -> (usize, usize) {
if a < b {
(a, b)
} else {
(b, a)
}
}
/// A pattern, otherwise called an essential graph, is a collection of nodes and edges that can be
/// either directed
///
/// Following methods have been implemented:
/// - [x] `enum EdgeType {None,Undirected,Directed}`
/// - [x] `int GetSize() const`
/// - [x] `EdgeType GetEdge(int from, int to) const`
/// - [x] `void SetEdge(int from, int to, EdgeType type)`
/// - [x] `bool HasDirectedPath(int from, int to) const`
/// - [x] `bool HasCycle() const`
/// - [x] `bool IsDAG() const`
/// - [x] `bool ToDAG()`
/// - [ ] `void Set(DSL_network &input)`
/// - [ ] `bool ToNetwork(const DSL_dataset &ds, DSL_network &net)`
/// - [ ] `void Print() const`
///
/// Differences from the original API:
/// - Method `void SetSize(int size)` has not been implemented, as the default way to construct a [Pattern] is to call [Pattern::with_capacity]
/// - Following methods have been implemented on [super::PatternNode]:
/// - `void GetAdjacentNodes(const int node, std::vector<int>& adj) const`
/// - `void GetParents(const int node, std::vector<int>& par) const`
/// - `void GetChildren(const int node, std::vector<int>& child) const`
/// - `bool HasIncomingEdge(int to) const`
/// - `bool HasOutgoingEdge(int from) const`
pub struct Pattern {
pub(super) dsl_pattern: Pin<Box<DSL_pattern>>,
node_count: usize,
}
impl Pattern {
/// Create a new pattern with no nodes or edges
///
/// Example usage:
/// ```rust
/// # use nice_smile::pattern::Pattern;
/// // Create a pattern with 5 nodes
/// let mut pattern = Pattern::with_capacity(5);
/// # assert_eq!(pattern.get_node_count(), 5);
/// ```
pub fn with_capacity(node_count: usize) -> Self {
let mut dsl_pattern = DSL_pattern::new().within_box();
dsl_pattern.as_mut().SetSize(c_int(node_count as i32));
Self {
dsl_pattern,
node_count,
}
}
/// Get the number of nodes in the pattern
///
/// Example usage:
/// ```rust
/// # use nice_smile::pattern::Pattern;
/// let mut pattern = Pattern::with_capacity(5);
/// assert_eq!(pattern.get_node_count(), 5);
/// ```
pub fn get_node_count(&self) -> usize {
self.node_count
}
fn assert_node_index(&self, index: usize) -> Result<()> {
if index >= self.get_node_count() {
return Err(PatternError::NodeIndexOutOfBounds(index));
}
Ok(())
}
/// Get [EdgeType] between two nodes
pub fn get_edge(&self, from: usize, to: usize) -> Result<EdgeType> {
self.assert_node_index(from)?;
self.assert_node_index(to)?;
let (from, to) = sort_pair(from, to);
let edge_type = self
.dsl_pattern
.as_ref()
.GetEdge(c_int(from as i32), c_int(to as i32));
Ok(EdgeType::from_dsl(edge_type, from, to))
}
/// Check if there is a directed path from one node to another.
///
/// Example usage:
/// ```rust
/// # use nice_smile::pattern::{Pattern, EdgeType};
/// let mut pattern = Pattern::with_capacity(3);
/// pattern.remove_edge(0, 2); // Remove edge A --- C
/// pattern.direct_edge(0, 1); // Direct edge A --> B
/// pattern.direct_edge(1, 2); // Direct edge B --> C
///
/// // Assert directed path A --> B --> C
/// assert_eq!(pattern.has_directed_path(0, 2), Ok(true));
/// ```
pub fn has_directed_path(&self, from: usize, to: usize) -> Result<bool> {
self.assert_node_index(from)?;
self.assert_node_index(to)?;
Ok(self
.dsl_pattern
.as_ref()
.HasDirectedPath(c_int(from as i32), c_int(to as i32)))
}
/// Set an edge between two nodes to a specific [EdgeType].
pub fn set_edge(&mut self, from: usize, to: usize, edge_type: EdgeType) -> Result<()> {
// NOTE: We can skip assertions here, as all of following calls do the asserts for us
match edge_type {
EdgeType::Absent => self.remove_edge(from, to),
EdgeType::Undirected => self.undirect_edge(from, to),
EdgeType::Directed(from, to) => self.direct_edge(from, to),
}
}
fn set_edge_unchecked(&mut self, from: usize, to: usize, edge_type: EdgeType) {
self.dsl_pattern
.as_mut()
.SetEdge(c_int(from as i32), c_int(to as i32), edge_type.into());
}
/// Remove an edge between two nodes
pub fn remove_edge(&mut self, from: usize, to: usize) -> Result<()> {
self.assert_node_index(from)?;
self.assert_node_index(to)?;
let (from, to) = sort_pair(from, to);
self.set_edge_unchecked(from, to, EdgeType::Absent);
Ok(())
}
/// Direct an edge between two nodes
pub fn direct_edge(&mut self, from: usize, to: usize) -> Result<()> {
self.assert_node_index(from)?;
self.assert_node_index(to)?;
// NOTE: We do not shadow the variables here, as we need to call the function with the correct order
let (a, b) = sort_pair(from, to);
self.set_edge_unchecked(a, b, EdgeType::Directed(from, to));
Ok(())
}
/// Undirect an edge between two nodes
pub fn undirect_edge(&mut self, from: usize, to: usize) -> Result<()> {
self.assert_node_index(from)?;
self.assert_node_index(to)?;
let (from, to) = sort_pair(from, to);
self.set_edge_unchecked(from, to, EdgeType::Undirected);
Ok(())
}
/// Try to convert the pattern to a Directed Acyclic Graph
pub fn to_dag(&mut self) -> Result<()> {
match self.dsl_pattern.as_mut().ToDAG() {
true => Ok(()),
false => Err(PatternError::ContainsCycle),
}
}
/// Check if the pattern is a Directed Acyclic Graph
pub fn is_dag(&self) -> bool {
self.dsl_pattern.as_ref().IsDAG()
}
/// Check if the pattern has a cycle
pub fn has_cycle(&self) -> bool {
self.dsl_pattern.as_ref().HasCycle()
}
/// Get a [PatternNode] for a specific node index
pub fn get_node(&self, index: usize) -> Result<PatternNode> {
self.assert_node_index(index)?;
Ok(PatternNode {
index,
pattern: self,
})
}
/// Convert the pattern to a network
pub fn into_network(self) -> Result<Network> {
// TODO: Implement this
// let network = Network::load(dataset);
// self.dsl_pattern.as_ref().ToNetwork(dataset, network);
Ok(Network::new_uninitialized())
}
}