2022-03-01 14:17:37 +01:00

33 lines
789 B
Python

import os
from image_prediction.config import CONFIG
from image_prediction.locations import DATA_DIR, TORCH_HOME
from image_prediction.predictor import Predictor
def suppress_userwarnings():
import warnings
warnings.filterwarnings("ignore")
def load_classes():
classes = CONFIG.estimator.classes
id2class = dict(zip(range(1, len(classes) + 1), classes))
return id2class
def get_checkpoint():
return DATA_DIR / CONFIG.estimator.checkpoint
def set_torch_env():
os.environ["TORCH_HOME"] = str(TORCH_HOME)
def initialize_predictor(resume):
set_torch_env()
checkpoint = get_checkpoint() if not resume else resume
predictor = Predictor(checkpoint, classes=load_classes(), rejection_class=CONFIG.estimator.rejection_class)
return predictor