Segmentation models实战
发表于:2025-11-16 | 分类: BioAI
字数统计: 7.3k | 阅读时长: 39分钟 | 阅读量:

1. Segmentation models 简介

使用Transformer骨干网络的分割模型(如Nucleotide Transformer、Enformer、Borzoi)可用于单核苷酸分辨率下的基因组元件预测。例如,SegmentNT能在长达30kb的序列(可扩展至50kbp)中预测14种不同类别的人类基因组元件,并表现出优异的性能。

所有模型均搭配一维U-Net分割头,以单核苷酸分辨率预测序列中多种基因组元件的位置。这些元件包括基因元件(蛋白质编码基因、长链非编码RNA、5’非翻译区、3’非翻译区、外显子、内含子、剪接受体位点和供体位点)和调控元件(polyA signal、组织非特异性和组织特异性启动子及增强子,以及CTCF结合位点)。

Performance on downstream tasks

Fig. 1: SegmentNT localizes genomics elements at nucleotide resolution.

2. 如何使用 🚀

2.1 安装并加载模块

1
2
3
4
!pip install boto3
!pip install matplotlib
!pip install biopython
!pip install dm-haiku
Looking in indexes: https://pypi.org/simple, https://packagecloud.io/github/git-lfs/pypi/simple
Requirement already satisfied: boto3 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (1.28.2)
Requirement already satisfied: botocore<1.32.0,>=1.31.2 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from boto3) (1.31.2)
Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from boto3) (1.0.1)
Requirement already satisfied: s3transfer<0.7.0,>=0.6.0 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from boto3) (0.6.1)
Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from botocore<1.32.0,>=1.31.2->boto3) (2.8.2)
Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from botocore<1.32.0,>=1.31.2->boto3) (1.26.16)
Requirement already satisfied: six>=1.5 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.32.0,>=1.31.2->boto3) (1.16.0)

[notice] A new release of pip is available: 24.1.1 -> 24.1.2
[notice] To update, run: pip install --upgrade pip
Looking in indexes: https://pypi.org/simple, https://packagecloud.io/github/git-lfs/pypi/simple
Requirement already satisfied: matplotlib in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (3.7.2)
Requirement already satisfied: contourpy>=1.0.1 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (1.1.0)
Requirement already satisfied: cycler>=0.10 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (4.41.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (1.4.4)
Requirement already satisfied: numpy>=1.20 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (1.25.1)
Requirement already satisfied: packaging>=20.0 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (23.1)
Requirement already satisfied: pillow>=6.2.0 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (10.0.0)
Requirement already satisfied: pyparsing<3.1,>=2.3.1 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (3.0.9)
Requirement already satisfied: python-dateutil>=2.7 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (2.8.2)
Requirement already satisfied: six>=1.5 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)

[notice] A new release of pip is available: 24.1.1 -> 24.1.2
[notice] To update, run: pip install --upgrade pip
Looking in indexes: https://pypi.org/simple, https://packagecloud.io/github/git-lfs/pypi/simple
Requirement already satisfied: biopython in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (1.81)
Requirement already satisfied: numpy in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from biopython) (1.25.1)

[notice] A new release of pip is available: 24.1.1 -> 24.1.2
[notice] To update, run: pip install --upgrade pip
Looking in indexes: https://pypi.org/simple, https://packagecloud.io/github/git-lfs/pypi/simple
Requirement already satisfied: dm-haiku in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (0.0.9)
Requirement already satisfied: absl-py>=0.7.1 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from dm-haiku) (1.4.0)
Requirement already satisfied: jmp>=0.0.2 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from dm-haiku) (0.0.4)
Requirement already satisfied: numpy>=1.18.0 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from dm-haiku) (1.25.1)
Requirement already satisfied: tabulate>=0.8.9 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from dm-haiku) (0.9.0)

[notice] A new release of pip is available: 24.1.1 -> 24.1.2
[notice] To update, run: pip install --upgrade pip
1
2
3
4
5
6
7
8
9
10
11
12
import os

try:
import nucleotide_transformer
except:
!pip install git+https://github.com/instadeepai/nucleotide-transformer@main |tail -n 1
import nucleotide_transformer

if "COLAB_TPU_ADDR" in os.environ:
from jax.tools import colab_tpu

