Commit 307314bd authored by Alexandru Dura's avatar Alexandru Dura
Browse files

Retrain a model on the full dataset

parent 67452edd
......@@ -33,7 +33,7 @@ def imshow(inp, title=None):
def train_model(device, dataloaders, model, criterion, optimizer, scheduler, num_epochs=25):
def train_model(device, dataloaders, dataset_sizes, model, criterion, optimizer, scheduler, num_epochs=2):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
......@@ -114,18 +114,30 @@ def main() :
# the folder where the images are
img_folder = tv.datasets.ImageFolder(IMAGE_FOLDER, preprocess)
n_img_train = int(len(img_folder) * 0.8)
n_img_eval = len(img_folder) - n_img_train
scale = 0.001
n_img_train = int(scale * n_img_train)
n_img_eval = int(scale * n_img_eval)
n_img_rest = len(img_folder) - n_img_train - n_img_eval
[img_folder_train, img_folder_eval, _] = torch.utils.data.dataset.random_split(img_folder, [n_img_train, n_img_eval, n_img_rest])
# load the training set in random order
data_loader_train = torch.utils.data.DataLoader(img_folder, batch_size=8,
data_loader_train = torch.utils.data.DataLoader(img_folder_train, batch_size=8,
shuffle=True)
data_loader_eval = torch.utils.data.DataLoader(img_folder, batch_size=8,
data_loader_eval = torch.utils.data.DataLoader(img_folder_eval, batch_size=8,
shuffle=True)
dataloaders = {'train' : data_loader_train, 'eval' : data_loader_eval}
dataset_sizes = {'train' : n_img_train, 'val' : n_img_eval}
dataloaders = {'train' : data_loader_train, 'val' : data_loader_eval}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load the model from torch hub
model = torch.hub.load('pytorch/vision:v0.5.0', 'mobilenet_v2', pretrained=True)
......@@ -151,9 +163,11 @@ def main() :
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
# Train the model
model_ft = train_model(device, dataloaders, model_ft, criterion, optimizer_ft, exp_lr_scheduler,
model_ft = train_model(device, dataloaders, dataset_sizes, model_ft, criterion, optimizer_ft, exp_lr_scheduler,
num_epochs=25)
torch.save(model_ft.state_dict(), "food_model")
return
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment