Skip to content

wangt0815/GWPF

 
 

Repository files navigation

GWPF

Official PyTorch implementation of "Communication-Efficient Federated Learning with Gradient-Wise Parameter Freezing"

Communication bottleneck is a critical challenge in federated learning due to limited network bandwidth. Existing compression methods are often accompanied by the problem of slow model convergence caused by loss of necessary information, while most parameter-based freezing mechanisms are without thawing design, resulting in degradation of model accuracy. This article proposes a gradient-wise parameter freezing (GWPF) mechanism with a thawing strategy for parameter synchronization in federated learning. The server makes a global freezing or thawing decision wisely based on the global aggregation of gradients or the trigger of locally accumulated gradients. The insignificant gradients are excluded from being transmitted between server and workers during the frozen period. The communication overhead is much reduced and the model training is greatly accelerated. The frozen period can be ended in time to aggregate the valuable locally accumulated gradients for the global synchronization, or it can be extended to further reduce the communication overhead. Compared with related frozen period controlling method, GWPF yields better synchronization efficiency and computational resource utilization. We perform the theoretical analysis of the convergence of GWPF for non-convex objectives with non-IID data distribution and conduct extensive experiments to demonstrate its effectiveness and reveal its superiority in different scenarios.

Quick Start

Cloning

git clone https://github.com/Dora233/GWPF
cd GWPF

Installation

pip install -r requirements.txt

Dataset Preparation

from torchvision import datasets
datasets.MNIST('./data', train=True, download=True)
datasets.CIFAR10('./data', train=True, download=True)
datasets.EMNIST('./data', train=True, download=True, split="balanced")

WikiText-2 can be downloaded from https://blog.salesforceairesearch.com/the-wikitext-long-term-dependency-language-modeling-dataset/

Run Simulation

We use a testbed with Nvidia DGX-1 with 32 CPU cores and 5 Tesla V100 GPUs with 32GB memory each. The CUDA version is 10.2. For the base settings, the number of the illustrative workers in each round of training in our experiments is set to 10.

The GWPF algorithm can be tested with ResNet18 on non-IID CIFAR-10 by:

sh train10n.sh
(Enter dataset_model: ) C1
(Enter is_iid: ) F

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 96.7%
  • Shell 3.3%