colab_tpu.setup_tpu()

2.2 SegmentNT

⚠️ SegmentNT 模型采用 核苷酸转换器(NT) 作为骨干网络,训练数据为 30,000 个核苷酸序列(对应 5001 个标记,包含 CLS 标记)。但研究表明,SegmentNT 可泛化至 50,000 bp(碱基对)的序列。由于 30,000 bp 的训练长度超过了核苷酸转换器所能处理的最大长度(2048 个 6-mer 标记),因此采用了 Yarn 缩放技术。

默认情况下,缩放系数(rescaling factor) 已设置为训练时使用的值。若需对 30kbp 至 50kbp 之间的序列进行推理,需在 get_pretrained_segment_nt_model 函数中传入 rescaling_factor 参数,其值计算公式为:

rescaling_factor = 最大核苷酸数 / 核苷酸转换器最大标记数

其中,推理时 DNA 标记数(num_dna_tokens_inference) 指推理过程中的标记总数(例如 40008 个碱基对的序列对应 6669 个标记),核苷酸转换器最大标记数(max_num_tokens_nt) 为骨干网络核苷酸转换器的训练最大标记数,即 2048

🚧 SegmentNT 模型不支持输入序列中包含任何 “N” 碱基。原因是每个核苷酸需被标记化为 6-mer 形式,而包含一个或多个 “N” 碱基的序列无法满足这一标记化要求。

2.2.1 🔍 示例代码

下面这段代码的核心功能是使用预训练的segment_nt模型对 DNA 序列进行基因组特征预测(如预测序列是否为内含子 intron 等),并展示了从模型加载、数据预处理到并行推理的完整流程。具体包括:

  • 环境配置:指定 JAX 使用 CPU 设备,避免内存泄漏,同时准备多设备并行计算环境。
  • 模型加载:加载预训练的segment_nt模型,包括模型参数、前向计算函数、分词器和配置信息,并通过 Haiku 和 JAX 的并行接口(pmap)适配多设备计算。
  • 数据处理:将输入的 DNA 序列通过分词器转换为模型可识别的 token ids,并转换为 JAX 数组格式。
  • 并行推理:在多个 CPU 设备上并行运行模型,输出 DNA 序列对应各类基因组特征的预测概率,并重点提取了 “intron”(内含子)的预测概率。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# 导入所需库:haiku用于构建神经网络,jax用于高性能数值计算,jax.numpy是jax的numpy接口
# nucleotide_transformer.pretrained中的get_pretrained_segment_nt_model用于获取预训练的核苷酸序列模型
import haiku as hk
import jax
import jax.numpy as jnp
from nucleotide_transformer.pretrained import get_pretrained_segment_nt_model

# 配置JAX默认使用CPU作为计算设备,避免在其他设备(如GPU)上可能出现的内存泄漏问题,增强代码稳定性
jax.config.update("jax_platform_name", "cpu")

# 指定计算后端为CPU,获取所有可用的CPU设备,并计算设备数量
backend = "cpu"
devices = jax.devices(backend)
num_devices = len(devices)
print(f"Devices found: {devices}") # 打印找到的设备信息

# DNA序列token数量(不包含前缀的CLS token)需要能被下采样块数量的2的幂整除(此处下采样块对应的值为4)
max_num_nucleotides = 8 # 定义最大核苷酸数量

# 断言检查:确保max_num_nucleotides能被4整除,否则抛出错误(保证模型输入尺寸兼容)
assert max_num_nucleotides % 4 == 0, (
"The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by"
"2 to the power of the number of downsampling block, i.e 4.")

# 获取预训练的segment_nt模型:包括模型参数、前向计算函数、分词器和模型配置
# 其中:
# - model_name指定模型名为"segment_nt"
# - embeddings_layers_to_save指定保存第29层的嵌入特征
# - attention_maps_to_save指定保存(1,4)和(7,10)位置的注意力图
# - max_positions指定最大位置数(包含CLS token,因此为max_num_nucleotides + 1)
parameters, forward_fn, tokenizer, config = get_pretrained_segment_nt_model(
model_name="segment_nt",
embeddings_layers_to_save=(29,),
attention_maps_to_save=((1, 4), (7, 10)),
max_positions=max_num_nucleotides + 1,
)
# 用haiku.transform转换前向函数,使其符合haiku的函数式编程范式
forward_fn = hk.transform(forward_fn)
# 使用jax.pmap对前向计算函数进行并行映射,指定计算设备,donate_argnums=(0,)表示允许释放参数的内存
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))


