1 /**
2  * Compute the cost of inlining a function call by counting expressions.
3  *
4  * Copyright:   Copyright (C) 1999-2023 by The D Language Foundation, All Rights Reserved
5  * Authors:     $(LINK2 https://www.digitalmars.com, Walter Bright)
6  * License:     $(LINK2 https://www.boost.org/LICENSE_1_0.txt, Boost License 1.0)
7  * Source:    $(LINK2 https://github.com/dlang/dmd/blob/master/src/dmd/inlinecost.d, _inlinecost.d)
8  * Documentation:  https://dlang.org/phobos/dmd_inlinecost.html
9  * Coverage:    https://codecov.io/gh/dlang/dmd/src/master/src/dmd/inlinecost.d
10  */
11 
12 module dmd.inlinecost;
13 
14 import core.stdc.stdio;
15 import core.stdc.string;
16 
17 import dmd.aggregate;
18 import dmd.arraytypes;
19 import dmd.astenums;
20 import dmd.attrib;
21 import dmd.dclass;
22 import dmd.declaration;
23 import dmd.dmodule;
24 import dmd.dscope;
25 import dmd.dstruct;
26 import dmd.dsymbol;
27 import dmd.expression;
28 import dmd.func;
29 import dmd.globals;
30 import dmd.id;
31 import dmd.identifier;
32 import dmd.init;
33 import dmd.mtype;
34 import dmd.opover;
35 import dmd.postordervisitor;
36 import dmd.statement;
37 import dmd.tokens;
38 import dmd.visitor;
39 
40 enum COST_MAX = 250;
41 
42 private enum STATEMENT_COST = 0x1000;
43 private enum STATEMENT_COST_MAX = 250 * STATEMENT_COST;
44 
45 // STATEMENT_COST be power of 2 and greater than COST_MAX
46 static assert((STATEMENT_COST & (STATEMENT_COST - 1)) == 0);
47 static assert(STATEMENT_COST > COST_MAX);
48 
49 /*********************************
50  * Determine if too expensive to inline.
51  * Params:
52  *      cost = cost of inlining
53  * Returns:
54  *      true if too costly
55  */
56 bool tooCostly(int cost) pure nothrow @safe
57 {
58     return ((cost & (STATEMENT_COST - 1)) >= COST_MAX);
59 }
60 
61 /*********************************
62  * Determine cost of inlining Expression
63  * Params:
64  *      e = Expression to determine cost of
65  * Returns:
66  *      cost of inlining e
67  */
68 int inlineCostExpression(Expression e)
69 {
70     scope InlineCostVisitor icv = new InlineCostVisitor(false, true, true, null);
71     icv.expressionInlineCost(e);
72     return icv.cost;
73 }
74 
75 
76 /*********************************
77  * Determine cost of inlining function
78  * Params:
79  *      fd = function to determine cost of
80  *      hasthis = if the function call has explicit 'this' expression
81  *      hdrscan = if generating a header file
82  * Returns:
83  *      cost of inlining fd
84  */
85 int inlineCostFunction(FuncDeclaration fd, bool hasthis, bool hdrscan)
86 {
87     scope InlineCostVisitor icv = new InlineCostVisitor(hasthis, hdrscan, false, fd);
88     fd.fbody.accept(icv);
89     return icv.cost;
90 }
91 
92 /**
93  * Indicates if a nested aggregate prevents or not a function to be inlined.
94  * It's used to compute the cost but also to avoid a copy of the aggregate
95  * while the inliner processes.
96  *
97  * Params:
98  *      e = the declaration expression that may represent an aggregate.
99  *
100  * Returns: `null` if `e` is not an aggregate or if it is an aggregate that
101  *      doesn't permit inlining, and the aggregate otherwise.
102  */
103 AggregateDeclaration isInlinableNestedAggregate(DeclarationExp e)
104 {
105     AggregateDeclaration result;
106     if (e.declaration.isAnonymous() && e.declaration.isAttribDeclaration)
107     {
108         AttribDeclaration ad = e.declaration.isAttribDeclaration;
109         if (ad.decl.length == 1)
110         {
111             if ((result = (*ad.decl)[0].isAggregateDeclaration) !is null)
112             {
113                 // classes would have to be destroyed
114                 if (auto cdecl = result.isClassDeclaration)
115                     return null;
116                 // if it's a struct: must not have dtor
117                 StructDeclaration sdecl = result.isStructDeclaration;
118                 if (sdecl && (sdecl.fieldDtor || sdecl.dtor))
119                     return null;
120                 // the aggregate must be static
121                 UnionDeclaration udecl = result.isUnionDeclaration;
122                 if ((sdecl || udecl) && !(result.storage_class & STC.static_))
123                     return null;
124 
125                 return result;
126             }
127         }
128     }
129     else if ((result = e.declaration.isStructDeclaration) !is null)
130     {
131         return result;
132     }
133     else if ((result = e.declaration.isUnionDeclaration) !is null)
134     {
135         return result;
136     }
137     return null;
138 }
139 
140 private:
141 
142 /***********************************************************
143  * Compute cost of inlining.
144  *
145  * Walk trees to determine if inlining can be done, and if so,
146  * if it is too complex to be worth inlining or not.
147  */
148 extern (C++) final class InlineCostVisitor : Visitor
149 {
150     alias visit = Visitor.visit;
151 public:
152     int nested;
153     bool hasthis;
154     bool hdrscan;       // if inline scan for 'header' content
155     bool allowAlloca;
156     FuncDeclaration fd;
157     int cost;           // zero start for subsequent AST
158 
159     extern (D) this() scope @safe
160     {
161     }
162 
163     extern (D) this(bool hasthis, bool hdrscan, bool allowAlloca, FuncDeclaration fd) scope @safe
164     {
165         this.hasthis = hasthis;
166         this.hdrscan = hdrscan;
167         this.allowAlloca = allowAlloca;
168         this.fd = fd;
169     }
170 
171     extern (D) this(InlineCostVisitor icv) scope @safe
172     {
173         nested = icv.nested;
174         hasthis = icv.hasthis;
175         hdrscan = icv.hdrscan;
176         allowAlloca = icv.allowAlloca;
177         fd = icv.fd;
178     }
179 
180     override void visit(Statement s)
181     {
182         //printf("Statement.inlineCost = %d\n", COST_MAX);
183         //printf("%p\n", s.isScopeStatement());
184         //printf("%s\n", s.toChars());
185         cost += COST_MAX; // default is we can't inline it
186     }
187 
188     override void visit(ExpStatement s)
189     {
190         expressionInlineCost(s.exp);
191     }
192 
193     override void visit(CompoundStatement s)
194     {
195         scope InlineCostVisitor icv = new InlineCostVisitor(this);
196         foreach (i; 0 .. s.statements.length)
197         {
198             if (Statement s2 = (*s.statements)[i])
199             {
200                 /* Specifically allow:
201                  *  if (condition)
202                  *      return exp1;
203                  *  return exp2;
204                  */
205                 IfStatement ifs;
206                 Statement s3;
207                 if ((ifs = s2.isIfStatement()) !is null &&
208                     ifs.ifbody &&
209                     ifs.ifbody.endsWithReturnStatement() &&
210                     !ifs.elsebody &&
211                     i + 1 < s.statements.length &&
212                     (s3 = (*s.statements)[i + 1]) !is null &&
213                     s3.endsWithReturnStatement()
214                    )
215                 {
216                     if (ifs.prm)       // if variables are declared
217                     {
218                         cost = COST_MAX;
219                         return;
220                     }
221                     expressionInlineCost(ifs.condition);
222                     ifs.ifbody.accept(this);
223                     s3.accept(this);
224                 }
225                 else
226                     s2.accept(icv);
227                 if (tooCostly(icv.cost))
228                     break;
229             }
230         }
231         cost += icv.cost;
232     }
233 
234     override void visit(UnrolledLoopStatement s)
235     {
236         scope InlineCostVisitor icv = new InlineCostVisitor(this);
237         foreach (s2; *s.statements)
238         {
239             if (s2)
240             {
241                 s2.accept(icv);
242                 if (tooCostly(icv.cost))
243                     break;
244             }
245         }
246         cost += icv.cost;
247     }
248 
249     override void visit(ScopeStatement s)
250     {
251         cost++;
252         if (s.statement)
253             s.statement.accept(this);
254     }
255 
256     override void visit(IfStatement s)
257     {
258         /* Can't declare variables inside ?: expressions, so
259          * we cannot inline if a variable is declared.
260          */
261         if (s.prm)
262         {
263             cost = COST_MAX;
264             return;
265         }
266         expressionInlineCost(s.condition);
267 
268         if (s.isIfCtfeBlock())
269         {
270             cost = COST_MAX;
271             return;
272         }
273 
274         /* Specifically allow:
275          *  if (condition)
276          *      return exp1;
277          *  else
278          *      return exp2;
279          * Otherwise, we can't handle return statements nested in if's.
280          */
281         if (s.elsebody && s.ifbody && s.ifbody.endsWithReturnStatement() && s.elsebody.endsWithReturnStatement())
282         {
283             s.ifbody.accept(this);
284             s.elsebody.accept(this);
285             //printf("cost = %d\n", cost);
286         }
287         else
288         {
289             nested += 1;
290             if (s.ifbody)
291                 s.ifbody.accept(this);
292             if (s.elsebody)
293                 s.elsebody.accept(this);
294             nested -= 1;
295         }
296         //printf("IfStatement.inlineCost = %d\n", cost);
297     }
298 
299     override void visit(ReturnStatement s)
300     {
301         // Can't handle return statements nested in if's
302         if (nested)
303         {
304             cost = COST_MAX;
305         }
306         else
307         {
308             expressionInlineCost(s.exp);
309         }
310     }
311 
312     override void visit(ImportStatement s)
313     {
314     }
315 
316     override void visit(ForStatement s)
317     {
318         cost += STATEMENT_COST;
319         if (s._init)
320             s._init.accept(this);
321         if (s.condition)
322             s.condition.accept(this);
323         if (s.increment)
324             s.increment.accept(this);
325         if (s._body)
326             s._body.accept(this);
327         //printf("ForStatement: inlineCost = %d\n", cost);
328     }
329 
330     override void visit(ThrowStatement s)
331     {
332         cost += STATEMENT_COST;
333         s.exp.accept(this);
334     }
335 
336     /* -------------------------- */
337     void expressionInlineCost(Expression e)
338     {
339         //printf("expressionInlineCost()\n");
340         //e.print();
341         if (e)
342         {
343             extern (C++) final class LambdaInlineCost : StoppableVisitor
344             {
345                 alias visit = typeof(super).visit;
346                 InlineCostVisitor icv;
347 
348             public:
349                 extern (D) this(InlineCostVisitor icv) @safe
350                 {
351                     this.icv = icv;
352                 }
353 
354                 override void visit(Expression e)
355                 {
356                     e.accept(icv);
357                     stop = icv.cost >= COST_MAX;
358                 }
359             }
360 
361             scope InlineCostVisitor icv = new InlineCostVisitor(this);
362             scope LambdaInlineCost lic = new LambdaInlineCost(icv);
363             walkPostorder(e, lic);
364             cost += icv.cost;
365         }
366     }
367 
368     override void visit(Expression e)
369     {
370         cost++;
371     }
372 
373     override void visit(VarExp e)
374     {
375         //printf("VarExp.inlineCost3() %s\n", toChars());
376         Type tb = e.type.toBasetype();
377         if (auto ts = tb.isTypeStruct())
378         {
379             StructDeclaration sd = ts.sym;
380             if (sd.isNested())
381             {
382                 /* An inner struct will be nested inside another function hierarchy than where
383                  * we're inlining into, so don't inline it.
384                  * At least not until we figure out how to 'move' the struct to be nested
385                  * locally. Example:
386                  *   struct S(alias pred) { void unused_func(); }
387                  *   void abc() { int w; S!(w) m; }
388                  *   void bar() { abc(); }
389                  */
390                 cost = COST_MAX;
391                 return;
392             }
393         }
394         FuncDeclaration fd = e.var.isFuncDeclaration();
395         if (fd && fd.isNested()) // https://issues.dlang.org/show_bug.cgi?id=7199 for test case
396             cost = COST_MAX;
397         else
398             cost++;
399     }
400 
401     override void visit(ThisExp e)
402     {
403         //printf("ThisExp.inlineCost3() %s\n", toChars());
404         if (!fd)
405         {
406             cost = COST_MAX;
407             return;
408         }
409         if (!hdrscan)
410         {
411             if (fd.isNested() || !hasthis)
412             {
413                 cost = COST_MAX;
414                 return;
415             }
416         }
417         cost++;
418     }
419 
420     override void visit(StructLiteralExp e)
421     {
422         //printf("StructLiteralExp.inlineCost3() %s\n", toChars());
423         if (e.sd.isNested())
424             cost = COST_MAX;
425         else
426             cost++;
427     }
428 
429     override void visit(NewExp e)
430     {
431         //printf("NewExp.inlineCost3() %s\n", e.toChars());
432         AggregateDeclaration ad = isAggregate(e.newtype);
433         if (ad && ad.isNested())
434             cost = COST_MAX;
435         else
436             cost++;
437     }
438 
439     override void visit(FuncExp e)
440     {
441         //printf("FuncExp.inlineCost3()\n");
442         // Right now, this makes the function be output to the .obj file twice.
443         cost = COST_MAX;
444     }
445 
446     override void visit(DelegateExp e)
447     {
448         //printf("DelegateExp.inlineCost3()\n");
449         cost = COST_MAX;
450     }
451 
452     override void visit(DeclarationExp e)
453     {
454         //printf("DeclarationExp.inlineCost3()\n");
455         if (auto vd = e.declaration.isVarDeclaration())
456         {
457             if (auto td = vd.toAlias().isTupleDeclaration())
458             {
459                 cost = COST_MAX; // finish DeclarationExp.doInlineAs
460                 return;
461             }
462             if (!hdrscan && vd.isDataseg())
463             {
464                 cost = COST_MAX;
465                 return;
466             }
467             if (vd.edtor)
468             {
469                 // if destructor required
470                 // needs work to make this work
471                 cost = COST_MAX;
472                 return;
473             }
474             // Scan initializer (vd.init)
475             if (vd._init)
476             {
477                 if (auto ie = vd._init.isExpInitializer())
478                 {
479                     expressionInlineCost(ie.exp);
480                 }
481             }
482             ++cost;
483         }
484 
485         // aggregates are accepted under certain circumstances
486         if (isInlinableNestedAggregate(e))
487         {
488             cost++;
489             return;
490         }
491 
492         // These can contain functions, which when copied, get output twice.
493         if (e.declaration.isStructDeclaration() ||
494             e.declaration.isClassDeclaration()  ||
495             e.declaration.isFuncDeclaration()   ||
496             e.declaration.isAttribDeclaration() ||
497             e.declaration.isTemplateMixin())
498         {
499             cost = COST_MAX;
500             return;
501         }
502         //printf("DeclarationExp.inlineCost3('%s')\n", toChars());
503     }
504 
505     override void visit(CallExp e)
506     {
507         //printf("CallExp.inlineCost3() %s\n", toChars());
508         // https://issues.dlang.org/show_bug.cgi?id=3500
509         // super.func() calls must be devirtualized, and the inliner
510         // can't handle that at present.
511         if (e.e1.op == EXP.dotVariable && (cast(DotVarExp)e.e1).e1.op == EXP.super_)
512             cost = COST_MAX;
513         else if (e.f && e.f.ident == Id.__alloca && e.f._linkage == LINK.c && !allowAlloca)
514             cost = COST_MAX; // inlining alloca may cause stack overflows
515         else
516             cost++;
517     }
518 }