33 lines
789 B
Python
33 lines
789 B
Python
import os
|
|
|
|
from image_prediction.config import CONFIG
|
|
from image_prediction.locations import DATA_DIR, TORCH_HOME
|
|
from image_prediction.predictor import Predictor
|
|
|
|
|
|
def suppress_userwarnings():
|
|
import warnings
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
def load_classes():
|
|
classes = CONFIG.estimator.classes
|
|
id2class = dict(zip(range(1, len(classes) + 1), classes))
|
|
return id2class
|
|
|
|
|
|
def get_checkpoint():
|
|
return DATA_DIR / CONFIG.estimator.checkpoint
|
|
|
|
|
|
def set_torch_env():
|
|
os.environ["TORCH_HOME"] = str(TORCH_HOME)
|
|
|
|
|
|
def initialize_predictor(resume):
|
|
set_torch_env()
|
|
checkpoint = get_checkpoint() if not resume else resume
|
|
predictor = Predictor(checkpoint, classes=load_classes(), rejection_class=CONFIG.estimator.rejection_class)
|
|
return predictor
|