Commit 0b5ed273 authored by Alexandru Dura's avatar Alexandru Dura

Fix the number of classe in the retraining program

parent 5c278103
...@@ -159,8 +159,17 @@ def main() : ...@@ -159,8 +159,17 @@ def main() :
class_names = img_folder.classes class_names = img_folder.classes
# model for training # model for training
model_ft = model model_ft = model
out_features = 101
num_features = model_ft.classifier[1].in_features
model_ft.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.2, inplace=False),
torch.nn.Linear(in_features=num_features,
out_features=out_features,
bias=True))
# num_ftrs = model_ft.fc.in_features # num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2. # Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)). # Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment