mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-07 01:48:51 -06:00
Always pass (n,)-shaped arrays to regressors instead of (n,1)
This commit is contained in:
@@ -41,18 +41,22 @@ class ObjectiveValueComponent(Component):
|
|||||||
features = InstanceFeaturesExtractor().extract(training_instances)
|
features = InstanceFeaturesExtractor().extract(training_instances)
|
||||||
ub = ObjectiveValueExtractor(kind="upper bound").extract(training_instances)
|
ub = ObjectiveValueExtractor(kind="upper bound").extract(training_instances)
|
||||||
lb = ObjectiveValueExtractor(kind="lower bound").extract(training_instances)
|
lb = ObjectiveValueExtractor(kind="lower bound").extract(training_instances)
|
||||||
|
assert ub.shape == (len(training_instances), 1)
|
||||||
|
assert lb.shape == (len(training_instances), 1)
|
||||||
self.ub_regressor = deepcopy(self.regressor_prototype)
|
self.ub_regressor = deepcopy(self.regressor_prototype)
|
||||||
self.lb_regressor = deepcopy(self.regressor_prototype)
|
self.lb_regressor = deepcopy(self.regressor_prototype)
|
||||||
logger.debug("Fitting ub_regressor...")
|
logger.debug("Fitting ub_regressor...")
|
||||||
self.ub_regressor.fit(features, ub)
|
self.ub_regressor.fit(features, ub.ravel())
|
||||||
logger.debug("Fitting ub_regressor...")
|
logger.debug("Fitting ub_regressor...")
|
||||||
self.lb_regressor.fit(features, lb)
|
self.lb_regressor.fit(features, lb.ravel())
|
||||||
|
|
||||||
def predict(self, instances):
|
def predict(self, instances):
|
||||||
features = InstanceFeaturesExtractor().extract(instances)
|
features = InstanceFeaturesExtractor().extract(instances)
|
||||||
lb = self.lb_regressor.predict(features)
|
lb = self.lb_regressor.predict(features)
|
||||||
ub = self.ub_regressor.predict(features)
|
ub = self.ub_regressor.predict(features)
|
||||||
return np.hstack([lb, ub])
|
assert lb.shape == (len(instances),)
|
||||||
|
assert ub.shape == (len(instances),)
|
||||||
|
return np.array([lb, ub]).T
|
||||||
|
|
||||||
def evaluate(self, instances):
|
def evaluate(self, instances):
|
||||||
y_pred = self.predict(instances)
|
y_pred = self.predict(instances)
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ def test_usage():
|
|||||||
def test_obj_evaluate():
|
def test_obj_evaluate():
|
||||||
instances, models = get_training_instances_and_models()
|
instances, models = get_training_instances_and_models()
|
||||||
reg = Mock(spec=Regressor)
|
reg = Mock(spec=Regressor)
|
||||||
reg.predict = Mock(return_value=np.array([[1000.0], [1000.0]]))
|
reg.predict = Mock(return_value=np.array([1000.0, 1000.0]))
|
||||||
comp = ObjectiveValueComponent(regressor=reg)
|
comp = ObjectiveValueComponent(regressor=reg)
|
||||||
comp.fit(instances)
|
comp.fit(instances)
|
||||||
ev = comp.evaluate(instances)
|
ev = comp.evaluate(instances)
|
||||||
|
|||||||
Reference in New Issue
Block a user