AlexNet ImageNet Training
1. Introduction
This repository contains a from-scratch PyTorch implementation of AlexNet trained on the ImageNet-1K dataset. It reproduces the classic 2012 network with modern training utilities such as data augmentation, learning-rate warm-up, and cosine/step decay scheduling.
2. Project Structure
βββ model.py # AlexNet architecture (5 conv + 3 fc)
βββ load_data.py # ImageNet dataloaders & preprocessing
βββ train.py # Training / validation loop & scheduler setup
βββ models/ # (auto-created) checkpoints & logs
βββ README.md # You are here
model.py
- Features block β 5 convolutional layers:
- 96 Γ (11\times11) conv, stride 4
- 256 Γ (5\times5) conv, padding 2
- 384 Γ (3\times3) conv, padding 1
- 384 Γ (3\times3) conv, padding 1
- 256 Γ (3\times3) conv, padding 1
- Classifier β flatten β 4096 β 4096 β 1000 with ReLU and Dropout.
- Optional Kaiming/Xavier weight initialisation via
--init_weights.
load_data.py
- Training augmentations β resize shorter side to 256 px β random 224-px crop β horizontal flip.
- Validation augmentations β resize 256 px β TenCrop(224) (5 crops + mirror) β normalisation.
- Returns two PyTorch
DataLoaders.
train.py
- Implements the epoch/iteration loop, loss backwards pass, accuracy calculation and checkpointing.
- Supports learning-rate warm-up for the first N epochs (
--warmup_epochs). - Choose between step decay or cosine annealing via
--scheduler. - Logs Top-1 accuracy & loss to
models/top1_accuracy.txtand saves a checkpoint every 10 epochs.
3. Dataset
The code expects the ImageNet directory in the original layout:
ILSVRC2012
βββ train
β βββ n01440764
β β βββ n01440764_10026.JPEG
β β βββ ...
βββ val
βββ n01440764
β βββ ILSVRC2012_val_00000293.JPEG
β βββ ...
Pass the root directory with --root /path/to/ILSVRC2012.
π‘ ImageNet licence β obtaining the dataset requires registration with the ImageNet website.
4. Installation
# (Optional) create a virtual environment
python -m venv .venv && source .venv/bin/activate
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
# or the CUDA wheels if you have a GPU
5. Training
Run:
python train.py \
--root /datasets/ILSVRC2012 \
--device cuda:0 # or cpu / mps
Common flags:
--epochs(default 100)--batch_size(default 128)--lr,--momentum,--weight_decay--schedulerstep|cosine+--lr_step_size,--lr_gamma--warmup_epochsβ linear warm-up length--save_dirβ directory for checkpoints & logs
Resuming / fine-tuning
To resume from a checkpoint:
python train.py --root /datasets/ILSVRC2012 --device cuda \
--init_weights False \
--save_dir models \
--epochs 30
# then inside train.py adapt: model.load_state_dict(torch.load('models/model_XX.pth'))
6. Metrics
The script prints Top-1 Accuracy after every epoch. You can extend it to Top-5 with:
maxk = 5
_, pred = logits.topk(maxk, 1, True, True) # (batch, 5)
correct = pred.eq(labels.view(-1, 1).expand_as(pred))
correct_top5 += correct.any(1).float().sum().item()
7. Citation
If you use this code in your research, please cite:
Krizhevsky, Alex, Ilya Sutskever, and Geoffrey Hinton. "ImageNet classification with deep convolutional neural networks." NeurIPS 2012.
8. License
license: mit