1. Segmentation models 简介
使用Transformer骨干网络的分割模型(如Nucleotide Transformer、Enformer、Borzoi)可用于单核苷酸分辨率下的基因组元件预测。例如,SegmentNT能在长达30kb的序列(可扩展至50kbp)中预测14种不同类别的人类基因组元件,并表现出优异的性能。
所有模型均搭配一维U-Net分割头,以单核苷酸分辨率预测序列中多种基因组元件的位置。这些元件包括基因元件(蛋白质编码基因、长链非编码RNA、5’非翻译区、3’非翻译区、外显子、内含子、剪接受体位点和供体位点)和调控元件(polyA signal、组织非特异性和组织特异性启动子及增强子,以及CTCF结合位点)。
- 📜 Read the Paper (Nature Methods 2025)
- 🤗 SegmentNT Hugging Face Collection
- 🚀 SegmentNT Inference Notebook (HF)
Fig. 1: SegmentNT localizes genomics elements at nucleotide resolution.
2. 如何使用 🚀
2.1 安装并加载模块
1 | !pip install boto3 |
Looking in indexes: https://pypi.org/simple, https://packagecloud.io/github/git-lfs/pypi/simple
Requirement already satisfied: boto3 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (1.28.2)
Requirement already satisfied: botocore<1.32.0,>=1.31.2 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from boto3) (1.31.2)
Requirement already satisfied: jmespath<2.0.0,>=0.7.1 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from boto3) (1.0.1)
Requirement already satisfied: s3transfer<0.7.0,>=0.6.0 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from boto3) (0.6.1)
Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from botocore<1.32.0,>=1.31.2->boto3) (2.8.2)
Requirement already satisfied: urllib3<1.27,>=1.25.4 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from botocore<1.32.0,>=1.31.2->boto3) (1.26.16)
Requirement already satisfied: six>=1.5 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.32.0,>=1.31.2->boto3) (1.16.0)
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.1[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Looking in indexes: https://pypi.org/simple, https://packagecloud.io/github/git-lfs/pypi/simple
Requirement already satisfied: matplotlib in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (3.7.2)
Requirement already satisfied: contourpy>=1.0.1 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (1.1.0)
Requirement already satisfied: cycler>=0.10 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (4.41.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (1.4.4)
Requirement already satisfied: numpy>=1.20 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (1.25.1)
Requirement already satisfied: packaging>=20.0 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (23.1)
Requirement already satisfied: pillow>=6.2.0 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (10.0.0)
Requirement already satisfied: pyparsing<3.1,>=2.3.1 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (3.0.9)
Requirement already satisfied: python-dateutil>=2.7 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from matplotlib) (2.8.2)
Requirement already satisfied: six>=1.5 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.1[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Looking in indexes: https://pypi.org/simple, https://packagecloud.io/github/git-lfs/pypi/simple
Requirement already satisfied: biopython in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (1.81)
Requirement already satisfied: numpy in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from biopython) (1.25.1)
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.1[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Looking in indexes: https://pypi.org/simple, https://packagecloud.io/github/git-lfs/pypi/simple
Requirement already satisfied: dm-haiku in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (0.0.9)
Requirement already satisfied: absl-py>=0.7.1 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from dm-haiku) (1.4.0)
Requirement already satisfied: jmp>=0.0.2 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from dm-haiku) (0.0.4)
Requirement already satisfied: numpy>=1.18.0 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from dm-haiku) (1.25.1)
Requirement already satisfied: tabulate>=0.8.9 in /home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages (from dm-haiku) (0.9.0)
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.1[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
1 | import os |
2.2 SegmentNT
⚠️ SegmentNT 模型采用 核苷酸转换器(NT) 作为骨干网络,训练数据为 30,000 个核苷酸序列(对应 5001 个标记,包含 CLS 标记)。但研究表明,SegmentNT 可泛化至 50,000 bp(碱基对)的序列。由于 30,000 bp 的训练长度超过了核苷酸转换器所能处理的最大长度(2048 个 6-mer 标记),因此采用了 Yarn 缩放技术。
默认情况下,缩放系数(rescaling factor) 已设置为训练时使用的值。若需对 30kbp 至 50kbp 之间的序列进行推理,需在 get_pretrained_segment_nt_model 函数中传入 rescaling_factor 参数,其值计算公式为:
rescaling_factor = 最大核苷酸数 / 核苷酸转换器最大标记数
其中,推理时 DNA 标记数(num_dna_tokens_inference) 指推理过程中的标记总数(例如 40008 个碱基对的序列对应 6669 个标记),核苷酸转换器最大标记数(max_num_tokens_nt) 为骨干网络核苷酸转换器的训练最大标记数,即 2048。
🚧 SegmentNT 模型不支持输入序列中包含任何 “N” 碱基。原因是每个核苷酸需被标记化为 6-mer 形式,而包含一个或多个 “N” 碱基的序列无法满足这一标记化要求。
2.2.1 🔍 示例代码
下面这段代码的核心功能是使用预训练的segment_nt模型对 DNA 序列进行基因组特征预测(如预测序列是否为内含子 intron 等),并展示了从模型加载、数据预处理到并行推理的完整流程。具体包括:
- 环境配置:指定 JAX 使用 CPU 设备,避免内存泄漏,同时准备多设备并行计算环境。
- 模型加载:加载预训练的segment_nt模型,包括模型参数、前向计算函数、分词器和配置信息,并通过 Haiku 和 JAX 的并行接口(pmap)适配多设备计算。
- 数据处理:将输入的 DNA 序列通过分词器转换为模型可识别的 token ids,并转换为 JAX 数组格式。
- 并行推理:在多个 CPU 设备上并行运行模型,输出 DNA 序列对应各类基因组特征的预测概率,并重点提取了 “intron”(内含子)的预测概率。
1 | # 导入所需库:haiku用于构建神经网络,jax用于高性能数值计算,jax.numpy是jax的numpy接口 |
支持的模型名:
- segment_nt
- segment_nt_multi_species
2.2.2 完整代码
下面的代码展示了如何对 10kb 和 50kb 序列进行推理,以及如何绘制概率分布图以复现论文中的图 1(Fig.1d)。
1 | from Bio import SeqIO |
Devices found: [CpuDevice(id=0)]
(1)指定后端设备
1 | # Use either "cpu", "gpu" or "tpu" |
(2)定义用于绘制概率的函数
1 | # seaborn settings |
(3)获取人的 20号染色体序列
为了重现Segment-NT论文中的图表,我们在此检索下载人类20号染色体的文件
1 | ! wget https://ftp.ensembl.org/pub/release-111/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz |
--2024-07-23 09:38:24-- https://ftp.ensembl.org/pub/release-111/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz
Resolving ftp.ensembl.org (ftp.ensembl.org)... 193.62.193.169
Connecting to ftp.ensembl.org (ftp.ensembl.org)|193.62.193.169|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 18833053 (18M) [application/x-gzip]
Saving to: ‘Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz’
Homo_sapiens.GRCh38 100%[===================>] 17,96M 1,77MB/s in 10s
2024-07-23 09:38:35 (1,75 MB/s) - ‘Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz’ saved [18833053/18833053]
1 | fasta_path = "Homo_sapiens.GRCh38.dna.chromosome.20.fa.gz" |
(4)对10kb基因组序列进行推断(不需要改变前向函数中的重新缩放因子)
① 实例化SegmentNT推理函数
以下代码允许您下载其中一个Segment-NT模型的权重。它会返回权重字典、haiku前向函数、分词器和配置字典。
与get_pretrained_nucleotide_transformer函数类似,您还可以指定:
- 您希望收集嵌入的层(例如,(5, 10, 20) 表示获取第5、10和20层的嵌入)
- 您希望收集的注意力图(例如,((1,4), (7,18)) 表示获取对应于第1层第4头和第7层第18头的注意力图)。请参考配置以查看模型中的层数和头数。
- 您将进行推理计算的序列中的最大标记数。您可以输入不超过模型配置中指定的值(包含将自动添加到序列开头的类标记),但为了优化内存和推理时间,我们建议将此数字尽可能设小。
1 | # The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by |
Downloading model's hyperparameters json file...
Downloaded model's hyperparameters.
Downloading model's weights...
Downloaded model's weights...
② 对DNA序列进行分词
1 | idx_start = 2650520 |
(1, 1, 1669)
③ 对生成的批次进行推理
1 | # Infer |
/home/hugo/anaconda3/envs/trix/lib/python3.10/site-packages/jax/interpreters/mlir.py:622: UserWarning: Some donated buffers were not usable: ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,8192]), ShapedArray(float32[4096,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[4107,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[4107]), ShapedArray(float32[1024,4107]), ShapedArray(float32[1024]), ShapedArray(float32[1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[1024]), ShapedArray(float32[168]), ShapedArray(float32[1024,168]), ShapedArray(float32[1024]), ShapedArray(float32[3,1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[3,1024,1024]), ShapedArray(float32[2048]), ShapedArray(float32[3,1024,2048]), ShapedArray(float32[2048]), ShapedArray(float32[3,2048,2048]), ShapedArray(float32[1024]), ShapedArray(float32[3,1024,1024]), ShapedArray(float32[1024]), ShapedArray(float32[3,1024,1024]), ShapedArray(float32[2048]), ShapedArray(float32[3,2048,2048]), ShapedArray(float32[2048]), ShapedArray(float32[3,2048,2048]), ShapedArray(float32[1024]), ShapedArray(float32[3,1024,2048]), ShapedArray(float32[1024]), ShapedArray(float32[3,1024,1024]).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}")
1 | probabilities.shape |
(1, 1, 10008, 14)
④ 绘制这条DNA序列上14个基因组特征的概率图
请注意,SegmentNT论文中的图1是用SegmentNT-10kb实现的,而这里使用的是SegmentNT-30kb,这就解释了为什么概率并非完全相同。
1 | plot_features( |

(5)对50kb的基因组序列进行推断并绘图(需要更改前向函数中的重新缩放因子)
① 实例化SegmentNT推理函数
1 | # The number of DNA tokens (excluding the CLS token prepended) needs to be dividible by |
② 对DNA序列进行分词
1 | idx_start = 5099984 |
((1, 1, 8333), 5149976)
③ 对生成的批次进行推理
1 | # Infer |
④ 绘制这条DNA序列上14个基因组特征的概率图
1 | plot_features( |

2.3 SegmentEnformer
SegmentEnformer借助了Enformer,它移除了预测头,并用一个一维U-Net分割头取而代之,以单核苷酸分辨率预测序列中多种基因组元件的位置。
2.3.1 🔍 示例
1 | import haiku as hk |
2.3.2 完整代码
下列代码展示了如何对一个196,608bp 的序列进行推断并绘制概率图。模块导入、基因组下载、绘图函数如前所述,不再重复。
(1)实例化SegmentEnformer推理函数
以下代码允许您下载SegmentEnformer的权重。它会返回权重字典、haiku前向函数、分词器和配置字典。
1 | parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_enformer_model() |
Downloading model's weights...
1 | forward_fn = hk.transform_with_state(forward_fn) |
(2)DNA序列分词
1 | num_nucleotides = 196_608 |
(1, 1, 196608)
(3)对生成的批次进行推断
1 | # Infer |
1 | probabilities.shape |
(1, 1, 196608, 14)
(4)绘制该DNA序列中14个基因组特征的概率图
1 | plot_features( |

2.4 SegmentBorzoi
SegmentBorzoi利用了Borzoi,它移除了预测头,并将其替换为一个一维U-Net分割头,以预测序列中多种基因组元素的位置。
2.4.1 🔍 示例
1 | import haiku as hk |
2.4.2 完整代码
下述代码展示了如何对一个196608 bp (524288bp?)的序列进行推断并绘制概率图。模块导入、基因组下载、绘图函数如前所述,不再重复。
(1)实例化SegmentBorzoi推理函数
以下代码允许您下载SegmentBorzoi的权重。它会返回权重字典、haiku前向函数、分词器和配置字典。
1 | parameters, state, forward_fn, tokenizer, config = get_pretrained_segment_borzoi_model() |
Downloading model's weights...
1 | forward_fn = hk.transform_with_state(forward_fn) |
(2)DNA 序列分词
1 | num_nucleotides = 524_288 |
(1, 1, 524288)
(3)对生成的批次进行推断
1 | # Infer |
1 | probabilities.shape |
(1, 1, 196608, 14)
(4)绘制该DNA序列中14个基因组特征的概率图
1 | plot_features( |

参考文献 📚
- [1] de Almeida, B.P., Dalla-Torre, H., Richard, G. et al. Annotating the genome at single-nucleotide resolution with DNA foundation models. Nat Methods (2025). https://doi.org/10.1038/s41592-025-02881-2
加关注
![]() |
![]() |

