Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
Alexandru Dura
AS1 - Object Detection
Commits
307314bd
Commit
307314bd
authored
Mar 17, 2020
by
Alexandru Dura
Browse files
Retrain a model on the full dataset
parent
67452edd
Changes
1
Hide whitespace changes
Inline
Side-by-side
src/retrain.py
View file @
307314bd
...
...
@@ -33,7 +33,7 @@ def imshow(inp, title=None):
def
train_model
(
device
,
dataloaders
,
model
,
criterion
,
optimizer
,
scheduler
,
num_epochs
=
2
5
):
def
train_model
(
device
,
dataloaders
,
dataset_sizes
,
model
,
criterion
,
optimizer
,
scheduler
,
num_epochs
=
2
):
since
=
time
.
time
()
best_model_wts
=
copy
.
deepcopy
(
model
.
state_dict
())
...
...
@@ -114,18 +114,30 @@ def main() :
# the folder where the images are
img_folder
=
tv
.
datasets
.
ImageFolder
(
IMAGE_FOLDER
,
preprocess
)
n_img_train
=
int
(
len
(
img_folder
)
*
0.8
)
n_img_eval
=
len
(
img_folder
)
-
n_img_train
scale
=
0.001
n_img_train
=
int
(
scale
*
n_img_train
)
n_img_eval
=
int
(
scale
*
n_img_eval
)
n_img_rest
=
len
(
img_folder
)
-
n_img_train
-
n_img_eval
[
img_folder_train
,
img_folder_eval
,
_
]
=
torch
.
utils
.
data
.
dataset
.
random_split
(
img_folder
,
[
n_img_train
,
n_img_eval
,
n_img_rest
])
# load the training set in random order
data_loader_train
=
torch
.
utils
.
data
.
DataLoader
(
img_folder
,
batch_size
=
8
,
data_loader_train
=
torch
.
utils
.
data
.
DataLoader
(
img_folder
_train
,
batch_size
=
8
,
shuffle
=
True
)
data_loader_eval
=
torch
.
utils
.
data
.
DataLoader
(
img_folder
,
batch_size
=
8
,
data_loader_eval
=
torch
.
utils
.
data
.
DataLoader
(
img_folder
_eval
,
batch_size
=
8
,
shuffle
=
True
)
dataloaders
=
{
'train'
:
data_loader_train
,
'eval'
:
data_loader_eval
}
dataset_sizes
=
{
'train'
:
n_img_train
,
'val'
:
n_img_eval
}
dataloaders
=
{
'train'
:
data_loader_train
,
'val'
:
data_loader_eval
}
device
=
torch
.
device
(
"cuda:0"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
# load the model from torch hub
model
=
torch
.
hub
.
load
(
'pytorch/vision:v0.5.0'
,
'mobilenet_v2'
,
pretrained
=
True
)
...
...
@@ -151,9 +163,11 @@ def main() :
exp_lr_scheduler
=
lr_scheduler
.
StepLR
(
optimizer_ft
,
step_size
=
7
,
gamma
=
0.1
)
# Train the model
model_ft
=
train_model
(
device
,
dataloaders
,
model_ft
,
criterion
,
optimizer_ft
,
exp_lr_scheduler
,
model_ft
=
train_model
(
device
,
dataloaders
,
dataset_sizes
,
model_ft
,
criterion
,
optimizer_ft
,
exp_lr_scheduler
,
num_epochs
=
25
)
torch
.
save
(
model_ft
.
state_dict
(),
"food_model"
)
return
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment