class Model:
...
def load_model():
model = Model()
model.load_state_dict(torch.load(modelLoader.path))
model.to(device)
model.eval()
modelLoader.update_model(model)
def predict(data):
model = modelLoader.get_model()
files = magic.download(data['images'])
predict = []
file_name = []
for file in files:
img, label = load_image(file)
result = model(img)
_, pred = torch.max(result, 1)
predict.append(int(pred))
file_name.append(label)
result = {
'data': {
'predict': predict,
'file_name': file_name
}
}
return result
def predict_file(files):
model = modelLoader.get_model()
files = magic.save(files)
predict = []
file_name = []
for file in files:
img, label = load_image(file)
result = model(img)
_, pred = torch.max(result, 1)
predict.append(int(pred))
file_name.append(label)
result = {
'data': {
'predict': predict,
'file_name': file_name
}
}
return result
def on_train_completed(metric, config):
if metric > pipelineHepler.last_metric:
modelLoader.save_model()
load_model()
pipelineHepler.update_metric(metric)