Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
357152c
create load meg data
inoue26 Apr 10, 2023
8550ef1
run training. simple case
inoue26 Apr 11, 2023
b3e7f21
check montage
inoue26 Apr 13, 2023
92484b6
BUGFIX: calculate loss backward
inoue26 Apr 13, 2023
6b3a989
ADD: seq2static model
inoue26 Apr 14, 2023
c7e948b
ADD: seq2static model
inoue26 Apr 17, 2023
8227711
ADD: sampler
inoue26 Apr 17, 2023
120a7f1
Merge branch 'main' of https://github.com/arayabrain/MEG-decoding
inoue26 Apr 17, 2023
4a4e1e5
notebooks
inoue26 Apr 17, 2023
3e5a95f
add: resample_fs
inoue26 Apr 17, 2023
79565c5
add: bandpass filter
inoue26 Apr 17, 2023
721b3a5
add: topk evaluate
inoue26 Apr 17, 2023
3052b25
move: evaluate.py
inoue26 Apr 17, 2023
e6c885b
mod: load weight in evaluate.py
inoue26 Apr 17, 2023
185bc25
mod
inoue26 Apr 17, 2023
cb7395c
mod: total test sample is 50
inoue26 Apr 17, 2023
97f3ad2
mod: use label instead of features
inoue26 Apr 17, 2023
fcf359d
merge
inoue26 Apr 17, 2023
6cf7cb3
merge
inoue26 Apr 17, 2023
9bd19a7
FIX: zeroshot-classification
inoue26 Apr 17, 2023
3d441e9
maege
inoue26 Apr 18, 2023
05ece4d
ADD: another EXP
inoue26 Apr 18, 2023
f439098
MOD: baseline correction last
inoue26 Apr 18, 2023
f0a1c73
FIX: not closed )
inoue26 Apr 18, 2023
26100e4
ADD: classification like liss
inoue26 Apr 18, 2023
5bba8d1
FIX: new loss bugs
inoue26 Apr 18, 2023
91cf622
FIX: createion target for val
inoue26 Apr 19, 2023
050bf2c
ADD: matlab style acc calculation
inoue26 Apr 19, 2023
69bdd78
FIX: bugs
inoue26 Apr 19, 2023
9822c86
ADD: category
inoue26 Apr 19, 2023
e655f55
ADD: corr check between train image and test image
inoue26 Apr 19, 2023
e8dec42
ADD: binary cross entropy
inoue26 Apr 20, 2023
674d535
ADD: pairwise accuracy via similarity
inoue26 Apr 20, 2023
64725a4
FIX: modulation of evaluation code
inoue26 Apr 20, 2023
c5fd7de
ADD: regression loss and similarity loss
inoue26 Apr 20, 2023
3a0d28a
MOD: LinearEmcoder input args
inoue26 Apr 20, 2023
f0f8bc0
MOD
inoue26 Apr 20, 2023
5b78099
Merge branch 'main' of https://github.com/arayabrain/MEG-decoding
inoue26 Apr 20, 2023
96ffa88
FIX: roi ch diff between matlab and python
inoue26 Apr 21, 2023
20b832a
ADD: kamitani lab regression
inoue26 Apr 24, 2023
9054f63
ADD: various loss
inoue26 Apr 27, 2023
b846981
ADD: corr of corr
inoue26 Apr 27, 2023
7997573
ADD: corr_corr of image and meg is higher when 200-400, 2-5Hz occipital
inoue26 Apr 27, 2023
2017504
ADD: EEGNet
inoue26 Apr 27, 2023
58c72f6
ADD: prevent overfitting
inoue26 Apr 28, 2023
c8af3fd
ADD: evaluate
inoue26 Apr 28, 2023
b77e5b0
MOD: can load weight
inoue26 Apr 28, 2023
3482d51
ADD: outlier analysis
inoue26 Apr 30, 2023
5cde617
ADD: same label loss
inoue26 Apr 30, 2023
25a32a8
ADD: regression. this is the best model I have ever seen, including R…
inoue26 May 1, 2023
8624f0f
:q
inoue26 May 1, 2023
fff87ce
ADD: regression cvfor all subj
inoue26 May 1, 2023
922ba79
ADD: regression for subs
inoue26 May 1, 2023
5922861
Report
inoue26 May 9, 2023
19dbd96
pred from database including imagenet cal
inoue26 May 9, 2023
78d603a
sbj01/02/03 is trained in single sub way
inoue26 May 16, 2023
df451e4
ADD: cogitat implementation
inoue26 May 16, 2023
ddcf89b
MOD: dataset can normalize per subject
inoue26 May 16, 2023
397ac7c
merge
inoue26 May 16, 2023
8e08951
Merge branch 'source_reconst' of https://github.com/arayabrain/MEG-de…
inoue26 May 16, 2023
1eead4f
merge
inoue26 May 16, 2023
5efefa0
MOD: gather
inoue26 May 16, 2023
c7b0c3c
merge
inoue26 May 16, 2023
c458901
merge done
inoue26 May 16, 2023
cdb8e37
FIX: device_erroe
inoue26 May 16, 2023
646e16b
FIX bug
inoue26 May 16, 2023
99d1a9f
MOD: deepsets
inoue26 May 16, 2023
0cd0ec7
ADD: check kernel notebook
inoue26 May 19, 2023
ec74f8f
ADD: source reconstruction dataset class
inoue26 May 19, 2023
ed5b84b
MOD: compatibility granted for old ver config
inoue26 May 19, 2023
4bd38b7
MOD: order of preprocess. src reconstruction is placed before bandpas…
inoue26 May 19, 2023
608e770
FIX: zfill
inoue26 May 19, 2023
b4476eb
FIX
inoue26 May 19, 2023
34ebf59
FIX
inoue26 May 19, 2023
34b3868
FIX: model num_channels
inoue26 May 19, 2023
2e91b4d
MOD: transpose common_kernel
inoue26 May 23, 2023
cf9ae86
MOD: for run all recons
inoue26 May 31, 2023
fd124bf
ADD: cross subject corelation
inoue26 May 31, 2023
0ba76b9
zero-shot learning with text
inoue26 Jun 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ wandb/
train.sh
*_debug*
nohup.out
logs/*
data/GOD/*.npy
data/ImageNet/*
data/prompts/*
35 changes: 35 additions & 0 deletions assets/evaluate.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
function acc_category_identification = identify_category(predicted_Y)
%% predicted_Y: trials * space dimensions
% global variables "D" for directory path and "P" for subject information.
global D P
% "val_index" that specifies the indices of the validation trials: 1*300 (trials)
load(fullfile(D.dp.exp_data, 'rand', P.sbj.name{P.ind.sbj}, 'val_index.mat'), 'val_index');
% space.vec: 50 (images) *512 (CLIP dimension)
space = load(fullfile(D.dp.exp_data, 'val', 'clip_image.vec.mat'));
n_space_vec = size(space.vec, 1); % number of image vectors
acc_tmp = zeros(size(predicted_Y,1), 1);
% iterating over each predicted vector
for i_pred = 1:size(predicted_Y,1)
space_corr = zeros(n_space_vec, 1);
% iterating over each image vector
% calculating the correlation coefficient between the current predicted
% vector and the image vector using the "corrcoef" function
for i_space = 1:n_space_vec
R = corrcoef(predicted_Y(i_pred,:), space.vec(i_space,:));
space_corr(i_space) = R(1,2);
end
% assigning the index of the current predicted vector to "i_image"
i_image = val_index(i_pred);
% calculating the accuracy of the current predicted vector by counting
% the number of image vectors whose correlation coefficients are less
% than the correlation coefficient of the corresponding image vector
% and dividing by the total number of image vectors minus one
acc_tmp(i_pred) = numel(find(space_corr<space_corr(i_image)))/(n_space_vec-1);
end
% the overall accuracy of category identification based on all the predicted vectors
acc_category_identification = mean(acc_tmp);
end
247 changes: 247 additions & 0 deletions classfication_by_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@

import numpy as np
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
from PIL import ImageFile # 大きな画像もロード
import pickle
import json
import os, sys, random
import torch
import torch.nn as nn
from time import time
from tqdm import tqdm, trange
from termcolor import cprint
import pandas as pd

from torch.utils.data import DataLoader, RandomSampler, BatchSampler
try:
from meg_decoding.models import get_model, Classifier
from meg_decoding.utils.get_dataloaders import get_dataloaders, get_samplers
from meg_decoding.dataclass.god import GODDatasetBase, GODCollator
from meg_decoding.utils.loggers import Pickleogger
from meg_decoding.utils.vis_grad import get_grad
from torch.utils.data.dataset import Subset
import matplotlib.pyplot as plt
import seaborn as sns
except ModuleNotFoundError :
pass




def get_language_model(prompt_dict:dict, savedir):
if os.path.exists(os.path.join(savedir, 'text_features')):

text_features = torch.load(os.path.join(savedir, 'text_features'))
with open(os.path.join(savedir, 'prompts.txt'), 'r') as f:
prompts = f.readlines()
else:
import clip
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
model = model.eval()
prompts = []
prefix = prompt_dict['prefix']
for i, t in prompt_dict.items():
if i == 'prefix':
continue
prompts.append(t+'\n')
text = clip.tokenize([prefix + s.replace('\n','') for s in prompts]).to(device)
with torch.no_grad():
text_features = model.encode_text(text)
# with open(os.path.join(savedir, 'text_features'), 'wb') as f:
torch.save(text_features, os.path.join(savedir, 'text_features'))
with open(os.path.join(savedir, 'prompts.txt'), 'w') as f:
f.writelines(prompts)
return text_features, prompts

def evaluate(args, text_features, prompts, savedir, eval_sbj='1'):
device = "cuda" if torch.cuda.is_available() else "cpu"
from meg_decoding.utils.reproducibility import seed_worker
# NOTE: We do need it (IMHO).
if args.reproducible:
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.use_deterministic_algorithms(True)
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
g = torch.Generator()
g.manual_seed(0)
seed_worker = seed_worker
else:
g = None
seed_worker = None
source_dataset = GODDatasetBase(args, 'train', return_label=True)
outlier_dataset = GODDatasetBase(args, 'val', return_label=True,
mean_X= source_dataset.mean_X, # testデータの統計情報をしれない
mean_Y=source_dataset.mean_Y,
std_X=source_dataset.std_X,
std_Y=source_dataset.std_Y
)
# import pdb; pdb.set_trace()
text_features -= source_dataset.mean_Y
text_features /= source_dataset.std_Y
text_features = torch.Tensor(text_features).to(device)

if eval_sbj == '1':
ind_tr = list(range(100))# list(range(0, 3000)) + list(range(3600, 6600)) #+ list(range(7200, 21600)) # + list(range(7200, 13200)) + list(range(14400, 20400))
ind_te = list(range(3000,3600)) + list(range(6600, 7200)) # + list(range(13200, 14400)) + list(range(20400, 21600))
ind_out = list(range(0,50))
elif eval_sbj == '2':
ind_tr = list(range(100))# list(range(7200, 7200+3000)) + list(range(10800, 10800+3000))
ind_te = list(range(7200+3000, 7200+3600)) + list(range(10800+3000, 10800+3600))
ind_out = list(range(50,100))
elif eval_sbj == '3':
ind_tr = list(range(100))# list(range(14400, 14400+3000)) + list(range(14400+3600, 14400+6600))
ind_te = list(range(14400+3000,14400+3600)) + list(range(14400+6600, 14400+7200))
ind_out = list(range(100,150))
else:
ind_tr = list(range(0, 3000)) + list(range(3600, 6600)) + list(range(7200, 7200+3000)) + list(range(10800, 10800+3000)) + list(range(14400, 14400+3000)) + list(range(14400+3600, 14400+6600))
ind_te = list(range(3000,3600)) + list(range(6600, 7200)) + list(range(7200+3000, 7200+3600)) + list(range(10800+3000, 10800+3600)) + list(range(14400+3000,14400+3600)) + list(range(14400+6600, 14400+7200))
ind_out = list(range(0,150))
outlier_dataset = Subset(outlier_dataset, ind_out)
train_dataset = Subset(source_dataset, ind_tr)
val_dataset = Subset(source_dataset, ind_te)
train_loader = DataLoader(
train_dataset,
batch_size= args.batch_size,
drop_last=True,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
worker_init_fn=seed_worker,
generator=g,
)
test_loader = DataLoader(
# val_dataset, #
outlier_dataset, # val_dataset
batch_size=50, # args.batch_size,
drop_last=True,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
worker_init_fn=seed_worker,
generator=g,
)

brain_encoder = get_model(args).to(device) #BrainEncoder(args).to(device)

weight_dir = os.path.join(args.save_root, 'weights')
last_weight_file = os.path.join(weight_dir, "model_last.pt")
best_weight_file = os.path.join(weight_dir, "model_best.pt")
if os.path.exists(best_weight_file):
brain_encoder.load_state_dict(torch.load(best_weight_file))
print('weight is loaded from ', best_weight_file)
else:
brain_encoder.load_state_dict(torch.load(last_weight_file))
print('weight is loaded from ', last_weight_file)


classifier = Classifier(args)

Zs = []
Ys = []
Ls = []
brain_encoder.eval()
for batch in test_loader:
with torch.no_grad():

if len(batch) == 3:
X, Y, subject_idxs = batch
elif len(batch) == 4:
X, Y, subject_idxs, Labels = batch
else:
raise ValueError("Unexpected number of items from dataloader.")

X, Y = X.to(device), Y.to(device)

Z = brain_encoder(X, subject_idxs) # 0.96 GB
Zs.append(Z)
Ys.append(Y)
Ls.append(Labels)

Zs = torch.cat(Zs, dim=0)
Ys = torch.cat(Ys, dim=0)
Ls = torch.cat(Ls, dim=0).detach().cpu().numpy()

# 仮説1:判定に偏りがある。-> あるサンプルのimageの特徴量がMEGの潜在空間ににているかどうかを判定するだけの基準になっているのではないか?
Zs = Zs - Zs.mean(dim=0, keepdims=True)
Zs = Zs / Zs.std(dim=0, keepdims=True)
Zs = Zs - Zs.mean(dim=1, keepdims=True)
Zs = Zs / Zs.std(dim=1, keepdims=True)


similarity_meg_text = calc_similarity(Zs, text_features)


text_features *= torch.Tensor(source_dataset.std_Y).to(device)
Ys *= torch.Tensor(source_dataset.std_Y).to(device)
text_features += torch.Tensor(source_dataset.mean_Y).to(device)
Ys += torch.Tensor(source_dataset.mean_Y).to(device)
similarity_image_text = calc_similarity(Ys, text_features)
preds_image_text = np.argmax(similarity_image_text, axis=1)
preds_meg_text = np.argmax(similarity_meg_text, axis=1)
pred_dict = {'image_text_label': [], 'image_text_similarity':[],'meg_text_label':[], 'meg_text_similarity':[]}

for i in range(len(preds_image_text)):

p_it = preds_image_text[i]
p_mt = preds_meg_text[i]
print('{}th: image2text {}'.format(i, prompts[p_it]), similarity_image_text[i], 'meg2text, {}'.format(prompts[p_mt]), similarity_meg_text[i])
pred_dict['image_text_label'].append(prompts[p_it])
pred_dict['image_text_similarity'].append(np.max(similarity_image_text[i]))
pred_dict['meg_text_label'].append(prompts[p_mt])
pred_dict['meg_text_similarity'].append(np.max(similarity_meg_text[i]))
# with open(os.path.join(savedir, 'preds.txt'), 'w') as f:
# f.writelines(pred_labels)
pd.DataFrame(pred_dict).to_csv(os.path.join(savedir, 'preds.csv'))
print('Compatibility image and meg', np.mean([it==mt for it, mt in zip(pred_dict['image_text_label'],pred_dict['meg_text_label'])]))
print('chance is {}'.format(1/len(prompts)))
print('save to ', os.path.join(savedir, 'preds.csv'))
# calc_similarity(Zs, Ys)
# import pdb; pdb.set_trace()

# # MSE
# squared_error = (Zs.unsqueeze(1) - text_features.unsqueeze(0))**2
# squared_error = torch.sqrt(squared_error.mean(dim=-1))
# squared_error = squared_error.cpu().numpy()
# preds = np.argmax(squared_error, axis=1)
# pred_labels= []
# for i, p in enumerate(preds):
# print('{}th: {}'.format(i, prompts[p]), squared_error[i])
# pred_labels.append(prompts[p]+'\n')
# with open(os.path.join(savedir, 'preds_mse.txt'), 'w') as f:
# f.writelines(pred_labels)


def calc_similarity(x, y):
batch_size = len(x)
gt_size = len(y)

similarity = torch.empty(batch_size, gt_size).to('cuda')
for i in range(batch_size):
for j in range(gt_size):
similarity[i, j] = (x[i] @ y[j]) / max((x[i].norm() * y[j].norm()), 1e-8)
return similarity.cpu().numpy()




if __name__ == '__main__':
prompt_root = '/home/yainoue/meg2image/codes/MEG-decoding/data/prompts'
prompt_sub_dir = 'prompt5'
prompt_dir = os.path.join(prompt_root, prompt_sub_dir)
prompt_dict_file = os.path.join(prompt_dir, 'classification1.json')
with open(prompt_dict_file, 'r') as f:
prompt_dict = json.load(f)
text_features, prompts = get_language_model(prompt_dict, prompt_dir)
# exit()
from hydra import initialize, compose
with initialize(version_base=None, config_path="../configs/"):
args = compose(config_name='20230429_sbj01_eegnet_regression')
savedir = os.path.join(args.save_root, 'classification', prompt_sub_dir)
if not os.path.exists(savedir):
os.makedirs(savedir)

prompts = [p.strip() for p in prompts]
evaluate(args, text_features.cpu().numpy(), prompts, savedir, eval_sbj='1')
59 changes: 59 additions & 0 deletions configs/config_GOD.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# ==== Dataset ==== #
dataset: GOD # for Brennan2018, override with Brennan2018 in CLI
rebuild_dataset: False

# ==== Training ==== #
use_wandb: False
wandb:
project: speech_decoding
entity: sensho
run_name: word-onsets
use_sampler: True # applicable to Gwilliams only
reproducible: False # NOTE: do we need it at all?
split_ratio: 0.8 # train. FIXME for valid
split_mode: shallow # sentence, shallow, deep
num_workers: 6
batch_size: 64
updates: 1200
lr: 3e-4
lr_scheduler: none # cosine or multistep or none
lr_multistep_mlstns: [0.4, 0.6, 0.8, 0.9]
lr_step_gamma: 0.5
epochs: 300
reduction: mean

# ==== Architecture ==== #
D1: 270
D2: 320
F: 512 # NOTE: because if you set last4layers=False, then it's set to 1024 in the dataset class
K: 32
d_drop: 0.1 # for spatial attention, drop channels within d_drop of a randomly selected channel

init_temperature: 5.1
wav2vec_model: facebook/wav2vec2-large-xlsr-53 # (HuggingFace) # xlsr_53_56k (FAIR)

# == Data pre-processing parameters === #
preprocs:
audio_resample_rate: 16000 # before wav2vec
lowpass_filter_width: 128
brain_resample_rate: 120 # Hz
brain_filter_low: 1.0 # Hz
brain_filter_high: 60 # Hz
seq_len_sec: 3 # segment length in seconds
baseline_len_sec: 0.5 # baseline period in seconds
shift_brain: True # whether to shift M/EEG into the future relative to audio
shift_len: 150 # if True, by how many ms
last4layers: True # if True, the brain_encoder's emsize will be 1024, not 512
subject_wise: True # whether to scale each subject's EEG dataset individually (only for Brennan2018)
clamp: True
clamp_lim: 20

memory_efficient: True

# ====
num_subjects: 2

# ==== Logging ==== #
hydra:
job:
chdir: True
Loading