import codecs import math import corpus import paradigms class Classifier(object): def __init__(self): self.paradigms = None self.corpus = None def train(self, paradigms, corpus): self.paradigms = paradigms self.corpus = corpus def classify(self, word, profile=[True,True,True]): candidates = [] for (n,p) in self.paradigms.paradigms.iteritems(): r = p.match_patterns(word,True) if len(r) > 0: candidates.extend(r) xs = sorted([(self.__confidence_score(dist,c,profile), c['p'].id) for (dist,c) in self.distribution(candidates)], reverse=True) return [(w,n) for (n,w) in xs] def __confidence_score(self,dist,c, profile): if profile[0]: d = dist else: d = 0 if profile[1]: c_score = self.__corpus_score(c) else: c_score = 1 if profile[2]: nc_score = self.__nc_score(c['w'],c) else: nc_score = 1 return d + c_score * nc_score def __corpus_score(self,c): score = 1 ft = self.corpus.add_frequency_to_table(set(c['table'])) for (w,n) in ft: score += math.log(n+1) return score def __len_common_prefix(self,m,n): count = 0 pos = 0 for (m,n) in zip(m,n): if m == n: count += 1 else: return count return count def __nc_score(self, baseform, c): nc_scores = [] for (_,inst) in c['p'].instances: paradigmhead = inst['0'] # the 0-form in the paradigm val = self.__len_common_prefix(baseform[::-1],paradigmhead[::-1]) # matching suffix length nc_scores.append(val) maxscore = max(nc_scores) numhits = len(filter(lambda x: x==maxscore, nc_scores)) # Number of hits of max. suffix length return (100 * (maxscore + 0.001 * numhits)) # We weight the total by 100 to override dist for sure def distribution(self,cs): # annotate the paradigm candidates with distributional information to be used as confidence score tiebreaker num_of_instances = sum([len(c['p'].instances) for c in cs]) return [((len(c['p'].instances)/float(num_of_instances), c)) for c in cs] if __name__ == "__main__": cl = Classifier() P = paradigms.Paradigms('../paradigms/de_nouns_train.para') C = corpus.Corpus(None) cl.train(P,C) print "\n".join(['%s %.2f'%(w,d) for (w,d) in cl.classify('Buch')])