Documentation of CSL
abreviation.h
Go to the documentation of this file.
1 // This file is part of MARTY.
2 //
3 // MARTY is free software: you can redistribute it and/or modify
4 // it under the terms of the GNU General Public License as published by
5 // the Free Software Foundation, either version 3 of the License, or
6 // (at your option) any later version.
7 //
8 // MARTY is distributed in the hope that it will be useful,
9 // but WITHOUT ANY WARRANTY; without even the implied warranty of
10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 // GNU General Public License for more details.
12 //
13 // You should have received a copy of the GNU General Public License
14 // along with MARTY. If not, see <https://www.gnu.org/licenses/>.
15 
23 #ifndef ABREVIATION_H_INCLUDED
24 #define ABREVIATION_H_INCLUDED
25 
26 #include "abstract.h"
27 #include "space.h"
28 #include "interface.h"
29 #include "index.h"
30 #include "utils.h"
31 #include "replace.h"
32 #include "algo.h"
33 
34 namespace csl {
35 
36 class AbstractParent;
37 template<class Base>
39 
40 inline std::map<std::string, std::vector<AbstractParent*>> abbreviationData;
41 
42 class Abbrev {
43 
44  template<class Base>
45  friend class Abbreviation;
46 
47  public:
48 
49  static inline bool avoidDuplicates = true;
50  static inline bool useDichotomy = true;
51 
52  struct compareParents {
53  bool operator()(AbstractParent* const& A,
54  AbstractParent* const& B);
55  };
56 
57  static std::map<std::string, size_t> id_name;
58 
59  static bool isAnAbbreviation(Expr const &ab);
60 
61  static AbstractParent* find(std::string_view name);
62 
63  static AbstractParent* find(Expr const& abreviation);
64 
65  static AbstractParent* find_opt(std::string_view name);
66 
67  static AbstractParent* find_opt(Expr const& abreviation);
68 
69  static void compressAbbreviations(std::string const &name = "");
70 
71  private:
72 
73  static void compressAbbreviations_impl(
74  std::vector<AbstractParent*> &abbreviations
75  );
76 
77  static void addAbbreviation(
78  AbstractParent* ptr,
79  std::string const &t_name
80  );
81 
82  static void removeAbbreviation(
83  AbstractParent* ptr,
84  std::string const &t_name
85  );
86 
87  static std::string getFinalName(std::string_view initialName);
88 
89  static std::vector<AbstractParent*> &getAbbreviationsForName(
90  std::string_view name
91  );
92 
93  static void cleanEmptyAbbreviation();
94 
95  public:
96 
97  static void printAbbreviations(std::ostream& fout = std::cout);
98 
99  static void printAbbreviations(
100  std::string_view name,
101  std::ostream &fout = std::cout
102  );
103 
104  static void enableEvaluation(std::string_view name);
105 
106  static void disableEvaluation(std::string_view name);
107 
108  static void toggleEvaluation(std::string_view name);
109 
110  static void enableGenericEvaluation(std::string_view name);
111 
112  static void disableGenericEvaluation(std::string_view name);
113 
114  static void toggleGenericEvaluation(std::string_view name);
115 
116  static void enableEvaluation(Expr const& abreviation);
117 
118  static void disableEvaluation(Expr const& abreviation);
119 
120  static void toggleEvaluation(Expr const& abreviation);
121 
122  static csl::IndexStructure getFreeStructure(
123  csl::IndexStructure const& structure);
124 
125  static csl::IndexStructure getFreeStructure(
126  Expr const& expr);
127 
128  static std::optional<Expr> findExisting(
129  std::string_view name,
130  Expr const& encaps);
131 
132  static Expr makeSubAbbrev(
133  std::vector<csl::Expr> const &args,
134  bool isProd
135  );
136 
137  static Expr makeAbbreviation(
138  std::string name,
139  Expr const& encapsulated,
140  bool split = true
141  );
142 
143  static Expr makeAbbreviation(
144  Expr const& encapsulated,
145  bool split = true
146  );
147 
148  static void removeAbbreviations(
149  std::string const &name
150  );
151 
152  template<class ...Args>
153  static void replace(
154  Args &&...args
155  )
156  {
157  for (auto &el : abbreviationData)
158  for (auto &ab : el.second)
159  ab->setEncapsulated(
160  csl::Replaced(
161  ab->getEncapsulated(),
162  std::forward<Args>(args)...
163  )
164  );
165  }
166 
167  public:
168 
169  Abbrev() = delete;
170 };
171 
172 
173 template<class BaseParent>
174 class Abbreviation: public BaseParent {
175 
176  friend class Abbrev;
177 
178  public:
179 
180  template<class ...Args>
181  Abbreviation(Expr const& t_encapsulated,
182  std::string const& t_name,
183  Args&& ... args)
184  :BaseParent(Abbrev::getFinalName(t_name), std::forward<Args>(args)...),
185  encapsulated(t_encapsulated),
186  baseName(t_name),
187  initialStructure(Abbrev::getFreeStructure(t_encapsulated))
188  {
189  Abbrev::addAbbreviation(this, t_name);
190  }
191 
192  ~Abbreviation()
193  {
194  Abbrev::removeAbbreviation(this, baseName);
195  }
196 
197  bool isAnAbbreviation() const override { return true; }
198 
199  void printDefinition(
200  std::ostream &out = std::cout,
201  int indentSize = 4,
202  bool header = false
203  ) const override
204  {
205  std::string indent(indentSize, ' ');
206  std::string regName = csl::Abstract::regularName(this->name);
207  std::string regLite = csl::Abstract::regularLiteral(this->name);
208  out << indent;
209  if (header)
210  out << "inline ";
211  out << "csl::Expr " << regName << "_encaps = ";
212  encapsulated->printCode(1, out);
213  out << ";\n";
214  out << indent;
215  if (header)
216  out << "inline ";
217  out << "csl::Expr " << regName
218  << " = csl::Abbrev::makeAbbreviation(\"" << regLite
219  << "\", " << regName << "_encaps);\n";
220  }
221 
222  std::string const &getBaseName() const override {
223  return baseName;
224  }
225 
226  void enableEvaluation() override {
227  evaluation = true;
228  };
229 
230  void disableEvaluation() override {
231  evaluation = false;
232  };
233 
234  void toggleEvaluation() override {
235  evaluation = not evaluation;
236  };
237 
254  {
255  Expr res = DeepCopy(getEncapsulated());
256  if (self->isComplexConjugate())
257  res = csl::GetComplexConjugate(res);
258  csl::IndexStructure structure = self->getIndexStructure();
259  CSL_ASSERT_SPEC(structure.size() == initialStructure.size(),
260  CSLError::RuntimeError,
261  "Wrong indicial structure " + toString(structure)
262  + " to apply for " + toString(initialStructure)
263  + " initially.");
264  // csl::IndexStructure intermediate(structure);
265  // for (auto &index : intermediate)
266  // index = index.rename();
267  std::map<csl::Index, csl::Index> mapping;
268  csl::ForEachLeaf(res, [&](Expr &sub)
269  {
270  if (!IsIndicialTensor(sub))
271  return;
272  IndexStructure &index = sub->getIndexStructureView();
273  for (auto &i : index) {
274  auto pos = std::find(
275  initialStructure.begin(),
276  initialStructure.end(),
277  i);
278  if (pos == initialStructure.end()) {
279  auto pos2 = mapping.find(i);
280  const bool sign = i.getSign();
281  if (pos2 == mapping.end()) {
282  mapping[i] = i.rename();
283  mapping[i].setSign(false);
284  i = mapping[i];
285  if (sign)
286  i = +i;
287  }
288  else {
289  i = mapping[i];
290  if (sign)
291  i = +i;
292  }
293  }
294  }
295  });
296  Replace(res, initialStructure, structure);
297  return res;
298  }
299 
300  std::optional<Expr> evaluate(
301  Expr_info self,
302  csl::eval::mode user_mode = csl::eval::base) const override {
303  if (evaluation or csl::eval::isContained(user_mode, csl::eval::abbreviation)){
304  auto res = getExactEncapsulated(self);
305  return Evaluated(res, user_mode);
306  }
307  return std::nullopt;
308  }
309 
316  Expr const &getEncapsulated() const override {
317  return encapsulated;
318  }
319 
320  void setEncapsulated(Expr const &t_encapsulated) override {
321  encapsulated = t_encapsulated;
322  }
323 
324  private:
325 
326  Expr encapsulated;
327 
328  std::string baseName;
329 
330  csl::IndexStructure initialStructure;
331 
332  bool evaluation = false;
333 };
334 
335 
336 } // End of namespace csl
337 
338 #endif
Namespace for csl library.
Definition: abreviation.h:34
Expr DeepCopy(const Abstract *expr)
See DeepCopy(const Expr& expr).
Definition: utils.cpp:113
void ForEachLeaf(Expr &init, std::function< void(Expr &)> const &f, int depth=-1)
Applies a user function on each leaf of an expression. The expression may be modified.
Definition: algo.cpp:370
Definition: abreviation.h:52
Root class of the inheritance tree of abstracts.
Definition: abstract.h:76
csl::Expr getExactEncapsulated(Expr_info self) const override
Returns the encapsulated expression, applying to it the correct index structure.
Definition: abreviation.h:253
Expr const & getEncapsulated() const override
Definition: abreviation.h:316
Base class for all parents (indicial, fields etc). All parents derive from this class.
Definition: parent.h:81
Definition: abreviation.h:38
Contains algorithms that look over (and possibly modify on the go) expressions for you...
Base classes for all exprs in the program.
Manages a std::vector of Index, to be used by an TensorElement.
Definition: index.h:472
Definition: abreviation.h:42
Expression type/.
Definition: abstract.h:1573