Skip to content

Yu-da-1/fl-demo

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FL Demo — フェデレート学習デモ

データを中央に集めず、端末内で学習し、統計量だけをサーバーに送ることで全体モデルを更新するデモです。
参加実績に応じた報酬台帳(表示のみ)も一緒に扱います。

コンセプト

  • Raw 禁止: サーバーは生データ(X, y)を受け取らない。各クライアントがローカルに保持する。
  • 端末内学習: 各クライアントは w_global を受け取り、自分のデータで 1 step 学習して delta_w を生成する。
  • 統計量のみ送信: Client → Server は delta_w, n_samples, local_loss, receipt に限定。
  • FedAvg 風の集約: サーバーは受け取った delta_wn_samples で重み付け平均し、w_global を更新する。
  • 報酬台帳: update 受理ごとに台帳に加算し、/payout で参加実績と累積金額を表示(送金はしない)。

アーキテクチャ

  • Aggregator Server (1): グローバルモデル w_global とラウンドを保持。N 台分の update が揃うたびに集約して w_global を更新。
  • Client (N 台、デフォルト 3): 各自がローカルで合成データ (X, y) を保持し、GET /model → ローカル 1-step 学習 → POST /update を繰り返す。
[Client-1]  (X1,y1) → delta_w1, n1, loss1 ─┐
[Client-2]  (X2,y2) → delta_w2, n2, loss2 ─┼→ [Server] FedAvg → w_global 更新, round++
[Client-3]  (X3,y3) → delta_w3, n3, loss3 ─┘

計算の内容

モデル: 線形回帰(MSE)

  • 入力: 特徴量 (X \in \mathbb{R}^{n \times d})、目的変数 (y \in \mathbb{R}^{n})
  • 予測: (\hat{y} = X w)(重みベクトル (w \in \mathbb{R}^{d}))
  • 損失(MSE): 予測と実測の差の2乗の平均
    [ \mathrm{loss}(w) = \frac{1}{n} \sum_{j=1}^{n} (x_j^\top w - y_j)^2 ]
  • 損失が小さいほど「予測が実データに近い」=モデルが良い。

クライアント側(1 step SGD)

  1. サーバーから w_global を取得。
  2. local_loss = 更新前の w_global に対する損失: (\mathrm{loss}(w_{\mathrm{global}}))。
  3. 勾配: (\nabla \mathrm{loss}(w) = \frac{2}{n} X^\top (Xw - y))。
  4. 1-step 更新: (w_{\mathrm{local}} = w_{\mathrm{global}} - \mathrm{lr} \cdot \nabla \mathrm{loss}(w_{\mathrm{global}}))。
  5. 送信する差分: (\Delta w = w_{\mathrm{local}} - w_{\mathrm{global}} = - \mathrm{lr} \cdot \nabla \mathrm{loss}(w_{\mathrm{global}}))。
  6. サーバーへ送るのは delta_w, n_samples, local_loss, receipt のみ(X, y は送らない)。

サーバー側(FedAvg 集約)

  • 同一ラウンドで N 台分の update が揃ったら:
    1. 加重平均: (\bar{\Delta} = \frac{\sum_i n_i \cdot \Delta w_i}{\sum_i n_i})
    2. グローバル更新: (w_{\mathrm{global}} \leftarrow w_{\mathrm{global}} + \bar{\Delta})
    3. weighted_loss = 各クライアントの local_loss の加重平均: (\frac{\sum_i n_i \cdot \mathrm{local_loss}_i}{\sum_i n_i})
    4. round を 1 増やし、次のラウンドを受け付ける。

報酬台帳

  • POST /update が受理されるたびに、その client_id の台帳で:
    • rounds_participated += 1
    • amount += reward_per_round(デフォルト 100)
  • 実際の送金は行わず、参加実績の記録・表示のみ。

結果の見方

Metrics(GET /metrics

  • round: 集約が終わったラウンド番号(0, 1, …)。
  • weighted_loss: そのラウンドで報告された local_loss の加重平均。値が下がるほど学習が進んでいる
  • samples: そのラウンドに参加した全クライアントの合計サンプル数。

Payout(GET /payout

  • client_id: クライアント識別子。
  • rounds: update が受理された回数(参加ラウンド数)。
  • amount: 累積報酬額(rounds × reward_per_round)。表示のみで送金はしない。

起動方法

前提

  • Python 3.10 以上を推奨。
  • プロジェクトルートで以下を実行する想定です。

1. 依存関係のインストール(初回のみ)

cd /path/to/fl-demo
pip install -r requirements.txt

2. サーバーの起動

別ターミナルでサーバーを起動し、起動したままにします。

cd /path/to/fl-demo
python -m server.app
  • ポート 8000 で待ち受けます。
  • Uvicorn running on http://0.0.0.0:8000 と表示されれば OK です。
  • 終了するときは Ctrl+C で停止してください。

3. クライアントの実行

もう一つのターミナルで、クライアントを実行します。サーバー起動後に実行してください。

cd /path/to/fl-demo
python -m client.run_clients
  • 3 台のクライアントが並列で動作し、全ラウンド(デフォルト 10)が終わると、/metrics/payout を取得して結果を表示します。

起動順序のまとめ

順番 コマンド 説明
1 pip install -r requirements.txt 初回のみ
2 python -m server.app サーバーを先に起動(ターミナル1)
3 python -m client.run_clients クライアント実行(ターミナル2)

期待される出力例

  • クライアント側: 各ラウンドで round=0,1,...local_loss=...accepted=Trueamount=... のログ。最後に Metrics と Payout のサマリが表示される。
  • サーバー側: received=1/3, 2/3, 3/3 AGGREGATED -> round=... weighted_loss=... のようなログ。

実行後のサマリ例

===== Metrics =====
  round=0  weighted_loss=7.45...  samples=600
  round=1  weighted_loss=5.66...  samples=600
  ...
  round=9  weighted_loss=0.87...  samples=600

===== Payout =====
  client-001  rounds=10  amount=1000
  client-002  rounds=10  amount=1000
  client-003  rounds=10  amount=1000
  • weighted_loss がラウンド進行とともに概ね低下していれば、全体モデルが改善していることを示しています。
  • 各クライアントは 10 ラウンド参加し、1 ラウンドあたり 100 なので amount=1000 になります。

デフォルト設定(v0)

  • ジョブ ID: job-001
  • モデル次元: d=5(線形回帰)
  • 最大ラウンド: rounds_max=10
  • 学習率: lr=0.05
  • 必要クライアント数: n_required=3
  • 1 ラウンドあたり報酬: reward_per_round=100

ファイル構成

パス 役割
server/app.py FastAPI エンドポイント(GET /job, /model, POST /update, GET /metrics, /payout)
server/schemas.py リクエスト・レスポンスの型定義
server/state.py サーバー状態・バリデーション・FedAvg 集約・台帳
client/client.py 1 クライアントの実行ループ(GET /model → 学習 → POST /update)
client/data.py 合成データ生成、MSE loss・勾配の計算
client/run_clients.py N クライアントの並列起動と結果表示
design.md 詳細な設計仕様

参考

  • 詳細な API 仕様・Reject 条件・内部手順は design.md を参照してください。

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages