Skip to content

Commit 5fdd1f0

Browse files
committed
[LLVM] Improve the DemandedBits Analysis
This patch adds support for missing operators inside the DemandedBits Analysis. Those operators are SDiv, UDiv, URem, SRem. Also, other operators such as Shl and Ashr are improved to handle non constant argument shift amount. Multiplication is also improved. Comparison with the upstream version of llvm with the Oz pipeline showed up to 10% code size reduction in the llvm test suite.
1 parent 977cfea commit 5fdd1f0

File tree

6 files changed

+448
-27
lines changed

6 files changed

+448
-27
lines changed

llvm/lib/Analysis/DemandedBits.cpp

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include "llvm/Support/Casting.h"
3737
#include "llvm/Support/Debug.h"
3838
#include "llvm/Support/KnownBits.h"
39+
#include "llvm/Support/MathExtras.h"
3940
#include "llvm/Support/raw_ostream.h"
4041
#include <algorithm>
4142
#include <cstdint>
@@ -164,10 +165,24 @@ void DemandedBits::determineLiveOperandBits(
164165
}
165166
break;
166167
case Instruction::Mul:
167-
// Find the highest live output bit. We don't need any more input
168-
// bits than that (adds, and thus subtracts, ripple only to the
169-
// left).
170-
AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
168+
const APInt *C;
169+
if (OperandNo == 0) {
170+
// to have output bits 0...H-1 we need the input bits
171+
// 0...(H - ceiling(log_2))
172+
if (match(UserI->getOperand(1), m_APInt(C))) {
173+
auto LogC = C->isOne() ? 0 : C->logBase2() + 1;
174+
unsigned Need =
175+
AOut.getActiveBits() > LogC ? AOut.getActiveBits() - LogC : 0;
176+
AB = APInt::getLowBitsSet(BitWidth, Need);
177+
} else { // TODO: we can possibly check for Op0 constant too
178+
AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
179+
}
180+
} else {
181+
// Find the highest live output bit. We don't need any more input
182+
// bits than that (adds, and thus subtracts, ripple only to the
183+
// left).
184+
AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits());
185+
}
171186
break;
172187
case Instruction::Shl:
173188
if (OperandNo == 0) {
@@ -183,6 +198,17 @@ void DemandedBits::determineLiveOperandBits(
183198
AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
184199
else if (S->hasNoUnsignedWrap())
185200
AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
201+
} else {
202+
ComputeKnownBits(BitWidth, UserI->getOperand(1), nullptr);
203+
unsigned Min = Known.getMinValue().getLimitedValue(BitWidth - 1);
204+
unsigned Max = Known.getMaxValue().getLimitedValue(BitWidth - 1);
205+
// similar to Lshr case
206+
AB = (AOut.lshr(Min) | AOut.lshr(Max));
207+
const auto *S = cast<ShlOperator>(UserI);
208+
if (S->hasNoSignedWrap())
209+
AB |= APInt::getHighBitsSet(BitWidth, Max + 1);
210+
else if (S->hasNoUnsignedWrap())
211+
AB |= APInt::getHighBitsSet(BitWidth, Max);
186212
}
187213
}
188214
break;
@@ -197,6 +223,19 @@ void DemandedBits::determineLiveOperandBits(
197223
// (they must be zero).
198224
if (cast<LShrOperator>(UserI)->isExact())
199225
AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
226+
} else {
227+
ComputeKnownBits(BitWidth, UserI->getOperand(1), nullptr);
228+
unsigned Min = Known.getMinValue().getLimitedValue(BitWidth - 1);
229+
unsigned Max = Known.getMaxValue().getLimitedValue(BitWidth - 1);
230+
// Suppose AOut == 0b0000 0011
231+
// [min, max] = [1, 3]
232+
// shift by 1 we get 0b0000 0110
233+
// shift by 2 we get 0b0000 1100
234+
// shift by 3 we get 0b0001 1000
235+
// we take the or here because need to cover all the above possibilities
236+
AB = (AOut.shl(Min) | AOut.shl(Max));
237+
if (cast<LShrOperator>(UserI)->isExact())
238+
AB |= APInt::getLowBitsSet(BitWidth, Max);
200239
}
201240
}
202241
break;
@@ -217,6 +256,27 @@ void DemandedBits::determineLiveOperandBits(
217256
// (they must be zero).
218257
if (cast<AShrOperator>(UserI)->isExact())
219258
AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
259+
} else {
260+
ComputeKnownBits(BitWidth, UserI->getOperand(1), nullptr);
261+
unsigned Min = Known.getMinValue().getLimitedValue(BitWidth - 1);
262+
unsigned Max = Known.getMaxValue().getLimitedValue(BitWidth - 1);
263+
AB = (AOut.shl(Min) | AOut.shl(Max));
264+
265+
if (Max) {
266+
// Suppose AOut = 0011 1100
267+
// [min, max] = [1, 3]
268+
// ShiftAmount = 1 : Mask is 1000 0000
269+
// ShiftAmount = 2 : Mask is 1100 0000
270+
// ShiftAmount = 3 : Mask is 1110 0000
271+
// The Mask with Max covers every case in [min, max],
272+
// so we are done
273+
if ((AOut & APInt::getHighBitsSet(BitWidth, Max)).getBoolValue())
274+
AB.setSignBit();
275+
}
276+
// If the shift is exact, then the low bits are not dead
277+
// (they must be zero).
278+
if (cast<AShrOperator>(UserI)->isExact())
279+
AB |= APInt::getLowBitsSet(BitWidth, Max);
220280
}
221281
}
222282
break;
@@ -246,6 +306,35 @@ void DemandedBits::determineLiveOperandBits(
246306
else
247307
AB &= ~(Known.One & ~Known2.One);
248308
break;
309+
case Instruction::UDiv:
310+
case Instruction::URem:
311+
case Instruction::SDiv:
312+
case Instruction::SRem: {
313+
auto Opc = UserI->getOpcode();
314+
auto IsDiv = Opc == Instruction::UDiv || Opc == Instruction::SDiv;
315+
bool IsSigned = Opc == Instruction::SDiv || Opc == Instruction::SRem;
316+
if (OperandNo == 0) {
317+
const APInt *DivAmnt;
318+
if (match(UserI->getOperand(1), m_APInt(DivAmnt))) {
319+
uint64_t D = DivAmnt->getZExtValue();
320+
if (isPowerOf2_64(D)) {
321+
unsigned Sh = Log2_64(D);
322+
if (IsDiv) {
323+
AB = AOut.shl(Sh);
324+
} else {
325+
AB = AOut & APInt::getLowBitsSet(BitWidth, Sh);
326+
}
327+
} else { // Non power of 2 constant div
328+
unsigned LowQ = AOut.getActiveBits();
329+
unsigned Need = LowQ + Log2_64_Ceil(D);
330+
if (IsSigned)
331+
Need++;
332+
AB = APInt::getLowBitsSet(BitWidth, std::min(BitWidth, Need));
333+
}
334+
}
335+
}
336+
break;
337+
}
249338
case Instruction::Xor:
250339
case Instruction::PHI:
251340
AB = AOut;

llvm/test/Analysis/DemandedBits/basic.ll

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,28 @@ define i8 @test_mul(i32 %a, i32 %b) {
2525
%6 = add nsw i8 %3, %5
2626
ret i8 %6
2727
}
28+
; CHECK-LABEL: Printing analysis 'Demanded Bits Analysis' for function 'test_mul_constant':
29+
; CHECK-DAG: DemandedBits: 0xff for %3 = trunc i32 %2 to i8
30+
; CHECK-DAG: DemandedBits: 0xff for %2 in %3 = trunc i32 %2 to i8
31+
; CHECK-DAG: DemandedBits: 0xff for %2 = mul nsw i32 %1, 6
32+
; CHECK-DAG: DemandedBits: 0x1f for %1 in %2 = mul nsw i32 %1, 6
33+
; CHECK-DAG: DemandedBits: 0xff for 6 in %2 = mul nsw i32 %1, 6
34+
; CHECK-DAG: DemandedBits: 0x1 for %4 = trunc i32 %2 to i1
35+
; CHECK-DAG: DemandedBits: 0x1 for %2 in %4 = trunc i32 %2 to i1
36+
; CHECK-DAG: DemandedBits: 0x1f for %1 = add nsw i32 %a, 12
37+
; CHECK-DAG: DemandedBits: 0x1f for %a in %1 = add nsw i32 %a, 12
38+
; CHECK-DAG: DemandedBits: 0x1f for 12 in %1 = add nsw i32 %a, 12
39+
; CHECK-DAG: DemandedBits: 0xff for %5 = zext i1 %4 to i8
40+
; CHECK-DAG: DemandedBits: 0x1 for %4 in %5 = zext i1 %4 to i8
41+
; CHECK-DAG: DemandedBits: 0xff for %6 = add nsw i8 %3, %5
42+
; CHECK-DAG: DemandedBits: 0xff for %3 in %6 = add nsw i8 %3, %5
43+
; CHECK-DAG: DemandedBits: 0xff for %5 in %6 = add nsw i8 %3, %5
44+
define i8 @test_mul_constant(i32 %a, i32 %b){
45+
%1 = add nsw i32 %a, 12
46+
%2 = mul nsw i32 %1, 6
47+
%3 = trunc i32 %2 to i8
48+
%4 = trunc i32 %2 to i1
49+
%5 = zext i1 %4 to i8
50+
%6 = add nsw i8 %3, %5
51+
ret i8 %6
52+
}

0 commit comments

Comments
 (0)