機械学習とか

条件付きC RNN GANで音楽を生成

前置き

昨年のGANの動向を見ると,半教師あり学習・条件付きでの生成がトレンドだったように思います. そこでC-RNN-GANのモデルを元にを条件付きで音楽を生成したいと思います.

条件付きGAN

条件付きのGANについてはこちらの論文を参照します.
モデル図は下記の通りのようになっています.
f:id:yakuta55:20180205102708p:plain 生成器には,ランダムな分布zとラベルyを元に生成しています. 一方識別器は,生成されてもの/学習データxとラベルyを入力しxであるかの真偽を出力しています.

最近ではこれとCycleGANなどを組み合わせたStarGANなどがあります.

C-RNN-GAN

C-RNN-GANは以下のようなモデル図になっています. f:id:yakuta55:20180205120421p:plain
基本的に生成器も識別器もRNNになった感じです.
ただ識別器側はBiRNNになっています.
また損失関数も通常のGANと変わりはありません.

Conditional C-RNN-GAN

モデル形状はC RNN GANで同じで,両方のモデルにラベルを加えて入力しています. 損失はWGANを使用しました.

生成結果はこちらにおいておきます.

Conditional_C_RNN_GAN/generated_mid at wgan · TrsNium/Conditional_C_RNN_GAN · GitHub

聞いてみればわかるのですが,mode collapseがおきています.
これを回避するためにWGANなどを入れて見たのですが上手く学習をすることができなかったようです.

コードは以下に置いておきます.

データセット

freemidiからデータを集めました.
一応カテゴライズされているので,そのカテゴリーに沿ってラベルを作成しデータセットを作成しました.
スクレイピングのコードは下記にあります.
実行するときは,chrome driverが必要なのでご用意してください.

from bs4 import BeautifulSoup
from urllib.request import urlopen
from urllib.request import urlretrieve
import re
import os
from selenium import webdriver
from selenium.webdriver.common.keys import Keys
import time

content_url = "https://freemidi.org/"
html_doc = urlopen(content_url+"genre").read()
sp = BeautifulSoup(html_doc)

genres = sp.find_all("div", {"class":"genre-big-ones"})
genres_href = [[tag["href"] for tag in genre.find_all("a")] for genre in genres]

'''
[['genre-rock', 'genre-pop', 'genre-hip-hop-rap', 'genre-rnb-soul'],
 ['genre-classical', 'genre-country', 'genre-jazz', 'genre-blues'],
 ['genre-dance-eletric', 'genre-folk', 'genre-punk', 'genre-newage'],
 ['genre-reggae-ska', 'genre-metal', 'genre-disco', 'genre-bluegrass']]
'''

genre_artist = {}
#お好きなジャンルをお選びください
selected_genre = []
for hrefs in genres_href:
    for href in hrefs:
        if not href in selected_genre:
            continue
        content = {}
        url = content_url + href
        html_doc = BeautifulSoup(urlopen(url).read(), "lxml")
        artist_hrefs = html_doc.find_all("div", {"class": "genre-link-text"})
        print(href)
        for artist_href in artist_hrefs:
            artist_href_ = artist_href.a.get("href")
            a_html_doc = BeautifulSoup(urlopen(content_url+artist_href_).read())
            content[artist_href.a.string] = {re.sub("\r\n\s{2,}", "", artist_href.a.string):a.get("href") for a in a_html_doc.find_all("a", {"itemprop":"url"})[1:]}
        genre_artist[href] = content

cwd = os.getcwd()
for key, item in genre_artist.items():
    if not os.path.exists(key):
        os.mkdir(key)
        
    chromeOptions = webdriver.ChromeOptions()
    prefs = {"download.default_directory" : cwd+"/"+key+"/"}
    chromeOptions.add_experimental_option("prefs",prefs)
    
    for artist_n, song_dict in item.items():                           
        for song_n, song_href in song_dict.items():
            try:
                browser = webdriver.Chrome(executable_path="chromedriver", chrome_options=chromeOptions)
                browser.get(content_url+song_href)
                browser.find_element_by_link_text('Download MIDI').click()
                time.sleep(5)
                browser.quit()
            except:
                continue

おわり

想定していた結果が得れませんでしたが,考え直すとデータ数が少なかったのかもしれません. データ数を増やしたり損失関数を工夫してもう一度挑戦して行きたいと思います.