Rewrite StaticLazy.sample_xy

This commit is contained in:
2021-04-12 07:35:51 -05:00
parent 2979bd157c
commit bccf0e9860
3 changed files with 98 additions and 15 deletions

View File

@@ -12,7 +12,7 @@ from miplearn.classifiers import Classifier
from miplearn.classifiers.counting import CountingClassifier
from miplearn.classifiers.threshold import MinProbabilityThreshold, Threshold
from miplearn.components.component import Component
from miplearn.features import TrainingSample, Features, Constraint
from miplearn.features import TrainingSample, Features, Constraint, Sample
from miplearn.types import LearningSolveStats
logger = logging.getLogger(__name__)
@@ -199,6 +199,46 @@ class StaticLazyConstraintsComponent(Component):
y[category] += [[True, False]]
return x, y
@overrides
def sample_xy(
self,
sample: Sample,
) -> Tuple[Dict[Hashable, List[List[float]]], Dict[Hashable, List[List[float]]]]:
x: Dict = {}
y: Dict = {}
assert sample.after_load is not None
assert sample.after_load.constraints is not None
for (cid, constr) in sample.after_load.constraints.items():
# Initialize categories
if not constr.lazy:
continue
category = constr.category
if category is None:
continue
if category not in x:
x[category] = []
y[category] = []
# Features
sf = sample.after_load
if sample.after_lp is not None:
sf = sample.after_lp
assert sf.instance is not None
features = list(sf.instance.to_list())
assert sf.constraints is not None
assert sf.constraints[cid] is not None
features.extend(sf.constraints[cid].to_list())
x[category].append(features)
# Labels
if sample.after_mip is not None:
assert sample.after_mip.extra is not None
if cid in sample.after_mip.extra["lazy_enforced"]:
y[category] += [[False, True]]
else:
y[category] += [[True, False]]
return x, y
@overrides
def fit_xy(
self,

View File

@@ -130,6 +130,7 @@ class Features:
constraints: Optional[Dict[str, Constraint]] = None
lp_solve: Optional["LPSolveStats"] = None
mip_solve: Optional["MIPSolveStats"] = None
extra: Optional[Dict] = None
@dataclass