use super::utils;
use crate::hard_policies::{HardPolicy, PolicyError};
use crate::modifier_ordering::RandomOrdering;
use crate::netsim::config::ConfigModifier;
use crate::netsim::Network;
use crate::optimizers::Optimizer;
use crate::soft_policies::SoftPolicy;
use crate::strategies::PushBackTreeStrategy;
use crate::{Error, Stopper};
use log::*;
use rand::prelude::*;
use std::time::{Duration, SystemTime};
pub struct OptimizerTRTA<P>
where
P: SoftPolicy + Clone,
{
net: Network,
groups: Vec<Vec<ConfigModifier>>,
hard_policy: HardPolicy,
soft_policy: P,
rng: ThreadRng,
stop_time: Option<SystemTime>,
max_group_solve_time: Option<Duration>,
#[cfg(feature = "count-states")]
num_states: usize,
}
impl<P> Optimizer<P> for OptimizerTRTA<P>
where
P: SoftPolicy + Clone,
{
fn new(
mut net: Network,
modifiers: Vec<ConfigModifier>,
mut hard_policy: HardPolicy,
soft_policy: P,
time_budget: Option<Duration>,
) -> Result<Box<Self>, Error> {
net.clear_undo_stack();
let num_modifiers = modifiers.len();
let mut groups: Vec<Vec<ConfigModifier>> = Vec::with_capacity(modifiers.len());
for modifier in modifiers {
groups.push(vec![modifier]);
}
let mut fw_state = net.get_forwarding_state();
hard_policy.set_num_mods_if_none(num_modifiers);
hard_policy.step(&mut net, &mut fw_state)?;
if !hard_policy.check() {
error!(
"Initial state errors::\n{}",
utils::fmt_err(&hard_policy.get_watch_errors(), &net)
);
return Err(Error::InvalidInitialState);
}
let max_group_solve_time: Option<Duration> =
time_budget.as_ref().map(|dur| *dur / super::TIME_FRACTION);
let stop_time: Option<SystemTime> = time_budget.map(|dur| SystemTime::now() + dur);
Ok(Box::new(Self {
net,
groups,
hard_policy,
soft_policy,
rng: rand::thread_rng(),
stop_time,
max_group_solve_time,
#[cfg(feature = "count-states")]
num_states: 0,
}))
}
fn work(&mut self, mut abort: Stopper) -> Result<(Vec<ConfigModifier>, f64), Error> {
let mut net = self.net.clone();
let mut hard_policy = self.hard_policy.clone();
let mut soft_policy = self.soft_policy.clone();
let (valid_groups, invalid_groups) = self.prepare_next_option(
&mut net,
&mut hard_policy,
&soft_policy,
0.0,
0..self.groups.len(),
);
let mut stack =
vec![StackFrame { num_undo: 0, valid_groups, invalid_groups, idx: 0, soft_policy }];
let mut current_sequence: Vec<usize> = vec![];
loop {
if self.stop_time.as_ref().map(|time| time.elapsed().is_ok()).unwrap_or(false) {
error!("Time budget is used up! No solution was found yet!");
return Err(Error::Timeout);
}
if abort.try_is_stop().unwrap_or(false) {
info!("Operation was aborted!");
return Err(Error::Abort);
}
let frame = match stack.last_mut() {
Some(frame) => frame,
None => {
error!("Could not find any valid ordering!");
return Err(Error::ProbablyNoSafeOrdering);
}
};
let action: StackAction<P> = if frame.idx < frame.valid_groups.len() {
let (next_group_idx, current_cost) = frame.valid_groups[frame.idx];
frame.idx += 1;
current_sequence.push(next_group_idx);
if current_sequence.len() == self.groups.len() {
info!(
"Valid solution was found! Learned {} groups",
self.groups.iter().filter(|g| g.len() > 1).count()
);
return Ok((
utils::finalize_ordering(&self.groups, ¤t_sequence),
current_cost,
));
}
let mut soft_policy = frame.soft_policy.clone();
for modifier in self.groups[next_group_idx].iter() {
net.apply_modifier(modifier).expect("Modifier should be ok!");
let mut fw_state = net.get_forwarding_state();
hard_policy.step(&mut net, &mut fw_state).expect("Modifier should be ok!");
soft_policy.update(&mut fw_state, &net);
}
let (valid_groups, invalid_groups) = self.prepare_next_option(
&mut net,
&mut hard_policy,
&soft_policy,
current_cost,
frame
.valid_groups
.iter()
.map(|(g, _)| g)
.filter(|g| **g != next_group_idx)
.chain(frame.invalid_groups.iter())
.cloned(),
);
StackAction::Push(StackFrame {
num_undo: self.groups[next_group_idx].len(),
valid_groups,
invalid_groups,
idx: 0,
soft_policy,
})
} else {
let random_group_pos = self.rng.gen_range(0, frame.invalid_groups.len());
match self.find_dependency(
&mut net,
&mut hard_policy,
¤t_sequence,
frame.invalid_groups[random_group_pos],
abort.clone(),
) {
Some((new_group, old_groups)) => {
info!("Found a new dependency group!");
utils::add_minimal_ordering_as_new_gorup(
&mut self.groups,
old_groups,
Some(new_group),
);
StackAction::Reset
}
None => {
info!("Could not find a new dependency group!");
if current_sequence.len() + 1 == self.groups.len() {
return Err(Error::NoSafeOrdering);
}
StackAction::Pop
}
}
};
match action {
StackAction::Pop => {
'backtrace: while let Some(frame) = stack.last() {
if frame.idx < frame.valid_groups.len() {
break 'backtrace;
} else {
current_sequence.pop();
(0..frame.num_undo).for_each(|_| {
net.undo_action().expect("Cannot undo the action on the network");
hard_policy.undo();
});
stack.pop();
}
}
}
StackAction::Push(new_frame) => stack.push(new_frame),
StackAction::Reset => {
net = self.net.clone();
hard_policy = self.hard_policy.clone();
soft_policy = self.soft_policy.clone();
let (valid_groups, invalid_groups) = self.prepare_next_option(
&mut net,
&mut hard_policy,
&soft_policy,
0.0,
0..self.groups.len(),
);
stack = vec![StackFrame {
num_undo: 0,
valid_groups,
invalid_groups,
idx: 0,
soft_policy,
}];
current_sequence.clear();
}
}
}
}
#[cfg(feature = "count-states")]
fn num_states(&self) -> usize {
self.num_states
}
}
impl<P> OptimizerTRTA<P>
where
P: SoftPolicy + Clone,
{
fn prepare_next_option(
&mut self,
net: &mut Network,
hard_policy: &mut HardPolicy,
soft_policy: &P,
current_cost: f64,
options: impl Iterator<Item = usize>,
) -> (Vec<(usize, f64)>, Vec<usize>) {
let mut valid_options: Vec<(usize, f64)> = Vec::new();
let mut invalid_options: Vec<usize> = Vec::new();
for group_idx in options {
let mut mod_ok: bool = true;
let mut num_undo: usize = 0;
let mut num_undo_policy: usize = 0;
let mut cost: f64 = current_cost;
let mut sp = soft_policy.clone();
'apply_group: for modifier in self.groups[group_idx].iter() {
#[cfg(feature = "count-states")]
{
self.num_states += 1;
}
num_undo += 1;
if net.apply_modifier(modifier).is_ok() {
num_undo_policy += 1;
let mut fw_state = net.get_forwarding_state();
hard_policy.step(net, &mut fw_state).expect("cannot check policies!");
if !hard_policy.check() {
mod_ok = false;
break 'apply_group;
}
sp.update(&mut fw_state, net);
cost += sp.cost();
} else {
mod_ok = false;
break 'apply_group;
}
}
(0..num_undo_policy).for_each(|_| hard_policy.undo());
(0..num_undo).for_each(|_| {
net.undo_action().expect("Cannot perform undo!");
});
if mod_ok {
valid_options.push((group_idx, cost));
} else {
invalid_options.push(group_idx);
}
}
valid_options.shuffle(&mut self.rng);
valid_options.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
(valid_options, invalid_options)
}
fn find_dependency(
&mut self,
net: &mut Network,
hard_policy: &mut HardPolicy,
good_ordering: &[usize],
bad_group: usize,
abort: Stopper,
) -> Option<(Vec<ConfigModifier>, Vec<usize>)> {
let mut num_undo = 0;
let mut num_undo_policy = 0;
let mut errors = None;
'apply_group: for modifier in self.groups[bad_group].iter() {
num_undo += 1;
if net.apply_modifier(modifier).is_ok() {
num_undo_policy += 1;
let mut fw_state = net.get_forwarding_state();
hard_policy.step(net, &mut fw_state).expect("cannot check policies!");
if !hard_policy.check() {
errors = Some(hard_policy.get_watch_errors());
break 'apply_group;
}
} else {
errors = Some((Vec::new(), vec![Some(PolicyError::NoConvergence)]));
break 'apply_group;
}
}
(0..num_undo_policy).for_each(|_| hard_policy.undo());
(0..num_undo).for_each(|_| {
net.undo_action().expect("Cannot perform undo!");
});
match errors {
Some(errors) => {
let ordering = good_ordering
.iter()
.cloned()
.chain(std::iter::once(bad_group))
.collect::<Vec<usize>>();
utils::find_dependency::<PushBackTreeStrategy<RandomOrdering>>(
&self.net,
&self.groups,
&self.hard_policy,
&ordering,
errors,
self.stop_time,
self.max_group_solve_time,
abort,
#[cfg(feature = "count-states")]
&mut self.num_states,
)
}
None => panic!("The bad group, passed into this function seems to be fine!"),
}
}
}
#[derive(Debug, Clone)]
enum StackAction<P>
where
P: SoftPolicy,
{
Pop,
Push(StackFrame<P>),
Reset,
}
#[derive(Clone)]
struct StackFrame<P>
where
P: SoftPolicy,
{
num_undo: usize,
valid_groups: Vec<(usize, f64)>,
invalid_groups: Vec<usize>,
idx: usize,
soft_policy: P,
}
impl<P> std::fmt::Debug for StackFrame<P>
where
P: SoftPolicy,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StackFrame")
.field("num_undo", &self.num_undo)
.field("valid_groups", &self.valid_groups)
.field("invalid_groups", &self.invalid_groups)
.field("idx", &self.idx)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::example_networks::repetitions::*;
use crate::example_networks::*;
use crate::hard_policies::*;
use crate::soft_policies::*;
use assert_approx_eq::assert_approx_eq;
use std::time::Duration;
#[test]
fn chain_gadget() {
type R = Repetition5;
type T = ChainGadget<R>;
let net = T::net(0);
let cf = T::final_config(&net, 0);
let patch = net.current_config().get_diff(&cf);
let hard_policy =
HardPolicy::reachability(net.get_routers().iter(), net.get_known_prefixes().iter());
let soft_policy = MinimizeTrafficShift::new(&mut net.get_forwarding_state(), &net);
let mut o = OptimizerTRTA::new(
net,
patch.modifiers,
hard_policy,
soft_policy,
Some(Duration::from_secs(1000)),
)
.unwrap();
let expected_cost = (R::get_count() as f64) / ((R::get_count() + 2) as f64);
let (_, cost) = o.work(Stopper::new()).unwrap();
assert_approx_eq!(expected_cost, cost);
}
#[test]
fn state_specific_chain_gadget() {
type R = Repetition5;
type T = StateSpecificChainGadget<R>;
let net = T::net(0);
let cf = T::final_config(&net, 0);
let patch = net.current_config().get_diff(&cf);
let hard_policy =
HardPolicy::reachability(net.get_routers().iter(), net.get_known_prefixes().iter());
let soft_policy = MinimizeTrafficShift::new(&mut net.get_forwarding_state(), &net);
let mut o = OptimizerTRTA::new(
net,
patch.modifiers,
hard_policy,
soft_policy,
Some(Duration::from_secs(1000)),
)
.unwrap();
let expected_cost = 0.0;
let (_, cost) = o.work(Stopper::new()).unwrap();
assert_approx_eq!(expected_cost, cost);
}
}