From 962707e8b7057d0b5e9b19d35d750c3416652bff Mon Sep 17 00:00:00 2001 From: "Alinson S. Xavier" Date: Tue, 27 Jul 2021 09:25:40 -0500 Subject: [PATCH] Replace push_sample by create_sample --- miplearn/instance/base.py | 6 ++++-- miplearn/instance/picklegz.py | 4 ++-- miplearn/solvers/learning.py | 3 +-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/miplearn/instance/base.py b/miplearn/instance/base.py index 5544cf0..3f0e7b2 100644 --- a/miplearn/instance/base.py +++ b/miplearn/instance/base.py @@ -6,7 +6,7 @@ import logging from abc import ABC, abstractmethod from typing import Any, List, TYPE_CHECKING, Dict -from miplearn.features.sample import Sample +from miplearn.features.sample import Sample, MemorySample logger = logging.getLogger(__name__) @@ -192,5 +192,7 @@ class Instance(ABC): def get_samples(self) -> List[Sample]: return self._samples - def push_sample(self, sample: Sample) -> None: + def create_sample(self) -> Sample: + sample = MemorySample() self._samples.append(sample) + return sample diff --git a/miplearn/instance/picklegz.py b/miplearn/instance/picklegz.py index a73b176..8472a9d 100644 --- a/miplearn/instance/picklegz.py +++ b/miplearn/instance/picklegz.py @@ -137,9 +137,9 @@ class PickleGzInstance(Instance): return self.instance.get_samples() @overrides - def push_sample(self, sample: Sample) -> None: + def create_sample(self) -> Sample: assert self.instance is not None - self.instance.push_sample(sample) + return self.instance.create_sample() def write_pickle_gz(obj: Any, filename: str) -> None: diff --git a/miplearn/solvers/learning.py b/miplearn/solvers/learning.py index efb10b7..648020a 100644 --- a/miplearn/solvers/learning.py +++ b/miplearn/solvers/learning.py @@ -150,8 +150,7 @@ class LearningSolver: # Initialize training sample # ------------------------------------------------------- - sample = MemorySample() - instance.push_sample(sample) + sample = instance.create_sample() # Initialize stats # -------------------------------------------------------