## Copyleft 2021, Alex Markham, see https://medil.causal.dev/license.html # Tested with versions: # python: 3.9.5 # requests: 2.25.1 # numpy: 1.20.3 # scipy: 1.6.3 # medil: 0.6.0 # matplotlib: 3.4.2 # networkx: 2.5 import requests, os import numpy as np from numpy import linalg as LA from scipy.spatial.distance import pdist, squareform from scipy.stats import chi2 from medil.ecc_algorithms import find_clique_min_cover as find_cm import matplotlib.pyplot as plt import networkx as nx ## Load data into Python, and download first if necessary path = "./" # change if desired if not os.path.exists(path + "data.txt"): url = "http://genome-www.stanford.edu/serum/data/fig2clusterdata.txt" r = requests.get(url) with open(path + "data.txt", "w") as f: f.write(r.text) cols = np.arange(5, 16) data = np.loadtxt( path + "data.txt", skiprows=2, usecols=cols, delimiter="\t", ) ## Perform clustering def dep_contrib_kernel(X, alpha=0.1): num_samps, num_feats = X.shape thresh = np.eye(num_feats) if alpha is not None: thresh[thresh == 0] = ( chi2(1).ppf(1 - alpha) / num_samps ) # critical value corresponding to alpha thresh[thresh == 1] = 0 Z = np.zeros((num_feats, num_samps, num_samps)) for j in range(num_feats): n = num_samps t = np.tile D = squareform(pdist(X[:, j].reshape(-1, 1), "cityblock")) D_bar = D.mean() D -= ( t(D.mean(0), (n, 1)) + t(D.mean(1), (n, 1)).T - t(D_bar, (n, n)) ) # doubly centered Z[j] = D / (D_bar) # standardized F = Z.reshape(num_feats * num_samps, num_samps) left = np.tensordot(Z, thresh, axes=([0], [0])) left_right = np.tensordot(left, Z, axes=([2, 1], [0, 1])) gamma = (F.T @ F) ** 2 - 2 * (left_right) + LA.norm(thresh) # helper kernel diag = np.diag(gamma) kappa = gamma / np.sqrt(np.outer(diag, diag)) # cosine similarity kappa[kappa > 1] = 1 # correct numerical errors return kappa, gamma def kernel_k_means(data, num_clus=6, kernel=dep_contrib_kernel, max_iters=100): num_samps, num_feats = data.shape rng = np.random.default_rng(1312) init = rng.choice( num_samps, num_clus, replace=False ) # choose initial clusters using Forgy method inner_prods, _ = kernel(data) left = np.tile(np.diag(inner_prods)[:, np.newaxis], (1, num_clus)) distances = ( left - 2 * inner_prods[:, init] + np.tile(inner_prods[init, init], (num_samps, 1)) ) # use law of cosines to get angle instead of Euc dist # clip corrects for numerical error, e.g. 1.0000004 instead of 1.0 arc_distances = np.arccos(np.clip((1 - (distances ** 2 / 2)), -1, 1)) labels = np.argmin(arc_distances, axis=1) for itr in range(max_iters): # compute kernel distance using ||x - mu|| = k(x,x) - 2k(x,mu).mean() + k(mu,mu).mean() = left - 2*middle + right ip_clus = np.tile(inner_prods, (num_clus, 1, 1)) m_idx = np.fromiter( (j for c in range(num_clus) for i in labels for j in labels == c), bool, num_clus * num_samps ** 2, ) m_idx = m_idx.reshape(num_clus, num_samps, num_samps) counts = np.fromiter( ((labels == label).sum() for label in range(num_clus)), int, num_clus ) # counts = m_idx[:, 0, :].sum(1) ip_clus[~m_idx] = 0 middle = ip_clus.sum(2).T / counts # sum/ counts, because 0s through off mean r_idx = np.fromiter( ( (i and j) for c in range(num_clus) for i in labels == c for j in labels == c ), bool, num_clus * num_samps ** 2, ) r_idx = r_idx.reshape(num_clus, num_samps, num_samps) ip_clus[~r_idx] = 0 right = ip_clus.sum((1, 2)) / (counts ** 2) distances = left - 2 * middle + right # law of cosines arc_distances = np.arccos(np.clip((1 - (distances ** 2 / 2)), -1, 1)) new_labels = np.argmin(arc_distances, axis=1) if (labels == new_labels).all(): print("converged") break print("iteration {} with cluster sizes {}".format(itr, counts)) labels = new_labels return labels cluster_labels = kernel_k_means(data) ## Generate plots def make_heatmaps_and_dags(path, data, labels): with open(path + "data.txt") as f: first_line = f.readline() x = first_line.split("\t")[5:16] x[0:2] = ["0.25", "0.5"] x[2:] = [time[:-2] for time in x[2:]] ## Dcov def compute_d_cov(X): num_samps, num_feats = X.shape dists = np.zeros((num_feats, num_samps ** 2)) d_bars = np.zeros(num_feats) # compute doubly centered distance matrix for every feature: for feat_idx in range(num_feats): n = num_samps t = np.tile # raw distance matrix: d = squareform(pdist(X[:, feat_idx].reshape(-1, 1), "cityblock")) # doubly centered: d_bar = d.mean() d -= t(d.mean(0), (n, 1)) + t(d.mean(1), (n, 1)).T - t(d_bar, (n, n)) dd = d.flatten() dists[feat_idx] = dd / n d_bars[feat_idx] = d_bar return dists @ dists.T, d_bars plt.rcParams.update( { "text.usetex": True, "font.family": "sans-serif", "font.sans-serif": ["Helvetica"], } ) fig, axs = plt.subplots( 4, 2, figsize=(4, 9.5), sharex=True, sharey=True, constrained_layout=True ) alpha = 0.1 crit = chi2(1).ppf(1 - alpha) counts = np.append(517, np.bincount(cluster_labels)) ims = dict() deps = dict() covs = dict() tests = dict() for r in range(2): for c in range(4): if c == 0: cov, d_bars = compute_d_cov(data) if c > 0: c += 2 cov, d_bars = compute_d_cov(data[cluster_labels == c - 1]) covs[c] = cov if r == 1: dep = np.zeros_like(cov) test = counts[c] * cov / np.outer(d_bars, d_bars) dep[test > crit] = 1 deps[c] = dep tests[c] = test c -= 2 if c > 0 else 0 ax = axs[c, r] cmap = "YlOrBr_r" if r == 0 else "binary" im = cov if r == 0 else -dep ims[(r + 1) * (c + 1)] = ax.imshow(im, cmap=cmap) ax.set_yticks(np.arange(len(x))) ax.set_yticklabels(x) ax.set_xticks(np.arange(len(x))) ax.set_xticklabels(x, rotation=80) # fig.text(0.53, -0.03, "Time (hours)", ha="center", va="center") # fig.text(-0.03, 0.5, "Time (hours)", ha="center", va="center", rotation="vertical") box = dict(facecolor="none", edgecolor="black", boxstyle="round") fig.text(0.54, 0.99, "Unclustered data", ha="center", va="center", bbox=box) fig.text(0.54, 0.765, "Cluster K1", ha="center", va="center", bbox=box) fig.text(0.54, 0.54, "Cluster K2", ha="center", va="center", bbox=box) fig.text(0.54, 0.295, "Cluster K3", ha="center", va="center", bbox=box) # fig.text( # 0.49, # 0.50, # r"Same custers as above, but with values thresholded to 0 or 1, using $\alpha=0.1$", # ha="center", # va="center", # bbox=box, # ) cbar = fig.colorbar( ims[1], ax=axs[3, 0], location="bottom", shrink=0.6, label="distance covariance", ticks=[0.005, 0.8], ) cbar.ax.set_xticklabels(["0", "0.8"]) cbar2 = fig.colorbar( ims[4], ax=axs[3, 1], location="bottom", shrink=0.6, label=r"dependent", # , $\alpha = 0.1$", ticks=[-0.75, -0.25], # boundaries=[0, 0.5, 1], values=[-1, 0], ) cbar2.ax.set_xticklabels(["true", "false"]) # fig.text(0.95, 0.3, "Fail to reject", ha="center", va="center", bbox=box, rotation=-90) # plt.tight_layout() # fig.text(0.935, 0.247, "0", ha="center", va="center") plt.savefig(path + "heatmaps.png", dpi=200, bbox_inches="tight") fig.clf() def plot_dag(biadj_mat, ax): num_latent, num_obs = biadj_mat.shape pos_dict = {} latent_pos_dict = { idx: (val, 1) for idx, val in enumerate(np.linspace(0, 1, num_latent)) } obs_pos_dict = { idx + num_latent: (val, 0) for idx, val in enumerate(np.linspace(0, 1, num_obs)) } pos_dict.update(latent_pos_dict) pos_dict.update(obs_pos_dict) # print(pos_dict) node_color = [] node_color.extend(num_latent * [0]) node_color.extend(num_obs * [1]) full_adj_mat = get_dag_from_biadj(biadj_mat) G = nx.DiGraph(full_adj_mat) nx.draw_networkx(G, pos=pos_dict, with_labels=False, ax=ax, node_size=500) nx.draw_networkx_labels( G, pos=latent_pos_dict, labels={idx: "$L_{{{}}}$".format(idx + 1) for idx in range(num_latent)}, font_color="w", ax=ax, ) nx.draw_networkx_labels( G, pos=obs_pos_dict, labels={ idx + num_latent: "$M_{{{}}}$".format(idx) for idx in range(num_obs) }, font_color="w", ax=ax, ) nx.draw_networkx_nodes( G, pos=pos_dict, node_color=node_color, ax=ax, node_size=0 ) # nx.draw_networkx(G, pos=pos_dict, arrows=True, with_labels=False) ax.set_xlim(-0.1, 1.1) ax.set_ylim(-0.5, 1.85) def get_dag_from_biadj(biadj_mat): num_latent, num_obs = biadj_mat.shape dag_adj_mat = np.zeros((num_latent + num_obs, num_latent + num_obs)) dag_adj_mat[:num_latent, num_latent:] = biadj_mat return dag_adj_mat biadj_mats = dict() for key in deps.keys(): dep = deps[key] biadj_mats[key] = find_cm(deps[key]) fig, axs = plt.subplots(4, 1, figsize=(5, 5.5), constrained_layout=True) for idx, key in enumerate((0, 3, 4, 5)): ax = axs[idx] b_mat = biadj_mats[key] plot_dag(b_mat[np.lexsort(b_mat.T)], ax) box = dict(facecolor="none", edgecolor="black", boxstyle="round") fig.text(0.5, 0.96, "Unclustered data", ha="center", va="center", bbox=box) fig.text(0.5, 0.71, "Cluster K1", ha="center", va="center", bbox=box) fig.text(0.5, 0.46, "Cluster K2", ha="center", va="center", bbox=box) fig.text(0.5, 0.21, "Cluster K3", ha="center", va="center", bbox=box) plt.savefig(path + "dags.png", dpi=200, bbox_inches="tight") make_heatmaps_and_dags(path, data, cluster_labels)