YOLOv8で学習データから推定を行う

YOLOv8での学習方法は別の投稿で紹介したので、ここでは学習させた結果から、実際に推定を行う方法を説明する。また、今回は抵抗値(カラーコード)の推定のために、画像から抵抗を推定し、抵抗の画像を切り取って保存する処理を行う。

推定テスト

まずは、学習させたモデルで推定を行う。学習した結果のモデルはtrainフォルダ内のweightsフォルダにある.ptファイルとなる。last.ptは学習した最後の結果で、best.ptは学習途中で最も誤差が少ない結果である。どちらを使うかはresults.pngを見るなどして決める。今回はbest.ptを使ってテストした。プログラムは公式を参考にした(https://docs.ultralytics.com/ja/usage/python/#predict)。処理はシンプルでbest.ptと画像の読み込み、推定した結果の保存である。結果はyolov8/runs/detect/predict に保存される。

Python
from ultralytics import YOLO
import cv2

model = YOLO("/content/drive/MyDrive/yolov8/runs/detect/train2/weights/best.pt")

im2 = cv2.imread("/content/drive/MyDrive/yolov8/data/originals/DSC_0178.JPG")
results = model.predict(source=im2, save=True)

推定した抵抗画像の切り取りと保存

抵抗画像からカラーコードを見つけるには、抵抗画像を使って学習させる必要がある。そのため、画像中から抵抗を切り取り保存を行う。大まかな処理の順序は以下のとおりである。

  1. モデルの読み込み
  2. 画像の読み込み
  3. 推定
  4. 画像の切り取り
  5. 保存
  6. 2-5を画像枚数分繰り返す

切り取りと保存を行うプログラムを以下に示す。プログラムを実行した結果、抵抗の画像が切り取られて保存されているのが分かる。

Python
import os
import cv2
from ultralytics import YOLO
import shutil

# パラメータ
IMG_SOURCE_PATH="/content/drive/MyDrive/yolov8/data/originals/"
IMG_EXT = ".JPG"
IMG_SAVE_PATH="/content/drive/MyDrive/yolov8/data/cropped_img/"

model = YOLO("/content/drive/MyDrive/yolov8/runs/detect/train2/weights/best.pt")

# 既存フォルダ削除
if os.path.exists(IMG_SAVE_PATH):
  shutil.rmtree(IMG_SAVE_PATH)
# フォルダ作成
os.makedirs(IMG_SAVE_PATH,exist_ok=True)

# 画像リストを取得
img_list = os.listdir(IMG_SOURCE_PATH)
for img_file in img_list:
  # ファイル名と拡張子に分離
  filename,ex=os.path.splitext(img_file)
  if ex==IMG_EXT:
    # 画像読み込み
    img = cv2.imread(IMG_SOURCE_PATH+img_file)
    # 推定実行
    results = model.predict(source=img)

    # 画像切り取りと保存のループ
    for i,result in enumerate(results[0].boxes.xyxy):
      left,top,right,bottom=map(int,result.tolist())
      # 切り取り
      roi_img = img[top:bottom,left:right]
      # 保存
      cv2.imwrite(IMG_SAVE_PATH+filename+'_'+str(i)+IMG_EXT, roi_img)

抵抗値(カラーコード)推定

保存した抵抗の画像を使用して、アノテーションと学習を行ない、推定ができるまでは、抵抗の推定と同じように行なった。ここでは、複数の抵抗が写っている画像から、その抵抗値を推定するプログラムを説明する。

  1. 抵抗推定モデルの読み込み
  2. カラーコード推定モデルの読み込み
  3. 画像の読み込み
  4. 抵抗の推定
  5. 抵抗画像の切り取り
  6. カラーコードの推定
  7. カラーコードの中から許容差(金のカラーコード)の位置を計算
  8. 他のカラーコードと許容差のカラーコードの距離を計算
  9. 距離順にソート(カラーコードの順番がわかる)
  10. 抵抗値の計算
  11. 結果の表示
  12. 保存

この処理を実装したプログラムを以下に示す。このプログラムを実行した結果、抵抗値が表示された。

Python
from google.colab.patches import cv2_imshow
import os
import cv2
from ultralytics import YOLO
import numpy as np
#10が先頭の時,0のみの時

IMG_PATH="/content/drive/MyDrive/yolov8/data/originals/DSC_0178.JPG"
SAVE_PATH="/content/drive/MyDrive/yolov8/result.JPG"

resistance_model = YOLO("/content/drive/MyDrive/yolov8/runs/detect/train2/weights/best.pt")
color_code_model = YOLO("/content/drive/MyDrive/yolov8/runs/detect/train4/weights/best.pt")

SI_UNIT_LIST = [[1, ''],   #0
                [10, ''],   #1
                [0.1, 'k'], #2
                [1, 'k'],   #3
                [10, 'k'],  #4
                [0.1, 'M'], #5
                [1, 'M'],   #6
                [10, 'M'],  #7
                [0.1, 'G'], #8
                [1, 'G'],   #9
                [0.1, '']]  #10
                
# 画像読み込み
img = cv2.imread(IMG_PATH)

# 推定実行
resistance_results = resistance_model.predict(source=img)

# 抵抗画像の切り取りと抵抗値推定のループ
for resistance_result in resistance_results[0].boxes.xyxy:
  left,top,right,bottom=map(int,resistance_result.tolist())
  # 切り取り
  roi_img = img[top:bottom,left:right]
  # カラーコード推定実行
  color_code_results = color_code_model.predict(source=roi_img)

  # クラスと位置のリスト
  cls_list = color_code_results[0].boxes.cls.int().tolist()
  xyxy_list = color_code_results[0].boxes.xyxy.tolist()

  error = True
  resistance_value = ""

  if len(cls_list)==4: #4つのカラーコードを検出
    try:
      # 許容差(金のカラーコード)を含む場合そのインデックスを取得
      tolerance_index = cls_list.index(10)
    except:
      print("カラーコード未検出")

    # 許容差のカラーコードの中心座標を計算
    tolerance_center = np.array([(xyxy_list[tolerance_index][0]+xyxy_list[tolerance_index][2])/2,(xyxy_list[tolerance_index][1]+xyxy_list[tolerance_index][3])/2])
    # 許容差のカラーコードからの距離を計算
    distance_list = []
    for j, (c_left,c_top,c_right,c_bottom) in enumerate(xyxy_list):
      color_code_center = np.array([(c_left+c_right)/2, (c_top+c_bottom)/2])
      distance_list.append([np.linalg.norm(color_code_center-tolerance_center), int(cls_list[j])])
    # 距離順にソート
    distance_list.sort(reverse=True)

    # 許容差以外に金のカラーコードを含む場合対策
    if(distance_list[1][1] == 10):
      distance_list[1], distance_list[2] = distance_list[2], distance_list[1]
    #抵抗値計算
    resistance_value_num = round((distance_list[0][1]*10+distance_list[1][1])*SI_UNIT_LIST[distance_list[2][1]][0],2)
    resistance_value = str(resistance_value_num) + SI_UNIT_LIST[distance_list[2][1]][1]
    error = False

  # 結果の描画
  if(error):
    cv2.rectangle(img,(left,top),(right,bottom),(0,0,255),2)
  else:
    cv2.rectangle(img,(left,top),(right,bottom),(0,255,0),2)
    cv2.putText(img,resistance_value,(left,top-30),cv2.FONT_HERSHEY_SIMPLEX,5,(0,255,0),10,cv2.LINE_AA)

cv2.imwrite(SAVE_PATH, img)
cv2_imshow(img)

コメントする