Ana6.1 網絡節點分群實作


在此節中應用Analysis 3章節中提及的 K-Means 分群演算法 於社會網絡的節點分群,即社群偵測 (community detection)。

網絡節點分群之程式的概念流程如下:

# 顯示開始訊息

# 讀取網絡相關資訊的pickle檔
# 取得網絡G
# 建立網絡G的相鄰矩陣(adjacency-matrix)
# 使用K-means演算法對網絡節點做分群
# 計算分群準確度
# 儲存分群結果到pickle檔中
# 顯示分群結果

# 顯示結束訊息
  • 參考檔案: community_detection.py
# coding=utf-8
# 匯入模組
import bisect
import copy
import forceatlas as fa
import matplotlib.pyplot as plt
import networkx as nx
import random
import sys
import pickle

from sklearn.metrics.cluster import normalized_mutual_info_score as nmi


# 定義函數
def create_adjacency_matrix(g):
    # create an adjacency matrix
    num_nodes = g.number_of_nodes()
    m = [[0] * num_nodes for i in range(num_nodes)]

    for v in g:
        # edges to neigbors
        for nb in g.neighbors(v):
            m[v][nb] = 1

        # self-loop of a node
        # m[v][v] = 1

    return m


def add_center(mi):
    center_info = dict()
    center_info['dimensions'] = [i for i in mi[0: len(mi)]]
    center_info['num_members'] = 0

    return center_info


def binary_pick(d):
    # create CDF
    for i in range(1, len(d)):
        d[i] = d[i] + d[i-1]

    # normalize D
    d_max = copy.copy(d[-1])
    d = [di/d_max for di in d]

    # pick a data-point as center
    f = random.uniform(0, 1)
    i = bisect.bisect(d, f) - 1

    return i if i != -1 else 0


def distance_to_center(m, c, lamda=1):
    distance = 0.0
    lamda = float(lamda)

    for a, b in zip(m, c):
        if a != 0:
            distance += (abs(a - b) ** lamda)
        else:
            pass

    return distance ** (1.0/lamda)


# ----------------------------------------------------------------------------------------------------
def random_initialization(g, m, num_community):
    num_node = g.number_of_nodes()
    picked_node = random.sample(range(num_node), num_community)
    center_list = [add_center(m[i]) for i in picked_node]

    print("[Msg][Init] Pick nodes as initial centers: {0}".format(picked_node))

    return center_list


def degree_initialization(g, m, num_community):
    num_node = g.number_of_nodes()
    num_center = 0
    degree_list = [g.degree(i) for i in g]
    picked_node = []

    while num_center < num_community:
        i = binary_pick(degree_list)

        if i in picked_node:
            continue
        else:
            picked_node.append(i)
            num_center += 1

    print("[Msg][Init] Pick nodes as initial centers: {0}".format(picked_node))

    center_list = [add_center(m[i]) for i in picked_node]

    return center_list


def find_closest_center(m, center_list):
    cluster_result = {}

    for i, mi in enumerate(m):
        center_index = None
        min_distance = sys.maxsize

        for j, ci in enumerate(center_list):
            ci = ci['dimensions']

            distance = distance_to_center(mi, ci)

            if distance < min_distance:
                center_index = j
                min_distance = distance

        # assign data-point to the cluster
        center_list[center_index]['num_members'] += 1
        cluster_result[i] = center_index + 1

    return cluster_result


def update_cluster_center(m, cluster_result, center_list):
    # reset dimensions' value
    for c in center_list:
        c['dimensions'] = [0.0] * len(c['dimensions'])

    # sum of dimension values of a cluster center
    for i, j in cluster_result.items():
        c = center_list[j-1]['dimensions']

        for k in range(len(c)):
            c[k] += m[i][k]

    # reset center-info
    for c in center_list:
        c['dimensions'] = [float(i/c['num_members']) if c['num_members'] != 0 else 0.0 for i in c['dimensions']]
        c['num_members'] = 0

    return center_list


def calculate_square_error(m, cluster_result, center_list):
    square_error = 0

    for i, j in cluster_result.items():
        mi = m[i]
        cj = center_list[j-1]['dimensions']

        for a, b in zip(mi, cj):
            square_error += abs(a - b) ** 2

    return square_error


# ----------------------------------------------------------------------------------------------------
def k_means_community_detection(g=None, m=None, num_community=2, num_iteration=100, initial_type="random"):
    if g and m:
        # 1.random initialize k-centers
        print("[Msg][Init] Randomly initialize community centers.")
        center_list = []

        if initial_type == "random":
            center_list = random_initialization(g, m, num_community)
        else:
            center_list = degree_initialization(g, m, num_community)

        # ----------------------------------------------------------------------------------------------------
        cluster_result = {}
        square_error = sys.maxsize

        for i in range(num_iteration):
            # preserve old centers
            old_square_error = copy.copy(square_error)

            # 2.classify each to the closest center
            cluster_result = find_closest_center(m, center_list)

            # 3.update centers
            new_center = update_cluster_center(m, cluster_result, center_list)

            # calculate square-error
            square_error = calculate_square_error(m, cluster_result, center_list)

            # stop-condition: square-error is no more improving
            print("[Msg][Loop] Iteration {0}, square-error: {1:.2f}".format(i + 1, square_error))

            if square_error == old_square_error:
                print("[Msg][Stop] Stop community detection, centers are stable!")
                return cluster_result, center_list
            else:
                # preserve related info.
                old_square_error = copy.copy(square_error)
                center_list = copy.deepcopy(new_center)

        # stop by reaching maximal iterations
        print("[Msg] STOP community detection, reach maximal iterations!")
        return cluster_result, center_list
    else:
        print("[ERR][Msg] No data for community detection.")
        sys.exit(0)


