Refer to variables by varname instead of (vname, index)

This commit is contained in:
2021-04-07 10:56:31 -05:00
parent 856b595d5e
commit 1cf6124757
22 changed files with 467 additions and 516 deletions

View File

@@ -6,14 +6,16 @@ import logging
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Hashable
from overrides import EnforceOverrides
from miplearn.features import TrainingSample, Features
from miplearn.types import VarIndex
from miplearn.types import VariableName, Category
logger = logging.getLogger(__name__)
# noinspection PyMethodMayBeStatic
class Instance(ABC):
class Instance(ABC, EnforceOverrides):
"""
Abstract class holding all the data necessary to generate a concrete model of the
proble.
@@ -60,9 +62,9 @@ class Instance(ABC):
"""
return [0]
def get_variable_features(self, var_name: str, index: VarIndex) -> List[float]:
def get_variable_features(self, var_name: VariableName) -> List[float]:
"""
Returns a 1-dimensional array of (numerical) features describing a particular
Returns a (1-dimensional) list of numerical features describing a particular
decision variable.
In combination with instance features, variable features are used by
@@ -79,11 +81,7 @@ class Instance(ABC):
"""
return [0]
def get_variable_category(
self,
var_name: str,
index: VarIndex,
) -> Optional[Hashable]:
def get_variable_category(self, var_name: VariableName) -> Optional[Category]:
"""
Returns the category for each decision variable.
@@ -91,6 +89,7 @@ class Instance(ABC):
internal ML model to predict the values of both variables. If the returned
category is None, ML models will ignore the variable.
A category can be any hashable type, such as strings, numbers or tuples.
By default, returns "default".
"""
return "default"

View File

@@ -2,14 +2,16 @@
# Copyright (C) 2020-2021, UChicago Argonne, LLC. All rights reserved.
# Released under the modified BSD license. See COPYING.md for more details.
import gc
import gzip
import os
import pickle
import gc
from typing import Optional, Any, List, Hashable, cast, IO, Callable
from typing import Optional, Any, List, Hashable, cast, IO
from overrides import overrides
from miplearn.instance.base import logger, Instance
from miplearn.types import VarIndex
from miplearn.types import VariableName, Category
class PickleGzInstance(Instance):
@@ -31,62 +33,72 @@ class PickleGzInstance(Instance):
self.instance: Optional[Instance] = None
self.filename: str = filename
@overrides
def to_model(self) -> Any:
assert self.instance is not None
return self.instance.to_model()
@overrides
def get_instance_features(self) -> List[float]:
assert self.instance is not None
return self.instance.get_instance_features()
def get_variable_features(self, var_name: str, index: VarIndex) -> List[float]:
@overrides
def get_variable_features(self, var_name: VariableName) -> List[float]:
assert self.instance is not None
return self.instance.get_variable_features(var_name, index)
return self.instance.get_variable_features(var_name)
def get_variable_category(
self,
var_name: str,
index: VarIndex,
) -> Optional[Hashable]:
@overrides
def get_variable_category(self, var_name: VariableName) -> Optional[Category]:
assert self.instance is not None
return self.instance.get_variable_category(var_name, index)
return self.instance.get_variable_category(var_name)
@overrides
def get_constraint_features(self, cid: str) -> Optional[List[float]]:
assert self.instance is not None
return self.instance.get_constraint_features(cid)
@overrides
def get_constraint_category(self, cid: str) -> Optional[Hashable]:
assert self.instance is not None
return self.instance.get_constraint_category(cid)
@overrides
def has_static_lazy_constraints(self) -> bool:
assert self.instance is not None
return self.instance.has_static_lazy_constraints()
@overrides
def has_dynamic_lazy_constraints(self) -> bool:
assert self.instance is not None
return self.instance.has_dynamic_lazy_constraints()
@overrides
def is_constraint_lazy(self, cid: str) -> bool:
assert self.instance is not None
return self.instance.is_constraint_lazy(cid)
@overrides
def find_violated_lazy_constraints(self, model: Any) -> List[Hashable]:
assert self.instance is not None
return self.instance.find_violated_lazy_constraints(model)
@overrides
def build_lazy_constraint(self, model: Any, violation: Hashable) -> Any:
assert self.instance is not None
return self.instance.build_lazy_constraint(model, violation)
@overrides
def find_violated_user_cuts(self, model: Any) -> List[Hashable]:
assert self.instance is not None
return self.instance.find_violated_user_cuts(model)
@overrides
def build_user_cut(self, model: Any, violation: Hashable) -> Any:
assert self.instance is not None
return self.instance.build_user_cut(model, violation)
@overrides
def load(self) -> None:
if self.instance is None:
obj = read_pickle_gz(self.filename)
@@ -95,12 +107,14 @@ class PickleGzInstance(Instance):
self.features = self.instance.features
self.training_data = self.instance.training_data
@overrides
def free(self) -> None:
self.instance = None # type: ignore
self.features = None # type: ignore
self.training_data = None # type: ignore
gc.collect()
@overrides
def flush(self) -> None:
write_pickle_gz(self.instance, self.filename)