Commit 0b5ed273 authored by Alexandru Dura's avatar Alexandru Dura

Fix the number of classe in the retraining program

parent 5c278103
......@@ -159,8 +159,17 @@ def main() :
class_names = img_folder.classes
# model for training
model_ft = model
out_features = 101
num_features = model_ft.classifier[1].in_features
model_ft.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.2, inplace=False),
# num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment