Federated Learning for Hospital Readmission Prediction with Flower and PyTorch
Updated on December 12, 2025 17 minutes read
Hospital 30-day readmission rates are tightly linked to patient outcomes, quality metrics, and reimbursement.
If you can predict which patients are at high risk of coming back, you can trigger targeted follow-up and prevent avoidable harm.
The problem is that the richest signals live in many hospitals’ electronic health records (EHRs).
Pooling those data into a single warehouse is often blocked by privacy laws, data governance, and institutional politics.
Federated learning (FL) offers a different route: you move the model to the data, train locally at each hospital, and only send model updates back to a central server.
Frameworks like Flower make it practical to experiment with FL using familiar tools such as PyTorch. In this article, we’ll build a realistic simulation of multiple “hospitals” that collaboratively train a readmission model without sharing raw data.
Background and Prerequisites
What You Should Already Know
You’ll be comfortable here if you know basic Python and ML. You should understand training vs test sets, loss functions, and have done some supervised learning on tabular data.
Some light math helps: vectors, probability, and logistic regression intuition. You don’t need advanced deep learning experience, but being able to read simple PyTorch code is important.
Basic healthcare knowledge is useful but not mandatory. Knowing what an admission, discharge, or comorbidity is will make the examples easier to follow.
Hospital Readmission and EHR Data
A 30-day unplanned readmission is a new admission soon after discharge that wasn’t planned. Health systems track this as a quality signal and often face financial penalties for high rates.
Each hospital stay can be represented as a row in a dataset. Typical columns include demographics, diagnoses, length of stay, prior utilisation, and discharge disposition.
The label is simple but clinically meaningful:
readmitted_30d = 1if the patient returned unexpectedly within 30 days- otherwise
0
This gives us a binary classification problem. The challenge is that positives are relatively rare and errors are not symmetric in their clinical impact.
Technical Background: PyTorch, Flower, and FL
PyTorch is a flexible deep learning framework that works well beyond images and text. For this problem, we’ll use it to build a small neural network over tabular EHR features.
Federated learning changes the training setup rather than the model. Instead of training on one big dataset, we train on separate local datasets and periodically aggregate model updates.
Flower is a federated learning framework that orchestrates this process. It provides client and server abstractions, a simulation engine, and integrates smoothly with PyTorch.
Core Intuition and Theory
Readmission Prediction as a Binary Classifier
Let each encounter be a feature vector $x_i \in \mathbb{R}^d$.
The label is $y_i \in {0,1}$, where $1$ means “readmitted within 30 days”.
We build a model $f_\theta(x_i)$ that outputs a logit (a real number). We map it to a probability with the sigmoid function:
$$ \hat{p}i = \sigma!\left(f\theta(x_i)\right) = \frac{1}{1 + e^{-f_\theta(x_i)}} $$
We interpret $\hat{p}_i$ as the estimated readmission risk. High values should correspond to patients who are likely to be readmitted.
The standard loss is binary cross-entropy:
$$ L(\theta) = -\frac{1}{N}\sum_{i=1}^{N}\left[y_i\log(\hat{p}_i) + (1-y_i)\log(1-\hat{p}_i)\right] $$
This penalises the model when predicted probabilities disagree with the true outcomes. Because readmissions are rare, we often use a class-weighted variant to up-weight missed positives.
Centralised vs Federated Objective
In a centralised setup, we minimise loss over all encounters. We treat the dataset as if it came from one huge hospital.
In federated learning, data is partitioned by hospital. Hospital $k$ has dataset $D_k$ with $n_k$ samples, and there are $K$ hospitals.
The local objective at hospital $k$ is:
$$ F_k(\theta) = \frac{1}{n_k}\sum_{i \in D_k} \ell!\left(f_\theta(x_i), y_i\right) $$
The global objective is a weighted sum of local objectives:
$$ F(\theta) = \sum_{k=1}^{K}\frac{n_k}{N}F_k(\theta), \quad N=\sum_{k=1}^{K}n_k $$
We still aim to minimise a single function. We just compute its pieces at each hospital without pooling raw data.
Federated Averaging (FedAvg)
FedAvg runs in discrete rounds coordinated by a server. Each round has four main steps:
- The server sends the current global parameters $\theta_t$ to a subset of hospitals.
- Each selected hospital initialises its local model with $\theta_t$.
- Locally, each hospital runs several epochs of SGD on its own data. This yields updated parameters $\theta_{t+1}^{(k)}$ for hospital $k$.
- The server aggregates the updates with a weighted average:
$$ \theta_{t+1} = \sum_{k\in S_t}\frac{n_k}{\sum_{j\in S_t}n_j},\theta_{t+1}^{(k)} $$
Hospitals with more data contribute more to the new global model. Over many rounds, this approximates centralised training under reasonable assumptions.
Clinical Interpretation of the Algorithm
Each hospital’s local optimisation reflects its own case mix and practice patterns. For example, a tertiary centre may see more complex cases than a community hospital.
Aggregation blends these perspectives into a global model. Small hospitals benefit from patterns learned at larger peers while keeping their own data private.
Hyperparameters like local epochs and the number of rounds matter clinically. Too much local training can over-emphasise big hospitals, while too few rounds might underfit rare patterns.
Metrics That Matter in Healthcare
For readmissions, plain accuracy is not enough. The positive class is imbalanced, and errors have different costs.
Common metrics include:
- ROC AUC for ranking quality across all thresholds.
- Average precision / AUPRC for performance on the minority class.
You also care a lot about calibration. If the model says a group hasa 20% risk, about 1 in 5 should actually be readmitted.
Implementation: Federated Readmission Prediction with Flower and PyTorch
We’ll now build a simulation on a single machine. It will mimic five hospitals training a shared model using Flower and PyTorch.
High-level steps:
- Load and preprocess per-hospital tabular data.
- Define a small PyTorch model for risk prediction.
- Wrap the model in a Flower client that performs local training.
- Configure a Flower server running FedAvg and start the simulation.
Environment and Project Layout
Create and activate a virtual environment:
python -m venv .venv
source .venv/bin/activate # Windows: .venv\Scripts\activate
Install dependencies:
pip install "flwr" "torch" "pandas" "scikit-learn"
Assume a folder structure like:
project/
hospital_fl.py
data/
hospital_0.csv
hospital_1.csv
hospital_2.csv
hospital_3.csv
hospital_4.csv
Each CSV contains the same schema but different rows. That lets us simulate hospitals with different populations.
Step 1: Loading and Preprocessing Hospital Data
We start by defining feature and label columns. For simplicity, use numeric features that are easy to interpret.
# hospital_fl.py
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
Import torch
from torch. Utils. data import Dataset, DataLoader
import torch.nn as nn
import torch. optim as optim
import flwr as fl
HOSPITAL_IDS = [0, 1, 2, 3, 4]
DATA_DIR = "data"
FEATURE_COLUMNS = [
"age",
"los_days", # length of stay in days
"num_prior_adm",
"num_chronic_conditions",
"ed_visits_12m",
"icu_stay", # 0/1
"discharge_to_home", # 0/1
]
TARGET_COLUMN = "readmitted_30d" # 0/1
Define a simple PyTorch dataset for tabular data. It will hold tensors for features and labels.
class ReadmissionDataset(Dataset):
def __init__(self, X, y):
self.X = torch.tensor(X, dtype=torch.float32)
self.y = torch.tensor(y, dtype=torch.float32)
def __len__(self):
return len(self.y)
def __getitem__(self, idx):
return self.X[idx], self.y[idx]
Now we write a helper to load one hospital’s CSV. It returns train and validation dataloaders and a fitted scaler.
def load_hospital_dataloaders(
hospital_id: int, batch_size: int = 64, test_size: float = 0.2
):
path = f"{DATA_DIR}/hospital_{hospital_id}.csv"
df = pd.read_csv(path)
# Basic cleaning: drop rows with missing key fields
df = df.dropna(subset=FEATURE_COLUMNS + [TARGET_COLUMN])
X = df[FEATURE_COLUMNS].values
y = df[TARGET_COLUMN].values
X_train, X_val, y_train, y_val = train_test_split(
X, y, test_size=test_size, stratify=y, random_state=42
)
# Scale features per hospital
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
train_ds = ReadmissionDataset(X_train_scaled, y_train)
val_ds = ReadmissionDataset(X_val_scaled, y_val)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
return train_loader, val_loader, scaler
Each hospital learns its own feature scaling. That avoids sharing global statistics and matches many real deployments.
Step 2: The PyTorch Readmission Model
We’ll build a small multilayer perceptron. It’s simple, reasonably expressive, and runs fine on CPUs.
class ReadmissionNet(nn.Module):
def __init__(self, input_dim: int):
super().__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 32),
nn.ReLU(),
nn.Dropout(p=0.2),
nn.Linear(32, 16),
nn.ReLU(),
nn.Linear(16, 1), # logits
)
def forward(self, x):
return self.model(x).squeeze(1) # shape: [batch]
Define training and evaluation helpers. We’ll use BCEWithLogitsLoss, which expects logits and handles the sigmoid internally.
def train_one_epoch(model, loader, optimizer, device):
model.train()
criterion = nn.BCEWithLogitsLoss()
total_loss = 0.0
num_examples = 0
for X_batch, y_batch in loader:
X_batch = X_batch.to(device)
y_batch = y_batch.to(device)
optimizer.zero_grad()
logits = model(X_batch)
loss = criterion(logits, y_batch)
loss.backward()
optimizer.step()
batch_size = y_batch.size(0)
total_loss += loss.item() * batch_size
num_examples += batch_size
return total_loss / num_examples
Evaluation returns both loss and raw predictions. We’ll convert them into clinical metrics later if desired.
@torch.no_grad()
def evaluate(model, loader, device):
model.eval()
criterion = nn.BCEWithLogitsLoss()
total_loss = 0.0
num_examples = 0
all_probs = []
all_labels = []
for X_batch, y_batch in loader:
X_batch = X_batch.to(device)
y_batch = y_batch.to(device)
logits = model(X_batch)
loss = criterion(logits, y_batch)
probs = torch.sigmoid(logits)
batch_size = y_batch.size(0)
total_loss += loss.item() * batch_size
num_examples += batch_size
all_probs.append(probs.cpu().numpy())
all_labels.append(y_batch.cpu().numpy())
avg_loss = total_loss / num_examples
y_true = np.concatenate(all_labels)
y_prob = np.concatenate(all_probs)
return avg_loss, y_true, y_prob
Step 3: Converting Between PyTorch and Flower Parameters
Flower expects parameters as a list of NumPy arrays. We need helper functions to go back and forth.
def get_model_parameters(model: nn.Module):
return [val.cpu().numpy() for _, val in model.state_dict().items()]
def set_model_parameters(model: nn.Module, parameters):
state_dict = model.state_dict()
for (key, _), param in zip(state_dict.items(), parameters):
state_dict[key] = torch.tensor(param)
model.load_state_dict(state_dict, strict=True)
These functions let Flower read and write model weights. They keep the mapping between state dict keys and weight arrays consistent.
Step 4: Implementing the Flower Hospital Client
We implement a NumPyClient that wraps a local model and data. It defines how a hospital participates in training and evaluation.
class HospitalClient(fl.client.NumPyClient):
def __init__(self, model, train_loader, val_loader, device):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
def get_parameters(self, config):
# Called before the first round or when the server needs parameters
, return get_model_parameters(self.model)
The fit method runs local training. It returns updated parameters and some metadata.
def fit(self, parameters, config):
# Receive global model from server
set_model_parameters(self.model, parameters)
self.model.to(self.device)
lr = config.get("lr", 1e-3)
local_epochs = config.get("local_epochs", 1)
optimizer = optim.Adam(self.model.parameters(), lr=lr)
for _ in range(local_epochs):
train_loss = train_one_epoch(
self.model, self.train_loader, optimizer, self.device
)
num_examples = len(self.train_loader.dataset)
metrics = {"train_loss": float(train_loss)}
# Send updated weights back to server
return get_model_parameters(self.model), num_examples, metrics
The evaluate method computes validation loss and AUC. This helps the server track progress over rounds.
def evaluate(self, parameters, config):
# Receive global params for evaluation
set_model_parameters(self.model, parameters)
self.model.to(self.device)
val_loss, y_true, y_prob = evaluate(
self.model, self.val_loader, self.device
)
try:
from sklearn.metrics import roc_auc_score
auc = roc_auc_score(y_true, y_prob)
Except Exception:
auc = float("nan")
num_examples = len(self.val_loader.dataset)
metrics = {"val_loss": float(val_loss), "val_auc": float(auc)}
# Flower expects (loss, num_examples, metrics)
return float(val_loss), num_examples, metrics
Step 5: Wiring Up the Flower Simulation
We now define a factory that creates a HospitalClient given a client ID. Flower will call this for each simulated hospital.
def client_fn(cid: str) -> fl.client.Client:
hospital_id = int(cid)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader, val_loader, _ = load_hospital_dataloaders(hospital_id)
model = ReadmissionNet(input_dim=len(FEATURE_COLUMNS))
return HospitalClient(model, train_loader, val_loader, device)
Configure the FedAvg strategy and run the simulation. We specify how many clients to sample per round and how many rounds to run.
def main():
num_clients = len(HOSPITAL_IDS)
Strategy = fl.server.strategy.FedAvg(
fraction_fit=0.6,
fraction_evaluate=0.6,
min_fit_clients=3,
min_evaluate_clients=3,
min_available_clients=num_clients,
on_fit_config_fn=lambda rnd: {
"lr": 1e-3,
"local_epochs": 1,
},
)
fl... simulation.start_simulation(
client_fn=client_fn,
num_clients=num_clients,
config=fl.server.ServerConfig(num_rounds=10),
strategy=strategy,
)
if __name__ == "__main__":
main()
Run the script with:
python hospital_fl.py
You’ll see Flower logs showing training and evaluation rounds. Each round includes metrics like validation loss and AUC aggregated across clients.
Interpreting the Simulation Results
At the end of training, the global model reflects data from all hospitals. You can inspect per-hospital validation AUC and see how performance varies by site.
Clinically, you’d look for:
- AUC high enough to be useful (for example, $> 0.7$).
- Reasonable calibration, especially for high-risk ranges.
You’d also compare to a centrally trained baseline when possible. That helps you quantify the cost or benefit of federated training.
From Notebook to Production in Healthcare
The simulation is a learning tool, not a production system. Real deployments involve infrastructure, governance, and workflow integration.
Deployment Architectures
One pattern is a central aggregator with local clients. A coordinating institution or vendor hosts the Flower server in a secure environment.
Each hospital runs a Flower client inside its own network. Clients connect outbound to the server over secure channels and never expose raw EHR data.
Another pattern uses a neutral third-party consortium. Academic or public health bodies host the aggregator while hospitals remain data controllers.
Your architecture must respect legal agreements, risk assessments, and network constraints. Federated learning changes where computation happens, but not the need for strong governance.
Data Pipelines: Training vs Inference
For training, each hospital typically uses batch extracts. Every month or quarter, they generate de-identified tables with relevant features and labels.
The local FL client then reads these tables to run training. That process can be scheduled at low-impact times, like nights or weekends.
For inference, you usually need near-real-time predictions. When a discharge order is placed, the EHR calls a local prediction service that loads the latest model.
The feature pipeline must be consistent between training and inference. Versioning both models and feature transformations is critical.
MLOps and Observability
At a minimum, you should monitor:
- Training metrics (loss, AUC) over FL rounds.
- Per-hospital performance and participation rates.
You also want drift detection on inputs and outputs. If a hospital changes its coding or patient mix, your model may become miscalibrated.
Logs should capture failures and anomalies in client updates. If one hospital sends obviously corrupted updates, you need to isolate and investigate.
Performance and Cost Trade-offs
Tabular MLPs are lightweight and run fine on CPUs. That’s important for hospitals with limited hardware.
Major costs include engineering time, network communication, and governance overhead. You trade off central storage costs and legal complexity for distributed orchestration complexity.
Model size and update frequency influence bandwidth and compute costs. Smaller models and less frequent rounds reduce costs but may slow learning.
Risks, Ethics, Safety, and Governance
Privacy and Security in Federated Settings
Federated learning reduces the need to centralise sensitive data, but it does not eliminate all privacy risks.
Model updates can theoretically leak information about training data. Attackers might exploit gradients or weight changes to infer membership or attributes.
Mitigations include:
- Encrypting communication channels and authenticating clients.
- Using secure aggregation so the server sees only combined updates.
- Adding differential privacy by injecting noise into updates (mathematical guarantees, typically at some accuracy cost).

