# -*- coding: utf-8 -*- ''' Created on Jan 13, 2017 @author: David ''' from tree_representation import TreeNodePair,DeltaMatrix from HitEx import call_sparv from tree_builder import build_trees from math import sqrt param_mu = 0.6 param_lambda = 0.6 param_lambda2 = param_lambda*param_lambda param_terminal_factor = 1 delta_matrix = DeltaMatrix() def determine_sub_list (a,b): intersect = [] i = j = j_old = j_final = 0 cfr = 0 nodes_a = a.getNodes() nodes_b = b.getNodes() n_a = len(nodes_a) n_b = len(nodes_b) while(i < n_a and j < n_b): cfr = nodes_a[i].getContent() > nodes_b[j].getContent() cfr2 = nodes_a[i].getContent() < nodes_b[j].getContent() # TODO test cfr values against Java expected values if (cfr): j += 1 elif (cfr2): i += 1 else: j_old = j; while(i < n_a and nodes_a[i].getContent() == nodes_b[j].getContent()): while (j < n_b and nodes_a[i].getContent() == nodes_b[j].getContent()): intersect.append(TreeNodePair(nodes_a[i],nodes_b[j])) #print("intersect {}:{}".format(str(nodes_a[i]),str(nodes_b[j]))) delta_matrix.add(nodes_a[i].getId(), nodes_b[j].getId(),-1) j+=1 i += 1 j_final = j j = j_old j = j_final #for obj in intersect: # print(str(obj)) #print(len(intersect)) return intersect def evaluate_kernel (a,b): delta_matrix.clear() pairs = determine_sub_list(a, b) k = 0 for i in range(len(pairs)): k += ptk_delta_function(pairs[i].getNx(), pairs[i].getNz()) return k def ptk_delta_function (nx,nz): isum = 0 if (delta_matrix.get(nx.getId(),nz.getId()) != -1): return delta_matrix.get(nx.getId(), nz.getId()) if (nx.getContent() != nz.getContent()): delta_matrix.add(nx.getId(), nz.getId(), 0) return 0 elif (len(nx.getChildren()) == 0 or len(nz.getChildren()) == 0): delta_matrix.add(nx.getId(), nz.getId(), param_mu*param_lambda2*param_terminal_factor) return param_mu*param_lambda2*param_terminal_factor else: delta_sk = string_kernel_delta_function(nx.getChildren(),nz.getChildren()) isum = param_mu * (param_lambda2 + delta_sk) delta_matrix.add(nx.getId(), nz.getId(), isum) return isum MAX_CHILDREN = 100 def string_kernel_delta_function(sx,sz): n = len(sx) m = len(sz) i = j = l = p = 0 K = 0 dps = [[0 for x in range(m)] for y in range(n)] dp = [[0 for x in range(m)] for y in range(n)] p = n if (m > n): p = m if (p > MAX_CHILDREN): p = MAX_CHILDREN kernel_mat = [0 for x in range(p)] kernel_mat[0] = 0 for i in range(1,n): for j in range(1,m): if(sx[i-1].getContent() == sz[j-1].getContent()): dps[i][j] = ptk_delta_function(sx[i-1], sz[j-1]) kernel_mat[0] += dps[i][j] else: dps[i][j] = 0 for l in range(1,p): kernel_mat[l] = 0 for j in range(m): dp[l-1][j] = 0 for i in range(n): dp[i][l-1] = 0 for i in range(l,n): for j in range(l,m): dp[i][j] = dps[i][j] + param_lambda * dp[i-1][j] + param_lambda * dp[i][j-1] - param_lambda2 * dp[i-1][j-1] if (sx[i-1].getContent() == sz[j-1].getContent()): dps[i][j] = ptk_delta_function(sx[i-1], sz[j-1]) * dp[i-1][j-1] kernel_mat[l] += dps[i][j] for l in range(p): K += kernel_mat[l] return K def normalize(a, b): return evaluate_kernel(a, b)/sqrt(evaluate_kernel(a, a)*evaluate_kernel(b, b)) def calculate_similarity(sent1, sent2): wta,pta,dta = build_trees(call_sparv.call_sparv(sent1)) wtb,ptb,dtb = build_trees(call_sparv.call_sparv(sent2)) word_kernel_self = evaluate_kernel(wta, wta) word_kernel_other = evaluate_kernel(wta, wtb) word_kernel_k = word_kernel_other/word_kernel_self word_kernel_norm = normalize(wta,wtb) word_kernel_mean = (word_kernel_k+word_kernel_norm)/2 pos_kernel_self = evaluate_kernel(pta, pta) pos_kernel_other = evaluate_kernel(pta, ptb) pos_kernel_k = pos_kernel_other/pos_kernel_self pos_kernel_norm = normalize(pta,ptb) pos_kernel_mean = (pos_kernel_k+pos_kernel_norm)/2 dep_kernel_self = evaluate_kernel(dta, dta) dep_kernel_other = evaluate_kernel(dta, dtb) dep_kernel_k = dep_kernel_other/dep_kernel_self dep_kernel_norm = normalize(dta, dtb) dep_kernel_mean = (dep_kernel_k+dep_kernel_norm)/2 tree_height = wta.getHeight() node_number = len(wta.getNodes()) height_nodes_ratio = tree_height/node_number return word_kernel_mean,pos_kernel_mean,dep_kernel_mean,node_number,tree_height,height_nodes_ratio sen1 = "En egyptisk katt ligger på fönsterbrädan i köket." sen2 = "En egyptisk katt ligger i köket på fönsterbrädan." res = calculate_similarity(sen1, sen2) print(res)