Commit e057a019 authored by Alexandru Dura's avatar Alexandru Dura
Browse files

Add argument to the retrain program

parent 307314bd
......@@ -2,6 +2,7 @@
# https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
#
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
......@@ -92,6 +93,8 @@ def train_model(device, dataloaders, dataset_sizes, model, criterion, optimizer,
best_model_wts = copy.deepcopy(model.state_dict())
print()
torch.save(model.state_dict(), "food_model_e{}".format(epoch))
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
......@@ -104,6 +107,16 @@ def train_model(device, dataloaders, dataset_sizes, model, criterion, optimizer,
def main() :
n_epochs = 25
# set scale to 1 to train on the entire dataset
scale = 0.001
if len(sys.argv) == 3:
n_epochs = int(sys.argv[1])
scale = float(sys.argv[2])
print("Training during {} epochs, using {}% of the dataset.".format(n_epochs, scale * 100))
preprocess = tv.transforms.Compose([
tv.transforms.Resize(256),
tv.transforms.CenterCrop(224),
......@@ -118,7 +131,8 @@ def main() :
n_img_eval = len(img_folder) - n_img_train
scale = 0.001
n_img_train = int(scale * n_img_train)
n_img_eval = int(scale * n_img_eval)
n_img_rest = len(img_folder) - n_img_train - n_img_eval
......@@ -164,7 +178,7 @@ def main() :
# Train the model
model_ft = train_model(device, dataloaders, dataset_sizes, model_ft, criterion, optimizer_ft, exp_lr_scheduler,
num_epochs=25)
num_epochs=n_epochs)
torch.save(model_ft.state_dict(), "food_model")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment