mirror of
https://github.com/ANL-CEEESA/MIPLearn.git
synced 2025-12-09 19:08:51 -06:00
Refactor PrimalSolutionComponent
This commit is contained in:
@@ -6,22 +6,28 @@ import gzip
|
||||
import logging
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union, cast, IO
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from miplearn.instance import Instance
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InstanceIterator:
|
||||
def __init__(self, instances):
|
||||
def __init__(
|
||||
self,
|
||||
instances: Union[List[str], List[Instance]],
|
||||
) -> None:
|
||||
self.instances = instances
|
||||
self.current = 0
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
def __next__(self) -> Instance:
|
||||
if self.current >= len(self.instances):
|
||||
raise StopIteration
|
||||
result = self.instances[self.current]
|
||||
@@ -30,13 +36,14 @@ class InstanceIterator:
|
||||
logger.debug("Read: %s" % result)
|
||||
try:
|
||||
if result.endswith(".gz"):
|
||||
with gzip.GzipFile(result, "rb") as file:
|
||||
result = pickle.load(file)
|
||||
with gzip.GzipFile(result, "rb") as gzfile:
|
||||
result = pickle.load(cast(IO[bytes], gzfile))
|
||||
else:
|
||||
with open(result, "rb") as file:
|
||||
result = pickle.load(file)
|
||||
result = pickle.load(cast(IO[bytes], file))
|
||||
except pickle.UnpicklingError:
|
||||
raise Exception(f"Invalid instance file: {result}")
|
||||
assert isinstance(result, Instance)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user