Skip to content

Added PLR capability#6

Draft
adityalj wants to merge 5 commits intob-shi:tf32_perffrom
adityalj:adijoshi_plr_tf32
Draft

Added PLR capability#6
adityalj wants to merge 5 commits intob-shi:tf32_perffrom
adityalj:adijoshi_plr_tf32

Conversation

@adityalj
Copy link

No description provided.

reject(state, printRejectionReason, "MIWaveTile0(%u) should be multiple of VectorWidthB(%u)" % (state["MIWaveTile"][1], state["VectorWidthB"]))
return

if ((state["DepthU"] ==state["MatrixInstK"] )and state["PrefetchGlobalRead"]):
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add Wavesizes not supported here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename D_U_iseqMI_K as DuEqMIK

mfmaiter2 = math.ceil(kernel["MIWaveTile"][0]/2) * math.floor(kernel["MIWaveTile"][1]/2)
writer.states.syncPlrMfmaIndex = (mfmaiter0 + mfmaiter1 + mfmaiter2)
if ( kernel["UseF32XEmulation"]) :
writer.states.syncPlrMfmaIndex = writer.states.syncPlrMfmaIndex *3 # TF32
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Complex *4

if kernel["1LDSBuffer"] or kernel["DirectToLds"]:
writer.states.sync1LdsMfmaIndex = max(writer.states.lwStartMfmaIndex - 1, 0)
startIter = writer.states.lwStartMfmaIndex//numMfmaPerIter
if kernel["D_U_iseqMI_K"]:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it more flexible

else:
for vIdx in range(0, numVectorsPerTile):
for eIdx in range(0, numReadsPerVector):
eIdxCnt = numReadsPerVector
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test on usual cases

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generate an example to demonstrate the issue

isBarrier = kernel["LoopIters"] - self.states.numItersPLR
writeItems = list(localWriteCode.items())
macIterItems = macIterCode.flatitems()
numMfmaPerIter = len(macIterItems)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print subiter instead pf iter at appropriate place

itemCounter = 0
for i in range(numMfmaPerIter):
mfmaIndex = iteration * numMfmaPerIter + i
kernel["mfmaIndex"] = kernel["mfmaIndex"] + 1
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

state instead of kernel

iterCode.add(SSetPrior(prior=3, comment="store optimization"))
if (mfmaIndex >= self.states.lwStartMfmaIndex):
numLoops, itemCounter = calculateRangeAndUpdateCounter(itemCounter, localWriteCodeCounts, self.states.numLocalWriteModPerMfma)
if kernel["D_U_iseqMI_K"]:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider DTL scenario

self.makeSchedule(kernel, tensorParametersA, tensorParametersB, localWriteEndIter, skipGlobalReadInc=False, lastLoop=NLLlast, isNGLL=isNGLL)
module.add(self.codes.unrollLoopHeader)

if kernel["D_U_iseqMI_K"]:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depend on quad cycle count

@adityalj adityalj force-pushed the adijoshi_plr_tf32 branch from 580587b to a3fd9d7 Compare July 23, 2025 20:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant