π« Lung Cancer Classification with Grad-CAM
This project provides a complete pipeline for training a deep learning model to detect lung cancer from histopathological images and deploying it via a Flask API. It includes Grad-CAM (Gradient-weighted Class Activation Mapping) to highlight the specific regions of the image that the model used to make its prediction. π Features
Model: ResNet-18 (Transfer Learning).
Explainability: Integrated Grad-CAM to generate visual heatmaps.
API: Flask-based REST API for real-time analysis.
Dataset Support: Automatic integration with Kaggle's Lung and Colon Cancer Histopathological Images.
π Getting Started
- Installation
Ensure you have Python 3.8+ installed. Install the required dependencies: Bash
pip install torch torchvision flask flask-cors pillow numpy opencv-python-headless tqdm kaggle
- Prepare the Dataset
You can either place your data in ./lung_cancer_data/ or let the training script download it for you (requires Kaggle API credentials).
Data Structure: Plaintext
lung_cancer_data/ βββ lung_colon_image_set/ βββ lung_image_sets/ βββ lung_aca/ (Cancerous) βββ lung_scc/ (Cancerous) βββ lung_n/ (Normal)
- Training the Model
Run the training script to train the ResNet-18 model. The script will save the best-performing model as lung_cancer_model.pth. Bash
python train.py
- Running the API
Once the model is trained, start the Flask server: Bash
python app.py
The server will start on http://localhost:7860. π API Reference Analyze Image
Analyzes a lung histopathology image and returns the probability of cancer along with a Grad-CAM heatmap.
URL: /api/analyze
Method: POST
Data Params: image=[file] (Multipart/form-data)
Success Response: JSON
{ "success": true, "probability": 89.45, "risk_level": "high", "heatmap": "data:image/jpeg;base64,/9j/4AAQ..." }
π¬ How Grad-CAM Works
Grad-CAM uses the gradients of the target class flowing into the final convolutional layer (layer3 in this implementation) to produce a coarse localization map.
Forward Pass: Calculate the probability of the "Cancer" class.
Backward Pass: Calculate gradients of the output with respect to the feature maps.
Weighting: Average the gradients to find the importance of each feature map.
Heatmap: Overlay a JET color map onto the original image.
π¦ Project Structure
app.py: Flask API with Grad-CAM logic and image processing.
train.py: Training script with data augmentation and validation.
requirements.txt: List of dependencies.
models/: Directory where the trained .pth files are stored.
This tool is for educational and research purposes only. It is not a substitute for professional medical advice, diagnosis, or treatment. Always seek the advice of a qualified health provider with any questions regarding a medical condition.