# 准备输入数据并进行分词
sequences = ["ATTCCGATTCCGATTCCAACGGATTATTCCGATTAACCGATTCCAATT", "ATTTCTCTCTCTCTCTGAGATCGATGATTTCTCTCTCATCGAACTATG"] # 两个DNA序列示例
# 对序列进行批量分词,提取token ids(忽略元组中的第一个元素,保留token ids)
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
# 将token ids转换为jax数组(int32类型),作为模型输入
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)

# 初始化随机种子,并将随机键复制到所有设备(多设备并行时保持一致性)
random_key = jax.random.PRNGKey(seed=0)
keys = jax.device_put_replicated(random_key, devices=devices)
# 将模型参数复制到所有设备
parameters = jax.device_put_replicated(parameters, devices=devices)
# 将输入token复制到所有设备
tokens = jax.device_put_replicated(tokens, devices=devices)

# 对输入序列进行模型推理
outs = apply_fn(parameters, keys, tokens)
# 从推理结果中获取基因组特征的logits(未归一化的概率)
logits = outs["logits"]
# 将logits通过softmax转换为概率,并取最后一个维度的结果
probabilities = jnp.asarray(jax.nn.softmax(logits, axis=-1))[...,-1]
print(f"Probabilities shape: {probabilities.shape}") # 打印概率的形状

# 打印模型可预测的基因组特征(如intron、exon等)
print(f"Features inferred: {config.features}")

# 获取"intron"(内含子)特征在特征列表中的索引
idx_intron = config.features.index("intron")
# 提取"intron"对应的概率
probabilities_intron = probabilities[..., idx_intron]
print(f"Intron probabilities shape: {probabilities_intron.shape}") # 打印内含子概率的形状

支持的模型名:

  • segment_nt
  • segment_nt_multi_species

2.2.2 完整代码

下面的代码展示了如何对 10kb 和 50kb 序列进行推理,以及如何绘制概率分布图以复现论文中的图 1(Fig.1d)。

1
2
3
4
5
6
7
8
9
10
11
12
13
from Bio import SeqIO
import gzip
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np
import seaborn as sns
from typing import List
import matplotlib.pyplot as plt
from nucleotide_transformer.pretrained import get_pretrained_segment_nt_model

# Specify "cpu" as default (but you can decide to use GPU or TPU in the next cell)
jax.config.update("jax_platform_name", "cpu")
Devices found: [CpuDevice(id=0)]

(1)指定后端设备

1
2
3
4
5
# Use either "cpu", "gpu" or "tpu"
backend = "cpu"
devices = jax.devices(backend)
num_devices = len(devices)
print(f"Devices found: {devices}")

(2)定义用于绘制概率的函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# seaborn settings
sns.set_style("whitegrid")
sns.set_context(
"notebook",
font_scale=1,
rc={
"font.size": 14,
"axes.titlesize": 18,
"axes.labelsize": 18,
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"legend.fontsize": 16,
}
)

plt.rcParams['xtick.bottom'] = True
plt.rcParams['ytick.left'] = True

# set colors
colors = sns.color_palette("Set2").as_hex()
colors2 = sns.color_palette("husl").as_hex()

# Rearrange order of the features to match Fig.3 from the paper.
features_rearranged = [
'protein_coding_gene',
'lncRNA',
'5UTR',
'3UTR',
'exon',
'intron',
'splice_donor',
'splice_acceptor',
'promoter_Tissue_specific',
'promoter_Tissue_invariant',
'enhancer_Tissue_specific',
'enhancer_Tissue_invariant',
'CTCF-bound',
'polyA_signal',
]

def plot_features(
predicted_probabilities_all,
seq_length: int,
features: List[str],
order_to_plot: List[str],
fig_width=8,
):
"""
Function to plot labels and predicted probabilities.

Args:
predicted_probabilities_all: Probabilities per genomic feature for each
nucleotides in the DNA sequence.
seq_length: DNA sequence length.
feature: Genomic features to plot.
order_to_plot: Order in which to plot the genomic features. This needs to be
specified in order to match the order presented in the Fig.3 of the paper
fig_width: Width of the figure
"""

