Gihak111 Navbar

Expo 앱과 Flask 서버를 활용한 이미지 분류 고도화

이전 글에서는 Expo 앱과 Flask 서버를 연동하여 간단한 MNIST 숫자 분류 모델을 구현했다. 이번 글에서는 이 프로젝트를 한 단계 더 발전시켜 보자. CIFAR-10 데이터셋을 활용해 이미지 전처리와 모델의 복잡성을 높이고, 추가적인 기능을 구현하여 더 정교한 예측이 가능하도록 할 것이다.

프로젝트 확장 개요

이번 확장에서는 CIFAR-10 데이터셋을 사용한 이미지 분류 모델을 Flask 서버로 배포하고, Expo 앱에서 이를 활용해 더 복잡한 이미지 분류 작업을 수행할 것이다. CIFAR-10 데이터셋은 10개의 클래스(비행기, 자동차, 새, 고양이 등)로 구성된 컬러 이미지 데이터셋으로, 이미지 분류 작업에 널리 사용된다. 또한, 서버와의 통신 최적화와 앱의 기능을 확장하여 사용자 경험을 개선하는 방법도 함께 살펴보자.

1. CIFAR-10 데이터셋 소개

CIFAR-10 데이터셋은 10개의 클래스와 6만 장의 32x32 픽셀 컬러 이미지로 구성되어 있다. 각 클래스는 다음과 같은 객체를 나타낸다:

이 데이터셋을 사용해 더 복잡한 이미지를 분류할 수 있는 모델을 구현하고, 이를 Flask 서버에 배포해보자. CIFAR-10은 MNIST에 비해 이미지 해상도가 높고, 컬러 이미지를 사용하므로, 모델의 복잡성도 자연스럽게 증가한다.

2. Flask 서버에서 CNN 모델 구축

CIFAR-10 데이터셋을 처리하기 위해 CNN(Convolutional Neural Network) 모델을 구축한다. CNN은 이미지 처리에 탁월한 성능을 보이며, 특히 패턴 인식에 강점이 있다. 아래 코드는 CIFAR-10 데이터를 처리할 수 있는 간단한 CNN 모델을 정의하고, 이를 Flask 서버에 배포하는 예제이다.

from flask import Flask, request, jsonify
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

app = Flask(__name__)

# CNN 모델 정의
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, 1)
        self.conv2 = nn.Conv2d(16, 32, 3, 1)
        self.fc1 = nn.Linear(32 * 6 * 6, 128)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 32 * 6 * 6)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# 모델 초기화 및 로드
model = CNN()
model.load_state_dict(torch.load('cifar10_model.pth'))  # 사전 훈련된 가중치 로드
model.eval()

# 전처리 함수
def transform_image(image_bytes):
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ])
    image = transform(image_bytes)
    image = image.unsqueeze(0)  # 배치 차원 추가
    return image

# API 엔드포인트
@app.route('/predict', methods=['POST'])
def predict():
    if 'file' not in request.files:
        return jsonify({'error': 'No file uploaded'}), 400

    file = request.files['file']
    image = transform_image(file)
    with torch.no_grad():
        output = model(image)
        prediction = output.argmax(dim=1, keepdim=True)
    return jsonify({'prediction': prediction.item()})

if __name__ == '__main__':
    app.run(debug=True)

코드 설명

3. Expo 앱 확장

이제 Expo 앱에서도 CIFAR-10 모델의 예측 결과를 처리할 수 있도록 코드를 확장하자. 이전에 만든 Expo 앱에서 몇 가지 기능을 추가하고 개선한다.

import React, { useState } from 'react';
import { Button, Image, View, Text } from 'react-native';
import * as ImagePicker from 'expo-image-picker';
import axios from 'axios';

export default function App() {
  const [image, setImage] = useState(null);
  const [prediction, setPrediction] = useState(null);

  const pickImage = async () => {
    let result = await ImagePicker.launchImageLibraryAsync({
      mediaTypes: ImagePicker.MediaTypeOptions.Images,
      allowsEditing: true,
      aspect: [4, 3],
      quality: 1,
    });

    if (!result.canceled) {
      setImage(result.uri);
    }
  };

  const uploadImage = async () => {
    if (!image) return;

    let formData = new FormData();
    formData.append('file', {
      uri: image,
      name: 'photo.jpg',
      type: 'image/jpg',
    });

    try {
      const response = await axios.post('http://<your-server-ip>:5000/predict', formData, {
        headers: {
          'Content-Type': 'multipart/form-data',
        },
      });
      const classNames = ["비행기", "자동차", "", "고양이", "사슴", "", "개구리", "", "", "트럭"];
      setPrediction(classNames[response.data.prediction]);
    } catch (error) {
      console.error(error);
    }
  };

  return (
    <View style=>
      <Button title="Pick an image from camera roll" onPress={pickImage} />
      {image && <Image source= style= />}
      <Button title="Upload to Flask Server" onPress={uploadImage} />
      {prediction !== null && <Text>Prediction: {prediction}</Text>}
    </View>
  );
}

코드 설명

4. 추가 개선 사항

a. 실시간 모델 학습 기능

추가로, Flask 서버에 실시간 모델 학습 기능을 구현할 수도 있다. 사용자가 올리는 이미지를 기반으로 모델이 점차 학습하도록 하여, 시간이 지남에 따라 성능이 향상되는 모델을 구축할 수 있다. 이를 위해서는 새로운 이미지가 업로드될 때마다 해당 이미지를 사용해 모델을 추가 학습시키는 기능을 서버에 구현하면 된다.

@app.route('/train', methods=['POST'])
def train():
    if 'file' not in request.files or 'label' not in request.form:
        return jsonify({'error': 'File or label missing'}), 400

    file = request.files['file']
    label = int(request.form['label'])
    image = transform_image(file)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    model.train()
    optimizer.zero_grad()
    output = model(image)
    loss = F.nll_loss(output, torch.tensor([label]))
    loss.backward()
    optimizer.step()

    return jsonify({'status': 'Model updated with new data'})

b. 다양한 이미지 소스 지원

Expo 앱에서 사진첩뿐만 아니라 카메라를 통해 직접 촬영한 이미지를 업로드할 수 있도록 기능을 확장할 수도 있다. 이렇게 하면 사용자가 보다

다양한 환경에서 이미지를 업로드하고 예측 결과를 받을 수 있게 된다.

const pickFromCamera = async () => {
  let result = await ImagePicker.launchCameraAsync({
    allowsEditing: true,
    aspect: [4, 3],
    quality: 1,
  });

  if (!result.canceled) {
    setImage(result.uri);
  }
};

// UI에 추가
<Button title="Take a photo" onPress={pickFromCamera} />

결론

이 글에서는 CIFAR-10 데이터셋을 활용하여 Expo 앱과 Flask 서버 간의 이미지 분류 기능을 확장하고 고도화하는 방법을 다뤘다. 단순한 MNIST 분류에서 벗어나, 실제로 사용될 수 있는 복잡한 이미지 분류 작업을 구현했다. 이러한 기술을 바탕으로 다양한 AI 기반 애플리케이션을 개발하고, 사용자 경험을 지속적으로 개선할 수 있다. 다음 글에서는 더욱 고급화된 이미지 처리 기법과 다양한 AI 모델을 활용하는 방법을 탐구해보자.

위 글은 내가 적은 글을 GPT 에게 어투 통일과 내용 적립을 시켜달라 하고 나온 글이다.
참 기술이 많이 발전한 것 같다.