1 : //
2 : // Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
3 : // Use of this source code is governed by a BSD-style license that can be
4 : // found in the LICENSE file.
5 : //
6 :
7 : #include "compiler/ForLoopUnroll.h"
8 :
9 : namespace {
10 :
11 0 : class IntegerForLoopUnrollMarker : public TIntermTraverser {
12 : public:
13 :
14 0 : virtual bool visitLoop(Visit, TIntermLoop* node)
15 : {
16 : // This is called after ValidateLimitations pass, so all the ASSERT
17 : // should never fail.
18 : // See ValidateLimitations::validateForLoopInit().
19 0 : ASSERT(node);
20 0 : ASSERT(node->getType() == ELoopFor);
21 0 : ASSERT(node->getInit());
22 0 : TIntermAggregate* decl = node->getInit()->getAsAggregate();
23 0 : ASSERT(decl && decl->getOp() == EOpDeclaration);
24 0 : TIntermSequence& declSeq = decl->getSequence();
25 0 : ASSERT(declSeq.size() == 1);
26 0 : TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
27 0 : ASSERT(declInit && declInit->getOp() == EOpInitialize);
28 0 : ASSERT(declInit->getLeft());
29 0 : TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
30 0 : ASSERT(symbol);
31 0 : TBasicType type = symbol->getBasicType();
32 0 : ASSERT(type == EbtInt || type == EbtFloat);
33 0 : if (type == EbtInt)
34 0 : node->setUnrollFlag(true);
35 0 : return true;
36 : }
37 :
38 : };
39 :
40 : } // anonymous namepsace
41 :
42 0 : void ForLoopUnroll::FillLoopIndexInfo(TIntermLoop* node, TLoopIndexInfo& info)
43 : {
44 0 : ASSERT(node->getType() == ELoopFor);
45 0 : ASSERT(node->getUnrollFlag());
46 :
47 0 : TIntermNode* init = node->getInit();
48 0 : ASSERT(init != NULL);
49 0 : TIntermAggregate* decl = init->getAsAggregate();
50 0 : ASSERT((decl != NULL) && (decl->getOp() == EOpDeclaration));
51 0 : TIntermSequence& declSeq = decl->getSequence();
52 0 : ASSERT(declSeq.size() == 1);
53 0 : TIntermBinary* declInit = declSeq[0]->getAsBinaryNode();
54 0 : ASSERT((declInit != NULL) && (declInit->getOp() == EOpInitialize));
55 0 : TIntermSymbol* symbol = declInit->getLeft()->getAsSymbolNode();
56 0 : ASSERT(symbol != NULL);
57 0 : ASSERT(symbol->getBasicType() == EbtInt);
58 :
59 0 : info.id = symbol->getId();
60 :
61 0 : ASSERT(declInit->getRight() != NULL);
62 0 : TIntermConstantUnion* initNode = declInit->getRight()->getAsConstantUnion();
63 0 : ASSERT(initNode != NULL);
64 :
65 0 : info.initValue = evaluateIntConstant(initNode);
66 0 : info.currentValue = info.initValue;
67 :
68 0 : TIntermNode* cond = node->getCondition();
69 0 : ASSERT(cond != NULL);
70 0 : TIntermBinary* binOp = cond->getAsBinaryNode();
71 0 : ASSERT(binOp != NULL);
72 0 : ASSERT(binOp->getRight() != NULL);
73 0 : ASSERT(binOp->getRight()->getAsConstantUnion() != NULL);
74 :
75 0 : info.incrementValue = getLoopIncrement(node);
76 : info.stopValue = evaluateIntConstant(
77 0 : binOp->getRight()->getAsConstantUnion());
78 0 : info.op = binOp->getOp();
79 0 : }
80 :
81 0 : void ForLoopUnroll::Step()
82 : {
83 0 : ASSERT(mLoopIndexStack.size() > 0);
84 0 : TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
85 0 : info.currentValue += info.incrementValue;
86 0 : }
87 :
88 0 : bool ForLoopUnroll::SatisfiesLoopCondition()
89 : {
90 0 : ASSERT(mLoopIndexStack.size() > 0);
91 0 : TLoopIndexInfo& info = mLoopIndexStack[mLoopIndexStack.size() - 1];
92 : // Relational operator is one of: > >= < <= == or !=.
93 0 : switch (info.op) {
94 : case EOpEqual:
95 0 : return (info.currentValue == info.stopValue);
96 : case EOpNotEqual:
97 0 : return (info.currentValue != info.stopValue);
98 : case EOpLessThan:
99 0 : return (info.currentValue < info.stopValue);
100 : case EOpGreaterThan:
101 0 : return (info.currentValue > info.stopValue);
102 : case EOpLessThanEqual:
103 0 : return (info.currentValue <= info.stopValue);
104 : case EOpGreaterThanEqual:
105 0 : return (info.currentValue >= info.stopValue);
106 : default:
107 0 : UNREACHABLE();
108 : }
109 : return false;
110 : }
111 :
112 0 : bool ForLoopUnroll::NeedsToReplaceSymbolWithValue(TIntermSymbol* symbol)
113 : {
114 0 : for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin();
115 0 : i != mLoopIndexStack.end();
116 : ++i) {
117 0 : if (i->id == symbol->getId())
118 0 : return true;
119 : }
120 0 : return false;
121 : }
122 :
123 0 : int ForLoopUnroll::GetLoopIndexValue(TIntermSymbol* symbol)
124 : {
125 0 : for (TVector<TLoopIndexInfo>::iterator i = mLoopIndexStack.begin();
126 0 : i != mLoopIndexStack.end();
127 : ++i) {
128 0 : if (i->id == symbol->getId())
129 0 : return i->currentValue;
130 : }
131 0 : UNREACHABLE();
132 : return false;
133 : }
134 :
135 0 : void ForLoopUnroll::Push(TLoopIndexInfo& info)
136 : {
137 0 : mLoopIndexStack.push_back(info);
138 0 : }
139 :
140 0 : void ForLoopUnroll::Pop()
141 : {
142 0 : mLoopIndexStack.pop_back();
143 0 : }
144 :
145 : // static
146 0 : void ForLoopUnroll::MarkForLoopsWithIntegerIndicesForUnrolling(
147 : TIntermNode* root)
148 : {
149 0 : ASSERT(root);
150 :
151 0 : IntegerForLoopUnrollMarker marker;
152 0 : root->traverse(&marker);
153 0 : }
154 :
155 0 : int ForLoopUnroll::getLoopIncrement(TIntermLoop* node)
156 : {
157 0 : TIntermNode* expr = node->getExpression();
158 0 : ASSERT(expr != NULL);
159 : // for expression has one of the following forms:
160 : // loop_index++
161 : // loop_index--
162 : // loop_index += constant_expression
163 : // loop_index -= constant_expression
164 : // ++loop_index
165 : // --loop_index
166 : // The last two forms are not specified in the spec, but I am assuming
167 : // its an oversight.
168 0 : TIntermUnary* unOp = expr->getAsUnaryNode();
169 0 : TIntermBinary* binOp = unOp ? NULL : expr->getAsBinaryNode();
170 :
171 0 : TOperator op = EOpNull;
172 0 : TIntermConstantUnion* incrementNode = NULL;
173 0 : if (unOp != NULL) {
174 0 : op = unOp->getOp();
175 0 : } else if (binOp != NULL) {
176 0 : op = binOp->getOp();
177 0 : ASSERT(binOp->getRight() != NULL);
178 0 : incrementNode = binOp->getRight()->getAsConstantUnion();
179 0 : ASSERT(incrementNode != NULL);
180 : }
181 :
182 0 : int increment = 0;
183 : // The operator is one of: ++ -- += -=.
184 0 : switch (op) {
185 : case EOpPostIncrement:
186 : case EOpPreIncrement:
187 0 : ASSERT((unOp != NULL) && (binOp == NULL));
188 0 : increment = 1;
189 0 : break;
190 : case EOpPostDecrement:
191 : case EOpPreDecrement:
192 0 : ASSERT((unOp != NULL) && (binOp == NULL));
193 0 : increment = -1;
194 0 : break;
195 : case EOpAddAssign:
196 0 : ASSERT((unOp == NULL) && (binOp != NULL));
197 0 : increment = evaluateIntConstant(incrementNode);
198 0 : break;
199 : case EOpSubAssign:
200 0 : ASSERT((unOp == NULL) && (binOp != NULL));
201 0 : increment = - evaluateIntConstant(incrementNode);
202 0 : break;
203 : default:
204 0 : ASSERT(false);
205 : }
206 :
207 0 : return increment;
208 : }
209 :
210 0 : int ForLoopUnroll::evaluateIntConstant(TIntermConstantUnion* node)
211 : {
212 0 : ASSERT((node != NULL) && (node->getUnionArrayPointer() != NULL));
213 0 : return node->getUnionArrayPointer()->getIConst();
214 : }
215 :
|