Line data Source code
1 : //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 : // instrinsics
3 : //
4 : // The LLVM Compiler Infrastructure
5 : //
6 : // This file is distributed under the University of Illinois Open Source
7 : // License. See LICENSE.TXT for details.
8 : //
9 : //===----------------------------------------------------------------------===//
10 : //
11 : // This pass replaces masked memory intrinsics - when unsupported by the target
12 : // - with a chain of basic blocks, that deal with the elements one-by-one if the
13 : // appropriate mask bit is set.
14 : //
15 : //===----------------------------------------------------------------------===//
16 :
17 : #include "llvm/ADT/Twine.h"
18 : #include "llvm/Analysis/TargetTransformInfo.h"
19 : #include "llvm/CodeGen/TargetSubtargetInfo.h"
20 : #include "llvm/IR/BasicBlock.h"
21 : #include "llvm/IR/Constant.h"
22 : #include "llvm/IR/Constants.h"
23 : #include "llvm/IR/DerivedTypes.h"
24 : #include "llvm/IR/Function.h"
25 : #include "llvm/IR/IRBuilder.h"
26 : #include "llvm/IR/InstrTypes.h"
27 : #include "llvm/IR/Instruction.h"
28 : #include "llvm/IR/Instructions.h"
29 : #include "llvm/IR/IntrinsicInst.h"
30 : #include "llvm/IR/Intrinsics.h"
31 : #include "llvm/IR/Type.h"
32 : #include "llvm/IR/Value.h"
33 : #include "llvm/Pass.h"
34 : #include "llvm/Support/Casting.h"
35 : #include <algorithm>
36 : #include <cassert>
37 :
38 : using namespace llvm;
39 :
40 : #define DEBUG_TYPE "scalarize-masked-mem-intrin"
41 :
42 : namespace {
43 :
44 : class ScalarizeMaskedMemIntrin : public FunctionPass {
45 : const TargetTransformInfo *TTI = nullptr;
46 :
47 : public:
48 : static char ID; // Pass identification, replacement for typeid
49 :
50 27457 : explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
51 27457 : initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
52 27457 : }
53 :
54 : bool runOnFunction(Function &F) override;
55 :
56 21 : StringRef getPassName() const override {
57 21 : return "Scalarize Masked Memory Intrinsics";
58 : }
59 :
60 27319 : void getAnalysisUsage(AnalysisUsage &AU) const override {
61 : AU.addRequired<TargetTransformInfoWrapperPass>();
62 27319 : }
63 :
64 : private:
65 : bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
66 : bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
67 : };
68 :
69 : } // end anonymous namespace
70 :
71 : char ScalarizeMaskedMemIntrin::ID = 0;
72 :
73 151909 : INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74 : "Scalarize unsupported masked memory intrinsics", false, false)
75 :
76 27453 : FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
77 27453 : return new ScalarizeMaskedMemIntrin();
78 : }
79 :
80 129 : static bool isConstantIntVector(Value *Mask) {
81 : Constant *C = dyn_cast<Constant>(Mask);
82 : if (!C)
83 : return false;
84 :
85 35 : unsigned NumElts = Mask->getType()->getVectorNumElements();
86 288 : for (unsigned i = 0; i != NumElts; ++i) {
87 253 : Constant *CElt = C->getAggregateElement(i);
88 253 : if (!CElt || !isa<ConstantInt>(CElt))
89 : return false;
90 : }
91 :
92 : return true;
93 : }
94 :
95 : // Translate a masked load intrinsic like
96 : // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
97 : // <16 x i1> %mask, <16 x i32> %passthru)
98 : // to a chain of basic blocks, with loading element one-by-one if
99 : // the appropriate mask bit is set
100 : //
101 : // %1 = bitcast i8* %addr to i32*
102 : // %2 = extractelement <16 x i1> %mask, i32 0
103 : // br i1 %2, label %cond.load, label %else
104 : //
105 : // cond.load: ; preds = %0
106 : // %3 = getelementptr i32* %1, i32 0
107 : // %4 = load i32* %3
108 : // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
109 : // br label %else
110 : //
111 : // else: ; preds = %0, %cond.load
112 : // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
113 : // %6 = extractelement <16 x i1> %mask, i32 1
114 : // br i1 %6, label %cond.load1, label %else2
115 : //
116 : // cond.load1: ; preds = %else
117 : // %7 = getelementptr i32* %1, i32 1
118 : // %8 = load i32* %7
119 : // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
120 : // br label %else2
121 : //
122 : // else2: ; preds = %else, %cond.load1
123 : // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
124 : // %10 = extractelement <16 x i1> %mask, i32 2
125 : // br i1 %10, label %cond.load4, label %else5
126 : //
127 16 : static void scalarizeMaskedLoad(CallInst *CI) {
128 16 : Value *Ptr = CI->getArgOperand(0);
129 : Value *Alignment = CI->getArgOperand(1);
130 : Value *Mask = CI->getArgOperand(2);
131 : Value *Src0 = CI->getArgOperand(3);
132 :
133 16 : unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
134 16 : VectorType *VecType = cast<VectorType>(CI->getType());
135 :
136 16 : Type *EltTy = VecType->getElementType();
137 :
138 16 : IRBuilder<> Builder(CI->getContext());
139 : Instruction *InsertPt = CI;
140 16 : BasicBlock *IfBlock = CI->getParent();
141 : BasicBlock *CondBlock = nullptr;
142 : BasicBlock *PrevIfBlock = CI->getParent();
143 :
144 16 : Builder.SetInsertPoint(InsertPt);
145 16 : Builder.SetCurrentDebugLocation(CI->getDebugLoc());
146 :
147 : // Short-cut if the mask is all-true.
148 16 : if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
149 1 : Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
150 1 : CI->replaceAllUsesWith(NewI);
151 1 : CI->eraseFromParent();
152 1 : return;
153 : }
154 :
155 : // Adjust alignment for the scalar instruction.
156 15 : AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
157 : // Bitcast %addr fron i8* to EltTy*
158 : Type *NewPtrType =
159 30 : EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
160 15 : Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
161 15 : unsigned VectorWidth = VecType->getNumElements();
162 :
163 : // The result vector
164 : Value *VResult = Src0;
165 :
166 15 : if (isConstantIntVector(Mask)) {
167 6 : for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
168 4 : if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
169 : continue;
170 : Value *Gep =
171 1 : Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
172 1 : LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
173 : VResult =
174 1 : Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
175 : }
176 2 : CI->replaceAllUsesWith(VResult);
177 2 : CI->eraseFromParent();
178 2 : return;
179 : }
180 :
181 35 : for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
182 : // Fill the "else" block, created in the previous iteration
183 : //
184 : // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
185 : // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
186 : // br i1 %mask_1, label %cond.load, label %else
187 : //
188 :
189 : Value *Predicate =
190 22 : Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
191 :
192 : // Create "cond" block
193 : //
194 : // %EltAddr = getelementptr i32* %1, i32 0
195 : // %Elt = load i32* %EltAddr
196 : // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
197 : //
198 22 : CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
199 22 : Builder.SetInsertPoint(InsertPt);
200 :
201 : Value *Gep =
202 22 : Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
203 22 : LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
204 22 : Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
205 22 : Builder.getInt32(Idx));
206 :
207 : // Create "else" block, fill it in the next iteration
208 : BasicBlock *NewIfBlock =
209 22 : CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
210 22 : Builder.SetInsertPoint(InsertPt);
211 : Instruction *OldBr = IfBlock->getTerminator();
212 22 : BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
213 22 : OldBr->eraseFromParent();
214 : PrevIfBlock = IfBlock;
215 : IfBlock = NewIfBlock;
216 :
217 : // Create the phi to join the new and previous value.
218 22 : PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
219 22 : Phi->addIncoming(NewVResult, CondBlock);
220 22 : Phi->addIncoming(VResult, PrevIfBlock);
221 : VResult = Phi;
222 : }
223 :
224 13 : CI->replaceAllUsesWith(VResult);
225 13 : CI->eraseFromParent();
226 : }
227 :
228 : // Translate a masked store intrinsic, like
229 : // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
230 : // <16 x i1> %mask)
231 : // to a chain of basic blocks, that stores element one-by-one if
232 : // the appropriate mask bit is set
233 : //
234 : // %1 = bitcast i8* %addr to i32*
235 : // %2 = extractelement <16 x i1> %mask, i32 0
236 : // br i1 %2, label %cond.store, label %else
237 : //
238 : // cond.store: ; preds = %0
239 : // %3 = extractelement <16 x i32> %val, i32 0
240 : // %4 = getelementptr i32* %1, i32 0
241 : // store i32 %3, i32* %4
242 : // br label %else
243 : //
244 : // else: ; preds = %0, %cond.store
245 : // %5 = extractelement <16 x i1> %mask, i32 1
246 : // br i1 %5, label %cond.store1, label %else2
247 : //
248 : // cond.store1: ; preds = %else
249 : // %6 = extractelement <16 x i32> %val, i32 1
250 : // %7 = getelementptr i32* %1, i32 1
251 : // store i32 %6, i32* %7
252 : // br label %else2
253 : // . . .
254 10 : static void scalarizeMaskedStore(CallInst *CI) {
255 10 : Value *Src = CI->getArgOperand(0);
256 : Value *Ptr = CI->getArgOperand(1);
257 : Value *Alignment = CI->getArgOperand(2);
258 : Value *Mask = CI->getArgOperand(3);
259 :
260 10 : unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
261 10 : VectorType *VecType = cast<VectorType>(Src->getType());
262 :
263 10 : Type *EltTy = VecType->getElementType();
264 :
265 10 : IRBuilder<> Builder(CI->getContext());
266 : Instruction *InsertPt = CI;
267 10 : BasicBlock *IfBlock = CI->getParent();
268 10 : Builder.SetInsertPoint(InsertPt);
269 10 : Builder.SetCurrentDebugLocation(CI->getDebugLoc());
270 :
271 : // Short-cut if the mask is all-true.
272 10 : if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
273 : Builder.CreateAlignedStore(Src, Ptr, AlignVal);
274 1 : CI->eraseFromParent();
275 1 : return;
276 : }
277 :
278 : // Adjust alignment for the scalar instruction.
279 9 : AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
280 : // Bitcast %addr fron i8* to EltTy*
281 : Type *NewPtrType =
282 18 : EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
283 9 : Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
284 9 : unsigned VectorWidth = VecType->getNumElements();
285 :
286 9 : if (isConstantIntVector(Mask)) {
287 6 : for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
288 4 : if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
289 : continue;
290 1 : Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
291 : Value *Gep =
292 1 : Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
293 : Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
294 : }
295 2 : CI->eraseFromParent();
296 2 : return;
297 : }
298 :
299 21 : for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
300 : // Fill the "else" block, created in the previous iteration
301 : //
302 : // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
303 : // br i1 %mask_1, label %cond.store, label %else
304 : //
305 : Value *Predicate =
306 14 : Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
307 :
308 : // Create "cond" block
309 : //
310 : // %OneElt = extractelement <16 x i32> %Src, i32 Idx
311 : // %EltAddr = getelementptr i32* %1, i32 0
312 : // %store i32 %OneElt, i32* %EltAddr
313 : //
314 : BasicBlock *CondBlock =
315 14 : IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
316 14 : Builder.SetInsertPoint(InsertPt);
317 :
318 14 : Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
319 : Value *Gep =
320 14 : Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
321 : Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
322 :
323 : // Create "else" block, fill it in the next iteration
324 : BasicBlock *NewIfBlock =
325 14 : CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
326 14 : Builder.SetInsertPoint(InsertPt);
327 : Instruction *OldBr = IfBlock->getTerminator();
328 14 : BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
329 14 : OldBr->eraseFromParent();
330 : IfBlock = NewIfBlock;
331 : }
332 7 : CI->eraseFromParent();
333 : }
334 :
335 : // Translate a masked gather intrinsic like
336 : // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
337 : // <16 x i1> %Mask, <16 x i32> %Src)
338 : // to a chain of basic blocks, with loading element one-by-one if
339 : // the appropriate mask bit is set
340 : //
341 : // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
342 : // %Mask0 = extractelement <16 x i1> %Mask, i32 0
343 : // br i1 %Mask0, label %cond.load, label %else
344 : //
345 : // cond.load:
346 : // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
347 : // %Load0 = load i32, i32* %Ptr0, align 4
348 : // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
349 : // br label %else
350 : //
351 : // else:
352 : // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
353 : // %Mask1 = extractelement <16 x i1> %Mask, i32 1
354 : // br i1 %Mask1, label %cond.load1, label %else2
355 : //
356 : // cond.load1:
357 : // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
358 : // %Load1 = load i32, i32* %Ptr1, align 4
359 : // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
360 : // br label %else2
361 : // . . .
362 : // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
363 : // ret <16 x i32> %Result
364 76 : static void scalarizeMaskedGather(CallInst *CI) {
365 76 : Value *Ptrs = CI->getArgOperand(0);
366 : Value *Alignment = CI->getArgOperand(1);
367 : Value *Mask = CI->getArgOperand(2);
368 : Value *Src0 = CI->getArgOperand(3);
369 :
370 76 : VectorType *VecType = cast<VectorType>(CI->getType());
371 :
372 76 : IRBuilder<> Builder(CI->getContext());
373 : Instruction *InsertPt = CI;
374 76 : BasicBlock *IfBlock = CI->getParent();
375 : BasicBlock *CondBlock = nullptr;
376 : BasicBlock *PrevIfBlock = CI->getParent();
377 76 : Builder.SetInsertPoint(InsertPt);
378 76 : unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
379 :
380 76 : Builder.SetCurrentDebugLocation(CI->getDebugLoc());
381 :
382 : // The result vector
383 : Value *VResult = Src0;
384 76 : unsigned VectorWidth = VecType->getNumElements();
385 :
386 : // Shorten the way if the mask is a vector of constants.
387 76 : if (isConstantIntVector(Mask)) {
388 264 : for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
389 235 : if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
390 : continue;
391 219 : Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
392 219 : "Ptr" + Twine(Idx));
393 : LoadInst *Load =
394 219 : Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
395 219 : VResult = Builder.CreateInsertElement(
396 438 : VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
397 : }
398 29 : CI->replaceAllUsesWith(VResult);
399 29 : CI->eraseFromParent();
400 : return;
401 : }
402 :
403 348 : for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
404 : // Fill the "else" block, created in the previous iteration
405 : //
406 : // %Mask1 = extractelement <16 x i1> %Mask, i32 1
407 : // br i1 %Mask1, label %cond.load, label %else
408 : //
409 :
410 301 : Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
411 301 : "Mask" + Twine(Idx));
412 :
413 : // Create "cond" block
414 : //
415 : // %EltAddr = getelementptr i32* %1, i32 0
416 : // %Elt = load i32* %EltAddr
417 : // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
418 : //
419 301 : CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
420 301 : Builder.SetInsertPoint(InsertPt);
421 :
422 301 : Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
423 301 : "Ptr" + Twine(Idx));
424 : LoadInst *Load =
425 301 : Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
426 301 : Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
427 301 : Builder.getInt32(Idx),
428 301 : "Res" + Twine(Idx));
429 :
430 : // Create "else" block, fill it in the next iteration
431 301 : BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
432 301 : Builder.SetInsertPoint(InsertPt);
433 : Instruction *OldBr = IfBlock->getTerminator();
434 301 : BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
435 301 : OldBr->eraseFromParent();
436 : PrevIfBlock = IfBlock;
437 : IfBlock = NewIfBlock;
438 :
439 301 : PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
440 301 : Phi->addIncoming(NewVResult, CondBlock);
441 301 : Phi->addIncoming(VResult, PrevIfBlock);
442 : VResult = Phi;
443 : }
444 :
445 47 : CI->replaceAllUsesWith(VResult);
446 47 : CI->eraseFromParent();
447 : }
448 :
449 : // Translate a masked scatter intrinsic, like
450 : // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
451 : // <16 x i1> %Mask)
452 : // to a chain of basic blocks, that stores element one-by-one if
453 : // the appropriate mask bit is set.
454 : //
455 : // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
456 : // %Mask0 = extractelement <16 x i1> %Mask, i32 0
457 : // br i1 %Mask0, label %cond.store, label %else
458 : //
459 : // cond.store:
460 : // %Elt0 = extractelement <16 x i32> %Src, i32 0
461 : // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
462 : // store i32 %Elt0, i32* %Ptr0, align 4
463 : // br label %else
464 : //
465 : // else:
466 : // %Mask1 = extractelement <16 x i1> %Mask, i32 1
467 : // br i1 %Mask1, label %cond.store1, label %else2
468 : //
469 : // cond.store1:
470 : // %Elt1 = extractelement <16 x i32> %Src, i32 1
471 : // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
472 : // store i32 %Elt1, i32* %Ptr1, align 4
473 : // br label %else2
474 : // . . .
475 29 : static void scalarizeMaskedScatter(CallInst *CI) {
476 29 : Value *Src = CI->getArgOperand(0);
477 : Value *Ptrs = CI->getArgOperand(1);
478 : Value *Alignment = CI->getArgOperand(2);
479 : Value *Mask = CI->getArgOperand(3);
480 :
481 : assert(isa<VectorType>(Src->getType()) &&
482 : "Unexpected data type in masked scatter intrinsic");
483 : assert(isa<VectorType>(Ptrs->getType()) &&
484 : isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
485 : "Vector of pointers is expected in masked scatter intrinsic");
486 :
487 29 : IRBuilder<> Builder(CI->getContext());
488 : Instruction *InsertPt = CI;
489 29 : BasicBlock *IfBlock = CI->getParent();
490 29 : Builder.SetInsertPoint(InsertPt);
491 29 : Builder.SetCurrentDebugLocation(CI->getDebugLoc());
492 :
493 29 : unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
494 29 : unsigned VectorWidth = Src->getType()->getVectorNumElements();
495 :
496 : // Shorten the way if the mask is a vector of constants.
497 29 : if (isConstantIntVector(Mask)) {
498 12 : for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
499 10 : if (cast<ConstantVector>(Mask)->getAggregateElement(Idx)->isNullValue())
500 : continue;
501 10 : Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
502 10 : "Elt" + Twine(Idx));
503 10 : Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
504 10 : "Ptr" + Twine(Idx));
505 : Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
506 : }
507 2 : CI->eraseFromParent();
508 : return;
509 : }
510 :
511 180 : for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
512 : // Fill the "else" block, created in the previous iteration
513 : //
514 : // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
515 : // br i1 %Mask1, label %cond.store, label %else
516 : //
517 153 : Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
518 153 : "Mask" + Twine(Idx));
519 :
520 : // Create "cond" block
521 : //
522 : // %Elt1 = extractelement <16 x i32> %Src, i32 1
523 : // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
524 : // %store i32 %Elt1, i32* %Ptr1
525 : //
526 153 : BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
527 153 : Builder.SetInsertPoint(InsertPt);
528 :
529 153 : Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
530 153 : "Elt" + Twine(Idx));
531 153 : Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
532 153 : "Ptr" + Twine(Idx));
533 : Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
534 :
535 : // Create "else" block, fill it in the next iteration
536 153 : BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
537 153 : Builder.SetInsertPoint(InsertPt);
538 : Instruction *OldBr = IfBlock->getTerminator();
539 153 : BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
540 153 : OldBr->eraseFromParent();
541 : IfBlock = NewIfBlock;
542 : }
543 27 : CI->eraseFromParent();
544 : }
545 :
546 406591 : bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
547 : bool EverMadeChange = false;
548 :
549 406591 : TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
550 :
551 : bool MadeChange = true;
552 813313 : while (MadeChange) {
553 : MadeChange = false;
554 3630094 : for (Function::iterator I = F.begin(); I != F.end();) {
555 : BasicBlock *BB = &*I++;
556 3223503 : bool ModifiedDTOnIteration = false;
557 3223503 : MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
558 :
559 : // Restart BB iteration if the dominator tree of the Function was changed
560 3223503 : if (ModifiedDTOnIteration)
561 : break;
562 : }
563 :
564 406722 : EverMadeChange |= MadeChange;
565 : }
566 :
567 406591 : return EverMadeChange;
568 : }
569 :
570 3223503 : bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
571 : bool MadeChange = false;
572 :
573 : BasicBlock::iterator CurInstIterator = BB.begin();
574 40229555 : while (CurInstIterator != BB.end()) {
575 : if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
576 2391686 : MadeChange |= optimizeCallInst(CI, ModifiedDT);
577 37006183 : if (ModifiedDT)
578 : return true;
579 : }
580 :
581 : return MadeChange;
582 : }
583 :
584 0 : bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
585 : bool &ModifiedDT) {
586 : IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
587 : if (II) {
588 0 : switch (II->getIntrinsicID()) {
589 : default:
590 : break;
591 0 : case Intrinsic::masked_load:
592 : // Scalarize unsupported vector masked load
593 0 : if (!TTI->isLegalMaskedLoad(CI->getType())) {
594 0 : scalarizeMaskedLoad(CI);
595 0 : ModifiedDT = true;
596 0 : return true;
597 : }
598 0 : return false;
599 0 : case Intrinsic::masked_store:
600 0 : if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
601 0 : scalarizeMaskedStore(CI);
602 0 : ModifiedDT = true;
603 0 : return true;
604 : }
605 0 : return false;
606 0 : case Intrinsic::masked_gather:
607 0 : if (!TTI->isLegalMaskedGather(CI->getType())) {
608 0 : scalarizeMaskedGather(CI);
609 0 : ModifiedDT = true;
610 0 : return true;
611 : }
612 0 : return false;
613 0 : case Intrinsic::masked_scatter:
614 0 : if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
615 0 : scalarizeMaskedScatter(CI);
616 0 : ModifiedDT = true;
617 0 : return true;
618 : }
619 0 : return false;
620 : }
621 : }
622 :
623 : return false;
624 : }
|