sc = 1.8
n_panels = 7

# fig, axes = plt.subplots(n_panels, 1, figsize=(fig_width * sc, (n_panels + 2) * sc), height_ratios=[6] + [2] * (n_panels-1))
_, axes = plt.subplots(n_panels, 1, figsize=(fig_width * sc, (n_panels + 4) * sc))

for n, feat in enumerate(order_to_plot):
feat_id = features.index(feat)
prob_dist = predicted_probabilities_all[:, feat_id]

# Use the appropriate subplot
ax = axes[n // 2]

try:
id_color = colors[feat_id]
except:
id_color = colors2[feat_id - 8]
ax.plot(
prob_dist,
color=id_color,
label=feat,
linestyle="-",
linewidth=1.5,
)
ax.set_xlim(0, seq_length)
ax.grid(False)
ax.spines['bottom'].set_color('black')
ax.spines['top'].set_color('black')
ax.spines['right'].set_color('black')
ax.spines['left'].set_color('black')

for a in range (0,n_panels):
axes[a].set_ylim(0, 1.05)
axes[a].set_ylabel("Prob.")
axes[a].legend(loc="upper left", bbox_to_anchor=(1, 1), borderaxespad=0)
if a != (n_panels-1):
axes[a].tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=False)

# Set common x-axis label
axes[-1].set_xlabel("Nucleotides")
# axes[0].axis('off') # Turn off the axis
axes[n_panels-1].grid(False)
axes[n_panels-1].tick_params(axis='y', which='both', left=True, right=False, labelleft=True, labelright=False)

axes[0].set_title("Probabilities predicted over all genomics features", fontweight="bold")

plt.show()

(3)获取人的 20号染色体序列

为了重现Segment-NT论文中的图表,我们在此检索下载人类20号染色体的文件

1
! wget https://ftp.ensembl.org/pub/release-111/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz
--2024-07-23 09:38:24--  https://ftp.ensembl.org/pub/release-111/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz
Resolving ftp.ensembl.org (ftp.ensembl.org)... 193.62.193.169
Connecting to ftp.ensembl.org (ftp.ensembl.org)|193.62.193.169|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18833053 (18M) [application/x-gzip]
Saving to: ‘Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz’

Homo_sapiens.GRCh38 100%[===================>]  17,96M  1,77MB/s    in 10s     

2024-07-23 09:38:35 (1,75 MB/s) - ‘Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz’ saved [18833053/18833053]
1
2
3
4
5
fasta_path = "Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz"

with gzip.open(fasta_path, "rt") as handle:
record = next(SeqIO.parse(handle, "fasta"))
chr20 = str(record.seq)

(4)对10kb基因组序列进行推断(不需要改变前向函数中的重新缩放因子)

① 实例化SegmentNT推理函数

以下代码允许您下载其中一个Segment-NT模型的权重。它会返回权重字典、haiku前向函数、分词器和配置字典。

get_pretrained_nucleotide_transformer函数类似,您还可以指定:

  1. 您希望收集嵌入的层(例如,(5, 10, 20) 表示获取第5、10和20层的嵌入)
  2. 您希望收集的注意力图(例如,((1,4), (7,18)) 表示获取对应于第1层第4头和第7层第18头的注意力图)。请参考配置以查看模型中的层数和头数。
  3. 您将进行推理计算的序列中的最大标记数。您可以输入不超过模型配置中指定的值(包含将自动添加到序列开头的类标记),但为了优化内存和推理时间,我们建议将此数字尽可能设小。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by
# the square of the number of downsampling block, i.e 4.
max_num_nucleotides = 1668

assert max_num_nucleotides % 4 == 0, (
"The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by"
"2 to the power of the number of downsampling block, i.e 4.")

# If max_num_nucleotides is larger than what was used to train Segment-NT, the rescaling
# factor needs to be adapted.
if max_num_nucleotides + 1 > 5001:
inference_rescaling_factor = (max_num_nucleotides + 1) / 2048
else:
inference_rescaling_factor=None

