From 6edc22a8d1db0cb8f0c251271688a7cd33d139e7 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 12 Dec 2024 19:43:56 -0600 Subject: [PATCH 1/3] Refractor monarch variable names to boost readbility --- bert/src/mm/blockdiag_butterfly_multiply.py | 104 +++++++++++++------- bert/src/mm/blockdiag_multiply.py | 42 ++++---- 2 files changed, 90 insertions(+), 56 deletions(-) diff --git a/bert/src/mm/blockdiag_butterfly_multiply.py b/bert/src/mm/blockdiag_butterfly_multiply.py index 49b285b..b91df56 100644 --- a/bert/src/mm/blockdiag_butterfly_multiply.py +++ b/bert/src/mm/blockdiag_butterfly_multiply.py @@ -46,68 +46,98 @@ def blockdiag_butterfly_multiply_reference(x, w1_bfly, w2_bfly, version=2): out2 = rearrange(out2, 'b (l s) -> b (s l)', l=l) return out2 - class BlockdiagButterflyMultiply(torch.autograd.Function): - """This is a faster implementation, with careful memory copies for the fastest bmm performance. The backward pass is also written manually with careful memory copies. Arguments: x: (batch, n) - w1_bfly: (k, q, p), where k = n / p - w2_bfly: (l, s, r), where l = k * q / r = n * q / (p * r) + w1_bfly: (nblocks, blk_blk2_in, blk_sz) + w2_bfly: (nblocks, blk_sz, blk_r) Outputs: out: (batch, m), where m = l * s = n * s * q / (p * r) """ @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.bfloat16) - def forward(ctx, x, w1_bfly, w2_bfly): + @torch.amp.custom_fwd(cast_inputs=torch.bfloat16) + def forward(ctx, x, w1_bfly, w2_bfly, debug_out1=False): batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = np.prod(batch_shape) - k, q, p = w1_bfly.shape - l, s, r = w2_bfly.shape - assert k * p == n - assert l * r == k * q - x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1) - out1 = torch.empty(batch_dim, k, q, device=x.device, dtype=x.dtype).transpose(0, 1) - out1 = torch.bmm(x_reshaped, w1_bfly.transpose(-1, -2), out=out1) - out1 = out1.transpose(0, 1).reshape(batch_dim, r, l).transpose(-1, -2).contiguous().transpose(0, 1) - out2 = torch.empty(batch_dim, l, s, device=x.device, dtype=x.dtype).transpose(0, 1) - out2 = torch.bmm(out1, w2_bfly.transpose(-1, -2), out=out2) - out2 = out2.permute(1, 2, 0).reshape(*batch_shape, s * l) - ctx.save_for_backward(x, w1_bfly, w2_bfly, out1) + seq_dim = np.prod(batch_shape) + + w1_bfly = w1_bfly.to(x.dtype) + w2_bfly = w2_bfly.to(x.dtype) + + # Typically blk1_out = blk2_in and nblocks1 = nblocks2 + # e.g. (4, 4, 1024) + nblocks1, blk1_out, blk1_in = w1_bfly.shape + nblocks2, blk2_out, blk2_in = w2_bfly.shape + assert nblocks1 * blk1_in == n + assert nblocks2 * blk2_in == nblocks1 * blk1_out + + # Typical shape for Llama 7B on Math reasoning: (4, 666, 1024) + x_reshaped = x.reshape(seq_dim, nblocks1, blk1_in).transpose(0, 1) + out1 = torch.empty(nblocks1, seq_dim, blk1_out, device=x.device, dtype=x.dtype) + + # (nblocks1, seq_dim, blk1_in) @ (nblocks1, blk1_in, blk1_out) + out1 = torch.bmm(x_reshaped, w1_bfly.transpose(-1, -2), out=out1) # -> (nblocks1, seq_dim, blk1_out) + del x_reshaped + + # Feature shuffling + out1 = ( + out1.transpose(0, 1).reshape(seq_dim, blk2_in, nblocks2).permute(2, 0, 1) + ) # (seq_dim, nblocks2, blk1_out) -> (.., blk2_in, nblocks2) -> (nblocks2, seq_dim, blk2_in) + + out2 = torch.empty(nblocks2, seq_dim, blk2_out, device=x.device, dtype=x.dtype) + out2 = torch.bmm( + out1, w2_bfly.transpose(-1, -2), out=out2 + ) # (nblocks2, seq_dim, blk2_in) @ (nblocks2, blk2_in, blk2_out) -> (nblocks2, seq_dim, blk2_out) + + out2 = out2.permute(1, 2, 0).reshape( + *batch_shape, blk2_out * nblocks2 + ) # (nblocks2, seq_dim, blk2_out) -> (seq_dim, nblocks2 * blk2_out ) + + ctx.save_for_backward(x, w1_bfly, w2_bfly, out1, None, None) + if debug_out1: + return out2, out1 return out2 @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.custom_bwd def backward(ctx, dout): - x, w1_bfly, w2_bfly, out1 = ctx.saved_tensors + x, w1_bfly, w2_bfly, out1, *_ = ctx.saved_tensors batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = np.prod(batch_shape) - k, q, p = w1_bfly.shape - l, s, r = w2_bfly.shape - # assert k * p == n - # assert l * r == k * q + seq_dim = np.prod(batch_shape) + nblocks1, blk1_out, blk1_in = w1_bfly.shape + nblocks2, blk2_out, blk2_in = w2_bfly.shape + dx, dw1_bfly, dw2_bfly = None, None, None - # dout_reshaped = dout.reshape(batch_dim, sqrtn, sqrtn).permute(2, 1, 0).contiguous() - dout_reshaped = dout.reshape(batch_dim, s, l).transpose(-1, -2).contiguous() - dout_reshaped = dout_reshaped.transpose(0, 1) + + dout_reshaped = dout.reshape(seq_dim, blk2_out, nblocks2).transpose(-1, -2) + dout_reshaped = dout_reshaped.transpose(0, 1).contiguous() # (nblocks2, seq_dim, blk2_out) if ctx.needs_input_grad[2]: - # dw2_bfly = torch.empty(l, s, r, device=w2_bfly.device, dtype=w2_bfly.dtype) + # dw2_bfly = torch.empty(nblocks2, blk2_out, blk2_in, device=w2_bfly.device, dtype=w2_bfly.dtype) # dw2_bfly = torch.bmm(dout_reshaped.transpose(-1, -2), out1, out=dw2_bfly) + + # (nblocks2, blk2_out, seq_dim) @ (nblocks2, seq_dim, blk1_out) -> (nblocks2, blk2_out, blk1_out) dw2_bfly = torch.bmm(dout_reshaped.transpose(-1, -2), out1.conj()) if ctx.needs_input_grad[1] or ctx.needs_input_grad[0]: - dout1 = torch.empty(batch_dim, l, r, device=x.device, dtype=x.dtype).transpose(0, 1) - dout1 = torch.bmm(dout_reshaped, w2_bfly.conj(), out=dout1) - dout1 = dout1.transpose(0, 1).transpose(-1, -2).contiguous().reshape(batch_dim, k, q).transpose(0, 1) - # dout1 = dout1.permute(1, 2, 0).contiguous().transpose(0, 1) + dout1 = torch.empty(nblocks2, seq_dim, blk2_in, device=x.device, dtype=x.dtype) + dout1 = torch.bmm(dout_reshaped, w2_bfly.conj(), out=dout1) # -> (nblocks2, seq_dim, blk2_in) + del dout_reshaped + # dout1 = dout1.transpose(0, 1).transpose(-1, -2).contiguous().reshape(seq_dim, nblocks1, blk1_out).transpose(0, 1) + # NOTE: We do NOT need contiguous in between? This should save memory & time + dout1 = ( + dout1.permute(1, 2, 0).reshape(seq_dim, nblocks1, blk1_out).transpose(0, 1) + ) # -> (nblocks1, seq_dim, blk2_in) if ctx.needs_input_grad[0]: - dx = torch.empty(batch_dim, k, p, device=x.device, dtype=x.dtype) + dx = torch.empty(seq_dim, nblocks1, blk1_in, device=x.device, dtype=x.dtype) + # (nblocks1, seq_dim, blk1_out) @ (nblocks1, blk1_out, blk1_in) -> (nblocks1, seq_dim, blk1_in) dx = torch.bmm(dout1, w1_bfly.conj(), out=dx.transpose(0, 1)).transpose(0, 1).reshape(*batch_shape, n) if ctx.needs_input_grad[1]: - x_reshaped = x.reshape(batch_dim, k, p).transpose(0, 1) + x_reshaped = x.reshape(seq_dim, nblocks1, blk1_in).transpose(0, 1) + # (nblocks2, blk2_in, seq_dim) @ (nblocks2, seq_dim, blk1_out) -> (nblocks2, blk2_in, blk1_in) dw1_bfly = torch.bmm(dout1.transpose(-1, -2), x_reshaped.conj()) - return dx, dw1_bfly, dw2_bfly + return dx, dw1_bfly, dw2_bfly, None, None + blockdiag_butterfly_multiply = BlockdiagButterflyMultiply.apply \ No newline at end of file diff --git a/bert/src/mm/blockdiag_multiply.py b/bert/src/mm/blockdiag_multiply.py index 2b21abe..df1fe60 100644 --- a/bert/src/mm/blockdiag_multiply.py +++ b/bert/src/mm/blockdiag_multiply.py @@ -35,48 +35,52 @@ def blockdiag_multiply_reference(x, weight): class BlockdiagMultiply(torch.autograd.Function): - """This is a faster implementation, with careful memory copies for the fastest bmm performance. The backward pass is also written manually with careful memory copies. Arguments: x: (..., n) - weight: (nblocks, q, n / nblocks) + weight: (nblocks, q, n / nblockblk2_out) Outputs: - out: (..., nblocks * q) + out: (..., nblocks * blk1_out) """ @staticmethod - @torch.cuda.amp.custom_fwd(cast_inputs=torch.bfloat16) + @torch.cuda.amp.custom_fwd() def forward(ctx, x, weight): ctx.save_for_backward(x, weight) batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = np.prod(batch_shape) - nblocks, q, p = weight.shape - assert nblocks * p == n - x_reshaped = x.reshape(batch_dim, nblocks, p).transpose(0, 1) - out = torch.empty(batch_dim, nblocks, q, device=x.device, dtype=x.dtype).transpose(0, 1) - out = torch.bmm(x_reshaped, weight.transpose(-1, -2), out=out).transpose(0, 1) - return out.reshape(*batch_shape, nblocks * q) + seq_dim = np.prod(batch_shape) + nblocks, blk_out, blk_in = weight.shape + assert nblocks * blk_in == n + x_reshaped = x.view(seq_dim, nblocks, blk_in).transpose(0, 1) # (nblocks, seq_dim, p) + + out = torch.empty(nblocks, seq_dim, blk_out, device=x.device, dtype=x.dtype) + out = torch.bmm(x_reshaped, weight.transpose(-1, -2), out=out).transpose( + 0, 1 + ) # (nblocks, seq_dim, blk_sz) @ (nblocks, blk_sz, blk_r) -> (nblocks, seq_dim, blk1_out) + return out.reshape(*batch_shape, nblocks * blk_out) @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout): x, weight = ctx.saved_tensors batch_shape, n = x.shape[:-1], x.shape[-1] - batch_dim = np.prod(batch_shape) - nblocks, q, p = weight.shape - assert nblocks * p == n + seq_dim = np.prod(batch_shape) + nblocks, blk_out, blk_in = weight.shape + assert nblocks * blk_in == n dx, dweight = None, None - dout_reshaped = dout.reshape(batch_dim, nblocks, q).transpose(0, 1) + dout_reshaped = dout.reshape(seq_dim, nblocks, blk_out).transpose(0, 1) if ctx.needs_input_grad[0]: - dx = torch.empty(batch_dim, nblocks, p, device=x.device, dtype=x.dtype) - dx = torch.bmm(dout_reshaped, weight.conj(), - out=dx.transpose(0, 1)).transpose(0, 1).reshape(*batch_shape, n) + dx = torch.empty(seq_dim, nblocks, blk_in, device=x.device, dtype=x.dtype) + dx = ( + torch.bmm(dout_reshaped, weight.conj(), out=dx.transpose(0, 1)).transpose(0, 1).reshape(*batch_shape, n) + ) if ctx.needs_input_grad[1]: - x_reshaped = x.reshape(batch_dim, nblocks, p).transpose(0, 1) + x_reshaped = x.reshape(seq_dim, nblocks, blk_in).transpose(0, 1) dweight = torch.bmm(dout_reshaped.transpose(-1, -2), x_reshaped.conj()) return dx, dweight + blockdiag_multiply = BlockdiagMultiply.apply \ No newline at end of file From 23bea4e4976b9471307f8db8a98d535e7b9f37b7 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Thu, 12 Dec 2024 19:46:55 -0600 Subject: [PATCH 2/3] Fix typo --- bert/src/mm/blockdiag_multiply.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bert/src/mm/blockdiag_multiply.py b/bert/src/mm/blockdiag_multiply.py index df1fe60..5a20da4 100644 --- a/bert/src/mm/blockdiag_multiply.py +++ b/bert/src/mm/blockdiag_multiply.py @@ -40,7 +40,7 @@ class BlockdiagMultiply(torch.autograd.Function): The backward pass is also written manually with careful memory copies. Arguments: x: (..., n) - weight: (nblocks, q, n / nblockblk2_out) + weight: (nblocks, q, n / blk2_out) Outputs: out: (..., nblocks * blk1_out) """ From 3f5274c500fe69f06c82d9e7677265bcf986e3a9 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 17 Jan 2025 18:23:13 -0600 Subject: [PATCH 3/3] Minor fix dtype --- bert/src/mm/blockdiag_multiply.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/bert/src/mm/blockdiag_multiply.py b/bert/src/mm/blockdiag_multiply.py index 5a20da4..00fd8b2 100644 --- a/bert/src/mm/blockdiag_multiply.py +++ b/bert/src/mm/blockdiag_multiply.py @@ -48,7 +48,9 @@ class BlockdiagMultiply(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd() def forward(ctx, x, weight): + weight = weight.to(x.dtype) ctx.save_for_backward(x, weight) + batch_shape, n = x.shape[:-1], x.shape[-1] seq_dim = np.prod(batch_shape) nblocks, blk_out, blk_in = weight.shape @@ -59,6 +61,7 @@ def forward(ctx, x, weight): out = torch.bmm(x_reshaped, weight.transpose(-1, -2), out=out).transpose( 0, 1 ) # (nblocks, seq_dim, blk_sz) @ (nblocks, blk_sz, blk_r) -> (nblocks, seq_dim, blk1_out) + return out.reshape(*batch_shape, nblocks * blk_out) @staticmethod