Commit 0b5ed273 authored by Alexandru Dura's avatar Alexandru Dura
Browse files

Fix the number of classe in the retraining program

parent 5c278103
......@@ -159,8 +159,17 @@ def main() :
class_names = img_folder.classes
# model for training
model_ft = model
out_features = 101
num_features = model_ft.classifier[1].in_features
model_ft.classifier = torch.nn.Sequential(
torch.nn.Dropout(p=0.2, inplace=False),
# num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment