mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 18:08:51 -06:00
Make WarmStartComponent use Extractor
This commit is contained in:
@@ -24,6 +24,23 @@ class Extractor(ABC):
|
||||
result[category] = []
|
||||
result[category] += [(var, index)]
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def merge(partial_results, vertical=False):
|
||||
results = {}
|
||||
all_categories = set()
|
||||
for pr in partial_results:
|
||||
all_categories |= pr.keys()
|
||||
for category in all_categories:
|
||||
results[category] = []
|
||||
for pr in partial_results:
|
||||
if category in pr.keys():
|
||||
results[category] += [pr[category]]
|
||||
if vertical:
|
||||
results[category] = np.vstack(results[category])
|
||||
else:
|
||||
results[category] = np.hstack(results[category])
|
||||
return results
|
||||
|
||||
|
||||
class UserFeaturesExtractor(Extractor):
|
||||
@@ -61,10 +78,20 @@ class SolutionExtractor(Extractor):
|
||||
if category not in result.keys():
|
||||
result[category] = []
|
||||
for (var, index) in var_index_pairs:
|
||||
result[category] += [[
|
||||
1 - var[index].value,
|
||||
var[index].value,
|
||||
]]
|
||||
v = var[index].value
|
||||
if v is None:
|
||||
result[category] += [[0, 0]]
|
||||
else:
|
||||
result[category] += [[1 - v, v]]
|
||||
for category in result.keys():
|
||||
result[category] = np.vstack(result[category])
|
||||
return result
|
||||
return result
|
||||
|
||||
|
||||
class CombinedExtractor(Extractor):
|
||||
def __init__(self, extractors):
|
||||
self.extractors = extractors
|
||||
|
||||
def extract(self, instances, models):
|
||||
return self.merge([ex.extract(instances, models)
|
||||
for ex in self.extractors])
|
||||
|
||||
Reference in New Issue
Block a user