Source code for graphpro.graph

import networkx as nx
import numpy as np
import torch

from torch_geometric.data import Data
from dataclasses import dataclass
from .model import Target, ProteinMetadata

[docs] class Graph(): """ Graph provides a representation of a graph and required helpers. """ def __init__(self, name: str, adjacency: np.array, positions: np.array, res_map: dict[int, dict], metadata: ProteinMetadata = None): self.name = name self.adjacency = adjacency self.positions = positions self._n_attr = {i: {"resid": res_attr['resid']} for i, res_attr in enumerate(res_map)} self._resid_to_node = { res_attr['resid']: i for i, res_attr in enumerate(res_map)} self.metadata = metadata def __eq__(self, other): """Compare two graphs for equality""" # TODO: may need to compare more than the adjacency if not other: return False return (self.adjacency == other.adjacency).any() def node_attr(self, node_id: int): return self._n_attr.get(node_id)
[docs] def node_attr_add(self, node_id: int, attribute_name: str, attribute: any): """Adds a specific attribute to a noode in the graph""" attrs = self.node_attr(node_id) if attrs: attrs[attribute_name] = attribute
[docs] def get_node_by_resid(self, resid: int) -> int: """Returns the node number using the residue id, None if the residue id is not known""" return self._resid_to_node.get(resid)
[docs] def communities(self) -> list[tuple[float, list[set]]]: """ Perform Girvan Newman communinity detection returning the list of communities. The algorithm is perform all the way until no more edges are left to be removed. """ from networkx.algorithms import community c_iter = community.girvan_newman(self.to_networkx()) return [(community.modularity(self.to_networkx(), com), com) for com in c_iter]
[docs] def to_networkx(self) -> nx.Graph: """ Returns a networkx G undirected graph with populated attributes """ G = nx.from_numpy_array(self.adjacency) nx.set_node_attributes(G, self._n_attr) return G
[docs] def to_data(self, node_encoders = [], target: Target = None) -> Data: """ Return a PyG object from this existing graph""" row, col = np.nonzero(self.adjacency) values = self.adjacency[row, col] indices = torch.tensor(np.array([row,col], dtype=int), dtype=torch.long) values = torch.tensor(values, dtype=torch.float) cco = torch.sparse_coo_tensor(indices, values, self.adjacency.shape).coalesce() x = None y = None # Concat a list of node features into a X tensor for encoder in node_encoders: ecoded_attr = encoder.encode(self) if isinstance(x,torch.Tensor): x = torch.concat((x, ecoded_attr), 1) # concat to 1 dim else: x = ecoded_attr if target: y = target.encode(self) return Data(x=x, edge_index=cco.indices(), y=y)
[docs] def nodes(self) -> list[int]: """ Return node list """ return self._resid_to_node.values()
[docs] def plot(self, figsize: tuple[int, int] = (8, 10), communities: list[set[int]] = [], show = True ) -> None: """ Plot the graph represention in 3D using residue positions. """ import matplotlib.pyplot as plt node_xyz = np.array([self.positions[v] for v in sorted(self.to_networkx())]) edge_xyz = np.array([(self.positions[u], self.positions[v]) for u, v in self.to_networkx().edges()]) node_colors = None if len(communities) > 0: community_node = sorted( [(n, i) for i, c in enumerate(communities) for n in c]) node_colors = [n[1] for n in community_node] fig = plt.figure(figsize=figsize) ax = fig.add_subplot(111, projection="3d") ax.scatter(*node_xyz.T, s=100, ec="w", c=node_colors) # Plot the edges for vizedge in edge_xyz: ax.plot(*vizedge.T, color="tab:gray") def _format_axes(ax): ax.grid(False) for dim in (ax.xaxis, ax.yaxis, ax.zaxis): dim.set_ticks([]) _format_axes(ax) fig.tight_layout() if show: plt.show()
def __repr__(self) -> str: return self.name