Triton 初体验

About Triton

Triton is a language and compiler for parallel programming. It aims to provide a Python-based programming environment for productively writing custom DNN compute kernels capable of running at maximal throughput on modern GPU hardware.

安装开发包

git clone https://github.com/triton-lang/triton.git
cd triton

pip install -r python/requirements.txt # build-time dependencies
pip install -e .
  • 使用 pip 前记得创建一个专属 venv,避免影响到系统全局 python 环境

  • 如 PC 内存较小,默认 install -e 会导致 OOM,建议限制 Job 数量,如MAX_JOBS=8 pip install -e .

运行 Sample

python python/tutorials/01-vector-add.py
  • 依赖包涉及torch, numpy, matplotlib,pandas

  • pip 网络问题安装失败时,记得添加nameserver 8.8.8.8/etc/resover.conf

运行结果

tensor([1.3713, 1.3076, 0.4940,  ..., 0.4705, 1.6737, 1.6400], device='cuda:0')
tensor([1.3713, 1.3076, 0.4940,  ..., 0.4705, 1.6737, 1.6400], device='cuda:0')
The maximum difference between torch and triton is 0.0
vector-add-performance:
           size      Triton       Torch
0        4096.0   13.837838   13.963636
1        8192.0   29.538462   36.141177
2       16384.0   63.999998   60.235295
3       32768.0  101.553721  100.310203
4       65536.0  166.054047  164.939598
5      131072.0  244.537310  244.537310
6      262144.0  303.407414  301.546004
7      524288.0  331.827848  323.368435
8     1048576.0  360.417951  360.417951
9     2097152.0  378.092307  379.003379
10    4194304.0  386.358140  386.453091
11    8388608.0  390.095241  390.095241
12   16777216.0  392.798670  392.578071
13   33554432.0  393.295877  393.215981
14   67108864.0  393.692747  393.575769
15  134217728.0  394.034837  393.674278