Securing Your FastAPI ML Application

Securing Your FastAPI ML Application

This tutorial shows how to add authentication to a FastAPI-based machine learning (ML) application. It is a follow-up to this tutorial on building a FastAPI application for image classification inference using a fine-tuned ResNet18 model. In the previous implementation, the API endpoints were open to all users. This guide explores how to secure the endpoints by requiring API keys for access.


Why Add Authentication to FastAPI Applications?

When exposing machine learning models through APIs, it is crucial to ensure that only authorized users can access your endpoints. Adding authentication not only improves security but also allows you to manage access permissions effectively. Here are some benefits of adding authentication to your FastAPI application:

  • Security: Prevent unauthorized access to your APIs and protect sensitive data.
  • Access Control: Grant access only to users or clients with valid credentials.
  • Resource Management: Avoid misuse or overloading of resources by limiting access to legitimate users.
  • Scalability: Secure APIs make it easier to deploy and scale your application in production environments.

Implementing API Key-Based Authentication

API keys are a simple and effective way to secure your FastAPI endpoints. In this section, we will learn how we can add the authentication tot eh simple fastapi applciaitons.

Step 1: Define the API Key and Header Field

The first step is to define an API key and specify the header name where the key will be passed. This API key is used to authenticate user requests.

from fastapi import FastAPI, Depends, HTTPException, Security, status
from fastapi.security.api_key import APIKeyHeader
from typing import Optional

app = FastAPI()

# Define the expected API key and header name.
API_KEY = "2rq82hasdflawsk"
API_KEY_NAME = "access_token"

# Create an APIKeyHeader instance; setting auto_error=False allows us to customize the error.
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)

Step 2: Add an Authentication Dependency

Create a dependency function to validate the API key. This function will check if the provided API key matches the expected key. If the key is invalid or missing, the function will return an appropriate HTTP error response.

async def get_api_key(api_key: Optional[str] = Security(api_key_header)):
    if api_key == API_KEY:
        return api_key
    raise HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED,
        detail="Invalid API Key",
        headers={"WWW-Authenticate": "Bearer"},
    )

Step 3: Secure Your Endpoints

Apply the authentication dependency to your FastAPI endpoints to protect them. Only requests with a valid API key will be granted access.


@app.get("/secure-data")
async def secure_data(api_key: str = Depends(get_api_key)):
    """Protected endpoint that returns a message when a valid API key is provided."""
    return {"message": "Your Bank is Secure with Bank name: ABL and Bank ID: 123456789"}


if __name__ == "__main__":
    import uvicorn

    uvicorn.run(app, host="127.0.0.1", port=8000)

Testing the Basic Authentication

Once the API key-based authentication is implemented, you can test it using tools like curl or Postman. Try the following scenarios:

  1. Access Without an API Key: Confirm that the endpoint denies access with an appropriate error message.
$ curl -X GET http://127.0.0.1:8000/secure-data
>>> {"detail":"Invalid API Key"}
  1. Access With an Invalid API Key: Verify that the endpoint returns an error for incorrect API keys.
$ curl -X GET -H "access_token: sfqwfqfqfaswqwq" http://127.0.0.1:8000/secure-data
>>> {"detail":"Invalid API Key"}
  1. Access With a Valid API Key: Ensure that authorized requests succeed and return the expected response.
$ curl -X GET -H "access_token: 2rq82hasdflawsk" http://127.0.0.1:8000/secure-data
>>> {"message":"Your Bank is Secure with Bank name: ABL and Bank ID: 123456789"}

Adding Authentication to the Machine Learning API

We will now apply the same authentication process to the machine learning applications. Most of the code is taken from the previous tutorial, with the addition of authentication for the model inference and prediction endpoints. To securely manage the API key, we retrieve it from the .env file instead of hardcoding it into the code.

Securing the /model-info Endpoint

The /model-info endpoint provides metadata about the ML model, such as its architecture and class names. By adding API key authentication, you can restrict access to this information.


import io
import os
import logging

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import uvicorn
from fastapi import FastAPI, File, HTTPException, UploadFile, Depends, Query
from fastapi.responses import JSONResponse
from fastapi.security.api_key import APIKeyHeader
from dotenv import load_dotenv
from PIL import Image, UnidentifiedImageError
from torchvision import models

# Load environment variables from .env file
if not load_dotenv():
    raise ValueError("Failed to load .env file")

# Get API key from environment variable
API_KEY = os.getenv("API_KEY")
if not API_KEY:
    raise ValueError("API_KEY environment variable not set in .env file")

# Initialize FastAPI app
app = FastAPI(
    title="CIFAR10 Image Classification APP",
    description="A production-ready API for image classification using a fine-tuned model on CIFAR10.",
)

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Define API key security scheme
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)

# Define CIFAR10 class names
class_names = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]
num_classes = len(class_names)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load the fine-tuned ResNet18 model
model_path = "finetuned_model.pth"
if not os.path.exists(model_path):
    raise FileNotFoundError(f"Model file not found at {model_path}")

model = models.resnet18(weights=None)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()

# Preprocessing transforms
preprocess = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


# API key validation dependency function
def get_api_key(api_key: str = Depends(api_key_header)):
    if not api_key:
        logger.warning("API key is missing")
        raise HTTPException(status_code=403, detail="API key is missing")
    if api_key != API_KEY:
        logger.warning(f"Invalid API key attempt: {api_key}")
        raise HTTPException(status_code=403, detail="Invalid API key")
    return api_key


@app.get("/health", summary="Health Check", tags=["Status"])
async def health_check():
    """Endpoint for checking if the API is running."""
    return {"status": "ok", "message": "API is running", "device": str(device)}