Bias and Fairness Across Hospitals
Readmission risk is deeply tied to social and structural factors, including socioeconomic status, race, and access to care.
Hospitals serving different populations may have very different distributions. Federated aggregation can overweight larger, wealthier institutions.
You should monitor performance by hospital and by subgroup. Look for systematic under- or over-prediction in vulnerable populations.
Addressing fairness might involve reweighting contributions, adding constraints, or adjusting thresholds. It also requires clinician and ethicist input, not just technical fixes.
Robustness and Failure Modes
Federated systems introduce new failure modes. For example, a hospital’s data pipeline might break and silently produce nonsensical features.
A misconfigured client could send extreme or adversarial updates. Without safeguards, this can destabilise the global model.
Mitigation strategies include:
- Sanity checks on client losses and gradient norms.
- Robust aggregation methods that down-weight outlier updates.
- Rollback mechanisms for global models (restore a known good version).
Human-in-the-Loop Clinical Use
Risk scores should support clinicians, not replace them. They can help prioritise limited post-discharge resources.
You must design UIs and workflows that encourage critical interpretation. That includes showing explanations, ranges, and confidence where possible.
Policies should state clearly how the model should and shouldn’t be used. For example: “This score helps triage follow-up calls but does not decide admission.”
Case Study: Simulated 30-Day Readmission Across Five Hospitals
Scenario Overview
Imagine five regional hospitals in a shared network. They want to build a joint readmission model, but cannot pool data.
Each hospital has a few tens of thousands of encounters. Alone, their models are unstable; together, they could be robust.
They agree on a common schema and label. They deploy the Flower client in each hospital and a central aggregator at the network level.