# If this download fails at one point, restarting it will not work.
# Before rerunning the cell, make sure to delete the cache by executing:
# ! rm -rf ~/.cache/nucleotide_transformer/
parameters, forward_fn, tokenizer, config = get_pretrained_segment_nt_model(
model_name="segment_nt",
rescaling_factor=inference_rescaling_factor,
embeddings_layers_to_save=(29,),
attention_maps_to_save=((1, 4), (7, 10)),
max_positions=max_num_nucleotides + 1,
)
forward_fn = hk.transform(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))

# Put required quantities for the inference on the devices. This step is not
# reproduced in the second inference since the quantities will already be loaded
# on the devices !
random_key = jax.random.PRNGKey(seed=0)
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)
Downloading model's hyperparameters json file...
Downloaded model's hyperparameters.
Downloading model's weights...
Downloaded model's weights...

② 对DNA序列进行分词

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
idx_start = 2650520

idx_stop = idx_start + max_num_nucleotides*6

sequences = [chr20[idx_start:idx_stop]]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]

# This stacks the batch so that it is repeated across the devices. This is done
# in order to allow for replication even if one has more than one device.
# To take advantage of the multiple devices and infer different sequences on
# each of the devices, make sure to change this line into a reshape.
# a reshape
tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)]*num_devices, axis=0)
tokens.shape

(1, 1, 1669)

③ 对生成的批次进行推理

1
2
3
4
5
6
7
# Infer
outs = apply_fn(parameters, keys, tokens)

# Obtain the logits over the genomic features
logits = outs["logits"]
# Transform them on probabilities
probabilities = np.asarray(jax.nn.softmax(logits, axis=-1))[...,-1]
/home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages/jax/interpreters/mlir.py:622: UserWarning: Some donated buffers were not usable: ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[4107,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[4107]), ShapedArray(float32[1024,4107]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[168]), ShapedArray(float32[1024,168]), ShapedArray(float32[1024]), ShapedArray(float32[3,1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[3,1024,1024]), ShapedArray(float32[2048]), ShapedArray(float32[3,1024,2048]), ShapedArray(float32[2048]), ShapedArray(float32[3,2048,2048]), ShapedArray(float32[1024]), ShapedArray(float32[3,1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[3,1024,1024]), ShapedArray(float32[2048]), ShapedArray(float32[3,2048,2048]), ShapedArray(float32[2048]), ShapedArray(float32[3,2048,2048]), ShapedArray(float32[1024]), ShapedArray(float32[3,1024,2048]), ShapedArray(float32[1024]), ShapedArray(float32[3,1024,1024]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
1
probabilities.shape
(1, 1, 10008, 14)

④ 绘制这条DNA序列上14个基因组特征的概率图

请注意,SegmentNT论文中的图1是用SegmentNT-10kb实现的,而这里使用的是SegmentNT-30kb,这就解释了为什么概率并非完全相同。

1
2
3
4
5
6
7
plot_features(
probabilities[0,0],
probabilities.shape[-2],
fig_width=20,
features=config.features,
order_to_plot=features_rearranged
)

png

(5)对50kb的基因组序列进行推断并绘图(需要更改前向函数中的重新缩放因子)

① 实例化SegmentNT推理函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by
# the square of the number of downsampling block, i.e 4.
max_num_nucleotides = 8332

assert max_num_nucleotides % 4 == 0, (
"The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by"
"2 to the power of the number of downsampling block, i.e 4.")

# If max_num_nucleotides is larger than what was used to train Segment-NT, the rescaling
# factor needs to be adapted.
if max_num_nucleotides + 1 > 5001:
inference_rescaling_factor = (max_num_nucleotides + 1) / 2048
else:
inference_rescaling_factor=None

# The parameters have already been downloaded above
# so we do not instantiate them. However we instantiate a new forward function
# where the context length extension needed is in effect.
_, forward_fn, tokenizer, config = get_pretrained_segment_nt_model(
model_name="segment_nt",
rescaling_factor=inference_rescaling_factor,
embeddings_layers_to_save=(29,),
attention_maps_to_save=((1, 4), (7, 10)),
max_positions=max_num_nucleotides + 1,
)
forward_fn = hk.transform(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))

② 对DNA序列进行分词

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
idx_start = 5099984
idx_stop = idx_start + max_num_nucleotides*6

sequences = [chr20[idx_start:idx_stop]]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]

# This stacks the batch so that it is repeated across the devices. This is done
# in order to allow for replication even if one has more than one device.
# To take advantage of the multiple devices and infer different sequences on
# each of the devices, make sure to change this line into a reshape.
# a reshape
tokens = jnp.asarray(tokens_ids, dtype=jnp.int32)[None, :]
tokens.shape, idx_stop

((1, 1, 8333), 5149976)

③ 对生成的批次进行推理

1
2
3
4
5
6
7
# Infer
outs = apply_fn(parameters, keys, tokens)

# Obtain the logits over the genomic features
logits = outs["logits"]
# Transform them on probabilities
probabilities = np.asarray(jax.nn.softmax(logits, axis=-1))[...,-1]

④ 绘制这条DNA序列上14个基因组特征的概率图

1
2
3
4
5
6
7
plot_features(
probabilities[0,0],
probabilities.shape[-2],
fig_width=20,
features=config.features,
order_to_plot=features_rearranged
)

png

2.3 SegmentEnformer

SegmentEnformer借助了Enformer,它移除了预测头,并用一个一维U-Net分割头取而代之,以单核苷酸分辨率预测序列中多种基因组元件的位置

2.3.1 🔍 示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

from nucleotide_transformer.enformer.pretrained import get_pretrained_segment_enformer_model
from nucleotide_transformer.enformer.features import FEATURES

# Initialize CPU as default JAX device. This makes the code robust to memory leakage on
# the devices.
jax.config.update("jax_platform_name", "cpu")

backend = "cpu"
devices = jax.devices(backend)
num_devices = len(devices)

# Load model
parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_enformer_model()
forward_fn = hk.transform_with_state(forward_fn)

apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))
random_key = jax.random.PRNGKey(seed=0)

