「論文通りのアルゴリズムを実装したのに、現場の病理画像では全く精度が出ない」
医療AI、特にデジタルパソロジーのプロジェクトにおいて、このような課題に直面するケースは少なくありません。特に、免疫チェックポイント阻害剤(ICI)の奏効予測マーカーとして注目される腫瘍浸潤リンパ球(TILs: Tumor-Infiltrating Lymphocytes)の定量化は、その難易度の高さで知られています。
理由は明白です。多くの開発現場において、きれいなデータセットでモデルを訓練することに終始し、数ギガバイトにも及ぶWSI(Whole Slide Image)という巨大な画像をどう効率的に処理し、そこから「臨床的に意味のある数値」をどう算出するかという、実践的なエンジニアリングが軽視されがちだからです。
AIモデルの精度(AUCやAccuracy)が高いことと、そのシステムが臨床現場で実用的に機能することは、全く別の問題です。AIはあくまで課題解決の手段であり、ROI(投資対効果)や実運用を見据えた設計が不可欠です。
今回は、プロジェクトマネジメントの実践的な視点も交えつつ、「実務で使えるTILs自動定量化パイプライン」をPython(PyTorch)で実装していきます。理論の解説は最小限に留め、WSIのハンドリングから予後予測スコアの算出まで、実運用に直結するコードを体系的に解説します。
1. 免疫療法におけるTILs評価の課題と自動化の意義
まず、これから解決しようとしている課題の「臨床的な重み」を理解することが重要です。ビジネス課題や現場のニーズを正確に把握せずに開発を進めると、実務では無意味な数値を計算するだけのシステムが出来上がってしまいます。
手動カウントの限界:観察者間変動と時間の壁
病理医は顕微鏡を覗き、腫瘍組織の中にどれくらいリンパ球が入り込んでいるかを目視で評価します。これがTILs評価です。しかし、手動での評価にはいくつかの構造的な課題が存在します。
- 観察者間変動(Inter-observer variability): 同じスライドを評価しても、複数の医師間でスコアが異なることが頻繁に発生します。
- 再現性の欠如: 同一の医師であっても、体調や時間帯によって評価にばらつきが生じる可能性があります。
- 物理的な限界: スライド全体(数万×数万ピクセル)をくまなく数えることは現実的ではありません。一部の視野(ROI: Region of Interest)のみで判断せざるを得ないのが実情です。
免疫療法の適応を判断する際、この評価の「ブレ」は患者の治療方針に直結する重大な問題となり得ます。そのため、AIを活用した「全自動かつ再現性のある定量化」が強く求められています。
デジタルパソロジーによる空間的解析の可能性
AIを導入することで、スライド全体の細胞を網羅的にカウントすることが可能になります。しかし、単に「リンパ球の総数」を算出するだけでは不十分です。
臨床的に重要なのは、リンパ球が「どこに」存在するかという空間的な情報です。
- Intratumoral TILs: 癌細胞の集団の中に直接入り込んでいるリンパ球
- Stromal TILs: 癌細胞を取り囲む間質(Stroma)にあるリンパ球
これら2つを明確に区別し、それぞれの密度を算出することが、予後予測の精度を大きく左右します。本記事では、この空間的な位置関係も考慮に入れたパイプラインの構築を目指します。
本記事で実装するパイプラインの全体像
今回構築するシステムのフローは以下の通りです。論理的かつ段階的に処理を進める設計としています。
- WSI前処理: 巨大画像から組織部分だけを抽出し、パッチ(小画像)に分割。
- 推論: 学習済みモデルを用いて、各パッチ内のリンパ球を検出。
- 定量化: 腫瘍領域と間質領域を区別し、密度スコアを算出。
- 可視化: ヒートマップを作成し、医師が検証可能な形にする。
それでは、具体的な実装手順を解説します。
2. WSI(Whole Slide Image)処理環境のセットアップ
WSIは通常の画像処理ライブラリ(OpenCVやPIL)でそのまま開こうとすると、メモリ不足を引き起こしシステムがクラッシュする原因となります。そのため、専用のライブラリであるopenslideを使用するのが標準的なアプローチです。
OpenSlideとPyTorchによる解析環境構築
まずは必要なライブラリをインポートします。ここでは、画像処理にalbumentations、WSI操作にopenslide-pythonを使用します。
# 必要なライブラリのインポート
import os
import numpy as np
import openslide
from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
# WSI解析用の設定
CONFIG = {
'WSI_PATH': './data/wsi_samples/sample_01.svs', # 架空のパス
'PATCH_SIZE': 256, # モデルに入力する画像サイズ
'LEVEL': 0, # 解析する倍率レベル(0が最高倍率)
'TISSUE_THRESH': 240, # 背景除去の閾値(白に近い部分を除去)
'BATCH_SIZE': 32,
'DEVICE': 'cuda' if torch.cuda.is_available() else 'cpu'
}
print(f"Using device: {CONFIG['DEVICE']}")
巨大画像のメモリ効率的な読み込み戦略
WSI全体を一度にメモリへ展開することは、リソース管理の観点から避けるべきです。サムネイル(低倍率画像)を使って組織の位置(座標)を特定し、必要な部分だけを高倍率で読み込む「オンデマンド・ローディング」を実装することで、メモリ効率を最適化します。
組織領域の自動抽出と背景除去
スライドガラス上の「何もない白い背景」に対してAIの推論を実行することは、計算リソースの著しい浪費につながります。Otsuの二値化や単純な輝度閾値を用いて、組織が存在する領域(Tissue Mask)を事前に抽出することが重要です。
以下のコードは、WSIから組織マスクを作成し、解析対象となるパッチの座標リストを生成する関数です。
def get_tissue_coordinates(wsi_path, patch_size, level=0, tissue_thresh=240):
"""
WSIから組織領域を検出し、パッチ切り出し用の座標リストを返す
"""
slide = openslide.OpenSlide(wsi_path)
# 処理高速化のため、低倍率(サムネイル)でマスクを作成
# レベル2程度(1/16サイズなど)で全体像を把握するのが一般的
thumbnail_level = min(2, slide.level_count - 1)
downsample = slide.level_downsamples[thumbnail_level]
w, h = slide.level_dimensions[thumbnail_level]
thumbnail = slide.read_region((0, 0), thumbnail_level, (w, h)).convert("RGB")
thumbnail_np = np.array(thumbnail)
# グレースケール化して二値化(組織部分は色が濃い=値が小さい)
gray = cv2.cvtColor(thumbnail_np, cv2.COLOR_RGB2GRAY)
_, tissue_mask = cv2.threshold(gray, tissue_thresh, 255, cv2.THRESH_BINARY_INV)
# 輪郭抽出でノイズ除去
contours, _ = cv2.findContours(tissue_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
patch_coords = []
# スライディングウィンドウで座標を収集
# 注意: 座標はレベル0(最高倍率)の座標系に変換する必要がある
step = int(patch_size / downsample) # サムネイル上でのステップサイズ
for y in range(0, h, step):
for x in range(0, w, step):
# サムネイル上で組織マスク内であれば座標リストに追加
# 中心点がマスク内にあるか簡易判定
cy, cx = y + step//2, x + step//2
if cy < h and cx < w and tissue_mask[cy, cx] > 0:
# レベル0座標へ変換
l0_x = int(x * downsample)
l0_y = int(y * downsample)
patch_coords.append((l0_x, l0_y))
print(f"Total patches to process: {len(patch_coords)}")
return slide, patch_coords
# 実行例
# slide_obj, coords = get_tissue_coordinates(CONFIG['WSI_PATH'], CONFIG['PATCH_SIZE'])
実装上の注意点: 座標変換に誤りがあると、AIは分析対象外の領域(ガラスの汚れなど)を解析してしまい、結果の信頼性が損なわれます。「サムネイルの座標」と「レベル0の座標」の倍率関係(downsample)は、正確に計算・確認する必要があります。
3. リンパ球検出モデルの構築と推論ループ
次に、切り出したパッチに対して実際にAIモデルを適用し、リンパ球の有無を判定する推論パイプラインを構築します。WSI(Whole Slide Image)の解析においては、膨大な数のパッチを効率よく処理するためのシステム設計が、全体のパフォーマンスを左右する重要な鍵となります。
パッチベース処理の実装:Datasetクラス
PyTorchのDatasetクラスを継承し、WSIから動的にパッチを読み込む専用のクラスを作成します。全パッチを一度にメモリへ展開することは現実的ではないため、推論時に必要な領域だけをオンデマンドで取得するアプローチが必須です。これにより、ディスクI/Oとメモリ使用量を最適化しながら、安定した処理を実現できます。
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch
class WSIPatchDataset(Dataset):
def __init__(self, slide_obj, coords, patch_size, transform=None):
self.slide = slide_obj
self.coords = coords
self.patch_size = patch_size
self.transform = transform
def __len__(self):
return len(self.coords)
def __getitem__(self, idx):
x, y = self.coords[idx]
# レベル0(最高解像度)でパッチを読み込む
patch = self.slide.read_region((x, y), 0, (self.patch_size, self.patch_size)).convert("RGB")
if self.transform:
patch = self.transform(patch)
else:
# デフォルトの前処理(Tensor化とImageNet基準の正規化)
t = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
patch = t(patch)
return patch, (x, y)
転移学習モデルと推論の高速化
リンパ球の分類には、軽量な畳み込みニューラルネットワークを転移学習させる手法が一般的です。長らくベースラインとしてResNet18が広く使われており、現在でもPyTorchの公式リポジトリからオリジナルのアーキテクチャ(18層の残差ブロック構成)を簡単にロードして利用できます。実装時は以下のように呼び出すのが標準的です。
# ResNet18の事前学習済みモデルのロード例
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
一方で、最新の研究トレンドにおいては、MobileNetV3などのより最適化された軽量モデルがResNet18の精度を上回るケースが報告されています。さらに、エッジ環境や計算リソースが限られた環境での推論を想定する場合、モデルの量子化(4bit化など)やプルーニングを適用して計算コストを下げるアプローチも、ROIを最大化する上で有効な選択肢となります。
ここではモデルの内部定義は省略し、推論ループの実装に焦点を当てます。数万から数十万に及ぶパッチを現実的な時間で処理するためには、DataLoaderのnum_workersを適切に設定し、GPUを活用したバッチ処理を構築することが不可欠です。
def run_inference(model, dataset, batch_size=32, device='cuda'):
# num_workersは環境のCPUコア数に合わせて調整
loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
model.eval()
model.to(device)
results = []
with torch.no_grad():
for batch_imgs, batch_coords in loader:
batch_imgs = batch_imgs.to(device)
outputs = model(batch_imgs)
# Softmaxで確率に変換(インデックス1がリンパ球クラスと仮定)
probs = torch.nn.functional.softmax(outputs, dim=1)[:, 1]
# 結果をCPUに戻してNumPy配列化
probs_np = probs.cpu().numpy()
# バッチ内の各画像の結果を保存
# batch_coordsは (x_tensor, y_tensor) のリストになっているため転置が必要
coords_x = batch_coords[0].numpy()
coords_y = batch_coords[1].numpy()
for i in range(len(probs_np)):
results.append({
'x': coords_x[i],
'y': coords_y[i],
'til_prob': probs_np[i]
})
return results
推論結果は、空間座標(x, y)とリンパ球の確率(til_prob)を紐付けた辞書のリストとして保存されます。このデータ構造を維持することで、後のステップでヒートマップとして可視化したり、空間的な分布を解析したりする作業を体系的かつスムーズに進めることができます。
4. 予後予測のための定量化ロジックと可視化
推論が完了すると、大量の「座標と確率」のデータが得られます。しかし、この段階ではまだ臨床的に意味のあるデータとは言えません。ここから、実務で活用できる「TILs密度スコア」を算出するプロセスに移行します。
TILs密度の算出アルゴリズム:腫瘍内 vs 間質
単純に「検出されたリンパ球の数」をカウントするだけでは不十分です。予後予測において本質的に重要な指標は、「腫瘍面積あたりのリンパ球数(密度)」です。
さらに高度な解析を実装する場合、別のモデル(セグメンテーションモデル)を用いて「腫瘍領域(Tumor)」と「間質領域(Stroma)」を識別し、それぞれの領域内での密度を個別に計算します。今回は基礎的なアプローチとして、スライド全体の組織面積に対する密度を計算するコード例を示します。
def calculate_tils_score(results, patch_size_mm_sq=0.01):
"""
TILsスコア(密度)を計算する
patch_size_mm_sq: 1パッチあたりの物理面積(mm^2)。倍率と解像度から計算。
"""
total_patches = len(results)
if total_patches == 0: return 0.0
# 確率0.5以上をリンパ球ありと判定
til_positive_count = sum(1 for r in results if r['til_prob'] > 0.5)
# 総組織面積 (mm^2)
total_area_mm2 = total_patches * patch_size_mm_sq
# 密度 (cells / mm^2)
# 注: ここでは1パッチに1細胞という簡易計算だが、実際はパッチ内の細胞数を回帰するモデルを使うか、
# セグメンテーションで個数をカウントする
density = til_positive_count / total_area_mm2
return density
# 実際には、WSIのメタデータからmpp(microns per pixel)を取得して面積を計算する
# mpp = float(slide.properties.get(openslide.PROPERTY_NAME_MPP_X, 0.25))
ヒートマップの生成とオーバーレイ表示
算出した数値の妥当性を現場の医師が直感的に検証できるよう、結果の可視化は不可欠なプロセスです。元画像の上にヒートマップを重ねることで、AIが「どの領域を根拠に」判断を下したのかを明確に提示します。
def generate_heatmap(slide, results, patch_size, downsample_factor=32):
w, h = slide.dimensions
# ヒートマップ用の縮小サイズ
hm_w, hm_h = w // downsample_factor, h // downsample_factor
heatmap = np.zeros((hm_h, hm_w), dtype=np.float32)
for r in results:
# 座標をヒートマップのサイズに変換
hx, hy = r['x'] // downsample_factor, r['y'] // downsample_factor
if 0 <= hy < hm_h and 0 <= hx < hm_w:
heatmap[hy, hx] = r['til_prob']
# カラーマップ適用
heatmap_color = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
# 元のサムネイル画像とブレンドして表示(省略)
return heatmap_color
5. 実運用に向けた最適化と精度検証
最後に、プロジェクトをPoC(概念実証)で終わらせず、実際の業務に定着させるための実務的な最適化ポイントを解説します。このフェーズの設計が、投資対効果を生む「実用的なAI」を実現するための分かれ道となります。
染色バリエーションへの対応(Color Normalization)
病理画像は、施設や担当する技師によって染色の濃さ(H&E染色の色味)に大きなばらつきが生じます。特定の医療機関のデータのみで学習したモデルを別の施設のデータに適用した場合、精度が著しく低下する現象は一般的な課題として広く認識されています。
この課題への対策として、Macenko法やVahadane法を用いた染色正規化(Color Normalization)を前処理パイプラインに組み込むことが、モデルの汎化性能を担保する上で極めて有効です。Python環境ではtorchstainライブラリを活用することで、効率的に実装可能です。
import torchstain
# Macenko法による正規化器の初期化
normalizer = torchstain.normalizers.MacenkoNormalizer(backend='torch')
# データセットのtransform内で適用
# target_imageは基準となるきれいな染色の画像
normalizer.fit(target_image)
# 各パッチ読み込み時に正規化を実行
# normalized_patch, H, E = normalizer.normalize(I=patch, stains=True)
偽陽性を減らすための後処理テクニック
AIモデルは、スライド上のインクの汚れ、組織の折り目、気泡などのアーティファクトを「細胞」として誤認するリスクを孕んでいます。この偽陽性を抑制するためには、単一の分類モデルに依存するのではなく、異常検知(Anomaly Detection)の仕組みを統合するか、ヒューリスティックなルール(極端に色が濃い、あるいは薄い領域を機械的に除外するなど)を組み合わせるアプローチが実務上有効です。
さらに、算出したTILsスコアを生存期間データ(Kaplan-Meier曲線など)と突き合わせ、統計的に有意な差が確認できる最適な閾値を探索・設定することが、臨床応用における最終的なゴールとなります。
まとめ
今回は、TILs自動定量化のためのパイプライン構築について、実践的なコードを交えて体系的に解説しました。
- WSIの効率的なハンドリング: 適切なメモリ管理と正確な座標変換の徹底。
- パッチベースの推論: DatasetクラスとDataLoaderを活用した処理の高速化。
- 臨床指標への変換: 単純なカウントに留まらない、密度や空間的分布を考慮したロジック設計。
- 実運用を見据えた最適化: 染色正規化やノイズ除去といった、堅牢性を高めるためのエンジニアリング。
AIによる病理画像解析は、決して魔法のような技術ではなく、こうした論理的かつ地道なエンジニアリングの積み重ねによって成立します。しかし、現場の課題に即して正しく設計・実装されれば、医療従事者の負担を大幅に軽減し、患者に対してより適切な治療方針を提供するための強力な基盤となります。
PoCの枠を超え、実際のビジネスや医療現場で価値を生み出すAIシステムの構築に向けて、本記事の解説がプロジェクト推進の一助となれば幸いです。
コメント