def save_result(file_name="community-detctionk-results.txt", network_info=None, cluster_result=None, center_list=None):
    if network_info and cluster_result and center_list:
        with open(file_name, "w") as f:
            # centers' info
            f.write("Centers\n")
            f.write("--------------------------------------------------\n")

            for i in range(len(center_list)):
                c = center_list[i]
                f.write("center-{0}: {1}\n".format(i+1, c['dimensions']))

            # network nodes' cluster
            identified_cluster = []

            f.write("\n")
            f.write("Clusters\n")
            f.write("--------------------------------------------------\n")

            for i in sorted(cluster_result):
                f.write("{0}: {1}\n".format(i + 1, cluster_result[i]))
                identified_cluster.append(cluster_result[i])

            # NMI value: community detection accuracy
            true_cluster = network_info['ground-truth'] 

            f.write("\n")
            f.write("Normalized Mutual Information (NMI) as accuracy\n")
            f.write("--------------------------------------------------\n")
            f.write("{0}\n".format(nmi(true_cluster, identified_cluster)))
    else:
        print("[ERR][Msg] No data to save.")
        sys.exit(0)


def visualize_result(file_name="community-detctionk-results.png", network_info=None, cluster_result=None, center_list=None, image_show=False, image_save=False):
    if network_info and cluster_result and center_list:
        # plot parameters
        plot_x_size = 6
        plot_y_size = 6
        plot_dpi = 300

        # prepare data
        cluster_dict = {}

        for i, c in cluster_result.items():
            if c in cluster_dict:
                cluster_dict[c].append(i)
            else:
                cluster_dict[c] = [i]

        # draw data figure
        g = network_info['graph']
        pos = network_info['axis-pos']
        color_list =['r', 'b', 'g', 'c', 'm', 'y', 'k', 'w'] * len(center_list)
        fig, axes = plt.subplots(figsize=(plot_x_size, plot_y_size), facecolor='w')

        for c in sorted(cluster_dict):
            nx.draw_networkx_nodes(g, pos=pos, nodelist=cluster_dict[c], node_size=250, node_color=color_list[c-1])
            # cmap=plt.cm.get_cmap('jet')

        nx.draw_networkx_edges(g, pos=pos, edgelist=g.edges())
        nx.draw_networkx_labels(g, pos=pos, label=g.nodes())

        # plot setting
        axes.axis(network_info['axis-display'])
        axes.xaxis.set_visible(False)
        axes.yaxis.set_visible(False)

        # save and show figure
        plt.tight_layout()

        if image_save:
            plt.savefig(file_name, dpi=plot_dpi, bbox_inches='tight', pad_inches=0.05)

        if image_show:
            plt.show()

        plt.close(fig)
    else:
        print("[ERR][Msg] No data to visualize.")


# ----------------------------------------------------------------------------------------------------
def community_detection(num_community=2):
    print(">>> Community detection")
    print()

    # Load the Karate network pickle file
    print("[Msg] Load network information of Karate network.")
    network_info = dict()

    with open('karate.pickle', 'rb') as f:
        network_info = pickle.load(f)

    # create adjacency-matrix
    print("[Msg] Create adjacency-matrix of Karate network.")
    G = network_info['graph']
    M = create_adjacency_matrix(G)

    # k-means for community detection
    print("[Msg] Community detection using K-means.")
    cluster_result, center_list = k_means_community_detection(G, M, num_community, initial_type="degree")

    # NMI
    identified = list(cluster_result.values())
    ground_truth = network_info['ground-truth']
    print("[Msg][NMI] {0}".format(nmi(ground_truth, identified)))

    # save community detection results
    print("[Msg] Save communiyt detection result.")
    save_result("community-detection-results.txt", network_info, cluster_result, center_list)

    # visualize community detection results
    print("[Msg] Visualize community detection result.")
    visualize_result("community-detctionk-results.png", network_info, cluster_result, center_list, True, True)

    print()
    print(">>> STOP Community detection")

if __name__ == "__main__":
    community_detection(num_community=2)
  • 參考檔案: community-detection-results.txt
Centers
--------------------------------------------------
center-1: [0.875, 0.5, 0.3125, 0.375, 0.1875, 0.25, 0.25, 0.25, 0.125, 0.0625, 0.1875, 0.0625, 0.125, 0.25, 0.0, 0.0, 0.125, 0.125, 0.0, 0.125, 0.0, 0.125, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0625, 0.0625, 0.0, 0.0625, 0.0625, 0.0625, 0.125]
center-2: [0.1111111111111111, 0.05555555555555555, 0.2777777777777778, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16666666666666666, 0.05555555555555555, 0.0, 0.0, 0.0, 0.05555555555555555, 0.1111111111111111, 0.1111111111111111, 0.0, 0.0, 0.1111111111111111, 0.05555555555555555, 0.1111111111111111, 0.0, 0.1111111111111111, 0.2777777777777778, 0.16666666666666666, 0.16666666666666666, 0.1111111111111111, 0.16666666666666666, 0.1111111111111111, 0.2222222222222222, 0.16666666666666666, 0.2777777777777778, 0.6111111111111112, 0.8333333333333334]

Clusters
--------------------------------------------------
1: 1
2: 1
3: 1
4: 1
5: 1
6: 1
7: 1
8: 1
9: 2
10: 2
11: 1
12: 1
13: 1
14: 1
15: 2
16: 2
17: 1
18: 1
19: 2
20: 1
21: 2
22: 1
23: 2
24: 2
25: 2
26: 2
27: 2
28: 2
29: 2
30: 2
31: 2
32: 2
33: 2
34: 2

Normalized Mutual Information (NMI) as accuracy
--------------------------------------------------
1.0

參考資料

results matching ""

    No results matching ""