]> gerrit.simantics Code Review - simantics/platform.git/blob - bundles/org.simantics.graph/src/org/simantics/graph/matching/CanonicalizingMatchingStrategy.java
Migrated source code from Simantics SVN
[simantics/platform.git] / bundles / org.simantics.graph / src / org / simantics / graph / matching / CanonicalizingMatchingStrategy.java
1 package org.simantics.graph.matching;\r
2 \r
3 import gnu.trove.list.array.TIntArrayList;\r
4 import gnu.trove.map.hash.TObjectIntHashMap;\r
5 import gnu.trove.set.hash.TIntHashSet;\r
6 \r
7 import java.util.Arrays;\r
8 import java.util.Comparator;\r
9 \r
10 import org.simantics.databoard.binding.mutable.Variant;\r
11 \r
12 public enum CanonicalizingMatchingStrategy implements GraphMatchingStrategy {\r
13         INSTANCE;\r
14 \r
15         private static class Vertex {\r
16                 int graph;\r
17                 int original;\r
18                 int pos;\r
19                 Stat[] stats;\r
20                 \r
21                 public Vertex(int graph, int original, int pos, Stat[] stats) {\r
22                         this.graph = graph;\r
23                         this.original = original;\r
24                         this.pos = pos;\r
25                         this.stats = stats;\r
26                 }\r
27         }\r
28         \r
29         private static final Comparator<Vertex> VERTEX_COMPARATOR = new Comparator<Vertex>() {\r
30                 @Override\r
31                 public int compare(Vertex o1, Vertex o2) {\r
32                         int pos1 = o1.pos;\r
33                         int pos2 = o2.pos;\r
34                         if(pos1 < pos2)\r
35                                 return -1;\r
36                         if(pos1 > pos2)\r
37                                 return 1;\r
38                         Stat[] stats1 = o1.stats;\r
39                         Stat[] stats2 = o2.stats;\r
40                         if(stats1.length < stats2.length)\r
41                                 return -1;\r
42                         if(stats1.length > stats2.length)\r
43                                 return 1;\r
44                         for(int i=0;i<stats1.length;++i) {\r
45                                 int comp = Stat.STAT_COMPARATOR.compare(stats1[i], stats2[i]);\r
46                                 if(comp != 0)\r
47                                         return comp;\r
48                         }\r
49                         if(o1.graph < o2.graph)\r
50                                 return -1;\r
51                         if(o1.graph > o2.graph)\r
52                                 return 1;                       \r
53                         if(o1.original < o2.original)\r
54                                 return -1;\r
55                         if(o1.original > o2.original)\r
56                                 return 1;\r
57                         return 0;\r
58                 }\r
59         };\r
60         \r
61         private static int[] generateMapA(int[] aToB) {\r
62                 int[] map = new int[aToB.length];\r
63                 for(int i=0;i<aToB.length;++i) {\r
64                         int c = aToB[i];\r
65                         if(c >= 0)\r
66                                 map[i] = -1 - c;\r
67                         else\r
68                                 map[i] = 0;\r
69                 }\r
70                 return map;\r
71         }\r
72         \r
73         private static int[] generateMapB(int[] bToA) {\r
74                 int[] map = new int[bToA.length];\r
75                 for(int i=0;i<bToA.length;++i) {\r
76                         int c = bToA[i];\r
77                         if(c >= 0)\r
78                                 map[i] = -1 - i;\r
79                         else\r
80                                 map[i] = 0;\r
81                 }\r
82                 return map;\r
83         }\r
84         \r
85         private static Vertex[] generateVertices(int graph, int[] map, Stat[][] statements) {\r
86                 int size = 0;\r
87                 for(int s=0;s<map.length;++s)\r
88                         if(map[s] == 0)\r
89                                 ++size;\r
90                 Vertex[] vertices = new Vertex[size];\r
91                 for(int s=0,i=0;s<map.length;++s)\r
92                         if(map[s] == 0) {\r
93                                 Stat[] ns = statements[s];\r
94                                 Stat[] stats = new Stat[ns.length];\r
95                                 for(int j=0;j<ns.length;++j) {\r
96                                         Stat n = ns[j];\r
97                                         stats[j] = new Stat(map[n.p], map[n.o]);\r
98                                 }\r
99                                 Arrays.sort(stats, Stat.STAT_COMPARATOR);\r
100                                 vertices[i++] = new Vertex(graph, s, 0, stats);\r
101                         }\r
102                 return vertices;\r
103         }\r
104         \r
105         private static void updateVertices(Vertex[] vertices, int[] map, Stat[][] statements) {\r
106                 for(int i=0;i<vertices.length;++i) {\r
107                         int s = vertices[i].original;\r
108                         Stat[] ns = statements[s];\r
109                         Stat[] stats = vertices[i].stats;\r
110                         for(int j=0;j<ns.length;++j) {\r
111                                 Stat n = ns[j];\r
112                                 Stat stat = stats[j];\r
113                                 stat.p = map[n.p];\r
114                                 stat.o = map[n.o];\r
115                         }\r
116                         Arrays.sort(stats, Stat.STAT_COMPARATOR);\r
117                 }\r
118         }\r
119         \r
120         private static Vertex[] concat(Vertex[] as, Vertex[] bs) {\r
121                 Vertex[] result = new Vertex[as.length + bs.length];\r
122                 System.arraycopy(as, 0, result, 0, as.length);\r
123                 System.arraycopy(bs, 0, result, as.length, bs.length);\r
124                 return result;\r
125         }\r
126         \r
127         static boolean equals(Stat[] stats1, Stat[] stats2) {\r
128                 if(stats1.length != stats2.length)\r
129                         return false;\r
130                 for(int i=0;i<stats1.length;++i) {\r
131                         Stat stat1 = stats1[i];\r
132                         Stat stat2 = stats2[i];\r
133                         if(stat1.p != stat2.p || stat1.o != stat2.o)\r
134                                 return false;\r
135                 }\r
136                 return true;\r
137         }\r
138         \r
139         private static boolean updatePositions(Vertex[] can) {\r
140                 boolean modified = false;\r
141                 int oldPos = can[0].pos;\r
142                 Vertex oldVertex = can[0];\r
143                 for(int i=1;i<can.length;++i) {\r
144                         Vertex curVertex = can[i];\r
145                         int curPos = curVertex.pos;\r
146                         if(curPos == oldPos) {\r
147                                 if(equals(oldVertex.stats, curVertex.stats))\r
148                                         curVertex.pos = oldVertex.pos;\r
149                                 else {\r
150                                         curVertex.pos = i;\r
151                                         modified = true;\r
152                                 }\r
153                         }\r
154                         else\r
155                                 oldPos = curPos;\r
156                         oldVertex = curVertex;\r
157                 }\r
158                 return modified;\r
159         }\r
160         \r
161         private static void updateMap(Vertex[] vertices, int[] map) {\r
162                 for(Vertex vertex : vertices)\r
163                         map[vertex.original] = vertex.pos;\r
164         }\r
165 \r
166         private static int[] groupPositions(Vertex[] can) {\r
167                 TIntArrayList result = new TIntArrayList();\r
168                 for(int i=0;i<can.length;++i)\r
169                         if(can[i].pos == i)\r
170                                 result.add(i);\r
171                 result.add(can.length);\r
172                 return result.toArray();\r
173         }       \r
174                 \r
175         static class TByteArrayIntHashMap extends TObjectIntHashMap<byte[]> {\r
176                 @Override\r
177                 protected boolean equals(Object one, Object two) {\r
178                         return Arrays.equals((byte[])one, (byte[])two);\r
179                 }\r
180                 \r
181                 @Override\r
182                 protected int hash(Object obj) {\r
183                         return Arrays.hashCode((byte[])obj);\r
184                 }\r
185         }\r
186         \r
187         private boolean separateByValues(Vertex[] can, int begin, int end, Variant[] aValues, Variant[] bValues) {              \r
188                 int valueCount = 0;\r
189                 TObjectIntHashMap<Variant> valueIds = new TObjectIntHashMap<Variant>();\r
190                 int[] ids = new int[end-begin];\r
191                 for(int i=begin;i<end;++i) {\r
192                         Vertex v = can[i];\r
193                         Variant value = v.graph==0 ? aValues[v.original] : bValues[v.original];\r
194                         int valueId = valueIds.get(value);\r
195                         if(valueId == 0) {\r
196                                 valueIds.put(value, ++valueCount);\r
197                                 ids[i-begin] = valueCount-1;\r
198                         }\r
199                         else\r
200                                 ids[i-begin] = valueId-1;\r
201                 }\r
202                 if(valueCount > 1) {\r
203                         Vertex[] vs = Arrays.copyOfRange(can, begin, end);\r
204                         int[] temp = new int[valueCount];\r
205                         for(int id : ids)\r
206                                 ++temp[id];\r
207                         int cur = begin;\r
208                         for(int i=0;i<temp.length;++i) {\r
209                                 int count = temp[i];\r
210                                 temp[i] = cur;\r
211                                 cur += count;\r
212                         }\r
213                         for(int i=0;i<ids.length;++i)\r
214                                 vs[i].pos = temp[ids[i]];\r
215                         for(int i=0;i<ids.length;++i)\r
216                                 can[temp[ids[i]]++] = vs[i];\r
217                         return true;\r
218                 }\r
219                 else\r
220                         return false;\r
221         }\r
222         \r
223         private boolean separateByValues(Vertex[] can, int[] groupPos, Variant[] aValues, Variant[] bValues) {\r
224                 boolean modified = false;\r
225                 for(int i=0;i<groupPos.length-1;++i) {\r
226                         int begin = groupPos[i];\r
227                         int end = groupPos[i+1];\r
228                         if(end - begin > 2)\r
229                                 modified |= separateByValues(can, begin, end, aValues, bValues);                                        \r
230                 }\r
231                 return modified;\r
232         }\r
233         \r
234         private boolean hasBigGroups(Vertex[] can, int[] groupPos) {\r
235                 for(int i=0;i<groupPos.length-1;++i) {\r
236                         int begin = groupPos[i];\r
237                         int end = groupPos[i+1];\r
238                         if(end - begin > 2 && can[begin].graph == 0 && can[end-1].graph == 1)\r
239                                 return true;\r
240                 }\r
241                 return false;\r
242         }\r
243         \r
244         static class UnionFind {\r
245                 int[] canonical;\r
246                 \r
247                 public UnionFind(int size) {\r
248                         canonical = new int[size];\r
249                         for(int i=0;i<size;++i)\r
250                                 canonical[i] = i;\r
251                 }\r
252                 \r
253                 public int canonical(int a) {\r
254                         int b = canonical[a];\r
255                         if(b != a) {\r
256                                 int c = canonical[b];\r
257                                 if(b != c) {\r
258                                         c = canonical(c);\r
259                                         canonical[b] = c;                                       \r
260                                         canonical[a] = c;\r
261                                         return c;\r
262                                 }\r
263                         }\r
264                         return b;\r
265                 }\r
266                 \r
267                 public void merge(int a, int b) {\r
268                         a = canonical(a);\r
269                         b = canonical(b);\r
270                         canonical[a] = b;\r
271                 }\r
272         }\r
273         \r
274         private static void guessIsomorphism(Vertex[] can, int[] groupPos) {\r
275                 UnionFind uf = new UnionFind(can.length);\r
276                 for(int i=0;i<can.length;++i) {\r
277                         uf.merge(i, can[i].pos);\r
278                         for(Stat stat : can[i].stats) {\r
279                                 if(stat.p >= 0)\r
280                                         uf.merge(i, stat.p);\r
281                                 if(stat.o >= 0)\r
282                                         uf.merge(i, stat.o);\r
283                         }\r
284                 }\r
285                 \r
286                 TIntHashSet done = new TIntHashSet();\r
287                 for(int i=0;i<groupPos.length-1;++i) {\r
288                         int begin = groupPos[i];\r
289                         int end = groupPos[i+1];\r
290                         if(end - begin > 2 && can[begin].graph == 0 && can[end-1].graph == 1) {\r
291                                 int c = uf.canonical(begin);\r
292                                 if(done.add(c)) {\r
293                                         int middle = begin+1;\r
294                                         while(can[middle].graph==0)\r
295                                                 ++middle;\r
296                                         int j;\r
297                                         for(j=0;begin+j < middle && middle+j < end;++j) {\r
298                                                 can[begin+j].pos = begin + j*2;\r
299                                                 can[middle+j].pos = begin + j*2;\r
300                                         }\r
301                                         int pos = begin+j*2;                                    \r
302                                         for(;begin+j < middle;++j)\r
303                                                 can[begin+j].pos = pos;\r
304                                         for(;middle+j < end;++j)\r
305                                                 can[middle+j].pos = pos;\r
306                                 }\r
307                         }\r
308                 }\r
309         }\r
310         \r
311         @Override\r
312         public void applyTo(GraphMatching matching) {\r
313                 if(matching.size == matching.aGraph.resourceCount ||\r
314                                 matching.size == matching.bGraph.resourceCount)\r
315                         return;\r
316                 long begin1 = System.nanoTime();\r
317                 int[] aMap = generateMapA(matching.aToB);\r
318                 int[] bMap = generateMapB(matching.bToA);\r
319                 Vertex[] aVertices = generateVertices(0,\r
320                                 aMap, matching.aGraph.statements);\r
321                 Vertex[] bVertices = generateVertices(1,\r
322                                 bMap, matching.bGraph.statements);\r
323                 Vertex[] can = concat(aVertices, bVertices);\r
324                 if(GraphMatching.TIMING)\r
325                         System.out.println("    Init:    " + (System.nanoTime()-begin1)*1e-6 + "ms");\r
326                 \r
327                 int[] groupPos = null;\r
328                 boolean doneSeparationByValues = false;\r
329                 while(true) {\r
330                         long begin2 = System.nanoTime();\r
331                         Arrays.sort(can, VERTEX_COMPARATOR);\r
332                         if(GraphMatching.TIMING)\r
333                                 System.out.println("    Sort:    " + (System.nanoTime()-begin2)*1e-6 + "ms");\r
334                         \r
335                         long begin3 = System.nanoTime();\r
336                         if(!updatePositions(can)) {                     \r
337                                 groupPos = groupPositions(can);                         \r
338                                 if(!hasBigGroups(can, groupPos))\r
339                                         break;\r
340                                 \r
341                                 boolean modified = false;\r
342                                 if(!doneSeparationByValues) {\r
343                                         modified = separateByValues(can, groupPos, matching.aGraph.values, matching.bGraph.values);                                                                             \r
344                                         doneSeparationByValues = true;\r
345                                         if(GraphMatching.TIMING)\r
346                                                 System.out.println("    - separate by values");\r
347                                 }\r
348                                 \r
349                                 if(!modified) {\r
350                                         guessIsomorphism(can, groupPos);\r
351                                         if(GraphMatching.TIMING)\r
352                                                 System.out.println("    - guess isomorphism");\r
353                                 }\r
354                         }\r
355                         if(GraphMatching.TIMING)\r
356                                 System.out.println("    Update1: " + (System.nanoTime()-begin3)*1e-6 + "ms");\r
357                         \r
358                         long begin4 = System.nanoTime();\r
359                         updateMap(aVertices, aMap);                     \r
360                         updateMap(bVertices, bMap);\r
361                         if(GraphMatching.TIMING)\r
362                                 System.out.println("    Update2: " + (System.nanoTime()-begin4)*1e-6 + "ms");                   \r
363                         long begin5 = System.nanoTime();\r
364                         updateVertices(aVertices, aMap, matching.aGraph.statements);\r
365                         updateVertices(bVertices, bMap, matching.bGraph.statements);\r
366                         if(GraphMatching.TIMING)\r
367                                 System.out.println("    Update3: " + (System.nanoTime()-begin5)*1e-6 + "ms");\r
368                 }\r
369                 \r
370                 for(int i=0;i<groupPos.length-1;++i) {\r
371                         int begin = groupPos[i];\r
372                         int end = groupPos[i+1];\r
373                         if(end - begin == 2) {\r
374                                 Vertex a = can[begin];\r
375                                 Vertex b = can[end-1];\r
376                                 if(a.graph == 0 && b.graph == 1)\r
377                                         matching.map(a.original, b.original);\r
378                         }\r
379                 }\r
380                 \r
381                 if(GraphMatching.DEBUG)\r
382                         for(int i=0;i<groupPos.length-1;++i) {\r
383                                 int begin = groupPos[i];\r
384                                 int end = groupPos[i+1];\r
385                                 if(end - begin > 2) {                           \r
386                                         System.out.println("----");\r
387                                         for(int j=begin;j<end;++j) {\r
388                                                 if(can[j].graph == 0) {\r
389                                                         int org = can[j].original;\r
390                                                         String name = matching.aGraph.names[org];\r
391                                                         System.out.println(name + " (A)");\r
392                                                         for(Stat stat : matching.aGraph.statements[org])\r
393                                                                 System.out.println("    " + stat.toString(matching.aGraph.names));\r
394                                                         Variant value = matching.aGraph.values[org];\r
395                                                         if(value != null)\r
396                                                                 System.out.println("    " + value);\r
397                                                 }\r
398                                                 else {\r
399                                                         int org = can[j].original;\r
400                                                         String name = matching.bGraph.names[org];\r
401                                                         System.out.println(name + " (B)");\r
402                                                         for(Stat stat : matching.bGraph.statements[org])\r
403                                                                 System.out.println("    " + stat.toString(matching.bGraph.names));\r
404                                                         Variant value = matching.bGraph.values[org];\r
405                                                         if(value != null)\r
406                                                                 System.out.println("    " + value);\r
407                                                 }\r
408                                         }\r
409                                 }\r
410                         }\r
411         }\r
412 }\r