|
|
|
@ -30,20 +30,20 @@ def test_branching():
|
|
|
|
|
assert component.y_train[key].shape == (8, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_branch_priority_save_load():
|
|
|
|
|
state_file = tempfile.NamedTemporaryFile(mode="r")
|
|
|
|
|
solver = LearningSolver(components={"branch-priority": BranchPriorityComponent()})
|
|
|
|
|
solver.parallel_solve(_get_instances(), n_jobs=2)
|
|
|
|
|
solver.fit()
|
|
|
|
|
comp = solver.components["branch-priority"]
|
|
|
|
|
assert comp.x_train["default"].shape == (8, 4)
|
|
|
|
|
assert comp.y_train["default"].shape == (8, 1)
|
|
|
|
|
assert "default" in comp.predictors.keys()
|
|
|
|
|
solver.save_state(state_file.name)
|
|
|
|
|
|
|
|
|
|
solver = LearningSolver(components={"branch-priority": BranchPriorityComponent()})
|
|
|
|
|
solver.load_state(state_file.name)
|
|
|
|
|
comp = solver.components["branch-priority"]
|
|
|
|
|
assert comp.x_train["default"].shape == (8, 4)
|
|
|
|
|
assert comp.y_train["default"].shape == (8, 1)
|
|
|
|
|
assert "default" in comp.predictors.keys()
|
|
|
|
|
# def test_branch_priority_save_load():
|
|
|
|
|
# state_file = tempfile.NamedTemporaryFile(mode="r")
|
|
|
|
|
# solver = LearningSolver(components={"branch-priority": BranchPriorityComponent()})
|
|
|
|
|
# solver.parallel_solve(_get_instances(), n_jobs=2)
|
|
|
|
|
# solver.fit()
|
|
|
|
|
# comp = solver.components["branch-priority"]
|
|
|
|
|
# assert comp.x_train["default"].shape == (8, 4)
|
|
|
|
|
# assert comp.y_train["default"].shape == (8, 1)
|
|
|
|
|
# assert "default" in comp.predictors.keys()
|
|
|
|
|
# solver.save_state(state_file.name)
|
|
|
|
|
#
|
|
|
|
|
# solver = LearningSolver(components={"branch-priority": BranchPriorityComponent()})
|
|
|
|
|
# solver.load_state(state_file.name)
|
|
|
|
|
# comp = solver.components["branch-priority"]
|
|
|
|
|
# assert comp.x_train["default"].shape == (8, 4)
|
|
|
|
|
# assert comp.y_train["default"].shape == (8, 1)
|
|
|
|
|
# assert "default" in comp.predictors.keys()
|
|
|
|
|