Implement load; update fit

This commit is contained in:
2022-02-25 08:26:33 -06:00
parent 522f3a7e18
commit 04dd3ad5d5
4 changed files with 54 additions and 19 deletions

View File

@@ -6,7 +6,7 @@ import gc
import gzip
import os
import pickle
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict
from typing import Optional, Any, List, cast, IO, TYPE_CHECKING, Dict, Callable
import numpy as np
from overrides import overrides
@@ -177,3 +177,9 @@ def save(objs: List[Any], dirname: str) -> List[str]:
filenames.append(filename)
write_pickle_gz(obj, filename)
return filenames
def load(filename: str, build_model: Callable) -> Any:
with gzip.GzipFile(filename, "rb") as file:
data = pickle.load(cast(IO[bytes], file))
return build_model(data)