Почему ваши данные не должны путешествовать в облако
Представьте, что вы врач. У вас есть тысяча снимков МРТ с редкой патологией. Коллега из другой клиники - тоже. Вместе вы могли бы обучить модель, которая спасает жизни. Но ваши данные - под замком. Юристы, GDPR, HIPAA. Пересылать нельзя. Облако - риск. Что делать?
Традиционное машинное обучение ломается об эту стену. Оно требует централизации данных. Собрать все в одну кучу, залить в GPU, ждать. Федеративное обучение (FL) переворачивает эту логику. Модель сама отправляется в гости к данным, учится локально, возвращается с новыми знаниями. Данные никуда не уезжают.
Flower (Flwr) - не единственный фреймворк для FL. Но он самый гибкий. Не привязывает к TensorFlow или PyTorch. Работает с чем угодно. Даже с вашим кастомным кодом на NumPy. Это как швейцарский нож в мире распределенного ML.
Как Flower заставляет устройства учиться вместе
Архитектура Flower проста до гениальности. Есть сервер (стратег) и клиенты (устройства). Сервер рассылает глобальную модель. Каждый клиент тренирует ее на своих данных. Отправляет обратно только обновления весов - градиенты или новые параметры. Сервер агрегирует эти обновления (например, усредняет) и создает улучшенную глобальную модель. Цикл повторяется.
Клиентом может быть что угодно: смартфон, Raspberry Pi, сервер в больнице, датчик на заводе. Главное - Python и связь с сервером. Это открывает двери для сценариев, о которых раньше думали только теоретики. Например, обучение модели прямо на Raspberry Pi без отправки чувствительных показаний.
FL - не панацея. Если данные на клиентах радикально разные (не-IID распределение), модель может сойти с ума. Представьте, что одна больница лечит только детей, а другая - только пожилых. Модель будет метаться между двумя паттернами. Но об этом позже.
Собираем федеративную систему за 30 минут
Хватит теории. Давайте запустим живой пример. Мы создадим систему для классификации изображений (CIFAR-10), распределенную между тремя клиентами. Каждый клиент получит свой кусок данных. Сервер будет агрегировать результаты.
1 Готовим среду и данные
Установите Flower и PyTorch. Мы используем PyTorch для модели, но вы можете заменить на TensorFlow или JAX.
pip install flwr torch torchvision
Теперь разделим CIFAR-10 между клиентами. Создайте файл dataset.py:
import torch
from torchvision.transforms import Compose, ToTensor, Normalize
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Subset
import numpy as np
def load_datasets(num_clients: int = 3):
transform = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = CIFAR10("./data", train=True, download=True, transform=transform)
testset = CIFAR10("./data", train=False, download=True, transform=transform)
# Делим тренировочные данные между клиентами
partition_size = len(trainset) // num_clients
lengths = [partition_size] * num_clients
datasets = torch.utils.data.random_split(trainset, lengths, torch.Generator().manual_seed(42))
trainloaders = []
valloaders = []
for ds in datasets:
len_val = len(ds) // 10 # 10% для валидации
len_train = len(ds) - len_val
lengths = [len_train, len_val]
ds_train, ds_val = torch.utils.data.random_split(ds, lengths, torch.Generator().manual_seed(42))
trainloaders.append(DataLoader(ds_train, batch_size=32, shuffle=True))
valloaders.append(DataLoader(ds_val, batch_size=32))
testloader = DataLoader(testset, batch_size=32)
return trainloaders, valloaders, testloader
Здесь мы нарезали данные на три части. Каждый клиент получит свой DataLoader. В реальной жизни данные уже распределены, делить ничего не нужно. Но для эксперимента сойдет.
2 Пишем клиента: мозг на edge-устройстве
Клиент в Flower - это класс, который наследует flwr.client.NumPyClient. Он должен уметь делать три вещи: возвращать текущие веса модели, тренироваться на локальных данных и обновлять веса после получения глобальной модели.
import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD
from typing import Dict, Tuple, List
import numpy as np
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def get_parameters(net) -> List[np.ndarray]:
return [val.cpu().numpy() for _, val in net.state_dict().items()]
def set_parameters(net, parameters: List[np.ndarray]):
params_dict = zip(net.state_dict().keys(), parameters)
state_dict = {k: torch.tensor(v) for k, v in params_dict}
net.load_state_dict(state_dict, strict=True)
class FlowerClient(fl.client.NumPyClient):
def __init__(self, trainloader, valloader):
self.net = Net()
self.trainloader = trainloader
self.valloader = valloader
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.net.to(self.device)
def get_parameters(self, config: Dict[str, str]):
return get_parameters(self.net)
def fit(self, parameters: List[np.ndarray], config: Dict[str, str]):
set_parameters(self.net, parameters)
optimizer = SGD(self.net.parameters(), lr=0.01, momentum=0.9)
criterion = nn.CrossEntropyLoss()
self.net.train()
for epoch in range(2): # Локальные эпохи
for images, labels in self.trainloader:
images, labels = images.to(self.device), labels.to(self.device)
optimizer.zero_grad()
outputs = self.net(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
return get_parameters(self.net), len(self.trainloader.dataset), {}
def evaluate(self, parameters: List[np.ndarray], config: Dict[str, str]):
set_parameters(self.net, parameters)
criterion = nn.CrossEntropyLoss()
loss, correct = 0.0, 0
self.net.eval()
with torch.no_grad():
for images, labels in self.valloader:
images, labels = images.to(self.device), labels.to(self.device)
outputs = self.net(images)
loss += criterion(outputs, labels).item()
_, predicted = torch.max(outputs.data, 1)
correct += (predicted == labels).sum().item()
accuracy = correct / len(self.valloader.dataset)
return float(loss), len(self.valloader.dataset), {"accuracy": accuracy}
def client_fn(cid: str) -> FlowerClient:
trainloader, valloader, _ = load_datasets(num_clients=3)
client_id = int(cid)
return FlowerClient(trainloader[client_id], valloader[client_id])
Клиент готов. Обратите внимание: метод fit тренирует модель всего 2 эпохи на своих данных. Это одно из ключевых отличий FL - короткие локальные тренировки. Мы не можем грузить edge-устройства недельными вычислениями.
3 Запускаем сервер: дирижер оркестра
Сервер в Flower управляет процессом. Он решает, когда запрашивать обновления у клиентов, как их агрегировать и когда остановиться. Самый простой способ - использовать встроенную стратегию FedAvg (Federated Averaging).
import flwr as fl
from dataset import load_datasets
# Загружаем данные для тестирования сервера
_, _, testloader = load_datasets()
# Функция для оценки на сервере
def get_evaluate_fn(testloader):
def evaluate_fn(server_round: int, parameters: fl.common.NDArrays, config: Dict[str, str]):
# Здесь можно загрузить модель и проверить на тестовых данных
# Для простоты пропустим, но в реальном проекте это важно
return 0.0, {"accuracy": 0.0} # Заглушка
return evaluate_fn
# Запускаем сервер
strategy = fl.server.strategy.FedAvg(
fraction_fit=1.0, # Доля клиентов для тренировки в каждом раунде (100%)
fraction_evaluate=0.5, # Доля для оценки
min_fit_clients=3, # Минимальное количество клиентов для тренировки
min_evaluate_clients=2, # Минимальное для оценки
min_available_clients=3, # Минимальное доступных клиентов для старта
evaluate_fn=get_evaluate_fn(testloader), # Функция оценки на сервере
)
fl.server.start_server(
server_address="0.0.0.0:8080",
config=fl.server.ServerConfig(num_rounds=3), # Три раунда обучения
strategy=strategy,
)
Сервер запустится на порту 8080. Теперь нужно запустить клиентов. Откройте новые терминалы и в каждом выполните:
import flwr as fl
from client import client_fn
# Запуск клиента с ID от 0 до 2
fl.client.start_client(server_address="127.0.0.1:8080", client=client_fn("0"))
Измените ID на 1 и 2 для других клиентов. Все. Система заработала. Клиенты подключатся, сервер начнет раунды обучения.
Где все ломается: 5 граблей, на которые наступают все
1. Не-IID данные
Самая большая проблема. Если у одного клиента только кошки, а у другого только собаки, модель не сойдется. Или сойдется, но будет бесполезной. Решения: увеличить количество раундов, использовать стратегии вроде FedProx, которые добавляют регуляризацию, или искусственно перемешивать данные (что не всегда возможно).
2. Отключения клиентов
Смартфон разрядился. Сеть пропала. В промышленном IoT такое постоянно. Сервер не должен висеть в ожидании. Настройте min_available_clients и таймауты. Flower умеет работать с частичными ответами.
3. Разная производительность устройств
Один клиент на Tesla V100, другой на Raspberry Pi Zero. Первый завершит тренировку за секунду, второй - за час. Стратегия FedAvg будет ждать самого медленного. Используйте FedAsync или настройте разное количество локальных эпох для разных устройств.
4. Безопасность обновлений
Отправка весов - не всегда безопасна. Из градиентов можно восстановить данные. Добавьте дифференциальную приватность (DP) или безопасное агрегирование (Secure Aggregation). Flower поддерживает оба подхода, но нужно покопаться в документации.
5. Масштабирование
Тысяча клиентов - это не три. Сервер упрется в CPU или память при агрегации. Подумайте о шардировании или иерархической структуре. Иногда проще запустить несколько независимых федераций и потом объединить модели. Как в статье про деление GPU на всех, но с серверами.
Частые вопросы, которые задают после первого запуска
| Вопрос | Ответ |
|---|---|
| Можно ли использовать Flower для тонкой настройки LLM? | Да, но осторожно. Размер модели убьет сеть. Сначала примените техники сжатия. Или настраивайте только последние слои. |
| Как добавить своего клиента на C++? | Через gRPC клиент. Flower использует gRPC, поэтому можно написать клиент на любом языке. Но проще обернуть C++ код в Python с помощью pybind11. |
| Сервер падает с ошибкой "Connection refused" | Клиенты пытаются подключиться раньше сервера. Запускайте сервер первым. И проверьте firewall. |
| Как визуализировать процесс обучения? | Flower имеет callback-систему. Можно писать метрики в TensorBoard или Comet.ml. Или просто в CSV файл. |
| Модель не улучшается после нескольких раундов | Скорее всего, данные не-IID. Попробуйте увеличить количество локальных эпох или использовать стратегию FedAvgM (с моментумом). |
Что дальше? Куда развивать проект
Вы запустили базовый пример. Теперь нужно адаптировать его под реальные задачи. Вот что я делаю дальше в production-проектах:
- Заменяю простую CNN на EfficientNet или Vision Transformer для изображений.
- Добавляю дифференциальную приватность через библиотеку Opacus для PyTorch.
- Внедряю безопасное агрегирование (хотя бы простое шифрование).
- Пишу свой стратег, который учитывает задержки сети и пропускную способность.
- Настраиваю мониторинг: кто из клиентов участвует, как меняется точность, нет ли аномалий в весах.
Самое важное - начать с простого. Не пытайтесь сразу сделать идеальную систему. Запустите три клиента на одной машине. Потом разнесите по разным виртуалкам. Потом добавьте реальные данные. И только потом думайте о безопасности и масштабировании.
Flower - это инструмент, который дает свободу. Но эта свобода требует ответственности. Вы сами должны решить, как агрегировать, как оценивать, как защищать. Это сложнее, чем нажать "Train" в Google Colab, но именно так создаются системы, которые работают в реальном мире, а не в идеальных условиях.
И последний совет: если ваша федеративная модель будет использоваться для инференса на edge-устройствах, подумайте о формате. GGUF отлично подходит для этого. Конвертируйте PyTorch модель в GGUF и запускайте на чем угодно.