Edinburgh Speech Tools  2.4-release
wagon_aux.cc
1 /*************************************************************************/
2 /* */
3 /* Centre for Speech Technology Research */
4 /* University of Edinburgh, UK */
5 /* Copyright (c) 1996,1997 */
6 /* All Rights Reserved. */
7 /* */
8 /* Permission is hereby granted, free of charge, to use and distribute */
9 /* this software and its documentation without restriction, including */
10 /* without limitation the rights to use, copy, modify, merge, publish, */
11 /* distribute, sublicense, and/or sell copies of this work, and to */
12 /* permit persons to whom this work is furnished to do so, subject to */
13 /* the following conditions: */
14 /* 1. The code must retain the above copyright notice, this list of */
15 /* conditions and the following disclaimer. */
16 /* 2. Any modifications must be clearly marked as such. */
17 /* 3. Original authors' names are not deleted. */
18 /* 4. The authors' names are not used to endorse or promote products */
19 /* derived from this software without specific prior written */
20 /* permission. */
21 /* */
22 /* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
23 /* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
24 /* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
25 /* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
26 /* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
27 /* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
28 /* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
29 /* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
30 /* THIS SOFTWARE. */
31 /* */
32 /*************************************************************************/
33 /* Author : Alan W Black */
34 /* Date : May 1996 */
35 /*-----------------------------------------------------------------------*/
36 /* */
37 /* Various method functions */
38 /*=======================================================================*/
39 
40 #include <cstdlib>
41 #include <iostream>
42 #include <cstring>
43 #include "EST_unix.h"
44 #include "EST_cutils.h"
45 #include "EST_Token.h"
46 #include "EST_Wagon.h"
47 #include "EST_math.h"
48 
49 
50 EST_Val WNode::predict(const WVector &d)
51 {
52  if (leaf())
53  return impurity.value();
54  else if (question.ask(d))
55  return left->predict(d);
56  else
57  return right->predict(d);
58 }
59 
60 WNode *WNode::predict_node(const WVector &d)
61 {
62  if (leaf())
63  return this;
64  else if (question.ask(d))
65  return left->predict_node(d);
66  else
67  return right->predict_node(d);
68 }
69 
70 int WNode::pure(void)
71 {
72  // A node is pure if it has no sub-nodes or its not of type class
73 
74  if ((left == 0) && (right == 0))
75  return TRUE;
76  else if (get_impurity().type() != wnim_class)
77  return TRUE;
78  else
79  return FALSE;
80 }
81 
82 void WNode::prune(void)
83 {
84  // Check all sub-nodes and if they are all of the same class
85  // delete their sub nodes. Returns pureness of this node
86 
87  if (pure() == FALSE)
88  {
89  // Ok lets try and make it pure
90  if (left != 0) left->prune();
91  if (right != 0) right->prune();
92 
93  // Have to check purity as well as values to ensure left and right
94  // don't further split
95  if ((left->pure() == TRUE) && ((right->pure() == TRUE)) &&
96  (left->get_impurity().value() == right->get_impurity().value()))
97  {
98  delete left; left = 0;
99  delete right; right = 0;
100  }
101  }
102 
103 }
104 
105 void WNode::held_out_prune()
106 {
107  // prune tree with held out data
108  // Check if node's questions differentiates for the held out data
109  // if not, prune all sub_nodes
110 
111  // Rescore with prune data
112  set_impurity(WImpurity(get_data())); // for this new data
113 
114  if (left != 0)
115  {
116  wgn_score_question(question,get_data());
117  if (question.get_score() < get_impurity().measure())
118  { // its worth goint ot the next level
119  wgn_find_split(question,get_data(),
120  left->get_data(),
121  right->get_data());
122  left->held_out_prune();
123  right->held_out_prune();
124  }
125  else
126  { // not worth the split so prune both sub_nodes
127  delete left; left = 0;
128  delete right; right = 0;
129  }
130  }
131 }
132 
133 void WNode::print_out(ostream &s, int margin)
134 {
135  int i;
136 
137  s << endl;
138  for (i=0;i<margin;i++) s << " ";
139  s << "(";
140  if (left==0) // base case
141  s << impurity;
142  else
143  {
144  s << question;
145  left->print_out(s,margin+1);
146  right->print_out(s,margin+1);
147  }
148  s << ")";
149 }
150 
151 ostream & operator <<(ostream &s, WNode &n)
152 {
153  // Output this node and its sub-node
154 
155  n.print_out(s,0);
156  s << endl;
157  return s;
158 }
159 
160 void WDataSet::ignore_non_numbers()
161 {
162  /* For ols we want to ignore anything that is categorial */
163  int i;
164 
165  for (i=0; i<dlength; i++)
166  {
167  if ((p_type[i] == wndt_binary) ||
168  (p_type[i] == wndt_float))
169  continue;
170  else
171  {
172  p_ignore[i] = TRUE;
173  }
174  }
175 
176  return;
177 }
178 
179 void WDataSet::load_description(const EST_String &fname, LISP ignores)
180 {
181  // Initialise a dataset with sizes and types
182  EST_String tname;
183  int i;
184  LISP description,d;
185 
186  description = car(vload(fname,1));
187  dlength = siod_llength(description);
188 
189  p_type.resize(dlength);
190  p_ignore.resize(dlength);
191  p_name.resize(dlength);
192 
193  if (wgn_predictee_name == "")
194  wgn_predictee = 0; // default predictee is first field
195  else
196  wgn_predictee = -1;
197 
198  for (i=0,d=description; d != NIL; d=cdr(d),i++)
199  {
200  p_name[i] = get_c_string(car(car(d)));
201  tname = get_c_string(car(cdr(car(d))));
202  p_ignore[i] = FALSE;
203  if ((wgn_predictee_name != "") && (wgn_predictee_name == p_name[i]))
204  wgn_predictee = i;
205  if ((wgn_count_field_name != "") &&
206  (wgn_count_field_name == p_name[i]))
207  wgn_count_field = i;
208  if ((tname == "count") || (i == wgn_count_field))
209  {
210  // The count must be ignored, repeat it if you want it too
211  p_type[i] = wndt_ignore; // the count must be ignored
212  p_ignore[i] = TRUE;
213  wgn_count_field = i;
214  }
215  else if ((tname == "ignore") || (siod_member_str(p_name[i],ignores)))
216  {
217  p_type[i] = wndt_ignore; // user specified ignore
218  p_ignore[i] = TRUE;
219  if (i == wgn_predictee)
220  wagon_error(EST_String("predictee \"")+p_name[i]+
221  "\" can't be ignored \n");
222  }
223  else if (siod_llength(car(d)) > 2)
224  {
225  LISP rest = cdr(car(d));
226  EST_StrList sl;
227  siod_list_to_strlist(rest,sl);
228  p_type[i] = wgn_discretes.def(sl);
229  if (streq(get_c_string(car(rest)),"_other_"))
230  wgn_discretes[p_type[i]].def_val("_other_");
231  }
232  else if (tname == "binary")
233  p_type[i] = wndt_binary;
234  else if (tname == "cluster")
235  p_type[i] = wndt_cluster;
236  else if (tname == "vector")
237  p_type[i] = wndt_vector;
238  else if (tname == "trajectory")
239  p_type[i] = wndt_trajectory;
240  else if (tname == "ols")
241  p_type[i] = wndt_ols;
242  else if (tname == "matrix")
243  p_type[i] = wndt_matrix;
244  else if (tname == "float")
245  p_type[i] = wndt_float;
246  else
247  {
248  wagon_error(EST_String("Unknown type \"")+tname+
249  "\" for field number "+itoString(i)+
250  "/"+p_name[i]+" in description file \""+fname+"\"");
251  }
252  }
253 
254  if (wgn_predictee == -1)
255  {
256  wagon_error(EST_String("predictee field \"")+wgn_predictee_name+
257  "\" not found in description ");
258  }
259 }
260 
261 const int WQuestion::ask(const WVector &w) const
262 {
263  // Ask this question of the given vector
264  switch (op)
265  {
266  case wnop_equal: // for numbers
267  if (w.get_flt_val(feature_pos) == operand1.Float())
268  return TRUE;
269  else
270  return FALSE;
271  case wnop_binary: // for numbers
272  if (w.get_int_val(feature_pos) == 1)
273  return TRUE;
274  else
275  return FALSE;
276  case wnop_greaterthan:
277  if (w.get_flt_val(feature_pos) > operand1.Float())
278  return TRUE;
279  else
280  return FALSE;
281  case wnop_lessthan:
282  if (w.get_flt_val(feature_pos) < operand1.Float())
283  return TRUE;
284  else
285  return FALSE;
286  case wnop_is: // for classes
287  if (w.get_int_val(feature_pos) == operand1.Int())
288  return TRUE;
289  else
290  return FALSE;
291  case wnop_in: // for subsets -- note operand is list of ints
292  if (ilist_member(operandl,w.get_int_val(feature_pos)))
293  return TRUE;
294  else
295  return FALSE;
296  default:
297  wagon_error("Unknown test operator");
298  }
299 
300  return FALSE;
301 }
302 
303 ostream& operator<<(ostream& s, const WQuestion &q)
304 {
305  EST_String name;
306  static EST_Regex needquotes(".*[()'\";., \t\n\r].*");
307 
308  s << "(" << wgn_dataset.feat_name(q.get_fp());
309  switch (q.get_op())
310  {
311  case wnop_equal:
312  s << " = " << q.get_operand1().string();
313  break;
314  case wnop_binary:
315  break;
316  case wnop_greaterthan:
317  s << " > " << q.get_operand1().Float();
318  break;
319  case wnop_lessthan:
320  s << " < " << q.get_operand1().Float();
321  break;
322  case wnop_is:
323  name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
324  name(q.get_operand1().Int());
325  s << " is ";
326  if (name.matches(needquotes))
327  s << quote_string(name,"\"","\\",1);
328  else
329  s << name;
330  break;
331  case wnop_matches:
332  name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
333  name(q.get_operand1().Int());
334  s << " matches " << quote_string(name,"\"","\\",1);
335  break;
336  case wnop_in:
337  s << " in (";
338  for (int l=0; l < q.get_operandl().length(); l++)
339  {
340  name = wgn_discretes[wgn_dataset.ftype(q.get_fp())].
341  name(q.get_operandl().nth(l));
342  if (name.matches(needquotes))
343  s << quote_string(name,"\"","\\",1);
344  else
345  s << name;
346  s << " ";
347  }
348  s << ")";
349  break;
350  // SunCC wont let me add this
351 // default:
352 // s << " unknown operation ";
353  }
354  s << ")";
355 
356  return s;
357 }
358 
359 EST_Val WImpurity::value(void)
360 {
361  // Returns the recommended value for this
362  EST_String s;
363  double prob;
364 
365  if (t==wnim_unset)
366  {
367  cerr << "WImpurity: no value currently set\n";
368  return EST_Val(0.0);
369  }
370  else if (t==wnim_class)
371  return EST_Val(p.most_probable(&prob));
372  else if (t==wnim_cluster)
373  return EST_Val(a.mean());
374  else if (t==wnim_ols) /* OLS TBA */
375  return EST_Val(a.mean());
376  else if (t==wnim_vector)
377  return EST_Val(a.mean()); /* wnim_vector */
378  else if (t==wnim_trajectory)
379  return EST_Val(a.mean()); /* NOT YET WRITTEN */
380  else
381  return EST_Val(a.mean());
382 }
383 
384 double WImpurity::samples(void)
385 {
386  if (t==wnim_float)
387  return a.samples();
388  else if (t==wnim_class)
389  return (int)p.samples();
390  else if (t==wnim_cluster)
391  return members.length();
392  else if (t==wnim_ols)
393  return members.length();
394  else if (t==wnim_vector)
395  return members.length();
396  else if (t==wnim_trajectory)
397  return members.length();
398  else
399  return 0;
400 }
401 
402 WImpurity::WImpurity(const WVectorVector &ds)
403 {
404  int i;
405 
406  t=wnim_unset;
407  a.reset(); trajectory=0; l=0; width=0;
408  data = &ds; // for ols, model calculation
409  for (i=0; i < ds.n(); i++)
410  {
411  if (t == wnim_ols)
412  cumulate(i,1);
413  else if (wgn_count_field == -1)
414  cumulate((*(ds(i)))[wgn_predictee],1);
415  else
416  cumulate((*(ds(i)))[wgn_predictee],
417  (*(ds(i)))[wgn_count_field]);
418  }
419 }
420 
421 float WImpurity::measure(void)
422 {
423  if (t == wnim_float)
424  return a.variance()*a.samples();
425  else if (t == wnim_vector)
426  return vector_impurity();
427  else if (t == wnim_trajectory)
428  return trajectory_impurity();
429  else if (t == wnim_matrix)
430  return a.variance()*a.samples();
431  else if (t == wnim_class)
432  return p.entropy()*p.samples();
433  else if (t == wnim_cluster)
434  return cluster_impurity();
435  else if (t == wnim_ols)
436  return ols_impurity(); /* RMSE for OLS model */
437  else
438  {
439  cerr << "WImpurity: can't measure unset object" << endl;
440  return 0.0;
441  }
442 }
443 
444 float WImpurity::vector_impurity()
445 {
446  // Find the mean/stddev for all values in all vectors
447  // sum the variances and multiply them by the number of members
448  EST_Litem *pp;
449  EST_Litem *countpp;
450  int i,j;
451  EST_SuffStats b;
452  double count = 1;
453 
454  a.reset();
455 #if 1
456  /* simple distance */
457  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
458  {
459  if (wgn_VertexFeats.a(0,j) > 0.0)
460  {
461  b.reset();
462  for (pp=members.head(), countpp=member_counts.head(); pp != 0; pp=pp->next(), countpp=countpp->next())
463  {
464  i = members.item(pp);
465 
466  // Accumulate the value with count
467  b.cumulate(wgn_VertexTrack.a(i,j), member_counts.item(countpp)) ;
468  }
469  a += b.stddev();
470  count = b.samples();
471  }
472  }
473 #endif
474 
475 #if 0
476  EST_SuffStats *c;
477  float x, lshift, rshift, ushift;
478  /* Find base mean, then measure do fshift to find best match */
479  c = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
480  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
481  {
482  if (wgn_VertexFeats.a(0,j) > 0.0)
483  {
484  c[j].reset();
485  for (pp=members.head(), countpp=member_counts.head(); pp != 0;
486  pp=pp->next(), countpp=countpp->next())
487  {
488  i = members.item(pp);
489  // Accumulate the value with count
490  c[j].cumulate(wgn_VertexTrack.a(i,j),member_counts.item(countpp));
491  }
492  count = c[j].samples();
493  }
494  }
495 
496  /* Pass through again but vary the num_channels offset (hardcoded) */
497  for (pp=members.head(), countpp=member_counts.head(); pp != 0;
498  pp=pp->next(), countpp=countpp->next())
499  {
500  int q;
501  float bshift, qshift;
502  /* For each sample */
503  i = members.item(pp);
504  /* Find the value left shifted, unshifted, and right shifted */
505  lshift = 0; ushift = 0; rshift = 0;
506  bshift = 0;
507  for (q=-20; q<=20; q++)
508  {
509  qshift = 0;
510  for (j=67+q; j<147+q/*hardcoded*/; j++)
511  {
512  x = c[j].mean() - wgn_VertexTrack(i,j);
513  qshift += sqrt(x*x);
514  if ((bshift > 0) && (qshift > bshift))
515  break;
516  }
517  if ((bshift == 0) || (qshift < bshift))
518  bshift = qshift;
519  }
520  a += bshift;
521  }
522 
523 #endif
524 
525 #if 0
526  /* full covariance */
527  /* worse in listening experiments */
528  EST_SuffStats **cs;
529  int mmm;
530  cs = new EST_SuffStats *[wgn_VertexTrack.num_channels()+1];
531  for (j=0; j<=wgn_VertexTrack.num_channels(); j++)
532  cs[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
533  /* Find means for diagonal */
534  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
535  {
536  if (wgn_VertexFeats.a(0,j) > 0.0)
537  {
538  for (pp=members.head(); pp != 0; pp=pp->next())
539  cs[j][j] += wgn_VertexTrack.a(members.item(pp),j);
540  }
541  }
542  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
543  {
544  for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
545  if (wgn_VertexFeats.a(0,j) > 0.0)
546  {
547  for (pp=members.head(); pp != 0; pp=pp->next())
548  {
549  mmm = members.item(pp);
550  cs[i][j] += (wgn_VertexTrack.a(mmm,i)-cs[j][j].mean())*
551  (wgn_VertexTrack.a(mmm,j)-cs[j][j].mean());
552  }
553  }
554  }
555  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
556  {
557  for (i=j+1; i<wgn_VertexFeats.num_channels(); i++)
558  if (wgn_VertexFeats.a(0,j) > 0.0)
559  a += cs[i][j].stddev();
560  }
561  count = cs[0][0].samples();
562 #endif
563 
564 #if 0
565  // look at mean euclidean distance between vectors
566  EST_Litem *qq;
567  int x,y;
568  double d,q;
569  count = 0;
570  for (pp=members.head(); pp != 0; pp=pp->next())
571  {
572  x = members.item(pp);
573  count++;
574  for (qq=pp->next(); qq != 0; qq=qq->next())
575  {
576  y = members.item(qq);
577  for (q=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
578  if (wgn_VertexFeats.a(0,j) > 0.0)
579  {
580  d = wgn_VertexTrack(x,j)-wgn_VertexTrack(y,j);
581  q += d*d;
582  }
583  a += sqrt(q);
584  }
585 
586  }
587 #endif
588 
589  // This is sum of stddev*samples
590  return a.mean() * count;
591 }
592 
593 WImpurity::~WImpurity()
594 {
595  int j;
596 
597  if (trajectory != 0)
598  {
599  for (j=0; j<l; j++)
600  delete [] trajectory[j];
601  delete [] trajectory;
602  trajectory = 0;
603  l = 0;
604  }
605 }
606 
607 
608 float WImpurity::trajectory_impurity()
609 {
610  // Find the mean length of all the units in the cluster
611  // Create that number of points
612  // Interpolate each unit to that number of points
613  // collect means and standard deviations for each point
614  // impurity is sum of the variance for each point and each coef
615  // multiplied by the number of units.
616  EST_Litem *pp;
617  int i, j;
618  int s, ti, ni, q;
619  int s1l, s2l;
620  double n, m, m1, m2, w;
621  EST_SuffStats lss, stdss;
622  EST_SuffStats l1ss, l2ss;
623  int l1, l2;
624  int ola=0;
625 
626  if (trajectory != 0)
627  { /* already done this */
628  return score;
629  }
630 
631  lss.reset();
632  l = 0;
633  for (pp=members.head(); pp != 0; pp=pp->next())
634  {
635  i = members.item(pp);
636  for (q=0; q<wgn_UnitTrack.a(i,1); q++)
637  {
638  ni = (int)wgn_UnitTrack.a(i,0)+q;
639  if (wgn_VertexTrack.a(ni,0) == -1.0)
640  {
641  l1ss += q;
642  ola = 1;
643  break;
644  }
645  }
646  if (q==wgn_UnitTrack.a(i,1))
647  { /* can't find -1 center point, so put all in l2 */
648  l1ss += 0;
649  l2ss += q;
650  }
651  else
652  l2ss += wgn_UnitTrack.a(i,1) - (q+1) - 1;
653  lss += wgn_UnitTrack.a(i,1); /* length of each unit in the cluster */
654  if (wgn_UnitTrack.a(i,1) > l)
655  l = (int)wgn_UnitTrack.a(i,1);
656  }
657 
658  if (ola==0) /* no -1's so its not an ola type cluster */
659  {
660  l = ((int)lss.mean() < 7) ? 7 : (int)lss.mean();
661 
662  /* a list of SuffStats on for each point in the trajectory */
663  trajectory = new EST_SuffStats *[l];
664  width = wgn_VertexTrack.num_channels()+1;
665  for (j=0; j<l; j++)
666  trajectory[j] = new EST_SuffStats[width];
667 
668  for (pp=members.head(); pp != 0; pp=pp->next())
669  { /* for each unit */
670  i = members.item(pp);
671  m = (float)wgn_UnitTrack.a(i,1)/(float)l; /* find interpolation */
672  s = (int)wgn_UnitTrack.a(i,0); /* start point */
673  for (ti=0,n=0.0; ti<l; ti++,n+=m)
674  {
675  ni = (int)n; // hmm floor or nint ??
676  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
677  {
678  if (wgn_VertexFeats.a(0,j) > 0.0)
679  trajectory[ti][j] += wgn_VertexTrack.a(s+ni,j);
680  }
681  }
682  }
683 
684  /* find sum of sum of stddev for all coefs of all traj points */
685  stdss.reset();
686  for (ti=0; ti<l; ti++)
687  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
688  {
689  if (wgn_VertexFeats.a(0,j) > 0.0)
690  stdss += trajectory[ti][j].stddev();
691  }
692 
693  // This is sum of all stddev * samples
694  score = stdss.mean() * members.length();
695  }
696  else
697  { /* OLA model */
698  l1 = (l1ss.mean() < 10.0) ? 10 : (int)l1ss.mean();
699  l2 = (l2ss.mean() < 10.0) ? 10 : (int)l2ss.mean();
700  l = l1 + l2 + 1 + 1;
701 
702  /* a list of SuffStats on for each point in the trajectory */
703  trajectory = new EST_SuffStats *[l];
704  for (j=0; j<l; j++)
705  trajectory[j] = new EST_SuffStats[wgn_VertexTrack.num_channels()+1];
706 
707  for (pp=members.head(); pp != 0; pp=pp->next())
708  { /* for each unit */
709  i = members.item(pp);
710  s1l = 0;
711  s = (int)wgn_UnitTrack.a(i,0); /* start point */
712  for (q=0; q<wgn_UnitTrack.a(i,1); q++)
713  if (wgn_VertexTrack.a(s+q,0) == -1.0)
714  {
715  s1l = q; /* printf("awb q is -1 at %d\n",q); */
716  break;
717  }
718  s2l = (int)wgn_UnitTrack.a(i,1) - (s1l + 2);
719  m1 = (float)(s1l)/(float)l1; /* find interpolation step */
720  m2 = (float)(s2l)/(float)l2; /* find interpolation step */
721  /* First half */
722  for (ti=0,n=0.0; s1l > 0 && ti<l1; ti++,n+=m1)
723  {
724  ni = s + (((int)n < s1l) ? (int)n : s1l - 1);
725  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
726  if (wgn_VertexFeats.a(0,j) > 0.0)
727  trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
728  }
729  ti = l1; /* do it explicitly in case s1l < 1 */
730  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
731  if (wgn_VertexFeats.a(0,j) > 0.0)
732  trajectory[ti][j] += -1;
733  /* Second half */
734  s += s1l+1;
735  for (ti++,n=0.0; s2l > 0 && ti<l-1; ti++,n+=m2)
736  {
737  ni = s + (((int)n < s2l) ? (int)n : s2l - 1);
738  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
739  if (wgn_VertexFeats.a(0,j) > 0.0)
740  trajectory[ti][j] += wgn_VertexTrack.a(ni,j);
741  }
742  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
743  if (wgn_VertexFeats.a(0,j) > 0.0)
744  trajectory[ti][j] += -2;
745  }
746 
747  /* find sum of sum of stddev for all coefs of all traj points */
748  /* windowing the sums with a triangular weight window */
749  stdss.reset();
750  m = 1.0/(float)l1;
751  for (w=0.0,ti=0; ti<l1; ti++,w+=m)
752  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
753  if (wgn_VertexFeats.a(0,j) > 0.0)
754  stdss += trajectory[ti][j].stddev() * w;
755  m = 1.0/(float)l2;
756  for (w=1.0,ti++; ti<l-1; ti++,w-=m)
757  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
758  if (wgn_VertexFeats.a(0,j) > 0.0)
759  stdss += trajectory[ti][j].stddev() * w;
760 
761  // This is sum of all stddev * samples
762  score = stdss.mean() * members.length();
763  }
764  return score;
765 }
766 
767 static void part_to_ols_data(EST_FMatrix &X, EST_FMatrix &Y,
768  EST_IVector &included,
769  EST_StrList &feat_names,
770  const EST_IList &members,
771  const WVectorVector &d)
772 {
773  int m,n,p;
774  int w, xm=0;
775  EST_Litem *pp;
776  WVector *wv;
777 
778  w = wgn_dataset.width();
779  included.resize(w);
780  X.resize(members.length(),w);
781  Y.resize(members.length(),1);
782  feat_names.append("Intercept");
783  included[0] = TRUE;
784 
785  for (p=0,pp=members.head(); pp; p++,pp=pp->next())
786  {
787  n = members.item(pp);
788  if (n < 0)
789  {
790  p--;
791  continue;
792  }
793  wv = d(n);
794  Y.a_no_check(p,0) = (*wv)[0];
795  X.a_no_check(p,0) = 1;
796  for (m=1,xm=1; m < w; m++)
797  {
798  if (wgn_dataset.ftype(m) == wndt_float)
799  {
800  if (p == 0) // only do this once
801  {
802  feat_names.append(wgn_dataset.feat_name(m));
803  }
804  X.a_no_check(p,xm) = (*wv)[m];
805  included.a_no_check(xm) = FALSE;
806  included.a_no_check(xm) = TRUE;
807  xm++;
808  }
809  }
810  }
811 
812  included.resize(xm);
813  X.resize(p,xm);
814  Y.resize(p,1);
815 }
816 
817 float WImpurity::ols_impurity()
818 {
819  // Build an OLS model for the current data and measure it against
820  // the data itself and give a RMSE
821  EST_FMatrix X,Y;
822  EST_IVector included;
823  EST_FMatrix coeffs;
824  EST_StrList feat_names;
825  float best_score;
826  EST_FMatrix coeffsl;
827  EST_FMatrix pred;
828  float cor,rmse;
829 
830  // Load the sample members into matrices for ols
831  part_to_ols_data(X,Y,included,feat_names,members,*data);
832 
833  // Find the best ols model.
834  // Far too computationally expensive
835  // if (!stepwise_ols(X,Y,feat_names,0.0,coeffs,
836  // X,Y,included,best_score))
837  // return WGN_HUGE_VAL; // couldn't find a model
838 
839  // Non stepwise model
840  if (!robust_ols(X,Y,included,coeffsl))
841  {
842  // printf("no robust ols\n");
843  return WGN_HUGE_VAL;
844  }
845  ols_apply(X,coeffsl,pred);
846  ols_test(Y,pred,cor,rmse);
847  best_score = cor;
848 
849  printf("Impurity OLS X(%d,%d) Y(%d,%d) %f, %f, %f\n",
850  X.num_rows(),X.num_columns(),Y.num_rows(),Y.num_columns(),
851  rmse,cor,
852  1-best_score);
853  if (fabs(coeffsl[0]) > 10000)
854  {
855  // printf("weird sized Intercept %f\n",coeffsl[0]);
856  return WGN_HUGE_VAL;
857  }
858 
859  return (1-best_score) *members.length();
860 }
861 
862 float WImpurity::cluster_impurity()
863 {
864  // Find the mean distance between all members of the dataset
865  // Uses the global DistMatrix for distances between members of
866  // the cluster set. Distances are assumed to be symmetric thus only
867  // the bottom half of the distance matrix is filled
868  EST_Litem *pp, *q;
869  int i,j;
870  double dist;
871 
872  a.reset();
873  for (pp=members.head(); pp != 0; pp=pp->next())
874  {
875  i = members.item(pp);
876  for (q=pp->next(); q != 0; q=q->next())
877  {
878  j = members.item(q);
879  dist = (j < i ? wgn_DistMatrix.a_no_check(i,j) :
880  wgn_DistMatrix.a_no_check(j,i));
881  a+=dist; // cumulate for whole cluster
882  }
883  }
884 
885  // This is sum distance between cross product of members
886 // return a.sum();
887  if (a.samples() > 1)
888  return a.stddev() * a.samples();
889  else
890  return 0.0;
891 }
892 
893 float WImpurity::cluster_distance(int i)
894 {
895  // Distance this unit is from all others in this cluster
896  // in absolute standard deviations from the the mean.
897  float dist = cluster_member_mean(i);
898  float mdist = dist-a.mean();
899 
900  if (mdist == 0.0)
901  return 0.0;
902  else
903  return fabs((dist-a.mean())/a.stddev());
904 
905 }
906 
907 int WImpurity::in_cluster(int i)
908 {
909  // Would this be a member of this cluster?. Returns 1 if
910  // its distance is less than at least one other
911  float dist = cluster_member_mean(i);
912  EST_Litem *pp;
913 
914  for (pp=members.head(); pp != 0; pp=pp->next())
915  {
916  if (dist < cluster_member_mean(members.item(pp)))
917  return 1;
918  }
919  return 0;
920 }
921 
922 float WImpurity::cluster_ranking(int i)
923 {
924  // Position in ranking closest to centre
925  float dist = cluster_distance(i);
926  EST_Litem *pp;
927  int ranking = 1;
928 
929  for (pp=members.head(); pp != 0; pp=pp->next())
930  {
931  if (dist >= cluster_distance(members.item(pp)))
932  ranking++;
933  }
934 
935  return ranking;
936 }
937 
938 float WImpurity::cluster_member_mean(int i)
939 {
940  // Returns the mean difference between this member and all others
941  // in cluster
942  EST_Litem *q;
943  int j,n;
944  double dist,sum;
945 
946  for (sum=0.0,n=0,q=members.head(); q != 0; q=q->next())
947  {
948  j = members.item(q);
949  if (i != j)
950  {
951  dist = (j < i ? wgn_DistMatrix(i,j) : wgn_DistMatrix(j,i));
952  sum += dist;
953  n++;
954  }
955  }
956 
957  return ( n == 0 ? 0.0 : sum/n );
958 }
959 
960 void WImpurity::cumulate(const float pv,double count)
961 {
962  // Cumulate data for impurity calculation
963 
964  if (wgn_dataset.ftype(wgn_predictee) == wndt_cluster)
965  {
966  t = wnim_cluster;
967  members.append((int)pv);
968  }
969  else if (wgn_dataset.ftype(wgn_predictee) == wndt_ols)
970  {
971  t = wnim_ols;
972  members.append((int)pv);
973  }
974  else if (wgn_dataset.ftype(wgn_predictee) == wndt_vector)
975  {
976  t = wnim_vector;
977 
978  // AUP: Implement counts in vectors
979  members.append((int)pv);
980  member_counts.append((float)count);
981  }
982  else if (wgn_dataset.ftype(wgn_predictee) == wndt_trajectory)
983  {
984  t = wnim_trajectory;
985  members.append((int)pv);
986  }
987  else if (wgn_dataset.ftype(wgn_predictee) >= wndt_class)
988  {
989  if (t == wnim_unset)
990  p.init(&wgn_discretes[wgn_dataset.ftype(wgn_predictee)]);
991  t = wnim_class;
992  p.cumulate((int)pv,count);
993  }
994  else if (wgn_dataset.ftype(wgn_predictee) == wndt_binary)
995  {
996  t = wnim_float;
997  a.cumulate((int)pv,count);
998  }
999  else if (wgn_dataset.ftype(wgn_predictee) == wndt_float)
1000  {
1001  t = wnim_float;
1002  a.cumulate(pv,count);
1003  }
1004  else
1005  {
1006  wagon_error("WImpurity: cannot cumulate EST_Val type");
1007  }
1008 }
1009 
1010 ostream & operator <<(ostream &s, WImpurity &imp)
1011 {
1012  int j,i;
1013  EST_SuffStats b;
1014 
1015  if (imp.t == wnim_float)
1016  s << "(" << imp.a.stddev() << " " << imp.a.mean() << ")";
1017  else if (imp.t == wnim_vector)
1018  {
1019  EST_Litem *p, *countp;
1020  s << "((";
1021  imp.vector_impurity();
1022  if (wgn_vertex_output == "mean") //output means
1023  {
1024  for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1025  {
1026  b.reset();
1027  for (p=imp.members.head(), countp=imp.member_counts.head(); p != 0; p=p->next(), countp=countp->next())
1028  {
1029  // Accumulate the members with their counts
1030  b.cumulate(wgn_VertexTrack.a(imp.members.item(p),j), imp.member_counts.item(countp));
1031  //b += wgn_VertexTrack.a(imp.members.item(p),j);
1032  }
1033  s << "(" << b.mean() << " ";
1034  if (isfinite(b.stddev()))
1035  s << b.stddev() << ")";
1036  else
1037  s << "0.001" << ")";
1038  if (j+1<wgn_VertexTrack.num_channels())
1039  s << " ";
1040  }
1041  }
1042  else /* output best in the cluster */
1043  {
1044  /* print out vector closest to center, rather than average */
1045  /* printf("awb_debug outputing best\n"); */
1046  double best = WGN_HUGE_VAL;
1047  double x,d;
1048  int bestp = 0;
1049  EST_SuffStats *cs;
1050 
1051  cs = new EST_SuffStats [wgn_VertexTrack.num_channels()+1];
1052 
1053  for (j=0; j<wgn_VertexFeats.num_channels(); j++)
1054  {
1055  cs[j].reset();
1056  for (p=imp.members.head(); p != 0; p=p->next())
1057  {
1058  cs[j] += wgn_VertexTrack.a(imp.members.item(p),j);
1059  }
1060  }
1061 
1062  for (p=imp.members.head(); p != 0; p=p->next())
1063  {
1064  for (x=0.0,j=0; j<wgn_VertexFeats.num_channels(); j++)
1065  if (wgn_VertexFeats.a(0,j) > 0.0)
1066  {
1067  d = (wgn_VertexTrack.a(imp.members.item(p),j)-cs[j].mean())
1068  /* / cs[j].stddev() */ ; /* seems worse 061218 */
1069  x += d*d;
1070  }
1071  if (x < best)
1072  {
1073  /* printf("awb_debug updating best %d %f %d %f\n",
1074  bestp, best, imp.members.item(p), x); */
1075  bestp = imp.members.item(p);
1076  best = x;
1077  }
1078  }
1079  for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1080  {
1081  s << "( ";
1082  s << wgn_VertexTrack.a(bestp,j);
1083  // s << " 0 "; // fake stddev
1084  s << " ";
1085  if (isfinite(cs[j].stddev()))
1086  s << cs[j].stddev();
1087  else
1088  s << "0";
1089  s << " ) ";
1090  if (j+1<wgn_VertexTrack.num_channels())
1091  s << " ";
1092  }
1093 
1094  delete [] cs;
1095  }
1096  s << ") ";
1097  s << imp.a.mean() << ")";
1098  }
1099  else if (imp.t == wnim_trajectory)
1100  {
1101  s << "((";
1102  imp.trajectory_impurity();
1103  for (i=0; i<imp.l; i++)
1104  {
1105  s << "(";
1106  for (j=0; j<wgn_VertexTrack.num_channels(); j++)
1107  {
1108  s << "(" << imp.trajectory[i][j].mean() << " "
1109  << imp.trajectory[i][j].stddev() << " " << ")";
1110  }
1111  s << ")\n";
1112  }
1113  s << ") ";
1114  // Mean of cross product of distances (cluster score)
1115  s << imp.a.mean() << ")";
1116  }
1117  else if (imp.t == wnim_cluster)
1118  {
1119  EST_Litem *p;
1120  s << "((";
1121  for (p=imp.members.head(); p != 0; p=p->next())
1122  {
1123  // Ouput cluster member and its mean distance to others
1124  s << "(" << imp.members.item(p) << " " <<
1125  imp.cluster_member_mean(imp.members.item(p)) << ")";
1126  if (p->next() != 0)
1127  s << " ";
1128  }
1129  s << ") ";
1130  // Mean of cross product of distances (cluster score)
1131  s << imp.a.mean() << ")";
1132  }
1133  else if (imp.t == wnim_ols)
1134  {
1135  /* Output intercept, feature names and coefficients for ols model */
1136  EST_FMatrix X,Y;
1137  EST_IVector included;
1138  EST_FMatrix coeffs;
1139  EST_StrList feat_names;
1140  EST_FMatrix coeffsl;
1141  EST_FMatrix pred;
1142  float cor=0.0,rmse;
1143 
1144  s << "((";
1145  // Load the sample members into matrices for ols
1146  part_to_ols_data(X,Y,included,feat_names,imp.members,*(imp.data));
1147  if (!robust_ols(X,Y,included,coeffsl))
1148  {
1149  printf("no robust ols\n");
1150  // shouldn't happen
1151  }
1152  else
1153  {
1154  ols_apply(X,coeffsl,pred);
1155  ols_test(Y,pred,cor,rmse);
1156  for (i=0; i<coeffsl.num_rows(); i++)
1157  {
1158  s << "(";
1159  s << feat_names.nth(i);
1160  s << " ";
1161  s << coeffsl[i];
1162  s << ") ";
1163  }
1164  }
1165 
1166  // Mean of cross product of distances (cluster score)
1167  s << ") " << cor << ")";
1168  }
1169  else if (imp.t == wnim_class)
1170  {
1171  EST_Litem *i;
1172  EST_String name;
1173  double prob;
1174 
1175  s << "(";
1176  for (i=imp.p.item_start(); !imp.p.item_end(i); i=imp.p.item_next(i))
1177  {
1178  imp.p.item_prob(i,name,prob);
1179  s << "(" << name << " " << prob << ") ";
1180  }
1181  s << imp.p.most_probable(&prob) << ")";
1182  }
1183  else
1184  s << "([WImpurity unset])";
1185 
1186  return s;
1187 }
1188 
1189 
1190 
1191 
EST_Litem * item_next(EST_Litem *idx) const
Used for iterating through members of the distribution.
EST_Litem * item_start() const
Used for iterating through members of the distribution.
void item_prob(EST_Litem *idx, EST_String &s, double &prob) const
During iteration returns name and probability given index.
const EST_String & most_probable(double *prob=NULL) const
Return the most probable member of the distribution.
double samples(void) const
Total number of example found.
double entropy(void) const
int item_end(EST_Litem *idx) const
Used for iterating through members of the distribution.
int matches(const char *e, int pos=0) const
Exactly match this string?
Definition: EST_String.cc:652
double stddev(void) const
standard deviation of currently cummulated values
double variance(void) const
variance of currently cummulated values
double mean(void) const
mean of currently cummulated values
void reset(void)
reset internal values
double samples(void)
number of samples in set
T & item(const EST_Litem *p)
Definition: EST_TList.h:133
void append(const T &item)
add item onto end of list
Definition: EST_TList.h:191
T & nth(int n)
return the Nth value
Definition: EST_TList.h:139
int num_columns() const
return number of columns
Definition: EST_TMatrix.h:181
int num_rows() const
return number of rows
Definition: EST_TMatrix.h:179
INLINE const T & a_no_check(int row, int col) const
const access with no bounds check, care recommend
Definition: EST_TMatrix.h:184
void resize(int rows, int cols, int set=1)
resize matrix
void resize(int n, int set=1)
resize vector
void resize(int n, int set=1)
Definition: EST_TVector.cc:196
INLINE const T & a_no_check(int n) const
read-only const access operator: without bounds checking
Definition: EST_TVector.h:257
INLINE int n() const
number of items in vector.
Definition: EST_TVector.h:254
float & a(int i, int c=0)
Definition: EST_Track.cc:1022
int num_channels() const
return number of channels in track
Definition: EST_Track.h:656
const int Int(void) const
Definition: EST_Val.h:130
const EST_String & string(void) const
Definition: EST_Val.h:150
const float Float(void) const
Definition: EST_Val.h:138