Федеративное обучение на Edge с 256 МБ: гайд по архитектуре и спарсификации | AiManual
AiManual Logo Ai / Manual.
29 Апр 2026 Гайд

Федеративное обучение на Edge-устройствах с памятью до 256 МБ: практические рекомендации и архитектура

Как запустить федеративное обучение на микроконтроллерах и IoT-устройствах с 256 МБ RAM. HeteroFL, экстремальная спарсификация градиентов, пошаговый план и код.

Сколько памяти нужно для «хоть какого-то» FL?

Возьмите стандартный фреймворк федеративного обучения (Flower, PySyft, TensorFlow Federated). Запустите его на клиенте с PyTorch и моделью ResNet-18. Что увидите? ~1,5 гигабайта занятой памяти только под граф вычислений. А если мы говорим про устройства вроде ESP32-P4, Raspberry Pi Zero 2W или промышленные Cortex-M7 — там максимум 256 МБ RAM, а часто — 64 МБ. И это общая память, а не только под модель.

В классических статьях вам скажут: «кластеризуйте устройства» или «используйте лёгкие CNN». Но на практике модель MobileNetV2 (1.0) с полным градиентом весит ~500 МБ на этапе обучения. Потому что обратное распространение хранит активации, веса и momentum. Влезть в 256 МБ — задача сродни «впихнуть невпихуемое». Но она решаема, если пожертвовать точностью и переписать движок.

Спойлер: мы не будем гонять PyTorch на микроконтроллере. Будем использовать TFLite Micro + ONNX Runtime + ручное управление памятью.

Ключевые фишки, которые спасут проект

Есть два слона, стоящих на страже памяти в федеративном обучении на edge: гетерогенность моделей (HeteroFL) и экстремальная спарсификация градиентов. Без них вы даже не начнёте.

HeteroFL — каждый клиент учит свою долю модели

Идея простая: сервер держит полную модель (например, с 1M параметров), а каждому клиенту выдаёт подсеть (subnetwork), кратную коэффициенту α. Если α = 0.25 — клиент получает 25% каналов в каждом слое. Если α = 0.0625 — жалкие 6% (около 62K параметров). Такая модель влезает в 256 МБ с запасом, даже с квантованием int8.

HeteroFL (оригинальная работа 2022 года) использует маски для обнуления нейронов. Но на устройствах с 256 МБ даже маски могут быть дороги. Практический совет: заранее нарезайте модель на фиксированные подсети и сохраняйте их как отдельные TFLite-файлы. Так вы переносите вычислительную нагрузку на этап компиляции.

Спарсификация градиентов: от 100% к 1%

Обычно клиент отправляет серверу все градиенты (float32 на параметр). При 1M параметров — 4 МБ на одну посылку. Вроде немного? Но если клиентов 1000 и раундов 100 — это 400 ГБ трафика. И это не учитывая, что на устройстве 256 МБ памяти для хранения всех градиентов уже не хватает.

Спарсификация решает обе проблемы: мы вычисляем все градиенты, но отправляем только top-k (обычно 1-5%) по абсолютной величине. Остальные зануляем и накапливаем локально (error feedback).

💡
Исследования 2025-2026 годов (например, SparsifiedFL от Google) показывают, что при 1% отправляемых градиентов точность падает всего на 2-3% на CIFAR-10 и ImageNet, если использовать gradient accumulation с моментом.

Архитектура для 256 МБ: пошаговый план

Разберём реальную реализацию. Целевые устройства: ESP32-S3 (512KB — шутка, но есть модели с 256 МБ внешней PSRAM), Raspberry Pi Zero 2W (512 МБ), nRF5340 (1 МБ — нет). Для определённости возьмём модуль ESP32-P4 с 256 МБ PSRAM (доступен с 2025 года).

1 Выбор модели и квантование

Мы выбрали MobileNetV3-Large (5.4M параметров). После квантования int8 — ~1.4 МБ. Но во время обучения нам нужно хранить веса (int8), активации (float16 для обратного прохода) и градиенты (float16). Подсчёт: 5.4M * (1+2+2) = 27 МБ. Плюс буфер для spaRSE (ещё ~1 МБ). Итог — 30 МБ под модель. Остальные 226 МБ — под данные и систему.

В реальности система (RTOS, драйверы, стек связи) съедает ~50-70 МБ. Остаётся ~160 МБ. Хватает, чтобы загрузить батч из 8-16 изображений 224x224 (ещё ~20-30 МБ).

2 Сборка HeteroFL-подсетей

На сервере (GPU/CPU) мы один раз генерируем несколько подсетей с α = [0.0625, 0.125, 0.25, 0.5, 1.0]. Для каждой подсети создаём отдельный TFLite-файл с весами в int8. На устройстве хранится только одна подсеть (соответствующая его возможностям).

# Генерация подсети для заданного alpha
import tensorflow as tf

def heterofl_subnet(base_model, alpha):
    """Создаёт подсеть с долей каналов alpha."""
    config = base_model.get_config()
    for layer in config['layers']:
        if 'filters' in layer['config']:
            layer['config']['filters'] = max(1, int(layer['config']['filters'] * alpha))
        if 'units' in layer['config']:
            layer['config']['units'] = max(1, int(layer['config']['units'] * alpha))
    subnet = tf.keras.models.Model.from_config(config)
    # Инициализация весов из базовой модели (обрезка)
    for i, layer in enumerate(base_model.layers):
        if layer.get_weights():
            w = layer.get_weights()
            subnet.layers[i].set_weights([arr[:,:,:,:] for arr in w])  # упрощённо
    return subnet

3 Экстремальная спарсификация градиентов