Data Sources and Challenges
Each site extracts data from its own EHR and data warehouse. They compute features like age, length of stay, prior admissions, and ICU stays.
Differences appear quickly. Urban hospitals see more complex multi-morbid patients, while others see more elective surgeries.
Some sites have a more complete recording of social factors. Others have patchy lab data or inconsistent coding of discharge destinations.
These differences create non-IID data across clients. That makes FL more realistic and also more challenging.
Model Training and Results
They run the federated training script overnight. Ten to twenty rounds of FedAvg complete without incident.
After training, they evaluate the global model per hospital. They see AUCs in the low- to mid-$0.7$ range, with modest variation by site.
Calibration plots show reasonable alignment between predicted and observed risks. Some under-prediction for the very elderly suggests room for feature engineering improvements.
They also compare to local models trained on each site’s data only. Most hospitals see improved performance with the federated model, especially smaller ones.
Translating to Clinical Action
The network defines a risk threshold for flagging patients. This is tuned based on local capacity for follow-up interventions.
At discharge, each hospital’s EHR calls a local prediction service. Patients above the threshold are added to care management queues for closer follow-up.
Clinicians receive both the risk score and a brief explanation. For example, key contributing factors such as frequent prior admissions and long stays.
The team monitors how many patients are flagged weekly. They also track readmission rates over time and across subgroups.
Skills Mapping and Learning Path
Programming and Data Skills
This project exercises core Python skills. You manage virtual environments, run scripts, and structure a small codebase.
On the data side, you manipulate tabular data with pandas. You handle missing values, basic feature engineering, and train/test splits.
You also use scikit-learn for preprocessinglikekstandardisationnt reinforces the idea of separating raw input from model-ready features.
ML and Federated Learning Skills
You build and train a small neural network in PyTorch. You work with losses, optimisers, and training loops.
You learn how to evaluate models properly on imbalanced data. ROC AUC, precision-recall, and calibration become part of your vocabulary.
Federated learning introduces new concepts. You understand clients, servers, parameter aggregation, and non-IID challenges.
Flower gives you practical experience with FL tooling. You implement a client, configure a strategy, and run multi-client simulations.
Systems and MLOps Skills
The project hints at distributed system design. You think about where training runs, where models are stored, and how predictions are served.
You start to see the importance of logging and monitoring. Training metrics, client participation, and drift all need observability.
You encounter trade-offs among compute, bandwidth, and latency. That sets you up for more advanced work in MLOps and production ML.