# Replicate over devices
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)
state = jax.device_put_replicated(state, devices=devices)

# Get data and tokenize it
sequences = ["A" * 196_608]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)] * num_devices, axis=0)

# Infer
outs, state = apply_fn(parameters, state, keys, tokens)

# Obtain the logits over the genomic features
logits = outs["logits"]
# Transform them on probabilities
probabilities = np.asarray(jax.nn.softmax(logits, axis=-1))[..., -1]

# Get probabilities associated with intron
idx_intron = FEATURES.index("intron")
probabilities_intron = probabilities[..., idx_intron]
print(f"Intron probabilities shape: {probabilities_intron.shape}")

2.3.2 完整代码

下列代码展示了如何对一个196,608bp 的序列进行推断并绘制概率图。模块导入、基因组下载、绘图函数如前所述,不再重复。

(1)实例化SegmentEnformer推理函数

以下代码允许您下载SegmentEnformer的权重。它会返回权重字典、haiku前向函数、分词器和配置字典。

1
parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_enformer_model()
Downloading model's weights...
1
2
3
4
5
6
7
8
9
10
forward_fn = hk.transform_with_state(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))

# Put required quantities for the inference on the devices. This step is not
# reproduced in the second inference since the quantities will already be loaded
# on the devices !
random_key = jax.random.PRNGKey(seed=0)
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)
state = jax.device_put_replicated(state, devices=devices)

(2)DNA序列分词

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
num_nucleotides = 196_608
idx_start = 2650520

idx_stop = idx_start + num_nucleotides

sequences = [chr20[idx_start:idx_stop]]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]

# This stacks the batch so that it is repeated across the devices. This is done
# in order to allow for replication even if one has more than one device.
# To take advantage of the multiple devices and infer different sequences on
# each of the devices, make sure to change this line into a reshape.
# a reshape
tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)]*num_devices, axis=0)
tokens.shape

(1, 1, 196608)

(3)对生成的批次进行推断

1
2
3
4
5
6
7
# Infer
outs, state = apply_fn(parameters, state, keys, tokens)

# Obtain the logits over the genomic features
logits = outs["logits"]
# Transform them on probabilities
probabilities = np.asarray(jax.nn.softmax(logits, axis=-1))[...,-1]
1
probabilities.shape
(1, 1, 196608, 14)

(4)绘制该DNA序列中14个基因组特征的概率图

