YOLOv8での学習方法は別の投稿で紹介したので、ここでは学習させた結果から、実際に推定を行う方法を説明する。また、今回は抵抗値(カラーコード)の推定のために、画像から抵抗を推定し、抵抗の画像を切り取って保存する処理を行う。
推定テスト
まずは、学習させたモデルで推定を行う。学習した結果のモデルはtrainフォルダ内のweightsフォルダにある.ptファイルとなる。last.ptは学習した最後の結果で、best.ptは学習途中で最も誤差が少ない結果である。どちらを使うかはresults.pngを見るなどして決める。今回はbest.ptを使ってテストした。プログラムは公式を参考にした(https://docs.ultralytics.com/ja/usage/python/#predict)。処理はシンプルでbest.ptと画像の読み込み、推定した結果の保存である。結果はyolov8/runs/detect/predict に保存される。
Python
from ultralytics import YOLO
import cv2
model = YOLO("/content/drive/MyDrive/yolov8/runs/detect/train2/weights/best.pt")
im2 = cv2.imread("/content/drive/MyDrive/yolov8/data/originals/DSC_0178.JPG")
results = model.predict(source=im2, save=True)
推定した抵抗画像の切り取りと保存
抵抗画像からカラーコードを見つけるには、抵抗画像を使って学習させる必要がある。そのため、画像中から抵抗を切り取り保存を行う。大まかな処理の順序は以下のとおりである。
- モデルの読み込み
- 画像の読み込み
- 推定
- 画像の切り取り
- 保存
- 2-5を画像枚数分繰り返す
切り取りと保存を行うプログラムを以下に示す。プログラムを実行した結果、抵抗の画像が切り取られて保存されているのが分かる。
Python
import os
import cv2
from ultralytics import YOLO
import shutil
# パラメータ
IMG_SOURCE_PATH="/content/drive/MyDrive/yolov8/data/originals/"
IMG_EXT = ".JPG"
IMG_SAVE_PATH="/content/drive/MyDrive/yolov8/data/cropped_img/"
model = YOLO("/content/drive/MyDrive/yolov8/runs/detect/train2/weights/best.pt")
# 既存フォルダ削除
if os.path.exists(IMG_SAVE_PATH):
shutil.rmtree(IMG_SAVE_PATH)
# フォルダ作成
os.makedirs(IMG_SAVE_PATH,exist_ok=True)
# 画像リストを取得
img_list = os.listdir(IMG_SOURCE_PATH)
for img_file in img_list:
# ファイル名と拡張子に分離
filename,ex=os.path.splitext(img_file)
if ex==IMG_EXT:
# 画像読み込み
img = cv2.imread(IMG_SOURCE_PATH+img_file)
# 推定実行
results = model.predict(source=img)
# 画像切り取りと保存のループ
for i,result in enumerate(results[0].boxes.xyxy):
left,top,right,bottom=map(int,result.tolist())
# 切り取り
roi_img = img[top:bottom,left:right]
# 保存
cv2.imwrite(IMG_SAVE_PATH+filename+'_'+str(i)+IMG_EXT, roi_img)
抵抗値(カラーコード)推定
保存した抵抗の画像を使用して、アノテーションと学習を行ない、推定ができるまでは、抵抗の推定と同じように行なった。ここでは、複数の抵抗が写っている画像から、その抵抗値を推定するプログラムを説明する。
- 抵抗推定モデルの読み込み
- カラーコード推定モデルの読み込み
- 画像の読み込み
- 抵抗の推定
- 抵抗画像の切り取り
- カラーコードの推定
- カラーコードの中から許容差(金のカラーコード)の位置を計算
- 他のカラーコードと許容差のカラーコードの距離を計算
- 距離順にソート(カラーコードの順番がわかる)
- 抵抗値の計算
- 結果の表示
- 保存
この処理を実装したプログラムを以下に示す。このプログラムを実行した結果、抵抗値が表示された。
Python
from google.colab.patches import cv2_imshow
import os
import cv2
from ultralytics import YOLO
import numpy as np
#10が先頭の時,0のみの時
IMG_PATH="/content/drive/MyDrive/yolov8/data/originals/DSC_0178.JPG"
SAVE_PATH="/content/drive/MyDrive/yolov8/result.JPG"
resistance_model = YOLO("/content/drive/MyDrive/yolov8/runs/detect/train2/weights/best.pt")
color_code_model = YOLO("/content/drive/MyDrive/yolov8/runs/detect/train4/weights/best.pt")
SI_UNIT_LIST = [[1, ''], #0
[10, ''], #1
[0.1, 'k'], #2
[1, 'k'], #3
[10, 'k'], #4
[0.1, 'M'], #5
[1, 'M'], #6
[10, 'M'], #7
[0.1, 'G'], #8
[1, 'G'], #9
[0.1, '']] #10
# 画像読み込み
img = cv2.imread(IMG_PATH)
# 推定実行
resistance_results = resistance_model.predict(source=img)
# 抵抗画像の切り取りと抵抗値推定のループ
for resistance_result in resistance_results[0].boxes.xyxy:
left,top,right,bottom=map(int,resistance_result.tolist())
# 切り取り
roi_img = img[top:bottom,left:right]
# カラーコード推定実行
color_code_results = color_code_model.predict(source=roi_img)
# クラスと位置のリスト
cls_list = color_code_results[0].boxes.cls.int().tolist()
xyxy_list = color_code_results[0].boxes.xyxy.tolist()
error = True
resistance_value = ""
if len(cls_list)==4: #4つのカラーコードを検出
try:
# 許容差(金のカラーコード)を含む場合そのインデックスを取得
tolerance_index = cls_list.index(10)
except:
print("カラーコード未検出")
# 許容差のカラーコードの中心座標を計算
tolerance_center = np.array([(xyxy_list[tolerance_index][0]+xyxy_list[tolerance_index][2])/2,(xyxy_list[tolerance_index][1]+xyxy_list[tolerance_index][3])/2])
# 許容差のカラーコードからの距離を計算
distance_list = []
for j, (c_left,c_top,c_right,c_bottom) in enumerate(xyxy_list):
color_code_center = np.array([(c_left+c_right)/2, (c_top+c_bottom)/2])
distance_list.append([np.linalg.norm(color_code_center-tolerance_center), int(cls_list[j])])
# 距離順にソート
distance_list.sort(reverse=True)
# 許容差以外に金のカラーコードを含む場合対策
if(distance_list[1][1] == 10):
distance_list[1], distance_list[2] = distance_list[2], distance_list[1]
#抵抗値計算
resistance_value_num = round((distance_list[0][1]*10+distance_list[1][1])*SI_UNIT_LIST[distance_list[2][1]][0],2)
resistance_value = str(resistance_value_num) + SI_UNIT_LIST[distance_list[2][1]][1]
error = False
# 結果の描画
if(error):
cv2.rectangle(img,(left,top),(right,bottom),(0,0,255),2)
else:
cv2.rectangle(img,(left,top),(right,bottom),(0,255,0),2)
cv2.putText(img,resistance_value,(left,top-30),cv2.FONT_HERSHEY_SIMPLEX,5,(0,255,0),10,cv2.LINE_AA)
cv2.imwrite(SAVE_PATH, img)
cv2_imshow(img)