Convert TrainingSample to dataclass

This commit is contained in:
2021-04-05 20:36:04 -05:00
parent aeed338837
commit b11779817a
15 changed files with 122 additions and 129 deletions

View File

@@ -38,8 +38,8 @@ def test_xy() -> None:
}
}
)
sample: TrainingSample = {
"Solution": {
sample = TrainingSample(
solution={
"x": {
0: 0.0,
1: 1.0,
@@ -47,7 +47,7 @@ def test_xy() -> None:
3: 0.0,
}
},
"LP solution": {
lp_solution={
"x": {
0: 0.1,
1: 0.1,
@@ -55,7 +55,7 @@ def test_xy() -> None:
3: 0.1,
}
},
}
)
x_expected = {
"default": [
[0.0, 0.0, 0.1],
@@ -99,8 +99,8 @@ def test_xy_without_lp_solution() -> None:
}
}
)
sample: TrainingSample = {
"Solution": {
sample = TrainingSample(
solution={
"x": {
0: 0.0,
1: 1.0,
@@ -108,7 +108,7 @@ def test_xy_without_lp_solution() -> None:
3: 0.0,
}
},
}
)
x_expected = {
"default": [
[0.0, 0.0],
@@ -161,15 +161,15 @@ def test_predict() -> None:
}
}
)
sample: TrainingSample = {
"LP solution": {
sample = TrainingSample(
lp_solution={
"x": {
0: 0.1,
1: 0.5,
2: 0.9,
}
}
}
)
x, _ = PrimalSolutionComponent.sample_xy(features, sample)
comp = PrimalSolutionComponent()
comp.classifiers = {"default": clf}
@@ -254,8 +254,8 @@ def test_evaluate() -> None:
}
}
)
sample: TrainingSample = {
"Solution": {
sample = TrainingSample(
solution={
"x": {
0: 1.0,
1: 1.0,
@@ -264,7 +264,7 @@ def test_evaluate() -> None:
4: 1.0,
}
}
}
)
ev = comp.sample_evaluate(features, sample)
assert ev == {
0: classifier_evaluation_dict(tp=1, fp=1, tn=3, fn=0),