1
2
3
4
5
6
7
plot_features(
probabilities[0, 0],
probabilities.shape[-2],
fig_width=20,
features=FEATURES,
order_to_plot=features_rearranged
)

png

2.4 SegmentBorzoi

SegmentBorzoi利用了Borzoi,它移除了预测头,并将其替换为一个一维U-Net分割头,以预测序列中多种基因组元素的位置

2.4.1 🔍 示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

from nucleotide_transformer.borzoi.pretrained import get_pretrained_segment_borzoi_model
from nucleotide_transformer.enformer.features import FEATURES

# Initialize CPU as default JAX device. This makes the code robust to memory leakage on
# the devices.
jax.config.update("jax_platform_name", "cpu")

backend = "cpu"
devices = jax.devices(backend)
num_devices = len(devices)

# Load model
parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_borzoi_model()
forward_fn = hk.transform_with_state(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))
random_key = jax.random.PRNGKey(seed=0)

# Replicate over devices
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)
state = jax.device_put_replicated(state, devices=devices)

# Get data and tokenize it
sequences = ["A" * 524_288]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)] * num_devices, axis=0)

# Infer
outs, state = apply_fn(parameters, state, keys, tokens)

# Obtain the logits over the genomic features
logits = outs["logits"]
# Transform them on probabilities
probabilities = np.asarray(jax.nn.softmax(logits, axis=-1))[..., -1]

# Get probabilities associated with intron
idx_intron = FEATURES.index("intron")
probabilities_intron = probabilities[..., idx_intron]
print(f"Intron probabilities shape: {probabilities_intron.shape}")

2.4.2 完整代码

下述代码展示了如何对一个196608 bp (524288bp?)的序列进行推断并绘制概率图。模块导入、基因组下载、绘图函数如前所述,不再重复。

(1)实例化SegmentBorzoi推理函数

以下代码允许您下载SegmentBorzoi的权重。它会返回权重字典、haiku前向函数、分词器和配置字典。

1
parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_borzoi_model()
Downloading model's weights...
1
2
3
4
5
6
7
8
9
10
forward_fn = hk.transform_with_state(forward_fn)
apply_fn = jax.pmap(forward_fn.apply, devices=devices, donate_argnums=(0,))

# Put required quantities for the inference on the devices. This step is not
# reproduced in the second inference since the quantities will already be loaded
# on the devices !
random_key = jax.random.PRNGKey(seed=0)
keys = jax.device_put_replicated(random_key, devices=devices)
parameters = jax.device_put_replicated(parameters, devices=devices)
state = jax.device_put_replicated(state, devices=devices)

(2)DNA 序列分词

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
num_nucleotides = 524_288

idx_start = 2650520
idx_stop = idx_start + num_nucleotides

sequences = [chr20[idx_start:idx_stop]]
tokens_ids = [b[1] for b in tokenizer.batch_tokenize(sequences)]
tokens_str = [b[0] for b in tokenizer.batch_tokenize(sequences)]

# This stacks the batch so that it is repeated across the devices. This is done
# in order to allow for replication even if one has more than one device.
# To take advantage of the multiple devices and infer different sequences on
# each of the devices, make sure to change this line into a reshape.
# a reshape
tokens = jnp.stack([jnp.asarray(tokens_ids, dtype=jnp.int32)]*num_devices, axis=0)
tokens.shape

(1, 1, 524288)

(3)对生成的批次进行推断

1
2
3
4
5
6
7
# Infer
outs, state = apply_fn(parameters, state, keys, tokens)

# Obtain the logits over the genomic features
logits = outs["logits"]
# Transform them on probabilities
probabilities = np.asarray(jax.nn.softmax(logits, axis=-1))[...,-1]
1
probabilities.shape
(1, 1, 196608, 14)

(4)绘制该DNA序列中14个基因组特征的概率图

1
2
3
4
5
6
7
plot_features(
probabilities[0, 0],
probabilities.shape[-2],
fig_width=20,
features=FEATURES,
order_to_plot=features_rearranged
)

png

参考文献 📚

加关注

生信之巅微信公众号 生信之巅小程序码
上一篇:
基因组试金石:在中心法则背景下对基因组语言模型进行基准测试
下一篇:
使用DNA基础模型在单核苷酸分辨率上注释基因组