Store cuts and lazy constraints as JSON in H5

dev
Alinson S. Xavier 2 years ago
parent 2774edae8c
commit 281508f44c

@ -2,8 +2,9 @@
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import json
import logging import logging
from typing import List, Dict, Any, Hashable, Union from typing import List, Dict, Any, Hashable
import numpy as np import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer from sklearn.preprocessing import MultiLabelBinarizer
@ -15,6 +16,15 @@ from miplearn.solvers.abstract import AbstractModel
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def convert_lists_to_tuples(obj: Any) -> Any:
if isinstance(obj, list):
return tuple(convert_lists_to_tuples(item) for item in obj)
elif isinstance(obj, dict):
return {key: convert_lists_to_tuples(value) for key, value in obj.items()}
else:
return obj
class _BaseMemorizingConstrComponent: class _BaseMemorizingConstrComponent:
def __init__(self, clf: Any, extractor: FeaturesExtractor, field: str) -> None: def __init__(self, clf: Any, extractor: FeaturesExtractor, field: str) -> None:
self.clf = clf self.clf = clf
@ -38,8 +48,7 @@ class _BaseMemorizingConstrComponent:
sample_constrs_str = h5.get_scalar(self.field) sample_constrs_str = h5.get_scalar(self.field)
assert sample_constrs_str is not None assert sample_constrs_str is not None
assert isinstance(sample_constrs_str, str) assert isinstance(sample_constrs_str, str)
sample_constrs = eval(sample_constrs_str) sample_constrs = convert_lists_to_tuples(json.loads(sample_constrs_str))
assert isinstance(sample_constrs, list)
y_sample = [] y_sample = []
for c in sample_constrs: for c in sample_constrs:
if c not in constr_to_idx: if c not in constr_to_idx:

@ -170,11 +170,11 @@ def _stab_read(data: Union[str, MaxWeightStableSetData]) -> MaxWeightStableSetDa
return data return data
def _stab_separate(data: MaxWeightStableSetData, x_val: List[float]) -> List[Hashable]: def _stab_separate(data: MaxWeightStableSetData, x_val: List[float]) -> List:
# Check that we selected at most one vertex for each # Check that we selected at most one vertex for each
# clique in the graph (sum <= 1) # clique in the graph (sum <= 1)
violations: List[Hashable] = [] violations: List[Any] = []
for clique in nx.find_cliques(data.graph): for clique in nx.find_cliques(data.graph):
if sum(x_val[i] for i in clique) > 1.0001: if sum(x_val[i] for i in clique) > 1.0001:
violations.append(tuple(sorted(clique))) violations.append(sorted(clique))
return violations return violations

@ -231,18 +231,18 @@ def _tsp_separate(
x_val: dict[Tuple[int, int], float], x_val: dict[Tuple[int, int], float],
edges: List[Tuple[int, int]], edges: List[Tuple[int, int]],
n_cities: int, n_cities: int,
) -> List[Tuple[Tuple[int, int], ...]]: ) -> List:
violations = [] violations = []
selected_edges = [e for e in edges if x_val[e] > 0.5] selected_edges = [e for e in edges if x_val[e] > 0.5]
graph = nx.Graph() graph = nx.Graph()
graph.add_edges_from(selected_edges) graph.add_edges_from(selected_edges)
for component in list(nx.connected_components(graph)): for component in list(nx.connected_components(graph)):
if len(component) < n_cities: if len(component) < n_cities:
cut_edges = tuple( cut_edges = [
(e[0], e[1]) [e[0], e[1]]
for e in edges for e in edges
if (e[0] in component and e[1] not in component) if (e[0] in component and e[1] not in component)
or (e[0] not in component and e[1] in component) or (e[0] not in component and e[1] in component)
) ]
violations.append(cut_edges) violations.append(cut_edges)
return violations return violations

@ -1,7 +1,9 @@
# MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization # MIPLearn: Extensible Framework for Learning-Enhanced Mixed-Integer Optimization
# Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved. # Copyright (C) 2020-2022, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details. # Released under the modified BSD license. See COPYING.md for more details.
import logging import logging
import json
from typing import Dict, Optional, Callable, Any, List from typing import Dict, Optional, Callable, Any, List
import gurobipy as gp import gurobipy as gp
@ -167,9 +169,9 @@ class GurobiModel(AbstractModel):
pass pass
self._extract_after_mip_solution_pool(h5) self._extract_after_mip_solution_pool(h5)
if self.lazy_ is not None: if self.lazy_ is not None:
h5.put_scalar("mip_lazy", repr(self.lazy_)) h5.put_scalar("mip_lazy", json.dumps(self.lazy_))
if self.cuts_ is not None: if self.cuts_ is not None:
h5.put_scalar("mip_cuts", repr(self.cuts_)) h5.put_scalar("mip_cuts", json.dumps(self.cuts_))
def fix_variables( def fix_variables(
self, self,

@ -28,17 +28,17 @@ def test_mem_component_gp(
clf.fit.assert_called() clf.fit.assert_called()
x, y = clf.fit.call_args.args x, y = clf.fit.call_args.args
assert x.shape == (3, 50) assert x.shape == (3, 50)
assert y.shape == (3, 415) assert y.shape == (3, 412)
y = y.tolist() y = y.tolist()
assert y[0][:20] == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] assert y[0][40:50] == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
assert y[1][:20] == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] assert y[1][40:50] == [1, 1, 0, 1, 1, 1, 1, 1, 1, 1]
assert y[2][:20] == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1] assert y[2][40:50] == [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
# Should store violations # Should store violations
assert comp.constrs_ is not None assert comp.constrs_ is not None
assert comp.n_features_ == 50 assert comp.n_features_ == 50
assert comp.n_targets_ == 415 assert comp.n_targets_ == 412
assert len(comp.constrs_) == 415 assert len(comp.constrs_) == 412
# Call before-mip # Call before-mip
stats: Dict[str, Any] = {} stats: Dict[str, Any] = {}
@ -52,7 +52,7 @@ def test_mem_component_gp(
# Should set cuts_aot_ # Should set cuts_aot_
assert model.cuts_aot_ is not None assert model.cuts_aot_ is not None
assert len(model.cuts_aot_) == 285 assert len(model.cuts_aot_) == 256
def test_usage_stab( def test_usage_stab(

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.
Loading…
Cancel
Save