На устройстве мы используем библиотеку TFLite Micro с кастомным оператором для sparse backward. Но после вычисления градиентов (float16) применяем Top-k спарсификацию с k = 1% от числа параметров. Отправляем индексы (int32) и значения (float16). Пример реализации:

import numpy as np

def topk_sparsify(gradients, fraction=0.01):
    """Возвращает разреженные градиенты (индексы, значения)."""
    flat = gradients.flatten()
    k = max(1, int(fraction * len(flat)))
    indices = np.argpartition(-np.abs(flat), k)[:k]
    values = flat[indices]
    error = np.zeros_like(flat)
    error[indices] = values
    return indices, values, error

Ошибка (error) накапливается и добавляется к градиентам следующего батча. Без error feedback точность падает катастрофически.

4 Коммуникация и агрегация

Используем протокол MQTT (довольно лёгкий, заголовки ~2 байта) поверх TCP. На сервере — Flower-стратег, который принимает разреженные обновления и усредняет их по схеме Federated Averaging с учётом весов клиентов (количество локальных шагов).

Критический нюанс: индексы в спарсифицированных градиентах обычно не пересекаются у разных клиентов. Поэтому сервер должен поддерживать плотную версию глобальной модели. Каждое обновление распаковывается в плотный вектор, добавляется в буфер, затем усредняется.

5 Локальное обучение с контролем памяти

На устройстве после загрузки подсети (1-2 МБ) и данных мы запускаем цикл:

  • Один батч (например, 8 изображений) -> forward (int8, но активации в float16).
  • Вычисление loss -> backward -> градиенты float16.
  • Спарсификация (top-k 1%) -> копим error.
  • Повторяем 10-50 шагов (локальные эпохи).
  • Отправляем суммарные разреженные обновления на сервер.

Важно: для работы в 256 МБ мы используем двойную буферизацию данных (считываем следующий батч, пока обрабатываем текущий). Это требует ~32 МБ на буфер.

Типичные ошибки и как их избежать

  1. Слишком агрессивная спарсификация (0.1%) — градиенты теряются, модель не сходится. Начинайте с 2-5%, затем снижайте.
  2. Отсутствие error feedback — даже при 5% точность падает на 10-15%. Error feedback обязателен.
  3. Полная модель на сервере без квантования — сервер должен использовать float32, иначе спарсификация неэффективна.
  4. HeteroFL с одинаковым α для всех — если одно устройство может взять 50%, а другое 6%, первое будет учить больше данных. Нужно нормализовать частоту раундов или вес при агрегации.

В 2026 году некоторые промышленные edge-устройства с 256 МБ (например, модули на базе StarFive JH7110) поддерживают RISC-V с M-расширением. TFLite Micro под них уже портирован, но спарсификацию градиентов придётся писать вручную на Си.

Стартовый код серверной стороны (Flower + PyTorch)

Поскольку большинство edge-фреймворков не поддерживают обучение, мы эмулируем HeteroFL и спарсификацию на клиентах, но агрегацию делаем на сервере. Полный код доступен по ссылкам (см. ниже), а здесь покажу ключевые моменты.

import flwr as fl

class SparseFed(fl.server.strategy.FedAvg):
    def aggregate_fit(self, rnd, results, failures):
        # results: список (ClientProxy, FitRes) с разреженными обновлениями
        total_weight = 0.
        aggregated = None
        for _, res in results:
            # res.parameters - это список тензоров (int64 indices + float16 values)
            weight = res.num_examples
            total_weight += weight
            # распаковка в плотный вектор
            dense = deserialize_sparse(res.parameters)
            if aggregated is None:
                aggregated = dense * weight
            else:
                aggregated += dense * weight
        aggregated /= total_weight
        # обновление глобальной модели
        return aggregated, {}

За подробностями о квантовании и выборе архитектуры для edge рекомендую заглянуть в статью «GLM-4.5-Air на 2-3 битных квантованиях: инструкция по выживанию для 48 ГБ RAM» — приёмы сжатия моделей применимы и здесь.

Часто задаваемые вопросы

Какой процент спарсификации оптимален для 256 МБ?

Обычно 1-2% от числа параметров. Если модель содержит 5 миллионов параметров, 1% — это 50 тысяч градиентов. Для их хранения и отправки нужно примерно 50K * (4+2) = 300 КБ на раунд. Коммуникация — узкое место, не память.

Можно ли обновлять только часть модели (например, последние слои)?

Да, это называется local fine-tuning. Часто замораживают первые слои, которые извлекают низкоуровневые признаки, и обучают только классификатор. Памяти нужно ещё меньше — около 50% от полного объёма.

Как HeteroFL влияет на точность?

При α = 0.25 точность падает на 2-5% по сравнению с полной моделью. При α = 0.0625 — на 10-15%. Рекомендуется использовать нижнюю границу α=0.125 для приемлемого качества. Детали — в оригинальной статье HeteroFL (2022) и наших тестах 2025-2026.

В итоге

256 МБ — это не приговор. Если вы умеете резать модель, квантовать веса и выбрасывать 99% градиентов, федеративное обучение на микроконтроллерах становится реальностью. Но готовьтесь к тому, что точность будет не 98%, а 90-92%, а скорость сходимости — раз в 5-10 медленнее, чем на сервере.

Совет для тех, кто начинает: не пытайтесь сделать «универсальный клиент». Сначала добейтесь стабильной работы на одном типе устройств (ESP32-P4), а потом масштабируйте. И не забывайте про оптимизацию оперативной памяти для LLM — некоторые приёмы (например, paged attention) можно адаптировать для сверточных сетей.

Подписаться на канал