mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-08 10:28:52 -06:00
Refactor PrimalSolutionComponent
This commit is contained in:
@@ -5,11 +5,11 @@
|
||||
import gzip
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Optional, Hashable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from miplearn.types import TrainingSample
|
||||
from miplearn.types import TrainingSample, VarIndex
|
||||
|
||||
|
||||
class Instance(ABC):
|
||||
@@ -34,9 +34,9 @@ class Instance(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_instance_features(self):
|
||||
def get_instance_features(self) -> List[float]:
|
||||
"""
|
||||
Returns a 1-dimensional Numpy array of (numerical) features describing the
|
||||
Returns a 1-dimensional array of (numerical) features describing the
|
||||
entire instance.
|
||||
|
||||
The array is used by LearningSolver to determine how similar two instances
|
||||
@@ -56,17 +56,13 @@ class Instance(ABC):
|
||||
|
||||
By default, returns [0].
|
||||
"""
|
||||
return np.zeros(1)
|
||||
return [0]
|
||||
|
||||
def get_variable_features(self, var, index):
|
||||
def get_variable_features(self, var_name: str, index: VarIndex) -> List[float]:
|
||||
"""
|
||||
Returns a 1-dimensional array of (numerical) features describing a particular
|
||||
decision variable.
|
||||
|
||||
The argument `var` is a pyomo.core.Var object, which represents a collection
|
||||
of decision variables. The argument `index` specifies which variable in the
|
||||
collection is the relevant one.
|
||||
|
||||
In combination with instance features, variable features are used by
|
||||
LearningSolver to predict, among other things, the optimal value of each
|
||||
decision variable before the optimization takes place. In the knapsack
|
||||
@@ -79,12 +75,15 @@ class Instance(ABC):
|
||||
|
||||
By default, returns [0].
|
||||
"""
|
||||
return np.zeros(1)
|
||||
return [0]
|
||||
|
||||
def get_variable_category(self, var, index):
|
||||
def get_variable_category(
|
||||
self,
|
||||
var_name: str,
|
||||
index: VarIndex,
|
||||
) -> Optional[Hashable]:
|
||||
"""
|
||||
Returns the category (a string, an integer or any hashable type) for each
|
||||
decision variable.
|
||||
Returns the category for each decision variable.
|
||||
|
||||
If two variables have the same category, LearningSolver will use the same
|
||||
internal ML model to predict the values of both variables. If the returned
|
||||
|
||||
Reference in New Issue
Block a user