@app.get("/model-info", summary="Get Model Information", tags=["Metadata"])
async def get_model_info(api_key: str = Depends(get_api_key)):
    """Combined endpoint for retrieving model metadata and class names."""
    model_info = {
        "model_architecture": "ResNet18",
        "num_classes": num_classes,
        "class_names": class_names,
        "device": str(device),
        "model_weights_file": model_path,
        "description": "Model fine-tuned on CIFAR10 dataset",
    }
    return JSONResponse(model_info)

Securing the /predict Endpoint

The /predict endpoint performs inference using the trained ML model. Securing this endpoint ensures that only authorized users can submit images for prediction.

@app.post("/predict", summary="Predict Image Class", tags=["Inference"])
async def predict(
    file: UploadFile = File(...),
    include_confidence: bool = Query(
        False, description="Include confidence scores for top predictions"
    ),
    top_k: int = Query(
        1, ge=1, le=10, description="Number of top predictions to return"
    ),
    api_key: str = Depends(get_api_key),
):
    """
    Unified prediction endpoint that can return either simple class prediction
    or detailed predictions with confidence scores.
    """
    # Validate file type
    if not file.filename.lower().endswith((".png", ".jpg", ".jpeg")):
        logger.error(f"Invalid file format: {file.filename}")
        raise HTTPException(
            status_code=400,
            detail="Invalid image format. Only PNG and JPEG are supported.",
        )

    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents)).convert("RGB")
    except UnidentifiedImageError:
        logger.error("Uploaded file is not a valid image")
        raise HTTPException(
            status_code=400, detail="Uploaded file is not a valid image."
        )
    except Exception as e:
        logger.error(f"Error processing image: {str(e)}")
        raise HTTPException(status_code=400, detail="Error processing image.")

    # Preprocess the image
    input_tensor = preprocess(image).unsqueeze(0).to(device)

    try:
        with torch.no_grad():
            outputs = model(input_tensor)

            if include_confidence:
                # Return predictions with confidence scores
                probabilities = torch.nn.functional.softmax(outputs, dim=1)
                top_probs, top_idxs = torch.topk(
                    probabilities, k=min(top_k, num_classes)
                )
                top_probs = top_probs.cpu().numpy().tolist()[0]
                top_idxs = top_idxs.cpu().numpy().tolist()[0]
                predictions = [
                    {"class": class_names[idx], "confidence": prob}
                    for idx, prob in zip(top_idxs, top_probs)
                ]
                return JSONResponse({"predictions": predictions})
            else:
                # Return simple prediction (just the class)
                _, preds = torch.max(outputs, 1)
                predicted_class = class_names[preds[0]]
                return JSONResponse({"predicted_class": predicted_class})
    except Exception as e:
        logger.error(f"Error during model inference: {str(e)}")
        raise HTTPException(status_code=500, detail="Error during model inference.")


if __name__ == "__main__":
    uvicorn.run("secure_app:app",port=5454)

Testing the ML API

After adding authentication to the ML API, test the endpoints with different scenarios:

  1. Accessing /model-info Without an API Key: Confirm that the endpoint denies access when no API key is provided.
$ curl -X GET 'http://127.0.0.1:5454/model-info'
>>> {"detail":"API key is missing"}
  1. Accessing /model-info With an Invalid API Key: Verify that the endpoint returns an error for invalid keys.
$ curl -X GET 'http://127.0.0.1:5454/model-info' \
  -H 'accept: application/json' \
  -H 'X-API-Key: asfasfwefasasfasdf'
>>> {"detail":"Invalid API key"}
  1. Accessing /model-info With a Valid API Key: Ensure that authorized requests return the expected model metadata.
$ curl -X GET 'http://127.0.0.1:5454/model-info' \
  -H 'accept: application/json' \
  -H 'X-API-Key: q3hewf#onio12$r032'
>>> {
    "model_architecture": "ResNet18",
    "num_classes": 10,
    "class_names": [
        "airplane",
        "automobile", 
        "bird",
        "cat",
        "deer",
        "dog",
        "frog",
        "horse",
        "ship",
        "truck"
    ],
    "device": "cuda:0",
    "model_weights_file": "finetuned_model.pth",
    "description": "Model fine-tuned on CIFAR10 dataset"
}
  1. Submitting a Prediction Without an API Key: Confirm that the /predict endpoint denies access for unauthorized users.
$ curl -X POST 'http://127.0.0.1:5454/predict' \
  -H 'accept: application/json' \
  -H 'X-API-Key: asfasfwefasasfasdf' \
  -F 'file=data/sample/cat.png'
>>> {"detail":"Invalid API key"}
  1. Submitting a Prediction With a Valid API Key: Ensure that authorized requests return the correct predictions.
$ curl -X POST 'http://127.0.0.1:5454/predict' \
  -H 'accept: application/json' \
  -H 'X-API-Key: q3hewf#onio12$r032' \
  -F 'file=@./data/sample/cat.png'
>>> {"predicted_class":"cat"}

Conclusion

Adding authentication to your FastAPI application is a crucial step in securing your APIs and managing access permissions. By using API key-based authentication, you can protect sensitive endpoints and ensure that only authorized users can interact with your application.

This tutorial shows how to implement authentication in a FastAPI ML application and integrate it into an existing ML inference API. Whether you are building an ML-powered API or any other application, authentication is an essential feature for production-grade deployments.

The complete source code for both tutorials is available in our kingabzpro/image-classification-fastapi GitHub repository. Feel free to explore the implementation, raise issues, or reach out if you need assistance with implementing authentication in your FastAPI application.


Ready to Transform Your Enterprise with AI?

Join companies leveraging NexusML to build seamless, high-performance AI and analytics solutions.