Make WarmStartComponent use Extractor

This commit is contained in:
2020-02-04 13:29:06 -06:00
parent 17c21382c5
commit dbea4aa988
5 changed files with 88 additions and 49 deletions

View File

@@ -24,6 +24,23 @@ class Extractor(ABC):
result[category] = []
result[category] += [(var, index)]
return result
@staticmethod
def merge(partial_results, vertical=False):
results = {}
all_categories = set()
for pr in partial_results:
all_categories |= pr.keys()
for category in all_categories:
results[category] = []
for pr in partial_results:
if category in pr.keys():
results[category] += [pr[category]]
if vertical:
results[category] = np.vstack(results[category])
else:
results[category] = np.hstack(results[category])
return results
class UserFeaturesExtractor(Extractor):
@@ -61,10 +78,20 @@ class SolutionExtractor(Extractor):
if category not in result.keys():
result[category] = []
for (var, index) in var_index_pairs:
result[category] += [[
1 - var[index].value,
var[index].value,
]]
v = var[index].value
if v is None:
result[category] += [[0, 0]]
else:
result[category] += [[1 - v, v]]
for category in result.keys():
result[category] = np.vstack(result[category])
return result
return result
class CombinedExtractor(Extractor):
def __init__(self, extractors):
self.extractors = extractors
def extract(self, instances, models):
return self.merge([ex.extract(instances, models)
for ex in self.extractors])