Healthcare Domain Skills
You learn how clinical concepts become features. Length of stay, prior utilisation, and discharge destination move from jargon to data columns.
You see how model metrics tie back to patient outcomes. Better AUC or calibration is only useful if it leads to safer, more efficient care.
You also become more aware of equity and fairness issues, including subgroup performance and the risk of encoding structural biases.
If you want this article to become a portfolio-ready project, pick one extension and ship it end-to-end. Treat each bullet below as a “mini milestone” you can demo in a repo, a short write-up, and (ideally) a small dashboard.
Compare logistic regression, gradient boosting, and neural nets in the FL setting. Start with a strong baseline (logistic regression), then add XGBoost/LightGBM and a small MLP to quantify trade-offs. If you want a structured path for model selection and evaluation, explore our Data Science & AI Bootcamp](/en/courses/data-science-and-ai)
Simulate more extreme non-IID scenarios and try alternative strategies like FedProx. Make one hospital “older + higher comorbidity” and another “short LOS + fewer readmissions,” then measure per-site calibration drift. To deepen your understanding of privacy, threat models, and secure training pipelines, see our Cyber Security Bootcamp
Add richer evaluation: calibration plots, decision curves, and cost-sensitive metrics. In healthcare, ranking is not enough - you need calibrated probabilities and thresholds tied to real care capacity.
Integrate simple explainability, such as feature importance or SHAP, at prediction time. Build a short “why this patient is high-risk” view and validate it with a clinician-style sanity check (no PHI, no overclaims). If you want to learn how to communicate model outputs clearly and design usable dashboards, our UX/UI Design Bootcamp
Prototype a tiny REST API for inference (closest step toward EHR integration). Wrap your model in a FastAPI service with input validation, logging, and versioned endpoints (v1/v2), then add a smoke-test script. For practical backend and deployment fundamentals, explore our Web Development Bootcamp
If you want feedback and a clear plan, choose your track (ML, security, backend, or UX) and map the next 4 weeks. Explore all bootcamps and formats here:
Our Courses
Conclusion
Federated learning fits hospital readmission prediction naturally because the data is sensitive and siloed, but the objective is shared. You can train a useful global model while keeping EHR data inside each hospital’s environment.
With FedAvg + Flower + PyTorch, you now have a practical blueprint: local training, secure aggregation, and clinically meaningful evaluation. That foundation is strong enough to evolve into a real-world pilot when paired with governance, monitoring, and careful workflow design.
Production success depends on more than model code.
You’ll need reliable data pipelines, access control, audit-friendly logging, drift monitoring, and clinician review loops.
If you’re ready to build skills toward real healthcare ML systems, pick a program that matches your goal:
If you’re ready to take the next step, you can start here: